diff --git a/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/README.md b/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d9e7975c86b18d11e5fe5ec78ea05b303711b374
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/README.md
@@ -0,0 +1,176 @@
+# StyleGAN2-ADA模型-推理指导
+
+
+- [概述](#ZH-CN_TOPIC_0000001172161501)
+
+ - [输入输出数据](#section540883920406)
+
+
+
+- [推理环境准备](#ZH-CN_TOPIC_0000001126281702)
+
+- [快速上手](#ZH-CN_TOPIC_0000001126281700)
+
+ - [获取源码](#section4622531142816)
+ - [准备数据集](#section183221994411)
+ - [模型推理](#section741711594517)
+
+- [模型推理性能&精度](#ZH-CN_TOPIC_0000001172201573)
+- [开发者自测 && 测试覆盖率](#section741711594518)
+
+ ******
+
+# 概述
+
+StyleGAN2-ADA是具有自适应鉴别器增强(ADA)的StyleGAN2,用有限的数据训练生成对抗网络,[论文链接](https://arxiv.org/abs/2006.06676)。
+
+- 参考实现:
+
+ ```
+ url=https://github.com/NVlabs/stylegan2-ada-pytorch
+ commit_id=765125e7f01a4c265e25d0a619c37a89ac3f5107
+ code_path=
+ model_name=StyleGAN2-ADA
+ ```
+
+## 输入输出数据
+
+- 输入数据
+
+ | 输入数据 | 数据类型 | 大小 | 数据排布格式 |
+ | -------- | -------- | ------------------------- | ------------ |
+ | input | RGB_FP32 | batchsize x 512 | ND |
+
+
+- 输出数据
+
+ | 输出数据 | 数据类型 | 大小 | 数据排布格式 |
+ | -------- | -------- | -------- | ------------ |
+ | output1 | FLOAT32 | batchsize x 3 x 512 x 512| NCHW |
+
+
+# 推理环境准备
+
+- 该模型需要以下插件与驱动
+
+ **表 1** 版本配套表
+
+ | 配套 | 版本 | 环境准备指导 |
+ |---------| ------- | ------------------------------------------------------------ |
+ | 固件与驱动 | 23.0.rc1 | [Pytorch框架推理环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/pies) |
+ | CANN | 7.0.RC1.alpha003 | - |
+ | Python | 3.9.11 | - |
+ | PyTorch | 2.0.1 | - |
+ | Torch_AIE | 6.3.rc2 | - |
+
+# 快速上手
+
+## 获取源码
+
+1. 获取源码。
+
+ ```
+ git clone https://github.com/NVlabs/stylegan2-ada-pytorch.git
+ ```
+
+2. 安装依赖。
+
+ ```
+ pip3 install -r requirements.txt
+ ```
+
+## 准备数据集
+
+1. 获取原始数据集。(解压命令参考tar –xvf \*.tar与 unzip \*.zip)
+
+ StyleGAN2-ADA网络使用随机生成隐变量作为输入来生成输出,生成方式见下一步。
+
+
+2. 数据预处理,将原始数据集转换为模型输入的数据。
+
+ 执行`preprocess.py`脚本,完成预处理。
+
+ ```
+ python3 preprocess.py --num_input=200 --save_path=./pre_data
+ ```
+ 生成`num_input个`随机输入,并保存为bin文件,保存目录为`./pre_data`。
+
+
+## 模型推理
+
+1. 开始推理验证。
+
+ 1. 获取权重文件。
+
+ 权重文件为:[G_ema_bs8_8p_kimg1000.pkl](https://ascend-repo-modelzoo.obs.cn-east-2.myhuaweicloud.com/model/1_PyTorch_PTH/StyleGAN/PTH/G_ema_bs8_8p_kimg1000.pkl)
+ 将获取的权重文件放在当前工作路径下。
+
+ 2. 导出ts模型
+
+ 执行如下命令导出ts模型文件:
+ ```
+ python3 export.py --stylegan_path ./stylegan2-ada-pytorch-main --pkl_file ./G_ema_bs8_8p_kimg1000.pkl
+ ```
+ - 参数说明:
+
+ - stylegan_path:拉下github仓的目录
+ - pkl_file:pkl文件目录
+
+
+ 3. 执行推理,生成伪装图片
+
+ ```
+ python3 infer.py --bin_path ./pre_data --image_path ./results --ts_model_path ./stylegan2.ts
+ ```
+
+ - 参数说明:
+
+ - bin_path:预处理得到的bin文件的目录
+ - image_path:推理后得到图片的目录
+ - ts_model_path:ts模型文件路径
+
+
+ 3. 精度验证(肉眼验证)。
+
+ a.调用`perf_gpu.py`脚本使用pkl权重文件成生图像,用于进行精度对比。
+
+ ```
+ python3 perf_gpu.py --stylegan_path ./stylegan2-ada-pytorch --pkl_file ./G_ema_bs8_8p_kimg1000.pkl --input_path ./pre_data --image_path ./results
+ ```
+
+ - 参数说明:
+
+ - stylegan_path:拉下github仓的目录
+ - pkl_file:pkl文件目录
+ - input_path:预处理得到的bin文件的目录
+ - image_path:推理后得到图片的目录
+
+ b.精度比对方法:将ts模型的推理结果转化的图像与pkl权重文件成生图像进行对比(均在image path路径下,原pkl推理结果在pkl_img子目录下),两幅图像在视觉效果上一致即可。
+
+
+ 4. 性能验证。
+
+ 使用`perf.py`脚本进行性能验证
+ ```
+ python3 perf.py --mode ts --ts_path ./stylegan2.ts --batch_size 1
+ ```
+
+ - 参数说明:
+
+ - mode:使用ts模型进行推理
+ - ts_path:ts模型文件所在路径
+ - batch_size:用于验证性能的batch size路径
+
+
+# 模型推理性能&精度
+
+调用ACL接口推理计算,性能参考下列数据。
+
+| 芯片型号 | Batch Size | 数据集 | 精度 | 性能 |
+| --------- | ---------------- | ---------- | ---------- | --------------- |
+| Ascend310P3 | 1 | 随机生成数据 | 图片视觉评价 | 23.50 fps |
+| Ascend310P3 | 16 | 随机生成数据 | 图片视觉评价 | 24.59 fps |
+
+# 开发者自测 && 测试覆盖率
+
+
\ No newline at end of file
diff --git a/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/export.py b/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/export.py
new file mode 100644
index 0000000000000000000000000000000000000000..76ec6a26d35df9bc47a5a5824bcc07a0f0d3ba8f
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/export.py
@@ -0,0 +1,103 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the License);
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# BSD 3-Clause License
+#
+# Copyright (c) 2017 xxxx
+# All rights reserved.
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+# ============================================================================
+
+
+import os
+import sys
+import pickle
+import argparse
+import functools
+
+import torch
+
+
+def parse_args():
+ args = argparse.ArgumentParser(description="A program that trace the pkl model into torchscript file")
+ args.add_argument('--stylegan_path', type=str, default='./stylegan2-ada-pytorch-main',
+ help='Path to the stylegan2-ada-pytorch Github directory')
+ args.add_argument('--pkl_file',help='stylegan2-ada pth file path', type=str,
+ default='./G_ema_bs8_8p_kimg1000.pkl'
+ )
+ return args.parse_args()
+
+
+def check_args(args):
+ if not os.path.exists(args.stylegan_path):
+ raise FileNotFoundError(f'The stylegan2-ada-pytorch Github directory {args.stylegan_path} not exists')
+ if not os.path.exists(args.pkl_file):
+ raise FileNotFoundError(f'The model file stylegan2-ada pth {args.pkl_file} not exists')
+
+
+def trace_ts_model(args):
+ sys.path.append(args.stylegan_path)
+
+ # Set up options
+ bs = 1
+ pkl_file = args.pkl_file
+
+ # Load pkl model
+ with open(pkl_file, 'rb') as f:
+ G = pickle.load(f)['G_ema']
+ G.forward = functools.partial(G.forward, force_fp32=True)
+
+ print("type of G is:", type(G))
+
+ # Prepare input
+ print(G.z_dim)
+ z = torch.randn([bs, G.z_dim])
+ c = torch.empty([bs, 0], device='cpu')
+ dummy_input = (z, c)
+ G.eval()
+
+ ts_model = torch.jit.trace(G, dummy_input)
+ ts_model.save("./stylegan2-bs.ts")
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ check_args(args)
+ trace_ts_model(args)
diff --git a/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/infer.py b/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..15e1d00203e0c6237692585613d00f29aec03225
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/infer.py
@@ -0,0 +1,98 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the License);
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# BSD 3-Clause License
+#
+# Copyright (c) 2017 xxxx
+# All rights reserved.
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+# ============================================================================
+
+import torch
+import os
+import numpy as np
+import argparse
+
+from tqdm import tqdm
+from perf_gpu import save_image_grid
+
+import torch_aie
+
+
+def test_om(args):
+ ts_model_path = args.ts_model_path
+ ts_model = torch.jit.load(ts_model_path)
+ input_info = [torch_aie.Input((1, 512)), torch_aie.Input((1, 0))]
+ torch_aie.set_device(0)
+ print("start_compile")
+ torchaie_model = torch_aie.compile(
+ ts_model,
+ inputs=input_info,
+ precision_policy=torch_aie.PrecisionPolicy.PREF_FP32, # _enums.PrecisionPolicy.FP32
+ soc_version='Ascend310P3'
+ )
+ print("end_compile")
+ torchaie_model.eval()
+
+ bin_path = args.bin_path
+ image_path = args.image_path
+ bin_list = os.listdir(bin_path)
+ bin_list.sort()
+ for i in tqdm(range(len(bin_list))):
+ input = np.fromfile(os.path.join(bin_path, bin_list[i]), dtype=np.float32).reshape(-1,512)
+ input = torch.Tensor(input)
+ c = torch.empty([1, 0])
+ images = torchaie_model(input.to('npu'), c.to('npu'))
+ images = images.to('cpu')
+ bs = images.shape[0]
+ grid_size = (4, 4) if bs == 16 else (1, 1)
+ os.makedirs(image_path, exist_ok=True)
+ save_image_grid(images, os.path.join(image_path, f'gen_image_{i:04d}') + ".png", drange=[-1, 1],
+ grid_size=grid_size)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--bin_path', type=str, default="./pre_data")
+ parser.add_argument('--image_path', type=str, default='./results')
+ parser.add_argument('--ts_model_path', type=str, default='./stylegan2.ts')
+ args = parser.parse_args()
+
+ test_om(args)
diff --git a/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/perf.py b/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/perf.py
new file mode 100644
index 0000000000000000000000000000000000000000..4765be36327890d186071213922b5b3205f013d9
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/perf.py
@@ -0,0 +1,78 @@
+import argparse
+import time
+
+import torch
+import numpy as np
+
+import torch_aie
+from torch_aie import _enums
+
+
+INPUT_SIZE = 512
+
+
+def parse_args():
+ args = argparse.ArgumentParser(description="A program that gets the performance of torchaie on stylegan.")
+ args.add_argument('--ts_path',help='stylegan2 ts file path', type=str,
+ default='./stylegan2-bs.ts'
+ )
+ args.add_argument("--batch_size", type=int, default=1, help="batch size.")
+ args.add_argument('--image_size', type=int, default=512, help='Image size')
+ return args.parse_args()
+
+
+def perf(torchaie_model, batch_size, image_size):
+ input = np.zeros((batch_size, image_size))
+ input = torch.Tensor(input)
+ c = torch.empty([batch_size, 0])
+ input_tensor = input.to('npu')
+ input_tensor2 = c.to('npu')
+ print("ready to infer")
+ _ = torchaie_model(input_tensor, input_tensor2)
+ print("finish infer once")
+ loops = 100
+ warm_ctr = 10
+
+ default_stream = torch_aie.npu.default_stream()
+ time_cost = 0
+
+ while warm_ctr:
+ _ = torchaie_model(input_tensor, input_tensor2)
+ default_stream.synchronize()
+ warm_ctr -= 1
+
+ for i in range(loops):
+ t0 = time.time()
+ _ = torchaie_model(input_tensor, input_tensor2)
+ default_stream.synchronize()
+ t1 = time.time()
+ time_cost += (t1 - t0)
+ print(i)
+
+ print(f"fps: {loops} * {batch_size} / {time_cost : .3f} samples/s")
+ print("torch_aie fps: ", loops * batch_size / time_cost)
+
+ from datetime import datetime
+ current_time = datetime.now()
+ formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
+ print("Current Time:", formatted_time)
+
+
+if __name__ == '__main__':
+ opts = parse_args()
+ batch_size = opts.batch_size
+ image_size = opts.image_size
+ ts_model = torch.jit.load(opts.ts_path)
+ input_info = [torch_aie.Input((batch_size, image_size)), torch_aie.Input((batch_size, 0))]
+ torch_aie.set_device(0)
+ print("start compile")
+ torchaie_model = torch_aie.compile(
+ ts_model,
+ inputs=input_info,
+ precision_policy=_enums.PrecisionPolicy.FP32,
+ soc_version='Ascend310P3',
+ )
+ print("end compile")
+ torchaie_model.eval()
+
+ perf(torchaie_model, batch_size, image_size)
diff --git a/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/perf_gpu.py b/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/perf_gpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb4a06a6042400b97ce9efcdd5a541c145008dff
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/perf_gpu.py
@@ -0,0 +1,125 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the License);
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# BSD 3-Clause License
+#
+# Copyright (c) 2017 xxxx
+# All rights reserved.
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+# ============================================================================
+
+import os
+import sys
+import time
+import pickle
+import argparse
+import functools
+from tqdm import tqdm
+
+import numpy as np
+import torch
+import PIL.Image
+
+
+def save_image_grid(img, fname, drange, grid_size):
+ lo, hi = drange
+ img = np.asarray(img, dtype=np.float32)
+ img = (img - lo) * (255 / (hi - lo))
+ img = np.rint(img).clip(0, 255).astype(np.uint8)
+
+ gw, gh = grid_size
+ _N, C, H, W = img.shape
+ img = img.reshape(gh, gw, C, H, W)
+ img = img.transpose(0, 3, 1, 4, 2)
+ img = img.reshape(gh * H, gw * W, C)
+
+ assert C in [1, 3]
+ if C == 1:
+ PIL.Image.fromarray(img[:, :, 0], 'L').save(fname)
+ if C == 3:
+ PIL.Image.fromarray(img, 'RGB').save(fname)
+
+
+def main(args):
+ pkl_file = args.pkl_file
+ bs = args.batch_size
+ input_path = args.input_path
+ image_path = args.image_path
+ device = "cpu"
+
+ grid_size = (1, 1)
+ input_files = os.listdir(input_path)
+ input_files.sort()
+ image_path = os.path.join(image_path, 'pkl_img')
+ os.makedirs(image_path, exist_ok=True)
+ # load model
+ start = time.time()
+ with open(pkl_file, 'rb') as f:
+ G = pickle.load(f)['G_ema']
+
+ G.forward = functools.partial(G.forward, force_fp32=True)
+ for i in tqdm(range(len(input_files))):
+ input_file = input_files[i]
+ input_file = os.path.join(input_path, input_file)
+ input_file = np.fromfile(input_file, dtype=np.float32)
+ z = torch.tensor(input_file).reshape(-1, G.z_dim).to(device)
+ c = torch.empty(bs, 0).to(device)
+ image = G(z, c)
+ image = image.reshape(-1, 3, 512, 512)
+ image = image.cpu()
+ save_image_grid(image, os.path.join(image_path, f'gen_image_{i:04d}') + '.png', drange=[-1, 1],
+ grid_size=grid_size)
+
+ end = time.time()
+ print(f'Inference average time : {((end - start) * 1000 / len(input_files)):.2f} ms')
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--stylegan_path', type=str, default='./stylegan2-ada-pytorch',
+ help='Path to the stylegan2-ada-pytorch Github directory')
+ parser.add_argument('--pkl_file', type=str, default='./G_ema_bs8_8p_kimg1000.pkl')
+ parser.add_argument('--input_path', type=str, default='./pre_data')
+ parser.add_argument('--image_path', type=str, default='./results')
+ parser.add_argument('--batch_size', type=int, default=1)
+ args = parser.parse_args()
+
+ sys.path.append(args.stylegan_path)
+
+ main(args)
diff --git a/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/preprocess.py b/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..1dcbbf8bccb32cfdce90d8ab54c32eb0f5df8c2d
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/preprocess.py
@@ -0,0 +1,92 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the License);
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# BSD 3-Clause License
+#
+# Copyright (c) 2017 xxxx
+# All rights reserved.
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+# ============================================================================
+
+import os
+import torch
+import argparse
+
+from tqdm import tqdm
+
+
+def parse_args():
+ args = argparse.ArgumentParser(description="A program that generates model data to be predicted")
+ args.add_argument('--num_input', type=int, default=1)
+ args.add_argument('--batch_size', type=int, default=1)
+ args.add_argument('--save_path', type=str, default='./input')
+ return args.parse_args()
+
+
+def check_args(args):
+ if args.num_input <= 0:
+ raise ValueError(f"num_input must be greater than 0. Got: {args.num_input}.")
+ if args.batch_size <= 0:
+ raise ValueError(f"batch_size must be greater than 0. Got: {args.batch_size}.")
+
+
+def main(args):
+ # set up option
+ z_dim = 512
+ c_dim = 0
+ bs = 1
+ num = args.num_input
+ save_path = args.save_path
+
+ # create save path dir
+ os.makedirs(save_path, exist_ok=True)
+
+ # generate input
+ for i in tqdm(range(num)):
+ z = torch.randn([bs, z_dim])
+ c = torch.empty([bs, c_dim])
+ input = torch.cat((z, c), 1).numpy()
+ input.tofile(os.path.join(save_path, f'input_bs{bs}_{i:04d}.bin'))
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ check_args(args)
+ main(args)
diff --git a/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/requirements.txt b/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ed86fa1b6ef6eec02e1dd019f47b7fdd614663bd
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/requirements.txt
@@ -0,0 +1,2 @@
+torch==2.0.1
+pillow==10.1.0
\ No newline at end of file
diff --git a/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/test_export.py b/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/test_export.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5b7352dcea9aa84a61fb030920c0cb76ca41420
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/test_export.py
@@ -0,0 +1,119 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the License);
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# BSD 3-Clause License
+#
+# Copyright (c) 2017 xxxx
+# All rights reserved.
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+# ============================================================================
+
+import unittest
+import argparse
+import os
+import shutil
+from unittest.mock import patch
+
+from export import parse_args, check_args, trace_ts_model
+
+
+class TestParseArgs(unittest.TestCase):
+ @patch('argparse.ArgumentParser.parse_args', return_value=argparse.Namespace(
+ stylegan_path='./stylegan',
+ pkl_file='model.pth'
+ ))
+ def test_parse_args_with_arguments(self, mock_parse_args):
+ # Call the function
+ args = parse_args()
+
+ # Check if the function returns the expected arguments
+ self.assertEqual(args.stylegan_path, './stylegan')
+ self.assertEqual(args.pkl_file, 'model.pth')
+
+ # Verify that argparse.ArgumentParser.parse_args was called
+ mock_parse_args.assert_called_once()
+
+
+class TestCheckArgs(unittest.TestCase):
+ def test_nonexistent_stylegan_path(self):
+ args = argparse.Namespace(stylegan_path='nonexistent_stylegan.pth', pkl_file='nonexistent_model.pth')
+
+ with self.assertRaises(FileNotFoundError) as context:
+ check_args(args)
+ expected_message = f'The stylegan2-ada-pytorch Github directory {args.stylegan_path} not exists'
+ self.assertEqual(str(context.exception), expected_message)
+
+ def test_nonexistent_model_path(self):
+ stylegan_path = 'existing_stylegan'
+ os.makedirs(stylegan_path, exist_ok=True)
+
+ args = argparse.Namespace(stylegan_path=stylegan_path, pkl_file='nonexistent_model.pth')
+
+ with self.assertRaises(FileNotFoundError) as context:
+ check_args(args)
+ expected_message = f'The model file stylegan2-ada pth {args.pkl_file} not exists'
+ shutil.rmtree(stylegan_path)
+ self.assertEqual(str(context.exception), expected_message)
+
+
+class TestTraceTSModel(unittest.TestCase):
+ def setUp(self):
+ pass
+ # self.temp_dir = 'temp_dir'
+ # os.makedirs(self.temp_dir, exist_ok=True)
+
+ def tearDown(self):
+ pass
+ # shutil.rmtree(self.temp_dir)
+
+ @patch('argparse.ArgumentParser.parse_args', return_value=argparse.Namespace(
+ stylegan_path='./stylegan2-ada-pytorch-main',
+ pkl_file='./models/G_ema_bs8_8p_kimg1000.pkl'
+ ))
+ def test_trace_ts_model(self, mock_parse_args):
+ args = parse_args()
+ trace_ts_model(args)
+
+ # Check if the traced model file exists
+ self.assertTrue(os.path.exists("./stylegan2-bs.ts"))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/test_infer.py b/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/test_infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b15f7201db7fe64a242903ea3b6b2f85b74e9f31
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/test_infer.py
@@ -0,0 +1,117 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the License);
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# BSD 3-Clause License
+#
+# Copyright (c) 2017 xxxx
+# All rights reserved.
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+# ============================================================================
+
+import unittest
+import argparse
+import os
+import shutil
+from unittest.mock import patch
+
+from preprocess import parse_args, check_args, main
+from infer import test_om
+
+
+class TestParseArgs(unittest.TestCase):
+ @patch('argparse.ArgumentParser.parse_args', return_value=argparse.Namespace(
+ num_input=1,
+ batch_size=1
+ ))
+ def test_parse_args_with_arguments(self, mock_parse_args):
+ # Call the function
+ args = parse_args()
+
+ # Check if the function returns the expected arguments
+ self.assertEqual(args.num_input, 1)
+ self.assertEqual(args.batch_size, 1)
+
+ # Verify that argparse.ArgumentParser.parse_args was called
+ mock_parse_args.assert_called_once()
+
+
+class TestCheckArgs(unittest.TestCase):
+ def test_negative_num_input(self):
+ args = argparse.Namespace(num_input=-1, batch_size=1)
+
+ with self.assertRaises(ValueError) as context:
+ check_args(args)
+ expected_message = f"num_input must be greater than 0. Got: {args.num_input}."
+ self.assertEqual(str(context.exception), expected_message)
+
+ def test_negative_batch_size(self):
+ args = argparse.Namespace(num_input=1, batch_size=-1)
+
+ with self.assertRaises(ValueError) as context:
+ check_args(args)
+ expected_message = f"batch_size must be greater than 0. Got: {args.batch_size}."
+ self.assertEqual(str(context.exception), expected_message)
+
+
+class TestMain(unittest.TestCase):
+ def test_preprocess_main(self):
+ with patch('argparse.ArgumentParser.parse_args', return_value=argparse.Namespace(
+ num_input=200,
+ batch_size=1,
+ save_path="temp_path"
+ )):
+ args = parse_args()
+ check_args(args)
+ main(args)
+ files_in_folder = os.listdir(args.save_path)
+ self.assertEqual(len(files_in_folder), args.num_input, f"Expected {args.num_input} files, but found {len(files_in_folder)}.")
+
+ def test_infer(self):
+ with patch('argparse.ArgumentParser.parse_args', return_value=argparse.Namespace(
+ bin_path="pre_data", # 假设已存在
+ image_path="./results-bak",
+ ts_model_path="./stylegan2-bs.ts"
+ )):
+ args = parse_args()
+ test_om(args)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/test_perf.py b/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/test_perf.py
new file mode 100644
index 0000000000000000000000000000000000000000..24404c8653e4fa1928062351b34486d5e7d623ea
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/gan/stylegan2-ada/test_perf.py
@@ -0,0 +1,91 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the License);
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# BSD 3-Clause License
+#
+# Copyright (c) 2017 xxxx
+# All rights reserved.
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+# ============================================================================
+
+import unittest
+import argparse
+from unittest.mock import patch
+import torch
+
+from perf import parse_args, perf
+
+import torch_aie
+from torch_aie import _enums
+
+
+class TestStyleGANPerf(unittest.TestCase):
+ def test_parse_args(self):
+ with patch('argparse.ArgumentParser.parse_args', return_value=argparse.Namespace(
+ ts_path='./stylegan2-bs.ts',
+ batch_size=1,
+ image_size=512
+ )):
+ args = parse_args()
+
+ self.assertEqual(args.ts_path, './stylegan2-bs.ts')
+ self.assertEqual(args.batch_size, 1)
+ self.assertEqual(args.image_size, 512)
+
+ def test_perf(self):
+ with patch('argparse.ArgumentParser.parse_args', return_value=argparse.Namespace(
+ ts_path='./stylegan2-bs.ts',
+ batch_size=1,
+ image_size=512
+ )):
+ args = parse_args()
+ ts_model = torch.jit.load(args.ts_path)
+ input_info = [torch_aie.Input((args.batch_size, args.image_size)), torch_aie.Input((args.batch_size, 0))]
+ torchaie_model = torch_aie.compile(
+ ts_model,
+ inputs=input_info,
+ precision_policy=_enums.PrecisionPolicy.FP32, # 必须是32才能跑
+ soc_version='Ascend310P3'
+ )
+ torchaie_model.eval()
+ perf(torchaie_model, args.batch_size, args.image_size)
+
+
+if __name__ == '__main__':
+ unittest.main()