diff --git a/MindIE/MindIE-Torch/built-in/multimodal/CLIP/README.md b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a103850620081c4f2d9299abd97e040dfd6c2e97
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/README.md
@@ -0,0 +1,261 @@
+# CLIP-推理指导
+
+
+- [概述](#ZH-CN_TOPIC_0000001172161501)
+
+- [推理环境准备](#ZH-CN_TOPIC_0000001126281702)
+
+- [快速上手](#ZH-CN_TOPIC_0000001126281700)
+
+- [模型推理性能精度](#ZH-CN_TOPIC_0000001172201573)
+
+
+# 概述
+
+([来自开源代码仓](https://github.com/openai/CLIP))CLIP使用大规模(图片、文本)数据对进行训练,旨在帮助用户快速实现图文特征&相似度计算、跨模态检索、零样本图片分类等任务。该项目旨在演示在多判断分支、使用hook机制的场景下,如何使用MindIE组件对主要计算模块进行加速。端到端推理演示通过命令行调用的形式呈现,每次都要花费很长时间载入模型,且各主要计算模块之间使用cpu计算串联,不适合直接在业务中应用。开发者可以根据该项目代码,配套Torch_NPU,开发出高性能的图文识别服务。
+
+# 推理环境准备[所有版本]
+- 该模型需要以下插件与驱动
+
+ **表 1** 版本配套表
+ | 配套 | 版本 |
+ |---------| ------- |
+ | 固件与驱动 | - |
+ | CANN | - |
+ | Python | 3.10.13 |
+ | PyTorch | 2.1.0 |
+ | MindIE | - |
+
+ 注意:由于MindIE暂无支持该模型的商发版本,烦请用户联系华为工程师获取对应的固件驱动,CANN,MindIE PoC版本链接。
+ 固件驱动和CANN的安装,请参考昇腾官方文档[环境快速部署](https://www.hiascend.com/document/detail/zh/quick-installation/24.0.RC1/quickinstg/800_3000/quickinstg_800_3000_0001.html)。
+
+ MindIE的安装需要先source toolkit的环境变量,然后直接安装,以默认安装路径`/usr/local/Ascend`为例:
+ ```
+ source /usr/local/Ascend/ascend-tookit/set_env.sh
+ bash Ascend-mindie_*.run --install
+ ```
+
+
+# 快速上手
+1. 安装依赖
+ ```
+ pip install transformers==4.41.1
+ ```
+2. 模型下载-本项目支持clip-zh:VIT-B-16以及clip-en:VIT-L/14
+ ```
+ CLIP:
+ git clone https://huggingface.co/openai/clip-vit-large-patch14-336
+ ChineseCLIP:
+ git clone https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16
+ ```
+3. 导出onnx模型-以ChineseCLIP为例
+ ```shell
+ cd CLIP/
+ hf_model_path="/path/to/chinese-clip-vit-base-patch16"
+ python ./export_onnx.py \
+ --text-max-batch 80 \
+ --img-max-batch 8 \
+ --max-token-len 52 \
+ --model-version "ViT-B-16" \
+ --hf-model-path $hf_model_path \
+ --save-dir "./"
+ ```
+ 执行结束后,会在save-dir目录下生成CLIP-ViT-B-16.onnx
+
+ 参数说明:
+ - --hf-model-path:下载的Huggingface模型路径。
+ - --soc_version:芯片类型,当前仅在Ascend910B4上调试。
+ - --text-max-batch:最大文本batchsize,默认最大值为80。
+ - --img-max-batch:最大图片batchsize,默认最大值为8。
+ - --max-token-len:文本token最大嵌入长度,默认最大值为52。
+ - --model-version:CLIP版本。
+ - 其他参数请参考脚本`parse_args`部分。
+
+4. 编译静态输入ts模型
+ ```shell
+ cd CLIP/
+ hf_model_path=/path/to/chinese-clip-vit-base-patch16
+ python ./compile_ts.py \
+ --soc-version "Ascend910B4" \
+ --device-id 0 \
+ --text-batch 5 \
+ --img-batch 1 \
+ --token-len 52 \
+ --model-version "ViT-B-16" \
+ --hf-model-path $hf_model_path \
+ --precision "fp16" \
+ --save-dir "./"
+ ```
+ 执行结束后,会在save-dir目录下生成CLIP_{--model-version}-MindIE.ts
+
+ 参数说明:
+ - --hf-model-path:下载的Huggingface模型路径。
+ - --soc_version: 芯片类型,当前仅在Ascend910B4上调试。
+ - --text-batch: 固定输入的文本batchsize,默认值为80。
+ - --img-max-batch: 固定输入的图片batchsize,默认值为1。
+ - --token-len: 固定的文本token嵌入长度,默认值为52。
+ - --model-version: CLIP版本。
+ - 其他参数请参考脚本`parse_args`部分。
+
+5. 编译动态输入aie模型(可选)
+ ```shell
+ cd chiness_clip/
+ hf_model_path="/path/to/huggingface_clip_model"
+ python ./compile_aie.py \
+ --soc-version "Ascend910B4" \
+ --device-id 0 \
+ --text-max-batch 80 \
+ --img-max-batch 8 \
+ --max-token-len 52 \
+ --model-version "ViT-B-16" \
+ --hf-model-path $hf_model_path \
+ --precision "fp16" \
+ --save-dir "./"
+ ```
+ 执行结束后,会在save-dir目录下生成CLIP_{--model-version}-MindIE.pt
+
+ 参数说明:
+ - --hf-model-path:下载的Huggingface模型路径。
+ - --soc_version: 芯片类型,当前仅在Ascend910B4上调试。
+ - --text-max-batch: 最大文本batchsize,默认最大值为80。
+ - --img-max-batch: 最大图片batchsize,默认最大值为8。
+ - --max-token-len: 文本token最大嵌入长度,,默认最大值为52。
+ - --model-version: CLIP版本。
+ - 其他参数请参考脚本`parse_args`部分。
+
+
+6. 模型推理
+ ```shell
+ aie_path="/path/to/ts_model"
+ hf_model_path="/path/to huggingface_clip_model"
+ image_path="/path/to/image"
+ python3 clip_infer.py \
+ --device-id 0 \
+ --clip-aie-path $aie_path \
+ --hf-model-path $hf_model_path \
+ --max-token-len 52 \
+ --image-path $image_path \
+ --text "杰尼龟" "妙蛙种子" "小火龙" "皮卡丘" "image内容"
+ ```
+ 推理结束后,会在命令行打印出类似如下输出:
+ ```
+ Probs per image: [[0.00939221866428852, 0.03177346661686897, 0.010315335355699062, 0.07622058689594269, 0.8722984194755554]]
+ ```
+ 参数说明:
+ - --hf-model-path: 所下载的Huggingface模型文件夹路径。
+ - --clip-aie-path:第4步或第5步编译得到的模型路径。clip-aie-path的模型必须与hf-model-path中的模型类型一致。
+ - --image-path: 图片路径,当前仅支持单张图片推理。
+ - --text:所需要的输入形式如:List[str, ]形式,例如:["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"]。其中每一个字段代表一个text_batch,如果导出的是静态ts模型,确保text_batch维度和编译时设置的text_batch输入维度保持一致。
+ - 其他参数请参考脚本`parse_args`部分。
+
+
+# 模型推理性能精度
+
+1. 精度验证
+ ```shell
+ hf_model_path="/path/to huggingface_clip_model"
+ aie_path="/path/to/ts_model"
+ onnx_path="/path/to/onnx_model"
+ python3 precision_test.py \
+ --device-id 0 \
+ --clip-aie-path $aie_path \
+ --clip-onnx-path $onnx_path \
+ --hf-model-path $hf_model_path \
+ --text-batchsize 5 \
+ --image-batchsize 1 \
+ --token-len 52 \
+ --sim-threshold 0.99
+ ```
+
+ 参数说明:
+ - --hf-model-path: 下载的Huggingface模型文件夹路径。
+ - --clip-aie-path:由hf-model-path模型经过第4步或第5步编译得到的模型路径。
+ - --clip-onnx-path:onnx_export.py文件导出的onnx模型路径。
+ - --sim-threshold: 余弦相似度阈值,默认0.99。
+ - --text-batchsize: dummy文本输入的batchsize,默认10,对于ts模型这里的输入维度需要与编译时设置的输入维度一致。
+ - --token-len:dummy文本输入的token长度,默认为20,对于ts模型这里的输入维度需要与编译时设置的输入维度一致。
+ - --image-batchsize: dummy图片输入的batchsize,默认为1,对于ts模型这里的输入维度需要与编译时设置的输入维度一致。
+ - 其他参数请参考脚本`parse_args`部分。
+
+ 执行结束后,期望输出如下:
+ ```
+ === Compare the outputs of ONNX and AIE ===
+ Start comparing clip...
+ Number of outputs to compare: 4
+ Number of outputs with cosine similarity > 0.99: 4
+ ```
+
+2. 性能验证
+
+ a) aie模型性能测试-以VIT-B-16为例
+ ```shell
+ aie_path="/path/to/ts_model or pt model"
+ hf_model_path="/path/tohuggingface_clip_model"
+ python3 perf_test_aie.py \
+ --device-id 0 \
+ --clip-aie-path $aie_path \
+ --hf-model-path $hf_model_path \
+ --text-batchsize 5 \
+ --token-len 52 \
+ --image-batchsize 1
+ ```
+
+ 参数说明:
+ - --hf-model-path: Huggingface模型文件夹路径。
+ - --clip-aie-path:由hf-model-path模型经过第4步或第5步编译得到的模型路径。
+ - --text-batchsize: dummy文本输入的batchsize,默认10。
+ - --token-len:dummy文本输入的token长度,默认为20。
+ - --image-batchsize: dummy图片输入的batchsize,默认为1。
+ - 其他参数请参考脚本`parse_args`部分。
+
+ 执行结束后,期望输出如下:
+ ```
+ CLIP latency: 40.47 ms
+ CLIP throughput: 24.71 fps
+ ```
+
+ b) onnx模型性能测试
+ (可选)若使用GPU,请确保已安装CUDA和pytorch-gpu版本,同时需安装onnxruntime-gpu,如下所示:
+ ```shell
+ pip uninstall onnxruntime
+ pip install onnxruntime-gpu
+ ```
+ 验证onnxruntime-gpu是否安装成功:
+ ```python
+ import onnxruntime
+ print(onnxruntime.get_device()) # 若输出为GPU,则说明安装成功
+ ```
+ 执行性能测试
+ ```shell
+ onnx_path="/path/to/onnx_model"
+ hf_model_path="/path/tohuggingface_clip_model"
+ python perf_test_onnx.py \
+ --onnx-path $onnx_path \
+ --hf-model-path $hf_model_path \
+ --text-batchsize 5 \
+ --token-len 52 \
+ --image-batchsize 1
+ ```
+
+ 参数说明:
+ - --use-gpu: 使能gpu推理,不加该选项默认cpu。
+ - --hf-model-path: Huggingface模型文件夹路径。
+ - --clip-onnx-path:hf-model-path经过第3步导出的onnx模型路径。
+ - --text-batchsize: dummy文本输入的batchsize,默认10。
+ - --token-len:dummy文本输入的token长度,默认为20。
+ - --image-batchsize: dummy图片输入的batchsize,默认为1。
+ - 其他参数请参考脚本`parse_args`部分。
+
+ 执行结束后,期望输出如下:
+ ```
+ CLIP latency: 181.08 ms
+ CLIP throughput: 5.51 fps
+ ```
+
+
+ | 模型 | pt插件 - 910B4性能(时延/吞吐率) |
+ |---------|--------------------------------|
+ | Chinese-CLIP-ViT-Base-Patch16 | 40.47 ms / 24.71 fps |
+ | clip-vit-large-patch14-336 | 47.73 ms / 21.12 fps |
+
+ 注:该性能是ts模型在text-batch=5, image-batch=1, token-len=52设置下测试10次取平均值得到,以上数据仅作参考。
\ No newline at end of file
diff --git a/MindIE/MindIE-Torch/built-in/multimodal/CLIP/clip_infer.py b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/clip_infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4157194111837e872a274883b620d7406998bb4c
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/clip_infer.py
@@ -0,0 +1,99 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import argparse
+import torch
+import mindietorch
+from PIL import Image
+import torch.nn.functional as F
+from transformers import AutoProcessor
+
+logging.basicConfig(level=logging.INFO)
+
+
+def inference(args):
+ device = f'npu:{args.device_id}'
+ stream = mindietorch.npu.Stream(device)
+ if args.clip_aie_path.endswith(".ts"):
+ aie_model = torch.jit.load(args.clip_aie_path)
+ else:
+ aie_model = torch.load(args.clip_aie_path)
+ aie_model.eval().to(device)
+
+ processor = AutoProcessor.from_pretrained(args.hf_model_path)
+ inputs = processor(text=args.text, images=Image.open(args.image_path), return_tensors="pt", padding=True)
+ input_ids = inputs.input_ids.to(torch.int32)
+ attention_mask = inputs.attention_mask.to(torch.int32)
+ pad_length = args.max_token_len
+ cur_length = inputs.input_ids.size(-1)
+ if pad_length > cur_length:
+ input_ids = F.pad(input_ids, (0, pad_length - cur_length), value=0)
+ attention_mask = F.pad(attention_mask, (0, pad_length - cur_length), value=0)
+
+ input_ids = input_ids.to(torch.int32).to(device)
+ attention_mask = attention_mask.to(torch.int32).to(device)
+ input_img = inputs.pixel_values.to(torch.float32).to(device)
+ inputs = [input_ids, input_img, attention_mask]
+
+ with mindietorch.npu.stream(stream):
+ aie_out = aie_model(*inputs)
+ stream.synchronize()
+
+ aie_out = [x.cpu() for x in aie_out]
+ logging.info("Probs per image: %s", aie_out[-1].softmax(dim=-1).tolist())
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--device-id", type=int, help="NPU device id", default=0)
+ parser.add_argument(
+ "--clip-aie-path",
+ type=str,
+ default="/Path/to/compiled/aie_or_ts_model"
+ )
+ parser.add_argument(
+ "--hf-model-path",
+ default="/Path/to/Huggingface_model_path",
+ type=str,
+ help="Huggingface CLIP Model Path."
+ )
+ parser.add_argument(
+ "--max-token-len",
+ type=int,
+ default=52,
+ help="Manually pad input ids to max-token-len"
+ )
+ parser.add_argument(
+ "--image-path",
+ type=str,
+ default="/Path/to/image"
+ )
+ parser.add_argument(
+ "--text",
+ type=str,
+ nargs='+',
+ default=["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘", "image内容"],
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ infer_args = parse_args()
+ mindietorch.set_device(infer_args.device_id)
+ inference(infer_args)
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file