diff --git a/mxRAG/MainRepo/patches/BertSelfAttentionFast/README.md b/mxRAG/MainRepo/patches/BertSelfAttentionFast/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7350b90b949419e466f203104022e46ffd6a25f3 --- /dev/null +++ b/mxRAG/MainRepo/patches/BertSelfAttentionFast/README.md @@ -0,0 +1,82 @@ +# BertSelfAttention融合算子使用说明 + +## BertSelfAttention融合算子注册 + +### 步骤1 +root用户默认安装路径下配置环境变量 +```sh +source /usr/local/Ascend/ascend-toolkit/set_env.sh +``` +非root用户默认安装路径下配置环境变量 +```sh +source /home/{当前用户名}/Ascend/ascend-toolkit/set_env.sh +``` + +### 步骤2 +在ops目录下运行run包注册算子 +Arm架构对应custom_opp_aarch64.run +x86_64架构对应custom_opp_x86_64.run +```sh +cd ops +./build.sh +cd BertSelfAttention/build_out/ +./custom_opp_{arch}.run +``` + +### 步骤3 +在ops目录下拉取ascend/op-plugin工程,并执行一键式注入脚本 +运行一键式注入脚本会将op_plugin_patch文件夹中的文件注入到op-plugin工程中,然后编译出包含融合算子的torch_npu包并替换原有的torch_npu包 +```sh +cd ops +git clone https://gitee.com/ascend/op-plugin.git +bash run_op_plugin.sh +``` + +#### 表1 custom_opp_{arch}.run参数说明 +| 参数 | 说明 | +|--------------|-----------| +| --help \| -h | 显示帮助信息 | +| --info | 打印安装包的信息。 | +| --list | 打印安装包的文件列表。 | +| --check | 检测压缩包完整性。 | +| --quiet | 不打印解压过程中的非错误信息。| +| --install-path | 安装到指定路径。| + +### 步骤4 +在mindxsdk-mxrag/patches/BertSelfAttentionFast 目录下运行一键补丁脚本,给transformers的bert_model.py打一个补丁使能融合算子 +```sh +bash bertSAFast_patch.sh +``` + +## 安装transformers库补丁 + +### 安装补丁步骤 +```bash +bash bertSAFast_patch.sh +``` + +### 注意事项 +1. 安装patch前,请先设置CANN环境变量 +```sh + source [cann安装路径](默认为/usr/local/Ascend/ascend-toolkit)/set_env.sh +``` +2. 安装patch后使用bert_model请确保已经完成了自定义算子的注册,运行run包注册。 + +```sh + custom_opp_aarch64.run +``` +3. 出现 aclnnBertSelfAttention not find的问题时请查看是否有设置LD_LIBRARY_PATH环境变量如下: +```sh +export LD_LIBRARY_PATH=$ASCEND_OPP_PATH/vendors/mxRAG/op_api/lib/:$LD_LIBRARY_PATH +``` + +4. 如果出现protobuf 需要降低版本问题 则需要将pip3 install protobuf==3.19.0版本,安装完成之后可以还原 + +## 版本依赖 +| 软件 | 版本要求 | +|--------------|-----------| +| pytorch | == 2.1.0 | +| python | >= 3.8.0 | +| transformers | >= 4.34.1 | + + diff --git a/mxRAG/MainRepo/patches/BertSelfAttentionFast/bertSAFast_patch.sh b/mxRAG/MainRepo/patches/BertSelfAttentionFast/bertSAFast_patch.sh new file mode 100644 index 0000000000000000000000000000000000000000..f8c628753347a6c056a49ab8da2479ad7f2bf777 --- /dev/null +++ b/mxRAG/MainRepo/patches/BertSelfAttentionFast/bertSAFast_patch.sh @@ -0,0 +1,8 @@ +#!/bin/bash +set -e + +PATCH_DIR=$(cd $(dirname $0); pwd) + +# transformers打patch +TRANSFORMER_PACKAGE_PATH=$(python3 -c 'import transformers; import os; print(os.path.dirname(transformers.__file__))') +patch -p1 $TRANSFORMER_PACKAGE_PATH/models/bert/modeling_bert.py < $PATCH_DIR/transformers_bert.patch \ No newline at end of file diff --git a/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/BertSelfAttention/CMakePresets.patch b/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/BertSelfAttention/CMakePresets.patch new file mode 100644 index 0000000000000000000000000000000000000000..d9a73977b0fc063a8f901c0cd9a4509b1090eb5f --- /dev/null +++ b/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/BertSelfAttention/CMakePresets.patch @@ -0,0 +1,17 @@ +22c22 +< "value": "True" +--- +> "value": "Fasle" +34c34 +< "value": "True" +--- +> "value": "False" +42c42 +< "value": "/usr/local/Ascend/latest" +--- +> "value": "/usr/local/Ascend/ascend-toolkit/latest" +63c63 +< } +\ No newline at end of file +--- +> } diff --git a/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/BertSelfAttention/cmake.patch b/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/BertSelfAttention/cmake.patch new file mode 100644 index 0000000000000000000000000000000000000000..ed5e38a2d0f5ea1f69a70c097779306ac3768ae8 --- /dev/null +++ b/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/BertSelfAttention/cmake.patch @@ -0,0 +1,154 @@ +--- makeself-header_origin.sh 2024-07-25 23:04:33.000000000 -0400 ++++ makeself-header.sh 2024-07-30 04:46:00.000000000 -0400 +@@ -3,15 +3,13 @@ + # This script was generated using Makeself $MS_VERSION + # The license covering this archive and its contents, if any, is wholly independent of the Makeself license (GPL) + # 2022.3.19-Modified the MS_Help function and some options +-# Huawei Technologies Co., Ltd. ++# Huawei Technologies Co., Ltd. + + ORIG_UMASK=\`umask\` + +-CRCsum="$CRCsum" +-MD5="$MD5sum" + SHA="$SHAsum" + SIGNATURE="$Signature" +-TMPROOT=\${TMPDIR:="\$HOME"} ++TMPROOT="\$PWD" + if ! test -d "\$TMPROOT"; then + TMPROOT="\$PWD" + fi +@@ -172,11 +170,6 @@ + --list Print the list of files in the archive + --check Checks integrity and version dependency of the archive + --quiet Quiet install mode, skip human-computer interactions +- --nox11 Do not spawn an xterm +- --noexec Do not run embedded script +- --extract= Extract directly to a target directory (absolute or relative) +- Usually used with --noexec to just extract files without running +- --tar arg1 [arg2 ...] Access the contents of the archive through the tar command + \${helpheader} + EOH + } +@@ -208,16 +201,12 @@ + + MS_Check() + { +- OLD_PATH="\$PATH" +- PATH=\${GUESS_MD5_PATH:-"\$OLD_PATH:/bin:/usr/bin:/sbin:/usr/local/ssl/bin:/usr/local/bin:/opt/openssl/bin"} +- MD5_ARG="" +- MD5_PATH=\`exec <&- 2>&-; which md5sum || command -v md5sum || type md5sum\` +- test -x "\$MD5_PATH" || MD5_PATH=\`exec <&- 2>&-; which md5 || command -v md5 || type md5\` +- test -x "\$MD5_PATH" || MD5_PATH=\`exec <&- 2>&-; which digest || command -v digest || type digest\` +- PATH="\$OLD_PATH" +- + SHA_PATH=\`exec <&- 2>&-; which shasum || command -v shasum || type shasum\` + test -x "\$SHA_PATH" || SHA_PATH=\`exec <&- 2>&-; which sha256sum || command -v sha256sum || type sha256sum\` ++ if ! test -x "\$SHA_PATH"; then ++ MS_Printf "Error in checksums, There is no sha sum tool in system, please install first." >&2 ++ exit 2 ++ fi + + if test x"\$quiet" = xn; then + MS_Printf "Verifying archive integrity..." +@@ -232,7 +221,6 @@ + i=1 + for s in \$filesizes + do +- crc=\`echo \$CRCsum | cut -d" " -f\$i\` + if test -x "\$SHA_PATH"; then + if test x"\`basename \$SHA_PATH\`" = xshasum; then + SHA_ARG="-a 256" +@@ -251,35 +239,6 @@ + crc="0000000000"; + fi + fi +- if test -x "\$MD5_PATH"; then +- if test x"\`basename \$MD5_PATH\`" = xdigest; then +- MD5_ARG="-a md5" +- fi +- md5=\`echo \$MD5 | cut -d" " -f\$i\` +- if test x"\$md5" = x00000000000000000000000000000000; then +- test x"\$verb" = xy && echo " \$1 does not contain an embedded MD5 checksum." >&2 +- else +- md5sum=\`MS_dd_Progress "\$1" \$offset \$s | eval "\$MD5_PATH \$MD5_ARG" | cut -b-32\`; +- if test x"\$md5sum" != x"\$md5"; then +- echo "Error in MD5 checksums: \$md5sum is different from \$md5" >&2 +- exit 2 +- elif test x"\$quiet" = xn; then +- MS_Printf " MD5 checksums are OK." >&2 +- fi +- crc="0000000000"; verb=n +- fi +- fi +- if test x"\$crc" = x0000000000; then +- test x"\$verb" = xy && echo " \$1 does not contain a CRC checksum." >&2 +- else +- sum1=\`MS_dd_Progress "\$1" \$offset \$s | CMD_ENV=xpg4 cksum | awk '{print \$1}'\` +- if test x"\$sum1" != x"\$crc"; then +- echo "Error in checksums: \$sum1 is different from \$crc" >&2 +- exit 2 +- elif test x"\$quiet" = xn; then +- MS_Printf " CRC checksums are OK." >&2 +- fi +- fi + i=\`expr \$i + 1\` + offset=\`expr \$offset + \$s\` + done +@@ -406,56 +365,11 @@ + done + exit 0 + ;; +- --tar) +- offset=\`head -n "\$skip" "\$0" | wc -c | tr -d " "\` +- arg1="\$2" +- shift 2 || { MS_Help; exit 1; } +- for s in \$filesizes +- do +- MS_dd "\$0" \$offset \$s | MS_Decompress | tar "\$arg1" - "\$@" +- offset=\`expr \$offset + \$s\` +- done +- exit 0 +- ;; + --check) + MS_Check "\$0" y + scriptargs="\$scriptargs \$1" + shift + ;; +- --noexec) +- script="" +- cleanup_script="" +- shift +- ;; +- --extract=*) +- keep=y +- targetdir=\`echo \$1 | cut -d"=" -f2 \` +- if ! shift; then MS_Help; exit 1; fi +- ;; +- --nox11) +- nox11=y +- shift +- ;; +- --xwin) +- if test "$NOWAIT" = n; then +- finish="echo Press Return to close this window...; read junk" +- fi +- xterm_loop=1 +- shift +- ;; +- --phase2) +- copy=phase2 +- shift +- ;; +- --repack | --repack-path=*) +- Script_Args_Check \$1 +- scriptargs="\$scriptargs '\$1'" +- shift +- if [[ ! "\$1" =~ ^-.* ]]; then +- scriptargs="\$scriptargs '\$1'" +- shift +- fi +- ;; + *) + Script_Args_Check \$1 + scriptargs="\$scriptargs '\$1'" \ No newline at end of file diff --git a/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/BertSelfAttention/op_host/bert_self_attention.cpp b/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/BertSelfAttention/op_host/bert_self_attention.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1524e9a790949c83ff3713e22dc15cb6293ab73a --- /dev/null +++ b/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/BertSelfAttention/op_host/bert_self_attention.cpp @@ -0,0 +1,450 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + */ + + + + +#include +#include "bert_self_attention_tiling.h" +#include "register/op_def_registry.h" +#include "tiling/tiling_api.h" +#include "tiling/softmax/softmax_tiling.h" + +template +inline auto DivUp(U a, V b) -> decltype(a + b) +{ + return ((a + b - 1) / b); +} + +using half = uint16_t; + +namespace { +constexpr uint32_t UB_BLOCK_BYTE_SIZE = 32; +constexpr uint32_t BASESEQLENGTH = 16; +constexpr uint32_t USERSPACE_QKVPROJECTED_NUM = 6; +constexpr uint32_t TRAVSERSE_TYPE_THRESHOLD = 2; +constexpr uint32_t TENSOR_DIM_ONE = 0; +constexpr uint32_t TENSOR_DIM_TWO = 1; +constexpr uint32_t TENSOR_DIM_THR = 2; + +template +uint32_t BaseSeqHeadCompute(uint32_t &baseSeqLength, const uint32_t sequenceLength, uint64_t ub_size) +{ + std::vector softMaxShapeVec = { sequenceLength, sequenceLength }; + ge::Shape srcShape(softMaxShapeVec); + uint32_t softMaxWorkSpaceSize = AscendC::GetSoftMaxMinTmpSize(srcShape, sizeof(T), false); + uint32_t realUbsize = ub_size - softMaxWorkSpaceSize; + uint32_t typeSize = sizeof(T); + baseSeqLength = BASESEQLENGTH; + uint32_t tmpBaseSeqLength; + uint32_t usedSpace = baseSeqLength * sequenceLength * typeSize + sequenceLength * typeSize + + baseSeqLength * sequenceLength / 8; + do { + tmpBaseSeqLength = baseSeqLength + BASESEQLENGTH; + if (tmpBaseSeqLength <= sequenceLength) { + uint32_t tmpUsedSpace = tmpBaseSeqLength * sequenceLength * typeSize + sequenceLength * typeSize + + tmpBaseSeqLength * sequenceLength / 8; + if (tmpUsedSpace <= realUbsize) { + usedSpace = tmpUsedSpace; + baseSeqLength = tmpBaseSeqLength; + } + } + } while (tmpBaseSeqLength == baseSeqLength); + return usedSpace; +} +} // namespace + +namespace optiling { +template +static ge::graphStatus TilingBasic(gert::TilingContext *context, BertSelfAttentionTilingData &tilingData) +{ + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + const gert::StorageShape *input_shape = context->GetInputShape(0); + auto attrs = context->GetAttrs(); + if (attrs == nullptr) { + printf("attrs null\n"); + return ge::GRAPH_FAILED; + } + auto batchSize = input_shape->GetStorageShape().GetDim(0); + auto sequenceLength = input_shape->GetStorageShape().GetDim(1); + auto featuresSize = input_shape->GetStorageShape().GetDim(2); + auto headNum = *(attrs->GetAttrPointer(0)); + auto headSize = *(attrs->GetAttrPointer(1)); + auto dropOutKeepProb = *(attrs->GetAttrPointer(2)); + + float sqrtHeadSize = sqrt(headSize); + + //feature byte size must 32 byte mutiple + uint32_t featuresByteSize = featuresSize * sizeof(T); + if ((featuresByteSize % UB_BLOCK_BYTE_SIZE) != 0) { + printf("featuresByteSize is not 32 byte mutiple\n"); + return ge::GRAPH_FAILED; + } + // part2 begin + uint32_t baseSeqLength; + uint64_t ub_size; + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ub_size); + uint32_t usedSpace = BaseSeqHeadCompute(baseSeqLength, sequenceLength, ub_size); + uint32_t softMaxWorkSpaceSize = ub_size - usedSpace; + uint32_t tailSeqLength = sequenceLength % baseSeqLength; + if (tailSeqLength == 0) { + tailSeqLength = baseSeqLength; + } + + std::vector softMaxShapeVec = { baseSeqLength, sequenceLength }; + ge::Shape srcShape(softMaxShapeVec); + AscendC::SoftMaxTilingFunc(srcShape, sizeof(T), softMaxWorkSpaceSize, tilingData.softMaxTilingData); + + tilingData.set_baseSeqLength(baseSeqLength); + tilingData.set_tailSeqLength(tailSeqLength); + tilingData.set_softMaxWorkSpaceSize(softMaxWorkSpaceSize); + //part2 end todo + + tilingData.set_batchSize(batchSize); + tilingData.set_sequenceLength(sequenceLength); + tilingData.set_featuresSize(featuresSize); + tilingData.set_headNum(headNum); + tilingData.set_headSize(headSize); + + tilingData.set_sqrtHeadSize(sqrtHeadSize); + tilingData.set_dropOutKeepProb(dropOutKeepProb); + + return ge::GRAPH_SUCCESS; +} + +template +static ge::graphStatus TilingCore(gert::TilingContext *context, BertSelfAttentionTilingData &tilingData) +{ + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + auto aicNum = ascendcPlatform.GetCoreNumAic(); + auto aivNum = ascendcPlatform.GetCoreNumAiv(); + auto B = tilingData.get_batchSize(); + auto N = tilingData.get_headNum(); + auto S = tilingData.get_sequenceLength(); + auto h = tilingData.get_headSize(); + + tilingData.set_aicNum(aicNum); + tilingData.set_aivNum(aivNum); + auto usedAivCoreNum = (B * N >= aivNum) ? aivNum : B * N; + tilingData.set_usedAivCoreNum(usedAivCoreNum); + auto fomerBatchNum = B * N % aivNum; + auto perBlockBatchFormer = DivUp(B * N, aivNum); + auto perBlockBatchLatter = B * N / aivNum; + tilingData.set_fomerBatchNum(fomerBatchNum); + tilingData.set_perBlockBatchFormer(perBlockBatchFormer); + tilingData.set_perBlockBatchLatter(perBlockBatchLatter); + + return ge::GRAPH_SUCCESS; +} + +template +static ge::graphStatus TilingCube_Projection(gert::TilingContext *context, BertSelfAttentionTilingData &tilingData) +{ + using namespace matmul_tiling; + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + auto aicNum = ascendcPlatform.GetCoreNumAic(); + auto aivNum = ascendcPlatform.GetCoreNumAiv(); + context->SetBlockDim(aicNum); + + MultiCoreMatmulTiling cubeTilingEachCore(ascendcPlatform); + auto M = tilingData.get_sequenceLength() * tilingData.get_batchSize(); + auto N = tilingData.get_featuresSize(); + auto K = tilingData.get_featuresSize(); + cubeTilingEachCore.SetDim(aivNum); + if (std::is_same::value) { + cubeTilingEachCore.SetAType(TPosition::GM, CubeFormat::ND, DataType::DT_FLOAT); + cubeTilingEachCore.SetBType(TPosition::GM, CubeFormat::ND, DataType::DT_FLOAT); + cubeTilingEachCore.SetCType(TPosition::GM, CubeFormat::ND, DataType::DT_FLOAT); + cubeTilingEachCore.SetBiasType(TPosition::GM, CubeFormat::ND, DataType::DT_FLOAT); + } else { + cubeTilingEachCore.SetAType(TPosition::GM, CubeFormat::ND, DataType::DT_FLOAT16); + cubeTilingEachCore.SetBType(TPosition::GM, CubeFormat::ND, DataType::DT_FLOAT16); + cubeTilingEachCore.SetCType(TPosition::GM, CubeFormat::ND, DataType::DT_FLOAT16); + cubeTilingEachCore.SetBiasType(TPosition::GM, CubeFormat::ND, DataType::DT_FLOAT16); + } + cubeTilingEachCore.SetShape(M, N, K); + cubeTilingEachCore.SetOrgShape(M, N, K); + cubeTilingEachCore.SetBufferSpace(-1, -1, -1); + cubeTilingEachCore.SetBias(true); + if (tilingData.get_batchSize() < TRAVSERSE_TYPE_THRESHOLD){ + cubeTilingEachCore.SetTraverse(MatrixTraverse::FIRSTN); + } else { + cubeTilingEachCore.SetTraverse(MatrixTraverse::FIRSTM); + } + int retTrans = cubeTilingEachCore.GetTiling(tilingData.projectionTiling); + if (retTrans == -1) { + printf("cube tiling each core error\n"); + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +template +static ge::graphStatus TilingCube(gert::TilingContext *context, BertSelfAttentionTilingData &tilingData) +{ + using namespace matmul_tiling; + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + auto M = tilingData.get_sequenceLength() * tilingData.get_batchSize(); + auto N = tilingData.get_featuresSize(); + auto K = tilingData.get_featuresSize(); + + TilingCube_Projection(context, tilingData); + + auto B = tilingData.get_batchSize(); + auto S = tilingData.get_sequenceLength(); + auto h = tilingData.get_headSize(); + MatmulApiTiling cubeTilingScores(ascendcPlatform); //为score部分做batch matmul的tiling + if (std::is_same::value){ + cubeTilingScores.SetAType(TPosition::GM, CubeFormat::ND, DataType::DT_FLOAT); + cubeTilingScores.SetBType(TPosition::GM, CubeFormat::ND, DataType::DT_FLOAT, true); + cubeTilingScores.SetCType(TPosition::GM, CubeFormat::ND, DataType::DT_FLOAT); + } else { + cubeTilingScores.SetAType(TPosition::GM, CubeFormat::ND, DataType::DT_FLOAT16); + cubeTilingScores.SetBType(TPosition::GM, CubeFormat::ND, DataType::DT_FLOAT16, true); + cubeTilingScores.SetCType(TPosition::GM, CubeFormat::ND, DataType::DT_FLOAT16); + } + cubeTilingScores.SetShape(S, S, h); + cubeTilingScores.SetOrgShape(S, S, h); + cubeTilingScores.SetBufferSpace(-1, -1, -1); + cubeTilingScores.SetBias(false); + int ret = cubeTilingScores.GetTiling(tilingData.scoresCubeTiling); + if (ret == -1){ + printf("cube tiling each core error\n"); + return ge::GRAPH_FAILED; + } + + M = tilingData.get_baseSeqLength(); + N = tilingData.get_headSize(); + K = tilingData.get_sequenceLength(); + MatmulApiTiling cubeTilingProbs(ascendcPlatform); + if (std::is_same::value) { + cubeTilingProbs.SetAType(TPosition::GM, CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT); + cubeTilingProbs.SetBType(TPosition::GM, CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT); + cubeTilingProbs.SetCType(TPosition::VECIN, CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT); + } else { + cubeTilingProbs.SetAType(TPosition::GM, CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16); + cubeTilingProbs.SetBType(TPosition::GM, CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16); + cubeTilingProbs.SetCType(TPosition::VECIN, CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16); + } + cubeTilingProbs.SetShape(M, N, K); + cubeTilingProbs.SetOrgShape(M, N, K); + cubeTilingProbs.SetBias(false); + cubeTilingProbs.SetBufferSpace(-1, -1, -1); + if (cubeTilingProbs.GetTiling(tilingData.probsCubeTiling) == -1) { + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus TilingFuncFP32(gert::TilingContext *context, BertSelfAttentionTilingData &tiling) +{ + context->SetTilingKey(0); + auto ret = TilingBasic(context, tiling); + if (ret != ge::GRAPH_SUCCESS) { + printf("TilingBasic error\n"); + return ret; + } + + ret = TilingCore(context, tiling); + if (ret != ge::GRAPH_SUCCESS) { + printf("TilingCore error\n"); + return ret; + } + + ret = TilingCube(context, tiling); + if (ret != ge::GRAPH_SUCCESS) { + printf("TilingCube error\n"); + return ret; + } + + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus TilingFuncFP16(gert::TilingContext *context, BertSelfAttentionTilingData &tiling) +{ + context->SetTilingKey(1); + auto ret = TilingBasic(context, tiling); + if (ret != ge::GRAPH_SUCCESS) { + printf("TilingBasic error\n"); + return ret; + } + + ret = TilingCore(context, tiling); + if (ret != ge::GRAPH_SUCCESS) { + printf("TilingCore error\n"); + return ret; + } + + ret = TilingCube(context, tiling); + if (ret != ge::GRAPH_SUCCESS) { + printf("TilingCube error\n"); + return ret; + } + + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus TilingFunc(gert::TilingContext *context) +{ + if (context == nullptr) { + printf("context can't be nullptr\n"); + return ge::GRAPH_FAILED; + } + auto aTensor = context->GetInputTensor(0); + if (aTensor == nullptr) { + printf("aTensor is nullptr\n"); + return ge::GRAPH_FAILED; + } + ge::DataType dtype = aTensor->GetDataType(); + size_t usrSize = 1; + BertSelfAttentionTilingData tiling; + if (dtype==ge::DT_FLOAT){ + auto ret = TilingFuncFP32(context, tiling); + if (ret != ge::GRAPH_SUCCESS) { + printf("TilingFuncFP32 error\n"); + return ret; + } + usrSize *= sizeof(float); + } else if (dtype == ge::DT_FLOAT16) { + auto ret = TilingFuncFP16(context, tiling); + if (ret != ge::GRAPH_SUCCESS) { + printf("TilingFuncFP16 error\n"); + return ret; + } + usrSize *= sizeof(half); + } else { + printf("dtype error\n"); + return ge::GRAPH_FAILED; + } + + if (context->GetRawTilingData() == nullptr) { + printf("RawTilingData can't be nullptr\n"); + return ge::GRAPH_FAILED; + } + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + auto B = tiling.get_batchSize(); + auto S = tiling.get_sequenceLength(); + auto H = tiling.get_featuresSize(); + auto N = tiling.get_headNum(); + usrSize *= USERSPACE_QKVPROJECTED_NUM * B * S * H + B * N * S * S; //设置用户需要使用的workspace大小为QKV投影及其transpose结果存储空间。 + // 如需要使用系统workspace需要调用GetLibWorkSpaceSize获取系统workspace的大小。 + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + uint32_t sysWorkspaceSize = ascendcPlatform.GetLibApiWorkSpaceSize(); + size_t *currentWorkspace = context->GetWorkspaceSizes(1); + if (currentWorkspace == nullptr) { + printf("work space null\n"); + return ge::GRAPH_FAILED; + } + currentWorkspace[0] = sysWorkspaceSize + usrSize; + + return ge::GRAPH_SUCCESS; +} +} //namespace optiling + +namespace ge { +static ge::graphStatus InferShape(gert::InferShapeContext *context) +{ + auto attrs = context->GetAttrs(); + if (attrs == nullptr) { + printf("attrs null\n"); + return ge::GRAPH_FAILED; + } + auto headNum = *(attrs->GetAttrPointer(0)); + auto headSize = *(attrs->GetAttrPointer(1)); + + const gert::Shape *x1_shape = context->GetInputShape(0); + auto B = x1_shape->GetDim(0); + auto S = x1_shape->GetDim(1); + auto H = x1_shape->GetDim(2); + if (headNum * headSize != H) { + printf("headNum * headSize != featureSize\n"); + return ge::GRAPH_FAILED; + } + + gert::Shape *outValue_shape = context->GetOutputShape(0); + outValue_shape->SetDim(TENSOR_DIM_ONE, B); + outValue_shape->SetDim(TENSOR_DIM_TWO, S); + outValue_shape->SetDim(TENSOR_DIM_THR, H); + + return GRAPH_SUCCESS; +} +} // namespace ge + +namespace ops { +class BertSelfAttention : public OpDef { +public: + explicit BertSelfAttention(const char *name) : OpDef(name) + { + this->Input("input") + .ParamType(REQUIRED) + .DataType({ ge::DT_FLOAT, ge::DT_FLOAT16 }) + .Format({ ge::FORMAT_ND, ge::FORMAT_ND }) + .UnknownShapeFormat({ ge::FORMAT_ND, ge::FORMAT_ND }); + this->Input("queryW") + .ParamType(REQUIRED) + .DataType({ ge::DT_FLOAT, ge::DT_FLOAT16 }) + .Format({ ge::FORMAT_ND, ge::FORMAT_ND }) + .UnknownShapeFormat({ ge::FORMAT_ND, ge::FORMAT_ND }); + this->Input("queryBias") + .ParamType(REQUIRED) + .DataType({ ge::DT_FLOAT, ge::DT_FLOAT16 }) + .Format({ ge::FORMAT_ND, ge::FORMAT_ND }) + .UnknownShapeFormat({ ge::FORMAT_ND, ge::FORMAT_ND }); + this->Input("keyW") + .ParamType(REQUIRED) + .DataType({ ge::DT_FLOAT, ge::DT_FLOAT16 }) + .Format({ ge::FORMAT_ND, ge::FORMAT_ND }) + .UnknownShapeFormat({ ge::FORMAT_ND, ge::FORMAT_ND }); + this->Input("keyBias") + .ParamType(REQUIRED) + .DataType({ ge::DT_FLOAT, ge::DT_FLOAT16 }) + .Format({ ge::FORMAT_ND, ge::FORMAT_ND }) + .UnknownShapeFormat({ ge::FORMAT_ND, ge::FORMAT_ND }); + this->Input("valueW") + .ParamType(REQUIRED) + .DataType({ ge::DT_FLOAT, ge::DT_FLOAT16 }) + .Format({ ge::FORMAT_ND, ge::FORMAT_ND }) + .UnknownShapeFormat({ ge::FORMAT_ND, ge::FORMAT_ND }); + this->Input("valueBias") + .ParamType(REQUIRED) + .DataType({ ge::DT_FLOAT, ge::DT_FLOAT16 }) + .Format({ ge::FORMAT_ND, ge::FORMAT_ND }) + .UnknownShapeFormat({ ge::FORMAT_ND, ge::FORMAT_ND }); + this->Input("attentionMask") + .ParamType(REQUIRED) + .DataType({ ge::DT_FLOAT, ge::DT_FLOAT16 }) + .Format({ ge::FORMAT_ND, ge::FORMAT_ND }) + .UnknownShapeFormat({ ge::FORMAT_ND, ge::FORMAT_ND }); + this->Input("dropOutMask") + .ParamType(REQUIRED) + .DataType({ ge::DT_UINT8, ge::DT_UINT8 }) + .Format({ ge::FORMAT_ND, ge::FORMAT_ND }) + .UnknownShapeFormat({ ge::FORMAT_ND, ge::FORMAT_ND }); + this->Input("headMask") + .ParamType(REQUIRED) + .DataType({ ge::DT_FLOAT, ge::DT_FLOAT16 }) + .Format({ ge::FORMAT_ND, ge::FORMAT_ND }) + .UnknownShapeFormat({ ge::FORMAT_ND, ge::FORMAT_ND }); + this->Output("output") + .ParamType(REQUIRED) + .DataType({ ge::DT_FLOAT, ge::DT_FLOAT16 }) + .Format({ ge::FORMAT_ND, ge::FORMAT_ND }) + .UnknownShapeFormat({ ge::FORMAT_ND, ge::FORMAT_ND }); + this->Attr("numAttentionHeads").Int(); + this->Attr("attentionHeadSize").Int(); + this->Attr("dropOutKeepProb").Float(); + + this->SetInferShape(ge::InferShape); + + this->AICore().SetTiling(optiling::TilingFunc); + this->AICore().AddConfig("ascend910b"); + } +}; + +OP_ADD(BertSelfAttention); +} // namespace ops + diff --git a/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/BertSelfAttention/op_host/bert_self_attention_tiling.h b/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/BertSelfAttention/op_host/bert_self_attention_tiling.h new file mode 100644 index 0000000000000000000000000000000000000000..fdcb2baf4368e697c5331c0ed0b36c8fcf0511e2 --- /dev/null +++ b/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/BertSelfAttention/op_host/bert_self_attention_tiling.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + */ + + + +#ifndef BERT_SELF_ATTENTION_TILING_H +#define BERT_SELF_ATTENTION_TILING_H + +#include "register/tilingdata_base.h" +#include "tiling/tiling_api.h" + +namespace optiling { + BEGIN_TILING_DATA_DEF(BertSelfAttentionTilingData) + TILING_DATA_FIELD_DEF(uint32_t, usedAivCoreNum); + TILING_DATA_FIELD_DEF(uint32_t, aicNum); + TILING_DATA_FIELD_DEF(uint32_t, aivNum); + TILING_DATA_FIELD_DEF(uint32_t, batchSize); + TILING_DATA_FIELD_DEF(uint32_t, sequenceLength); + TILING_DATA_FIELD_DEF(uint32_t, featuresSize); + TILING_DATA_FIELD_DEF(uint32_t, headNum); + TILING_DATA_FIELD_DEF(uint32_t, headSize); + TILING_DATA_FIELD_DEF(uint32_t, fomerBatchNum); + TILING_DATA_FIELD_DEF(uint32_t, perBlockBatchFormer); + TILING_DATA_FIELD_DEF(uint32_t, perBlockBatchLatter); + + TILING_DATA_FIELD_DEF(uint32_t, baseSeqLength); + TILING_DATA_FIELD_DEF(uint32_t, tailSeqLength); + TILING_DATA_FIELD_DEF(uint32_t, softMaxWorkSpaceSize); + TILING_DATA_FIELD_DEF(float, sqrtHeadSize); + TILING_DATA_FIELD_DEF(float, dropOutKeepProb); + + TILING_DATA_FIELD_DEF_STRUCT(TCubeTiling, projectionTiling); + TILING_DATA_FIELD_DEF_STRUCT(TCubeTiling, scoresCubeTiling); + TILING_DATA_FIELD_DEF_STRUCT(TCubeTiling, probsCubeTiling); + TILING_DATA_FIELD_DEF_STRUCT(SoftMaxTiling, softMaxTilingData); + END_TILING_DATA_DEF; + + REGISTER_TILING_DATA_CLASS(BertSelfAttention, BertSelfAttentionTilingData) +} + +#endif \ No newline at end of file diff --git a/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/BertSelfAttention/op_kernel/bert_self_attention.cpp b/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/BertSelfAttention/op_kernel/bert_self_attention.cpp new file mode 100644 index 0000000000000000000000000000000000000000..376c42f6ff3ffc00773b109b36b50687647bf12a --- /dev/null +++ b/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/BertSelfAttention/op_kernel/bert_self_attention.cpp @@ -0,0 +1,527 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + */ + + + + +#include "kernel_tiling/kernel_tiling.h" +#include "kernel_operator.h" +#include "lib/matmul_intf.h" + +using namespace AscendC; +using namespace matmul; + +template +class BertSelfAttention { +public: + __aicore__ inline BertSelfAttention(){}; + + __aicore__ inline void Init(GM_ADDR input, GM_ADDR queryW, GM_ADDR queryBias, GM_ADDR keyW, GM_ADDR keyBias, + GM_ADDR valueW, GM_ADDR valueBias, GM_ADDR attentionMask, GM_ADDR dropOutMask, + GM_ADDR headMask, GM_ADDR output, GM_ADDR usrWorkspace, + const BertSelfAttentionTilingData *__restrict tiling); + + __aicore__ inline void Process(); + +private: + __aicore__ inline void InitBasic(const BertSelfAttentionTilingData *__restrict tiling); + + __aicore__ inline void InitBuffers(); + + __aicore__ inline void ProcessProjections(); + + __aicore__ inline void ProcessAttScore(); + + __aicore__ inline void ProcessVectors(); + + __aicore__ inline void CopyIn(uint64_t offset); + + __aicore__ inline void MulScalarCompute(); + + __aicore__ inline void AddAttentionCompute(uint64_t batchId); + + __aicore__ inline void SoftMaxCompute(); + + __aicore__ inline void DropOutCompute(uint64_t offset); + + __aicore__ inline void MulHeadCompute(uint64_t headId); + + __aicore__ inline void CopyInA(uint64_t offset); + + __aicore__ inline void CopyResOut(uint64_t resOffset); + + TPipe pipe; + + TCubeTiling projectionTiling; + TCubeTiling attentionScoresTiling; + TCubeTiling attentionProbsTiling; + SoftMaxTiling softMaxTiling; + + Matmul, MatmulType, + MatmulType, MatmulType > + projectionMatmulObj; + + Matmul, + MatmulType, + MatmulType > + attentionScoresMatmulObj; + + Matmul, MatmulType, + MatmulType > + attentionProbsMatmulObj; + + TQue inQueueVec; + + TQue attentionMaskQueue; + TQue dropOutMaskQueue; + TQue softMaxWorkSpaceQueue; + TQue tmpQueue; + + GlobalTensor inputGlobal; + GlobalTensor queryWGlobal; + GlobalTensor queryBiasGlobal; + GlobalTensor queryProjectedGlobal; + GlobalTensor queryTransedProjectedGlobal; + GlobalTensor keyWGlobal; + GlobalTensor keyBiasGlobal; + GlobalTensor keyProjectedGlobal; + GlobalTensor keyTransedProjectedGlobal; + GlobalTensor valueWGlobal; + GlobalTensor valueBiasGlobal; + GlobalTensor valueProjectedGlobal; + GlobalTensor valueTransedProjectedGlobal; + GlobalTensor attentionScoresGlobal; + + GlobalTensor attentionMaskGlobal; + GlobalTensor dropOutMaskGlobal; + GlobalTensor headMaskGlobal; + GlobalTensor outputGlobal; + + LocalTensor transUb; + LocalTensor invecLocal; + LocalTensor inputValueLocal; + LocalTensor outputLocal; + + LocalTensor attentionMaskLocal; + LocalTensor dropOutMaskLocal; + LocalTensor softMaxWorkSpaceLocal; + + // split data by tiling info from MultiCoreMatmulTiling + uint64_t mmAOffset; + uint64_t mmBOffset; + uint64_t mmCOffset; + uint64_t mmBiasOffset; + uint64_t mmASize; + uint64_t mmBSize; + + uint32_t batchSize; + uint32_t headNum; + uint32_t aicNum; + uint32_t aivNum; + uint32_t usedAivCoreNum; + uint32_t sequenceLength; + uint32_t featuresSize; + uint32_t headSize; + uint32_t fomerBatchNum; + uint32_t perBlockBatchFormer; + uint32_t perBlockBatchLatter; + uint32_t projectedVectorsMatSize; + uint32_t attentionScoresMatSize; + uint32_t projectedVectorsOffset; + uint32_t attentionScoresOffset; + uint32_t batchIdOffset; + bool isNoTask{false}; + + uint32_t softMaxWorkSpaceSize; + + uint32_t baseSeqLength; + uint32_t tailSeqLength; + uint32_t computeSeqLength; + + DataType recSqrtHeadSize; + float mulProbValue; + + uint32_t formerNum; + uint32_t probePerBlockFormer; + uint32_t probePerBlockLatter; + uint32_t probePerBlock; + uint32_t beginProbeId; + + uint32_t blkIdx; +}; + +template +__aicore__ inline void BertSelfAttention::InitBasic(const BertSelfAttentionTilingData *__restrict tiling) +{ + this->blkIdx = GetBlockIdx(); + this->projectionTiling = tiling->projectionTiling; + this->attentionScoresTiling = tiling->scoresCubeTiling; + this->attentionProbsTiling = tiling->probsCubeTiling; + this->projectedVectorsMatSize = attentionScoresTiling.M * attentionScoresTiling.Ka; //左右矩阵等大小 + this->attentionScoresMatSize = attentionScoresTiling.M * attentionScoresTiling.N; + this->batchSize = tiling->batchSize; + this->headNum = tiling->headNum; + this->aicNum = tiling->aicNum; + this->aivNum = tiling->aivNum; + this->usedAivCoreNum = tiling->usedAivCoreNum; + this->sequenceLength = tiling->sequenceLength; + this->featuresSize = tiling->featuresSize; + this->headSize = tiling->headSize; + this->fomerBatchNum = tiling->fomerBatchNum; + this->perBlockBatchFormer = tiling->perBlockBatchFormer; + this->perBlockBatchLatter = tiling->perBlockBatchLatter; + this->isNoTask = (blkIdx >= usedAivCoreNum); + if (blkIdx <= fomerBatchNum) { + this->projectedVectorsOffset = blkIdx * perBlockBatchFormer * projectedVectorsMatSize; + this->attentionScoresOffset = blkIdx * perBlockBatchFormer * attentionScoresMatSize; + this->batchIdOffset = blkIdx * perBlockBatchFormer; + } else { + this->projectedVectorsOffset = fomerBatchNum * perBlockBatchFormer * projectedVectorsMatSize + + (blkIdx - fomerBatchNum) * perBlockBatchLatter * projectedVectorsMatSize; + this->attentionScoresOffset = fomerBatchNum * perBlockBatchFormer * attentionScoresMatSize + + (blkIdx - fomerBatchNum) * perBlockBatchLatter * attentionScoresMatSize; + this->batchIdOffset = fomerBatchNum * perBlockBatchFormer + (blkIdx - fomerBatchNum) * perBlockBatchLatter; + } + this->softMaxTiling = tiling->softMaxTilingData; + this->softMaxWorkSpaceSize = tiling->softMaxWorkSpaceSize; + this->baseSeqLength = tiling->baseSeqLength; + this->tailSeqLength = tiling->tailSeqLength; + this->recSqrtHeadSize = static_cast(1.0f / tiling->sqrtHeadSize); + this->mulProbValue = 1.0f - tiling->dropOutKeepProb; + //单个query需要由当前核处理的probe数量 + uint32_t partNum = batchSize * headNum * ((sequenceLength + baseSeqLength -1) / baseSeqLength); + this->formerNum = partNum % aivNum; + this->probePerBlockFormer = (partNum + aivNum - 1) / aivNum; + this->probePerBlockLatter = partNum / aivNum; + this->probePerBlock = (blkIdx < formerNum ? probePerBlockFormer : probePerBlockLatter); + this->beginProbeId = (blkIdx < formerNum + ? (probePerBlockFormer * blkIdx) + : (probePerBlockFormer * formerNum + (blkIdx - formerNum) * probePerBlockLatter)); +} + +template +__aicore__ inline void BertSelfAttention::InitBuffers() +{ + if (baseSeqLength < headSize) { + pipe.InitBuffer(inQueueVec, 1, headSize * sequenceLength * sizeof(DataType)); + } else { + pipe.InitBuffer(inQueueVec, 1, baseSeqLength * sequenceLength * sizeof(DataType)); + } + pipe.InitBuffer(attentionMaskQueue, 1, sequenceLength * sizeof(DataType)); + pipe.InitBuffer(dropOutMaskQueue, 1, + baseSeqLength * sequenceLength * sizeof(uint8_t) / 8); // sequenceLength>=16, 且为16的倍数 + pipe.InitBuffer(softMaxWorkSpaceQueue, 1, softMaxWorkSpaceSize); + pipe.InitBuffer(tmpQueue, 1, 1024); +} + +template +__aicore__ inline void BertSelfAttention::Init(GM_ADDR input, GM_ADDR queryW, GM_ADDR queryBias, GM_ADDR keyW, + GM_ADDR keyBias, GM_ADDR valueW, GM_ADDR valueBias, + GM_ADDR attentionMask, GM_ADDR dropOutMask, GM_ADDR headMask, + GM_ADDR output, GM_ADDR usrWorkspace, + const BertSelfAttentionTilingData *__restrict tiling) +{ + InitBasic(tiling); + inputGlobal.SetGlobalBuffer((__gm__ DataType *)input); + queryWGlobal.SetGlobalBuffer((__gm__ DataType *)queryW); + queryBiasGlobal.SetGlobalBuffer((__gm__ DataType *)queryBias); + keyWGlobal.SetGlobalBuffer((__gm__ DataType *)keyW); + keyBiasGlobal.SetGlobalBuffer((__gm__ DataType *)keyBias); + valueWGlobal.SetGlobalBuffer((__gm__ DataType *)valueW); + valueBiasGlobal.SetGlobalBuffer((__gm__ DataType *)valueBias); + outputGlobal.SetGlobalBuffer((__gm__ DataType *)output); + queryProjectedGlobal.SetGlobalBuffer((__gm__ DataType *)usrWorkspace); + keyProjectedGlobal.SetGlobalBuffer((__gm__ DataType *)usrWorkspace + batchSize * sequenceLength * featuresSize); + valueProjectedGlobal.SetGlobalBuffer((__gm__ DataType *)usrWorkspace + + 2 * batchSize * sequenceLength * featuresSize); + queryTransedProjectedGlobal.SetGlobalBuffer((__gm__ DataType *)usrWorkspace + + 3 * batchSize * sequenceLength * featuresSize); + keyTransedProjectedGlobal.SetGlobalBuffer((__gm__ DataType *)usrWorkspace + + 4 * batchSize * sequenceLength * featuresSize); + valueTransedProjectedGlobal.SetGlobalBuffer((__gm__ DataType *)usrWorkspace + + 5 * batchSize * sequenceLength * featuresSize); + attentionScoresGlobal.SetGlobalBuffer((__gm__ DataType *)usrWorkspace + + 6 * batchSize * sequenceLength * featuresSize); + attentionMaskGlobal.SetGlobalBuffer((__gm__ DataType *)attentionMask); + dropOutMaskGlobal.SetGlobalBuffer((__gm__ uint8_t *)dropOutMask); + headMaskGlobal.SetGlobalBuffer((__gm__ DataType *)headMask); + auto nNum = Ceil(projectionTiling.N, projectionTiling.singleCoreN); + auto mNum = Ceil(projectionTiling.M, projectionTiling.singleCoreM); + auto coreNId = blkIdx % nNum; + auto coreMId = blkIdx / nNum; + mmAOffset = coreMId * projectionTiling.singleCoreM * projectionTiling.Ka; + mmBOffset = coreNId * projectionTiling.singleCoreN; + mmCOffset = coreMId * projectionTiling.singleCoreM * projectionTiling.N + coreNId * projectionTiling.singleCoreN; + mmBiasOffset = coreNId * projectionTiling.singleCoreN; + mmASize = (coreMId != (mNum - 1)) ? projectionTiling.singleCoreM + : (projectionTiling.M - coreMId * (projectionTiling.singleCoreM)); + mmBSize = (coreNId != (nNum - 1)) ? projectionTiling.singleCoreN + : (projectionTiling.N - coreNId * (projectionTiling.singleCoreN)); + InitBuffers(); +} + +template +__aicore__ inline void BertSelfAttention::ProcessProjections() +{ + projectionMatmulObj.SetTensorA(inputGlobal[mmAOffset]); + projectionMatmulObj.SetTensorB(queryWGlobal[mmBOffset]); + projectionMatmulObj.SetBias(queryBiasGlobal[mmBiasOffset]); + projectionMatmulObj.SetTail(mmASize, mmBSize, projectionTiling.Ka); + projectionMatmulObj.IterateAll(queryProjectedGlobal[mmCOffset]); + projectionMatmulObj.End(); + SyncAll(); + + projectionMatmulObj.SetTensorA(inputGlobal[mmAOffset]); + projectionMatmulObj.SetTensorB(keyWGlobal[mmBOffset]); + projectionMatmulObj.SetBias(keyBiasGlobal[mmBiasOffset]); + projectionMatmulObj.SetTail(mmASize, mmBSize, projectionTiling.Ka); + projectionMatmulObj.IterateAll(keyProjectedGlobal[mmCOffset]); + projectionMatmulObj.End(); + SyncAll(); + + projectionMatmulObj.SetTensorA(inputGlobal[mmAOffset]); + projectionMatmulObj.SetTensorB(valueWGlobal[mmBOffset]); + projectionMatmulObj.SetBias(valueBiasGlobal[mmBiasOffset]); + projectionMatmulObj.SetTail(mmASize, mmBSize, projectionTiling.Ka); + projectionMatmulObj.IterateAll(valueProjectedGlobal[mmCOffset]); + projectionMatmulObj.End(); + SyncAll(); +} + +template +__aicore__ inline void BertSelfAttention::ProcessAttScore() +{ + if (!isNoTask) { + int curHeadNum = (blkIdx < fomerBatchNum) ? perBlockBatchFormer : perBlockBatchLatter; + transUb = inQueueVec.AllocTensor(); + for (int transLoopIds = 0; transLoopIds < curHeadNum; ++transLoopIds) { + pipe_barrier(PIPE_ALL); //useful + auto transBatchId = (batchIdOffset + transLoopIds) / headNum; + auto transHeadId = (batchIdOffset + transLoopIds) % headNum; + auto transOffsetCopyIn = transBatchId * (sequenceLength * featuresSize) + transHeadId * headSize; + auto transOffsetCopyOut = transBatchId * (sequenceLength * featuresSize) + + transHeadId * sequenceLength * headSize; + + pipe_barrier(PIPE_ALL); // useful + DataCopy(transUb, queryProjectedGlobal[transOffsetCopyIn], + { static_cast(sequenceLength), static_cast(headSize * sizeof(DataType) / 32), + static_cast((headNum - 1) * headSize * sizeof(DataType) / 32), 0}); + pipe_barrier(PIPE_ALL); // useful + DataCopy(queryTransedProjectedGlobal[transOffsetCopyOut], transUb, sequenceLength * headSize); + + pipe_barrier(PIPE_ALL); // useful + DataCopy(transUb, keyProjectedGlobal[transOffsetCopyIn], + { static_cast(sequenceLength), static_cast(headSize * sizeof(DataType) / 32), + static_cast((headNum - 1) * headSize * sizeof(DataType) / 32), 0}); + pipe_barrier(PIPE_ALL); // useful + DataCopy(keyTransedProjectedGlobal[transOffsetCopyOut], transUb, sequenceLength * headSize); + + pipe_barrier(PIPE_ALL); // useful + DataCopy(transUb, valueProjectedGlobal[transOffsetCopyIn], + { static_cast(sequenceLength), static_cast(headSize * sizeof(DataType) / 32), + static_cast((headNum - 1) * headSize * sizeof(DataType) / 32), 0}); + pipe_barrier(PIPE_ALL); // useful + DataCopy(valueTransedProjectedGlobal[transOffsetCopyOut], transUb, sequenceLength * headSize); + } + + inQueueVec.FreeTensor(transUb); + pipe_barrier(PIPE_ALL); //useful + for (int headCnt = 0; headCnt < curHeadNum; ++headCnt) { + attentionScoresMatmulObj.SetTensorA( + queryTransedProjectedGlobal[projectedVectorsOffset + headCnt * projectedVectorsMatSize]); + attentionScoresMatmulObj.SetTensorB( + keyTransedProjectedGlobal[projectedVectorsOffset + headCnt * projectedVectorsMatSize], true); + attentionScoresMatmulObj.IterateAll( + attentionScoresGlobal[attentionScoresOffset + headCnt * attentionScoresMatSize]); + attentionScoresMatmulObj.End(); + } + } + SyncAll(); +} + +template +__aicore__ inline void BertSelfAttention::ProcessVectors() +{ + for (uint64_t probeId = beginProbeId; probeId computeSeqLength = (sequencePartId == sequencePart - 1) ? tailSeqLength : baseSeqLength; + uint64_t offset = batchId * headNum * sequenceLength * sequenceLength + + headId * sequenceLength * sequenceLength + sequencePartId * baseSeqLength * sequenceLength; + pipe_barrier(PIPE_ALL); + CopyIn(offset); + + pipe_barrier(PIPE_ALL); + MulScalarCompute(); + + pipe_barrier(PIPE_ALL); + AddAttentionCompute(batchId); + + pipe_barrier(PIPE_ALL); + SoftMaxCompute(); + pipe_barrier(PIPE_ALL); + DropOutCompute(offset); + + pipe_barrier(PIPE_ALL); + MulHeadCompute(headId); + + pipe_barrier(PIPE_ALL); + CopyInA(offset); + + pipe_barrier(PIPE_ALL); + attentionProbsMatmulObj.SetTensorA(attentionScoresGlobal[offset]); + + uint64_t valueOffset = batchId * headNum * sequenceLength * headSize + headId * sequenceLength * headSize; + uint64_t resOffset = batchId * headNum * sequenceLength * headSize + headId * headSize + + sequencePartId * baseSeqLength * headNum * headSize; + outputLocal = inQueueVec.AllocTensor(); + attentionProbsMatmulObj.SetTensorB(valueTransedProjectedGlobal[valueOffset]); + attentionProbsMatmulObj.SetTail(computeSeqLength, headSize, sequenceLength); + attentionProbsMatmulObj.template IterateAll(outputLocal, 0); + attentionProbsMatmulObj.End(); + pipe_barrier(PIPE_ALL); //useful + inQueueVec.EnQue(outputLocal); + CopyResOut(resOffset); + } +} + +template +__aicore__ inline void BertSelfAttention::Process() +{ + REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), projectionMatmulObj, &projectionTiling, attentionScoresMatmulObj, + &attentionScoresTiling, attentionProbsMatmulObj, &attentionProbsTiling); + + ProcessProjections(); + + ProcessAttScore(); + + ProcessVectors(); +} + +template +__aicore__ inline void BertSelfAttention::CopyIn(uint64_t offset) +{ + invecLocal = inQueueVec.AllocTensor(); + DataCopy(invecLocal, attentionScoresGlobal[offset], computeSeqLength * sequenceLength); + inQueueVec.EnQue(invecLocal); +} + +template +__aicore__ inline void BertSelfAttention::MulScalarCompute() +{ + invecLocal = inQueueVec.DeQue(); + Muls(invecLocal, invecLocal, recSqrtHeadSize, computeSeqLength * sequenceLength); + inQueueVec.EnQue(invecLocal); +} + +template +__aicore__ inline void BertSelfAttention::AddAttentionCompute(uint64_t batchId) +{ + invecLocal = inQueueVec.DeQue(); + attentionMaskLocal = attentionMaskQueue.AllocTensor(); + DataCopy(attentionMaskLocal, attentionMaskGlobal[batchId * sequenceLength], sequenceLength); + pipe_barrier(PIPE_ALL); //useful + for (int i = 0; i < computeSeqLength; i++) { + uint64_t offset = i * sequenceLength; + Add(invecLocal[offset], invecLocal[offset], attentionMaskLocal, sequenceLength); + } + attentionMaskQueue.FreeTensor(attentionMaskLocal); + inQueueVec.EnQue(invecLocal); +} + +template +__aicore__ inline void BertSelfAttention::SoftMaxCompute() +{ + invecLocal = inQueueVec.DeQue(); + softMaxWorkSpaceLocal = softMaxWorkSpaceQueue.AllocTensor(); + SoftMax(invecLocal, invecLocal, softMaxWorkSpaceLocal, softMaxTiling, + { computeSeqLength, sequenceLength, computeSeqLength, sequenceLength }); + softMaxWorkSpaceQueue.FreeTensor(softMaxWorkSpaceLocal); + inQueueVec.EnQue(invecLocal); +} + +template +__aicore__ inline void BertSelfAttention::DropOutCompute(uint64_t offset) +{ + const DataType zero = 0.0f; + invecLocal = inQueueVec.DeQue(); + if (mulProbValue ==0) { + Duplicate(invecLocal, zero, computeSeqLength * sequenceLength); + } else if (mulProbValue != 1) { + const DataType tmpMulProbValue = 1 / mulProbValue; + dropOutMaskLocal = dropOutMaskQueue.AllocTensor(); + + DataCopy(dropOutMaskLocal, dropOutMaskGlobal[offset / 8], computeSeqLength * sequenceLength / 8); //需要大于328 + pipe_barrier(PIPE_ALL); + + Select(invecLocal, dropOutMaskLocal, invecLocal, zero, SELMODE::VSEL_TENSOR_SCALAR_MODE, + computeSeqLength * sequenceLength); + Muls(invecLocal, invecLocal, tmpMulProbValue, computeSeqLength * sequenceLength); + + dropOutMaskQueue.FreeTensor(dropOutMaskLocal); + } else { + // Do Noting + } + inQueueVec.EnQue(invecLocal); +} + +template +__aicore__ inline void BertSelfAttention::MulHeadCompute(uint64_t headId) +{ + invecLocal = inQueueVec.DeQue(); + const DataType headMaskValue = headMaskGlobal.GetValue(headId); + Muls(invecLocal, invecLocal, headMaskValue, computeSeqLength * sequenceLength); + inQueueVec.EnQue(invecLocal); +} + +template +__aicore__ inline void BertSelfAttention::CopyInA(uint64_t offset) +{ + invecLocal = inQueueVec.DeQue(); + DataCopy(attentionScoresGlobal[offset], invecLocal, computeSeqLength * sequenceLength); + inQueueVec.FreeTensor(invecLocal); +} + +template +__aicore__ inline void BertSelfAttention::CopyResOut(uint64_t resOffset) +{ + outputLocal = inQueueVec.DeQue(); + DataCopy(outputGlobal[resOffset], outputLocal, + { static_cast(computeSeqLength), static_cast(headSize * sizeof(DataType) / 32), 0, + static_cast((headNum - 1) * headSize * sizeof(DataType) / 32)}); + inQueueVec.FreeTensor(outputLocal); +} + +extern "C" __global__ __aicore__ void bert_self_attention(GM_ADDR input, GM_ADDR queryW, GM_ADDR queryBias, + GM_ADDR keyW, GM_ADDR keyBias, GM_ADDR valueW, + GM_ADDR valueBias, GM_ADDR attentionMask, GM_ADDR dropOutMask, + GM_ADDR headMask, GM_ADDR output, GM_ADDR workspace, + GM_ADDR tiling) +{ + if (workspace == nullptr) { + return; + } + SetSysWorkspace(workspace); + GM_ADDR usrWorkspace = GetUserWorkspace(workspace); + GET_TILING_DATA(tiling_data, tiling); + const BertSelfAttentionTilingData *__restrict tiling_device = &tiling_data; + + if (TILING_KEY_IS(0)) { + BertSelfAttention op; + op.Init(input ,queryW, queryBias, keyW, keyBias, valueW, valueBias, attentionMask, dropOutMask, headMask, + output, usrWorkspace, tiling_device); + op.Process(); + } else if (TILING_KEY_IS(1)) { + BertSelfAttention op; + op.Init(input ,queryW, queryBias, keyW, keyBias, valueW, valueBias, attentionMask, dropOutMask, headMask, + output, usrWorkspace, tiling_device); + op.Process(); + } else { + return; + } +} + diff --git a/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/bert_self_attention.json b/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/bert_self_attention.json new file mode 100644 index 0000000000000000000000000000000000000000..2fa35ed455210d8fb14c7faaabc7f8d98b32a3eb --- /dev/null +++ b/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/bert_self_attention.json @@ -0,0 +1,137 @@ +[ + { + "op": "BertSelfAttention", + "language": "cpp", + "input_desc": [ + { + "name": "input", + "param_type": "required", + "format": [ + "ND","ND" + ], + "type": [ + "float","float16" + ] + }, + { + "name": "queryW", + "param_type": "required", + "format":[ + "ND","ND" + ], + "type": [ + "float","float16" + ] + }, + { + "name": "queryBias", + "param_type": "required", + "format":[ + "ND","ND" + ], + "type": [ + "float","float16" + ] + }, + { + "name": "keyW", + "param_type": "required", + "format":[ + "ND","ND" + ], + "type": [ + "uint8","uint8" + ] + }, + { + "name": "keyBias", + "param_type": "required", + "format":[ + "ND","ND" + ], + "type": [ + "float","float16" + ] + }, + { + "name": "valueW", + "param_type": "required", + "format":[ + "ND","ND" + ], + "type": [ + "float","float16" + ] + }, + { + "name": "valueBias", + "param_type": "required", + "format":[ + "ND","ND" + ], + "type": [ + "float","float16" + ] + }, + { + "name": "attentionMask", + "param_type": "required", + "format":[ + "ND","ND" + ], + "type": [ + "float","float16" + ] + }, + { + "name": "dropOutMask", + "param_type": "required", + "format":[ + "ND","ND" + ], + "type": [ + "uint8","uint8" + ] + }, + { + "name": "headMask", + "param_type": "required", + "format":[ + "ND","ND" + ], + "type": [ + "float","float16" + ] + } + ], + "output_desc": [ + { + "name": "output", + "param_type": "required", + "format": [ + "ND","ND" + ], + "type": [ + "float","float16" + ] + } + ], + "attr": [ + { + "name": "numAttentionHeads", + "param_type": "required", + "type": "int" + }, + { + "name": "attentionHeadSize", + "param_type": "required", + "type": "int" + }, + { + "name": "dropOutKeepProb", + "param_type": "required", + "type": "float" + } + ] + } +] \ No newline at end of file diff --git a/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/build.sh b/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..30407889cb2ef34e3db4206f25c7c67e1b2f24b4 --- /dev/null +++ b/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/build.sh @@ -0,0 +1,20 @@ +source /usr/local/Ascend/ascend-toolkit/set_env.sh +$ASCEND_HOME_PATH/python/site-packages/bin/msopgen gen -i bert_self_attention.json -f pytorch -c ai_core-Ascend910B4 -lan cpp -out BertSelfAttention_ +cp -r ./BertSelfAttention/* ./BertSelfAttention_ +rm -rf ./BertSelfAttention/* +mv ./BertSelfAttention_/* ./BertSelfAttention +rm -rf BertSelfAttention_ +cd ./BertSelfAttention +if [ ! -f "CMakePresets.json" ]; then + echo "Error: CMakePresets.json File does not exist." + exit 1 +fi +patch -p1 CMakePresets.json < CMakePresets.patch +patch -p1 cmake/util/makeself/makeself-header.sh < cmake.patch +sed -i 's|"customize"|"mxRAG"|' "CMakePresets.json" +sed -i 's|TMPROOT=\\${TMPDIR:="\\$HOME"}|TMPROOT="\\$PWD"|' cmake/util/makeself/makeself-header.sh +chmod -R +x ./build.sh +find . -type f -name "*.sh" -print0 | xargs -0 dos2unix +bash build.sh +./build_out/custom_opp_*.run +cd .. \ No newline at end of file diff --git a/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/op_plugin_patch/BertSelfAttentionKernelNpuApi.cpp b/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/op_plugin_patch/BertSelfAttentionKernelNpuApi.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b87d8287ef9c181d2679735845ddc3d0ff4bfbf5 --- /dev/null +++ b/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/op_plugin_patch/BertSelfAttentionKernelNpuApi.cpp @@ -0,0 +1,174 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + */ + +#include +#include +#include "torch_npu/csrc/aten/CustomFunctions.h" +#include "torch_npu/csrc/framework/utils/RandomOpAdapter.h" +#include "op_plugin/OpApiInterface.h" +#include "op_plugin/AclOpsInterface.h" +#include "op_plugin/utils/op_api_common.h" +#include "op_plugin/utils/OpAdapter.h" + +namespace op_api { +constexpr size_t BASESEQLENGTH = 16; +constexpr size_t ATTENTIONMASKCATAIX = 3; + +using npu_preparation = at_npu::native::OpPreparation; +using namespace at_npu::native; + +enum class DropOutStatus { + DROPOUT_NORMAL = 0, + DROPOUT_NONE, + DROPOUT_ALL +}; + +struct BertSelfAttentionParams { + at::Tensor input; + at::Tensor input_padded; + at::Tensor queryW; + at::Tensor queryBias; + at::Tensor keyW; + at::Tensor keyBias; + at::Tensor valueW; + at::Tensor valueBias; + at::Tensor attentionMask; + at::Tensor attentionMask_padded; + at::Tensor headMask; + int64_t numAttentionHeads; + int64_t attentionHeadSize; + uint64_t seqLenPadded; + double dropOutKeepProb; + at::Tensor dropOutMask; + at::IntArrayRef inputSizes; +}; + +DropOutStatus GetDropoutStatus(double &keepProb, bool train) +{ + if (!train) { + keepProb = 1.0; + } + + if (keepProb == 0) { + return DropOutStatus::DROPOUT_ALL; + } + + if (keepProb == 1) { + return DropOutStatus::DROPOUT_NONE; + } + + return DropOutStatus::DROPOUT_NORMAL; +} + +void padding_input_and_attentionmask(uint64_t &seqLenPadded, at::Tensor &input_padded, at::Tensor &attentionMask_padded, + const at::Tensor &input, const at::Tensor &attentionMask) +{ + auto input_sizes = input.sizes(); + if (input_sizes[1] % BASESEQLENGTH != 0) { + int paddingSize = BASESEQLENGTH - input_sizes[1] % BASESEQLENGTH; + seqLenPadded += paddingSize; + auto padding_tensor = at::full({ input.size(0), paddingSize, input.size(2) }, 0, input.options()); + input_padded = at::cat({ input, padding_tensor }, 1); + padding_tensor = at::full({ attentionMask.size(0), attentionMask.size(1), attentionMask.size(2), paddingSize }, + std::numeric_limits::lowest(), attentionMask.options()); + attentionMask_padded = at::cat({ attentionMask, padding_tensor }, ATTENTIONMASKCATAIX); + } +} + +int check_tensor_shape(std::vector tensors) +{ + std::vector expected_dims = { 3, 2, 1, 2, 1, 2, 1, 4, 4 }; + + for (size_t i = 0; i < tensors.size(); ++i) { + if (tensors[i].dim() != expected_dims[i]) { + std::cout << "Tensor at index " + std::to_string(i) + + " does not have the expected number of dimensions: expected " + + std::to_string(expected_dims[i]) + ", got " + std::to_string(tensors[i].dim()) + << std::endl; + return -1; + } + } + return 0; +} + +at::Tensor execute_bert_self_attention(const BertSelfAttentionParams ¶ms) +{ + at::Tensor result; + + if (params.inputSizes[1] % BASESEQLENGTH == 0) { + result = npu_preparation::apply_tensor_without_format(params.input); + EXEC_NPU_CMD(aclnnBertSelfAttention, params.input, params.queryW, params.queryBias, params.keyW, params.keyBias, + params.valueW, params.valueBias, params.attentionMask, params.dropOutMask, params.headMask, + params.numAttentionHeads, params.attentionHeadSize, params.dropOutKeepProb, result); + } else { + std::vector output_sizes = { params.inputSizes[0], params.seqLenPadded, params.inputSizes[2] }; + c10::IntArrayRef output_size_ref(output_sizes); + at::Tensor result_padded = npu_preparation::apply_tensor_without_format(output_size_ref, + params.input.options()); + EXEC_NPU_CMD(aclnnBertSelfAttention, params.input_padded, params.queryW, params.queryBias, params.keyW, + params.keyBias, params.valueW, params.valueBias, params.attentionMask_padded, params.dropOutMask, + params.headMask, params.numAttentionHeads, params.attentionHeadSize, params.dropOutKeepProb, + result_padded); + result = result_padded.narrow(1, 0, params.inputSizes[1]); + } + + return result; +} + +at::Tensor npu_bert_self_attention_custom(const at::Tensor &input, const at::Tensor &queryW, + const at::Tensor &queryBias, const at::Tensor &keyW, + const at::Tensor &keyBias, const at::Tensor &valueW, + const at::Tensor &valueBias, const at::Tensor &attentionMask, + const at::Tensor &headMask, int64_t numAttentionHeads, + int64_t attentionHeadSize, double dropOutKeepProb, bool train) +{ + std::vector tensors = { input, queryW, queryBias, keyW, keyBias, + valueW, valueBias, attentionMask, headMask }; + if (check_tensor_shape(tensors) == -1) { + return at::Tensor(); + } + double keepProb = 1.0 - dropOutKeepProb; + auto input_sizes = input.sizes(); + uint64_t seqLenPadded = input_sizes[1]; + + // padding input and attentionmask + at::Tensor input_padded; + at::Tensor attentionMask_padded; + padding_input_and_attentionmask(seqLenPadded, input_padded, attentionMask_padded, input, attentionMask); + + int64_t maskLen = (input_sizes[0] * numAttentionHeads * seqLenPadded * seqLenPadded + 128 - 1) / 128 * 128 / 8; + + // gen_drop_out_mask + at::Tensor dropOutMask; + DropOutStatus droupOutStatus = GetDropoutStatus(keepProb, train); + if (droupOutStatus == DropOutStatus::DROPOUT_ALL) { + dropOutMask = at::zeros(at::IntArrayRef{maskLen}, input.options().dtype(at::kByte)); + } else if (droupOutStatus == DropOutStatus::DROPOUT_NONE) { + dropOutKeepProb = 0; + dropOutMask = at::ones(at::IntArrayRef{maskLen}, input.options().dtype(at::kByte)); + } else { + auto original_stream = c10_npu::getCurrentNPUStream(); + c10_npu::SecondaryStreamGuard guard(c10_npu::getCurrentSecondaryStream()); + dropOutMask = at_npu::native::OpPreparation::apply_tensor_without_format({maskLen}, + input.options().dtype(at::kByte)); + std::vector shape = { input_sizes[0], numAttentionHeads, seqLenPadded, seqLenPadded }; + at::IntArrayRef shapeArray(shape); + + const auto gen = at_npu::detail::getDefaultNPUGenerator(); + auto pair = at::check_generator(gen)->philox_engine_inputs(10); + + const uint64_t seed = pair.first; + const uint64_t offset = pair.second; + EXEC_NPU_CMD(aclnnDropoutGenMask, shapeArray, dropOutKeepProb, seed, offset, dropOutMask); + c10_npu::NPUCachingAllocator::recordStream(dropOutMask.storage().data_ptr(), original_stream); + } + + BertSelfAttentionParams params = { input, input_padded, queryW, queryBias, keyW, keyBias, valueW, + valueBias, attentionMask, attentionMask_padded, headMask, + numAttentionHeads, attentionHeadSize, seqLenPadded, + dropOutKeepProb, dropOutMask, input_sizes }; + + return execute_bert_self_attention(params); +} +} // namespace op_api diff --git a/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/op_plugin_patch/op_plugin_functions.yaml b/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/op_plugin_patch/op_plugin_functions.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9d8db99f3b5479e066c1a223868d62ba9de599f8 --- /dev/null +++ b/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/op_plugin_patch/op_plugin_functions.yaml @@ -0,0 +1,2 @@ + - func: npu_bert_self_attention_custom(Tensor input, Tensor queryW, Tensor queryBias, Tensor keyW, Tensor keyBias, Tensor valueW, Tensor valueBias, Tensor attentionMask, Tensor headMask, int numAttentionHeads, int attentionHeadSize, float dropOutKeepProb, bool train) -> Tensor + op_api: all_version diff --git a/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/run_op_plugin.sh b/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/run_op_plugin.sh new file mode 100644 index 0000000000000000000000000000000000000000..9cceef45cca65afed6b7c1b18d35146b26169898 --- /dev/null +++ b/mxRAG/MainRepo/patches/BertSelfAttentionFast/ops/run_op_plugin.sh @@ -0,0 +1,44 @@ +#!/bin/bash +export LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64:/usr/local/Ascend/driver/lib64/common:/usr/local/Ascend/driver/lib64/driver:$LD_LIBRARY_PATH + +CURRENT_DIR=$( + cd $(dirname ${BASH_SOURCE:-$0}) + pwd +) +cd $CURRENT_DIR + +# 导出环境变量 +PTA_DIR=op-plugin + +if [ ! $ASCEND_HOME_DIR ]; then + if [ -d "$HOME/Ascend/ascend-toolkit/latest" ]; then + export ASCEND_HOME_DIR=$HOME/Ascend/ascend-toolkit/latest + else + export ASCEND_HOME_DIR=/usr/local/Ascend/ascend-toolkit/latest + fi +fi +source $ASCEND_HOME_DIR/bin/setenv.bash + +PYTHON_VERSION=3.11 +PYTORCH_VESION=`pip3.11 show torch | grep "Version:" | awk '{print $2}' | awk -F '.' '{print $1"."$2"."$3}' | awk -F '+' '{print $1}'` +export HI_PYTHON=python${PYTHON_VERSION} +export PYTHONPATH=$ASCEND_HOME_DIR/python/site-packages:$PYTHONPATH +export PATH=$ASCEND_HOME_DIR/python/site-packages/bin:$PATH + +function main() { + # 1. PTA自定义算子注册 + cd ${CURRENT_DIR} + FUNCTION_REGISTE_FIELD="op_plugin_patch/op_plugin_functions.yaml" + FUNCTION_REGISTE_FILE="${PTA_DIR}/op_plugin/config/op_plugin_functions.yaml" + line=" - func: npu_bert_self_attention_custom(Tensor input, Tensor queryW, Tensor queryBias, Tensor keyW, Tensor keyBias, Tensor valueW, Tensor valueBias, Tensor attentionMask, Tensor dropOutMask, Tensor headMask, int numAttentionHeads, int attentionHeadSize, float dropOutKeepProb) -> Tensor" + if ! grep -q "\ $line" $FUNCTION_REGISTE_FILE; then + sed -i "/custom:/r $FUNCTION_REGISTE_FIELD" $FUNCTION_REGISTE_FILE + fi + + # 2. 编译PTA插件并安装 + cp -rf op_plugin_patch/*.cpp ${PTA_DIR}/op_plugin/ops/opapi + cd ${PTA_DIR}; + (bash ci/build.sh --python=${PYTHON_VERSION} --pytorch=v$PYTORCH_VESION ; pip3.11 uninstall torch-npu -y ; pip3.11 install dist/*.whl) + +} +main diff --git a/mxRAG/MainRepo/patches/BertSelfAttentionFast/transformers_bert.patch b/mxRAG/MainRepo/patches/BertSelfAttentionFast/transformers_bert.patch new file mode 100644 index 0000000000000000000000000000000000000000..098988da74d2037cdaa295368d08b3994d348575 --- /dev/null +++ b/mxRAG/MainRepo/patches/BertSelfAttentionFast/transformers_bert.patch @@ -0,0 +1,166 @@ +--- modeling_bert.py 2024-06-14 11:17:00.553390150 +0000 ++++ modeling_bert_tei.py 2024-06-14 11:32:57.753390150 +0000 +@@ -26,6 +26,7 @@ + import torch.utils.checkpoint + from torch import nn + from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ++import torch_npu + + from ...activations import ACT2FN + from ...modeling_outputs import ( +@@ -189,7 +190,6 @@ + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) +- self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( +@@ -237,7 +237,6 @@ + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) +- embeddings = self.dropout(embeddings) + return embeddings + + +@@ -258,7 +257,6 @@ + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + +- self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) +@@ -354,10 +352,6 @@ + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + +- # This is actually dropping out entire tokens to attend to, which might +- # seem a bit unusual, but is taken from the original Transformer paper. +- attention_probs = self.dropout(attention_probs) +- + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask +@@ -374,17 +368,76 @@ + outputs = outputs + (past_key_value,) + return outputs + ++class BertSelfAttentionFast(nn.Module): ++ def __init__(self, config, position_embedding_type=None): ++ super().__init__() ++ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): ++ raise ValueError( ++ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " ++ f"heads ({config.num_attention_heads}QV)" ++ ) ++ ++ self.num_attention_heads = config.num_attention_heads ++ self.attention_head_size = int(config.hidden_size / config.num_attention_heads) ++ self.all_head_size = self.num_attention_heads * self.attention_head_size ++ ++ self.query = nn.Linear(config.hidden_size, self.all_head_size) ++ self.key = nn.Linear(config.hidden_size, self.all_head_size) ++ self.value = nn.Linear(config.hidden_size, self.all_head_size) ++ ++ self.dropout = nn.Dropout(config.attention_probs_dropout_prob) ++ self.position_embedding_type = position_embedding_type or getattr( ++ config, "position_embedding_type", "absolute" ++ ) ++ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": ++ self.max_position_embeddings = config.max_position_embeddings ++ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) ++ ++ self.is_decoder = config.is_decoder ++ self.attention_probs_dropout_prob = config.attention_probs_dropout_prob ++ ++ self.queryW=None ++ self.keyW=None ++ self.valueW=None ++ self.head_mask_defualt = None ++ ++ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: ++ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) ++ x = x.view(new_x_shape) ++ return x.permute(0, 2, 1, 3) ++ ++ def forward( ++ self, ++ hidden_states: torch.Tensor, ++ attention_mask: Optional[torch.FloatTensor] = None, ++ head_mask: Optional[torch.FloatTensor] = None, ++ encoder_hidden_states: Optional[torch.FloatTensor] = None, ++ encoder_attention_mask: Optional[torch.FloatTensor] = None, ++ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, ++ output_attentions: Optional[bool] = False, ++ ) -> Tuple[torch.Tensor]: ++ if self.queryW is None: ++ self.queryW = self.query.weight.data.t().contiguous() ++ self.keyW = self.key.weight.data.t().contiguous() ++ self.valueW = self.value.weight.data.t().contiguous() ++ self.head_mask_defualt = torch.ones(1,self.num_attention_heads,1,1,dtype=self.query.weight.data.dtype).to(self.queryW.device) ++ if head_mask is None: ++ head_mask = self.head_mask_defualt ++ ++ context_layer = torch_npu.npu_bert_self_attention_custom(hidden_states, self.queryW, self.query.bias.data, self.keyW, self.key.bias.data, self.valueW, self.value.bias.data, attention_mask, head_mask, self.num_attention_heads, self.attention_head_size, self.attention_probs_dropout_prob, self.training) ++ ++ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) ++ ++ return outputs + + class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) +- self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) +- hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + +@@ -392,7 +445,7 @@ + class BertAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() +- self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type) ++ self.self = BertSelfAttentionFast(config, position_embedding_type=position_embedding_type) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + +@@ -458,11 +511,9 @@ + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) +- self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) +- hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + +@@ -536,17 +587,12 @@ + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + +- layer_output = apply_chunking_to_forward( +- self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output +- ) ++ intermediate_output = self.intermediate(attention_output) ++ layer_output = self.output(intermediate_output, attention_output) + outputs = (layer_output,) + outputs +- +- # if decoder, return the attn key/values as the last output +- if self.is_decoder: +- outputs = outputs + (present_key_value,) +- + return outputs + ++ + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) diff --git a/mxRAG/MainRepo/patches/TEI/README.md b/mxRAG/MainRepo/patches/TEI/README.md new file mode 100644 index 0000000000000000000000000000000000000000..551e7fb7dbf6dc5f0b47e690e7c73d436218f54e --- /dev/null +++ b/mxRAG/MainRepo/patches/TEI/README.md @@ -0,0 +1,24 @@ +# 安装补丁说明 + +## 安装补丁步骤 +```bash +bash tei_patch.sh [text-embeddings-inference包的所在路径] +``` + +## 注意事项 +1. 安装patch前,请先设置CANN环境变量 +```sh + source [cann安装路径](默认为/usr/local/Ascend/ascend-toolkit)/set_env.sh +``` +2. TEI开源代码保持text-embeddings-inference名称不变 + +3. 安装patch后,在text-embeddings-inference目录下会生成NPU适配相关使用说明NPU_ADAPT.md,根据该说明进行后续相关操作 + +## 版本依赖 +| 软件 | 版本要求 | +| ----- |-------------| +| python | 3.10.14 | +| transformers | 4.40.2 | +|TEI| 1.2.3 | +|torch_npu| 2.1.0.post3 | +|safetensors| 0.4.1 | \ No newline at end of file diff --git a/mxRAG/MainRepo/patches/TEI/ops/transformers_bert.patch b/mxRAG/MainRepo/patches/TEI/ops/transformers_bert.patch new file mode 100644 index 0000000000000000000000000000000000000000..be7f8f30aa3df3954435593c9d79334f890932ef --- /dev/null +++ b/mxRAG/MainRepo/patches/TEI/ops/transformers_bert.patch @@ -0,0 +1,82 @@ +--- modeling_bert.py 2024-05-16 13:55:47.420093700 +0800 ++++ modeling_bert_patch.py 2024-05-16 14:10:55.067551700 +0800 +@@ -164,7 +164,6 @@ class BertEmbeddings(nn.Module): + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) +- self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( +@@ -212,7 +211,6 @@ class BertEmbeddings(nn.Module): + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) +- embeddings = self.dropout(embeddings) + return embeddings + + +@@ -233,7 +231,6 @@ class BertSelfAttention(nn.Module): + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + +- self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) +@@ -329,10 +326,6 @@ class BertSelfAttention(nn.Module): + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + +- # This is actually dropping out entire tokens to attend to, which might +- # seem a bit unusual, but is taken from the original Transformer paper. +- attention_probs = self.dropout(attention_probs) +- + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask +@@ -355,11 +348,9 @@ class BertSelfOutput(nn.Module): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) +- self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) +- hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + +@@ -433,11 +424,9 @@ class BertOutput(nn.Module): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) +- self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) +- hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + +@@ -511,17 +500,12 @@ class BertLayer(nn.Module): + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + +- layer_output = apply_chunking_to_forward( +- self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output +- ) ++ intermediate_output = self.intermediate(attention_output) ++ layer_output = self.output(intermediate_output, attention_output) + outputs = (layer_output,) + outputs +- +- # if decoder, return the attn key/values as the last output +- if self.is_decoder: +- outputs = outputs + (present_key_value,) +- + return outputs + ++ + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) diff --git a/mxRAG/MainRepo/patches/TEI/ops/transformers_xlm_roberta.patch b/mxRAG/MainRepo/patches/TEI/ops/transformers_xlm_roberta.patch new file mode 100644 index 0000000000000000000000000000000000000000..62a595b187737729ed87922a96e3b5535a592a05 --- /dev/null +++ b/mxRAG/MainRepo/patches/TEI/ops/transformers_xlm_roberta.patch @@ -0,0 +1,61 @@ +--- modeling_xlm_roberta.py 2024-05-16 13:55:52.724988200 +0800 ++++ modeling_xlm_roberta_patch.py 2024-05-16 14:11:56.076500900 +0800 +@@ -71,7 +71,6 @@ class XLMRobertaEmbeddings(nn.Module): + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) +- self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( +@@ -124,7 +123,6 @@ class XLMRobertaEmbeddings(nn.Module): + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) +- embeddings = self.dropout(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): +@@ -163,7 +161,6 @@ class XLMRobertaSelfAttention(nn.Module) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + +- self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) +@@ -259,10 +256,6 @@ class XLMRobertaSelfAttention(nn.Module) + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + +- # This is actually dropping out entire tokens to attend to, which might +- # seem a bit unusual, but is taken from the original Transformer paper. +- attention_probs = self.dropout(attention_probs) +- + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask +@@ -286,11 +279,9 @@ class XLMRobertaSelfOutput(nn.Module): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) +- self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) +- hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + +@@ -367,11 +358,9 @@ class XLMRobertaOutput(nn.Module): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) +- self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) +- hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + diff --git a/mxRAG/MainRepo/patches/TEI/tei/tei.patch b/mxRAG/MainRepo/patches/TEI/tei/tei.patch new file mode 100644 index 0000000000000000000000000000000000000000..75afdd8453e0ec08f986a6acd568801038564588 --- /dev/null +++ b/mxRAG/MainRepo/patches/TEI/tei/tei.patch @@ -0,0 +1,735 @@ +diff -uprN text-embeddings-inference-1.2.3/NPU_ADAPT.md text-embeddings-inference/NPU_ADAPT.md +--- text-embeddings-inference-1.2.3/NPU_ADAPT.md 1970-01-01 00:00:00.000000000 +0000 ++++ text-embeddings-inference/NPU_ADAPT.md 2024-08-20 06:45:23.590000000 +0000 +@@ -0,0 +1,81 @@ ++# NPU ADAPT TEI BASED ON PYTHON ++ ++A Python gRPC server for Text Embeddings Inference ++ ++## Prepare Environment ++First, you need to install rust and protobuf ++Instructions for installing rust, ++```shell ++curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh ++``` ++## install protobuf v21.12 ++Please refer to protobuf official instructions to install it ++ ++## Install Router ++locate in /tei-npu ++```shell ++cargo install --path router -F python -F http --no-default-features ++``` ++ ++## Install Python Backend ++locate in /tei-npu/backends/python/server ++### Setup Python Environment & Make Local Grpc ++```shell ++make install ++``` ++### Verify Python Backend ++If you plan to use an online model, modify the model path of the run-dev entry in the Makefile. ++For example, ++```shell ++run-dev: ++ python text_embeddings_server/cli.py /home/bge-large-zh-v1.5 ++``` ++then, you can verify python backend based the model. ++```shell ++make run-dev ++``` ++If python-backends does not start normally, rectify the fault according to the error stack ++ ++### Setup Python Cli ++```shell ++poetry install ++``` ++ ++## Run ++### Set device ++If you want to specify an NPU device, specify it in the environment TEI_NPU_DEVICE before starting the server. ++For example, ++```shell ++export TEI_NPU_DEVICE=3 ++``` ++### Server ++locate in /tei-npu ++```shell ++text-embeddings-router --model-id $model --revision $revision --port 8080 ++``` ++For example, if you want to use the embed service ++```shell ++text-embeddings-router --model-id /home/bge-large-zh-v1.5 --dtype float32 --pooling cls --max-concurrent-requests 2048 --max-batch-requests 2048 --max-batch-tokens 1100000 --max-client-batch-size 256 --port 8080 ++``` ++For example, if you want to use the rerank service ++```shell ++text-embeddings-router --model-id /home/bge-reranker-large --max-concurrent-requests 2048 --max-batch-tokens 163840 --max-batch-requests 1 --port 8080 ++``` ++### Client ++For example, if you want to use the embed service ++```bash ++curl 127.0.0.1:8080/embed \ ++ -X POST \ ++ -d '{"inputs":"What is Deep Learning?"}' \ ++ -H 'Content-Type: application/json' ++``` ++For example, if you want to use the embed service ++```bash ++curl 127.0.0.1:8080/rerank \ ++ -X POST \ ++ -d '{"query":"What is Deep Learning?", "texts": ["Deep Learning is not...", "Deep learning is..."]}' \ ++ -H 'Content-Type: application/json' ++``` ++ ++## Attention ++If Python-Backend cannot start normally, run the make run-dev command in backends/python/server to check whether python backends can be started, and modify it using the error stack. +diff -uprN text-embeddings-inference-1.2.3/README.md text-embeddings-inference/README.md +--- text-embeddings-inference-1.2.3/README.md 2024-04-25 08:47:38.000000000 +0000 ++++ text-embeddings-inference/README.md 2024-08-20 07:28:35.800000000 +0000 +@@ -150,7 +150,7 @@ Options: + The dtype to be forced upon the model + + [env: DTYPE=] +- [possible values: float16, float32] ++ [possible values: float16, float32, bfloat16] + + --pooling + Optionally control the pooling method for embedding models. +diff -uprN text-embeddings-inference-1.2.3/backends/grpc-client/src/client.rs text-embeddings-inference/backends/grpc-client/src/client.rs +--- text-embeddings-inference-1.2.3/backends/grpc-client/src/client.rs 2024-04-25 08:47:38.000000000 +0000 ++++ text-embeddings-inference/backends/grpc-client/src/client.rs 2024-08-20 06:45:23.590000000 +0000 +@@ -64,4 +64,46 @@ impl Client { + let response = self.stub.embed(request).await?.into_inner(); + Ok(response.embeddings) + } ++ ++ #[instrument(skip_all)] ++ pub async fn embed_all( ++ &mut self, ++ input_ids: Vec, ++ token_type_ids: Vec, ++ position_ids: Vec, ++ cu_seq_lengths: Vec, ++ max_length: u32, ++ ) -> Result> { ++ let request = tonic::Request::new(EmbedRequest { ++ input_ids, ++ token_type_ids, ++ position_ids, ++ max_length, ++ cu_seq_lengths, ++ }) ++ .inject_context(); ++ let response = self.stub.embed_all(request).await?.into_inner(); ++ Ok(response.allembeddings) ++ } ++ ++ #[instrument(skip_all)] ++ pub async fn predict( ++ &mut self, ++ input_ids: Vec, ++ token_type_ids: Vec, ++ position_ids: Vec, ++ cu_seq_lengths: Vec, ++ max_length: u32, ++ ) -> Result> { ++ let request = tonic::Request::new(PredictRequest { ++ input_ids, ++ token_type_ids, ++ position_ids, ++ max_length, ++ cu_seq_lengths, ++ }) ++ .inject_context(); ++ let response = self.stub.predict(request).await?.into_inner(); ++ Ok(response.predictions) ++ } + } +diff -uprN text-embeddings-inference-1.2.3/backends/proto/embed.proto text-embeddings-inference/backends/proto/embed.proto +--- text-embeddings-inference-1.2.3/backends/proto/embed.proto 2024-04-25 08:47:38.000000000 +0000 ++++ text-embeddings-inference/backends/proto/embed.proto 2024-08-20 06:45:23.590000000 +0000 +@@ -5,6 +5,8 @@ package embedding.v1; + service EmbeddingService { + /// Decode token for a list of prefilled batches + rpc Embed (EmbedRequest) returns (EmbedResponse); ++ rpc Embed_all (EmbedRequest) returns (RawEmbedResponse); ++ rpc Predict (PredictRequest) returns (PredictResponse); + /// Health check + rpc Health (HealthRequest) returns (HealthResponse); + } +@@ -12,6 +14,23 @@ service EmbeddingService { + message HealthRequest {} + message HealthResponse {} + ++message PredictRequest { ++ repeated uint32 input_ids = 1; ++ repeated uint32 token_type_ids = 2; ++ repeated uint32 position_ids = 3; ++ repeated uint32 cu_seq_lengths = 4; ++ /// Length of the longest request ++ uint32 max_length = 5; ++} ++ ++message Prediction { ++ repeated float values = 1; ++} ++ ++message PredictResponse { ++ repeated Prediction predictions = 1; ++} ++ + message EmbedRequest { + repeated uint32 input_ids = 1; + repeated uint32 token_type_ids = 2; +@@ -21,6 +40,7 @@ message EmbedRequest { + uint32 max_length = 5; + } + ++ + message Embedding { + repeated float values = 1; + } +@@ -28,3 +48,11 @@ message Embedding { + message EmbedResponse { + repeated Embedding embeddings = 1; + } ++ ++message TokenEmbedding { ++ repeated Embedding embeddings = 1; ++} ++ ++message RawEmbedResponse { ++ repeated TokenEmbedding allembeddings = 1; ++} +diff -uprN text-embeddings-inference-1.2.3/backends/python/server/Makefile text-embeddings-inference/backends/python/server/Makefile +--- text-embeddings-inference-1.2.3/backends/python/server/Makefile 2024-04-25 08:47:38.000000000 +0000 ++++ text-embeddings-inference/backends/python/server/Makefile 2024-08-20 06:45:23.590000000 +0000 +@@ -19,7 +19,7 @@ install: gen-server + pip install -e . + + run-dev: +- python text_embeddings_server/cli.py serve BAAI/bge-small-en ++ python text_embeddings_server/cli.py BAAI/bge-small-en + + export-requirements: + poetry export -o requirements.txt --without-hashes +diff -uprN text-embeddings-inference-1.2.3/backends/python/server/pyproject.toml text-embeddings-inference/backends/python/server/pyproject.toml +--- text-embeddings-inference-1.2.3/backends/python/server/pyproject.toml 2024-04-25 08:47:38.000000000 +0000 ++++ text-embeddings-inference/backends/python/server/pyproject.toml 2024-08-20 06:45:23.590000000 +0000 +@@ -20,7 +20,7 @@ loguru = "^0.6.0" + opentelemetry-api = "^1.15.0" + opentelemetry-exporter-otlp = "^1.15.0" + opentelemetry-instrumentation-grpc = "^0.36b0" +-torch = { version = "^2.0.1" } ++torch = { version = "^2.1.0" } + + [tool.poetry.extras] + +@@ -29,9 +29,9 @@ grpcio-tools = "^1.51.1" + pytest = "^7.3.0" + + [[tool.poetry.source]] +-name = "pytorch-gpu-src" +-url = "https://download.pytorch.org/whl/cu118" +-priority = "explicit" ++name = "mirrors" ++url = "https://pypi.tuna.tsinghua.edu.cn/simple/" ++priority = "default" + + [tool.pytest.ini_options] + markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"] +diff -uprN text-embeddings-inference-1.2.3/backends/python/server/requirements.txt text-embeddings-inference/backends/python/server/requirements.txt +--- text-embeddings-inference-1.2.3/backends/python/server/requirements.txt 2024-04-25 08:47:38.000000000 +0000 ++++ text-embeddings-inference/backends/python/server/requirements.txt 2024-08-20 06:45:23.590000000 +0000 +@@ -11,7 +11,7 @@ grpc-interceptor==0.15.3 ; python_versio + grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13" + grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13" + grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13" +-huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13" ++huggingface-hub==0.23.0 ; python_version >= "3.9" and python_version < "3.13" + idna==3.4 ; python_version >= "3.9" and python_version < "3.13" + jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13" + loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" +@@ -31,13 +31,14 @@ packaging==23.1 ; python_version >= "3.9 + protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13" + pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" + requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" +-safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13" ++safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" + setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13" + sympy==1.12 ; python_version >= "3.9" and python_version < "3.13" +-torch==2.0.1 ; python_version >= "3.9" and python_version < "3.13" ++torch==2.1.0 ; python_version >= "3.9" and python_version < "3.13" + tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" + typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" + typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13" + urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13" + win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" + wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13" ++poetry==1.8.3 ; python_version >= "3.9" and python_version < "3.13" +diff -uprN text-embeddings-inference-1.2.3/backends/python/server/text_embeddings_server/models/__init__.py text-embeddings-inference/backends/python/server/text_embeddings_server/models/__init__.py +--- text-embeddings-inference-1.2.3/backends/python/server/text_embeddings_server/models/__init__.py 2024-04-25 08:47:38.000000000 +0000 ++++ text-embeddings-inference/backends/python/server/text_embeddings_server/models/__init__.py 2024-08-20 06:45:23.590000000 +0000 +@@ -1,5 +1,6 @@ ++import os + import torch +- ++import torch_npu + from loguru import logger + from pathlib import Path + from typing import Optional +@@ -8,6 +9,7 @@ from transformers.models.bert import Ber + + from text_embeddings_server.models.model import Model + from text_embeddings_server.models.default_model import DefaultModel ++from text_embeddings_server.models.rerank_model import RerankModel + + __all__ = ["Model"] + +@@ -37,6 +39,14 @@ def get_model(model_path: Path, dtype: O + + if torch.cuda.is_available(): + device = torch.device("cuda") ++ elif torch.npu.is_available(): ++ device = torch.device("npu") ++ torch.npu.set_compile_mode(jit_compile=False) ++ option = {"NPU_FUZZY_COMPILE_BLACKLIST": "ReduceProd"} ++ torch.npu.set_option(option) ++ deviceIdx = os.environ.get('TEI_NPU_DEVICE') ++ if deviceIdx != None and deviceIdx.isdigit() and int(deviceIdx) >= 0 and int(deviceIdx) <= 7: ++ torch.npu.set_device(torch.device(f"npu:{deviceIdx}")) + else: + if dtype != torch.float32: + raise ValueError("CPU device only supports float32 dtype") +@@ -44,10 +54,12 @@ def get_model(model_path: Path, dtype: O + + config = AutoConfig.from_pretrained(model_path) + +- if config.model_type == "bert": +- config: BertConfig ++ if config.architectures[0].endswith("Classification"): ++ return RerankModel(model_path, device, dtype) ++ else: + if ( +- device.type == "cuda" ++ config.model_type == "bert" ++ and device.type == "cuda" + and config.position_embedding_type == "absolute" + and dtype in [torch.float16, torch.bfloat16] + and FLASH_ATTENTION +@@ -55,5 +67,6 @@ def get_model(model_path: Path, dtype: O + return FlashBert(model_path, device, dtype) + else: + return DefaultModel(model_path, device, dtype) ++ + + raise NotImplementedError +diff -uprN text-embeddings-inference-1.2.3/backends/python/server/text_embeddings_server/models/default_model.py text-embeddings-inference/backends/python/server/text_embeddings_server/models/default_model.py +--- text-embeddings-inference-1.2.3/backends/python/server/text_embeddings_server/models/default_model.py 2024-04-25 08:47:38.000000000 +0000 ++++ text-embeddings-inference/backends/python/server/text_embeddings_server/models/default_model.py 2024-08-20 06:45:23.590000000 +0000 +@@ -1,20 +1,20 @@ + import inspect + import torch +- ++import torch_npu + from pathlib import Path + from typing import Type, List + from transformers import AutoModel + from opentelemetry import trace + + from text_embeddings_server.models import Model +-from text_embeddings_server.models.types import PaddedBatch, Embedding ++from text_embeddings_server.models.types import PaddedBatch, Embedding, Prediction, TokenEmbedding + + tracer = trace.get_tracer(__name__) + + + class DefaultModel(Model): + def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): +- model = AutoModel.from_pretrained(model_path).to(dtype).to(device) ++ model = AutoModel.from_pretrained(model_path).to(dtype).to(device).eval() + self.hidden_size = model.config.hidden_size + + self.has_position_ids = ( +@@ -41,7 +41,7 @@ class DefaultModel(Model): + kwargs["position_ids"] = batch.position_ids + + output = self.model(**kwargs) +- embedding = output[0][:, 0] ++ embedding = output[0][:, 0].contiguous() + cpu_results = embedding.view(-1).tolist() + + return [ +@@ -50,3 +50,29 @@ class DefaultModel(Model): + ) + for i in range(len(batch)) + ] ++ ++ @tracer.start_as_current_span("embed_all") ++ def embed_all(self, batch: PaddedBatch): ++ kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask} ++ if self.has_token_type_ids: ++ kwargs["token_type_ids"] = batch.token_type_ids ++ if self.has_position_ids: ++ kwargs["position_ids"] = batch.position_ids ++ output = self.model(**kwargs) ++ embedding = output[0].contiguous() ++ cpu_results = embedding.view(-1).tolist() ++ embedding_result=[] ++ for i in range(len(batch)): ++ embedding_tmp=[ ++ Embedding(values=cpu_results[(j+i*batch.max_length) * self.hidden_size : ++ (j + 1 + i*batch.max_length) * self.hidden_size]) ++ for j in range(batch.input_ids.size()[1]) ++ ] ++ tokenembeddings=TokenEmbedding(embeddings=embedding_tmp) ++ embedding_result.append(tokenembeddings) ++ ++ return embedding_result ++ ++ @tracer.start_as_current_span("predict") ++ def predict(self, batch: PaddedBatch) -> List[Prediction]: ++ print("embedding model is not support predict") +diff -uprN text-embeddings-inference-1.2.3/backends/python/server/text_embeddings_server/models/flash_bert.py text-embeddings-inference/backends/python/server/text_embeddings_server/models/flash_bert.py +--- text-embeddings-inference-1.2.3/backends/python/server/text_embeddings_server/models/flash_bert.py 2024-04-25 08:47:38.000000000 +0000 ++++ text-embeddings-inference/backends/python/server/text_embeddings_server/models/flash_bert.py 2024-08-20 06:45:23.590000000 +0000 +@@ -12,7 +12,7 @@ from opentelemetry import trace + import dropout_layer_norm + + from text_embeddings_server.models import Model +-from text_embeddings_server.models.types import FlashBatch, Embedding ++from text_embeddings_server.models.types import FlashBatch, Embedding, Prediction, TokenEmbedding + from text_embeddings_server.utils.flash_attn import attention + + tracer = trace.get_tracer(__name__) +@@ -251,3 +251,12 @@ class FlashBert(Model): + ) + for i in range(len(batch)) + ] ++ ++ @tracer.start_as_current_span("embed_all") ++ def embed_all(self, batch: FlashBatch) -> List[TokenEmbedding]: ++ print("flashbert model is not support embed_all") ++ ++ @tracer.start_as_current_span("predict") ++ def predict(self, batch: PaddedBatch) -> List[Prediction]: ++ print("embedding model is not support predict") ++ +\ No newline at end of file +diff -uprN text-embeddings-inference-1.2.3/backends/python/server/text_embeddings_server/models/model.py text-embeddings-inference/backends/python/server/text_embeddings_server/models/model.py +--- text-embeddings-inference-1.2.3/backends/python/server/text_embeddings_server/models/model.py 2024-04-25 08:47:38.000000000 +0000 ++++ text-embeddings-inference/backends/python/server/text_embeddings_server/models/model.py 2024-08-20 06:45:23.590000000 +0000 +@@ -3,7 +3,7 @@ import torch + from abc import ABC, abstractmethod + from typing import List, TypeVar, Type + +-from text_embeddings_server.models.types import Batch, Embedding ++from text_embeddings_server.models.types import Batch, Embedding, Prediction, TokenEmbedding + + B = TypeVar("B", bound=Batch) + +@@ -27,3 +27,11 @@ class Model(ABC): + @abstractmethod + def embed(self, batch: B) -> List[Embedding]: + raise NotImplementedError ++ ++ @abstractmethod ++ def embed_all(self, batch: B) -> List[TokenEmbedding]: ++ raise NotImplementedError ++ ++ @abstractmethod ++ def predict(self, batch: B) -> List[Prediction]: ++ raise NotImplementedError +diff -uprN text-embeddings-inference-1.2.3/backends/python/server/text_embeddings_server/models/rerank_model.py text-embeddings-inference/backends/python/server/text_embeddings_server/models/rerank_model.py +--- text-embeddings-inference-1.2.3/backends/python/server/text_embeddings_server/models/rerank_model.py 1970-01-01 00:00:00.000000000 +0000 ++++ text-embeddings-inference/backends/python/server/text_embeddings_server/models/rerank_model.py 2024-08-20 06:45:23.590000000 +0000 +@@ -0,0 +1,57 @@ ++import inspect ++import torch ++import torch_npu ++from pathlib import Path ++from typing import Type, List ++from transformers import AutoModel, AutoModelForSequenceClassification ++from opentelemetry import trace ++ ++from text_embeddings_server.models import Model ++from text_embeddings_server.models.types import PaddedBatch, Embedding, Prediction, TokenEmbedding ++ ++tracer = trace.get_tracer(__name__) ++ ++ ++class RerankModel(Model): ++ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): ++ model = AutoModelForSequenceClassification.from_pretrained(model_path).to(dtype).to(device).eval() ++ ++ self.has_position_ids = ( ++ inspect.signature(model.forward).parameters.get("position_ids", None) ++ is not None ++ ) ++ self.has_token_type_ids = ( ++ inspect.signature(model.forward).parameters.get("token_type_ids", None) ++ is not None ++ ) ++ ++ super(RerankModel, self).__init__(model=model, dtype=dtype, device=device) ++ ++ @property ++ def batch_type(self) -> Type[PaddedBatch]: ++ return PaddedBatch ++ ++ @tracer.start_as_current_span("embed") ++ def embed(self, batch: PaddedBatch) -> List[Embedding]: ++ print("rerank model is not support embed") ++ ++ @tracer.start_as_current_span("embed_all") ++ def embed_all(self, batch: PaddedBatch) -> List[TokenEmbedding]: ++ print("rerank model is not support embed_all") ++ ++ @tracer.start_as_current_span("predict") ++ def predict(self, batch: PaddedBatch) -> List[Prediction]: ++ kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask} ++ if self.has_token_type_ids: ++ kwargs["token_type_ids"] = batch.token_type_ids ++ if self.has_position_ids: ++ kwargs["position_ids"] = batch.position_ids ++ ++ scores = self.model(**kwargs).logits.view(-1,).tolist() ++ ++ return [ ++ Prediction( ++ values=scores[i:i+1] ++ ) ++ for i in range(len(batch)) ++ ] +diff -uprN text-embeddings-inference-1.2.3/backends/python/server/text_embeddings_server/models/types.py text-embeddings-inference/backends/python/server/text_embeddings_server/models/types.py +--- text-embeddings-inference-1.2.3/backends/python/server/text_embeddings_server/models/types.py 2024-04-25 08:47:38.000000000 +0000 ++++ text-embeddings-inference/backends/python/server/text_embeddings_server/models/types.py 2024-08-20 06:45:23.590000000 +0000 +@@ -5,7 +5,7 @@ from dataclasses import dataclass + from opentelemetry import trace + + from text_embeddings_server.pb import embed_pb2 +-from text_embeddings_server.pb.embed_pb2 import Embedding ++from text_embeddings_server.pb.embed_pb2 import Embedding, Prediction, TokenEmbedding + + tracer = trace.get_tracer(__name__) + +@@ -27,6 +27,7 @@ class PaddedBatch(Batch): + token_type_ids: torch.Tensor + position_ids: torch.Tensor + attention_mask: torch.Tensor ++ max_length: int + + @classmethod + @tracer.start_as_current_span("from_pb") +@@ -35,6 +36,7 @@ class PaddedBatch(Batch): + all_tensors = torch.zeros( + [4, len(pb.cu_seq_lengths) - 1, pb.max_length], dtype=torch.int32 + ) ++ max_length=pb.max_length + + for i, start_index in enumerate(pb.cu_seq_lengths[:-1]): + end_index = pb.cu_seq_lengths[i + 1] +@@ -59,6 +61,7 @@ class PaddedBatch(Batch): + token_type_ids=all_tensors[1], + position_ids=all_tensors[2], + attention_mask=all_tensors[3], ++ max_length=max_length, + ) + + def __len__(self): +diff -uprN text-embeddings-inference-1.2.3/backends/python/server/text_embeddings_server/server.py text-embeddings-inference/backends/python/server/text_embeddings_server/server.py +--- text-embeddings-inference-1.2.3/backends/python/server/text_embeddings_server/server.py 2024-04-25 08:47:38.000000000 +0000 ++++ text-embeddings-inference/backends/python/server/text_embeddings_server/server.py 2024-08-20 06:45:23.590000000 +0000 +@@ -31,6 +31,20 @@ class EmbeddingService(embed_pb2_grpc.Em + embeddings = self.model.embed(batch) + + return embed_pb2.EmbedResponse(embeddings=embeddings) ++ ++ async def Embed_all(self, request, context): ++ batch = self.model.batch_type.from_pb(request, self.model.device) ++ ++ embeddings = self.model.embed_all(batch) ++ ++ return embed_pb2.RawEmbedResponse(allembeddings=embeddings) ++ ++ async def Predict(self, request, context): ++ batch = self.model.batch_type.from_pb(request, self.model.device) ++ ++ predictions = self.model.predict(batch) ++ ++ return embed_pb2.PredictResponse(predictions=predictions) + + + def serve( +diff -uprN text-embeddings-inference-1.2.3/backends/python/src/lib.rs text-embeddings-inference/backends/python/src/lib.rs +--- text-embeddings-inference-1.2.3/backends/python/src/lib.rs 2024-04-25 08:47:38.000000000 +0000 ++++ text-embeddings-inference/backends/python/src/lib.rs 2024-08-20 06:45:23.590000000 +0000 +@@ -25,9 +25,8 @@ impl PythonBackend { + ) -> Result { + match model_type { + ModelType::Classifier => { +- return Err(BackendError::Start( +- "`classifier` model type is not supported".to_string(), +- )) ++ let pool = Pool::Cls; ++ pool + } + ModelType::Embedding(pool) => { + if pool != Pool::Cls { +@@ -75,16 +74,62 @@ impl Backend for PythonBackend { + } + + fn embed(&self, batch: Batch) -> Result { +- if !batch.raw_indices.is_empty() { +- return Err(BackendError::Inference( +- "raw embeddings are not supported for the Python backend.".to_string(), +- )); ++ let batch_size = batch.len(); ++ ++ let mut embeddings = ++ HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); ++ ++ if !batch.pooled_indices.is_empty() { ++ let results = self ++ .tokio_runtime ++ .block_on(self.backend_client.clone().embed( ++ batch.input_ids, ++ batch.token_type_ids, ++ batch.position_ids, ++ batch.cumulative_seq_lengths, ++ batch.max_length, ++ )) ++ .map_err(|err| BackendError::Inference(err.to_string()))?; ++ ++ let pooled_embeddings: Vec> = results.into_iter().map(|r| r.values).collect(); ++ for (i, e) in pooled_embeddings.into_iter().enumerate() { ++ embeddings.insert(i, Embedding::Pooled(e)); ++ } ++ } ++ else if !batch.raw_indices.is_empty() { ++ let results = self ++ .tokio_runtime ++ .block_on(self.backend_client.clone().embed_all( ++ batch.input_ids, ++ batch.token_type_ids, ++ batch.position_ids, ++ batch.cumulative_seq_lengths, ++ batch.max_length, ++ )) ++ .map_err(|err| BackendError::Inference(err.to_string()))?; ++ ++ let mut raw_embeddings = Vec::new(); ++ for token_embedding in results { ++ let mut two_dim_list = Vec::new(); ++ for embeddings in token_embedding.embeddings { ++ let values = embeddings.values.clone(); ++ two_dim_list.push(values); ++ } ++ raw_embeddings.push(two_dim_list); ++ } ++ for (i, e) in raw_embeddings.into_iter().enumerate() { ++ embeddings.insert(i, Embedding::All(e)); ++ } + } ++ Ok(embeddings) ++ } ++ ++ fn predict(&self, batch: Batch) -> Result { + let batch_size = batch.len(); + + let results = self + .tokio_runtime +- .block_on(self.backend_client.clone().embed( ++ .block_on(self.backend_client.clone().predict( + batch.input_ids, + batch.token_type_ids, + batch.position_ids, +@@ -92,20 +137,15 @@ impl Backend for PythonBackend { + batch.max_length, + )) + .map_err(|err| BackendError::Inference(err.to_string()))?; +- let pooled_embeddings: Vec> = results.into_iter().map(|r| r.values).collect(); ++ ++ let predictions_result: Vec> = results.into_iter().map(|r| r.values).collect(); + +- let mut embeddings = ++ let mut predictions = + HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); +- for (i, e) in pooled_embeddings.into_iter().enumerate() { +- embeddings.insert(i, Embedding::Pooled(e)); ++ for (i, r) in predictions_result.into_iter().enumerate() { ++ predictions.insert(i, r); + } + +- Ok(embeddings) +- } +- +- fn predict(&self, _batch: Batch) -> Result { +- Err(BackendError::Inference( +- "`predict` is not implemented".to_string(), +- )) ++ Ok(predictions) + } + } +diff -uprN text-embeddings-inference-1.2.3/backends/src/dtype.rs text-embeddings-inference/backends/src/dtype.rs +--- text-embeddings-inference-1.2.3/backends/src/dtype.rs 2024-04-25 08:47:38.000000000 +0000 ++++ text-embeddings-inference/backends/src/dtype.rs 2024-08-20 07:28:35.800000000 +0000 +@@ -15,6 +15,8 @@ pub enum DType { + // Float32 is not available on candle cuda + #[cfg(any(feature = "python", feature = "candle"))] + Float32, ++ #[cfg(any(feature = "python", feature = "candle"))] ++ BFloat16, + // #[cfg(feature = "candle")] + // Q6K, + } +@@ -31,6 +33,8 @@ impl fmt::Display for DType { + // Float32 is not available on candle cuda + #[cfg(any(feature = "python", feature = "candle"))] + DType::Float32 => write!(f, "float32"), ++ #[cfg(any(feature = "python", feature = "candle"))] ++ DType::BFloat16 => write!(f, "bfloat16"), + // #[cfg(feature = "candle")] + // DType::Q6K => write!(f, "q6k"), + } +diff -uprN text-embeddings-inference-1.2.3/core/src/infer.rs text-embeddings-inference/core/src/infer.rs +--- text-embeddings-inference-1.2.3/core/src/infer.rs 2024-04-25 08:47:38.000000000 +0000 ++++ text-embeddings-inference/core/src/infer.rs 2024-08-20 06:45:23.590000000 +0000 +@@ -37,11 +37,6 @@ impl Infer { + notify_batching_task.clone(), + embed_sender.clone(), + )); +- tokio::spawn(batching_task( +- queue.clone(), +- notify_batching_task.clone(), +- embed_sender, +- )); + + // Create embed task to communicate with backend + tokio::spawn(backend_task(backend.clone(), embed_receiver)); +diff -uprN text-embeddings-inference-1.2.3/docs/source/en/cli_arguments.md text-embeddings-inference/docs/source/en/cli_arguments.md +--- text-embeddings-inference-1.2.3/docs/source/en/cli_arguments.md 2024-04-25 08:47:38.000000000 +0000 ++++ text-embeddings-inference/docs/source/en/cli_arguments.md 2024-08-20 07:28:35.800000000 +0000 +@@ -50,7 +50,7 @@ Options: + The dtype to be forced upon the model + + [env: DTYPE=] +- [possible values: float16, float32] ++ [possible values: float16, float32, bfloat16] + + --pooling + Optionally control the pooling method for embedding models. +diff -uprN text-embeddings-inference-1.2.3/router/src/lib.rs text-embeddings-inference/router/src/lib.rs +--- text-embeddings-inference-1.2.3/router/src/lib.rs 2024-04-25 08:47:38.000000000 +0000 ++++ text-embeddings-inference/router/src/lib.rs 2024-08-20 07:28:35.800000000 +0000 +@@ -188,6 +188,10 @@ pub async fn run( + { + DType::Float16 + } ++ #[cfg(any(feature = "accelerate", feature = "mkl", feature = "mkl-dynamic"))] ++ { ++ DType::BFloat16 ++ } + }); + + // Create backend diff --git a/mxRAG/MainRepo/patches/TEI/tei_patch.sh b/mxRAG/MainRepo/patches/TEI/tei_patch.sh new file mode 100644 index 0000000000000000000000000000000000000000..e7ac5a8cd31e316ad1556518df1f8893a07bf6c8 --- /dev/null +++ b/mxRAG/MainRepo/patches/TEI/tei_patch.sh @@ -0,0 +1,13 @@ +#!/bin/bash +set -e + +PATCH_DIR=$(cd $(dirname $0); pwd) + +# 1. TEI打patch +cd $1 +patch -p1 < $PATCH_DIR/tei/tei.patch + +# 2. transformers打patch +TRANSFORMER_PACKAGE_PATH=$(python3 -c 'import transformers; import os; print(os.path.dirname(transformers.__file__))') +patch -p1 $TRANSFORMER_PACKAGE_PATH/models/bert/modeling_bert.py < $PATCH_DIR/ops/transformers_bert.patch +patch -p1 $TRANSFORMER_PACKAGE_PATH/models/xlm_roberta/modeling_xlm_roberta.py < $PATCH_DIR/ops/transformers_xlm_roberta.patch \ No newline at end of file diff --git a/mxRAG/MainRepo/patches/whisper/README.md b/mxRAG/MainRepo/patches/whisper/README.md new file mode 100644 index 0000000000000000000000000000000000000000..059454cd7eb41a4a0b0e6b463719bc16e2d3632d --- /dev/null +++ b/mxRAG/MainRepo/patches/whisper/README.md @@ -0,0 +1,66 @@ +# 安装openai-whisper补丁说明 + +## 安装环境准备 + +参考:https://gitee.com/ascend/ModelZoo-PyTorch/tree/master/MindIE/MindIE-Torch/built-in/audio/Whisper + +| 配套 | 版本要求 | +|---------|---------| +| CANN | 8.0.RC2 | +| MindIE | 1.0.RC2 | +| Python | 3.10.X | +| PyTorch | 2.1.0 | +| ffmpeg | 4.2.7 | +| onnx | 1.16.1 | + + +1.安装MindIE前需要先source toolkit的环境变量,然后直接安装,以默认安装路径/usr/local/Ascend为例: +```sh +source /usr/local/Ascend/ascend-toolkit/set_env.sh +bash Ascend-mindie_*.run --install +``` +2.ubuntu下可通过apt-get install ffmpeg命令安装ffmpeg + +## 安装补丁步骤 +1.进入patches/whisper目录,下载zh.wav音频文件。 +```sh +wget https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav +``` +2.在patches/whisper目录下进行补丁操作。 +```sh +bash whisper_patch.sh +``` +注意事项: +1.参考设计默认的模型为tiny,运行的设备为Ascend310P3,使用的DEVICE为0,如需更换模型、设备类型、device对应修改whisper_patch.sh中MODEL_NAME、SOC_VERSION、DEVICE参数后再执行补丁操作。 + +## 模型推理 +1.命令行调用 +```sh +whisper zh.wav --model tiny +``` +其中:zh.wav表示待转译的音频文件,支持的音频类型包括M4A、MP3、MP4、MPEG、MPGA、WAV 和 WEBM等; +tiny表示使用的模型,支持tiny、base、small、medium、large模型。 + +2.Python调用 + +```python +from whisper import load_model +from whisper.transcribe import transcribe +# 模型加载 +model = load_model('tiny') +# 音频转译 +result = transcribe(model, audio="zh.wav", verbose=False, beam_size=5, temperature=0) +print(result['text']) +``` +3.运行结果 +```commandline +"我認爲跑步最重要的事就是給我帶來了身體健康" +``` +## 参数说明 +whisper接口功能请参考openai-whisper官方接口:https://github.com/openai/whisper + +因模型经过npu适配重新编译,使用时需注意以下两点: + +1.whisper.load_model为模型加载方法,使用时参数name填写经过本地向量化的模型名称,如需更换模型请重新编译后再使用;参数download_root默认为编译模型的导出路径,使用时无需填写。 + +2.whisper.transcribe.transcribe为模型转译方法,当temperature使用默认值0时需声明beam_size=5,当temperature使用其他非0值时需声明best_of=5。 diff --git a/mxRAG/MainRepo/patches/whisper/compile.py b/mxRAG/MainRepo/patches/whisper/compile.py new file mode 100644 index 0000000000000000000000000000000000000000..513f2b0af814d6234a5af8dc6d27040aea7ffa50 --- /dev/null +++ b/mxRAG/MainRepo/patches/whisper/compile.py @@ -0,0 +1,133 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import torch +import mindietorch + +_FRAMES = 3000 +_HALF_FRAMES = 1500 +_MAX_TOKEN = 448 +_KV_NUM = 2 + +def parse_args(): + parser = argparse.ArgumentParser(description="mindietorch model compilation") + parser.add_argument("--model_path", default="/tmp/models") + parser.add_argument("--beam_size", type=int, default=5) + parser.add_argument("--nblocks", type=int, default=4) + parser.add_argument("--hidden", type=int, default=384) + parser.add_argument("--n_mels", type=int, default=80) + parser.add_argument("--soc_version", default="Ascend310P3") + args = parser.parse_args() + return args + +def compile_and_save(ts_model, input_info, soc_version, save_path): + ts_model.eval() + mindie_model = mindietorch.compile( + ts_model, + inputs=input_info, + precision_policy=mindietorch._enums.PrecisionPolicy.FP16, + truncate_long_and_double=True, + allow_tensor_replace_int=True, + soc_version=soc_version, + optimization_level=0 + ) + mindie_model.save(save_path) + +def encoder(args): + ts_model = torch.jit.load(f"{args.model_path}/encoder.ts") + input_mel_info = mindietorch.Input([1, args.n_mels, _FRAMES]) + input_info = [input_mel_info] + save_path = f"{args.model_path}/encoder_compiled.ts" + compile_and_save(ts_model, input_info, args.soc_version, save_path) + +def language(args): + ts_model = torch.jit.load(f"{args.model_path}/decoder_prefill.ts") + input_tokens_info = mindietorch.Input([1, 1]) + input_audio_features_info = mindietorch.Input([1, _HALF_FRAMES, args.hidden]) + input_pos_embed_info = mindietorch.Input([1, args.hidden]) + input_info = [ + input_tokens_info, + input_audio_features_info, + input_pos_embed_info, + ] + save_path = f"{args.model_path}/language_detection_compiled.ts" + compile_and_save(ts_model, input_info, args.soc_version, save_path) + +def prefill(args): + ts_model = torch.jit.load(f"{args.model_path}/decoder_prefill.ts") + + input_tokens_info = mindietorch.Input( + min_shape=[args.beam_size, 1], + max_shape=[args.beam_size, _MAX_TOKEN] + ) + input_audio_features_info = mindietorch.Input( + min_shape=[1, _HALF_FRAMES, args.hidden], + max_shape=[1, _HALF_FRAMES, args.hidden] + ) + input_pos_embed_info = mindietorch.Input( + min_shape=[1, args.hidden], + max_shape=[_MAX_TOKEN, args.hidden] + ) + input_info = [ + input_tokens_info, + input_audio_features_info, + input_pos_embed_info, + ] + save_path = f"{args.model_path}/decoder_prefill_compiled.ts" + compile_and_save(ts_model, input_info, args.soc_version, save_path) + +def decode(args): + ts_model = torch.jit.load(f"{args.model_path}/decoder_decode.ts") + + input_tokens_info = mindietorch.Input( + min_shape=[args.beam_size, 1], + max_shape=[args.beam_size, 1] + ) + input_audio_features_info = mindietorch.Input( + min_shape=[1, _HALF_FRAMES, args.hidden], + max_shape=[1, _HALF_FRAMES, args.hidden] + ) + input_pos_embed_info = mindietorch.Input( + min_shape=[args.hidden], + max_shape=[args.hidden] + ) + input_cache_dyn_info = mindietorch.Input( + min_shape=(args.nblocks, _KV_NUM, args.beam_size, 1, args.hidden), + max_shape=(args.nblocks, _KV_NUM, args.beam_size, _MAX_TOKEN, args.hidden) + ) + input_cache_sta_info = mindietorch.Input( + min_shape=[args.nblocks, _KV_NUM, 1, _HALF_FRAMES, args.hidden], + max_shape=[args.nblocks, _KV_NUM, 1, _HALF_FRAMES, args.hidden] + ) + + input_info = [ + input_tokens_info, + input_audio_features_info, + input_pos_embed_info, + input_cache_dyn_info, + input_cache_sta_info + ] + + save_path = f"{args.model_path}/decoder_decode_compiled.ts" + compile_and_save(ts_model, input_info, args.soc_version, save_path) + +def main(): + args = parse_args() + for func in encoder, language, prefill, decode: + func(args) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/mxRAG/MainRepo/patches/whisper/mindietorch_infer.patch b/mxRAG/MainRepo/patches/whisper/mindietorch_infer.patch new file mode 100644 index 0000000000000000000000000000000000000000..fc7f771847486b8542d41a2a54876304c481399e --- /dev/null +++ b/mxRAG/MainRepo/patches/whisper/mindietorch_infer.patch @@ -0,0 +1,226 @@ +diff --git a/whisper/decoding.py b/whisper/decoding.py +index 49485d0..4dccc86 100644 +--- a/whisper/decoding.py ++++ b/whisper/decoding.py +@@ -6,6 +6,7 @@ import torch + import torch.nn.functional as F + from torch import Tensor + from torch.distributions import Categorical ++import mindietorch + + from .audio import CHUNK_LENGTH + from .tokenizer import Tokenizer, get_tokenizer +@@ -14,6 +15,7 @@ from .utils import compression_ratio + if TYPE_CHECKING: + from .model import Whisper + ++mindietorch.set_device(0) + + @torch.no_grad() + def detect_language( +@@ -54,7 +56,7 @@ def detect_language( + # forward pass using a single token, startoftranscript + n_audio = mel.shape[0] + x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1] +- logits = model.logits(x, mel)[:, 0] ++ logits = model.logits(x, mel)[0][:, 0] + + # collect detected languages; suppress all non-language tokens + mask = torch.ones(logits.shape[-1], dtype=torch.bool) +@@ -145,36 +147,35 @@ class PyTorchInference(Inference): + def __init__(self, model: "Whisper", initial_token_length: int): + self.model: "Whisper" = model + self.initial_token_length = initial_token_length +- self.kv_cache = {} +- self.hooks = [] +- +- key_modules = [block.attn.key for block in self.model.decoder.blocks] +- value_modules = [block.attn.value for block in self.model.decoder.blocks] +- self.kv_modules = key_modules + value_modules ++ self.cache_dyn = None ++ self.cache_sta = None + + def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: +- if not self.kv_cache: +- self.kv_cache, self.hooks = self.model.install_kv_cache_hooks() +- + if tokens.shape[-1] > self.initial_token_length: + # only need to use the last token except in the first forward pass + tokens = tokens[:, -1:] ++ pos_embed = self.model.decoder.positional_embedding[self.cache_dyn.shape[3]] ++ logits, cache_dyn, _ = self.model.decoder( ++ tokens, audio_features, pos_embed, self.cache_dyn, self.cache_sta) ++ self.cache_dyn = cache_dyn ++ else: ++ pos_embed = self.model.decoder.positional_embedding[:tokens.shape[-1]] ++ logits, cache_dyn, cache_sta = self.model.decoder(tokens, audio_features, pos_embed) ++ self.cache_dyn = cache_dyn ++ self.cache_sta = cache_sta + +- return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache) ++ return logits + + def cleanup_caching(self): +- for hook in self.hooks: +- hook.remove() +- +- self.kv_cache = {} +- self.hooks = [] ++ self.cache_dyn = None ++ self.cache_sta = None + + def rearrange_kv_cache(self, source_indices): + if source_indices != list(range(len(source_indices))): +- for module in self.kv_modules: +- # update the key/value cache to contain the selected sequences +- self.kv_cache[module] = self.kv_cache[module][source_indices].detach() +- ++ blocks = self.cache_dyn.shape[0] ++ for i in range(blocks): ++ for j in range(2): # k and v 2 items ++ self.cache_dyn[i][j] = self.cache_dyn[i][j][source_indices] + + class SequenceRanker: + def rank( +diff --git a/whisper/model.py b/whisper/model.py +index a678283..c94a024 100644 +--- a/whisper/model.py ++++ b/whisper/model.py +@@ -1,12 +1,14 @@ + import base64 + import gzip + from dataclasses import dataclass ++import os + from typing import Dict, Iterable, Optional + + import numpy as np + import torch + import torch.nn.functional as F + from torch import Tensor, nn ++import mindietorch + + from .decoding import decode as decode_function + from .decoding import detect_language as detect_language_function +@@ -153,24 +155,19 @@ class AudioEncoder(nn.Module): + [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] + ) + self.ln_post = LayerNorm(n_state) ++ self.device = "npu:0" ++ self.mindietorch_encoder_model = torch.jit.load( ++ "/tmp/models/encoder_compiled.ts" ++ ).eval().to(self.device) + + def forward(self, x: Tensor): + """ + x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) + the mel spectrogram of the audio + """ +- x = F.gelu(self.conv1(x)) +- x = F.gelu(self.conv2(x)) +- x = x.permute(0, 2, 1) +- +- assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" +- x = (x + self.positional_embedding).to(x.dtype) +- +- for block in self.blocks: +- x = block(x) +- +- x = self.ln_post(x) +- return x ++ x = x.to(self.device) ++ x = self.mindietorch_encoder_model(x) ++ return x.cpu() + + + class TextDecoder(nn.Module): +@@ -193,29 +190,58 @@ class TextDecoder(nn.Module): + mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) + self.register_buffer("mask", mask, persistent=False) + +- def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): +- """ +- x : torch.LongTensor, shape = (batch_size, <= n_ctx) +- the text tokens +- xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) +- the encoded audio features to be attended on +- """ +- offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 +- x = ( +- self.token_embedding(x) +- + self.positional_embedding[offset : offset + x.shape[-1]] +- ) +- x = x.to(xa.dtype) +- +- for block in self.blocks: +- x = block(x, xa, mask=self.mask, kv_cache=kv_cache) ++ self.device = "npu:0" ++ self.mindietorch_language_detection_model = torch.jit.load( ++ "/tmp/models/language_detection_compiled.ts" ++ ).eval().to(self.device) ++ self.mindietorch_prefill_model = torch.jit.load( ++ "/tmp/models/decoder_prefill_compiled.ts" ++ ).eval().to(self.device) ++ self.mindietorch_decode_model = torch.jit.load( ++ "/tmp/models/decoder_decode_compiled.ts" ++ ).eval().to(self.device) + +- x = self.ln(x) +- logits = ( +- x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) +- ).float() +- +- return logits ++ def forward( ++ self, ++ x: Tensor, ++ xa: Tensor, ++ pos_embed: Tensor = None, ++ cache_dyn: Tensor = None, ++ cache_sta: Tensor = None, ++ ): ++ if cache_dyn is None: ++ tokens_npu = x.float().to(self.device) ++ audio_features_npu = xa.to(self.device) ++ pos_embed_npu = pos_embed.to(self.device) ++ if x.shape[0] != 1: ++ logits, cache_dyn, cache_sta = self.mindietorch_prefill_model( ++ tokens_npu, ++ audio_features_npu, ++ pos_embed_npu ++ ) ++ else: ++ logits, cache_dyn, cache_sta = self.mindietorch_language_detection_model( ++ tokens_npu, ++ audio_features_npu, ++ pos_embed_npu ++ ) ++ logits = logits.cpu() ++ cache_dyn = cache_dyn.cpu() ++ else: ++ tokens_npu = x.float().to(self.device) ++ audio_features_npu = xa.to(self.device) ++ pos_embed_npu = pos_embed.to(self.device) ++ cache_dyn_npu = cache_dyn.to(self.device) ++ logits, cache_dyn, _ = self.mindietorch_decode_model( ++ tokens_npu, ++ audio_features_npu, ++ pos_embed_npu, ++ cache_dyn_npu, ++ cache_sta ++ ) ++ logits = logits.cpu() ++ cache_dyn = cache_dyn.cpu() ++ return logits, cache_dyn, cache_sta + + + class Whisper(nn.Module): +@@ -257,7 +283,8 @@ class Whisper(nn.Module): + return self.encoder(mel) + + def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): +- return self.decoder(tokens, audio_features) ++ pos_embed = self.decoder.positional_embedding[:tokens.shape[-1]] ++ return self.decoder(tokens, audio_features, pos_embed) + + def forward( + self, mel: torch.Tensor, tokens: torch.Tensor diff --git a/mxRAG/MainRepo/patches/whisper/trace_model.patch b/mxRAG/MainRepo/patches/whisper/trace_model.patch new file mode 100644 index 0000000000000000000000000000000000000000..903444eddb385a3f8e8ed4e09e0c0eeeeea1f9e0 --- /dev/null +++ b/mxRAG/MainRepo/patches/whisper/trace_model.patch @@ -0,0 +1,343 @@ +diff --git a/whisper/decoding.py b/whisper/decoding.py +index 49485d0..495fe45 100644 +--- a/whisper/decoding.py ++++ b/whisper/decoding.py +@@ -2,6 +2,7 @@ from dataclasses import dataclass, field, replace + from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union + + import numpy as np ++import os + import torch + import torch.nn.functional as F + from torch import Tensor +@@ -49,12 +50,24 @@ def detect_language( + + # skip encoder forward pass if already-encoded audio features were given + if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state): ++ encoder_ts_model = torch.jit.trace(model.encoder, mel) ++ encoder_ts_model.save( ++ "/tmp/models/encoder.ts") ++ torch.onnx.export( ++ model.encoder, ++ (mel), ++ "/tmp/models/onnx/encode/encoder.onnx", ++ opset_version=11, ++ input_names=["mel"], ++ output_names=["ret"] ++ ) ++ + mel = model.encoder(mel) + + # forward pass using a single token, startoftranscript + n_audio = mel.shape[0] + x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1] +- logits = model.logits(x, mel)[:, 0] ++ logits = model.logits(x, mel)[0][:, 0] + + # collect detected languages; suppress all non-language tokens + mask = torch.ones(logits.shape[-1], dtype=torch.bool) +@@ -145,36 +158,74 @@ class PyTorchInference(Inference): + def __init__(self, model: "Whisper", initial_token_length: int): + self.model: "Whisper" = model + self.initial_token_length = initial_token_length +- self.kv_cache = {} +- self.hooks = [] +- +- key_modules = [block.attn.key for block in self.model.decoder.blocks] +- value_modules = [block.attn.value for block in self.model.decoder.blocks] +- self.kv_modules = key_modules + value_modules ++ self.cache_dyn = None ++ self.cache_sta = None + + def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: +- if not self.kv_cache: +- self.kv_cache, self.hooks = self.model.install_kv_cache_hooks() +- + if tokens.shape[-1] > self.initial_token_length: + # only need to use the last token except in the first forward pass + tokens = tokens[:, -1:] ++ pos_embed = self.model.decoder.positional_embedding[self.cache_dyn.shape[3]] ++ torch.onnx.export( ++ self.model.decoder, ++ (tokens, audio_features, pos_embed, self.cache_dyn, self.cache_sta), ++ "/tmp/models/onnx/decode/decoder_decode.onnx", ++ opset_version=11, ++ input_names=["tokens", "audio_features", "pos_embed", "cache_dyn", "cache_sta"], ++ output_names=["logits", "new_cache_dyn", "new_cache_sta"], ++ dynamic_axes={ ++ "cache_dyn": {3: "ntokens"}, ++ "new_cache_dyn": {3: "ntokens"} ++ } ++ ) ++ decoder_decode_ts_model = torch.jit.trace( ++ self.model.decoder, ++ (tokens, audio_features, pos_embed, self.cache_dyn, self.cache_sta) ++ ) ++ decoder_decode_ts_model.save( ++ "/tmp/models/decoder_decode.ts") ++ logits, cache_dyn, _ = self.model.decoder( ++ tokens, audio_features, pos_embed, self.cache_dyn, self.cache_sta) ++ os.sys.exit(0) ++ self.cache_dyn = cache_dyn ++ else: ++ pos_embed = self.model.decoder.positional_embedding[:tokens.shape[-1]] ++ torch.onnx.export( ++ self.model.decoder, ++ (tokens, audio_features, pos_embed), ++ "/tmp/models/onnx/prefill/decoder_prefill.onnx", ++ opset_version=11, ++ input_names=["tokens", "audio_features", "pos_embed"], ++ output_names=["logits", "cache_dyn", "cache_sta"], ++ dynamic_axes={ ++ "tokens": {1: "ntokens"}, ++ "pos_embed": {0: "ntokens"}, ++ "logits": {1: "ntokens"}, ++ "cache_dyn": {3: "ntokens"} ++ } ++ ) ++ decoder_prefill_ts_model = torch.jit.trace( ++ self.model.decoder, ++ (tokens, audio_features, pos_embed) ++ ) ++ decoder_prefill_ts_model.save( ++ "/tmp/models/decoder_prefill.ts") ++ logits, cache_dyn, cache_sta = self.model.decoder(tokens, audio_features, pos_embed) ++ self.cache_dyn = cache_dyn ++ self.cache_sta = cache_sta + +- return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache) ++ return logits + + def cleanup_caching(self): +- for hook in self.hooks: +- hook.remove() +- +- self.kv_cache = {} +- self.hooks = [] ++ self.cache_dyn = None ++ self.cache_sta = None + + def rearrange_kv_cache(self, source_indices): + if source_indices != list(range(len(source_indices))): +- for module in self.kv_modules: +- # update the key/value cache to contain the selected sequences +- self.kv_cache[module] = self.kv_cache[module][source_indices].detach() +- ++ blocks = self.cache_dyn.shape[0] ++ for i in range(blocks): ++ for j in range(2): # k and v 2 items ++ self.cache_dyn[i][j] = self.cache_dyn[i][j][source_indices] + + class SequenceRanker: + def rank( +diff --git a/whisper/model.py b/whisper/model.py +index a678283..2a95e28 100644 +--- a/whisper/model.py ++++ b/whisper/model.py +@@ -1,6 +1,7 @@ + import base64 + import gzip + from dataclasses import dataclass ++import os + from typing import Dict, Iterable, Optional + + import numpy as np +@@ -68,6 +69,63 @@ class MultiHeadAttention(nn.Module): + self.value = Linear(n_state, n_state) + self.out = Linear(n_state, n_state) + ++ def encoder_forward(self, x: Tensor): ++ q = self.query(x) ++ k = self.key(x) ++ v = self.value(x) ++ wv, qk = self.qkv_attention(q, k, v) ++ return self.out(wv) ++ ++ def prefill_self_attn_forward( ++ self, ++ x: Tensor, ++ mask: Tensor, ++ ): ++ q = self.query(x) ++ k = self.key(x) ++ v = self.value(x) ++ cache_dyn = torch.stack([k, v]) ++ wv, _ = self.qkv_attention(q, k, v, mask) ++ return self.out(wv), cache_dyn ++ ++ def prefill_cross_attn_forward( ++ self, ++ x: Tensor, ++ xa: Tensor, ++ ): ++ q = self.query(x) ++ k = self.key(xa) ++ v = self.value(xa) ++ cache_sta = torch.stack([k, v]) ++ wv, _ = self.qkv_attention(q, k, v) ++ return self.out(wv), cache_sta ++ ++ def decode_self_attn_forward( ++ self, ++ x: Tensor, ++ mask: Tensor, ++ cache_dyn: Tensor ++ ): ++ q = self.query(x) ++ token_k = self.key(x) ++ k = torch.cat([cache_dyn[0], token_k], dim=1).detach() ++ token_v = self.value(x) ++ v = torch.cat([cache_dyn[1], token_v], dim=1).detach() ++ new_cache_dyn = torch.stack([k, v]) ++ wv, _ = self.qkv_attention(q, k, v, mask) ++ return self.out(wv), new_cache_dyn ++ ++ def decode_cross_attn_forward( ++ self, ++ x: Tensor, ++ cache_sta: Tensor ++ ): ++ q = self.query(x) ++ k = cache_sta[0] ++ v = cache_sta[1] ++ wv, _ = self.qkv_attention(q, k, v) ++ return self.out(wv) ++ + def forward( + self, + x: Tensor, +@@ -126,6 +184,39 @@ class ResidualAttentionBlock(nn.Module): + ) + self.mlp_ln = LayerNorm(n_state) + ++ def encoder_forward(self, x: Tensor): ++ x = x + self.attn.encoder_forward(self.attn_ln(x)) ++ x = x + self.mlp(self.mlp_ln(x)) ++ return x ++ ++ def prefill_forward( ++ self, ++ x: Tensor, ++ xa: Tensor, ++ mask: Tensor, ++ ): ++ self_attn_out, new_cache_dyn = self.attn.prefill_self_attn_forward(self.attn_ln(x), mask) ++ x = x + self_attn_out ++ cross_attn_out, new_cache_sta = self.cross_attn.prefill_cross_attn_forward(self.cross_attn_ln(x), xa) ++ x = x + cross_attn_out ++ x = x + self.mlp(self.mlp_ln(x)) ++ return x, new_cache_dyn, new_cache_sta ++ ++ def decode_forward( ++ self, ++ x: Tensor, ++ xa: Tensor, ++ mask: Tensor, ++ cache_dyn: Tensor, ++ cache_sta: Tensor ++ ): ++ self_attn_out, new_cache_dyn = self.attn.decode_self_attn_forward(self.attn_ln(x), mask, cache_dyn) ++ x = x + self_attn_out ++ cross_attn_out = self.cross_attn.decode_cross_attn_forward(self.cross_attn_ln(x), cache_sta) ++ x = x + cross_attn_out ++ x = x + self.mlp(self.mlp_ln(x)) ++ return x, new_cache_dyn ++ + def forward( + self, + x: Tensor, +@@ -163,11 +254,10 @@ class AudioEncoder(nn.Module): + x = F.gelu(self.conv2(x)) + x = x.permute(0, 2, 1) + +- assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" + x = (x + self.positional_embedding).to(x.dtype) + + for block in self.blocks: +- x = block(x) ++ x = block.encoder_forward(x) + + x = self.ln_post(x) + return x +@@ -193,29 +283,56 @@ class TextDecoder(nn.Module): + mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) + self.register_buffer("mask", mask, persistent=False) + +- def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): +- """ +- x : torch.LongTensor, shape = (batch_size, <= n_ctx) +- the text tokens +- xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) +- the encoded audio features to be attended on +- """ +- offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 +- x = ( +- self.token_embedding(x) +- + self.positional_embedding[offset : offset + x.shape[-1]] +- ) +- x = x.to(xa.dtype) ++ def prefill(self, x: Tensor, xa: Tensor, pos_embed: Tensor): ++ x = (self.token_embedding(x) + pos_embed).to(xa.dtype) + ++ cache_dyn_list = [] ++ cache_sta_list = [] + for block in self.blocks: +- x = block(x, xa, mask=self.mask, kv_cache=kv_cache) ++ x, new_cache_dyn, new_cache_sta = block.prefill_forward(x, xa, self.mask) ++ cache_dyn_list.append(new_cache_dyn) ++ cache_sta_list.append(new_cache_sta) ++ ++ cache_dyn = torch.stack(cache_dyn_list) ++ cache_sta = torch.stack(cache_sta_list) + + x = self.ln(x) + logits = ( + x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) + ).float() + +- return logits ++ return logits, cache_dyn, cache_sta ++ ++ def decode(self, x: Tensor, xa: Tensor, pos_embed: Tensor, cache_dyn: Tensor, cache_sta: Tensor): ++ x = (self.token_embedding(x) + pos_embed).to(xa.dtype) ++ ++ cache_dyn_list = [] ++ for idx, block in enumerate(self.blocks): ++ x, new_cache_dyn = block.decode_forward(x, xa, self.mask, cache_dyn[idx], cache_sta[idx]) ++ cache_dyn_list.append(new_cache_dyn) ++ ++ new_cache_dyn = torch.stack(cache_dyn_list) ++ ++ x = self.ln(x) ++ logits = ( ++ x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) ++ ).float() ++ ++ return logits, new_cache_dyn ++ ++ def forward( ++ self, ++ x: Tensor, ++ xa: Tensor, ++ pos_embed: Tensor = None, ++ cache_dyn: Tensor = None, ++ cache_sta: Tensor = None, ++ ): ++ if cache_dyn is None: ++ logits, cache_dyn, cache_sta = self.prefill(x, xa, pos_embed) ++ else: ++ logits, cache_dyn = self.decode(x, xa, pos_embed, cache_dyn, cache_sta) ++ return logits, cache_dyn, cache_sta + + + class Whisper(nn.Module): +@@ -257,7 +374,8 @@ class Whisper(nn.Module): + return self.encoder(mel) + + def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): +- return self.decoder(tokens, audio_features) ++ pos_embed = self.decoder.positional_embedding[:tokens.shape[-1]] ++ return self.decoder(tokens, audio_features, pos_embed) + + def forward( + self, mel: torch.Tensor, tokens: torch.Tensor diff --git a/mxRAG/MainRepo/patches/whisper/whisper_patch.sh b/mxRAG/MainRepo/patches/whisper/whisper_patch.sh new file mode 100644 index 0000000000000000000000000000000000000000..5779dca9fd023297d8d18623c01337d10518bd52 --- /dev/null +++ b/mxRAG/MainRepo/patches/whisper/whisper_patch.sh @@ -0,0 +1,82 @@ +#!/bin/bash +#在patches/whisper目录下执行一键补丁脚本 +set -e + +MODEL_NAME='tiny' +DEVICE=0 +SOC_VERSION="Ascend310P3" + +PATCH_DIR=$(dirname $(readlink -f $0)) +dist_packages=$(python3 -c "import site;print(site.getsitepackages()[0])") +dist_packages_path=$(echo "$dist_packages" | sed "s/[',\[\]]//g") +echo "dist-packages path is: $dist_packages_path" + +# 初始化模型参数 +declare -A params +params[tiny]="4 384 80" +params[base]="6 512 80" +params[small]="12 768 80" +params[medium]="24 1024 80" +params[large]="32 1280 128" + +# 提取模型 +IFS=' ' read -r -a model_params <<< "${params[$MODEL_NAME]}" +N_BLOCKS=${model_params[0]} +HIDDEN=${model_params[1]} +N_MEIS=${model_params[2]} + +# 提取DEVICE参数 +sed -E -i "s/set_device\([0-9]+\)/set_device($DEVICE)/g" mindietorch_infer.patch +sed -E -i "s/npu:[0-9]+/npu:$DEVICE/g" mindietorch_infer.patch +echo "set device is: $DEVICE" + +function install_packages(){ + pip3 install onnx==1.16.1 + pip3 uninstall -y openai-whisper==20231117 + pip3 install openai-whisper==20231117 + +} + +function patch_trace_model(){ + cd $dist_packages_path + patch -p1 < $PATCH_DIR/trace_model.patch + cd $PATCH_DIR + DIRS=("/tmp/models" + "/tmp/models/onnx" + "/tmp/models/onnx/encode" + "/tmp/models/onnx/decode" + "/tmp/models/onnx/prefill" + ) + for dir in "${DIRS[@]}"; do + if [ ! -d "$dir" ]; then + mkdir -p "$dir" + echo "Directory $dir created." + else + echo "Directory $dir already exists." + fi + done + whisper zh.wav --model $MODEL_NAME +} + +function compile_model(){ + source /usr/local/Ascend/ascend-toolkit/set_env.sh + source /usr/local/Ascend/mindie/set_env.sh + source /usr/local/Ascend/ascend-toolkit/set_env.sh + python3 compile.py --nblocks $N_BLOCKS --hidden $HIDDEN --n_mels $N_MEIS --soc_version $SOC_VERSION +} + +function patch_mindietorch(){ + pip3 uninstall -y openai-whisper==20231117 + pip3 install openai-whisper==20231117 + cd $dist_packages_path + patch -p1 < $PATCH_DIR/mindietorch_infer.patch +} + +function main(){ + install_packages + patch_trace_model + compile_model + patch_mindietorch +} + +main \ No newline at end of file