diff --git a/MindIE/MindIE-Torch/built-in/multimodal/EVA_CLIP/README.md b/MindIE/MindIE-Torch/built-in/multimodal/EVA_CLIP/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b4eb84b5038d5774fb42003a89d14c9ede31b4ad
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/multimodal/EVA_CLIP/README.md
@@ -0,0 +1,130 @@
+# EVA-CLIP-推理指导
+
+
+- [概述](#ZH-CN_TOPIC_0000001172161501)
+
+- [推理环境准备](#ZH-CN_TOPIC_0000001126281702)
+
+- [快速上手](#ZH-CN_TOPIC_0000001126281700)
+
+- [模型推理性能精度](#ZH-CN_TOPIC_0000001172201573)
+
+
+# 概述
+
+([来自开源代码仓](https://github.com/baaivision/EVA.git))EVA-CLIP是一种新型模型系列,它在训练效率和效果上对CLIP(Contrastive Language-Image Pre-training)进行了显著提升。EVA-CLIP通过融合新的表征学习、优化和增强技术,使得在相同参数量的情况下,相较于之前的CLIP模型,实现了更优的性能表现,并且显著降低了训练成本。该项目旨在演示如何使用MindIE组件对EVA-02-CLIP进行推理。
+
+
+# 推理环境准备\[所有版本\]
+
+- 该模型需要以下插件与驱动
+
+ **表 1** 版本配套表
+ | 配套 | 版本 |
+ |---------| ------- |
+ | 固件与驱动 | - |
+ | CANN | - |
+ | Python | 3.10.13 |
+ | PyTorch | 2.1.0 |
+ | MindIE | - |
+
+注意:由于MindIE暂无支持该模型的商发版本,烦请用户联系华为工程师获取对应的固件驱动,CANN,MindIE PoC版本链接。
+固件驱动和CANN的安装,请参考昇腾官方文档[环境快速部署](https://www.hiascend.com/document/detail/zh/quick-installation/24.0.RC1/quickinstg/800_3000/quickinstg_800_3000_0001.html)。
+
+MindIE的安装需要先source toolkit的环境变量,然后直接安装,以默认安装路径`/usr/local/Ascend`为例:
+```
+source /usr/local/Ascend/ascend-toolkit/set_env.sh
+bash Ascend-mindie_*.run --install
+ ```
+
+
+# 快速上手
+
+1. 源码下载
+ ```
+ git clone https://github.com/baaivision/EVA.git
+ ```
+ 注:源码下载后也需要下载[预训练模型权重](https://huggingface.co/QuanSun/EVA-CLIP/tree/main)。
+
+ 源码下载后的目录结构如下:
+ ```
+ |
+ |----EVA
+ | |----assets
+ | |----EVA-01
+ | ...
+ | |----EVA-CLIP
+ | |____...
+ |----README.md
+ |____eva_clip_export.py
+ ```
+
+ 根据源码给出的依赖文件安装相关三方依赖:
+
+ ```
+ pip install -r ./EVA/EVA-CLIP/requirements.txt
+ ```
+
+2. 模型编译
+
+ ❗❗❗重要:在进行模型编译之前,需要修改相应模型的配置.json文件,将配置文件中所有的 `"xattn"` 修改为 `false`。以 `EVA02-CLIP-bigE-14` 模型为例,该模型的配置文件通常出现在源码路径 `EVA/EVA-CLIP/rei/eva_clip/model_configs/EVA02-CLIP-bigE-14.json`。修改完成后使用下面命令行编译:
+
+ ```
+ source /usr/local/Ascend/ascend-toolkit/set_env.sh
+ source /usr/local/Ascend/mindie/set_env.sh
+ python eva_clip_export.py --model_name {model_name} --model_path {model_path} --save_dir {save_dir} --compile
+ ```
+
+ 😊小贴士:整个编译过程包含了 `torch.jit.trace` 和 `mindietorch.compile` 两个阶段,耗时会随着模型参数增大而增大,耗时约以小时为单位,因此需要耐心等待并建议保存编译好的模型权重。
+
+ 执行完成后将在 `{save_dir}` 目录下生成 `EVA02-CLIP-B-16-trace-MindIE-Vis.pt`, `EVA02-CLIP-B-16-trace-MindIE-Text.pt`, `EVA02-CLIP-B-16-compile-MindIE-Vis.pt`, `EVA02-CLIP-B-16-compile-MindIE-Text.pt` 四个文件,`*trace-MindIE-*.pt` 权重可以用来验证trace的模型是否出错,酌情选择是否保留。
+
+ 参数说明:
+ - --model_name:模型结构名称,可选为:`["EVA02-CLIP-B-16", "EVA02-CLIP-L-14", "EVA02-CLIP-L-14-336",
+ "EVA02-CLIP-bigE-14", "EVA02-CLIP-bigE-14-plus"]`。
+ - --model_path:下载的预训练权重路径。
+ - --save_dir: 编译后权重文件夹路径。
+ - --compile: 对模型进行编译。
+
+3. 模型推理
+ ```
+ python eva_clip_export.py --model_name {model_name} --model_path {model_path} --save_dir {save_dir} --image_path {image_path} --classes {classes} --infer
+ ```
+
+ 参数说明:
+ - --image_path:用来测试的图片路径。
+ - --classes: 设置测试的类别,List[str]类型,例如:["a diagram", "a dog", "a cat"]。
+ - --infer: 选择为推理模式。
+ - 其他参数请参考脚本parse_args部分。
+
+
+# 模型推理精度
+
+对比源码Pytorch模型和编译后的MindIE模型的输出余弦相似度:
+```
+python eva_clip_export.py --model_name {model_name} --model_path {model_path} --save_dir {save_dir} --image_path {image_path} --classes {classes} --compare
+```
+
+参数说明:
+- --compare: 设置对比模式。
+
+执行结束后,期望输出如下:
+```
+*******compile compare********
+>>>>>>image similarity: 0.9999822974205017>>>>>>>>>>>
+>>>>>>text similarity: 0.9999982118606567>>>>>>>>>>>>
+>>>>>>result similarity: 1.0>>>>>>>>>>>>
+```
+
+注:不同模型架构、测试图片、测试类别的输出值会有所不同。
+
+**表 2** 推理结果余弦相似度
+| Model | Result similarity |
+|---------------------|-------------------|
+| EVA02-CLIP-B-16 | 0.9877 |
+| EVA02-CLIP-L-14 | 0.9999 |
+| EVA02-CLIP-L-14-336 | 0.9999 |
+| EVA02-CLIP-L-14-224to336 | 0.9999 |
+| EVA02-CLIP-bigE-14 | 1.0 |
+| EVA02-CLIP-bigE-14-plus | 0.9999 |
+
diff --git a/MindIE/MindIE-Torch/built-in/multimodal/EVA_CLIP/eva_clip_export.py b/MindIE/MindIE-Torch/built-in/multimodal/EVA_CLIP/eva_clip_export.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e88dd74996cf46415ec6c227738f14ff758c011
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/multimodal/EVA_CLIP/eva_clip_export.py
@@ -0,0 +1,308 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import torch
+import mindietorch
+import time
+from PIL import Image
+import argparse
+import sys
+from pathlib import Path
+
+current_dir = Path(__file__).resolve().parent
+sys.path.append((current_dir / 'EVA' / 'EVA-CLIP' / 'rei').as_posix())
+
+from eva_clip import create_model_and_transforms, get_tokenizer, get_model_config
+
+def cos_sim(tensor1, tensor2):
+
+ # 将多维数组展平成一维数组
+ flattened_tensor1 = torch.flatten(tensor1)
+ flattened_tensor2 = torch.flatten(tensor2)
+
+ # 计算点积
+ dot_product = torch.dot(flattened_tensor1, flattened_tensor2)
+
+ # 计算向量的欧几里得范数
+ norm_flattened_tensor1 = torch.norm(flattened_tensor1)
+ norm_flattened_tensor2 = torch.norm(flattened_tensor2)
+
+ # 计算余弦相似度
+ cosine_similarity_manual = dot_product / (norm_flattened_tensor1 * norm_flattened_tensor2)
+ return cosine_similarity_manual
+
+def check_and_create_directory(path):
+ """
+ Check if a directory exists at the given path and create it if it doesn't.
+ Args:
+ path (str): The path of the directory to check and create.
+ """
+ if not os.path.isdir(path):
+ # Create the directory
+ os.makedirs(path, exist_ok=True)
+
+def compile_eva_vis(args):
+ # 加载模型
+ model_config = get_model_config(args.model_name)
+ model, _, _ = create_model_and_transforms(args.model_name, args.model_path, force_custom_clip=True)
+ torch_model = model.visual
+
+ # 构造模型输入
+ image_size = model_config["vision_cfg"]["image_size"]
+ img_input_shape = (args.img_max_batch, 3, image_size, image_size)
+
+ input_img = torch.ones(img_input_shape, dtype=torch.float32)
+ dummy_input = (
+ input_img
+ )
+
+ # jit.trace 模型并保存
+ check_and_create_directory(args.save_dir)
+ traced_path = os.path.join(args.save_dir, f"{args.model_name}-trace-MindIE-Vis.pt")
+ print("Starting trace eva clip vision_tower...")
+ torch.jit.trace(torch_model, dummy_input).save(traced_path)
+ print("Trace eva clip vision_tower success!")
+
+ img_min_shape = (1, 3, image_size, image_size)
+ img_max_shape = (args.img_max_batch, 3, image_size, image_size)
+
+ # 执行MindIETorch编译
+ compile_inputs = [
+ mindietorch.Input(min_shape=img_min_shape,
+ max_shape=img_max_shape, dtype=torch.float32)
+ ]
+
+ print("start mindietorch compile eva clip vision_tower...")
+ intermediate_model = torch.jit.load(traced_path).eval()
+ ts = time.time()
+ compiled_model = mindietorch.compile(
+ intermediate_model,
+ inputs=compile_inputs,
+ precision_policy=mindietorch._enums.PrecisionPolicy.FP16,
+ soc_version=args.soc_version,
+ )
+ compile_cost = time.time() - ts
+ print("compile time cost: ", compile_cost)
+ print("end mindietorch compile eva clip vision_tower.")
+
+ print("start saving eva clip vision_tower")
+ compile_path = os.path.join(args.save_dir, f"{args.model_name}-compile-MindIE-Vis.pt")
+ model_save_dir = f"{args.save_dir}"
+ if not os.path.exists(model_save_dir):
+ os.makedirs(model_save_dir)
+ torch.jit.save(compiled_model, compile_path)
+ print("saving done eva clip vision_tower")
+
+def compile_eva_text(args):
+ # 加载模型
+ model_config = get_model_config(args.model_name)
+ model, _, _ = create_model_and_transforms(args.model_name, args.model_path, force_custom_clip=True)
+ torch_model = model.text
+
+ text_length = model_config["text_cfg"]["context_length"]
+ text_input_shape = (args.text_max_batch, text_length)
+
+ input_ids = torch.randint(high=1, size=text_input_shape, dtype=torch.int64)
+ dummy_input = (
+ input_ids
+ )
+
+ # jit.trace 模型并保存
+ check_and_create_directory(args.save_dir)
+ traced_path = os.path.join(args.save_dir, f"{args.model_name}-trace-MindIE-Text.pt")
+ print("Starting trace eva clip text_tower ...")
+ torch.jit.trace(torch_model, dummy_input).save(traced_path)
+ print("Trace eva clip text_tower success!")
+
+ token_min_shape = (1, text_length)
+ token_max_shape = (args.text_max_batch, text_length)
+
+ # 执行MindIETorch编译
+ compile_inputs = [
+ mindietorch.Input(min_shape=token_min_shape,
+ max_shape=token_max_shape, dtype=torch.int64)
+ ]
+
+ print("start mindietorch compile eva clip text_tower...")
+ intermediate_model = torch.jit.load(traced_path).eval()
+ ts = time.time()
+ compiled_model = mindietorch.compile(
+ intermediate_model,
+ inputs=compile_inputs,
+ precision_policy=mindietorch._enums.PrecisionPolicy.FP16,
+ soc_version=args.soc_version,
+ )
+ compile_cost = time.time() - ts
+ print("compile time cost: ", compile_cost)
+ print("end mindietorch compile eva clip text_tower.")
+
+ print("start saving eva clip text_tower")
+ compile_path = os.path.join(args.save_dir, f"{args.model_name}-compile-MindIE-Text.pt")
+
+ model_save_dir = f"{args.save_dir}"
+ if not os.path.exists(model_save_dir):
+ os.makedirs(model_save_dir)
+ torch.jit.save(compiled_model, compile_path)
+ print("saving done eva clip text_tower")
+
+
+def inference_aie(args):
+ print("Params: ")
+ for name in sorted(vars(args)):
+ val = getattr(args, name)
+ print(f" {name}: {val}")
+
+ # 加载MindeIE模型
+ mindietorch.set_device(args.device)
+ stream = mindietorch.npu.Stream()
+ mindietorch.npu.set_stream(stream)
+ if args.compare_type == "traced":
+ vis_traced_path = os.path.join(f"{args.save_dir}", f"{args.model_name}-trace-MindIE-Vis.pt")
+ text_traced_path = os.path.join(f"{args.save_dir}", f"{args.model_name}-trace-MindIE-Text.pt")
+ else:
+ vis_traced_path = os.path.join(f"{args.save_dir}", f"{args.model_name}-compile-MindIE-Vis.pt")
+ text_traced_path = os.path.join(f"{args.save_dir}", f"{args.model_name}-compile-MindIE-Text.pt")
+
+ compile_vis = torch.jit.load(vis_traced_path).to(f"npu:{args.device}")
+ compile_text = torch.jit.load(text_traced_path).to(f"npu:{args.device}")
+
+ # 准备输入
+ _, _, preprocessor = create_model_and_transforms(args.model_name, args.model_path, force_custom_clip=True)
+ tokenizer = get_tokenizer(args.model_name)
+ image_pixel = preprocessor(Image.open(args.image_path)).unsqueeze(0).to(f"npu:{args.device}")
+ input_ids = tokenizer(args.classes).to(f"npu:{args.device}")
+
+ image_features = compile_vis(image_pixel)
+ text_features = compile_text(input_ids)
+
+ image_features_norm = image_features / image_features.norm(dim=-1, keepdim=True)
+ text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True)
+ text_probs = (100.0 * image_features_norm @ text_features_norm.T).softmax(dim=-1)
+ print("Label probs:", text_probs)
+ return image_features, text_features, text_probs
+
+def inference_pytorch(args):
+ model_name = args.model_name
+ pretrained = args.model_path
+
+ image_path = args.image_path
+ caption = args.classes
+
+ model, _, preprocess = create_model_and_transforms(model_name, pretrained, force_custom_clip=True)
+ tokenizer = get_tokenizer(model_name)
+ model = model.to(f"npu:{args.device}")
+
+ image = preprocess(Image.open(image_path)).unsqueeze(0).to(f"npu:{args.device}")
+ text = tokenizer(caption).to(f"npu:{args.device}")
+
+ image_features = model.encode_image(image)
+ text_features = model.encode_text(text)
+ image_features_norm = image_features / image_features.norm(dim=-1, keepdim=True)
+ text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True)
+ text_probs = (100.0 * image_features_norm @ text_features_norm.T).softmax(dim=-1)
+ print("Label probs:", text_probs)
+ return image_features, text_features, text_probs
+
+def compare(args):
+ aie_img, aie_text, aie_res = inference_aie(args)
+ pytorch_img, pytorch_text, pytorch_res = inference_pytorch(args)
+
+ img_cmp = cos_sim(aie_img, pytorch_img)
+ text_cmp = cos_sim(aie_text, pytorch_text)
+ res_cmp = cos_sim(aie_res, pytorch_res)
+
+ print("*******compile compare********")
+ print(">>>>>>image similarity: {}>>>>>>>>>>>".format(img_cmp))
+ print(">>>>>>text similarity: {}>>>>>>>>>>>>".format(text_cmp))
+ print(">>>>>>result similarity: {}>>>>>>>>>>>>".format(res_cmp))
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Compile Clip model")
+ parser.add_argument("--soc_version", default="Ascend910B4", help="NPU version")
+ parser.add_argument("--device", type=int, default=4)
+ parser.add_argument("--text-max-batch", type=int, default=80)
+ parser.add_argument("--img-max-batch", type=int, default=8)
+
+ parser.add_argument(
+ "--model_name",
+ default="EVA02-CLIP-B-16",
+ choices=["EVA02-CLIP-B-16", "EVA02-CLIP-L-14", "EVA02-CLIP-L-14-336",
+ "EVA02-CLIP-bigE-14", "EVA02-CLIP-bigE-14-plus"],
+ help="Specify the architecture of EVA CLIP model to be converted."
+ )
+ parser.add_argument(
+ "--model_path",
+ type=str,
+ help="Path of the pretrained model."
+ )
+ parser.add_argument(
+ "--image_path",
+ type=str,
+ default="./EVA/EVA-CLIP/assets/CLIP.png",
+ help="Path of the test image."
+ )
+ parser.add_argument(
+ "--compare_type",
+ default="compiled",
+ type=str,
+ help="infer with [traced] model or [compiled] model."
+ )
+ parser.add_argument(
+ "--save_dir",
+ default="./mindie_weights",
+ help="Path to save compiled weight"
+ )
+ parser.add_argument(
+ "--classes",
+ default=["a diagram", "a dog", "a cat"],
+ help="Text input"
+ )
+ parser.add_argument(
+ "--compile",
+ action='store_true',
+ help="Compile EVA CLIP model"
+ )
+ parser.add_argument(
+ "--infer",
+ action='store_true',
+ help="Infer EVA CLIP model."
+ )
+ parser.add_argument(
+ "--infer_type",
+ default="aie",
+ choices=["aie", "pytorch"],
+ help="Specify the infer type, MindIE or Pytorch."
+ )
+ parser.add_argument(
+ "--compare",
+ action='store_true',
+ help="Compare EVA CLIP compiled model with pytorch model"
+ )
+ args = parser.parse_args()
+ return args
+
+if __name__ == "__main__":
+ args = parse_args()
+ if args.compile:
+ compile_eva_text(args)
+ compile_eva_vis(args)
+ elif args.infer and args.infer_type == "aie":
+ inference_aie(args)
+ elif args.infer and args.infer_type == "pytorch":
+ inference_pytorch(args)
+ elif args.compare:
+ compare(args)
+ else:
+ raise ValueError("You must specify [--compile, --infer, --compare]")