diff --git a/RAGSDK/MainRepo/Dockerfile/ubuntu/Dockerfile b/RAGSDK/MainRepo/Dockerfile/ubuntu/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..a495a2c395f87ea201a09fa2d359a2dd1044828c --- /dev/null +++ b/RAGSDK/MainRepo/Dockerfile/ubuntu/Dockerfile @@ -0,0 +1,173 @@ +FROM ubuntu:20.04 + +WORKDIR /tmp + +ARG ARCH=aarch64 +ARG PYTHON_VERSION=python3.11 +ARG TORCH_VERSION=2.1.0 +# 设置时区禁用交互式配置 +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && apt-get install -y software-properties-common && add-apt-repository -y ppa:deadsnakes/ppa && apt-get update +RUN apt-get update && apt-get install -y vim tar zip unzip git curl wget dos2unix make gcc g++ ccache gfortran libssl-dev libpq-dev swig ffmpeg + +RUN apt-get update && apt-get install -y build-essential ${PYTHON_VERSION} ${PYTHON_VERSION}-dev ${PYTHON_VERSION}-distutils ${PYTHON_VERSION}-venv +COPY ./package/urls.conf /tmp/urls.conf +RUN . /tmp/urls.conf && curl -k $PYPI_URL | ${PYTHON_VERSION} && update-alternatives --install /usr/bin/python3 python3 /usr/bin/${PYTHON_VERSION} 1 + +#配置 python3-config +RUN ln -sf /usr/bin/${PYTHON_VERSION}-config /usr/local/bin/python3-config + +# 安装cmake +RUN . /tmp/urls.conf && wget $CMAKE_URL && \ + tar -zxf cmake-3.24.3.tar.gz && \ + cd cmake-3.24.3 && \ + ./bootstrap && make -j && make install + +# 请根据在服务器上执行npu-smi info 命令进行查询,将查询到的"Name"字段最后一位数字删除后值修改PLATFORM字段 +ARG PLATFORM=310P + +#解决blinker无法卸载的问题 +RUN apt-get remove -y python3-blinker && apt-get autoremove -y + +# 安装cann 依赖 +RUN pip3 install --upgrade setuptools && pip3 install numpy==1.26.4 decorator==5.1.1 sympy==1.12 cffi==1.16.0 pyyaml==6.0.1 pathlib2==2.3.7.post1 protobuf==5.26.0 scipy==1.12.0 requests==2.31.0 psutil==5.9.8 absl-py==2.1.0 attrs==23.2.0 + +# 安装torch for x86 +# RUN pip3 install torch=="${TORCH_VERSION}" --index-url https://download.pytorch.org/whl/cpu +# 安装torch-npu +# RUN pip3 install torch-npu=="${TORCH_VERSION}".post3 + +# 安装cann-toolkit和kernel +COPY ./package/Ascend-cann*_linux-${ARCH}.run /tmp/ +RUN useradd -d /home/HwHiAiUser -u 1000 -m -s /bin/bash HwHiAiUser +RUN bash /tmp/Ascend-cann-toolkit*_linux-${ARCH}.run --install --install-for-all --quiet +RUN bash /tmp/Ascend-cann-kernels*_linux-${ARCH}.run --install --install-for-all --quiet +#安装 nnal +RUN bash -c "source /usr/local/Ascend/ascend-toolkit/set_env.sh && bash /tmp/Ascend-cann-nnal*_linux-${ARCH}.run --install --quiet" + +ENV ASCEND_HOME_PATH=/usr/local/Ascend/ascend-toolkit/latest +ENV TOOLCHAIN_HOME=/usr/local/Ascend/ascend-toolkit/latest/toolkit +ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest +ENV ASCEND_OPP_PATH=/usr/local/Ascend/ascend-toolkit/latest/opp +ENV ASCEND_AICPU_PATH=/usr/local/Ascend/ascend-toolkit/latest +ENV LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64:/usr/local/Ascend/driver/lib64/common:/usr/local/Ascend/driver/lib64/driver:$LD_LIBRARY_PATH + +RUN pip3 install build wheel + + +# 安装openblas +ARG OPENBLAS_INSTALL_PATH=/usr/local/Ascend/OpenBLAS +RUN . /tmp/urls.conf && cd /tmp && \ + git clone $OPENBLAS_URL && \ + cd OpenBLAS && \ + git checkout v0.3.10 && \ + make FC=gfortran USE_OPENMP=1 -j && \ + make PREFIX=${OPENBLAS_INSTALL_PATH} install +ENV LD_LIBRARY_PATH=${OPENBLAS_INSTALL_PATH}/lib:$LD_LIBRARY_PATH + + +# 安装faiss +ARG FAISS_INSTALL_PATH=/usr/local/faiss/faiss1.7.4 +RUN . /tmp/urls.conf && wget $FAISS_URL && \ + tar -xf v1.7.4.tar.gz && \ + cd faiss-1.7.4/faiss && \ + sed -i "131 i virtual void search_with_filter (idx_t n, const float *x, idx_t k, float *distances, idx_t *lables, const void *mask = nullptr) const{}" Index.h && \ + sed -i "38 i template IndexIDMapTemplate::IndexIDMapTemplate (IndexT *index, std::vector &ids): index (index), own_fields (false) { this->is_trained = index->is_trained; this->metric_type = index->metric_type; this->verbose = index->verbose; this->d = index->d; id_map = ids; }" IndexIDMap.cpp && \ + sed -i "29 i explicit IndexIDMapTemplate (IndexT *index, std::vector &ids);" IndexIDMap.h && \ + sed -i "199 i utils/sorting.h" CMakeLists.txt && \ + cd .. && cmake -B build . -DFAISS_ENABLE_GPU=OFF -DPython_EXECUTABLE=/usr/bin/python3 -DBUILD_TESTING=OFF -DBUILD_SHARED_LIBS=ON -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=${FAISS_INSTALL_PATH} && \ + make -C build -j faiss && \ + make -C build -j swigfaiss && \ + cd build/faiss/python && python3 setup.py bdist_wheel && \ + cd ../../.. && make -C build install && \ + cd build/faiss/python && cp libfaiss_python_callbacks.so ${FAISS_INSTALL_PATH}/lib && \ + cd dist && \ + pip3 install faiss-1.7.4*.whl + +# 安装index +COPY ./package/Ascend-mindxsdk-mxindex*_linux-${ARCH}.run /tmp/ +RUN bash /tmp/Ascend-mindxsdk-mxindex*_linux-${ARCH}.run --quiet --install --install-path=/usr/local/Ascend --platform=${PLATFORM} +ENV MX_INDEX_MODELPATH=/home/ascend/modelpath +RUN cd /usr/local/Ascend/mxIndex/ops && ./custom_opp_${ARCH}.run && mkdir -p /home/ascend/modelpath + +# 安装ascendfaiss,到这里了 +ARG FAISS_INSTALL_PATH=/usr/local/faiss/faiss1.7.4 +ARG MXINDEX_INSTALL_PATH=/usr/local/Ascend/mxIndex +ARG PYTHON_HEADER=/usr/include/${PYTHON_VERSION}/ +ARG ASCEND_INSTALL_PATH=/usr/local/Ascend/ascend-toolkit/latest +ARG DRIVER_INSTALL_PATH=/usr/local/Ascend + +COPY ./package/driver /usr/local/Ascend/driver +RUN . /tmp/urls.conf && wget $ASCENDFAISS_URL + + +RUN unzip master.zip && \ + cd mindsdk-referenceapps-master/IndexSDK/faiss-python && \ + swig -python -c++ -Doverride= -module swig_ascendfaiss -I${PYTHON_HEADER} -I${FAISS_INSTALL_PATH}/include -I${MXINDEX_INSTALL_PATH}/include -DSWIGWORDSIZE64 -o swig_ascendfaiss.cpp swig_ascendfaiss.swig && \ + g++ -std=c++11 -DFINTEGER=int -fopenmp -I/usr/local/include -I${ASCEND_INSTALL_PATH}/acllib/include -I${ASCEND_INSTALL_PATH}/runtime/include -I${DRIVER_INSTALL_PATH}/driver/kernel/inc/driver -I${DRIVER_INSTALL_PATH}/driver/kernel/libc_sec/include -fPIC -fstack-protector-all -Wall -Wreturn-type -D_FORTIFY_SOURCE=2 -g -O3 -Wall -Wextra -I${PYTHON_HEADER} -I/usr/local/lib/${PYTHON_VERSION}/dist-packages/numpy/core/include -I${FAISS_INSTALL_PATH}/include -I${MXINDEX_INSTALL_PATH}/include -c swig_ascendfaiss.cpp -o swig_ascendfaiss.o && \ + g++ -std=c++11 -shared -fopenmp -L${ASCEND_INSTALL_PATH}/lib64 -L${ASCEND_INSTALL_PATH}/acllib/lib64 -L${ASCEND_INSTALL_PATH}/runtime/lib64 -L${DRIVER_INSTALL_PATH}/driver/lib64 -L${DRIVER_INSTALL_PATH}/driver/lib64/common -L${DRIVER_INSTALL_PATH}/driver/lib64/driver -L${FAISS_INSTALL_PATH}/lib -L${MXINDEX_INSTALL_PATH}/lib -Wl,-rpath-link=${ASCEND_INSTALL_PATH}/acllib/lib64:${ASCEND_INSTALL_PATH}/runtime/lib64:${DRIVER_INSTALL_PATH}/driver/lib64:${DRIVER_INSTALL_PATH}/driver/lib64/common:${DRIVER_INSTALL_PATH}/driver/lib64/driver -L/usr/local/lib -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -s -o _swig_ascendfaiss.so swig_ascendfaiss.o -L.. -lascendfaiss -lfaiss -lascend_hal -lacl_retr -lascendcl -lc_sec -lacl_op_compiler && \ + ${PYTHON_VERSION} -m build && \ + cd dist && pip3 install ascendfaiss*.whl + + +#ENV LD_LIBRARY_PATH=/usr/local/Ascend/mxIndex/lib:/usr/local/faiss/faiss1.7.4/lib:$LD_LIBRARY_PATH + + +# 安装toch_npu +#RUN pip3 install torch-${TORCH_VERSION}*_${ARCH}.whl && pip3 install torch_npu-${TORCH_VERSION}*_${ARCH}.whl +RUN pip3 install torch-npu=="${TORCH_VERSION}".post8 + +# 安装mxrag +COPY ./package/Ascend-mindxsdk-mxrag_*_linux-${ARCH}.run /tmp/ +RUN bash /tmp/Ascend-mindxsdk-mxrag_*_linux-${ARCH}.run --install --install-path=/usr/local/Ascend --quiet --platform=${PLATFORM} +RUN pip3 install -r /usr/local/Ascend/mxRag/requirements.txt + +# 安装mxrag第三方依赖 +RUN pip3 install ragas==0.1.9 rank_bm25==0.2.2 readability_lxml==0.8.1 html_text==0.6.2 gradio==3.50.2 +#清理临时目录 +RUN rm -rf ./* && rm -rf /usr/local/Ascend/driver + +# 添加环境变量 +RUN sed -i '$a\export PYTHONPATH=/root/.local/lib/$PYTHON_VERSION/site-packages/mx_rag/libs/:$PYTHONPATH' /root/.bashrc && \ + sed -i '$a\export LD_PRELOAD=$(ls /usr/local/lib/$PYTHON_VERSION/dist-packages/scikit_learn.libs/libgomp-*):$LD_PRELOAD' /root/.bashrc && \ + sed -i '$a\source /usr/local/Ascend/ascend-toolkit/set_env.sh' /root/.bashrc && \ + sed -i '$a\source /usr/local/Ascend/nnal/atb/set_env.sh' /root/.bashrc && \ + sed -i '$a\source /usr/local/Ascend/mxRag/script/set_env.sh' /root/.bashrc && \ + sed -i '$a\export LD_LIBRARY_PATH=/usr/local/Ascend/mxIndex/lib:/usr/local/faiss/faiss1.7.4/lib:$LD_LIBRARY_PATH' /root/.bashrc && \ + sed -i '$a\export PATH=/usr/local/bin:$PATH' /root/.bashrc && \ + sed -i 's/$PYTHON_VERSION/'"$PYTHON_VERSION"'/g' /root/.bashrc && \ + sed -i '$a\export LOGURU_FORMAT="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} - {message!r}"' /root/.bashrc + + +USER HwHiAiUser:HwHiAiUser + +# 安装index for HwHiAiUser +COPY ./package/Ascend-mindxsdk-mxindex*_linux-${ARCH}.run /tmp/ +RUN bash /tmp/Ascend-mindxsdk-mxindex*_linux-${ARCH}.run --quiet --install --install-path=/home/HwHiAiUser/Ascend --platform=${PLATFORM} + +# 安装mxrag for HwHiAiUser +COPY ./package/Ascend-mindxsdk-mxrag_*_linux-${ARCH}.run /tmp/ +RUN bash /tmp/Ascend-mindxsdk-mxrag_*_linux-${ARCH}.run --install --install-path=/home/HwHiAiUser/Ascend --quiet --platform=${PLATFORM} + +# 添加环境变量 for HwHiAiUser用户 +RUN sed -i '$a\export PYTHONPATH=/home/HwHiAiUser/.local/lib/$PYTHON_VERSION/site-packages/mx_rag/libs/:$PYTHONPATH' /home/HwHiAiUser/.bashrc && \ + sed -i '$a\export LD_PRELOAD=$(ls /usr/local/lib/$PYTHON_VERSION/dist-packages/scikit_learn.libs/libgomp-*):$LD_PRELOAD' /home/HwHiAiUser/.bashrc && \ + sed -i '$a\source /usr/local/Ascend/ascend-toolkit/set_env.sh' /home/HwHiAiUser/.bashrc && \ + sed -i '$a\source /usr/local/Ascend/nnal/atb/set_env.sh' /home/HwHiAiUser/.bashrc && \ + sed -i '$a\source /home/HwHiAiUser/Ascend/mxRag/script/set_env.sh' /home/HwHiAiUser/.bashrc && \ + sed -i '$a\export LD_LIBRARY_PATH=/home/HwHiAiUser/Ascend/mxIndex/lib:/usr/local/faiss/faiss1.7.4/lib:$LD_LIBRARY_PATH' /home/HwHiAiUser/.bashrc && \ + sed -i '$a\export PATH=/usr/local/bin:$PATH' /home/HwHiAiUser/.bashrc && \ + sed -i 's/$PYTHON_VERSION/'"$PYTHON_VERSION"'/g' /home/HwHiAiUser/.bashrc && \ + sed -i '$a\export LOGURU_FORMAT="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} - {message!r}"' /home/HwHiAiUser/.bashrc + +ENV MX_INDEX_MULTITHREAD=1 +ENV MX_INDEX_FINALIZE=0 + +USER root +RUN chmod +r /usr/local/Ascend/ascend-toolkit/latest/opp/vendors/config.ini +RUN rm -rf /tmp/* +USER HwHiAiUser:HwHiAiUser + +WORKDIR /home/HwHiAiUser + diff --git a/RAGSDK/MainRepo/Dockerfile/ubuntu/urls.conf b/RAGSDK/MainRepo/Dockerfile/ubuntu/urls.conf new file mode 100644 index 0000000000000000000000000000000000000000..ba90bb23b30be8fba19d21871427a7eef3ee9184 --- /dev/null +++ b/RAGSDK/MainRepo/Dockerfile/ubuntu/urls.conf @@ -0,0 +1,5 @@ +PYPI_URL=https://bootstrap.pypa.io/get-pip.py +CMAKE_URL=https://cmake.org/files/v3.24/cmake-3.24.3.tar.gz +OPENBLAS_URL=https://github.com/OpenMathLib/OpenBLAS.git +FAISS_URL=https://github.com/facebookresearch/faiss/archive/v1.7.4.tar.gz +ASCENDFAISS_URL=https://gitee.com/ascend/mindsdk-referenceapps/repository/archive/master.zip diff --git a/RAGSDK/MainRepo/README.md b/RAGSDK/MainRepo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0ce78483d78d6f16014b80c1cc6653b420e19708 --- /dev/null +++ b/RAGSDK/MainRepo/README.md @@ -0,0 +1,36 @@ +# RAG SDK + +RAG SDK是昇腾面向大语言模型的知识增强开发套件,为解决大模型知识更新缓慢以及垂直领域知识问答弱的问题,面向大模型知识库提供垂域调优、生成增强、知识管理等特性,帮助用户搭建专属的高性能、准确度高的大模型问答系统。 + +## 版本配套说明 + +本版本配套RAG SDK 7.1.RC1版本使用,依赖的其他配套软件版本为: + +| 软件包简称 | 配套版本 | +|-----------------------|---------| +| CANN软件包 | 8.2.RC1 | +| 二进制算子包Kernels | 8.2.RC1 | +| npu-drive驱动包 | 25.2.0 | +| npu-firmware固件包 | 25.2.0 | +| Index SDK检索软件包 | 7.1.RC1 | +| MindIE推理引擎软件包 | 2.0.RC2 | +| Ascend Docker Runtime | 7.0.RC1 | + +## 支持的硬件和运行环境 + +| 产品系列 | 产品型号 | +|-------------------|---------------------| +| Atlas 推理系列产品 | Atlas 300I Duo 推理卡 | +| Atlas 800I A2推理产品 | Atlas 800I A2 推理服务器 | + +支持的软件运行环境为:Ubuntu 22.04,Python3.11 + +## 目录结构与说明 + +| 目录 | 说明 | +|------------|--------------------------------------------------------------| +| Dockerfile | 部署RAG SDK容器,用户若自行准备镜像文件的参考样例,对应用户手册《安装RAG SDK》章节。 | +| Samples | RAG SDK完整开发流程的开发参考样例,包含"创建知识库"、"在线问答"、"MxRAGCache缓存和自动生成QA"。 | +| langgraph | Agentic RAG样例。 | +| sd_samples | 安装并运行stable_diffusion模型参考样例。 | + diff --git a/RAGSDK/MainRepo/Samples/RagDemo/README.md b/RAGSDK/MainRepo/Samples/RagDemo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fb57eded41c03f510c11ec977bd7aefe58d0df60 --- /dev/null +++ b/RAGSDK/MainRepo/Samples/RagDemo/README.md @@ -0,0 +1,50 @@ +# RAG SDK Demo运行说明 + +## 前提条件 + +执行Demo前请先阅读《RAG SDK 用户指南》,并按照其中"安装部署"章节的要求完成必要软、硬件安装。 +本章节为"应用开发"章节提供开发样例代码,便于开发者快速开发。 + +## 样例说明 + +详细的样例介绍请参考《RAG SDK 用户指南》"应用开发"章节说明。 其中: + +1.rag_demo_knowledge.py为"创建知识库"样例代码; rag_demo_query为"在线问答"样例代码。 + +2."创建知识库"样例和"在线问答" +样例是以FLATL2检索方式为例,当参数tei_emb为False时表示本地启动embedding模型,需传入参数embedding_path,当参数tei_emb为True时,表示启动服务化模型,需传入参数embedding_url;reranker同理,其中reranker为可选过程,默认不使用。 + +3.rag_demo_cache_qa.py对应"MxRAGCache缓存和自动生成QA"样例。 + +4.fastapi_demo目录下为fastapi多线程并发调用在线问答的样例。 + +5.langchain_chains目录下为使用langchain llm chain支持双向认证样例。 + +注意: +1.创建知识库过程和在线问答过程使用的embedding模型、关系数据库路径、向量数据库路径需对应保持一致。其中关系数据库和向量数据库路径在样例代码中已经默认设置成一致,embedding模型需用户手动设置成一致。 + +## 运行及参数说明 + +1.调用示例 + +```commandline +# 上传知识库,支持多线程上传 +python3 rag_demo_knowledge.py --file_path "/home/data/MindIE.docx" --file_path "/home/data/gaokao.docx" + +# 在线问答,支持多线程问答 +python3 rag_demo_query.py --query "请描述2024年高考作为题目" --query "请问2025年一共有多少天法定节假日" + +# fastapi多线程并发调用,需要安装fastapi和uvicorn +python3 fastapi_multithread.py --llm_url http://x.x.x.x:port/v1/chat/completions 为启动fastapi服务端在线问答,python3 fastapi_request.py为客户端多线程并发请求在线问答的样例,若想单次请求也可以使用curl指令,请求示例如下: +curl -X 'POST' 'http://127.0.0.1:8000/query/' -H 'Content-Type: application/json' -d '{"question": "介绍一下2024年高考题目"}' +``` + +说明: +调用示例前请先根据用户实际情况完成参数配置,确保embedding模型路径正确,大模型能正常访问,文件路径正确等,参数可以通过修改样例代码,也可通过命令行的方式传入,运行fastapi_multithread.py时需要安装fastapi和uvicorn包。 + +2.参数说明 + +```commandline +以"创建知识库"为例,用户可以通过以下命令查看参数情况;如需开发其他样例,请详细参考《RAG SDK用户指南》"接口参考"章节。 +python3 rag_demo_knowledge.py --help +``` \ No newline at end of file diff --git a/RAGSDK/MainRepo/Samples/RagDemo/fastapi_multi_demo/fastapi_multithread.py b/RAGSDK/MainRepo/Samples/RagDemo/fastapi_multi_demo/fastapi_multithread.py new file mode 100644 index 0000000000000000000000000000000000000000..4025ac30668206ae855c52e855b5d6830190ddf7 --- /dev/null +++ b/RAGSDK/MainRepo/Samples/RagDemo/fastapi_multi_demo/fastapi_multithread.py @@ -0,0 +1,142 @@ +import argparse +import os +import threading +from concurrent.futures import ThreadPoolExecutor +from typing import Optional +from pathlib import Path +from fastapi import FastAPI, HTTPException +from langchain.text_splitter import RecursiveCharacterTextSplitter +from mx_rag.chain import SingleText2TextChain +from mx_rag.document import LoaderMng +from mx_rag.document.loader import DocxLoader +from mx_rag.embedding.local import TextEmbedding +from mx_rag.embedding.service import TEIEmbedding +from mx_rag.knowledge import KnowledgeDB +from mx_rag.knowledge.handler import upload_files +from mx_rag.knowledge.knowledge import KnowledgeStore +from mx_rag.llm import Text2TextLLM +from mx_rag.reranker.local import LocalReranker +from mx_rag.reranker.service import TEIReranker +from mx_rag.retrievers import Retriever +from mx_rag.storage.document_store import SQLiteDocstore +from mx_rag.storage.vectorstore import MindFAISS +from mx_rag.utils import ClientParam +from paddle.base import libpaddle + + + + +class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter): + def _get_default_metavar_for_optional(self, action): + return action.type.__name__ + + def _get_default_metavar_for_positional(self, action): + return action.type.__name__ + + +text_retriever = any +llm = any +reranker = any + + +def rag_init(): + parse = argparse.ArgumentParser(formatter_class=CustomFormatter) + parse.add_argument("--embedding_path", type=str, default="/home/mxaiagent/data/acge_text_embedding", + help="embedding模型本地路径") + parse.add_argument("--embedding_url", type=str, default="http://127.0.0.1:8080/embed", + help="使用TEI服务化的embedding模型url地址") + parse.add_argument("--tei_emb", type=bool, default=False, help="是否使用TEI服务化的embedding模型") + parse.add_argument("--llm_url", type=str, default="http://127.0.0.1:1025/v1/chat/completions", help="大模型url地址") + parse.add_argument("--model_name", type=str, default="Llama3-8B-Chinese-Chat", help="大模型名称") + parse.add_argument("--score_threshold", type=float, default=0.5, + help="相似性得分的阈值,大于阈值认为检索的信息与问题越相关,取值范围[0,1]") + parse.add_argument("--reranker_path", type=str, + default="/home/mxaiagent/data/bge-reranker-v2-m3", help="reranker模型本地路径") + parse.add_argument("--reranker_url", type=str, default="http://127.0.0.1:8080/rerank", + help="使用TEI服务化的embedding模型url地址") + parse.add_argument("--tei_reranker", type=bool, default=False, help="是否使用TEI服务化的reranker模型") + parse.add_argument("--white_path", type=str, nargs='+', default=["/home"], help="文件白名单路径") + parse.add_argument("--up_files", type=str, nargs='+', default=None, help="要上传的文件路径,需在白名单路径下") + parse.add_argument("--sql_path", type=str, nargs='+', default="./sql.db", help="关系数据库文件保存路径") + parse.add_argument("--vector_path", type=str, nargs='+', default="./faiss.index", help="向量数据库文件保存路径") + parse.add_argument("--question", type=str, default="描述一下地球的内部结构", help="用户问题") + args = parse.parse_args().__dict__ + embedding_path: str = args.pop('embedding_path') + embedding_url: str = args.pop('embedding_url') + tei_emb: bool = args.pop('tei_emb') + llm_url: str = args.pop('llm_url') + model_name: str = args.pop('model_name') + score_threshold: int = args.pop('score_threshold') + reranker_path: str = args.pop('reranker_path') + reranker_url: str = args.pop('reranker_url') + tei_reranker: bool = args.pop('tei_reranker') + white_path: str = args.pop('white_path') + up_files: list[str] = args.pop('up_files') + sql_path: str = args.pop('sql_path') + vector_path: str = args.pop('vector_path') + question: str = args.pop('question') + + dev = 0 + if tei_emb: + emb = TEIEmbedding(url=embedding_url, client_param=ClientParam(use_http=True)) + else: + emb = TextEmbedding(model_path=embedding_path, dev_id=dev) + chunk_store = SQLiteDocstore(db_path=sql_path) + vector_store = MindFAISS(1024, [dev], load_local_index=vector_path) + global text_retriever + text_retriever = Retriever(vector_store=vector_store, document_store=chunk_store, + embed_func=emb.embed_documents, k=1, score_threshold=score_threshold) + + # 创建知识管理 + knowledge_store = KnowledgeStore(db_path=sql_path) + knowledge_store.add_knowledge("test", "Default01", "admin") + knowledge_db = KnowledgeDB(knowledge_store=knowledge_store, chunk_store=chunk_store, vector_store=vector_store, + knowledge_name="test", white_paths=white_path, user_id="Default01") + + # 上传文档到知识库 + if up_files: + loader_mng = LoaderMng() + loader_mng.register_loader(DocxLoader, [".docx"]) + loader_mng.register_splitter(RecursiveCharacterTextSplitter, [".xlsx", ".docx", ".pdf"], + {"chunk_size": 750, "chunk_overlap": 150, "keep_separator": False}) + upload_files(knowledge_db, up_files, loader_mng=loader_mng, embed_func=emb.embed_documents, force=True) + # 上传文档结束 + + global reranker + if tei_reranker: + reranker = TEIReranker(url=reranker_url, client_param=ClientParam(use_http=True)) + else: + reranker = LocalReranker(reranker_path, dev_id=dev) + global llm + llm = Text2TextLLM(base_url=llm_url, model_name=model_name, client_param=ClientParam(use_http=True)) + + +app = FastAPI() + + +def fun(input_string: str) -> str: + text2text_chain = SingleText2TextChain(retriever=text_retriever, llm=llm, reranker=reranker) + res = text2text_chain.query(input_string) + return f"{res}" + + +# 创建一个线程池执行器 +thread_pool_executor = ThreadPoolExecutor(max_workers=10) + + +@app.post("/query/") +async def call_fun(items: dict): + # 使用线程池异步调用 fun 函数 + future = thread_pool_executor.submit(fun, items['question']) + try: + result = future.result() + return {"result": result} + except Exception as e: + raise HTTPException(status_code=500) from e + + +if __name__ == "__main__": + rag_init() + import uvicorn + + uvicorn.run(app, host="127.0.0.1", port=8000) diff --git a/RAGSDK/MainRepo/Samples/RagDemo/fastapi_multi_demo/fastapi_request.py b/RAGSDK/MainRepo/Samples/RagDemo/fastapi_multi_demo/fastapi_request.py new file mode 100644 index 0000000000000000000000000000000000000000..c6069237b09253448c01e544956392962fa4b650 --- /dev/null +++ b/RAGSDK/MainRepo/Samples/RagDemo/fastapi_multi_demo/fastapi_request.py @@ -0,0 +1,66 @@ +import json +import logging +import random +import threading + +import requests + +logging.basicConfig(level=logging.INFO) +SERVICE_URL = 'http://127.0.0.1:8000/query/' + +QUESTION = [ + "请描述2024年高考作文", + "中国的绿色发展有哪些好处", + "描述一下可持续发展战略", + "解释一下极昼极夜", + "介绍一下台风", + "解释一下地球的内部结构", + "黄土高原的环境问题该怎么解决", + "介绍一下巴西", + "介绍一下爪哇岛", + "高技术工业特区有什么特点", + "台风来了怎么办", + "台风的发生条件?", + "洋流有几种", + "地球上为什么会产生昼夜更替", + "坚持以人民为中心的内涵是什么", + "圣诞节这一天,哈尔滨的白昼为什么最短", + "西北太平洋热带气旋出现的纬度?", + "东部地区重要的地理界线是什么?", + "印尼五大岛屿分别是哪些?", + "描述一下地球的内部结构" +] + +DATA_TEMPLATE = { + 'question': 'value1', +} + + +# 定义一个函数来发送POST请求 +def send_post_request(post_data, thread_id): + try: + headers = {'Content-Type': 'application/json'} + response = requests.post(SERVICE_URL, data=json.dumps(post_data), headers=headers) + logging.info('\nThread %d: Status Code %d, Response Body: %s', thread_id, response.status_code, response.text) + except Exception as e: + logging.debug('Thread %d encountered an error: %s', thread_id, e) + + +# 创建线程列表 +threads = [] + +# 创建并启动多个线程 +for i in range(20): + # 为每个线程准备不同的数据 + data = DATA_TEMPLATE.copy() + data['question'] = f'{QUESTION[random.randint(0, 19)]}' + + # 创建线程 + thread = threading.Thread(target=send_post_request, args=(data, i)) + thread.start() + threads.append(thread) + +for thread in threads: + thread.join() + +logging.info('All threads have finished execution') diff --git a/RAGSDK/MainRepo/Samples/RagDemo/langchain_chains/README.md b/RAGSDK/MainRepo/Samples/RagDemo/langchain_chains/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5eba6c557cca80d0ecb0291d7258f4dd38bf4b98 --- /dev/null +++ b/RAGSDK/MainRepo/Samples/RagDemo/langchain_chains/README.md @@ -0,0 +1,9 @@ +# 使用说明 + +默认不使用双向认证,ssl为True时,需要传递ca_path,cert_path,key_path,pwd等字段 + + 不使用认证 + python3 llm_chain_stream.py --base_url http://127.0.0.1:1025/v1 --model_name Llama3-8B-Chinese-Chat + + 使用双向认证: + python3 llm_chain_stream.py --base_url https://127.0.0.1:1025/v1 --model_name Llama3-8B-Chinese-Chat --ssl True --ca_path xxx --cert_path xxx --key_path xxx --pwd xxx diff --git a/RAGSDK/MainRepo/Samples/RagDemo/langchain_chains/llm_chain_stream.py b/RAGSDK/MainRepo/Samples/RagDemo/langchain_chains/llm_chain_stream.py new file mode 100644 index 0000000000000000000000000000000000000000..3734900002695bf3060b85eb23bca0ed9e09ec91 --- /dev/null +++ b/RAGSDK/MainRepo/Samples/RagDemo/langchain_chains/llm_chain_stream.py @@ -0,0 +1,153 @@ +import sys +import argparse +import threading +from queue import Queue +from typing import Any +import httpx +import openai +from langchain.chains import LLMChain +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.outputs import LLMResult +from langchain_core.prompts import PromptTemplate +from langchain_openai import ChatOpenAI +from loguru import logger + + + +class EventData: + + def __init__(self, data, finish_reason): + self.data = data + self.finish_reason = finish_reason + + +class StreamingLLMCallbackHandler(BaseCallbackHandler): + def __init__(self): + self._is_done = False + self._queue = Queue() + + def clear(self): + with self._queue.mutex: + self._queue.queue.clear() + self._is_done = False + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + self._queue.put(EventData(data=token, finish_reason="0")) + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + self._queue.put(EventData(data="", finish_reason="done")) + self._is_done = True + + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: + logger.error(f" error happend:{error}") + self._queue.put(EventData(data=f"{error}", finish_reason="done")) + self._is_done = True + + @property + def stream_gen(self): + while not self._queue.empty() or not self._is_done: + try: + delta = self._queue.get() + yield str(delta.data) + except Exception as e: + logger.error(f"Exception:{e}") + + +class LLMInfo: + def __init__(self, base_url, model_name, handler, ssl, ca_path, cert_path, key_path, pwd): + self.base_url = base_url + self.model_name = model_name + self.handler = handler + self.ssl = ssl + self.ca_path = ca_path + self.cert_path = cert_path + self.key_path = key_path + self.pwd = pwd + + +def create_llm_chain(params: LLMInfo): + base_url = params.base_url + model_name = params.model_name + handler = params.handler + ssl = params.ssl + ca_path = params.ca_path + cert_path = params.cert_path + key_path = params.key_path + pwd = params.pwd + + if not ssl: + http_client = httpx.Client() + else: + http_client = httpx.Client( + cert=(cert_path, key_path, pwd), + verify=ca_path + ) + root_client = openai.OpenAI( + base_url=base_url, + api_key="sk_fake", + http_client=http_client + ) + + client = root_client.chat.completions + + llm = ChatOpenAI( + api_key="sk_fake", + client=client, + model_name=model_name, + temperature=0.5, + streaming=True, + callbacks=[handler] + ) + + template = """<指令>你是一个旅游专家,请简明扼要回答用户问题。<指令>\n用户问题:{question}""" + prompt = PromptTemplate.from_template(template) + + # chain = LLMChain( + # llm=llm, + # prompt=prompt + # ) + return prompt | llm + + +class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter): + def _get_default_metavar_for_optional(self, action): + return action.type.__name__ + + def _get_default_metavar_for_positional(self, action): + return action.type.__name__ + + +if __name__ == "__main__": + parse = argparse.ArgumentParser(formatter_class=CustomFormatter) + parse.add_argument("--base_url", type=str, default="http://127.0.0.1:1025/v1", help="大模型url base地址") + parse.add_argument("--model_name", type=str, default="Llama3-8B-Chinese-Chat", help="大模型名称") + parse.add_argument("--ssl", type=bool, default=False, help="是否开启认证") + parse.add_argument("--ca_path", type=str, default="", help="ca证书") + parse.add_argument("--cert_path", type=str, default="", help="客户端证书") + parse.add_argument("--key_path", type=str, default="", help="客户端私钥") + parse.add_argument("--pwd", type=str, default="", help="私钥解密口令") + + args = parse.parse_args() + + streaming_llm_callback_handler = StreamingLLMCallbackHandler() + streaming_llm_callback_handler.clear() + + + def get_llm_result(handler): + for chunk in handler.stream_gen: + logger.info(chunk) + + + thread = threading.Thread(target=get_llm_result, args=(streaming_llm_callback_handler,)) + thread.start() + + llm_chain = create_llm_chain(base_url=args.base_url, + model_name=args.model_name, + handler=streaming_llm_callback_handler, + ssl=args.ssl, + ca_path=args.ca_path, + cert_path=args.cert_path, + key_path=args.key_path, + pwd=args.pwd + ) + llm_chain.invoke({"question": "介绍北京风景区"}) diff --git a/RAGSDK/MainRepo/Samples/RagDemo/rag_demo_cache_qa.py b/RAGSDK/MainRepo/Samples/RagDemo/rag_demo_cache_qa.py new file mode 100644 index 0000000000000000000000000000000000000000..6d9fafbb60aa70401ab9ced3a059b9a3aebe73bc --- /dev/null +++ b/RAGSDK/MainRepo/Samples/RagDemo/rag_demo_cache_qa.py @@ -0,0 +1,163 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + +import argparse +import json +import os +import time +import traceback +from loguru import logger +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.document_loaders import TextLoader +from mx_rag.cache import CacheConfig, SimilarityCacheConfig, MxRAGCache, CacheChainChat, \ + MarkDownParser, QAGenerationConfig, QAGenerate +from mx_rag.chain import SingleText2TextChain +from mx_rag.document import LoaderMng +from mx_rag.embedding.local import TextEmbedding +from mx_rag.knowledge import KnowledgeStore, KnowledgeDB, upload_files +from mx_rag.llm import Text2TextLLM +from mx_rag.retrievers import Retriever +from mx_rag.storage.document_store import SQLiteDocstore +from mx_rag.storage.vectorstore import MindFAISS +from mx_rag.utils import ClientParam +from paddle.base import libpaddle +from transformers import AutoTokenizer + + +class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter): + def _get_default_metavar_for_optional(self, action): + return action.type.__name__ + + def _get_default_metavar_for_positional(self, action): + return action.type.__name__ + + +def rag_cache_demo(): + parse = argparse.ArgumentParser(formatter_class=CustomFormatter) + parse.add_argument("--embedding_path", type=str, default="/home/data/acge_text_embedding", + help="embedding模型本地路径") + parse.add_argument("--embedding_dim", type=int, default=1024, help="embedding模型向量维度") + parse.add_argument("--reranker_path", type=str, default="home/data/bge-reranker-large", help="reranker模型本地路径") + parse.add_argument("--white_path", type=str, nargs='+', default=["/home"], help="白名单路径,文件需在白名单路径下") + parse.add_argument("--file_path", type=str, default="/home/HwHiAiUser/gaokao.md", + help="要上传的文件路径,需在白名单路径下") + parse.add_argument("--cache_save_path", type=str, default="/home/HwHiAiUser/cache_save_dir", help="缓存地址") + parse.add_argument("--llm_url", type=str, default="http://127.0.0.1:1025/v1/chat/completions", help="大模型url地址") + parse.add_argument("--model_name", type=str, default="Llama3-8B-Chinese-Chat", help="大模型名称") + parse.add_argument("--score_threshold", type=float, default=0.5, + help="相似性得分的阈值,大于阈值认为检索的信息与问题越相关,取值范围[0,1]") + parse.add_argument("--query", type=str, default="请描述2024年高考作文题目", help="用户问题") + parse.add_argument("--tokenizer_path", type=str, default="/home/data/Llama3-8B-Chinese-Chat/", + help="大模型tokenizer参数路径") + parse.add_argument("--npu_device_id", type=int, default=0, help="NPU设备ID") + args = parse.parse_args() + + try: + # memory cache缓存作为L1缓存 + cache_config = CacheConfig(cache_size=100, data_save_folder=args.cache_save_path) + # similarity cache缓存作为L2缓存 + similarity_config = SimilarityCacheConfig( + vector_config={"vector_type": "npu_faiss_db", + "x_dim": args.embedding_dim, + "devs": [args.npu_device_id]}, + cache_config="sqlite", + emb_config={"embedding_type": "local_text_embedding", + "x_dim": args.embedding_dim, + "model_path": args.embedding_path, + "dev_id": args.npu_device_id + }, + similarity_config={ + "similarity_type": "local_reranker", + "model_path": args.embedding_path, + "dev_id": args.npu_device_id + }, + retrieval_top_k=5, + cache_size=1000, + similarity_threshold=0.8, + data_save_folder=args.cache_save_path) + + # 构造memory cache实例 + memory_cache = MxRAGCache("memory_cache", cache_config) + # 构造similarity cache实例 + similarity_cache = MxRAGCache("similarity_cache", similarity_config) + # memory_cache和similarity_cache串联形成多级缓存入口是memory cache + memory_cache.join(similarity_cache) + # 定义用于生成QA的大模型 + client_param = ClientParam(use_http=True, timeout=600) + llm = Text2TextLLM(base_url=args.llm_url, model_name=args.model_name, client_param=client_param) + # 返回markdown的标题和内容,标题要和内容相关 + titles, contents = MarkDownParser(os.path.dirname(args.file_path)).parse() + # 使用大模型计算token大小 + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, local_files_only=True) + # 组装生成QA的配置参数,qas_num为每个文件要生成的QA数量 + config = QAGenerationConfig(titles, contents, tokenizer, llm, qas_num=3) + # 调用大模型生成QA对 + qas = QAGenerate(config).generate_qa() + logger.info(f"qas:{qas}") + # 将QA存入缓存,答案部分需按照json格式保存 + for query, answer in qas.items(): + memory_cache.update(query, json.dumps({"result": answer})) + + # 离线构建知识库,首先注册文档处理器 + loader_mng = LoaderMng() + # 加载文档加载器,可以使用mxrag自有的,也可以使用langchain的 + loader_mng.register_loader(loader_class=TextLoader, file_types=[".txt", ".md", ".docx"]) + # 加载文档切分器,使用langchain的 + loader_mng.register_splitter(splitter_class=RecursiveCharacterTextSplitter, + file_types=[".txt", ".md", ".docx"], + splitter_params={"chunk_size": 750, + "chunk_overlap": 150, + "keep_separator": False + } + ) + + # 初始化向量数据库 + vector_store = MindFAISS(x_dim=args.embedding_dim, + devs=[args.npu_device_id], + load_local_index=os.path.join(args.cache_save_path, "./faiss.index")) + + # 加载embedding模型,请根据模型具体路径适配 + emb = TextEmbedding(model_path=args.embedding_path, dev_id=args.npu_device_id) + + # 初始化文档chunk关系数据库 + chunk_store = SQLiteDocstore(db_path="./sql.db") + # <可选>初始化知识管理关系数据库 + knowledge_store = KnowledgeStore(db_path="./sql.db") + knowledge_store.add_knowledge("rag", "Default01", "admin") + # <可选>初始化知识库管理 + knowledge_db = KnowledgeDB(knowledge_store=knowledge_store, + chunk_store=chunk_store, + vector_store=vector_store, + knowledge_name="test", + white_paths=args.white_path, + user_id="Default01" + ) + # <可选> 完成离线知识库构建,上传领域知识gaokao.md文档。 + upload_files(knowledge=knowledge_db, + files=[args.file_path], + loader_mng=loader_mng, + embed_func=emb.embed_documents, + force=True + ) + + # 初始化Retriever检索器 + text_retriever = Retriever(vector_store=vector_store, + document_store=chunk_store, + embed_func=emb.embed_documents, + k=3, + score_threshold=args.score_threshold + ) + + # 构造cache_chain, 缓存memory cache作为入口 + cache_chain = CacheChainChat(chain=SingleText2TextChain(llm, text_retriever), cache=memory_cache) + # 提问和网页相关的问题,如果与已生成的QA近似,则会命中返回 + now_time = time.time() + logger.info(cache_chain.query(args.query)) + logger.info(f"耗时:{time.time() - now_time}s") + except Exception as e: + stack_trace = traceback.format_exc() + logger.error(stack_trace) + + +if __name__ == '__main__': + rag_cache_demo() diff --git a/RAGSDK/MainRepo/Samples/RagDemo/rag_demo_knowledge.py b/RAGSDK/MainRepo/Samples/RagDemo/rag_demo_knowledge.py new file mode 100644 index 0000000000000000000000000000000000000000..021ff870b074405a625eaba2a3e82f510c544653 --- /dev/null +++ b/RAGSDK/MainRepo/Samples/RagDemo/rag_demo_knowledge.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + +import argparse +import threading +import traceback +from loguru import logger +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.document_loaders import TextLoader +from mx_rag.document import LoaderMng +from mx_rag.document.loader import DocxLoader, PdfLoader +from mx_rag.embedding.local import TextEmbedding +from mx_rag.embedding.service import TEIEmbedding +from mx_rag.knowledge import KnowledgeDB +from mx_rag.knowledge.handler import upload_files +from mx_rag.knowledge.knowledge import KnowledgeStore +from mx_rag.storage.document_store import SQLiteDocstore +from mx_rag.storage.vectorstore import MindFAISS +from mx_rag.utils import ClientParam +from paddle.base import libpaddle + + +class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter): + def _get_default_metavar_for_optional(self, action): + return action.type.__name__ + + def _get_default_metavar_for_positional(self, action): + return action.type.__name__ + + +def rag_demo_upload(): + parse = argparse.ArgumentParser(formatter_class=CustomFormatter) + parse.add_argument("--embedding_path", type=str, default="/home/data/acge_text_embedding", + help="embedding模型本地路径") + parse.add_argument("--tei_emb", type=bool, default=False, help="是否使用TEI服务化的embedding模型") + parse.add_argument("--embedding_url", type=str, default="http://127.0.0.1:8080/embed", + help="使用TEI服务化的embedding模型url地址") + parse.add_argument("--embedding_dim", type=int, default=1024, help="embedding模型向量维度") + parse.add_argument("--white_path", type=str, nargs='+', default=["/home"], help="文件白名单路径") + parse.add_argument("--file_path", type=str, action='append', help="要上传的文件路径,需在白名单路径下") + parse.add_argument("--num_threads", type=int, default=2, help="可以根据实际情况调整线程数量") + + args = parse.parse_args().__dict__ + embedding_path: str = args.pop('embedding_path') + embedding_url: str = args.pop('embedding_url') + tei_emb: bool = args.pop('tei_emb') + embedding_dim: int = args.pop('embedding_dim') + white_path: list[str] = args.pop('white_path') + file_path: list[str] = args.pop('file_path') + num_threads: int = args.pop('num_threads') + + try: + # 离线构建知识库,首先注册文档处理器 + loader_mng = LoaderMng() + # 加载文档加载器,可以使用mxrag自有的,也可以使用langchain的 + loader_mng.register_loader(loader_class=TextLoader, file_types=[".txt", ".md"]) + loader_mng.register_loader(loader_class=PdfLoader, file_types=[".pdf"]) + loader_mng.register_loader(loader_class=DocxLoader, file_types=[".docx"]) + # 加载文档切分器,使用langchain的 + loader_mng.register_splitter(splitter_class=RecursiveCharacterTextSplitter, + file_types=[".pdf", ".docx", ".txt", ".md"], + splitter_params={"chunk_size": 750, + "chunk_overlap": 150, + "keep_separator": False + } + ) + # 设置向量检索使用的npu卡,具体可以用的卡可执行npu-smi info查询获取 + dev = 0 + # 加载embedding模型,请根据模型具体路径适配 + if tei_emb: + emb = TEIEmbedding(url=embedding_url, client_param=ClientParam(use_http=True)) + else: + emb = TextEmbedding(model_path=embedding_path, dev_id=dev) + # 初始化向量数据库 + vector_store = MindFAISS(x_dim=embedding_dim, + devs=[dev], + load_local_index="./faiss.index", + auto_save=True + ) + # 初始化文档chunk关系数据库 + chunk_store = SQLiteDocstore(db_path="./sql.db") + # 初始化知识管理关系数据库 + knowledge_store = KnowledgeStore(db_path="./sql.db") + # 添加知识库 + knowledge_store.add_knowledge("test", "Default", "admin") + # 初始化知识库管理 + knowledge_db = KnowledgeDB(knowledge_store=knowledge_store, + chunk_store=chunk_store, + vector_store=vector_store, + knowledge_name="test", + white_paths=white_path, + user_id="Default" + ) + + # 多线程上传文件 + batch_size = len(file_path) // num_threads + if len(file_path) % num_threads != 0: + batch_size += 1 + file_batchs = [file_path[i:i + batch_size] for i in range(0, len(file_path), batch_size)] + + threads = [] + for batch in file_batchs: + thread = threading.Thread(target=upload_files, + args=(knowledge_db, batch, loader_mng, emb.embed_documents, True)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + # 检验文件是否上传成功 + documents = [document.document_name for document in knowledge_db.get_all_documents()] + logger.info(documents) + except Exception as e: + stack_trace = traceback.format_exc() + logger.error(stack_trace) + + +if __name__ == '__main__': + rag_demo_upload() diff --git a/RAGSDK/MainRepo/Samples/RagDemo/rag_demo_query.py b/RAGSDK/MainRepo/Samples/RagDemo/rag_demo_query.py new file mode 100644 index 0000000000000000000000000000000000000000..ff5e747e208cdbba62adbb1346b4a0ae1e6a2b42 --- /dev/null +++ b/RAGSDK/MainRepo/Samples/RagDemo/rag_demo_query.py @@ -0,0 +1,143 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + +import argparse +import threading +import traceback +from loguru import logger +from mx_rag.chain import SingleText2TextChain +from mx_rag.embedding.local import TextEmbedding +from mx_rag.embedding.service import TEIEmbedding +from mx_rag.llm import Text2TextLLM +from mx_rag.reranker.local import LocalReranker +from mx_rag.reranker.service import TEIReranker +from mx_rag.retrievers import Retriever +from mx_rag.storage.document_store import SQLiteDocstore +from mx_rag.storage.vectorstore import MindFAISS +from mx_rag.utils import ClientParam +from paddle.base import libpaddle + + +class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter): + def _get_default_metavar_for_optional(self, action): + return action.type.__name__ + + def _get_default_metavar_for_positional(self, action): + return action.type.__name__ + + +class ThreadWithResult(threading.Thread): + def __init__(self, group=None, target=None, name=None, args=None, kwargs=None, *, daemon=None): + def function(): + self.result = target(*args, **kwargs) + + super().__init__(group=group, target=function, name=name, daemon=daemon) + + +def rag_demo_query(): + parse = argparse.ArgumentParser(formatter_class=CustomFormatter) + parse.add_argument("--embedding_path", type=str, default="/home/data/acge_text_embedding", + help="embedding模型本地路径") + parse.add_argument("--tei_emb", type=bool, default=False, help="是否使用TEI服务化的embedding模型") + parse.add_argument("--embedding_url", type=str, default="http://127.0.0.1:8080/embed", + help="使用TEI服务化的embedding模型url地址") + parse.add_argument("--embedding_dim", type=int, default=1024, help="embedding模型向量维度") + parse.add_argument("--llm_url", type=str, default="http://127.0.0.1:1025/v1/chat/completions", help="大模型url地址") + parse.add_argument("--model_name", type=str, default="Llama3-8B-Chinese-Chat", help="大模型名称") + parse.add_argument("--score_threshold", type=float, default=0.5, + help="相似性得分的阈值,大于阈值认为检索的信息与问题越相关,取值范围[0,1]") + parse.add_argument("--tei_reranker", type=bool, default=False, help="是否使用TEI服务化的reranker模型") + parse.add_argument("--reranker_path", type=str, default=None, help="reranker模型本地路径") + parse.add_argument("--reranker_url", type=str, default=None, help="使用TEI服务化的embedding模型url地址") + parse.add_argument("--query", type=str, action='append', help="用户问题") + parse.add_argument("--num_threads", type=int, default=2, help="可以根据实际情况调整线程数量") + + args = parse.parse_args().__dict__ + embedding_path: str = args.pop('embedding_path') + embedding_url: str = args.pop('embedding_url') + tei_emb: bool = args.pop('tei_emb') + embedding_dim: int = args.pop('embedding_dim') + llm_url: str = args.pop('llm_url') + model_name: str = args.pop('model_name') + score_threshold: int = args.pop('score_threshold') + query: list[str] = args.pop('query') + num_threads: int = args.pop('num_threads') + + try: + # 设置向量检索使用的npu卡,具体可以用的卡可执行npu-smi info查询获取 + dev = 0 + # 加载embedding模型,请根据模型具体路径适配 + if tei_emb: + emb = TEIEmbedding(url=embedding_url, client_param=ClientParam(use_http=True)) + else: + emb = TextEmbedding(model_path=embedding_path, dev_id=dev) + + # 初始化向量数据库 + vector_store = MindFAISS(x_dim=embedding_dim, + devs=[dev], + load_local_index="./faiss.index", + auto_save=True + ) + # 初始化文档chunk关系数据库 + chunk_store = SQLiteDocstore(db_path="./sql.db") + + # Step2在线问题答复,初始化检索器 + text_retriever = Retriever(vector_store=vector_store, + document_store=chunk_store, + embed_func=emb.embed_documents, + k=1, + score_threshold=score_threshold + ) + # 配置reranker,请根据模型具体路径适配 + reranker_path = args.get("reranker_path") + reranker_url = args.get("reranker_url") + tei_reranker = args.get("tei_reranker") + if tei_reranker: + reranker = TEIReranker(url=reranker_url, client_param=ClientParam(use_http=True)) + elif reranker_path is not None: + reranker = LocalReranker(model_path=reranker_path, dev_id=dev) + else: + reranker = None + # 配置text生成text大模型chain,具体ip端口请根据实际情况适配修改 + llm = Text2TextLLM(base_url=llm_url, model_name=model_name, client_param=ClientParam(use_http=True, timeout=60)) + + def process_query(input_string: str) -> str: + text2text_chain = SingleText2TextChain(retriever=text_retriever, llm=llm, reranker=reranker) + # 知识问答 + res = text2text_chain.query(input_string) + # 打印结果 + logger.info(res) + return f"{res}" + + results = [] + batch_size = len(query) // num_threads + if len(query) % num_threads != 0: + batch_size += 1 + batchs = [query[i:i + batch_size] for i in range(0, len(query), batch_size)] + + threads = [] + for batch in batchs: + def process_batch(batch): + batch_results = [] + for s in batch: + batch_results.append(process_query(s)) + return batch_results + + thread = ThreadWithResult(target=process_batch, args=(batch,)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + results.extend(thread.result) + + return results + + except Exception as e: + stack_trace = traceback.format_exc() + logger.error(stack_trace) + raise e + + +if __name__ == '__main__': + rag_demo_query() diff --git a/RAGSDK/MainRepo/langgraph/README.md b/RAGSDK/MainRepo/langgraph/README.md new file mode 100644 index 0000000000000000000000000000000000000000..612b0977d3c843bcdff77ab7c486ed173ca98d51 --- /dev/null +++ b/RAGSDK/MainRepo/langgraph/README.md @@ -0,0 +1,609 @@ +# RAG SDK基于LangGraph知识检索增强应用使能方案 + +## 1 背景 + +### 1.0 LangGraph介绍 + +[LangGraph官方介绍](https://blog.langchain.dev/langgraph/) +> LangGraph是Langchain新出的一个成员,是 LangChain 的 LangChain Expression Language +> (LCEL)的扩展。能够利用有向无环图的方式,去协调多个LLM或者状态,使用起来比 LCEL +> 会复杂,但是逻辑会更清晰。前期大家都学习了langchain,现在再上手学习langGraph就会容易许多,这也是我上面提到过的学习成本。我们可以把它也当做langchain扩展出来的Agent框架,langchain原有agent +> 的实现在LangGraph中都得到了重新实现,所以对于原来使用Langchain的系统去接入更容易。 + +### 1.1 RAG SDK介绍 + +RAG SDK 详细资料参考 [昇腾社区](https://www.hiascend.com/software/mindx-sdk) +> RAG SDK 是昇腾面向大语言模型的知识增强开发套件,为解决大模型知识更新缓慢以及垂直领域知识问答弱的问题,面向大模型知识库提供垂域调优、生成增强、知识管理等特性。 + +### 1.2 昇腾mis介绍 + +> mis: mind inference microservice +> +昇腾提供基于昇腾硬件加速的reranker和embedding mis服务,通过快速部署就可以支撑RAG应用。 +
+ +
+ +embedding tei [安装地址](https://www.hiascend.com/developer/ascendhub/detail/07a016975cc341f3a5ae131f2b52399d) +reranker tei [安装地址](https://www.hiascend.com/developer/ascendhub/detail/07a016975cc341f3a5ae131f2b52399d) + +## 2 环境安装 + +参考 RAG SDK环境安装手册,分别安装cann、rag SDK、以及部署embedding, reranker, mindie Service服务,安装langgraph包:pip3 +install langgraph==0.2.19 + +## 3 总体介绍 + +基于langgraph和rag sdk搭建RAG应用, 根据langgraph的定义需要包含node和graph,node中使用rag sdk完成相应的功能。 + +**RAG 节点(Node)定义** + +* cache search node:用户问题缓存查询节点 +* query decompose node:用户问题拆分子问题节点 +* hybrid Retrieve node:用户问题混合检索文档节点 +* rerank node:重排检索文档节点 +* generate node:大模型生成节点 +* halluciation check node:生成内容幻觉检查节点 +* query rewrite node: 用户问题重写节点 +* cache update node: 用户问题缓存更新节点 + +**RAG 图(GRAPH)定义** +图定义如下: +![alt text](image_new.png) + +状态转换如下表所示: +name | type |next hop | input | output +-|-|-|-|- +**cache search**| node| if cache hit return generation to user else go **query decompose** | (qustion) | if hit ( +question, generation) else (qustion) +**query decompose** | node | **hybrid retrieve** | (question) | (question, sub_qustion) +**hybrid retrieve** | node | **rerank** | (question, sub_qustion) | (question, sub_qustion, contexts) +**rerank** | node | **generate** | (question, sub_qustion, contexts) | (question, sub_qustion, contexts) +**query rewrite** | node | **cache search** | (question, sub_qustion, contexts) or (question, sub_qustion, contexts, +generate) | (question) +**generate**| node | **halluciation check** | (question, sub_qustion, contexts) | (question, sub_qustion, contexts, +generate) +**halluciation check** | node | if halluciation check pass go **cache update** else go **query rewrite** | (question, +sub_qustion, contexts, generate) | (question, sub_qustion, contexts, generate) +**cache update** | node | END | (question, sub_qustion, contexts, generate) | (question, sub_qustion, contexts, +generate) + +## 4 RAG SDK 功能初始化 + +完整的代码样例请参考[langgraph_demo.py](langgraph_demo.py) + +### 4.1 RAG文档加载和切分 + +以下是初始化一个docx的文件加载器和文件切分器,并且按照chunk_size=200,chunk_overlap=50进行切分,详细的API文档请参考RAG +SDK的使用手册。 + +```python +def create_loader_and_spliter(mxrag_component: Dict[str, Any], + chunk_size:int = 200, + chunk_overlap:int = 50): + from langchain.text_splitter import RecursiveCharacterTextSplitter + + from mx_rag.knowledge.doc_loader_mng import LoaderMng + from mx_rag.document.loader import DocxLoader + + loader_mng = LoaderMng() + loader_mng.register_loader(DocxLoader, [".docx"]) + loader_mng.register_splitter(RecursiveCharacterTextSplitter, [".docx"], + {"chunk_size": chunk_size, "chunk_overlap": chunk_overlap, "keep_separator": False}) + mxrag_component["loader_mng"] = loader_mng +``` + +### 4.2 RAG远端服务 + +以下是分别初始化mindie Service,AIM embedding,AIM reranker服务,用户需要传入相应的地址。 + +```python +def create_remote_connector(mxrag_component: Dict[str, Any], + reranker_url: str, + embedding_url: str, + llm_url: str, + llm_model_name: str): + from mx_rag.llm.text2text import Text2TextLLM + from mx_rag.embedding import EmbeddingFactory + from mx_rag.reranker.reranker_factory import RerankerFactory + + reranker = RerankerFactory.create_reranker(similarity_type="tei_reranker", + url=reranker_url, + client_param=ClientParam,(use_http=True), + k=3) + mxrag_component['reranker_connector'] = reranker + + embedding = EmbeddingFactory.create_embedding(embedding_type="tei_embedding", + url=embedding_url, + client_param=ClientParam(use_http=True)) + mxrag_component['embedding_connector'] = embedding + + llm = Text2TextLLM(base_url=llm_url, model_name=llm_model_name, + client_param=ClientParam(use_http=True), + llm_config=LLMParameterConfig(max_tokens=4096)) + mxrag_component['llm_connector'] = llm +``` + +### 4.3 RAG知识库 + +以下是存放用户知识文档的样例,这里使用mxindex(MindFaiss)作为矢量检索,knowledge_files是用户需要传入包含文件路径的文件名列表。 + +```python +def create_knowledge_storage(mxrag_component: Dict[str, Any], knowledge_files: List[str]): + from mx_rag.knowledge.knowledge import KnowledgeStore + from mx_rag.knowledge import KnowledgeDB + from mx_rag.knowledge.handler import upload_files + from mx_rag.storage.document_store import SQLiteDocstore + + npu_dev_id = 0 + + # faiss_index_save_file is your faiss index save dir + faiss_index_save_file:str = "/home/HwHiAiUser/rag_npu_faiss.index" + vector_store = MindFAISS(x_dim=1024, + devs=[npu_dev_id], + load_local_index=faiss_index_save_file) + mxrag_component["vector_store"] = vector_store + + + # sqlite_save_file is your sqlite save dir + sqlite_save_file:str = "/home/HwHiAiUser/rag_sql.db" + chunk_store = SQLiteDocstore(db_path=sqlite_save_file) + mxrag_component["chunk_store"] = chunk_store + + # your knowledge file white paths if docx not in white paths will raise exception + white_paths=["/home/HwHiAiUser/"] + knowledge_store = KnowledgeStore(db_path=sqlite_save_file) + knowledge_store.add_knowledge("rag", "Default01", "admin") + Knowledge_db = KnowledgeDB(knowledge_store=knowledge_store, chunk_store=chunk_store, vector_store=vector_store, + knowledge_name="rag", white_paths=white_paths, user_id="Default01") + + upload_files(Knowledge_db, knowledge_files, loader_mng=mxrag_component.get("loader_mng"), + embed_func=mxrag_component.get("embedding_connector").embed_documents, + force=True) +``` + +### 4.4 RAG缓存系统 + +定义语义缓存系统,用于缓存用户已经提供过的答案,当用户再次提问相似的问题可以很快 +返回结果,不需要再进行大模型推理,加速E2E性能。 +语义缓存一般包含矢量数据库,标量数据库以及相应的embedding和相似度计算方法。 + +```python +def create_cache(mxrag_component: Dict[str, Any], + reranker_url: str, + embedding_url: str): + from mx_rag.cache import SimilarityCacheConfig + from mx_rag.cache import EvictPolicy + from mx_rag.cache import MxRAGCache + + npu_dev_id = 0 + # data_save_folder is your cache file when you next run your rag applicate it will read form disk + cache_data_save_folder = "/home/HwHiAiUser/mx_rag/cache_save_folder/" + + similarity_config = SimilarityCacheConfig( + vector_config={ + "vector_type": "npu_faiss_db", + "x_dim": 1024, + "devs": [npu_dev_id], + }, + cache_config="sqlite", + emb_config={ + "embedding_type": "tei_embedding", + "url": embedding_url, + "client_param": ClientParam(use_http=True) + }, + similarity_config={ + "similarity_type": "tei_reranker", + "url": reranker_url, + "client_param": ClientParam(use_http=True) + }, + retrieval_top_k=3, + cache_size=100, + auto_flush=100, + similarity_threshold=0.70, + data_save_folder=cache_data_save_folder, + disable_report=True, + eviction_policy=EvictPolicy.LRU + ) + + similarity_cache = MxRAGCache("similarity_cache", similarity_config) + mxrag_component["cache"] = similarity_cache +``` + +### 4.5 RAG评估系统 + +以下是初始化评估系统,这里使用大模型进行评估 + +```python +def create_evaluate(mxrag_component): + from mx_rag.evaluate import Evaluate + + llm = mxrag_component.get("llm_connector") + embedding = mxrag_component.get("embedding_connector") + mxrag_component["evaluator"] = Evaluate(llm=llm, embedding=embedding) +``` + +### 4.6 RAG混合检索 + +以下是构建混合检索的样例,这里使用了矢量检索和BM25检索,并按照RRF算法设置权重进行排序得到最后的检索文档。 + +```python +def create_hybrid_search_retriever(mxrag_component: Dict[str, Any]): + from langchain.retrievers import EnsembleRetriever + + from mx_rag.retrievers.retriever import Retriever + from mx_rag.retrievers import BMRetriever + + chunk_store = mxrag_component.get("chunk_store") + vector_store = mxrag_component.get("vector_store") + embedding = mxrag_component.get("embedding_connector") + + npu_faiss_retriever = Retriever(vector_store=vector_store, document_store=chunk_store, + embed_func=embedding.embed_documents, k=10, score_threshold=0.4) + + hybrid_retriever = EnsembleRetriever( + retrievers=[npu_faiss_retriever], weights=[1.0] + ) + + mxrag_component["retriever"] = hybrid_retriever +``` + +## 5 langgraph 图定义和编译运行 + +完整的代码样例请参考[langgraph_demo.py](langgraph_demo.py) + +### 5.1 Node定义 + +#### 5.1.1 Cache Search + +使用用户的问题,访问rag cache,如果命中generation不为None + +```python +def cache_search(cache): + def cache_search_process(state): + logger.info("---QUERY SEARCH ---") + question = state["question"] + generation = cache.search(question) + return {"question": question, "generation": generation} + + return cache_search_process +``` + +判决cache search 是否hit,根据generation 是否为None进行判断,如果为None则表示 +cache miss,如果不为None则cache hit + +```python +def decide_to_decompose(state): + logger.info("---DECIDE TO DECOMPOSE---") + cache_generation = state["generation"] + + if cache_generation is None: + logger.warning( + "---DECISION: CACHE MISS GO DECOMPOSE---" + ) + return "cache_miss" + + logger.info("---DECISION: CACHE HIT END---") + return "cache_hit" +``` + +#### 5.1.2 Query Decompose + +使用提示词工程进行问题拆解,拆解为子问题 + +```python +def decompose(llm): + sub_question_key_words = "Q:" + prompt = PromptTemplate( + template=""" + 请你参考如下示例,拆分用户的问题为独立子问题,如果无法拆分则返回原始问题: + 示例一: + 用户问题: 今天的天气如何, 你今天过的怎么样? + + {sub_question_key_words}今天的天气如何? + {sub_question_key_words}你今天过的怎么样? + + 示例二: + 用户问题: 汉堡好吃吗? + + {sub_question_key_words}汉堡好吃吗? + + 现在请你参考示例拆分以下用户问题: + 用户的问题:{question} + """, + input_variables=["question", "sub_question_key_words"] + ) + + sub_question_generator = LLMChain(llm=llm, prompt=prompt) + + def decompose_process(state): + logger.info("---QUERY DECOMPOSITION ---") + question = state["question"] + + sub_queries = sub_question_generator.predict(question=question, sub_question_key_words=sub_question_key_words) + if sub_question_key_words not in sub_queries: + sub_queries = None + else: + sub_queries = sub_queries.split(sub_question_key_words) + sub_queries = sub_queries[1:] + + return {"sub_questions": sub_queries, "question": question} + + return decompose_process +``` + +#### 5.1.3 Hybrid Retrive + +以下是进行混合检索,如果sub_question为None则使用quesiton进行检索,如果sub_question不为None则使用sub_question进行检索。 + +```python +def retrieve(retriever: BaseRetriever): + def retrieve_process(state): + logger.info("---RETRIEVE---") + sub_questions = state["sub_questions"] + question = state["question"] + + documents = [] + docs = [] + if sub_questions is None: + docs = retriever.get_relevant_documents(question) + else: + for query in sub_questions: + docs.extend(retriever.get_relevant_documents(query)) + + for doc in docs: + if doc.page_content not in documents: + documents.append(doc.page_content) + + return {"documents": documents, "question": question} + + return retrieve_process +``` + +#### 5.1.4 Rerank + +将用户的检索文档根据语义进行重排序 + +```python +def rerank(reranker): + def rerank_process(state): + logger.info("---RERANK---") + question = state["question"] + documents = state["documents"] + + scores = reranker.rerank(query=question, texts=documents) + documents = reranker.rerank_top_k(objs=documents, scores=scores) + + return {"documents": documents, "question": question} + + return rerank_process +``` + +#### 5.1.6 Generate + +使用提示词工程访问进行大模型推理过程得到生成结果。 + +```python +def generate(llm): + prompt = PromptTemplate( + template="""{context} + + 根据上述已知信息,简洁和专业的来回答用户问题。如果无法从中已知信息中得到答案,请根据自身经验做出回答 + + {question} + """, + input_variables=["context", "question"] + ) + + rag_chain = LLMChain(llm=llm, prompt=prompt) + + def generate_process(state): + logger.info("---GENERATE---") + question = state["question"] + documents = state["documents"] + + generation = rag_chain.predict(context=documents, question=question) + return {"documents": documents, "question": question, "generation": generation} + + return generate_process +``` + +#### 5.1.7 Halluciation Check + +利用大模型评估进行判断生成质量是否符合用户需求。 + +```python +def grade_generation_v_documents_and_question(evaluate, + context_score_threshold: float = 0.6, + answer_score_threshold: float = 0.6): + generate_evalutor = evaluate_creator(evaluate, "generate_relevancy") + + def grade_generation_v_documents_and_question_process(state): + logger.info("---CHECK HALLUCINATIONS---") + + answer_score, context_score = generate_evalutor(state) + + answer_score = answer_score[0] + logger.info("---GRADE GENERATION vs QUESTION---") + if answer_score < answer_score_threshold: + logger.warning(f"---DECISION: GENERATION DOES NOT ADDRESS QUESTION," + f" RE-TRY--- answer_score:{answer_score}," + f"answer_score_threshold:{answer_score_threshold}") + return "not useful" + + logger.info(f"---DECISION: GENERATION ADDRESSES QUESTION--- " + f"answer_score:{answer_score}," + f"answer_score_threshold:{answer_score_threshold}") + + context_score = context_score[0] + logger.info("---GRADE GENERATION vs DOCUMENTS---") + if context_score < context_score_threshold: + logger.warning(f"---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, " + f" RE-TRY--- context_score:{context_score}," + f"context_score_threshold:{context_score_threshold}") + return "not useful" + + logger.info(f"---DECISION: GENERATION GROUNDED IN DOCUMENTS---" + f"context_score:{context_score}," + f"context_score_threshold:{context_score_threshold}") + return "useful" + + return grade_generation_v_documents_and_question_process +``` + +#### 5.1.8 CacheUpdate + +如果大模型生成质量符合要求,则更新缓存 + +```python +def cache_update(cache): + def cache_update_process(state): + logger.info("---QUERY UPDATE ---") + question = state["question"] + generation = state["generation"] + + cache.update(question, generation) + + return state + + return cache_update_process +``` + +#### 5.1.9 Query Rewrite + +利用提示词工程进行问答重写 + +```python +def transform_query(llm): + prompt = PromptTemplate( + template=""" + 你是一个用户问题重写员, 请仔细理解用户问题的内容和语义和检索的文档,在不修改用户问题 + 语义的前提下,将用户问题重写为可以更好被矢量检索的形式 + + 用户问题:{question} + """, + input_variables=["question"] + ) + + question_rewriter = LLMChain(llm=llm, prompt=prompt) + + def transform_query_process(state): + logger.info("---TRANSFORM QUERY---") + question = state["question"] + documents = state["documents"] + + better_question = question_rewriter.predict(question=question) + + return {"documents": documents, "question": better_question} + + return transform_query_process +``` + +### 5.2 图编译 + +```python +def build_mxrag_application(mxrag_component): + from langgraph.graph import END, START, StateGraph + + class GraphState(TypedDict): + question: str + sub_questions: List[str] + generation: str + documents: List[str] + + llm = mxrag_component.get("llm_connector") + retriever = mxrag_component.get("retriever") + reranker = mxrag_component.get("reranker_connector") + cache = mxrag_component.get("cache") + evaluate = mxrag_component.get("evaluator") + + workflow = StateGraph(GraphState) + workflow.add_node("cache_search", cache_search(cache)) + workflow.add_node("cache_update", cache_update(cache)) + workflow.add_node("decompose", decompose(llm)) + workflow.add_node("retrieve", retrieve(retriever)) + workflow.add_node("rerank", rerank(reranker)) + workflow.add_node("generate", generate(llm)) + workflow.add_node("transform_query", transform_query(llm)) + + workflow.add_edge(START, "cache_search") + + workflow.add_conditional_edges( + "cache_search", + decide_to_decompose, + { + "cache_hit": END, + "cache_miss": "decompose", + }, + ) + + workflow.add_edge("decompose", "retrieve") + workflow.add_edge("retrieve", "rerank") + workflow.add_edge("rerank", "generate") + workflow.add_edge("transform_query", "cache_search") + workflow.add_conditional_edges( + "generate", + grade_generation_v_documents_and_question(evaluate), + { + "useful": "cache_update", + "not useful": "transform_query" + }, + ) + + workflow.add_edge("cache_update", END) + app = workflow.compile() + return app +``` + +### 5.3 在线问答 + +```python +if __name__ == "__main__": + mxrag_component: Dict[str, Any] = {} + + # mis tei rerank + mis_tei_reranker_url = "http://ip:port/rerank" + # mis tei embed + mis_tei_embedding_url = "http://ip:port/embed" + + # mindie llm server + llm_url = "http://ip:port/v1/chat/completions" + + # llm model name like Llama3-8B-Chinese-Chat etc + llm_model_name = "Llama3-8B-Chinese-Chat" + + # your knowledge list + knowledge_files = ["/home/HwHiAiUser/doc1.docx"] + + create_loader_and_spliter(mxrag_component, chunk_size=200, chunk_overlap=50) + + create_remote_connector(mxrag_component, + reranker_url=mis_tei_reranker_url, + embedding_url=mis_tei_embedding_url, + llm_url=llm_url, + llm_model_name=llm_model_name) + + create_knowledge_storage(mxrag_component, knowledge_files=knowledge_files) + + create_cache(mxrag_component, + reranker_url=mis_tei_reranker_url, + embedding_url=mis_tei_embedding_url) + + create_hybrid_search_retriever(mxrag_component) + + create_evaluate(mxrag_component) + + rag_app = build_mxrag_application(mxrag_component) + + user_question = "your question" + + start_time = time.time() + user_answer = rag_app.invoke({"question": user_question}) + end_time = time.time() + + print(f"user_question:{user_question}") + print(f"user_answer:{user_answer}") + print(f"app time cost:{(end_time - start_time) * 1000} ms") +``` + +## 6 附录 diff --git a/RAGSDK/MainRepo/langgraph/image-1.png b/RAGSDK/MainRepo/langgraph/image-1.png new file mode 100644 index 0000000000000000000000000000000000000000..1a14712ef83412dcf0826a0497b06bc3c998b245 Binary files /dev/null and b/RAGSDK/MainRepo/langgraph/image-1.png differ diff --git a/RAGSDK/MainRepo/langgraph/image.png b/RAGSDK/MainRepo/langgraph/image.png new file mode 100644 index 0000000000000000000000000000000000000000..8e4091182c754aabc007b2e04ae76c58ac2a6ba1 Binary files /dev/null and b/RAGSDK/MainRepo/langgraph/image.png differ diff --git a/RAGSDK/MainRepo/langgraph/image_new.png b/RAGSDK/MainRepo/langgraph/image_new.png new file mode 100644 index 0000000000000000000000000000000000000000..8185776d91b85f5c3229282fb58adfbb08eeeff4 Binary files /dev/null and b/RAGSDK/MainRepo/langgraph/image_new.png differ diff --git a/RAGSDK/MainRepo/langgraph/langgraph_demo.py b/RAGSDK/MainRepo/langgraph/langgraph_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..9ff9842762889d44400ecad38d4676ec5353692a --- /dev/null +++ b/RAGSDK/MainRepo/langgraph/langgraph_demo.py @@ -0,0 +1,498 @@ +import os +import time +from typing import List, TypedDict, Any, Dict +from langchain.chains.llm import LLMChain +from langchain_core.documents import Document +from langchain_core.prompts import PromptTemplate +from langchain_core.retrievers import BaseRetriever +from loguru import logger +from mx_rag.llm import LLMParameterConfig +from mx_rag.storage.vectorstore import MindFAISS +from mx_rag.utils import ClientParam +from paddle.base import libpaddle + + +def evaluate_creator(evaluator, evaluate_type: str): + language = "chinese" + + # prompt_dir is ragas cache_dir will speed evaluate + + def evaluate_generate_relevancy(state): + question = state["question"] + retrieved_contexts = [doc.page_content for doc in state["documents"]] + response = state["generation"] + + datasets = { + "user_input": [question], + "response": [response], + "retrieved_contexts": [retrieved_contexts] + } + + scores = evaluator.evaluate_scores(metrics_name=["answer_relevancy", "faithfulness"], + datasets=datasets, + language=language) + return scores["answer_relevancy"], scores["faithfulness"] + + if evaluate_type == "generate_relevancy": + return evaluate_generate_relevancy + + raise KeyError("evaluate_type not support") + + +def cache_search(cache): + def cache_search_process(state): + logger.info("---QUERY SEARCH ---") + question = state["question"] + generation = cache.search(question) + return {"question": question, "generation": generation} + + return cache_search_process + + +def cache_update(cache): + def cache_update_process(state): + logger.info("---QUERY UPDATE ---") + question = state["question"] + generation = state["generation"] + + cache.update(question, generation) + + return state + + return cache_update_process + + +def decide_to_decompose(state): + logger.info("---DECIDE TO DECOMPOSE---") + cache_generation = state["generation"] + + if cache_generation is None: + logger.warning( + "---DECISION: CACHE MISS GO DECOMPOSE---" + ) + return "cache_miss" + + logger.info("---DECISION: CACHE HIT END---") + return "cache_hit" + + +def decompose(llm): + sub_question_key_words = "Q:" + prompt = PromptTemplate( + template=""" + 请你参考如下示例,拆分用户的问题为独立子问题,如果无法拆分则返回原始问题: + 示例一: + 用户问题: 今天的天气如何, 你今天过的怎么样? + + {sub_question_key_words}今天的天气如何? + {sub_question_key_words}你今天过的怎么样? + + 示例二: + 用户问题: 汉堡好吃吗? + + {sub_question_key_words}汉堡好吃吗? + + 现在请你参考示例拆分以下用户问题: + 用户的问题:{question} + """, + input_variables=["question", "sub_question_key_words"] + ) + + sub_question_generator = LLMChain(llm=llm, prompt=prompt) + + def decompose_process(state): + logger.info("---QUERY DECOMPOSITION ---") + question = state["question"] + + sub_queries = sub_question_generator.predict(question=question, sub_question_key_words=sub_question_key_words) + if sub_question_key_words not in sub_queries: + sub_queries = None + else: + sub_queries = sub_queries.split(sub_question_key_words) + sub_queries = sub_queries[1:] + + return {"sub_questions": sub_queries, "question": question} + + return decompose_process + + +def retrieve(retriever: BaseRetriever): + def retrieve_process(state): + logger.info("---RETRIEVE---") + sub_questions = state["sub_questions"] + question = state["question"] + + documents = [] + docs = [] + if sub_questions is None: + docs = retriever.get_relevant_documents(question) + else: + for query in sub_questions: + docs.extend(retriever.get_relevant_documents(query)) + + for doc in docs: + if doc.page_content not in documents: + documents.append(doc.page_content) + + return {"documents": documents, "question": question} + + return retrieve_process + + +def rerank(reranker): + def rerank_process(state): + logger.info("---RERANK---") + question = state["question"] + documents = state["documents"] + if len(documents) < 2: + return {"documents": documents, "question": question} + scores = reranker.rerank(query=question, texts=documents) + documents = [Document(page_content=content) for content in documents] + documents = reranker.rerank_top_k(objs=documents, scores=scores) + + return {"documents": documents, "question": question} + + return rerank_process + + +def generate(llm): + prompt = PromptTemplate( + template="""{context} + + 根据上述已知信息,简洁和专业的来回答用户问题。如果无法从中已知信息中得到答案,请根据自身经验做出回答 + + {question} + """, + input_variables=["context", "question"] + ) + + rag_chain = LLMChain(llm=llm, prompt=prompt) + + def generate_process(state): + logger.info("---GENERATE---") + question = state["question"] + documents = state["documents"] + + generation = rag_chain.predict(context=documents, question=question) + return {"documents": documents, "question": question, "generation": generation} + + return generate_process + + +def transform_query(llm): + prompt = PromptTemplate( + template=""" + 你是一个用户问题重写员, 请仔细理解用户问题的内容和语义和检索的文档,在不修改用户问题 + 语义的前提下,将用户问题重写为可以更好被矢量检索的形式 + + 用户问题:{question} + """, + input_variables=["question"] + ) + + question_rewriter = LLMChain(llm=llm, prompt=prompt) + + def transform_query_process(state): + logger.info("---TRANSFORM QUERY---") + question = state["question"] + documents = state["documents"] + + better_question = question_rewriter.predict(question=question) + + return {"documents": documents, "question": better_question} + + return transform_query_process + + +def decide_to_generate(state): + logger.info("---ASSESS GRADED DOCUMENTS---") + filtered_documents = state["documents"] + + if not filtered_documents: + logger.warning( + "---DECISION:ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---" + ) + return "transform_query" + logger.info("---DECISION: GENERATE---") + return "generate" + + +def grade_generation_v_documents_and_question(evaluate, + context_score_threshold: float = 0.6, + answer_score_threshold: float = 0.6): + generate_evalutor = evaluate_creator(evaluate, "generate_relevancy") + + def grade_generation_v_documents_and_question_process(state): + logger.info("---CHECK HALLUCINATIONS---") + + answer_score, context_score = generate_evalutor(state) + + answer_score = answer_score[0] + logger.info("---GRADE GENERATION vs QUESTION---") + if answer_score < answer_score_threshold: + logger.warning(f"---DECISION: GENERATION DOES NOT ADDRESS QUESTION," + f" RE-TRY--- answer_score:{answer_score}," + f"answer_score_threshold:{answer_score_threshold}") + return "not useful" + + logger.info(f"---DECISION: GENERATION ADDRESSES QUESTION--- " + f"answer_score:{answer_score}," + f"answer_score_threshold:{answer_score_threshold}") + + context_score = context_score[0] + logger.info("---GRADE GENERATION vs DOCUMENTS---") + if context_score < context_score_threshold: + logger.warning(f"---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, " + f" RE-TRY--- context_score:{context_score}," + f"context_score_threshold:{context_score_threshold}") + return "not useful" + + logger.info(f"---DECISION: GENERATION GROUNDED IN DOCUMENTS---" + f"context_score:{context_score}," + f"context_score_threshold:{context_score_threshold}") + return "useful" + + return grade_generation_v_documents_and_question_process + + +def create_loader_and_spliter(mxrag_component: Dict[str, Any], + chunk_size: int = 200, + chunk_overlap: int = 50): + from langchain.text_splitter import RecursiveCharacterTextSplitter + + from mx_rag.document import LoaderMng + from mx_rag.document.loader import DocxLoader + + loader_mng = LoaderMng() + loader_mng.register_loader(DocxLoader, [".docx"]) + loader_mng.register_splitter(RecursiveCharacterTextSplitter, [".docx"], + {"chunk_size": chunk_size, "chunk_overlap": chunk_overlap, "keep_separator": False}) + mxrag_component["loader_mng"] = loader_mng + + +def create_remote_connector(mxrag_component: Dict[str, Any], + reranker_url: str, + embedding_url: str, + llm_url: str, + llm_model_name: str): + from mx_rag.llm.text2text import Text2TextLLM + from mx_rag.embedding import EmbeddingFactory + from mx_rag.reranker.reranker_factory import RerankerFactory + + reranker = RerankerFactory.create_reranker(similarity_type="tei_reranker", + url=reranker_url, + client_param=ClientParam(use_http=True), + k=3) + mxrag_component['reranker_connector'] = reranker + + embedding = EmbeddingFactory.create_embedding(embedding_type="tei_embedding", + url=embedding_url, + client_param=ClientParam(use_http=True) + ) + mxrag_component['embedding_connector'] = embedding + + llm = Text2TextLLM(base_url=llm_url, model_name=llm_model_name, + client_param=ClientParam(use_http=True, timeout=240), + llm_config=LLMParameterConfig(max_tokens=4096)) + mxrag_component['llm_connector'] = llm + + +def create_knowledge_storage(mxrag_component: Dict[str, Any], knowledge_files: List[str]): + from mx_rag.knowledge.knowledge import KnowledgeStore + from mx_rag.knowledge import KnowledgeDB + from mx_rag.knowledge.handler import upload_files + from mx_rag.storage.document_store import SQLiteDocstore + + npu_dev_id = 0 + + # faiss_index_save_file is your faiss index save dir + faiss_index_save_file: str = "/home/HwHiAiUser/rag_npu_faiss.index" + vector_store = MindFAISS(x_dim=1024, + devs=[npu_dev_id], + load_local_index=faiss_index_save_file) + mxrag_component["vector_store"] = vector_store + + # sqlite_save_file is your sqlite save dir + sqlite_save_file: str = "/home/HwHiAiUser/rag_sql.db" + chunk_store = SQLiteDocstore(db_path=sqlite_save_file) + mxrag_component["chunk_store"] = chunk_store + + # your knowledge file white paths if docx not in white paths will raise exception + white_paths = ["/home/HwHiAiUser/"] + knowledge_store = KnowledgeStore(db_path=sqlite_save_file) + knowledge_store.add_knowledge("rag", "Default01", "admin") + Knowledge_db = KnowledgeDB(knowledge_store=knowledge_store, chunk_store=chunk_store, vector_store=vector_store, + knowledge_name="rag", white_paths=white_paths, user_id="Default01") + + upload_files(Knowledge_db, knowledge_files, loader_mng=mxrag_component.get("loader_mng"), + embed_func=mxrag_component.get("embedding_connector").embed_documents, + force=True) + + +def create_hybrid_search_retriever(mxrag_component: Dict[str, Any]): + from langchain.retrievers import EnsembleRetriever + from mx_rag.retrievers.retriever import Retriever + + chunk_store = mxrag_component.get("chunk_store") + vector_store = mxrag_component.get("vector_store") + embedding = mxrag_component.get("embedding_connector") + + npu_faiss_retriever = Retriever(vector_store=vector_store, document_store=chunk_store, + embed_func=embedding.embed_documents, k=10, score_threshold=0.4) + + hybrid_retriever = EnsembleRetriever( + retrievers=[npu_faiss_retriever], weights=[1.0] + ) + + mxrag_component["retriever"] = hybrid_retriever + + +def create_cache(mxrag_component: Dict[str, Any], + reranker_url: str, + embedding_url: str): + from mx_rag.cache import SimilarityCacheConfig + from mx_rag.cache import EvictPolicy + from mx_rag.cache import MxRAGCache + + npu_dev_id = 0 + # data_save_folder is your cache file when you next run your rag applicate it will read form disk + cache_data_save_folder = "/home/HwHiAiUser/mx_rag/cache_save_folder/" + + similarity_config = SimilarityCacheConfig( + vector_config={ + "vector_type": "npu_faiss_db", + "x_dim": 1024, + "devs": [npu_dev_id], + }, + cache_config="sqlite", + emb_config={ + "embedding_type": "tei_embedding", + "url": embedding_url, + "client_param": ClientParam(use_http=True) + }, + similarity_config={ + "similarity_type": "tei_reranker", + "url": reranker_url, + "client_param": ClientParam(use_http=True) + }, + retrieval_top_k=3, + cache_size=100, + auto_flush=100, + similarity_threshold=0.70, + data_save_folder=cache_data_save_folder, + disable_report=True, + eviction_policy=EvictPolicy.LRU + ) + + similarity_cache = MxRAGCache("similarity_cache", similarity_config) + mxrag_component["cache"] = similarity_cache + + +def create_evaluate(mxrag_component): + from mx_rag.evaluate import Evaluate + + llm = mxrag_component.get("llm_connector") + embedding = mxrag_component.get("embedding_connector") + mxrag_component["evaluator"] = Evaluate(llm=llm, embedding=embedding) + + +def build_mxrag_application(mxrag_component): + from langgraph.graph import END, START, StateGraph + + class GraphState(TypedDict): + question: str + sub_questions: List[str] + generation: str + documents: List[str] + + llm = mxrag_component.get("llm_connector") + retriever = mxrag_component.get("retriever") + reranker = mxrag_component.get("reranker_connector") + cache = mxrag_component.get("cache") + evaluate = mxrag_component.get("evaluator") + + workflow = StateGraph(GraphState) + workflow.add_node("cache_search", cache_search(cache)) + workflow.add_node("cache_update", cache_update(cache)) + workflow.add_node("decompose", decompose(llm)) + workflow.add_node("retrieve", retrieve(retriever)) + workflow.add_node("rerank", rerank(reranker)) + workflow.add_node("generate", generate(llm)) + workflow.add_node("transform_query", transform_query(llm)) + + workflow.add_edge(START, "cache_search") + + workflow.add_conditional_edges( + "cache_search", + decide_to_decompose, + { + "cache_hit": END, + "cache_miss": "decompose", + }, + ) + + workflow.add_edge("decompose", "retrieve") + workflow.add_edge("retrieve", "rerank") + + workflow.add_edge("rerank", "generate") + workflow.add_edge("transform_query", "cache_search") + workflow.add_conditional_edges( + "generate", + grade_generation_v_documents_and_question(evaluate), + { + "useful": "cache_update", + "not useful": "transform_query" + }, + ) + + workflow.add_edge("cache_update", END) + app = workflow.compile() + return app + + +if __name__ == "__main__": + mxrag_component: Dict[str, Any] = {} + + # mis tei rerank + mis_tei_reranker_url = "http://ip:port/rerank" + # mis tei embed + mis_tei_embedding_url = "http://ip:port/embed" + + # mindie llm server + llm_url = "http://ip:port/v1/chat/completions" + + # llm model name like Llama3-8B-Chinese-Chat etc + llm_model_name = "Llama3-8B-Chinese-Chat" + + # your knowledge list + knowledge_files = ["/home/HwHiAiUser/doc1.docx"] + + create_loader_and_spliter(mxrag_component, chunk_size=200, chunk_overlap=50) + + create_remote_connector(mxrag_component, + reranker_url=mis_tei_reranker_url, + embedding_url=mis_tei_embedding_url, + llm_url=llm_url, + llm_model_name=llm_model_name) + + create_knowledge_storage(mxrag_component, knowledge_files=knowledge_files) + + create_cache(mxrag_component, + reranker_url=mis_tei_reranker_url, + embedding_url=mis_tei_embedding_url) + + create_hybrid_search_retriever(mxrag_component) + + create_evaluate(mxrag_component) + + rag_app = build_mxrag_application(mxrag_component) + + user_question = "your question" + + start_time = time.time() + user_answer = rag_app.invoke({"question": user_question}) + end_time = time.time() + + logger.info(f"user_question:{user_question}") + logger.info(f"user_answer:{user_answer}") + logger.info(f"app time cost:{(end_time - start_time) * 1000} ms") diff --git a/RAGSDK/MainRepo/langgraph/requirements.txt b/RAGSDK/MainRepo/langgraph/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..c4c9f4ee58708a9e01ed30fad77241adf0ecdc1e --- /dev/null +++ b/RAGSDK/MainRepo/langgraph/requirements.txt @@ -0,0 +1,2 @@ +langgraph==0.2.19 +langchain_core==0.2.43 \ No newline at end of file diff --git a/RAGSDK/MainRepo/sd_samples/README.md b/RAGSDK/MainRepo/sd_samples/README.md new file mode 100644 index 0000000000000000000000000000000000000000..06e4ef3192c964426e3beec7dbd22c0aa6a6cb62 --- /dev/null +++ b/RAGSDK/MainRepo/sd_samples/README.md @@ -0,0 +1,151 @@ +# 安装stable-diffusion运行文生图参考样例说明 + +## 安装前准备 + +1)安装部署mindie +Service容器,镜像及部署指导参考[链接](https://www.hiascend.com/developer/ascendhub/detail/af85b724a7e5469ebd7ea13c3439d48f) + +此处只需使用mindie Service镜像和软件包安装包,无需执行部署大模型相关操作 + +2)注意运行的环境不能有torch-npu,如果存在,需卸载; 运行模型依赖MindIE Service 1.0.R2及以上的版本 + +3)下载SD 模型,下载链接如下 + +https://huggingface.co/stabilityai/stable-diffusion-2-1-base + +4) 下载MindIE Service适配代码,并切换到指定节点 + +``` +git clone https://gitee.com/ascend/ModelZoo-PyTorch.git +git checkout a6cef84ca2cce2413a3c34baa1649e05def18b67 +``` + +5)对stable_diffusion_pipeline文件打补丁 + +复制当前指导文件所在目录下的补丁文件 stable_diffusion_pipeline_parallel_web.patch 和stable_diffusion_pipeline_web.patch到 +ModelZoo-PyTorch/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion目录下 + +```bash +cd ModelZoo-PyTorch/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion +patch -p0 stable_diffusion_pipeline_parallel.py < stable_diffusion_pipeline_parallel_web.patch +patch -p0 stable_diffusion_pipeline.py < stable_diffusion_pipeline_web.patch +``` + +6)安装MindIE Service依赖 + +``` +pip3 install fastapi>=0.110.0 uvicorn diffusers + +pip3 install -r requirements.txt +``` + +7)设置环境变量 + +``` +source /usr/local/Ascend/mindie/set_env.sh +``` + +8)打stable_diffusion补丁 + +``` +python3 stable_diffusion_clip_patch.py +python3 stable_diffusion_attention_patch.py +python3 stable_diffusion_unet_patch.py +``` + +9)导出pt模型并进行编译 + +``` +(根据执行步骤3下载权重路径适配修改如下变量) +model_base="./stable-diffusion-2-1-base" +``` + +导出pt模型: + +``` +python3 export_ts.py --model ${model_base} --output_dir ./models \ + --parallel \ + --use_cache +``` + +参数说明: + +--model:模型名称或本地模型目录的路径 + +--output_dir: pt模型输出目录 + +--parallel:【可选】模型使用双芯/双卡并行推理 + +--use_cache: 【可选】模型使用UnetCache优化 + +--use_cache_faster: 【可选】模型使用deepcache+faster融合方案 + +10)启动web服务执行推理 + +不使用并行: + +``` +python3 stable_diffusion_pipeline.py \ + --model ${model_base} \ + --prompt_file ./prompts.txt \ + --device 0 \ + --save_dir ./results \ + --steps 50 \ + --scheduler DDIM \ + --soc Duo \ + --output_dir ./models \ + --use_cache +``` + +使用并行时: + +``` + python3 stable_diffusion_pipeline_parallel.py \ + --model ${model_base} \ + --prompt_file ./prompts.txt \ + --device 0,1 \ + --save_dir ./results \ + --steps 50 \ + --scheduler DDIM \ + --soc Duo \ + --output_dir ./models \ + --use_cache +``` + +参数说明: + +--model:模型名称或本地模型目录的路径。 + +--prompt_file:输入文本文件,按行分割。 + +--save_dir:生成图片的存放目录。 + +--steps:生成图片迭代次数。 + +--device:推理设备ID;可用逗号分割传入两个设备ID,此时会使用并行方式进行推理。 + +--scheduler: 【可选】推荐使用DDIM采样器。 + +--soc: 硬件配置,根据硬件配置选择Duo或者A2。 + +--output_dir: 编译好的模型路径。 + +--use_cache: 【可选】推荐使用UnetCache策略。 + +--use_cache_faster: 【可选】模型使用deepcache+faster融合方案。 + +上面步骤可参考[MindIE Service](https://gitee.com/ascend/ModelZoo-PyTorch/tree/master/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion#stable-diffusion%E6%A8%A1%E5%9E%8B-%E6%8E%A8%E7%90%86%E6%8C%87%E5%AF%BC) +指导 + +# 大模型测试 + +执行如下命令生成dog.jpeg文件 + +``` +curl http://127.0.0.1:7860/text2img \ + -X POST \ + -d '{"prompt":"dog wearing black glasses", "output_format": "jpeg", "size": "512*512"}' \ + -H 'Content-Type: application/json' | awk -F '"' '{print $2}' | base64 --decode > dog.jpeg +``` + + diff --git a/RAGSDK/MainRepo/sd_samples/stable_diffusion_pipeline_parallel_web.patch b/RAGSDK/MainRepo/sd_samples/stable_diffusion_pipeline_parallel_web.patch new file mode 100644 index 0000000000000000000000000000000000000000..e932c72c963f7919843961ec5f158ba4dd89ea36 --- /dev/null +++ b/RAGSDK/MainRepo/sd_samples/stable_diffusion_pipeline_parallel_web.patch @@ -0,0 +1,277 @@ +diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/stable_diffusion_pipeline_parallel.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/stable_diffusion_pipeline_parallel.py +index 76c7e606c..7a07a3793 100644 +--- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/stable_diffusion_pipeline_parallel.py ++++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/stable_diffusion_pipeline_parallel.py +@@ -13,13 +13,20 @@ + # limitations under the License. + + import argparse ++import base64 + import csv + import json + import os ++import io + import time + from typing import Callable, List, Optional, Union + import numpy as np + ++import uvicorn ++from fastapi import FastAPI, HTTPException ++from pydantic import BaseModel ++from fastapi.responses import Response ++ + import torch + import mindietorch + from mindietorch import _enums +@@ -29,6 +36,10 @@ from diffusers import DPMSolverMultistepScheduler, EulerDiscreteScheduler, DDIMS + from background_runtime import BackgroundRuntime, RuntimeIOInfo + from background_runtime_cache import BackgroundRuntimeCache, RuntimeIOInfoCache + ++app = FastAPI() ++pipe = None ++args = None ++ + clip_time = 0 + unet_time = 0 + vae_time = 0 +@@ -37,84 +48,6 @@ p2_time = 0 + p3_time = 0 + scheduler_time = 0 + +-class PromptLoader: +- def __init__( +- self, +- prompt_file: str, +- prompt_file_type: str, +- batch_size: int, +- num_images_per_prompt: int = 1, +- ): +- self.prompts = [] +- self.catagories = ['Not_specified'] +- self.batch_size = batch_size +- self.num_images_per_prompt = num_images_per_prompt +- +- if prompt_file_type == 'plain': +- self.load_prompts_plain(prompt_file) +- +- elif prompt_file_type == 'parti': +- self.load_prompts_parti(prompt_file) +- +- self.current_id = 0 +- self.inner_id = 0 +- +- def __len__(self): +- return len(self.prompts) * self.num_images_per_prompt +- +- def __iter__(self): +- return self +- +- def __next__(self): +- if self.current_id == len(self.prompts): +- raise StopIteration +- +- ret = { +- 'prompts': [], +- 'catagories': [], +- 'save_names': [], +- 'n_prompts': self.batch_size, +- } +- for _ in range(self.batch_size): +- if self.current_id == len(self.prompts): +- ret['prompts'].append('') +- ret['save_names'].append('') +- ret['catagories'].append('') +- ret['n_prompts'] -= 1 +- +- else: +- prompt, catagory_id = self.prompts[self.current_id] +- ret['prompts'].append(prompt) +- ret['catagories'].append(self.catagories[catagory_id]) +- ret['save_names'].append(f'{self.current_id}_{self.inner_id}') +- +- self.inner_id += 1 +- if self.inner_id == self.num_images_per_prompt: +- self.inner_id = 0 +- self.current_id += 1 +- +- return ret +- +- def load_prompts_plain(self, file_path: str): +- with os.fdopen(os.open(file_path, os.O_RDONLY), "r") as f: +- for i, line in enumerate(f): +- prompt = line.strip() +- self.prompts.append((prompt, 0)) +- +- def load_prompts_parti(self, file_path: str): +- with os.fdopen(os.open(file_path, os.O_RDONLY), "r", encoding='utf8') as f: +- # Skip the first line +- next(f) +- tsv_file = csv.reader(f, delimiter="\t") +- for i, line in enumerate(tsv_file): +- prompt = line[0] +- catagory = line[1] +- if catagory not in self.catagories: +- self.catagories.append(catagory) +- +- catagory_id = self.catagories.index(catagory) +- self.prompts.append((prompt, catagory_id)) +- + + class AIEStableDiffusionPipeline(StableDiffusionPipeline): + device_0 = None +@@ -983,10 +916,19 @@ def parse_arguments(): + help="Steps to use cache data." + ) + ++ parser.add_argument( ++ "--port", ++ type=int, ++ default=7860, ++ help="The port number used by fastapi." ++ ) ++ + return parser.parse_args() + + + def main(): ++ global args ++ global pipe + args = parse_arguments() + save_dir = args.save_dir + +@@ -1006,89 +948,66 @@ def main(): + pipe.compile_aie_model() + mindietorch.set_device(pipe.device_0) + +- skip_steps = [0] * args.steps ++ args.skip_steps = [0] * args.steps + +- flag_cache = 0 ++ args.flag_cache = 0 + if args.use_cache: +- flag_cache = 1 ++ args.flag_cache = 1 + for i in args.cache_steps.split(','): + if int(i) >= args.steps: + continue +- skip_steps[int(i)] = 1 +- +- use_time = 0 +- prompt_loader = PromptLoader(args.prompt_file, +- args.prompt_file_type, +- args.batch_size, +- args.num_images_per_prompt) +- +- infer_num = 0 +- image_info = [] +- current_prompt = None +- for i, input_info in enumerate(prompt_loader): +- prompts = input_info['prompts'] +- catagories = input_info['catagories'] +- save_names = input_info['save_names'] +- n_prompts = input_info['n_prompts'] +- +- print(f"[{infer_num + n_prompts}/{len(prompt_loader)}]: {prompts}") +- infer_num += args.batch_size +- +- start_time = time.time() +- if args.scheduler == "DDIM": +- images = pipe.ascendie_infer_ddim( +- prompts, +- num_inference_steps=args.steps, +- skip_steps=skip_steps, +- flag_cache=flag_cache, +- ) +- else: +- images = pipe.ascendie_infer( +- prompts, +- num_inference_steps=args.steps, +- skip_steps=skip_steps, +- flag_cache=flag_cache, +- ) ++ args.skip_steps[int(i)] = 1 + +- if i > 4: # do not count the time spent inferring the first 0 to 4 images +- use_time += time.time() - start_time + +- for j in range(n_prompts): +- image_save_path = os.path.join(save_dir, f"{save_names[j]}.png") +- image = images[0][j] +- image.save(image_save_path) ++class ImageRequest(BaseModel): ++ prompt: str ++ output_format: str ++ size: str = "512*512" + +- if current_prompt != prompts[j]: +- current_prompt = prompts[j] +- image_info.append({'images': [], 'prompt': current_prompt, 'category': catagories[j]}) + +- image_info[-1]['images'].append(image_save_path) ++@app.post("/text2img") ++async def text2image(image_request: ImageRequest): ++ prompt = image_request.prompt ++ output_format = image_request.output_format ++ height = int(image_request.size.split("*")[0]) ++ width = int(image_request.size.split("*")[1]) ++ if output_format.lower() not in ["png", "jpeg", "jpg", "webp"]: ++ raise HTTPException(status_code=400, detail="Invalid output format") + +- infer_num = infer_num - 5 # do not count the time spent inferring the first 5 images +- print(f"[info] infer number: {infer_num}; use time: {use_time:.3f}s\n" +- f"average time: {use_time / infer_num:.3f}s\n" +- f"clip time: {clip_time / infer_num:.3f}s\n" +- f"unet time: {unet_time / infer_num:.3f}s\n" +- f"vae time: {vae_time / infer_num:.3f}s\n" +- f"p1 time: {p1_time / infer_num:.3f}s\n" +- f"p2 time: {p2_time / infer_num:.3f}s\n" +- ) +- if hasattr(pipe, 'device_1'): +- if (pipe.unet_bg): +- pipe.unet_bg.stop() ++ if output_format == "jpg": ++ output_format = "jpeg" + +- if (pipe.unet_bg_cache): +- pipe.unet_bg_cache.stop() ++ global args ++ global pipe + +- # Save image information to a json file +- if os.path.exists(args.info_file_save_path): +- os.remove(args.info_file_save_path) ++ if args.scheduler == "DDIM": ++ images = pipe.ascendie_infer_ddim( ++ [prompt], ++ height = height, ++ width = width, ++ num_inference_steps=args.steps, ++ skip_steps=args.skip_steps, ++ flag_cache=args.flag_cache, ++ ) ++ else: ++ images = pipe.ascendie_infer( ++ [prompt], ++ height = height, ++ width = width, ++ num_inference_steps=args.steps, ++ skip_steps=args.skip_steps, ++ flag_cache=args.flag_cache, ++ ) + +- with os.fdopen(os.open(args.info_file_save_path, os.O_RDWR | os.O_CREAT, 0o640), "w") as f: +- json.dump(image_info, f) ++ image = images[0][0] + +- mindietorch.finalize() ++ image_byte_arr = io.BytesIO() ++ image.save(image_byte_arr, format=output_format) ++ image_byte_arr.seek(0) ++ return base64.b64encode(image_byte_arr.read()) ++ # return Response(content=image_byte_arr.getvalue(), media_type=f"image/{output_format.lower()}") + + + if __name__ == "__main__": + main() ++ uvicorn.run(app, host="0.0.0.0", port=args.port) +\ No newline at end of file diff --git a/RAGSDK/MainRepo/sd_samples/stable_diffusion_pipeline_web.patch b/RAGSDK/MainRepo/sd_samples/stable_diffusion_pipeline_web.patch new file mode 100644 index 0000000000000000000000000000000000000000..70264e1c17ed237e51c590df716f1d5980aff934 --- /dev/null +++ b/RAGSDK/MainRepo/sd_samples/stable_diffusion_pipeline_web.patch @@ -0,0 +1,279 @@ +diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/stable_diffusion_pipeline.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/stable_diffusion_pipeline.py +index a953ae480..b470ffdec 100644 +--- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/stable_diffusion_pipeline.py ++++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/stable_diffusion_pipeline.py +@@ -13,19 +13,30 @@ + # limitations under the License. + + import argparse ++import base64 + import csv + import json + import os ++import io + import time + from typing import Callable, List, Optional, Union + import numpy as np + ++import uvicorn ++from fastapi import FastAPI, HTTPException ++from pydantic import BaseModel ++from fastapi.responses import Response ++ + import torch + import mindietorch + from mindietorch import _enums + from diffusers import StableDiffusionPipeline + from diffusers import DPMSolverMultistepScheduler, EulerDiscreteScheduler, DDIMScheduler, SASolverScheduler + ++app = FastAPI() ++pipe = None ++args = None ++ + clip_time = 0 + unet_time = 0 + vae_time = 0 +@@ -34,84 +45,6 @@ p2_time = 0 + p3_time = 0 + scheduler_time = 0 + +-class PromptLoader: +- def __init__( +- self, +- prompt_file: str, +- prompt_file_type: str, +- batch_size: int, +- num_images_per_prompt: int = 1, +- ): +- self.prompts = [] +- self.catagories = ['Not_specified'] +- self.batch_size = batch_size +- self.num_images_per_prompt = num_images_per_prompt +- +- if prompt_file_type == 'plain': +- self.load_prompts_plain(prompt_file) +- +- elif prompt_file_type == 'parti': +- self.load_prompts_parti(prompt_file) +- +- self.current_id = 0 +- self.inner_id = 0 +- +- def __len__(self): +- return len(self.prompts) * self.num_images_per_prompt +- +- def __iter__(self): +- return self +- +- def __next__(self): +- if self.current_id == len(self.prompts): +- raise StopIteration +- +- ret = { +- 'prompts': [], +- 'catagories': [], +- 'save_names': [], +- 'n_prompts': self.batch_size, +- } +- for _ in range(self.batch_size): +- if self.current_id == len(self.prompts): +- ret['prompts'].append('') +- ret['save_names'].append('') +- ret['catagories'].append('') +- ret['n_prompts'] -= 1 +- +- else: +- prompt, catagory_id = self.prompts[self.current_id] +- ret['prompts'].append(prompt) +- ret['catagories'].append(self.catagories[catagory_id]) +- ret['save_names'].append(f'{self.current_id}_{self.inner_id}') +- +- self.inner_id += 1 +- if self.inner_id == self.num_images_per_prompt: +- self.inner_id = 0 +- self.current_id += 1 +- +- return ret +- +- def load_prompts_plain(self, file_path: str): +- with os.fdopen(os.open(file_path, os.O_RDONLY), "r") as f: +- for i, line in enumerate(f): +- prompt = line.strip() +- self.prompts.append((prompt, 0)) +- +- def load_prompts_parti(self, file_path: str): +- with os.fdopen(os.open(file_path, os.O_RDONLY), "r", encoding='utf8') as f: +- # Skip the first line +- next(f) +- tsv_file = csv.reader(f, delimiter="\t") +- for i, line in enumerate(tsv_file): +- prompt = line[0] +- catagory = line[1] +- if catagory not in self.catagories: +- self.catagories.append(catagory) +- +- catagory_id = self.catagories.index(catagory) +- self.prompts.append((prompt, catagory_id)) +- + + class AIEStableDiffusionPipeline(StableDiffusionPipeline): + def parser_args(self, args): +@@ -890,10 +823,19 @@ def parse_arguments(): + help="Steps to use cache data." + ) + ++ parser.add_argument( ++ "--port", ++ type=int, ++ default=7860, ++ help="The port number used by fastapi." ++ ) ++ + return parser.parse_args() + + + def main(): ++ global args ++ global pipe + args = parse_arguments() + save_dir = args.save_dir + +@@ -912,90 +854,73 @@ def main(): + pipe.scheduler = SASolverScheduler.from_config(pipe.scheduler.config) + pipe.compile_aie_model() + +- skip_steps = [0] * args.steps ++ args.skip_steps = [0] * args.steps + +- flag_cache = 0 ++ args.flag_cache = 0 + if args.use_cache: +- flag_cache = 1 ++ args.flag_cache = 1 + for i in args.cache_steps.split(','): + if int(i) >= args.steps: + continue +- skip_steps[int(i)] = 1 ++ args.skip_steps[int(i)] = 1 + +- use_time = 0 +- prompt_loader = PromptLoader(args.prompt_file, +- args.prompt_file_type, +- args.batch_size, +- args.num_images_per_prompt) +- +- infer_num = 0 +- image_info = [] +- current_prompt = None + + mindietorch.set_device(args.device) + +- for i, input_info in enumerate(prompt_loader): +- prompts = input_info['prompts'] +- catagories = input_info['catagories'] +- save_names = input_info['save_names'] +- n_prompts = input_info['n_prompts'] +- +- print(f"[{infer_num + n_prompts}/{len(prompt_loader)}]: {prompts}") +- infer_num += args.batch_size + +- start_time = time.time() + +- if args.scheduler == "DDIM": +- stream = mindietorch.npu.Stream("npu:" + str(args.device)) +- with mindietorch.npu.stream(stream): +- images = pipe.ascendie_infer_ddim( +- prompts, +- num_inference_steps=args.steps, +- skip_steps=skip_steps, +- flag_cache=flag_cache, +- ) +- else: +- images = pipe.ascendie_infer( +- prompts, +- num_inference_steps=args.steps, +- skip_steps=skip_steps, +- flag_cache=flag_cache, +- ) ++class ImageRequest(BaseModel): ++ prompt: str ++ output_format: str ++ size: str = "512*512" + +- if i > 4: # do not count the time spent inferring the first 0 to 4 images +- use_time += time.time() - start_time + +- for j in range(n_prompts): +- image_save_path = os.path.join(save_dir, f"{save_names[j]}.png") +- image = images[0][j] +- image.save(image_save_path) ++@app.post("/text2img") ++async def text2image(image_request: ImageRequest): ++ prompt = image_request.prompt ++ output_format = image_request.output_format ++ height = int(image_request.size.split("*")[0]) ++ width = int(image_request.size.split("*")[1]) + +- if current_prompt != prompts[j]: +- current_prompt = prompts[j] +- image_info.append({'images': [], 'prompt': current_prompt, 'category': catagories[j]}) ++ if output_format.lower() not in ["png", "jpeg", "jpg", "webp"]: ++ raise HTTPException(status_code=400, detail="Invalid output format") + +- image_info[-1]['images'].append(image_save_path) ++ if output_format == "jpg": ++ output_format = "jpeg" + +- infer_num = infer_num - 5 # do not count the time spent inferring the first 5 images +- print(f"[info] infer number: {infer_num}; use time: {use_time:.3f}s\n" +- f"average time: {use_time / infer_num:.3f}s\n" +- f"clip time: {clip_time / infer_num:.3f}s\n" +- f"unet time: {unet_time / infer_num:.3f}s\n" +- f"vae time: {vae_time / infer_num:.3f}s\n" +- f"p1 time: {p1_time / infer_num:.3f}s\n" +- f"p2 time: {p2_time / infer_num:.3f}s\n" +- f"p3 time: {p3_time / infer_num:.3f}s\n" +- f"scheduler time: {scheduler_time / infer_num:.3f}s\n") ++ global args ++ global pipe + +- # Save image information to a json file +- if os.path.exists(args.info_file_save_path): +- os.remove(args.info_file_save_path) ++ if args.scheduler == "DDIM": ++ stream = mindietorch.npu.Stream("npu:" + str(args.device)) ++ with mindietorch.npu.stream(stream): ++ images = pipe.ascendie_infer_ddim( ++ [prompt], ++ height = height, ++ width = width, ++ num_inference_steps=args.steps, ++ skip_steps=args.skip_steps, ++ flag_cache=args.flag_cache, ++ ) ++ else: ++ images = pipe.ascendie_infer( ++ [prompt], ++ height = height, ++ width = width, ++ num_inference_steps=args.steps, ++ skip_steps=args.skip_steps, ++ flag_cache=args.flag_cache, ++ ) + +- with os.fdopen(os.open(args.info_file_save_path, os.O_RDWR | os.O_CREAT, 0o640), "w") as f: +- json.dump(image_info, f) ++ image = images[0][0] + +- mindietorch.finalize() ++ image_byte_arr = io.BytesIO() ++ image.save(image_byte_arr, format=output_format) ++ image_byte_arr.seek(0) ++ return base64.b64encode(image_byte_arr.read()) ++ # return Response(content=image_byte_arr.getvalue(), media_type=f"image/{output_format.lower()}") + + + if __name__ == "__main__": + main() ++ uvicorn.run(app, host="0.0.0.0", port=args.port) +\ No newline at end of file diff --git a/RAGSDK/PocValidation/README.md b/RAGSDK/PocValidation/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5fd7075e1451594fc1f19c98757f14a52325f7c0 --- /dev/null +++ b/RAGSDK/PocValidation/README.md @@ -0,0 +1,6 @@ +## 目录结构与说明 + +| 目录 | 说明 | +|----------------------------|-------------| +| chat_with_ascend | 问答场景参考样例 | +| rag_recursive_tree_demo.oy | 递归树检索参考样例 | diff --git a/RAGSDK/PocValidation/chat_with_ascend/app.py b/RAGSDK/PocValidation/chat_with_ascend/app.py new file mode 100644 index 0000000000000000000000000000000000000000..1ab773228be13d19a8a236adbfd2c0fc2179fec1 --- /dev/null +++ b/RAGSDK/PocValidation/chat_with_ascend/app.py @@ -0,0 +1,682 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + +import argparse +import gradio as gr +import os +import shutil +import sys +from langchain.retrievers import EnsembleRetriever +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.document_loaders import TextLoader +from langchain_core.documents import Document +from langchain_core.prompts import PromptTemplate +from loguru import logger +from mx_rag.document.loader import DocxLoader, PdfLoader, ExcelLoader +from mx_rag.embedding.local import TextEmbedding +from mx_rag.embedding.service import TEIEmbedding +from mx_rag.knowledge import KnowledgeStore, KnowledgeDB +from mx_rag.knowledge.handler import upload_files, LoaderMng +from mx_rag.llm import Text2TextLLM, LLMParameterConfig +from mx_rag.reranker.local import LocalReranker +from mx_rag.reranker.service import TEIReranker +from mx_rag.retrievers import Retriever +from mx_rag.retrievers.full_text_retriever import FullTextRetriever +from mx_rag.storage.document_store import SQLiteDocstore, MilvusDocstore +from mx_rag.storage.vectorstore import MilvusDB +from mx_rag.utils import ClientParam +from paddle.base import libpaddle +from pymilvus import MilvusClient +from typing import List, Iterator, Dict + +sys.tracebacklimit = 1000 + +user_id = "7d1d04c1-dd5f-43f8-bad5-99795f24bce6" + +default_prompt = """<指令>以下是提供的背景知识,请简洁和专业地回答用户的问题。如果无法从已知信息中得到答案,请根据自身经验做出回答。<指令>\n背景知识:{context}\n用户问题:{question}""" +llm_prompt = default_prompt +# 初始化知识管理关系数据库 +knowledge_store = KnowledgeStore(db_path="./knowledge_store_sql.db") + +KnowledgeDB_Map = {} + +milvus_url = "http://127.0.0.1:19530" + + +# 创建新的知识库 +def get_knowledge_db(knowledge_name: str): + knowledge_name = get_knowledge_ename(knowledge_name) + if knowledge_name in KnowledgeDB_Map.keys(): + return KnowledgeDB_Map["knowledge_name"][2] + logger.info(f"get knowledge_name:{knowledge_name}") + index_name, db_name = get_db_file_names(knowledge_name) + + milvus_url = os.environ.get("milvus_url") + milvus_client = MilvusClient(milvus_url) + vector_store = MilvusDB.create(client=milvus_client, x_dim=int(os.environ.get("embedding_dim")), + collection_name=f"{knowledge_name}_vector") + chunk_store = MilvusDocstore(milvus_client, collection_name=f"{knowledge_name}_chunk") + + knowledge_store.add_knowledge(knowledge_name, user_id=user_id) + # 初始化知识库管理 + knowledge_db = KnowledgeDB(knowledge_store=knowledge_store, + chunk_store=chunk_store, + vector_store=vector_store, + knowledge_name=knowledge_name, + white_paths=["/tmp"], + user_id=user_id) + KnowledgeDB_Map["knowledge_name"] = (vector_store, chunk_store, knowledge_db) + + return vector_store, chunk_store, knowledge_db + + +# 创建检索器 +def creat_dense_retriever(knowledge_name: str, top_k, score_threshold): + vector_store, chunk_store, _ = get_knowledge_db(knowledge_name) + # 初始化文档chunk关系数据库 + dense_retriever = Retriever(vector_store=vector_store, + document_store=chunk_store, + embed_func=text_emb.embed_documents, + k=top_k, + score_threshold=score_threshold) + + return dense_retriever + + +def creat_sparse_retriever(knowledge_name: str, top_k): + _, chunk_store, _ = get_knowledge_db(knowledge_name) + # 初始化文档chunk关系数据库 + sparse_retriever = FullTextRetriever( + document_store=chunk_store, + k=top_k) + + return sparse_retriever + + +# 删除知识库 +def delete_knowledge_db(knowledge_name: str): + knowledge_name = get_knowledge_ename(knowledge_name) + knowledge_names = knowledge_store.get_all_knowledge_name(user_id) + + if knowledge_name in knowledge_names: + _, _, knowledge_db = get_knowledge_db(knowledge_name) + # 删除知识中的所有信息,包含文件,文本,向量 + knowledge_db.delete_all() + milvus_url = os.environ.get("milvus_url") + milvus_client = MilvusClient(milvus_url) + vector_store = MilvusDB.create(client=milvus_client, x_dim=int(os.environ.get("embedding_dim")), + collection_name=f"{knowledge_name}_vector") + chunk_store = MilvusDocstore(milvus_client, collection_name=f"{knowledge_name}_chunk") + + vector_store.drop_collection() + chunk_store.drop_collection() + + return get_knowledge_info() + + +# 获取知识库列表 +def get_knowledge_info(): + knowledge_info = knowledge_store.get_all_knowledge_info(user_id) + knowledge_names = [info.knowledge_name for info in knowledge_info] + return knowledge_names, len(knowledge_names) + + +# 获取知识库中文档列表 +def get_document(knowledge_name: str): + _, _, knowledge_db = get_knowledge_db(knowledge_name) + doc_names = [doc_model.document_name for doc_model in knowledge_db.get_all_documents()] + return knowledge_name, doc_names, len(doc_names) + + +# 清空知识库中文档列表 +def clear_file_in_kg(knowledge_name: str): + knowledge_name, doc_names, doc_cnt = get_document(knowledge_name) + if doc_cnt > 0: + _, _, knowledge_db = get_knowledge_db(knowledge_name) + for doc_name in doc_names: + knowledge_db.delete_file(doc_name) + return get_document(knowledge_name) + else: + return knowledge_name, doc_names, 0 + + +# 删除知识库中的文件 +def delete_document_in_kg(knowledge_name: str, files: str): + _, _, knowledge_db = get_knowledge_db(knowledge_name) + for file in files.split(","): + knowledge_db.delete_file(file) + + doc_names = [doc_model.document_name for doc_model in knowledge_db.get_all_documents()] + return knowledge_name, doc_names, len(doc_names) + + +def set_llm_prompt(prompt: str): + global llm_prompt + llm_prompt = prompt + + +def get_llm_prompt(): + global llm_prompt + return llm_prompt + + +# 上传知识库 +def file_upload(files, + knowledge_db_name: str = 'test_poc', + chunk_size: int = 750, + chunk_overlap: int = 150 + ): + save_file_path = "/tmp/document_files" + knowledge_db_name = get_knowledge_ename(knowledge_db_name) + # 指定保存文件的文件夹 + if not os.path.exists(save_file_path): + os.makedirs(save_file_path) + if files is None or len(files) == 0: + print('no file need save') + + _, _, knowledge_db = get_knowledge_db(knowledge_db_name) + + # 注册文档处理器 + loader_mng = LoaderMng() + # 加载文档加载器,可以使用mxrag自有的,也可以使用langchain的 + loader_mng.register_loader(loader_class=TextLoader, file_types=[".txt", ".md"]) + loader_mng.register_loader(loader_class=DocxLoader, file_types=[".docx"]) + loader_mng.register_loader(loader_class=PdfLoader, file_types=[".pdf"]) + loader_mng.register_loader(loader_class=ExcelLoader, file_types=[".xlsx", ".xls"]) + + # 加载文档切分器,使用langchain的 + loader_mng.register_splitter(splitter_class=RecursiveCharacterTextSplitter, + file_types=[".docx", ".txt", ".md", ".pdf", ".xlsx", ".xls"], + splitter_params={"chunk_size": chunk_size, + "chunk_overlap": chunk_overlap, + "keep_separator": False + }) + for file in files: + try: + # 上传领域知识文档 + shutil.copy(file.name, save_file_path) + # 知识库:chunk\embedding\add + upload_files(knowledge_db, [file.name], loader_mng=loader_mng, embed_func=text_emb.embed_documents, + force=True) + print(f"file {file.name} save to {save_file_path}.") + except Exception as err: + logger.error(f"save failed, find exception: {err}") + + +def file_change(files, upload_btn): + print("file changes") + + +def get_db_file_names(knowledge_name: str): + index_name = "./" + knowledge_name + "_faiss.index" + db_name = "./" + knowledge_name + "_sql.db" + return index_name, db_name + + +def get_knowledge_ename(knowledge_name: str): + if knowledge_name is None or len(knowledge_name) == 0 or ' ' in knowledge_name: + return 'test' + else: + return knowledge_name + + +# 历史问题改写 +def generate_question(history, llm, history_n: int = 5): + prompt = """现在你有一个上下文依赖问题补全任务,任务要求:请根据对话历史和用户当前的问句,重写问句。\n + 历史问题依次是:\n + {}\n + 用户当前的问句:\n + {}\n + 注意如果当前问题不依赖历史问题直接返回none即可\n + 请根据上述对话历史重写用户当前的问句,仅输出重写后的问句,不需要附加任何分析。\n + 重写问句: \n + """ + if len(history) <= 2: + return history + cur_query = history[-1][0] + history_qa = history[0:-1] + history_list = [f"第{idx + 1}轮:{q_a[0]}" for idx, q_a in enumerate(history_qa) if q_a[0] is not None] + history_list = history_list[:history_n] + history_str = "\n\n".join(history_list) + re_query = prompt.format(history_str, cur_query) + new_query = llm.chat(query=re_query, llm_config=LLMParameterConfig(max_tokens=512, + temperature=0.5, + top_p=0.95)) + if new_query != "none": + history[-1][0] = "原始问题: " + cur_query + history += [[new_query, None]] + return history + + +def merge_query_prompt_by_metadata(docs: List[Document], prompt: str): + final_prompt = "" + document_separator = "\n\n" + if len(docs) != 0: + last_doc = docs[-1] + last_doc.metadata["answer"] = (last_doc.metadata["answer"] + + f"{document_separator}{prompt}") + docs[-1] = last_doc + + final_prompt = document_separator.join(x.metadata["answer"] for x in docs) + + return final_prompt + + +def merge_query_prompt_by_content(docs: List[Document], prompt: str): + final_prompt = "" + document_separator = "\n\n" + if len(docs) != 0: + last_doc = docs[-1] + last_doc.page_content = (last_doc.page_content + + f"{document_separator}{prompt}") + docs[-1] = last_doc + + final_prompt = document_separator.join(x.page_content for x in docs) + + return final_prompt + + +def do_stream_query(q_with_prompt: str, llm, llm_config: LLMParameterConfig, question: str, + q_docs: List[Document] = None) -> Iterator[Dict]: + logger.info("invoke stream query") + resp = {"query": question, "result": ""} + resp['source_documents'] = [{'metadata': x.metadata, 'page_content': x.page_content} for x in q_docs] + + for response in llm.chat_streamly(query=q_with_prompt, llm_config=llm_config): + resp['result'] = response + yield resp + + +# 聊天对话框 +def bot_response(history, + history_r, + max_tokens: int = 512, + temperature: float = 0.5, + top_p: float = 0.95, + history_n: int = 5, + score_threshold: float = 0.5, + top_k: int = 1, + chat_type: str = "RAG检索增强对话", + show_type: str = "不显示", + is_rewrite: str = "否", + knowledge_name: str = 'test_poc' + ): + chat_type_mapping = {"RAG检索增强对话": 1, + "仅大模型对话": 0} + show_type_mapping = {"对话结束后显示": 1, + "检索框单独显示": 2, + "不显示": 0} + is_rewrite_mapping = {"是": 1, + "否": 0} + # 初始化检索器 + knowledge_name = get_knowledge_ename(knowledge_name) + dense_retriever = creat_dense_retriever(knowledge_name, top_k, score_threshold) + sparse_retriever = creat_sparse_retriever(knowledge_name, top_k) + + # 历史问题改写 + if is_rewrite_mapping.get(is_rewrite) == 1: + history = generate_question(history, llm, history_n) + history[-1][1] = '推理错误' + try: + # 仅使用大模型回答 + if chat_type_mapping.get(chat_type) == 0: + response = llm.chat_streamly(query=history[-1][0], + llm_config=LLMParameterConfig(max_tokens=max_tokens, + temperature=temperature, + top_p=top_p)) + # 返回迭代器 + for res in response: + history[-1][1] = res + yield history, history_r + # 使用RAG增强回答 + elif chat_type_mapping.get(chat_type) == 1: + hybrid_retriever = EnsembleRetriever( + retrievers=[dense_retriever, sparse_retriever], weights=[0.7, 0.3] + ) + q_docs = hybrid_retriever.invoke(history[-1][0]) + if reranker is not None: + score = reranker.rerank(history[-1][0], [doc.page_content for doc in q_docs]) + q_docs = reranker.rerank_top_k(q_docs, score) + + prompt = PromptTemplate.from_template(llm_prompt) + + query = prompt.format(context="\n\n".join(doc.page_content for doc in q_docs), question=history[-1][0]) + + response = do_stream_query(query, llm, q_docs=q_docs, question=history[-1][0], + llm_config=LLMParameterConfig(max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stream=True)) + + # 不展示检索内容 + if show_type_mapping.get(show_type) == 0: + for res in response: + history[-1][1] = '推理错误' if res['result'] is None else res['result'].replace("<", "<").replace( + ">", ">") + yield history, history_r + # 问答结尾展示 + elif show_type_mapping.get(show_type) == 1: + for res in response: + history[-1][1] = '推理错误' if res['result'] is None else res['result'].replace("<", "<").replace( + ">", ">") + q_docs = res['source_documents'] + yield history, history_r + # 有检索到信息 + if len(q_docs) > 0: + history_last = '' + for i, source in enumerate(q_docs): + sources = "\n====检索信息来源:" + str(i + 1) + "[数据库名]:" + str(knowledge_name) + \ + "[文件名]:" + source['metadata']['source'] + "====" + "\n" + \ + "参考内容:" + source['page_content'] + "\n" + history_last += sources + history += [[None, history_last]] + yield history, history_r + # 检索窗口展示 + else: + for res in response: + history[-1][1] = '推理错误' if res['result'] is None else res['result'].replace("<", "<").replace( + ">", ">") + q_docs = res['source_documents'] + yield history, history_r + # 有检索到信息 + if len(q_docs) > 0: + history_r_last = '' + for i, source in enumerate(q_docs): + sources = "\n====检索信息来源:" + str(i + 1) + "[数据库名]:" + str(knowledge_name) + \ + "[文件名]:" + source['metadata']['source'] + "====" + "\n" + \ + "参考内容:" + source['page_content'] + "\n" + history_r_last += sources + history_r += [[history[-1][0], history_r_last]] + yield history, history_r + except Exception as err: + logger.error(f"query failed, find exception: {err}") + yield history, history_r + + +# 检索对话框 +def re_response(history_r, + score_threshold: float = 0.5, + top_k: int = 1, + knowledge_name: str = 'test_poc' + ): + # 初始化检索器 + knowledge_name = get_knowledge_ename(knowledge_name) + retriever_cls = creat_dense_retriever(knowledge_name, top_k, score_threshold) + q_docs = retriever_cls.invoke(history_r[-1][0]) + if len(q_docs) > 0: + history_r_last = '' + for i, source in enumerate(q_docs): + sources = "\n====检索信息来源:" + str(i + 1) + "[数据库名]:" + str(knowledge_name) + \ + "[文件名]:" + source.metadata['source'] + "====" + "\n" + \ + "参考内容:" + source.page_content + "\n" + history_r_last += sources + history_r[-1][1] = history_r_last + else: + history_r[-1][1] = "未检索到相关信息" + return history_r + + +# 检索信息 +def user_retriever(user_message, history_r): + return "", history_r + [[user_message, None]] + + +# 聊天信息 +def user_query(user_message, history): + return "", history + [[user_message, None]] + + +def clear_history(history): + return [] + + +if __name__ == '__main__': + class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter): + def _get_default_metavar_for_optional(self, action): + return action.type.__name__ + + def _get_default_metavar_for_positional(self, action): + return action.type.__name__ + + + parse = argparse.ArgumentParser(formatter_class=CustomFormatter) + parse.add_argument("--embedding_path", type=str, default="/home/data/acge_text_embedding", + help="embedding模型本地路径") + parse.add_argument("--tei_emb", type=bool, default=False, help="是否使用TEI服务化的embedding模型") + parse.add_argument("--embedding_url", type=str, default="http://127.0.0.1:8080/embed", + help="使用TEI服务化的embedding模型url地址") + parse.add_argument("--embedding_dim", type=int, default=1024, help="embedding模型向量维度") + parse.add_argument("--llm_url", type=str, default="http://127.0.0.1:1025/v1/chat/completions", + help="大模型url地址") + parse.add_argument("--model_name", type=str, default="Llama3-8B-Chinese-Chat", help="大模型名称") + parse.add_argument("--tei_reranker", type=bool, default=False, help="是否使用TEI服务化的reranker模型") + parse.add_argument("--reranker_path", type=str, default=None, help="reranker模型本地路径") + parse.add_argument("--reranker_url", type=str, default=None, help="使用TEI服务化的embedding模型url地址") + parse.add_argument("--dev", type=int, default=0, help="使用的npu卡,可通过执行npu-smi info获取") + parse.add_argument("--port", type=int, default=8080, help="web后端端口") + parse.add_argument("--milvus_url", type=str, default="http://127.0.0.1:19530", help="milvus url地址") + + args = parse.parse_args().__dict__ + embedding_path: str = args.pop('embedding_path') + tei_emb: bool = args.pop('tei_emb') + embedding_url: str = args.pop('embedding_url') + embedding_dim: int = args.pop('embedding_dim') + llm_url: str = args.pop('llm_url') + model_name: str = args.pop('model_name') + tei_reranker: bool = args.pop('tei_reranker') + reranker_path: str = args.pop('reranker_path') + reranker_url: str = args.pop('reranker_url') + dev: int = args.pop('dev') + port: int = args.pop('port') + + os.environ["milvus_url"] = args.pop("milvus_url") + os.environ["embedding_dim"] = str(embedding_dim) + + # 初始化test数据库 + knowledge_db = get_knowledge_db('test_poc') + # 配置text生成text大模型chain,具体ip端口请根据实际情况适配修改 + llm = Text2TextLLM(base_url=llm_url, model_name=model_name, client_param=ClientParam(use_http=True)) + # 配置embedding模型,请根据模型具体路径适配 + if tei_emb: + text_emb = TEIEmbedding(url=embedding_url, client_param=ClientParam(use_http=True)) + else: + text_emb = TextEmbedding(model_path=embedding_path, dev_id=dev) + # 配置reranker,请根据模型具体路径适配 + if tei_reranker: + reranker = TEIReranker(url=reranker_url, client_param=ClientParam(use_http=True)) + elif reranker_path is not None: + reranker = LocalReranker(model_path=reranker_path, dev_id=dev, k=3) + else: + reranker = None + + + # 构建gradio框架 + def build_demo(): + with (gr.Blocks() as demo): + gr.HTML("

检索增强生成(RAG)对话

powered by MindX RAG

") + with gr.Row(): + with gr.Column(scale=100): + with gr.Row(): + files = gr.components.File( + height=100, + file_count="multiple", + file_types=[".docx", ".txt", ".md", ".pdf", ".xlsx", ".xls"], + interactive=True, + label="上传知识库文档" + ) + with gr.Row(): + upload_btn = gr.Button("上传文件") + with gr.Row(): + with gr.TabItem("知识库情况"): + knowledge_names, knowledge_name_num = get_knowledge_info() + set_knowledge_name = gr.Textbox(label='设置当前知识库', + value=knowledge_names[0], + placeholder="在此输入知识库名称,默认使用test知识库") + with gr.Row(): + creat_knowledge_btn = gr.Button('创建知识库') + delete_knowledge_btn = gr.Button('删除知识库') + + knowledge_name_output = gr.Textbox(label='知识库列表', value=knowledge_names) + knowledge_number_output = gr.Textbox(label='知识库数量', value=knowledge_name_num) + with gr.Row(): + show_knowledge_btn = gr.Button('显示知识库') + with gr.TabItem("文件情况"): + knowledge_names, knowledge_name_num = get_knowledge_info() + knowledge_name = gr.Textbox(label='知识库名称', value=knowledge_names[0]) + knowledge_file_output = gr.Textbox(label='知识库文件列表') + knowledge_file_num_output = gr.Textbox(label='知识库文件数量') + delete_knowledge_files = gr.Textbox(label='待删除知识库中的文件,使用逗号分隔', value="") + with gr.Row(): + knowledge_file_out_btn = gr.Button('显示文件情况') + knowledge_file_delete_btn = gr.Button('删除指定文件') + knowledge_clear_btn = gr.Button('清空知识库') + with gr.Row(): + with gr.TabItem("设置提示词"): + new_llm_prompt = gr.Textbox(label="提示词", value=default_prompt) + with gr.Row(): + set_llm_prompt_btn = gr.Button('设置提示词') + + with gr.Row(): + with gr.Accordion(label='文档切分参数设置', open=False): + chunk_size = gr.Slider( + minimum=50, + maximum=5000, + value=750, + step=50, + interactive=True, + label="chunk_size", + info="文本切分长度" + ) + chunk_overlap = gr.Slider( + minimum=10, + maximum=500, + value=150, + step=10, + interactive=True, + label="chunk_overlap", + info="文本切分填充长度" + ) + with gr.Row(): + with gr.Accordion(label='大模型参数设置', open=False): + temperature = gr.Slider( + minimum=0.01, + maximum=2, + value=0.5, + step=0.01, + interactive=True, + label="温度", + info="Token生成的随机性" + ) + top_p = gr.Slider( + minimum=0.01, + maximum=1, + value=0.95, + step=0.05, + interactive=True, + label="Top P", + info="累计概率总和阈值" + ) + max_tokens = gr.Slider( + minimum=100, + maximum=1024, + value=512, + step=1, + interactive=True, + label="最大tokens", + info="输入+输出最多的tokens数" + ) + is_rewrite = gr.Radio(['是', '否'], value="否", label="是否根据历史提问重写问题?") + history_n = gr.Slider( + minimum=1, + maximum=10, + value=5, + step=1, + interactive=True, + label="历史提问重写轮数", + info="问题重写时所参考的历史提问轮数" + ) + with gr.Row(): + with gr.Accordion(label='检索参数设置', open=False): + score_threshold = gr.Slider( + minimum=0, + maximum=1, + value=0.5, + step=0.01, + interactive=True, + label="score_threshold", + info="相似性检索阈值,值越大表示越相关,低于阈值不会被返回。" + ) + top_k = gr.Slider( + minimum=1, + maximum=10, + value=3, + step=1, + interactive=True, + label="top_k", + info="相似性检索返回条数" + ) + show_type = gr.Radio(['对话结束后显示', '检索框单独显示', '不显示'], value="对话结束后显示", + label="知识库文档匹配结果展示方式选择") + with gr.Column(scale=200): + with gr.Tabs(): + with gr.TabItem("对话窗口"): + chat_type = gr.Radio(['RAG检索增强对话', '仅大模型对话'], value="RAG检索增强对话", + label="请选择对话模式?") + chatbot = gr.Chatbot(height=550) + with gr.Row(): + msg = gr.Textbox(placeholder="在此输入问题...", container=False) + with gr.Row(): + send_btn = gr.Button(value="发送", variant="primary") + clean_btn = gr.Button(value="清空历史") + with gr.TabItem("检索窗口"): + chatbot_r = gr.Chatbot(height=550) + with gr.Row(): + msg_r = gr.Textbox(placeholder="在此输入问题...", container=False) + with gr.Row(): + send_btn_r = gr.Button(value="文档检索", variant="primary") + clean_btn_r = gr.Button(value="清空历史") + # RAG发送 + send_btn.click(user_query, [msg, chatbot], [msg, chatbot], queue=False + ).then(bot_response, + [chatbot, chatbot_r, max_tokens, temperature, top_p, history_n, score_threshold, + top_k, chat_type, show_type, is_rewrite, set_knowledge_name], + [chatbot, chatbot_r]) + # RAG清除历史 + clean_btn.click(clear_history, chatbot, chatbot) + # 上传文件 + files.change(file_change, [], []) + upload_btn.click(file_upload, [files, set_knowledge_name, chunk_size, chunk_overlap], files) + # 管理所有知识库 + creat_knowledge_btn.click(get_knowledge_db, [set_knowledge_name], []).then(get_knowledge_info, [], + [knowledge_name_output, + knowledge_number_output]) + # 设置大模型提示词 + set_llm_prompt_btn.click(set_llm_prompt, [new_llm_prompt], []) + + show_knowledge_btn.click(get_knowledge_info, [], [knowledge_name_output, knowledge_number_output]) + delete_knowledge_btn.click(delete_knowledge_db, [set_knowledge_name], + [knowledge_name_output, knowledge_number_output]) + # 管理知识库里文件 + knowledge_file_out_btn.click(get_document, [set_knowledge_name], + [knowledge_name, knowledge_file_output, knowledge_file_num_output]) + knowledge_clear_btn.click(clear_file_in_kg, [set_knowledge_name], + [knowledge_name, knowledge_file_output, knowledge_file_num_output]) + + knowledge_file_delete_btn.click(delete_document_in_kg, [knowledge_name, delete_knowledge_files], + [knowledge_name, knowledge_file_output, knowledge_file_num_output]) + # 检索发送 + send_btn_r.click(user_retriever, [msg_r, chatbot_r], [msg_r, chatbot_r], queue=False + ).then(re_response, [chatbot_r, score_threshold, top_k, set_knowledge_name], chatbot_r) + # 检索清除历史 + clean_btn_r.click(clear_history, chatbot_r, chatbot_r) + return demo + + + def create_gradio(ports): + demo = build_demo() + demo.queue() + demo.launch(share=True, server_name="0.0.0.0", server_port=ports) + + + # 启动gradio + create_gradio(port) diff --git a/RAGSDK/PocValidation/chat_with_ascend/images/demo.png b/RAGSDK/PocValidation/chat_with_ascend/images/demo.png new file mode 100644 index 0000000000000000000000000000000000000000..e8c842cf1839ef6cdd3af6a5fd8d323993da8a44 Binary files /dev/null and b/RAGSDK/PocValidation/chat_with_ascend/images/demo.png differ diff --git a/RAGSDK/PocValidation/chat_with_ascend/readme.md b/RAGSDK/PocValidation/chat_with_ascend/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..443c4a0554285f1032aaf8467539467bc5f6c0f3 --- /dev/null +++ b/RAGSDK/PocValidation/chat_with_ascend/readme.md @@ -0,0 +1,46 @@ +# RAG SDK运行说明 + +## 环境准备(容器化部署) + +1.下载RAG +SDK镜像并按操作步骤启动容器,下载参考地址:https://www.hiascend.com/developer/ascendhub/detail/27c1cba133384f59ac7ec2500b0e3ffc + +2.下载mindie镜像并按操作步骤启动大模型,下载参考地址:https://www.hiascend.com/developer/ascendhub/detail/af85b724a7e5469ebd7ea13c3439d48f + +注意:按照操作步骤完成并执行推理脚本成功后,需按以下步骤继续启动MindIE server大模型推理服务,以供RAG +SDK调用。参考地址:https://www.hiascend.com/document/detail/zh/mindie/10RC2/envdeployment/instg/mindie_instg_0025.html + +3.下载embedding模型,存放在指定目录:如/data/bge-large-zh-v1.5/(与app.py中embedding模型路径对应一致) + +4.下载reranker模型(可选),存放在指定目录:如/data/bge-reranker-large/(启动时配置tei_reranker参数 ) + +5.参考指导安装运行milvus数据库 + +链接:https://milvus.io/docs/zh/install_standalone-docker.md + +## demo运行 + +1.将app.py文件放至容器任意目录下 + +2.调用示例 + +``` +python3 app.py --llm_url "http://127.0.0.1:1025/v1/chat/completions" --port 8080 +``` + +可通过以下命令查看,并完善其他参数的传入 + +``` +python3 app.py --help +``` + +3.运行demo打开前端网页 + +![demo.png](images%2Fdemo.png) + +说明:此demo适配POC版本的RAG +SDK软件包,如果使用了网络代理启动框架前先关闭代理。如果遇到pydantic.errors.PydanticSchemaGenerationError类错误,请将gradio版本切换至3.50.2。 + + + + diff --git a/RAGSDK/PocValidation/embedding_finetune/README.md b/RAGSDK/PocValidation/embedding_finetune/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c5cd8a15203fd547904e17551612f0ffc2469d0c --- /dev/null +++ b/RAGSDK/PocValidation/embedding_finetune/README.md @@ -0,0 +1,305 @@ +# Embedding微调样例代码说明 + +## 脚本执行 + +样例pyhthon脚本执行命令: + +``` +python3 finetune.py \ +--document_path /home/embedding_finetune/rag_optimized/train_document \ +--generate_dataset_path /home/embedding_finetune/rag_optimized/dataset \ +--llm_url http://51.38.68.109:1025/v1/chat/completions \ +--llm_model_name Llama \ +--use_http True \ +--embedding_model_path /home/embedding_finetune/bge-large-zh-v1.5 \ +--reranker_model_path /home/embedding_finetune/bge-reranker-v2-m3 \ +--finetune_output_path /home/embedding_finetune/rag_optimized/finetune_model \ +--featured_percentage 0.8 \ +--llm_threshold_score 0.8 \ +--train_question_number 2 \ +--query_rewrite_number 1 \ +--eval_data_path /home/embedding_finetune/rag_optimized/eval/evaluate_data.jsonl \ +--max_iter 3 \ +--log_path /home/embedding_finetune/app.log \ +--increase_rate 15 +``` + +参数说明: + +document_path:用于训练的原始文档路径,支持txt、md、doc格式 + +generate_dataset_path:数据集路径,生成的训练数据存放路径 + +llm_url:大模型推理接口地址 + +llm_model_name:接口地址对应的大模型名称 + +use_http:是否是http接口,默认False + +embedding_model_path:embedding模型路径 + +reranker_model_path:reranker模型路径 + +finetune_output_path:微调模型的输出路径 + +featured_percentage:精选比例,bm25打分和reranker排序后保留的列表大小 + +llm_threshold_score:大模型打分优选分数阈值,只保留分数在阈值之上的QD对 + +train_question_number:针对切分的doc片段,每个doc片段产生的问题数 + +query_rewrite_number:query重写的次数 + +eval_data_path:评估数据路径,需要符合{"anchor": "query?", "positive": "answer."}这种格式,也可自定义key值,注意和代码对应 + +或者借助sdk辅助生成,生成后注意数据质量,需要手动过滤低质量数据 + +max_iter:最大迭代次数,对于切分后的doc数据来说,设定最大迭代次数,则每次取1/max_iter的数据(顺序取)参与训练数据生成 + +log_path:log文件保存路径 + +increase_rate:提升比例,当微调模型的召回率-原始模型的召回率超过了提升比例,则终止训练 + +## 微调实践之模型合并 + +如果微调后的模型在其他数据集上表现下降,可以采用模型合并的技术 + +https://github.com/FlagOpen/FlagEmbedding/blob/master/research/LM_Cocktail/README.md + +# [LM-Cocktail: Resilient Tuning of Language Models via Model Merging](https://arxiv.org/abs/2311.13534) + +**Make fine-tuning of language models akin to crafting a nuanced cocktail.** + +Model merging can be used to improve the performance of single model. +We find this method is also useful for large language models and dense embedding model, +and design the LM-Cocktail strategy which automatically merges fine-tuned models and base model using a simple function +to compute merging weights. +LM-Cocktail can be used to improve the performance on target domain without decrease +the general capabilities beyond target domain. +It also can be used to generate a model for new tasks without fine-tuning. +For more details please refer to our report: [LM-Cocktail](https://arxiv.org/abs/2311.13534). + +## Application + +The following are some application scenarios (Note that the models used to merge need to have the same architecture and +the same initialization parameter): + +### 1. Mitigate the problem of Catastrophic Forgetting + +Fine-tuning the base language model could lead to severe degeneration of model’s general capabilities beyond the +targeted domain. +By mixing the fine-tuned model and the base model (use function `mix_models`), LM-Cocktail can significantly enhance +performance in downstream task +while maintaining performance in other unrelated tasks. + +If there are some available models fine-tuned on other tasks, you can further use them to enhance your fine-tuned model. +Firstly, you need to collect five example data from your task, then employ function `mix_models_with_data` to compute +weights and merge available models. +In this way, it can assign lower weights to low-quality models, avoiding degrading the performance on your task. +Finally, use `mix_models` to merge produced model and your fine-tuned model. + +### 2. Improve the performance of new task without fine-tuning + +LM-Cocktail can improve the accuracy of the new task without a requisition to fine-tune a model. +Give a few examples data (e.g., five examples), +and some available models (from open-source community or pre-existing for other tasks), +function `mix_models_wit_data` can automatically assign different merging weights for different model +based their loss in example data, and then merge these available models to generate a task-specific new model. + +### 3. Approximate multitask learning + +If you have some models who are fine-tune on different tasks, you can merge them into one model to approximate multitask +learning. +The merged model can be used to perform multiple tasks. + +## Usage + +Install the latest version from source (Recommended): + +```bash +git clone https://github.com/FlagOpen/FlagEmbedding.git +cd FlagEmbedding/research/LM_Cocktail +pip install -e . +``` + +Install by pip: + +```bash +pip install -U LM_Cocktail +``` + +There are two key functions in LM-Cocktail: + +### 1. Mix models + +`mix_models` can merge models based on the given merging weights. +An example is merging the fine-tuned model and +the base model to mitigate Catastrophic Forgetting after fine-tuning: + +```python +from LM_Cocktail import mix_models, mix_models_with_data + +# mix LLMs and save it to output_path: ./mixed_model_1 +model = mix_models( + model_names_or_paths=["meta-llama/Llama-2-7b-chat-hf", "Shitao/llama2-ag-news"], + model_type='decoder', + weights=[0.7, 0.3], + output_path='./mixed_llm') +# you can select a weight for your models to get a trade-off between generality and expertise. + +# Mix Embedding Models +model = mix_models( + model_names_or_paths=["BAAI/bge-base-en-v1.5", "Shitao/bge-hotpotqa"], + model_type='encoder', + weights=[0.5, 0.5], + output_path='./mixed_embedder') + +# Mix reranker Models +model = mix_models( + model_names_or_paths=["BAAI/bge-reranker-base", "BAAI/bge-reranker-base"], + model_type='reranker', + weights=[0.5, 0.5], + output_path="./mixed_reranker") +``` + +Note that the sum of weights should be equal to 1. + +You also can merge multiple models: + +```python +from LM_Cocktail import mix_models, mix_models_with_data + +model = mix_models( + model_names_or_paths=["BAAI/bge-base-en-v1.5", "Shitao/bge-hotpotqa", "Shitao/bge-quora", "Shitao/bge-msmarco"], + model_type='encoder', + weights=[0.3, 0.2, 0.2, 0.3], + output_path='./mixed_embedder_2') +# The sum of weights should be equal to 1. +``` + +### 2. Mix models with weights computed based on a few examples + +`mix_models_with_data` can compute merging weights based on given data and merge models. +It can be used to produce a model for a new task without training, +or boost the performance for the downstream task by leveraging the knowledge in others models. + +- For LLMs + +The format of `example_data` for LLMs is a list, where each item is a dict like: + +``` +{"input": str, "output": str} +``` + +LM-cocktial will compute the loss of the output. + +You can use the example data to merge models as following: + +```python +from LM_Cocktail import mix_models, mix_models_with_data + +example_data = [ + {"input": "Question: when was the last time anyone was on the moon? Answer:\n", "output": "14 December 1972 UTC"}, + {"input": "Review: \"it 's a charming and often affecting journey . \" Is this movie review sentence negative or positive?\n", "output": "Positive"} +] + +model = mix_models_with_data( + model_names_or_paths=["meta-llama/Llama-2-7b-chat-hf", "Shitao/llama2-ag-news", "Shitao/llama2-nq"], + model_type='decoder', + example_data=example_data, + temperature=5.0) +# you can set the temperature argument to adjust the distribution of mixing weights +``` + +- For Embedder + +The format of `example_data` for LLMs is a list, where each item is a dict like: + +``` +{"query": str, "pos": List[str], 'neg': List[str]} +``` + +where pos is a list of positive text and neg is a list of negative text. LM-Cocktail will compute the contrastive loss. + +You can use the example data to merge models as following: + +```python +from LM_Cocktail import mix_models, mix_models_with_data + +example_data = [ + {"query": "How does one become an actor in the Telugu Film Industry?", "pos": [" How do I become an actor in Telugu film industry?"], "neg": [" What is the story of Moses and Ramesses?", " Does caste system affect economic growth of India?"]}, + {"query": "Why do some computer programmers develop amazing software or new concepts, while some are stuck with basic programming work?", "pos": [" Why do some computer programmers develops amazing softwares or new concepts, while some are stuck with basics programming works?"], "neg": [" When visiting a friend, do you ever think about what would happen if you did something wildly inappropriate like punch them or destroy their furniture?", " What is the difference between a compliment and flirting?"]} +] + +model = mix_models_with_data( + model_names_or_paths=["BAAI/bge-base-en-v1.5", "Shitao/bge-hotpotqa", "Shitao/bge-quora"], + model_type='encoder', + example_data=example_data, + temperature=5.0, + max_input_length=512, + neg_number=2) +``` + +### 3. Mix models layer by layer for reducing memory cost + +The function `mix_models_by_layers` creates temporary directories to store weights of individual models and then merges +them layer by layer. + +This approach helps in reducing the memory consumption. + +Once the merging process is completed, the temporary directories and files will be automatically removed. + +```python +from LM_Cocktail import mix_models_by_layers + +# Mix Large Language Models (LLMs) and save the combined model to the path: ./mixed_llm +model = mix_models_by_layers( + model_names_or_paths=["meta-llama/Llama-2-7b-chat-hf", "Shitao/llama2-ag-news"], + model_type='decoder', + weights=[0.7, 0.3], + output_path='./mixed_llm') +``` + +## Performance + +Detailed results please refer to our report: [LM-Cocktail](https://arxiv.org/abs/2311.13534) + +- LM-Cocktail for Catastrophic Forgetting + +| Model | Target Task | Others(29 tasks) | +|:---------------------------|:-----------:|:----------------:| +| Llama | 40.8 | 46.8 | +| Fine-tuned | 94.4 | 38.6 | +| LM-Cocktail(2 models) [1] | 94.5 | 47.7 | +| LM-Cocktail(10 models) [2] | 94.4 | 48.3 | + +[1]: merge 2 models: fine-tuned model and the base model + +[2]: merge 10 models based on five examples: fine-tuned model, the base model, and 8 models fine-tuned on other tasks + +| Model | Target Task | Other Tasks(14 tasks) | +|:-----------------------|:-----------:|:---------------------:| +| BGE | 71.8 | 49.8 | +| Fine-tuned | 76.0 | 48.5 | +| LM-Cocktail(2 models) | 74.8 | 50.0 | +| LM-Cocktail(10 models) | 74.7 | 50.6 | + +- LM-Cocktail for new tasks without fine-tuning + +Merge 10 models fine-tuned on other tasks based on five examples for new tasks: + +| Model | MMLU(57 tasks) | +|:-----------------------|:--------------:| +| Llama | 45.9 | +| Llama-5shot | 46.7 | +| LM-Cocktail(10 models) | 48.0 | + +| Model | Retrieval(12 tasks) | +|:-----------------------|:-------------------:| +| BGE | 47.3 | +| LM-Cocktail(10 models) | 48.8 | + + + + + diff --git a/RAGSDK/PocValidation/embedding_finetune/finetune.py b/RAGSDK/PocValidation/embedding_finetune/finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..4e29625bfa4e0389694d85c2e0953c988b6a8225 --- /dev/null +++ b/RAGSDK/PocValidation/embedding_finetune/finetune.py @@ -0,0 +1,232 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + +import argparse +import os +import torch +import torch_npu +from datasets import load_dataset +from langchain_community.document_loaders import TextLoader +from langchain_text_splitters import RecursiveCharacterTextSplitter +from loguru import logger +from mx_rag.document import LoaderMng +from mx_rag.document.loader import DocxLoader +from mx_rag.llm import Text2TextLLM +from mx_rag.reranker.local import LocalReranker +from mx_rag.tools.finetune.generator import TrainDataGenerator, DataProcessConfig +from mx_rag.utils import ClientParam +from mx_rag.utils.file_check import FileCheck +from paddle.base import libpaddle +from sentence_transformers import SentenceTransformer +from sentence_transformers import SentenceTransformerTrainer +from sentence_transformers import SentenceTransformerTrainingArguments +from sentence_transformers.evaluation import InformationRetrievalEvaluator +from sentence_transformers.losses import MultipleNegativesRankingLoss +from sentence_transformers.training_args import BatchSamplers + +DEFAULT_LLM_TIMEOUT = 10 * 60 + + +class Finetune: + def __init__(self, + document_path: str, + generate_dataset_path: str, + llm: Text2TextLLM, + embed_model_path: str, + reranker: LocalReranker, + finetune_output_path: str, + featured_percentage: float, + llm_threshold_score: float, + train_question_number: int, + query_rewrite_number: int, + eval_data_path: str, + log_path: str, + max_iter: int, + increase_rate: float): + self.document_path = document_path + self.generate_dataset_path = generate_dataset_path + self.llm = llm + self.embed_model_path = embed_model_path + self.reranker = reranker + self.finetune_output_path = finetune_output_path + + self.featured_percentage = featured_percentage + self.llm_threshold_score = llm_threshold_score + self.train_question_number = train_question_number + self.query_rewrite_number = query_rewrite_number + + self.eval_data_path = eval_data_path + + self.log_path = log_path + self.max_iter = max_iter + self.increase_rate = increase_rate + + def start(self): + # 配置日志文件 + logger.add(self.log_path, rotation="1 MB", retention="10 days", level="INFO", + format="{time} {level} {message}") + train_data_generator = TrainDataGenerator(self.llm, self.generate_dataset_path, self.reranker) + logger.info("--------------------Processing origin document--------------------") + + loader_mng = LoaderMng() + loader_mng.register_loader(loader_class=TextLoader, file_types=[".txt", ".md"]) + loader_mng.register_loader(loader_class=DocxLoader, file_types=[".docx"]) + # 加载文档切分器,使用langchain的 + loader_mng.register_splitter(splitter_class=RecursiveCharacterTextSplitter, + file_types=[".docx", ".txt", ".md"], + splitter_params={"chunk_size": 750, + "chunk_overlap": 150, + "keep_separator": False + } + ) + + split_doc_list = train_data_generator.generate_origin_document(self.document_path, loader_mng=loader_mng) + logger.info("--------------------Calculate origin embedding model recall--------------------") + origin_recall_top5 = self.evaluate("origin_model", self.embed_model_path) + logger.info(f"origin_recall@5: {origin_recall_top5}") + config = DataProcessConfig(question_number=self.train_question_number, + featured_percentage=self.featured_percentage, + llm_threshold_score=self.llm_threshold_score, + query_rewrite_number=self.query_rewrite_number) + iter_num = 1 + while iter_num <= self.max_iter: + logger.info(f'the {iter_num} iteration beginning') + per_data_len = round(len(split_doc_list) // self.max_iter) + end_index = len(split_doc_list) if iter_num == self.max_iter else iter_num * per_data_len + train_doc_list = split_doc_list[:end_index] + logger.info("--------------------Generating training data--------------------") + train_data_generator.generate_train_data(train_doc_list, config) + + logger.info("--------------------Fine-tuning embedding--------------------") + train_data_path = os.path.join(self.generate_dataset_path, "train_data.jsonl") + output_embed_model_path = os.path.join(self.finetune_output_path, 'embedding', str(iter_num)) + if not os.path.exists(output_embed_model_path): + os.makedirs(output_embed_model_path) + FileCheck.dir_check(output_embed_model_path) + self.train_embedding(train_data_path, output_embed_model_path) + logger.info("--------------------Calculate origin embedding model recall--------------------") + finetune_recall_top5 = self.evaluate("finetune_model", output_embed_model_path) + logger.info(f"finetune_recall@5: {finetune_recall_top5}") + recall_increase = (finetune_recall_top5 - origin_recall_top5) / origin_recall_top5 * 100 + logger.info(f'The recall rate of the {iter_num} iteration increases by {recall_increase}%.') + iter_num += 1 + if recall_increase > self.increase_rate or finetune_recall_top5 >= 0.95: + break + if iter_num < self.max_iter: + self.delete_dataset_file() + + def train_embedding(self, train_data_path, output_path): + torch.npu.set_device(torch.device("npu:0")) + model = SentenceTransformer(self.embed_model_path, device="npu" if torch.npu.is_available() else "cpu") + train_loss = MultipleNegativesRankingLoss(model) + train_dataset = load_dataset("json", data_files=train_data_path, split="train") + args = SentenceTransformerTrainingArguments( + output_dir=output_path, # output directory and hugging face model ID + num_train_epochs=4, # number of epochs + per_device_train_batch_size=4, # train batch size + gradient_accumulation_steps=16, # for a global batch size of 512 + warmup_ratio=0.1, # warmup ratio + learning_rate=2e-5, # learning rate, 2e-5 is a good value + lr_scheduler_type="cosine", # use constant learning rate scheduler + optim="adamw_torch_fused", # use fused adamw optimizer + batch_sampler=BatchSamplers.NO_DUPLICATES, + # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch + logging_steps=10, # log every 10 steps + ) + trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset.select_columns(["query", "corpus"]), + loss=train_loss, + ) + trainer.train() + trainer.save_model() + torch.npu.empty_cache() + + def evaluate(self, model_name, model_path): + torch.npu.set_device(torch.device("npu:0")) + model = SentenceTransformer(model_path, device="npu" if torch.npu.is_available() else "cpu") + eval_data = load_dataset("json", data_files=self.eval_data_path, split="train") + eval_data = eval_data.add_column("id", range(len(eval_data))) + corpus = dict( + zip(eval_data["id"], eval_data["corpus"]) + ) + queries = dict( + zip(eval_data["id"], eval_data["query"]) + ) + relevant_docs = {} + for q_id in queries: + relevant_docs[q_id] = [q_id] + evaluator = InformationRetrievalEvaluator(queries=queries, + corpus=corpus, + relevant_docs=relevant_docs, + name=model_name) + result = evaluator(model) + return result[model_name + "_cosine_recall@5"] + + def delete_dataset_file(self): + # 删除dataset下所有文件 + for filename in os.listdir(self.generate_dataset_path): + file_path = os.path.join(self.generate_dataset_path, filename) + # 检查是否是文件 + if os.path.isfile(file_path): + try: + os.remove(file_path) + logger.info(f"delete file success: {file_path}") + except Exception as e: + logger.info(f"delete file occur error:", {file_path} - {e}) + + +class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter): + def _get_default_metavar_for_optional(self, action): + return action.type.__name__ + + def _get_default_metavar_for_positional(self, action): + return action.type.__name__ + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(formatter_class=CustomFormatter) + parser.add_argument("--document_path", type=str, default="", help="语料文档路径,支持doc、txt、md格式") + parser.add_argument("--generate_dataset_path", type=str, default="", help="生成数据保存路径") + parser.add_argument("--llm_url", type=str, default="", help="大模型推理服务地址") + parser.add_argument("--llm_model_name", type=str, default="", help="大模型推理服务对应的模型名称") + parser.add_argument("--use_http", type=bool, default=False, help="是否是http") + parser.add_argument("--embedding_model_path", type=str, default="", help="embedding模型路径") + parser.add_argument("--reranker_model_path", type=str, default="", help="reranker模型路径") + parser.add_argument("--finetune_output_path", type=str, default="", help="微调模型的输出路径") + + parser.add_argument("--featured_percentage", type=float, default=0.8, help="数据精选比例") + parser.add_argument("--llm_threshold_score", type=float, default=0.8, help="大模型打分阈值") + parser.add_argument("--train_question_number", type=int, default=2, help="单个文档切片生成的问题数") + parser.add_argument("--query_rewrite_number", type=int, default=1, help="问题重写次数") + + parser.add_argument("--eval_data_path", type=str, default="", help="评估数据路径") + + parser.add_argument("--log_path", type=str, default='./app.log', help="日志路径") + parser.add_argument("--max_iter", type=int, default=5, help="最大迭代次数") + parser.add_argument("--increase_rate", type=float, default=20, help="召回率提升比例") + + args = parser.parse_args() + + logger.info("Fine-tuning beginning") + client_param = ClientParam(timeout=DEFAULT_LLM_TIMEOUT, use_http=args.use_http) + text_llm = Text2TextLLM(base_url=args.llm_url, model_name=args.llm_model_name, client_param=client_param) + local_reranker = LocalReranker(args.reranker_model_path, dev_id=1) + + finetune = Finetune(args.document_path, + args.generate_dataset_path, + text_llm, + args.embedding_model_path, + local_reranker, + args.finetune_output_path, + args.featured_percentage, + args.llm_threshold_score, + args.train_question_number, + args.query_rewrite_number, + args.eval_data_path, + args.log_path, + args.max_iter, + args.increase_rate) + finetune.start() + logger.info("Fine-tuning ending") diff --git a/RAGSDK/PocValidation/rag_recursive_tree_demo.py b/RAGSDK/PocValidation/rag_recursive_tree_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..88f28d1a56a7cf45c98a0cd4c50e4e037c236c0c --- /dev/null +++ b/RAGSDK/PocValidation/rag_recursive_tree_demo.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + +import argparse +from loguru import logger +from langchain.text_splitter import RecursiveCharacterTextSplitter +from mx_rag.document import LoaderMng +from mx_rag.document.loader import DocxLoader, PdfLoader +from mx_rag.embedding.local import TextEmbedding +from mx_rag.knowledge.handler import upload_files_build_tree +from mx_rag.knowledge.knowledge import KnowledgeStore, KnowledgeTreeDB +from mx_rag.llm import Text2TextLLM +from mx_rag.recursive_tree import TreeBuilderConfig, TreeRetrieverConfig, TreeRetriever, TreeText2TextChain +from mx_rag.recursive_tree.tree_structures import save_tree +from mx_rag.storage.document_store import SQLiteDocstore +from mx_rag.utils import ClientParam +from paddle.base import libpaddle +from transformers import AutoTokenizer + + +class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter): + def _get_default_metavar_for_optional(self, action): + return action.type.__name__ + + def _get_default_metavar_for_positional(self, action): + return action.type.__name__ + + +def rag_recursive_tree_demo(): + try: + parse = argparse.ArgumentParser(formatter_class=CustomFormatter) + parse.add_argument("--embedding_path", type=str, default="/home/data/acge_text_embedding", + help="embedding模型本地路径") + parse.add_argument("--white_path", type=str, nargs='+', default=["/home"], help="文件白名单路径") + parse.add_argument("--file_path", type=str, + default="/home/HwHiAiUser/MindIE 1.0.RC3 安装指南 01.pdf", + help="要上传的文件路径,需在白名单路径下") + parse.add_argument("--llm_url", type=str, default="http://127.0.0.1:1025/v1/chat/completions", + help="大模型url地址") + parse.add_argument("--query", type=str, default="请介绍MindIE容器化部署和制造镜像的步骤。", help="用户问题") + parse.add_argument("--model_name", type=str, default="Llama3-8B-Chinese-Chat", help="大模型名称") + parse.add_argument("--tokenizer_path", type=str, default="/home/data/Llama3-8B-Chinese-Chat/", + help="大模型tokenizer参数路径") + args = parse.parse_args().__dict__ + + embedding_path: str = args.pop('embedding_path') + white_path: list[str] = args.pop('white_path') + file_path: str = args.pop('file_path') + llm_url: str = args.pop('llm_url') + query: str = args.pop('query') + model_name: str = args.pop('model_name') + tokenizer_path: str = args.pop('tokenizer_path') + + # Step1离线构建知识库,首先注册文档处理器 + loader_mng = LoaderMng() + # 加载文档加载器,可以使用mxrag自有的,也可以使用langchain的 + loader_mng.register_loader(loader_class=PdfLoader, file_types=[".pdf"]) + loader_mng.register_loader(loader_class=DocxLoader, file_types=[".docx"]) + # 加载文档切分器,使用langchain的 + loader_mng.register_splitter(splitter_class=RecursiveCharacterTextSplitter, + file_types=[".pdf", ".docx"]) + # 设置向量检索使用的npu卡,具体可以用的卡可执行npu-smi info查询获取 + dev = 0 + # 加载embedding模型,请根据模型具体路径适配 + text_emb = TextEmbedding(model_path=embedding_path, dev_id=dev) + # 初始化文档chunk关系数据库 + document_store = SQLiteDocstore(db_path="./sql.db") + # 初始化TreeText2TextChain实例,具体ip、端口、llm请根据实际情况修改。 + # 在构建树过程中总结摘要时会使用,最后问答也会使用,问答调用前设置tree_retriever。 + tree_chain = TreeText2TextChain( + llm=Text2TextLLM(base_url=llm_url, model_name=model_name, + client_param=ClientParam(use_http=True, timeout=600))) + # 使用模型的tokenizer + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, local_files_only=True) + # 初始化递归树构建的参数 + tree_builder_config = TreeBuilderConfig(tokenizer=tokenizer, summarization_model=tree_chain) + # 初始化递归树知识管理 + knowledge = KnowledgeTreeDB(KnowledgeStore("./sql.db"), chunk_store=document_store, knowledge_name="test", + white_paths=white_path, tree_builder_config=tree_builder_config) + # 上传领域知识文档,方法会返回构建树实例,当前仅支持同时上传一个文件 + tree = upload_files_build_tree(knowledge, file_path, loader_mng=loader_mng, + embed_func=text_emb.embed_documents, force=True) + # 初始化递归树检索器配置参数 + tree_retriver_config = TreeRetrieverConfig(tokenizer=tokenizer, embed_func=text_emb.embed_documents, + collapse_tree=False, top_k=3) + # 初始化递归树检索器 + tree_retriever = TreeRetriever(tree_retriver_config, tree) + # 设置TreeText2TextChain的检索器 + tree_chain.set_tree_retriever(tree_retriever) + # 知识问答 + answer = tree_chain.query(query, max_tokens=1000) + # 打印结果 + logger.info(answer) + # 递归树Tree实例序列化保存为json文件,使用load_tree方法反序列化 + save_path = "./tree.json" + save_tree(tree, save_path) + except Exception as e: + logger.error(f"run demo failed: {e}") + + +if __name__ == '__main__': + rag_recursive_tree_demo() diff --git a/RAGSDK/README.md b/RAGSDK/README.md new file mode 100644 index 0000000000000000000000000000000000000000..75db8ccbb053b63fd8975f01c98ab75838866b4a --- /dev/null +++ b/RAGSDK/README.md @@ -0,0 +1,7 @@ +## 目录结构与说明 + +| 目录 | 说明 | +|---------------|-----------| +| MainRepo | 商用版本参考样例 | +| PocValidation | POC版本参考样例 | + diff --git a/README.md b/README.md index aea12ad2072e05ce78654ca0beef7d8ad58539e4..5890cea99ec6303d3aa48afd40af85ef90bca723 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,9 @@ EN|[CN](README.zh.md) # MindSDK Reference Apps -[MindSDK](https://www.hiascend.com/software/mindx-sdk) is a software development kit (SDK) launched by Huawei, offering simple and easy-to-use, high-performance APIs and tools. It includes multiple SDKs such as Vision SDK (visual analysis), Index SDK (feature retrieval), and Rec SDK (search recommendation), aiding the Ascend AI processor in empowering various application scenarios. +[MindSDK](https://www.hiascend.com/software/mindx-sdk) is a software development kit (SDK) launched by Huawei, offering simple and easy-to-use, high-performance APIs and tools. It includes multiple SDKs such as Vision SDK (visual analysis), Index SDK (feature retrieval), and RAG SDK (retrieval augmented), aiding the Ascend AI processor in empowering various application scenarios. -To help developers quickly master the use of Vision SDK and Index SDK interfaces and rapidly implement business functions, this code repository provides various reference samples developed based on Vision SDK and Index SDK. The Agent SDK provides interfaces to build agent application. Users can select the appropriate sample code according to their needs. +To help developers quickly master the use of Vision SDK 、 Index SDK and RAG SDK interfaces and rapidly implement business functions, this code repository provides various reference samples developed based on Vision SDK 、 Index SDK and RAG SDK. Users can select the appropriate sample code according to their needs. ## Main Directory Structure and Description | 目录 | 说明 | |---|-----------------------------------------------| @@ -12,7 +12,7 @@ To help developers quickly master the use of Vision SDK and Index SDK interfaces | [VisionSDK](./VisionSDK) | Vision SDK official sample directory | | [tutorials](./tutorials) | Vision SDK official tutorials directory | | [IndexSDK](./IndexSDK) | Index SDK offical sample directory | -| [AgentSDK](./AgentSDK) | Agent SDK and offical sample directory | +| [RAGSDK](./RAGSDK) | RAG SDK and offical sample directory | ## Related Websites diff --git a/README.zh.md b/README.zh.md index 8474d4e1bb9cca57979f7e0f31f0652a67367bc4..c5f39ed139463dc07a56690becf0e51f73479f93 100644 --- a/README.zh.md +++ b/README.zh.md @@ -1,9 +1,9 @@ 中文|[英文](README.md) # MindSDK Reference Apps -[MindSDK](https://www.hiascend.com/software/mindx-sdk) 是华为推出的软件开发套件(SDK),提供极简易用、高性能的API和工具,包含Vision SDK(视觉分析)、Index SDK(特征检索)、Rec SDK(搜索推荐)等多个SDK,助力昇腾AI处理器赋能各应用场景。 +[MindSDK](https://www.hiascend.com/software/mindx-sdk) 是华为推出的软件开发套件(SDK),提供极简易用、高性能的API和工具,包含Vision SDK(视觉分析)、Index SDK(特征检索)、RAG SDK(检索增强)等多个SDK,助力昇腾AI处理器赋能各应用场景。 -为助力开发者快速掌握Vision SDK和Index SDK接口的使用、快速实现业务功能,本代码仓提供了基于Vision SDK和Index SDK开发的各类参考样例。用户可根据自身需求选择相应案例代码。 +为助力开发者快速掌握Vision SDK、Index SDK和RAG SDK接口的使用、快速实现业务功能,本代码仓提供了基于Vision SDK、Index SDK和RAG SDK开发的各类参考样例。用户可根据自身需求选择相应案例代码。 ## 主要目录结构与说明 @@ -14,7 +14,7 @@ | [VisionSDK](./VisionSDK) | Vision SDK官方应用样例目录 | | [tutorials](./tutorials) | Vision SDK官方开发样例和文档参考工程目录 | | [IndexSDK](./IndexSDK) | Index SDK参考样例目录 | - +| [RAGSDK](./RAGSDK) | RAG SDK参考样例目录 | ## 相关网站 昇腾社区鼓励开发者多交流,共学习。开发者可以通过昇腾社区网站获取最新的MindSDK的软件、文档等资源;可以通过昇腾论坛与其他开发者交流开发经验。