From 94ad5095b65b7195d8146aa5ede38f7d0cbde0c4 Mon Sep 17 00:00:00 2001 From: commc Date: Wed, 4 Sep 2024 15:41:41 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E5=AE=A2=E6=88=B7=E9=9C=80=E6=B1=82infer?= =?UTF-8?q?=E5=8F=8Areadme?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/multimodal/README.md | 261 ++++++++++++++++++ .../built-in/multimodal/clip_infer.py | 99 +++++++ 2 files changed, 360 insertions(+) create mode 100644 MindIE/MindIE-Torch/built-in/multimodal/README.md create mode 100644 MindIE/MindIE-Torch/built-in/multimodal/clip_infer.py diff --git a/MindIE/MindIE-Torch/built-in/multimodal/README.md b/MindIE/MindIE-Torch/built-in/multimodal/README.md new file mode 100644 index 0000000000..a103850620 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/README.md @@ -0,0 +1,261 @@ +# CLIP-推理指导 + + +- [概述](#ZH-CN_TOPIC_0000001172161501) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + +- [模型推理性能精度](#ZH-CN_TOPIC_0000001172201573) + + +# 概述 + +([来自开源代码仓](https://github.com/openai/CLIP))CLIP使用大规模(图片、文本)数据对进行训练,旨在帮助用户快速实现图文特征&相似度计算、跨模态检索、零样本图片分类等任务。该项目旨在演示在多判断分支、使用hook机制的场景下,如何使用MindIE组件对主要计算模块进行加速。端到端推理演示通过命令行调用的形式呈现,每次都要花费很长时间载入模型,且各主要计算模块之间使用cpu计算串联,不适合直接在业务中应用。开发者可以根据该项目代码,配套Torch_NPU,开发出高性能的图文识别服务。 + +# 推理环境准备[所有版本] +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + | 配套 | 版本 | + |---------| ------- | + | 固件与驱动 | - | + | CANN | - | + | Python | 3.10.13 | + | PyTorch | 2.1.0 | + | MindIE | - | + + 注意:由于MindIE暂无支持该模型的商发版本,烦请用户联系华为工程师获取对应的固件驱动,CANN,MindIE PoC版本链接。 + 固件驱动和CANN的安装,请参考昇腾官方文档[环境快速部署](https://www.hiascend.com/document/detail/zh/quick-installation/24.0.RC1/quickinstg/800_3000/quickinstg_800_3000_0001.html)。 + + MindIE的安装需要先source toolkit的环境变量,然后直接安装,以默认安装路径`/usr/local/Ascend`为例: + ``` + source /usr/local/Ascend/ascend-tookit/set_env.sh + bash Ascend-mindie_*.run --install + ``` + + +# 快速上手 +1. 安装依赖 + ``` + pip install transformers==4.41.1 + ``` +2. 模型下载-本项目支持clip-zh:VIT-B-16以及clip-en:VIT-L/14 + ``` + CLIP: + git clone https://huggingface.co/openai/clip-vit-large-patch14-336 + ChineseCLIP: + git clone https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16 + ``` +3. 导出onnx模型-以ChineseCLIP为例 + ```shell + cd CLIP/ + hf_model_path="/path/to/chinese-clip-vit-base-patch16" + python ./export_onnx.py \ + --text-max-batch 80 \ + --img-max-batch 8 \ + --max-token-len 52 \ + --model-version "ViT-B-16" \ + --hf-model-path $hf_model_path \ + --save-dir "./" + ``` + 执行结束后,会在save-dir目录下生成CLIP-ViT-B-16.onnx + + 参数说明: + - --hf-model-path:下载的Huggingface模型路径。 + - --soc_version:芯片类型,当前仅在Ascend910B4上调试。 + - --text-max-batch:最大文本batchsize,默认最大值为80。 + - --img-max-batch:最大图片batchsize,默认最大值为8。 + - --max-token-len:文本token最大嵌入长度,默认最大值为52。 + - --model-version:CLIP版本。 + - 其他参数请参考脚本`parse_args`部分。 + +4. 编译静态输入ts模型 + ```shell + cd CLIP/ + hf_model_path=/path/to/chinese-clip-vit-base-patch16 + python ./compile_ts.py \ + --soc-version "Ascend910B4" \ + --device-id 0 \ + --text-batch 5 \ + --img-batch 1 \ + --token-len 52 \ + --model-version "ViT-B-16" \ + --hf-model-path $hf_model_path \ + --precision "fp16" \ + --save-dir "./" + ``` + 执行结束后,会在save-dir目录下生成CLIP_{--model-version}-MindIE.ts + + 参数说明: + - --hf-model-path:下载的Huggingface模型路径。 + - --soc_version: 芯片类型,当前仅在Ascend910B4上调试。 + - --text-batch: 固定输入的文本batchsize,默认值为80。 + - --img-max-batch: 固定输入的图片batchsize,默认值为1。 + - --token-len: 固定的文本token嵌入长度,默认值为52。 + - --model-version: CLIP版本。 + - 其他参数请参考脚本`parse_args`部分。 + +5. 编译动态输入aie模型(可选) + ```shell + cd chiness_clip/ + hf_model_path="/path/to/huggingface_clip_model" + python ./compile_aie.py \ + --soc-version "Ascend910B4" \ + --device-id 0 \ + --text-max-batch 80 \ + --img-max-batch 8 \ + --max-token-len 52 \ + --model-version "ViT-B-16" \ + --hf-model-path $hf_model_path \ + --precision "fp16" \ + --save-dir "./" + ``` + 执行结束后,会在save-dir目录下生成CLIP_{--model-version}-MindIE.pt + + 参数说明: + - --hf-model-path:下载的Huggingface模型路径。 + - --soc_version: 芯片类型,当前仅在Ascend910B4上调试。 + - --text-max-batch: 最大文本batchsize,默认最大值为80。 + - --img-max-batch: 最大图片batchsize,默认最大值为8。 + - --max-token-len: 文本token最大嵌入长度,,默认最大值为52。 + - --model-version: CLIP版本。 + - 其他参数请参考脚本`parse_args`部分。 + + +6. 模型推理 + ```shell + aie_path="/path/to/ts_model" + hf_model_path="/path/to huggingface_clip_model" + image_path="/path/to/image" + python3 clip_infer.py \ + --device-id 0 \ + --clip-aie-path $aie_path \ + --hf-model-path $hf_model_path \ + --max-token-len 52 \ + --image-path $image_path \ + --text "杰尼龟" "妙蛙种子" "小火龙" "皮卡丘" "image内容" + ``` + 推理结束后,会在命令行打印出类似如下输出: + ``` + Probs per image: [[0.00939221866428852, 0.03177346661686897, 0.010315335355699062, 0.07622058689594269, 0.8722984194755554]] + ``` + 参数说明: + - --hf-model-path: 所下载的Huggingface模型文件夹路径。 + - --clip-aie-path:第4步或第5步编译得到的模型路径。clip-aie-path的模型必须与hf-model-path中的模型类型一致。 + - --image-path: 图片路径,当前仅支持单张图片推理。 + - --text:所需要的输入形式如:List[str, ]形式,例如:["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"]。其中每一个字段代表一个text_batch,如果导出的是静态ts模型,确保text_batch维度和编译时设置的text_batch输入维度保持一致。 + - 其他参数请参考脚本`parse_args`部分。 + + +# 模型推理性能精度 + +1. 精度验证 + ```shell + hf_model_path="/path/to huggingface_clip_model" + aie_path="/path/to/ts_model" + onnx_path="/path/to/onnx_model" + python3 precision_test.py \ + --device-id 0 \ + --clip-aie-path $aie_path \ + --clip-onnx-path $onnx_path \ + --hf-model-path $hf_model_path \ + --text-batchsize 5 \ + --image-batchsize 1 \ + --token-len 52 \ + --sim-threshold 0.99 + ``` + + 参数说明: + - --hf-model-path: 下载的Huggingface模型文件夹路径。 + - --clip-aie-path:由hf-model-path模型经过第4步或第5步编译得到的模型路径。 + - --clip-onnx-path:onnx_export.py文件导出的onnx模型路径。 + - --sim-threshold: 余弦相似度阈值,默认0.99。 + - --text-batchsize: dummy文本输入的batchsize,默认10,对于ts模型这里的输入维度需要与编译时设置的输入维度一致。 + - --token-len:dummy文本输入的token长度,默认为20,对于ts模型这里的输入维度需要与编译时设置的输入维度一致。 + - --image-batchsize: dummy图片输入的batchsize,默认为1,对于ts模型这里的输入维度需要与编译时设置的输入维度一致。 + - 其他参数请参考脚本`parse_args`部分。 + + 执行结束后,期望输出如下: + ``` + === Compare the outputs of ONNX and AIE === + Start comparing clip... + Number of outputs to compare: 4 + Number of outputs with cosine similarity > 0.99: 4 + ``` + +2. 性能验证 + + a) aie模型性能测试-以VIT-B-16为例 + ```shell + aie_path="/path/to/ts_model or pt model" + hf_model_path="/path/tohuggingface_clip_model" + python3 perf_test_aie.py \ + --device-id 0 \ + --clip-aie-path $aie_path \ + --hf-model-path $hf_model_path \ + --text-batchsize 5 \ + --token-len 52 \ + --image-batchsize 1 + ``` + + 参数说明: + - --hf-model-path: Huggingface模型文件夹路径。 + - --clip-aie-path:由hf-model-path模型经过第4步或第5步编译得到的模型路径。 + - --text-batchsize: dummy文本输入的batchsize,默认10。 + - --token-len:dummy文本输入的token长度,默认为20。 + - --image-batchsize: dummy图片输入的batchsize,默认为1。 + - 其他参数请参考脚本`parse_args`部分。 + + 执行结束后,期望输出如下: + ``` + CLIP latency: 40.47 ms + CLIP throughput: 24.71 fps + ``` + + b) onnx模型性能测试 + (可选)若使用GPU,请确保已安装CUDA和pytorch-gpu版本,同时需安装onnxruntime-gpu,如下所示: + ```shell + pip uninstall onnxruntime + pip install onnxruntime-gpu + ``` + 验证onnxruntime-gpu是否安装成功: + ```python + import onnxruntime + print(onnxruntime.get_device()) # 若输出为GPU,则说明安装成功 + ``` + 执行性能测试 + ```shell + onnx_path="/path/to/onnx_model" + hf_model_path="/path/tohuggingface_clip_model" + python perf_test_onnx.py \ + --onnx-path $onnx_path \ + --hf-model-path $hf_model_path \ + --text-batchsize 5 \ + --token-len 52 \ + --image-batchsize 1 + ``` + + 参数说明: + - --use-gpu: 使能gpu推理,不加该选项默认cpu。 + - --hf-model-path: Huggingface模型文件夹路径。 + - --clip-onnx-path:hf-model-path经过第3步导出的onnx模型路径。 + - --text-batchsize: dummy文本输入的batchsize,默认10。 + - --token-len:dummy文本输入的token长度,默认为20。 + - --image-batchsize: dummy图片输入的batchsize,默认为1。 + - 其他参数请参考脚本`parse_args`部分。 + + 执行结束后,期望输出如下: + ``` + CLIP latency: 181.08 ms + CLIP throughput: 5.51 fps + ``` + + + | 模型 | pt插件 - 910B4性能(时延/吞吐率) | + |---------|--------------------------------| + | Chinese-CLIP-ViT-Base-Patch16 | 40.47 ms / 24.71 fps | + | clip-vit-large-patch14-336 | 47.73 ms / 21.12 fps | + + 注:该性能是ts模型在text-batch=5, image-batch=1, token-len=52设置下测试10次取平均值得到,以上数据仅作参考。 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/multimodal/clip_infer.py b/MindIE/MindIE-Torch/built-in/multimodal/clip_infer.py new file mode 100644 index 0000000000..4157194111 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/clip_infer.py @@ -0,0 +1,99 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import argparse +import torch +import mindietorch +from PIL import Image +import torch.nn.functional as F +from transformers import AutoProcessor + +logging.basicConfig(level=logging.INFO) + + +def inference(args): + device = f'npu:{args.device_id}' + stream = mindietorch.npu.Stream(device) + if args.clip_aie_path.endswith(".ts"): + aie_model = torch.jit.load(args.clip_aie_path) + else: + aie_model = torch.load(args.clip_aie_path) + aie_model.eval().to(device) + + processor = AutoProcessor.from_pretrained(args.hf_model_path) + inputs = processor(text=args.text, images=Image.open(args.image_path), return_tensors="pt", padding=True) + input_ids = inputs.input_ids.to(torch.int32) + attention_mask = inputs.attention_mask.to(torch.int32) + pad_length = args.max_token_len + cur_length = inputs.input_ids.size(-1) + if pad_length > cur_length: + input_ids = F.pad(input_ids, (0, pad_length - cur_length), value=0) + attention_mask = F.pad(attention_mask, (0, pad_length - cur_length), value=0) + + input_ids = input_ids.to(torch.int32).to(device) + attention_mask = attention_mask.to(torch.int32).to(device) + input_img = inputs.pixel_values.to(torch.float32).to(device) + inputs = [input_ids, input_img, attention_mask] + + with mindietorch.npu.stream(stream): + aie_out = aie_model(*inputs) + stream.synchronize() + + aie_out = [x.cpu() for x in aie_out] + logging.info("Probs per image: %s", aie_out[-1].softmax(dim=-1).tolist()) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--device-id", type=int, help="NPU device id", default=0) + parser.add_argument( + "--clip-aie-path", + type=str, + default="/Path/to/compiled/aie_or_ts_model" + ) + parser.add_argument( + "--hf-model-path", + default="/Path/to/Huggingface_model_path", + type=str, + help="Huggingface CLIP Model Path." + ) + parser.add_argument( + "--max-token-len", + type=int, + default=52, + help="Manually pad input ids to max-token-len" + ) + parser.add_argument( + "--image-path", + type=str, + default="/Path/to/image" + ) + parser.add_argument( + "--text", + type=str, + nargs='+', + default=["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘", "image内容"], + ) + + return parser.parse_args() + + +def main(): + infer_args = parse_args() + mindietorch.set_device(infer_args.device_id) + inference(infer_args) + + +if __name__ == "__main__": + main() \ No newline at end of file -- Gitee From 02a270e8791de591bf62465365322990f84468e3 Mon Sep 17 00:00:00 2001 From: commc Date: Thu, 5 Sep 2024 17:11:35 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E6=AD=A3=E7=A1=AE=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E5=A4=B9=E5=B1=82=E7=BA=A7=E6=95=B4=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- MindIE/MindIE-Torch/built-in/multimodal/{ => CLIP}/README.md | 0 MindIE/MindIE-Torch/built-in/multimodal/{ => CLIP}/clip_infer.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename MindIE/MindIE-Torch/built-in/multimodal/{ => CLIP}/README.md (100%) rename MindIE/MindIE-Torch/built-in/multimodal/{ => CLIP}/clip_infer.py (100%) diff --git a/MindIE/MindIE-Torch/built-in/multimodal/README.md b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/README.md similarity index 100% rename from MindIE/MindIE-Torch/built-in/multimodal/README.md rename to MindIE/MindIE-Torch/built-in/multimodal/CLIP/README.md diff --git a/MindIE/MindIE-Torch/built-in/multimodal/clip_infer.py b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/clip_infer.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/multimodal/clip_infer.py rename to MindIE/MindIE-Torch/built-in/multimodal/CLIP/clip_infer.py -- Gitee