From 8d938ce3f643c178f962a60ba324cc1d9571d596 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E6=B1=9F=E6=B1=9F?= Date: Fri, 18 Apr 2025 17:03:59 +0800 Subject: [PATCH] feat: bge-m3 --- .../built-in/embedding/bge-m3/README.md | 80 +++++++ .../built-in/embedding/bge-m3/infer.py | 221 ++++++++++++++++++ .../embedding/bge-m3/requirements.txt | 8 + 3 files changed, 309 insertions(+) create mode 100644 ACL_PyTorch/built-in/embedding/bge-m3/README.md create mode 100644 ACL_PyTorch/built-in/embedding/bge-m3/infer.py create mode 100644 ACL_PyTorch/built-in/embedding/bge-m3/requirements.txt diff --git a/ACL_PyTorch/built-in/embedding/bge-m3/README.md b/ACL_PyTorch/built-in/embedding/bge-m3/README.md new file mode 100644 index 0000000000..e139c5ebc2 --- /dev/null +++ b/ACL_PyTorch/built-in/embedding/bge-m3/README.md @@ -0,0 +1,80 @@ +# BGE-M3模型适配 + +- [概述](#概述) +- [推理环境准备](#推理环境准备) +- [快速上手](#快速上手) + - [获取源码](#获取源码) + - [模型推理](#模型推理) + - [开始推理验证](#开始推理验证) + - [性能](#性能) + +****** + +# 概述 +```BGE-M3```模型是BAAI General Embedding提出的先进的多语言、多功能文本Embedding模式。该模型基于Transformers Encoder,引入稀疏注意力和多向量检索,支持3种语义表示,同时还可以支持超过100种语言,最长可以处理8192序列长度,适合处理长文本。```BGE-M3```可以快速高效地生成3种不同的文本语义表示,通过语义表示间的不同组合,可以支持多种检索方式,在多语言、跨语言、长本文信息检索领域表现出色,为开发者提供了使用的工具。 + +# 推理环境准备 +- 该模型需要以下插件与驱动 + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + |--------------------------------------------------------------| ------ | ------------------------------------------------------------ | + | 固件与驱动 | 25.0.RC1 | [Pytorch框架推理环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/pies) | + | CANN | 8.1.RC1 | 包含kernels包和toolkit包 | + | Python | 3.10 | - | + | PyTorch | 2.5.1 | - | + | Ascend Extension PyTorch | 2.5.1.post2 | - | + | 说明:Atlas 800I A2 推理卡和Atlas 300I DUO 推理卡请以CANN版本选择实际固件与驱动版本。 | \ | \ | + +# 快速上手 + +## 获取源码 +1. 获取开源模型源码和权重(可选) + > 如果您的设备可以方便的直接从hugging-hub下载权重和代码,则不需要执行这一步 + ``` + # git模型下载,请确保已安装git lfs + git clone https://huggingface.co/BAAI/bge-m3 + cd bge-m3 + git reset --hard 5617a9f + ``` + 本地下载完成后的目录树如下: + ```TEXT + bge-m3/ + ├── colbert_linear.pt + ├── config.json + ├── config_sentence_transformers.json + ├── infer.py # 本仓库提供的自定义推理脚本 + ├── modules.json + ├── pytorch_model.bin + ├── sentence_bert_config.json + ├── sentencepiece.bpe.model + ├── sparse_linear.pt + ├── special_tokens_map.json + ├── tokenizer.json + └── tokenizer_config.json + ``` +2. 安装依赖 + ```SHELL + pip3 install FlagEmbedding transformers==4.51.1 + ``` + 其他基础依赖信息可参考`requirements.txt`文件。 + +## 模型推理 +### 开始推理验证 +设置环境变量,执行推理命令 +```SHELL +# 指定使用NPU ID,默认为0 +export ASCEND_RT_VISIBLE_DEVICES=0 +# 如果可以方便快速从huggingface-hub下载权重,则可以使用如下命令 +# python3 infer.py --model_path=BAAI/bge-m3 +python3 infer.py # 可以使用 --model_path 指定权重路径 +``` +在推理开始后,首先会默认执行warm_up,目的是执行首次编译,首次编译时间较长,在warm_up结束后,会执行推理操作,并打屏E2E性能数据。如果想测试模型推理耗时,可以在 `YOUR_ENV\FlagEmbedding\inference\embedder\encoder_only\m3.py` 文件423行 `outputs = self.model(...)` 前后添加时间打点。 +> 其中 YOUR_ENV 是你当前的环境路径,可以通过 ```pip show FlagEmbedding | grep Location``` 查看 + + +### 性能 + | 模型 | 芯片 | E2E | forward | + |--------|----------------|----------|---------| + | bge-m3 | Atlas 300I DUO | 137.59ms | 23.23ms | + | bge-m3 | Atlas 800I A2 | 103.88ms | 14.71ms | \ No newline at end of file diff --git a/ACL_PyTorch/built-in/embedding/bge-m3/infer.py b/ACL_PyTorch/built-in/embedding/bge-m3/infer.py new file mode 100644 index 0000000000..36a81e51df --- /dev/null +++ b/ACL_PyTorch/built-in/embedding/bge-m3/infer.py @@ -0,0 +1,221 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd +# [Software Name] is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +import time +import argparse +from typing import Tuple, Optional, Union, List, Any + +import torch +from torch import nn +import torch_npu +import torchair as tng +from torchair.configs.compiler_config import CompilerConfig + +from FlagEmbedding import BGEM3FlagModel +from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaEmbeddings + + +def parse_args(): + parser = argparse.ArgumentParser(description="BGE-M3 infer") + parser.add_argument("--model_path", type=str, default="./", + help="model local path (either local directory or huggingface-Hub)") + parser.add_argument('--warmup', type=int, default=4, help="Warm up times") + parser.add_argument('--loop', type=int, default=10, help="loop times") + parser.add_argument("--devices", type=str, default="['npu:0']", help="target npu devices") + args = parser.parse_args() + return args + + +def create_model(args): + model = BGEM3FlagModel("./", trust_remote_code=True) + model.target_devices = eval(args.devices) # model.target_devices默认=['npu:0', 'npu:1', 'npu:2', 'npu:3'] + return model + + +class MyXLMRobertaEmbeddings(XLMRobertaEmbeddings): + """ + 重写模型的Embedding层 + 修改原本XLMRobertaEmbeddings中的create_position_ids_from_input_ids方法 + 将 padding_idx 转换为与 input_ids 相同的设备和张量类型 + """ + def create_position_ids_from_input_ids(self, input_ids, padding_idx, past_key_values_length=0): + # 将 padding_idx 转换为与 input_ids 相同的设备和张量类型,避免 FakeTensor + padding_idx = torch.tensor(padding_idx, device=input_ids.device, dtype=input_ids.dtype) + + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if position_ids is None: + if input_ids is not None: + # 这里使用重写后的 self.create_position_ids_from_input_ids 方法 + position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +def rewrite_self_attention_forward(model): + """ + 此处有1个优化点: 使用一个Linear(qkv)来代替原有的3个Linear + """ + # 新建 Linear(qkv) 并设置权重 + wq = model.query.weight + wk = model.key.weight + wv = model.value.weight + model.qkv = nn.Linear(wq.shape[0], wq.shape[1] + wk.shape[1] + wv.shape[1]) + model.qkv.weight = nn.Parameter(torch.concat([wq, wk, wv], dim=0), requires_grad=False) + model.qkv.bias = nn.Parameter(torch.concat([ + model.query.bias, model.key.bias, model.value.bias + ], dim=0), requires_grad=False) + del model.query + del model.key + del model.value + + def forward( + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + # 使用新的qkv进行计算 + qkv_layers = model.qkv(hidden_states) + # 使用chunk得到单独的q, k, v + chunk_size = wq.shape[1] + query_layer = qkv_layers[:, :, :chunk_size] + key_layer = qkv_layers[:, :, chunk_size:chunk_size * 2] + value_layer = qkv_layers[:, :, chunk_size * 2:] + + bsz, tgt_len, _ = hidden_states.size() + + query_layer = model.transpose_for_scores(query_layer) + key_layer = model.transpose_for_scores(key_layer) + value_layer = model.transpose_for_scores(value_layer) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, model.all_head_size) + + outputs = (attn_output,) + if model.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + model.forward = forward + + +def modify_model(model): + xlm_roberta_config = model.model.model.config + xlm_roberta_embeddings = model.model.model.embeddings + model.model.model.embeddings = MyXLMRobertaEmbeddings(xlm_roberta_config) + model.model.model.embeddings.load_state_dict(xlm_roberta_embeddings.state_dict()) + model.model.model.embeddings.eval().half() + + for layer in model.model.model.encoder.layer: + rewrite_self_attention_forward(layer.attention.self) + + return model + + +if __name__ == '__main__': + args = parse_args() + + torch_npu.npu.set_compile_mode(jit_compile=False) + + # 设置torchair参数 + config = CompilerConfig() + config.experimental_config.frozen_parameter = True + npu_backend = tng.get_npu_backend(compiler_config=config) + + # 模型创建及torchair处理 + model = create_model(args) + model = modify_model(model) + model.model.eval().half() + model.model.forward = torch.compile(model.model.forward, dynamic=True, fullgraph=True, backend=npu_backend) + + sentences1 = ["What is BGE M3?", "Defination of BM25"] + sentences2 = [ + "BGE M3 is an embedding model supporting dense retrieval, lexical matching and multi-vector interaction.", + "BM25 is a bag-of-words retrieval function that ranks a set of documents based on the query terms appearing i each document" + ] + + with torch.inference_mode(): + for _ in range(args.warmup): + output1 = model.encode(sentences1, return_dense=True, return_sparse=True, return_colbert_vecs=True) + output2 = model.encode(sentences2, return_dense=True, return_sparse=True, return_colbert_vecs=True) + + dense_vecs1 = output1['dense_vecs'] + dense_vecs2 = output2['dense_vecs'] + print("sentences1 = {}".format(sentences1)) + print("sentences2 = {}".format(sentences2)) + print(" dense similarity scores = {}".format(dense_vecs1 @ dense_vecs2.T)) + + lexical_weights1 = output1['lexical_weights'] + lexical_weights2 = output2['lexical_weights'] + print("lexical_weights for sentences1 = {}".format(model.convert_id_to_token(lexical_weights1))) + print("lexical_weights for sentences2 = {}".format(model.convert_id_to_token(lexical_weights2))) + lexical_similarity = model.compute_lexical_matching_score(lexical_weights1, lexical_weights2) + print("lexical similarity scores = {}".format(lexical_similarity)) + + multi_vecs1 = output1['colbert_vecs'] + multi_vecs2 = output2['colbert_vecs'] + multi_vecs_similarity = [ + model.colbert_score(multi_vecs1[0], multi_vecs2[0]), + model.colbert_score(multi_vecs1[0], multi_vecs2[1]), + model.colbert_score(multi_vecs1[1], multi_vecs2[0]), + model.colbert_score(multi_vecs1[1], multi_vecs2[1]), + ] + print("multi vecs similarity scores = {}".format(multi_vecs_similarity)) + + start_time = time.time() + for _ in range(args.loop): + output1 = model.encode(sentences1, return_dense=True, return_sparse=True, return_colbert_vecs=True) + output2 = model.encode(sentences2, return_dense=True, return_sparse=True, return_colbert_vecs=True) + print(f'E2E time = {(time.time() - start_time) / args.loop * 1000}ms') \ No newline at end of file diff --git a/ACL_PyTorch/built-in/embedding/bge-m3/requirements.txt b/ACL_PyTorch/built-in/embedding/bge-m3/requirements.txt new file mode 100644 index 0000000000..54edc7eec2 --- /dev/null +++ b/ACL_PyTorch/built-in/embedding/bge-m3/requirements.txt @@ -0,0 +1,8 @@ +decorator==5.2.1 +einops==0.8.1 +FlagEmbedding==1.3.4 +numpy==1.26.4 +PyYAML==6.0.2 +tokenizers==0.21.1 +torch==2.5.1 +transformers==4.51.1 -- Gitee