From 43c31d3e1362f44478761f5cff5a202a0de4b30c Mon Sep 17 00:00:00 2001 From: commc Date: Wed, 4 Sep 2024 11:35:42 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=AE=A2=E6=88=B7=E9=9C=80?= =?UTF-8?q?=E6=B1=82=E6=A8=A1=E5=9E=8BCLIP?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/multimodal/CLIP/README.md | 261 ++++++++++++++++++ .../built-in/multimodal/CLIP/clip_infer.py | 99 +++++++ .../built-in/multimodal/CLIP/compile_aie.py | 169 ++++++++++++ .../built-in/multimodal/CLIP/compile_ts.py | 147 ++++++++++ .../built-in/multimodal/CLIP/export_onnx.py | 126 +++++++++ .../built-in/multimodal/CLIP/perf_test_aie.py | 99 +++++++ .../multimodal/CLIP/perf_test_onnx.py | 97 +++++++ .../multimodal/CLIP/precision_test.py | 134 +++++++++ 8 files changed, 1132 insertions(+) create mode 100644 MindIE/MindIE-Torch/built-in/multimodal/CLIP/README.md create mode 100644 MindIE/MindIE-Torch/built-in/multimodal/CLIP/clip_infer.py create mode 100644 MindIE/MindIE-Torch/built-in/multimodal/CLIP/compile_aie.py create mode 100644 MindIE/MindIE-Torch/built-in/multimodal/CLIP/compile_ts.py create mode 100644 MindIE/MindIE-Torch/built-in/multimodal/CLIP/export_onnx.py create mode 100644 MindIE/MindIE-Torch/built-in/multimodal/CLIP/perf_test_aie.py create mode 100644 MindIE/MindIE-Torch/built-in/multimodal/CLIP/perf_test_onnx.py create mode 100644 MindIE/MindIE-Torch/built-in/multimodal/CLIP/precision_test.py diff --git a/MindIE/MindIE-Torch/built-in/multimodal/CLIP/README.md b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/README.md new file mode 100644 index 0000000000..a103850620 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/README.md @@ -0,0 +1,261 @@ +# CLIP-推理指导 + + +- [概述](#ZH-CN_TOPIC_0000001172161501) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + +- [模型推理性能精度](#ZH-CN_TOPIC_0000001172201573) + + +# 概述 + +([来自开源代码仓](https://github.com/openai/CLIP))CLIP使用大规模(图片、文本)数据对进行训练,旨在帮助用户快速实现图文特征&相似度计算、跨模态检索、零样本图片分类等任务。该项目旨在演示在多判断分支、使用hook机制的场景下,如何使用MindIE组件对主要计算模块进行加速。端到端推理演示通过命令行调用的形式呈现,每次都要花费很长时间载入模型,且各主要计算模块之间使用cpu计算串联,不适合直接在业务中应用。开发者可以根据该项目代码,配套Torch_NPU,开发出高性能的图文识别服务。 + +# 推理环境准备[所有版本] +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + | 配套 | 版本 | + |---------| ------- | + | 固件与驱动 | - | + | CANN | - | + | Python | 3.10.13 | + | PyTorch | 2.1.0 | + | MindIE | - | + + 注意:由于MindIE暂无支持该模型的商发版本,烦请用户联系华为工程师获取对应的固件驱动,CANN,MindIE PoC版本链接。 + 固件驱动和CANN的安装,请参考昇腾官方文档[环境快速部署](https://www.hiascend.com/document/detail/zh/quick-installation/24.0.RC1/quickinstg/800_3000/quickinstg_800_3000_0001.html)。 + + MindIE的安装需要先source toolkit的环境变量,然后直接安装,以默认安装路径`/usr/local/Ascend`为例: + ``` + source /usr/local/Ascend/ascend-tookit/set_env.sh + bash Ascend-mindie_*.run --install + ``` + + +# 快速上手 +1. 安装依赖 + ``` + pip install transformers==4.41.1 + ``` +2. 模型下载-本项目支持clip-zh:VIT-B-16以及clip-en:VIT-L/14 + ``` + CLIP: + git clone https://huggingface.co/openai/clip-vit-large-patch14-336 + ChineseCLIP: + git clone https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16 + ``` +3. 导出onnx模型-以ChineseCLIP为例 + ```shell + cd CLIP/ + hf_model_path="/path/to/chinese-clip-vit-base-patch16" + python ./export_onnx.py \ + --text-max-batch 80 \ + --img-max-batch 8 \ + --max-token-len 52 \ + --model-version "ViT-B-16" \ + --hf-model-path $hf_model_path \ + --save-dir "./" + ``` + 执行结束后,会在save-dir目录下生成CLIP-ViT-B-16.onnx + + 参数说明: + - --hf-model-path:下载的Huggingface模型路径。 + - --soc_version:芯片类型,当前仅在Ascend910B4上调试。 + - --text-max-batch:最大文本batchsize,默认最大值为80。 + - --img-max-batch:最大图片batchsize,默认最大值为8。 + - --max-token-len:文本token最大嵌入长度,默认最大值为52。 + - --model-version:CLIP版本。 + - 其他参数请参考脚本`parse_args`部分。 + +4. 编译静态输入ts模型 + ```shell + cd CLIP/ + hf_model_path=/path/to/chinese-clip-vit-base-patch16 + python ./compile_ts.py \ + --soc-version "Ascend910B4" \ + --device-id 0 \ + --text-batch 5 \ + --img-batch 1 \ + --token-len 52 \ + --model-version "ViT-B-16" \ + --hf-model-path $hf_model_path \ + --precision "fp16" \ + --save-dir "./" + ``` + 执行结束后,会在save-dir目录下生成CLIP_{--model-version}-MindIE.ts + + 参数说明: + - --hf-model-path:下载的Huggingface模型路径。 + - --soc_version: 芯片类型,当前仅在Ascend910B4上调试。 + - --text-batch: 固定输入的文本batchsize,默认值为80。 + - --img-max-batch: 固定输入的图片batchsize,默认值为1。 + - --token-len: 固定的文本token嵌入长度,默认值为52。 + - --model-version: CLIP版本。 + - 其他参数请参考脚本`parse_args`部分。 + +5. 编译动态输入aie模型(可选) + ```shell + cd chiness_clip/ + hf_model_path="/path/to/huggingface_clip_model" + python ./compile_aie.py \ + --soc-version "Ascend910B4" \ + --device-id 0 \ + --text-max-batch 80 \ + --img-max-batch 8 \ + --max-token-len 52 \ + --model-version "ViT-B-16" \ + --hf-model-path $hf_model_path \ + --precision "fp16" \ + --save-dir "./" + ``` + 执行结束后,会在save-dir目录下生成CLIP_{--model-version}-MindIE.pt + + 参数说明: + - --hf-model-path:下载的Huggingface模型路径。 + - --soc_version: 芯片类型,当前仅在Ascend910B4上调试。 + - --text-max-batch: 最大文本batchsize,默认最大值为80。 + - --img-max-batch: 最大图片batchsize,默认最大值为8。 + - --max-token-len: 文本token最大嵌入长度,,默认最大值为52。 + - --model-version: CLIP版本。 + - 其他参数请参考脚本`parse_args`部分。 + + +6. 模型推理 + ```shell + aie_path="/path/to/ts_model" + hf_model_path="/path/to huggingface_clip_model" + image_path="/path/to/image" + python3 clip_infer.py \ + --device-id 0 \ + --clip-aie-path $aie_path \ + --hf-model-path $hf_model_path \ + --max-token-len 52 \ + --image-path $image_path \ + --text "杰尼龟" "妙蛙种子" "小火龙" "皮卡丘" "image内容" + ``` + 推理结束后,会在命令行打印出类似如下输出: + ``` + Probs per image: [[0.00939221866428852, 0.03177346661686897, 0.010315335355699062, 0.07622058689594269, 0.8722984194755554]] + ``` + 参数说明: + - --hf-model-path: 所下载的Huggingface模型文件夹路径。 + - --clip-aie-path:第4步或第5步编译得到的模型路径。clip-aie-path的模型必须与hf-model-path中的模型类型一致。 + - --image-path: 图片路径,当前仅支持单张图片推理。 + - --text:所需要的输入形式如:List[str, ]形式,例如:["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"]。其中每一个字段代表一个text_batch,如果导出的是静态ts模型,确保text_batch维度和编译时设置的text_batch输入维度保持一致。 + - 其他参数请参考脚本`parse_args`部分。 + + +# 模型推理性能精度 + +1. 精度验证 + ```shell + hf_model_path="/path/to huggingface_clip_model" + aie_path="/path/to/ts_model" + onnx_path="/path/to/onnx_model" + python3 precision_test.py \ + --device-id 0 \ + --clip-aie-path $aie_path \ + --clip-onnx-path $onnx_path \ + --hf-model-path $hf_model_path \ + --text-batchsize 5 \ + --image-batchsize 1 \ + --token-len 52 \ + --sim-threshold 0.99 + ``` + + 参数说明: + - --hf-model-path: 下载的Huggingface模型文件夹路径。 + - --clip-aie-path:由hf-model-path模型经过第4步或第5步编译得到的模型路径。 + - --clip-onnx-path:onnx_export.py文件导出的onnx模型路径。 + - --sim-threshold: 余弦相似度阈值,默认0.99。 + - --text-batchsize: dummy文本输入的batchsize,默认10,对于ts模型这里的输入维度需要与编译时设置的输入维度一致。 + - --token-len:dummy文本输入的token长度,默认为20,对于ts模型这里的输入维度需要与编译时设置的输入维度一致。 + - --image-batchsize: dummy图片输入的batchsize,默认为1,对于ts模型这里的输入维度需要与编译时设置的输入维度一致。 + - 其他参数请参考脚本`parse_args`部分。 + + 执行结束后,期望输出如下: + ``` + === Compare the outputs of ONNX and AIE === + Start comparing clip... + Number of outputs to compare: 4 + Number of outputs with cosine similarity > 0.99: 4 + ``` + +2. 性能验证 + + a) aie模型性能测试-以VIT-B-16为例 + ```shell + aie_path="/path/to/ts_model or pt model" + hf_model_path="/path/tohuggingface_clip_model" + python3 perf_test_aie.py \ + --device-id 0 \ + --clip-aie-path $aie_path \ + --hf-model-path $hf_model_path \ + --text-batchsize 5 \ + --token-len 52 \ + --image-batchsize 1 + ``` + + 参数说明: + - --hf-model-path: Huggingface模型文件夹路径。 + - --clip-aie-path:由hf-model-path模型经过第4步或第5步编译得到的模型路径。 + - --text-batchsize: dummy文本输入的batchsize,默认10。 + - --token-len:dummy文本输入的token长度,默认为20。 + - --image-batchsize: dummy图片输入的batchsize,默认为1。 + - 其他参数请参考脚本`parse_args`部分。 + + 执行结束后,期望输出如下: + ``` + CLIP latency: 40.47 ms + CLIP throughput: 24.71 fps + ``` + + b) onnx模型性能测试 + (可选)若使用GPU,请确保已安装CUDA和pytorch-gpu版本,同时需安装onnxruntime-gpu,如下所示: + ```shell + pip uninstall onnxruntime + pip install onnxruntime-gpu + ``` + 验证onnxruntime-gpu是否安装成功: + ```python + import onnxruntime + print(onnxruntime.get_device()) # 若输出为GPU,则说明安装成功 + ``` + 执行性能测试 + ```shell + onnx_path="/path/to/onnx_model" + hf_model_path="/path/tohuggingface_clip_model" + python perf_test_onnx.py \ + --onnx-path $onnx_path \ + --hf-model-path $hf_model_path \ + --text-batchsize 5 \ + --token-len 52 \ + --image-batchsize 1 + ``` + + 参数说明: + - --use-gpu: 使能gpu推理,不加该选项默认cpu。 + - --hf-model-path: Huggingface模型文件夹路径。 + - --clip-onnx-path:hf-model-path经过第3步导出的onnx模型路径。 + - --text-batchsize: dummy文本输入的batchsize,默认10。 + - --token-len:dummy文本输入的token长度,默认为20。 + - --image-batchsize: dummy图片输入的batchsize,默认为1。 + - 其他参数请参考脚本`parse_args`部分。 + + 执行结束后,期望输出如下: + ``` + CLIP latency: 181.08 ms + CLIP throughput: 5.51 fps + ``` + + + | 模型 | pt插件 - 910B4性能(时延/吞吐率) | + |---------|--------------------------------| + | Chinese-CLIP-ViT-Base-Patch16 | 40.47 ms / 24.71 fps | + | clip-vit-large-patch14-336 | 47.73 ms / 21.12 fps | + + 注:该性能是ts模型在text-batch=5, image-batch=1, token-len=52设置下测试10次取平均值得到,以上数据仅作参考。 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/multimodal/CLIP/clip_infer.py b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/clip_infer.py new file mode 100644 index 0000000000..4157194111 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/clip_infer.py @@ -0,0 +1,99 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import argparse +import torch +import mindietorch +from PIL import Image +import torch.nn.functional as F +from transformers import AutoProcessor + +logging.basicConfig(level=logging.INFO) + + +def inference(args): + device = f'npu:{args.device_id}' + stream = mindietorch.npu.Stream(device) + if args.clip_aie_path.endswith(".ts"): + aie_model = torch.jit.load(args.clip_aie_path) + else: + aie_model = torch.load(args.clip_aie_path) + aie_model.eval().to(device) + + processor = AutoProcessor.from_pretrained(args.hf_model_path) + inputs = processor(text=args.text, images=Image.open(args.image_path), return_tensors="pt", padding=True) + input_ids = inputs.input_ids.to(torch.int32) + attention_mask = inputs.attention_mask.to(torch.int32) + pad_length = args.max_token_len + cur_length = inputs.input_ids.size(-1) + if pad_length > cur_length: + input_ids = F.pad(input_ids, (0, pad_length - cur_length), value=0) + attention_mask = F.pad(attention_mask, (0, pad_length - cur_length), value=0) + + input_ids = input_ids.to(torch.int32).to(device) + attention_mask = attention_mask.to(torch.int32).to(device) + input_img = inputs.pixel_values.to(torch.float32).to(device) + inputs = [input_ids, input_img, attention_mask] + + with mindietorch.npu.stream(stream): + aie_out = aie_model(*inputs) + stream.synchronize() + + aie_out = [x.cpu() for x in aie_out] + logging.info("Probs per image: %s", aie_out[-1].softmax(dim=-1).tolist()) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--device-id", type=int, help="NPU device id", default=0) + parser.add_argument( + "--clip-aie-path", + type=str, + default="/Path/to/compiled/aie_or_ts_model" + ) + parser.add_argument( + "--hf-model-path", + default="/Path/to/Huggingface_model_path", + type=str, + help="Huggingface CLIP Model Path." + ) + parser.add_argument( + "--max-token-len", + type=int, + default=52, + help="Manually pad input ids to max-token-len" + ) + parser.add_argument( + "--image-path", + type=str, + default="/Path/to/image" + ) + parser.add_argument( + "--text", + type=str, + nargs='+', + default=["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘", "image内容"], + ) + + return parser.parse_args() + + +def main(): + infer_args = parse_args() + mindietorch.set_device(infer_args.device_id) + inference(infer_args) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/multimodal/CLIP/compile_aie.py b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/compile_aie.py new file mode 100644 index 0000000000..35f13c9747 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/compile_aie.py @@ -0,0 +1,169 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time +import json +import logging +import argparse +import torch +import mindietorch +import torch.nn as nn +from torch._export import export, dynamic_dim +from mindietorch import _enums +from transformers.models.auto.modeling_auto import AutoModel + +logging.basicConfig(level=logging.INFO) + + +class CLIPWrapper(nn.Module): + def __init__(self, clip): + super(CLIPWrapper, self).__init__() + self.model = clip + self.logit_scale = clip.logit_scale.exp().to(self.model.device) + + def forward(self, input_ids, pixel_values, attention_mask): + image_embeds = self.model.get_image_features(pixel_values) + text_embeds = self.model.get_text_features(input_ids, attention_mask) + + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + logits_per_image = image_embeds @ text_embeds.transpose(1, 0).contiguous() * self.logit_scale + logits_per_text = logits_per_image.transpose(1, 0).contiguous() + + return image_embeds, text_embeds, logits_per_text, logits_per_image + + +def compile_clip(args): + # 加载Pytorch模型 + with torch.no_grad(): + torch_model = AutoModel.from_pretrained(args.hf_model_path).float().eval() + torch_model = CLIPWrapper(torch_model) + + hf_config_path = os.path.join(args.hf_model_path, "config.json") + if not os.path.exists(hf_config_path): + raise FileNotFoundError(f"config.json not found at {args.hf_model_path}: {hf_config_path}") + with open(hf_config_path, "r") as f: + config_dict = json.load(f) + + # 构造模型输入 + image_width = config_dict["vision_config"]["image_size"] + pixel_values_shape = (args.img_max_batch, 3, image_width, image_width) + input_ids_shape = (args.text_max_batch, args.max_token_len) + pixel_values = torch.ones(pixel_values_shape, dtype=torch.float32) + input_ids = torch.randint(high=1, size=input_ids_shape, dtype=torch.int32) + attention_mask = torch.ones_like(input_ids, dtype=torch.int32) + + # 导出fx格式模型并执行MindIE编译 + constraints = [ + # input ids + dynamic_dim(input_ids, 0) >= 1, + dynamic_dim(input_ids, 0) <= args.text_max_batch, + dynamic_dim(input_ids, 1) >= 1, + dynamic_dim(input_ids, 1) <= 52, + # pixel input + dynamic_dim(pixel_values, 0) >= 1, + dynamic_dim(pixel_values, 0) <= args.img_max_batch, + # input ids attention mask + dynamic_dim(attention_mask, 0) == dynamic_dim(input_ids, 0), + dynamic_dim(attention_mask, 1) == dynamic_dim(input_ids, 1), + ] + + logging.info("Starting to export dynamic clip ...") + intermediate_model = export( + torch_model, + args=(input_ids, pixel_values, attention_mask,), + constraints=constraints + ) + logging.info("Successfully export dynamic clip!") + + mindietorch.set_device(args.device_id) + pixel_values_min_shape = (1, 3, image_width, image_width) + pixel_values_max_shape = (args.img_max_batch, 3, image_width, image_width) + input_ids_min_shape = (1, 1) + input_ids_max_shape = (args.text_max_batch, args.max_token_len) + + # 执行MindIETorch编译 + compile_inputs = [ + mindietorch.Input(min_shape=input_ids_min_shape, max_shape=input_ids_max_shape), + mindietorch.Input(min_shape=pixel_values_min_shape, max_shape=pixel_values_max_shape), + mindietorch.Input(min_shape=input_ids_min_shape, max_shape=input_ids_max_shape), # attention mask + ] + + if args.precision == "fp16": + model_precision = _enums.PrecisionPolicy.FP16 + elif args.precision == "fp32": + model_precision = _enums.PrecisionPolicy.FP32 + else: + raise ValueError("Unsupported precision type!") + + logging.info("Starting to compile mindietorch clip ...") + ts = time.time() + compiled_model = mindietorch.compile( + intermediate_model, + inputs=compile_inputs, + precision_policy=model_precision, + soc_version=args.soc_version, + ) + compile_cost = time.time() - ts + logging.info("compile time cost: %f", compile_cost) + logging.info("Successfully exported mindietorch clip!") + + logging.info("Starting to save ...") + model_save_dir = f"{args.save_dir}" + if not os.path.exists(model_save_dir): + os.makedirs(model_save_dir) + compiled_file_name = f"CLIP-{args.model_version}-MindIE.pt" + torch.save(compiled_model, model_save_dir + compiled_file_name, pickle_protocol=4) + logging.info("Saving done!") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Compile Clip model") + parser.add_argument("--soc-version", type=str, default="Ascend910B4", help="NPU version") + parser.add_argument("--device-id", type=int, default=0) + parser.add_argument("--text-max-batch", type=int, default=80) + parser.add_argument("--img-max-batch", type=int, default=8) + parser.add_argument( + "--max-token-len", + type=int, + default=52, + help="The padded length of input text (include [CLS] & [SEP] tokens)." + ) + parser.add_argument( + "--model-version", + default="ViT-B-16", + choices=["ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"], + help="Specify the architecture of CLIP model to be converted." + ) + parser.add_argument( + "--hf-model-path", + default="/Path/to/Huggingface_model_path", + type=str, + help="Huggingface CLIP Model Path." + ) + parser.add_argument( + "--precision", + default="fp16", + choices=["fp16", "fp32"], + help="Specify the precision of CLIP model to be converted." + ) + parser.add_argument("--save-dir", type=str, default="./", help="Path to save the exported model") + + return parser.parse_args() + + +if __name__ == "__main__": + compile_args = parse_args() + compile_clip(compile_args) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/multimodal/CLIP/compile_ts.py b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/compile_ts.py new file mode 100644 index 0000000000..74f6c1892e --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/compile_ts.py @@ -0,0 +1,147 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time +import json +import logging +import argparse +import torch +import mindietorch +import torch.nn as nn +from mindietorch import _enums +from transformers.models.auto.modeling_auto import AutoModel + +logging.basicConfig(level=logging.INFO) + + +class CLIPWrapper(nn.Module): + def __init__(self, clip): + super(CLIPWrapper, self).__init__() + self.model = clip + self.logit_scale = clip.logit_scale.exp() + self.logit_scale.to(self.model.device) + + def forward(self, input_ids, pixel_values, attention_mask): + image_embeds = self.model.get_image_features(pixel_values) + text_embeds = self.model.get_text_features(input_ids, attention_mask) + + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + logits_per_image = image_embeds @ text_embeds.transpose(1, 0).contiguous() * self.logit_scale + logits_per_text = logits_per_image.transpose(1, 0).contiguous() + + return image_embeds, text_embeds, logits_per_text, logits_per_image + + +def compile_clip(args): + # 加载Pytorch模型 + with torch.no_grad(): + torch_model = AutoModel.from_pretrained(args.hf_model_path).float().eval() + torch_model = CLIPWrapper(torch_model) + + hf_config_path = os.path.join(args.hf_model_path, "config.json") + if not os.path.exists(hf_config_path): + raise FileNotFoundError(f"config.json not found at {args.hf_model_path}: {hf_config_path}") + with open(hf_config_path, "r") as f: + config_dict = json.load(f) + + # 构造模型输入 + image_width = config_dict["vision_config"]["image_size"] + pixel_values_shape = (args.img_batch, 3, image_width, image_width) + input_ids_shape = (args.text_batch, args.token_len) + pixel_values = torch.ones(pixel_values_shape, dtype=torch.float32) + input_ids = torch.randint(high=1, size=input_ids_shape, dtype=torch.int32) + attention_mask = torch.ones_like(input_ids, dtype=torch.int32) + + # 导出ts格式模型并执行MindIE编译 + input_data = [input_ids, pixel_values, attention_mask] + logging.info("Starting to trace clip ...") + intermediate_model = torch.jit.trace(torch_model, input_data) + logging.info("Successfully trace clip!") + + mindietorch.set_device(args.device_id) + + # 执行MindIETorch编译 + compile_inputs = [ + mindietorch.Input(shape=input_ids_shape, dtype=torch.int32), + mindietorch.Input(shape=pixel_values_shape, dtype=torch.float32), + mindietorch.Input(shape=input_ids_shape, dtype=torch.int32) + ] + if args.precision == "fp16": + model_precision = _enums.PrecisionPolicy.FP16 + elif args.precision == "fp32": + model_precision = _enums.PrecisionPolicy.FP32 + else: + raise ValueError("Unsupported precision type!") + + logging.info("Starting to compile mindietorch clip ...") + ts = time.time() + compiled_model = mindietorch.compile( + intermediate_model, + inputs=compile_inputs, + precision_policy=model_precision, + soc_version=args.soc_version, + ) + compile_cost = time.time() - ts + logging.info("compile time cost: %f", compile_cost) + logging.info("Successfully exported mindietorch clip!") + + logging.info("Starting to save ...") + model_save_dir = f"{args.save_dir}" + if not os.path.exists(model_save_dir): + os.makedirs(model_save_dir) + compiled_file_name = f"CLIP-{args.model_version}-MindIE.ts" + compiled_model.save(model_save_dir + compiled_file_name) + logging.info("Saving done!") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Compile Clip model") + parser.add_argument("--soc-version", default="Ascend910B4", help="NPU version") + parser.add_argument("--device-id", type=int, default=0) + parser.add_argument("--text-batch", type=int, default=80) + parser.add_argument("--img-batch", type=int, default=1) + parser.add_argument( + "--token-len", + type=int, + default=52, + help="The padded length of input text (include [CLS] & [SEP] tokens)." + ) + parser.add_argument( + "--model-version", + default="ViT-L-14", + choices=["ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"], + help="Specify the architecture of CLIP model to be converted." + ) + parser.add_argument( + "--hf-model-path", + default="/Path/to/Huggingface_model_path", + type=str, + help="Huggingface CLIP Model Path." + ) + parser.add_argument( + "--precision", + default="fp16", + choices=["fp16", "fp32"], + help="Specify the precision of CLIP model to be converted." + ) + parser.add_argument("--save-dir", type=str, default="./", help="Path to save the exported model") + + return parser.parse_args() + + +if __name__ == "__main__": + compile_args = parse_args() + compile_clip(compile_args) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/multimodal/CLIP/export_onnx.py b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/export_onnx.py new file mode 100644 index 0000000000..0b5897cb28 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/export_onnx.py @@ -0,0 +1,126 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import json +import logging +import argparse +import torch +import torch.onnx +import torch.nn as nn +from transformers.models.auto.modeling_auto import AutoModel + +logging.basicConfig(level=logging.INFO) + + +class CLIPWrapper(nn.Module): + def __init__(self, clip): + super(CLIPWrapper, self).__init__() + self.model = clip + self.logit_scale = clip.logit_scale.exp() + self.logit_scale.to(self.model.device) + + def forward(self, input_ids, pixel_values, attention_mask): + image_embeds = self.model.get_image_features(pixel_values) + text_embeds = self.model.get_text_features(input_ids, attention_mask) + + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + logits_per_image = image_embeds @ text_embeds.transpose(1, 0).contiguous() * self.logit_scale + logits_per_text = logits_per_image.transpose(1, 0).contiguous() + + return image_embeds, text_embeds, logits_per_text, logits_per_image + + +def export_onnx(args): + + # 加载Pytorch模型 + with torch.no_grad(): + torch_model = AutoModel.from_pretrained(args.hf_model_path).float().eval() + torch_model = CLIPWrapper(torch_model) + + hf_config_path = os.path.join(args.hf_model_path, "config.json") + if not os.path.exists(hf_config_path): + raise FileNotFoundError(f"config.json not found at {args.hf_model_path}: {hf_config_path}") + with open(hf_config_path, "r") as f: + config_dict = json.load(f) + + # 构造模型输入 + image_width = config_dict["vision_config"]["image_size"] + img_input_shape = (1, 3, image_width, image_width) + text_input_shape = (3, args.max_token_len) + input_img = torch.ones(img_input_shape, dtype=torch.float32) + input_ids = torch.randint(high=1, size=text_input_shape, dtype=torch.int32) + attention_mask = torch.ones_like(input_ids, dtype=torch.int32) + torch_model(input_ids, input_img, attention_mask) + + # 导出onnx模型 + file_name = f"CLIP-{args.model_version}.onnx" + model_save_dir = args.save_dir + file_name + logging.info("Starting to export dynamic onnx ...") + text_batch_size = "text_batch_size" + image_batch_size = "image_batch_size" + seq_len = "seq_len" + torch.onnx.export( + torch_model, + (input_ids, input_img, attention_mask), + model_save_dir, + input_names=['input_ids', "pixel_values", "attention_mask"], + output_names=['image_embeds', "text_embeds", "logits_per_text", "logits_per_image"], + export_params=True, + opset_version=13, + verbose=True, + dynamic_axes={ + "input_ids":{0: text_batch_size, 1: seq_len}, + "pixel_values":{0: image_batch_size}, + "attention_mask":{0: text_batch_size, 1: seq_len}, + "image_embeds":{0: image_batch_size}, + "text_embeds":{0: text_batch_size}, + "logits_per_text":{0: text_batch_size, 1: image_batch_size}, + "logits_per_image":{0: image_batch_size, 1:text_batch_size}, + } + ) + logging.info("Successfully exported dynamic onnx!") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Compile Clip model") + parser.add_argument("--text-max-batch", type=int, default=80) + parser.add_argument("--img-max-batch", type=int, default=8) + parser.add_argument( + "--max-token-len", + type=int, + default=52, + help="The padded length of input text (include [CLS] & [SEP] tokens)." + ) + parser.add_argument( + "--model-version", + default="ViT-B-16", + choices=["ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"], + help="Specify the architecture of CLIP model to be converted." + ) + parser.add_argument( + "--hf-model-path", + default="/Path/to/Huggingface_model_path", + type=str, + help="Huggingface CLIP Model Path." + ) + parser.add_argument("--save-dir", type=str, default="./", help="Path to save the exported model") + + return parser.parse_args() + + +if __name__ == "__main__": + compile_args = parse_args() + export_onnx(compile_args) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/multimodal/CLIP/perf_test_aie.py b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/perf_test_aie.py new file mode 100644 index 0000000000..f268c4680d --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/perf_test_aie.py @@ -0,0 +1,99 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import json +import logging +import argparse +import time +import torch +import mindietorch + +logging.basicConfig(level=logging.INFO) + + +def test(inputs, model, stream, meta=""): + # warmup + for _ in range(10): + with mindietorch.npu.stream(stream): + model(*inputs) + stream.synchronize() + + # performance test + num_infer = 100 + start = time.time() + for _ in range(num_infer): + with mindietorch.npu.stream(stream): + model(*inputs) + stream.synchronize() + end = time.time() + + logging.info("%s latency: %.2f ms", meta, (end - start) / num_infer * 1000) + logging.info("%s throughput: %.2f fps", meta, num_infer / (end - start)) + + +def test_clip(args): + device = f'npu:{args.device_id}' + stream = mindietorch.npu.Stream(device) + if args.clip_aie_path.endswith(".ts"): + model = torch.jit.load(args.clip_aie_path) + else: + model = torch.load(args.clip_aie_path) + model.eval().to(device) + + hf_config_path = os.path.join(args.hf_model_path, "config.json") + if not os.path.exists(hf_config_path): + raise FileNotFoundError(f"config.json not found at {args.hf_model_path}: {hf_config_path}") + with open(hf_config_path, "r") as f: + config_dict = json.load(f) + + image_width = config_dict["vision_config"]["image_size"] + img_input_shape = (args.image_batchsize, 3, image_width, image_width) + text_input_shape = (args.text_batchsize, args.token_len) + input_img = torch.randn(img_input_shape, dtype=torch.float32).to(device) + input_ids = torch.randint(high=1000, size=text_input_shape, dtype=torch.int32).to(device) + attention_mask = torch.ones(text_input_shape, dtype=torch.int32).to(device) + inputs = [input_ids, input_img, attention_mask] + + test(inputs, model, stream, "CLIP") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--device-id", type=int, help="NPU device id", default=0) + parser.add_argument( + "--clip-aie-path", + type=str, + default="/Path/to/compiled/aie_or_ts_model" + ) + parser.add_argument( + "--hf-model-path", + default="/Path/to/Huggingface_model_path", + type=str, + help="Huggingface CLIP Model Path." + ) + parser.add_argument("--text-batchsize", type=int, default=80) + parser.add_argument("--image-batchsize", type=int, default=1) + parser.add_argument("--token-len", type=int, default=52) + + return parser.parse_args() + + +def main(): + perf_args = parse_args() + mindietorch.set_device(perf_args.device_id) + test_clip(perf_args) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/multimodal/CLIP/perf_test_onnx.py b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/perf_test_onnx.py new file mode 100644 index 0000000000..106b7b87d2 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/perf_test_onnx.py @@ -0,0 +1,97 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import json +import logging +import argparse +import time +import torch +import onnxruntime as ort + +logging.basicConfig(level=logging.INFO) + + +def test(encoder_path, provider, output_names, onnx_inputs, meta=""): + onnx_model = ort.InferenceSession( + encoder_path, + providers=[provider] + ) + + # warmup + for _ in range(10): + onnx_model.run(output_names, onnx_inputs) + # performance test + num_infer = 100 + start = time.time() + for _ in range(num_infer): + onnx_model.run(output_names, onnx_inputs) + end = time.time() + + logging.info("%s latency: %.2f ms", meta, (end - start) / num_infer * 1000) + logging.info("%s throughput: %.2f fps", meta, num_infer / (end - start)) + + +def test_clip(args, provider): + hf_config_path = os.path.join(args.hf_model_path, "config.json") + if not os.path.exists(hf_config_path): + raise FileNotFoundError(f"config.json not found at {args.hf_model_path}: {hf_config_path}") + with open(hf_config_path, "r") as f: + config_dict = json.load(f) + + image_width = config_dict["vision_config"]["image_size"] + img_input_shape = (args.image_batchsize, 3, image_width, image_width) + text_input_shape = (args.text_batchsize, args.token_len) + input_img = torch.randn(img_input_shape, dtype=torch.float32).detach().numpy() + input_ids = torch.randint(high=1000, size=text_input_shape, dtype=torch.int32).detach().numpy() + attention_mask = torch.ones(text_input_shape, dtype=torch.int32).detach().numpy() + + onnx_inputs = {"input_ids": input_ids, "pixel_values": input_img, "attention_mask": attention_mask} + output_names = ["image_embeds", "text_embeds", "logits_per_text", "logits_per_image"] + + test(args.onnx_path, provider, output_names, onnx_inputs, "CLIP") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--onnx-path", + type=str, + default="/Path/to/onnx_model" + ) + parser.add_argument( + "--hf-model-path", + default="/Path/to/Huggingface_model_path", + type=str, + help="Huggingface CLIP Model Path." + ) + parser.add_argument("--text-batchsize", type=int, default=80) + parser.add_argument("--image-batchsize", type=int, default=1) + parser.add_argument("--token-len", type=int, default=52) + parser.add_argument("--use-gpu", action="store_true") + + return parser.parse_args() + + +def main(): + perf_args = parse_args() + if perf_args.use_gpu: + provider = "CUDAExecutionProvider" + else: + provider = "CPUExecutionProvider" + + test_clip(perf_args, provider) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/multimodal/CLIP/precision_test.py b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/precision_test.py new file mode 100644 index 0000000000..8a46d0e965 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/precision_test.py @@ -0,0 +1,134 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import json +import logging +import argparse +import torch +import mindietorch +import torch +import onnxruntime as ort +import numpy as np +import torch.nn.functional as F + +logging.basicConfig(level=logging.INFO) + + +def compare_onnx_aie_output(onnx_out, aie_out, sim_threshold=0.99): + num_sim = 0 + for i, (a, b) in enumerate(zip(onnx_out, aie_out)): + a = a.reshape(1, -1).astype(np.float32) + b = b.reshape(1, -1) + sim = F.cosine_similarity(torch.from_numpy(a), b, dim=1) + if sim > sim_threshold: + num_sim += 1 + else: + logging.info('Output %d similarity: %f', i, sim) + + logging.info('Number of outputs to compare: %d', len(onnx_out)) + logging.info('Number of outputs with cosine similarity > %.2f: %d', sim_threshold, num_sim) + + +def compare(args): + # MindIETorch + device = f'npu:{args.device_id}' + stream = mindietorch.npu.Stream(device) + + if args.clip_aie_path.endswith(".ts"): + aie_model = torch.jit.load(args.clip_aie_path) + else: + aie_model = torch.load(args.clip_aie_path) + aie_model.eval().to(device) + + hf_config_path = os.path.join(args.hf_model_path, "config.json") + if not os.path.exists(hf_config_path): + raise FileNotFoundError(f"config.json not found at {args.hf_model_path}: {hf_config_path}") + with open(hf_config_path, "r") as f: + config_dict = json.load(f) + + image_width = config_dict["vision_config"]["image_size"] + img_input_shape = (args.image_batchsize, 3, image_width, image_width) + text_input_shape = (args.text_batchsize, args.token_len) + input_img = torch.randn(img_input_shape, dtype=torch.float32).to(device) + input_ids = torch.randint(high=1000, size=text_input_shape, dtype=torch.int32).to(device) + attention_mask = torch.ones(text_input_shape, dtype=torch.int32).to(device) + inputs = [input_ids, input_img, attention_mask] + + with mindietorch.npu.stream(stream): + aie_out = aie_model(*inputs) + stream.synchronize() + + if isinstance(aie_out, tuple) or isinstance(aie_out, list): + aie_out = (x.cpu() for x in aie_out) + else: + aie_out = aie_out.cpu() + + # ONNX + input_img = input_img.cpu().detach().numpy() + input_ids = input_ids.cpu().detach().numpy() + attention_mask = attention_mask.cpu().detach().numpy() + + if args.use_gpu: + provider = "CUDAExecutionProvider" + else: + provider = "CPUExecutionProvider" + + onnx_model = ort.InferenceSession( + args.clip_onnx_path, + providers=[provider] + ) + onnx_inputs = {"input_ids": input_ids, "pixel_values": input_img, "attention_mask": attention_mask} + output_names = ["image_embeds", "text_embeds", "logits_per_text", "logits_per_image"] + onnx_out = onnx_model.run(output_names, onnx_inputs) + + compare_onnx_aie_output(onnx_out, aie_out, args.sim_threshold) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--device-id", type=int, default=0, help="NPU device id") + parser.add_argument( + "--clip-aie-path", + type=str, + default="/Path/to/compiled/aie_or_ts_model" + ) + parser.add_argument( + "--clip-onnx-path", + type=str, + default="/Path/to/onnx_model" + ) + parser.add_argument( + "--hf-model-path", + default="/Path/to/Huggingface_model_path", + type=str, + help="Huggingface CLIP Model Path." + ) + parser.add_argument("--text-batchsize", type=int, default=80) + parser.add_argument("--image-batchsize", type=int, default=1) + parser.add_argument("--token-len", type=int, default=52) + parser.add_argument('--sim-threshold', type=float, default=0.99) + parser.add_argument("--use-gpu", action="store_true") + + return parser.parse_args() + + +def main(): + compare_args = parse_args() + mindietorch.set_device(compare_args.device_id) + logging.info('=== Compare the outputs of ONNX and AIE ===') + compare(compare_args) + + +if __name__ == "__main__": + main() \ No newline at end of file -- Gitee