From fe0c65cc286ce5f2f228632341174d914f56c1ce Mon Sep 17 00:00:00 2001 From: commc Date: Wed, 4 Sep 2024 14:29:23 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E5=AE=A2=E6=88=B7=E9=9C=80=E6=B1=82CLIP?= =?UTF-8?q?=E6=A8=A1=E5=9E=8Bcompile=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/multimodal/compile_aie.py | 169 ++++++++++++++++++ .../built-in/multimodal/compile_ts.py | 147 +++++++++++++++ 2 files changed, 316 insertions(+) create mode 100644 MindIE/MindIE-Torch/built-in/multimodal/compile_aie.py create mode 100644 MindIE/MindIE-Torch/built-in/multimodal/compile_ts.py diff --git a/MindIE/MindIE-Torch/built-in/multimodal/compile_aie.py b/MindIE/MindIE-Torch/built-in/multimodal/compile_aie.py new file mode 100644 index 0000000000..35f13c9747 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/compile_aie.py @@ -0,0 +1,169 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time +import json +import logging +import argparse +import torch +import mindietorch +import torch.nn as nn +from torch._export import export, dynamic_dim +from mindietorch import _enums +from transformers.models.auto.modeling_auto import AutoModel + +logging.basicConfig(level=logging.INFO) + + +class CLIPWrapper(nn.Module): + def __init__(self, clip): + super(CLIPWrapper, self).__init__() + self.model = clip + self.logit_scale = clip.logit_scale.exp().to(self.model.device) + + def forward(self, input_ids, pixel_values, attention_mask): + image_embeds = self.model.get_image_features(pixel_values) + text_embeds = self.model.get_text_features(input_ids, attention_mask) + + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + logits_per_image = image_embeds @ text_embeds.transpose(1, 0).contiguous() * self.logit_scale + logits_per_text = logits_per_image.transpose(1, 0).contiguous() + + return image_embeds, text_embeds, logits_per_text, logits_per_image + + +def compile_clip(args): + # 加载Pytorch模型 + with torch.no_grad(): + torch_model = AutoModel.from_pretrained(args.hf_model_path).float().eval() + torch_model = CLIPWrapper(torch_model) + + hf_config_path = os.path.join(args.hf_model_path, "config.json") + if not os.path.exists(hf_config_path): + raise FileNotFoundError(f"config.json not found at {args.hf_model_path}: {hf_config_path}") + with open(hf_config_path, "r") as f: + config_dict = json.load(f) + + # 构造模型输入 + image_width = config_dict["vision_config"]["image_size"] + pixel_values_shape = (args.img_max_batch, 3, image_width, image_width) + input_ids_shape = (args.text_max_batch, args.max_token_len) + pixel_values = torch.ones(pixel_values_shape, dtype=torch.float32) + input_ids = torch.randint(high=1, size=input_ids_shape, dtype=torch.int32) + attention_mask = torch.ones_like(input_ids, dtype=torch.int32) + + # 导出fx格式模型并执行MindIE编译 + constraints = [ + # input ids + dynamic_dim(input_ids, 0) >= 1, + dynamic_dim(input_ids, 0) <= args.text_max_batch, + dynamic_dim(input_ids, 1) >= 1, + dynamic_dim(input_ids, 1) <= 52, + # pixel input + dynamic_dim(pixel_values, 0) >= 1, + dynamic_dim(pixel_values, 0) <= args.img_max_batch, + # input ids attention mask + dynamic_dim(attention_mask, 0) == dynamic_dim(input_ids, 0), + dynamic_dim(attention_mask, 1) == dynamic_dim(input_ids, 1), + ] + + logging.info("Starting to export dynamic clip ...") + intermediate_model = export( + torch_model, + args=(input_ids, pixel_values, attention_mask,), + constraints=constraints + ) + logging.info("Successfully export dynamic clip!") + + mindietorch.set_device(args.device_id) + pixel_values_min_shape = (1, 3, image_width, image_width) + pixel_values_max_shape = (args.img_max_batch, 3, image_width, image_width) + input_ids_min_shape = (1, 1) + input_ids_max_shape = (args.text_max_batch, args.max_token_len) + + # 执行MindIETorch编译 + compile_inputs = [ + mindietorch.Input(min_shape=input_ids_min_shape, max_shape=input_ids_max_shape), + mindietorch.Input(min_shape=pixel_values_min_shape, max_shape=pixel_values_max_shape), + mindietorch.Input(min_shape=input_ids_min_shape, max_shape=input_ids_max_shape), # attention mask + ] + + if args.precision == "fp16": + model_precision = _enums.PrecisionPolicy.FP16 + elif args.precision == "fp32": + model_precision = _enums.PrecisionPolicy.FP32 + else: + raise ValueError("Unsupported precision type!") + + logging.info("Starting to compile mindietorch clip ...") + ts = time.time() + compiled_model = mindietorch.compile( + intermediate_model, + inputs=compile_inputs, + precision_policy=model_precision, + soc_version=args.soc_version, + ) + compile_cost = time.time() - ts + logging.info("compile time cost: %f", compile_cost) + logging.info("Successfully exported mindietorch clip!") + + logging.info("Starting to save ...") + model_save_dir = f"{args.save_dir}" + if not os.path.exists(model_save_dir): + os.makedirs(model_save_dir) + compiled_file_name = f"CLIP-{args.model_version}-MindIE.pt" + torch.save(compiled_model, model_save_dir + compiled_file_name, pickle_protocol=4) + logging.info("Saving done!") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Compile Clip model") + parser.add_argument("--soc-version", type=str, default="Ascend910B4", help="NPU version") + parser.add_argument("--device-id", type=int, default=0) + parser.add_argument("--text-max-batch", type=int, default=80) + parser.add_argument("--img-max-batch", type=int, default=8) + parser.add_argument( + "--max-token-len", + type=int, + default=52, + help="The padded length of input text (include [CLS] & [SEP] tokens)." + ) + parser.add_argument( + "--model-version", + default="ViT-B-16", + choices=["ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"], + help="Specify the architecture of CLIP model to be converted." + ) + parser.add_argument( + "--hf-model-path", + default="/Path/to/Huggingface_model_path", + type=str, + help="Huggingface CLIP Model Path." + ) + parser.add_argument( + "--precision", + default="fp16", + choices=["fp16", "fp32"], + help="Specify the precision of CLIP model to be converted." + ) + parser.add_argument("--save-dir", type=str, default="./", help="Path to save the exported model") + + return parser.parse_args() + + +if __name__ == "__main__": + compile_args = parse_args() + compile_clip(compile_args) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/multimodal/compile_ts.py b/MindIE/MindIE-Torch/built-in/multimodal/compile_ts.py new file mode 100644 index 0000000000..74f6c1892e --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/compile_ts.py @@ -0,0 +1,147 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time +import json +import logging +import argparse +import torch +import mindietorch +import torch.nn as nn +from mindietorch import _enums +from transformers.models.auto.modeling_auto import AutoModel + +logging.basicConfig(level=logging.INFO) + + +class CLIPWrapper(nn.Module): + def __init__(self, clip): + super(CLIPWrapper, self).__init__() + self.model = clip + self.logit_scale = clip.logit_scale.exp() + self.logit_scale.to(self.model.device) + + def forward(self, input_ids, pixel_values, attention_mask): + image_embeds = self.model.get_image_features(pixel_values) + text_embeds = self.model.get_text_features(input_ids, attention_mask) + + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + logits_per_image = image_embeds @ text_embeds.transpose(1, 0).contiguous() * self.logit_scale + logits_per_text = logits_per_image.transpose(1, 0).contiguous() + + return image_embeds, text_embeds, logits_per_text, logits_per_image + + +def compile_clip(args): + # 加载Pytorch模型 + with torch.no_grad(): + torch_model = AutoModel.from_pretrained(args.hf_model_path).float().eval() + torch_model = CLIPWrapper(torch_model) + + hf_config_path = os.path.join(args.hf_model_path, "config.json") + if not os.path.exists(hf_config_path): + raise FileNotFoundError(f"config.json not found at {args.hf_model_path}: {hf_config_path}") + with open(hf_config_path, "r") as f: + config_dict = json.load(f) + + # 构造模型输入 + image_width = config_dict["vision_config"]["image_size"] + pixel_values_shape = (args.img_batch, 3, image_width, image_width) + input_ids_shape = (args.text_batch, args.token_len) + pixel_values = torch.ones(pixel_values_shape, dtype=torch.float32) + input_ids = torch.randint(high=1, size=input_ids_shape, dtype=torch.int32) + attention_mask = torch.ones_like(input_ids, dtype=torch.int32) + + # 导出ts格式模型并执行MindIE编译 + input_data = [input_ids, pixel_values, attention_mask] + logging.info("Starting to trace clip ...") + intermediate_model = torch.jit.trace(torch_model, input_data) + logging.info("Successfully trace clip!") + + mindietorch.set_device(args.device_id) + + # 执行MindIETorch编译 + compile_inputs = [ + mindietorch.Input(shape=input_ids_shape, dtype=torch.int32), + mindietorch.Input(shape=pixel_values_shape, dtype=torch.float32), + mindietorch.Input(shape=input_ids_shape, dtype=torch.int32) + ] + if args.precision == "fp16": + model_precision = _enums.PrecisionPolicy.FP16 + elif args.precision == "fp32": + model_precision = _enums.PrecisionPolicy.FP32 + else: + raise ValueError("Unsupported precision type!") + + logging.info("Starting to compile mindietorch clip ...") + ts = time.time() + compiled_model = mindietorch.compile( + intermediate_model, + inputs=compile_inputs, + precision_policy=model_precision, + soc_version=args.soc_version, + ) + compile_cost = time.time() - ts + logging.info("compile time cost: %f", compile_cost) + logging.info("Successfully exported mindietorch clip!") + + logging.info("Starting to save ...") + model_save_dir = f"{args.save_dir}" + if not os.path.exists(model_save_dir): + os.makedirs(model_save_dir) + compiled_file_name = f"CLIP-{args.model_version}-MindIE.ts" + compiled_model.save(model_save_dir + compiled_file_name) + logging.info("Saving done!") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Compile Clip model") + parser.add_argument("--soc-version", default="Ascend910B4", help="NPU version") + parser.add_argument("--device-id", type=int, default=0) + parser.add_argument("--text-batch", type=int, default=80) + parser.add_argument("--img-batch", type=int, default=1) + parser.add_argument( + "--token-len", + type=int, + default=52, + help="The padded length of input text (include [CLS] & [SEP] tokens)." + ) + parser.add_argument( + "--model-version", + default="ViT-L-14", + choices=["ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"], + help="Specify the architecture of CLIP model to be converted." + ) + parser.add_argument( + "--hf-model-path", + default="/Path/to/Huggingface_model_path", + type=str, + help="Huggingface CLIP Model Path." + ) + parser.add_argument( + "--precision", + default="fp16", + choices=["fp16", "fp32"], + help="Specify the precision of CLIP model to be converted." + ) + parser.add_argument("--save-dir", type=str, default="./", help="Path to save the exported model") + + return parser.parse_args() + + +if __name__ == "__main__": + compile_args = parse_args() + compile_clip(compile_args) \ No newline at end of file -- Gitee From e0726036ac43e73770453d3d5fb42de750e7842f Mon Sep 17 00:00:00 2001 From: commc Date: Wed, 4 Sep 2024 14:59:00 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E5=AE=A2=E6=88=B7=E9=9C=80=E6=B1=82CLIP?= =?UTF-8?q?=E6=A8=A1=E5=9E=8Bonnx=E5=AF=BC=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/multimodal/export_onnx.py | 126 ++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 MindIE/MindIE-Torch/built-in/multimodal/export_onnx.py diff --git a/MindIE/MindIE-Torch/built-in/multimodal/export_onnx.py b/MindIE/MindIE-Torch/built-in/multimodal/export_onnx.py new file mode 100644 index 0000000000..0b5897cb28 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/export_onnx.py @@ -0,0 +1,126 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import json +import logging +import argparse +import torch +import torch.onnx +import torch.nn as nn +from transformers.models.auto.modeling_auto import AutoModel + +logging.basicConfig(level=logging.INFO) + + +class CLIPWrapper(nn.Module): + def __init__(self, clip): + super(CLIPWrapper, self).__init__() + self.model = clip + self.logit_scale = clip.logit_scale.exp() + self.logit_scale.to(self.model.device) + + def forward(self, input_ids, pixel_values, attention_mask): + image_embeds = self.model.get_image_features(pixel_values) + text_embeds = self.model.get_text_features(input_ids, attention_mask) + + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + logits_per_image = image_embeds @ text_embeds.transpose(1, 0).contiguous() * self.logit_scale + logits_per_text = logits_per_image.transpose(1, 0).contiguous() + + return image_embeds, text_embeds, logits_per_text, logits_per_image + + +def export_onnx(args): + + # 加载Pytorch模型 + with torch.no_grad(): + torch_model = AutoModel.from_pretrained(args.hf_model_path).float().eval() + torch_model = CLIPWrapper(torch_model) + + hf_config_path = os.path.join(args.hf_model_path, "config.json") + if not os.path.exists(hf_config_path): + raise FileNotFoundError(f"config.json not found at {args.hf_model_path}: {hf_config_path}") + with open(hf_config_path, "r") as f: + config_dict = json.load(f) + + # 构造模型输入 + image_width = config_dict["vision_config"]["image_size"] + img_input_shape = (1, 3, image_width, image_width) + text_input_shape = (3, args.max_token_len) + input_img = torch.ones(img_input_shape, dtype=torch.float32) + input_ids = torch.randint(high=1, size=text_input_shape, dtype=torch.int32) + attention_mask = torch.ones_like(input_ids, dtype=torch.int32) + torch_model(input_ids, input_img, attention_mask) + + # 导出onnx模型 + file_name = f"CLIP-{args.model_version}.onnx" + model_save_dir = args.save_dir + file_name + logging.info("Starting to export dynamic onnx ...") + text_batch_size = "text_batch_size" + image_batch_size = "image_batch_size" + seq_len = "seq_len" + torch.onnx.export( + torch_model, + (input_ids, input_img, attention_mask), + model_save_dir, + input_names=['input_ids', "pixel_values", "attention_mask"], + output_names=['image_embeds', "text_embeds", "logits_per_text", "logits_per_image"], + export_params=True, + opset_version=13, + verbose=True, + dynamic_axes={ + "input_ids":{0: text_batch_size, 1: seq_len}, + "pixel_values":{0: image_batch_size}, + "attention_mask":{0: text_batch_size, 1: seq_len}, + "image_embeds":{0: image_batch_size}, + "text_embeds":{0: text_batch_size}, + "logits_per_text":{0: text_batch_size, 1: image_batch_size}, + "logits_per_image":{0: image_batch_size, 1:text_batch_size}, + } + ) + logging.info("Successfully exported dynamic onnx!") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Compile Clip model") + parser.add_argument("--text-max-batch", type=int, default=80) + parser.add_argument("--img-max-batch", type=int, default=8) + parser.add_argument( + "--max-token-len", + type=int, + default=52, + help="The padded length of input text (include [CLS] & [SEP] tokens)." + ) + parser.add_argument( + "--model-version", + default="ViT-B-16", + choices=["ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"], + help="Specify the architecture of CLIP model to be converted." + ) + parser.add_argument( + "--hf-model-path", + default="/Path/to/Huggingface_model_path", + type=str, + help="Huggingface CLIP Model Path." + ) + parser.add_argument("--save-dir", type=str, default="./", help="Path to save the exported model") + + return parser.parse_args() + + +if __name__ == "__main__": + compile_args = parse_args() + export_onnx(compile_args) \ No newline at end of file -- Gitee