diff --git a/ACL_PyTorch/contrib/cv/image_process/DnCNN/DnCNN_710/.keep b/ACL_PyTorch/contrib/cv/image_process/DnCNN/DnCNN_710/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ACL_PyTorch/contrib/cv/image_process/DnCNN/DnCNN_710/DnCNN_postprocess.py b/ACL_PyTorch/contrib/cv/image_process/DnCNN/DnCNN_710/DnCNN_postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..dbaf8635a6f9cb01c11193a030c60e08ed8b2147 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/image_process/DnCNN/DnCNN_710/DnCNN_postprocess.py @@ -0,0 +1,103 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import sys +import glob +import numpy as np +import cv2 +import torch +import torch.nn as nn +import struct +from skimage.metrics import peak_signal_noise_ratio as compare_psnr + + +def batch_PSNR(img, imclean, data_range): + + Img = img.data.cpu().numpy().astype(np.float32) + Iclean = imclean.data.cpu().numpy().astype(np.float32) + PSNR = 0 + for i in range(Img.shape[0]): + PSNR += compare_psnr(Iclean[i, :, :, :], Img[i, :, :, :], data_range=data_range) + return (PSNR / Img.shape[0]) + + +def bin2npy(filepath): + + size = os.path.getsize(filepath) + res = [] + L = int(size / 4) + binfile = open(filepath, 'rb') + for i in range(L): + data = binfile.read(4) + num = struct.unpack('f', data) + res.append(num[0]) + binfile.close() + dim_res = np.array(res).reshape(1, 1, 481, 481) + return dim_res + + +def main(Result_path): + + # load data info + print('Loading ISource bin ...\n') + ISource = glob.glob(os.path.join('ISource', '*.bin')) + ISource.sort() + print('Loading INoisy bin ...\n') + INoisy = glob.glob(os.path.join('INoisy', '*.bin')) + INoisy.sort() + # load result file + print('Loading res bin ...\n') + Result_path = glob.glob(os.path.join(Result_path, '*.bin')) + Result_path.sort() + + # begin data + print('begin infer') + psnr_test = 0 + n_lables = 0 + + for isource in ISource: + isource_name = isource + # isource + isource = bin2npy(isource) + isource = torch.from_numpy(isource) + # inoisy + inoisy = bin2npy(INoisy[n_lables]) + inoisy = torch.from_numpy(inoisy) + # Result_path + Result = bin2npy(Result_path[n_lables]) + Result = torch.from_numpy(Result) + n_lables += 1 + print('infering...') + with torch.no_grad(): + Out = torch.clamp(inoisy - Result, 0., 1.) + psnr = batch_PSNR(Out, isource, 1.) + psnr_test += psnr + print("%s PSNR %f" % (isource_name, psnr)) + psnr_test /= len(ISource) + print("\nPSNR on test data %f" % psnr_test) + +if __name__ == "__main__": + + try: + Result_path = sys.argv[1] + + except IndexError: + print("Stopped!") + exit(1) + + if not (os.path.exists(Result_path)): + print("Result path doesn't exist.") + + main(Result_path) diff --git a/ACL_PyTorch/contrib/cv/image_process/DnCNN/DnCNN_710/DnCNN_preprocess.py b/ACL_PyTorch/contrib/cv/image_process/DnCNN/DnCNN_710/DnCNN_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..dd612dcef814aa1947311432bdab6c9102968092 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/image_process/DnCNN/DnCNN_710/DnCNN_preprocess.py @@ -0,0 +1,79 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import sys +import os +import os.path +import numpy as np +import random +import torch +import cv2 +import glob + +infer_data = 'Set68' +infer_noiseL = 15 + +def normalize(data): + return data / 255. + + +def proprecess(data_path, ISource_bin, INoisy_bin): + + # load data info + print('Loading data info ...\n') + files = glob.glob(os.path.join(data_path, infer_data, '*.png')) + files.sort() + # process data + for i in range(len(files)): + # image + filename = os.path.basename(files[i]) + img = cv2.imread(files[i]) + img = normalize(np.float32(img[:, :, 0])) + + img_padded = np.full([481, 481], 0, dtype=np.float32) + width_offset = (481 - img.shape[1]) // 2 + height_offset = (481 - img.shape[0]) // 2 + img_padded[height_offset:height_offset + img.shape[0], width_offset:width_offset + img.shape[1]] = img + img = img_padded + + img = np.expand_dims(img, 0) + img = np.expand_dims(img, 1) + + ISource = torch.Tensor(img) + # noise + noise = torch.FloatTensor(ISource.size()).normal_(mean=0, std=infer_noiseL / 255.) + # noisy image + INoisy = ISource + noise + + # save ISource_bin + ISource = ISource.numpy() + print("ISource shape is", ISource.shape) + ISource.tofile(os.path.join(ISource_bin, filename.split('.')[0] + '.bin')) + + # save INoisy_bin + INoisy = INoisy.numpy() + print("INoisy shape is", INoisy.shape) + INoisy.tofile(os.path.join(INoisy_bin, filename.split('.')[0] + '.bin')) + +if __name__ == '__main__': + + data_path = sys.argv[1] + ISource_bin = sys.argv[2] + INoisy_bin = sys.argv[3] + if os.path.exists(ISource_bin) is False: + os.mkdir(ISource_bin) + if os.path.exists(INoisy_bin) is False: + os.mkdir(INoisy_bin) + + proprecess(data_path, ISource_bin, INoisy_bin) diff --git a/ACL_PyTorch/contrib/cv/image_process/DnCNN/DnCNN_710/DnCNN_pth2onnx.py b/ACL_PyTorch/contrib/cv/image_process/DnCNN/DnCNN_710/DnCNN_pth2onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..8362e6eea7735e96a7763fd570a6553104eb0a7d --- /dev/null +++ b/ACL_PyTorch/contrib/cv/image_process/DnCNN/DnCNN_710/DnCNN_pth2onnx.py @@ -0,0 +1,80 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import torch +import torch.onnx +import torch.nn as nn +import sys + +from collections import OrderedDict + +class DnCNN(nn.Module): + def __init__(self, channels, num_of_layers=17): + super(DnCNN, self).__init__() + kernel_size = 3 + padding = 1 + features = 64 + layers = [] + layers.append(nn.Conv2d(in_channels=channels, out_channels=features, \ + kernel_size=kernel_size, padding=padding, bias=False)) + layers.append(nn.ReLU(inplace=True)) + for _ in range(num_of_layers - 2): + layers.append(nn.Conv2d(in_channels=features, out_channels=features, \ + kernel_size=kernel_size, padding=padding, bias=False)) + layers.append(nn.BatchNorm2d(features)) + layers.append(nn.ReLU(inplace=True)) + layers.append(nn.Conv2d(in_channels=features, out_channels=channels, \ + kernel_size=kernel_size, padding=padding, bias=False)) + self.dncnn = nn.Sequential(*layers) + + def forward(self, x): + + out = self.dncnn(x) + return out + + +def proc_nodes_module(checkpoint): + + new_state_dict = OrderedDict() + for k, v in checkpoint.items(): + if(k[0:7] == "module."): + name = k[7:] + else: + name = k[0:] + new_state_dict[name]=v + return new_state_dict + + +def convert(pth_file, onnx_file): + + pretrained_net = torch.load(pth_file, map_location='cpu') + pretrained_net['state_dict'] = proc_nodes_module(pretrained_net) + + model = DnCNN(channels=1, num_of_layers=17) + model.load_state_dict(pretrained_net['state_dict']) + model.eval() + input_names = ["actual_input_1"] + dummy_input = torch.randn(1, 1, 481, 481) + #torch.onnx.export(model, dummy_input, onnx_file, input_names = input_names, opset_version=11, verbose=True) + dynamic_axes = {'actual_input_1': {0: '-1'}} + torch.onnx.export(model, dummy_input, onnx_file, dynamic_axes=dynamic_axes, \ + input_names=input_names, opset_version=11) + +if __name__ == "__main__": + + pth_file = sys.argv[1] + onnx_file = sys.argv[2] + + convert(pth_file, onnx_file) diff --git a/ACL_PyTorch/contrib/cv/image_process/DnCNN/DnCNN_710/get_info.py b/ACL_PyTorch/contrib/cv/image_process/DnCNN/DnCNN_710/get_info.py new file mode 100644 index 0000000000000000000000000000000000000000..def864bec0be6d2a594819936f282ea03aba46f4 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/image_process/DnCNN/DnCNN_710/get_info.py @@ -0,0 +1,60 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import sys +import cv2 +from glob import glob + + +def get_bin_info(file_path, info_name, width, height): + bin_images = glob(os.path.join(file_path, '*.bin')) + with open(info_name, 'w') as file: + for index, img in enumerate(bin_images): + content = ' '.join([str(index), img, width, height]) + file.write(content) + file.write('\n') + + +def get_jpg_info(file_path, info_name): + extensions = ['jpg', 'jpeg', 'JPG', 'JPEG'] + image_names = [] + for extension in extensions: + image_names.append(glob(os.path.join(file_path, '*.' + extension))) + with open(info_name, 'w') as file: + for image_name in image_names: + if len(image_name) == 0: + continue + else: + for index, img in enumerate(image_name): + img_cv = cv2.imread(img) + shape = img_cv.shape + width, height = shape[1], shape[0] + content = ' '.join([str(index), img, str(width), str(height)]) + file.write(content) + file.write('\n') + + +if __name__ == '__main__': + file_type = sys.argv[1] + file_path = sys.argv[2] + info_name = sys.argv[3] + if file_type == 'bin': + width = sys.argv[4] + height = sys.argv[5] + assert len(sys.argv) == 6, 'The number of input parameters must be equal to 5' + get_bin_info(file_path, info_name, width, height) + elif file_type == 'jpg': + assert len(sys.argv) == 4, 'The number of input parameters must be equal to 3' + get_jpg_info(file_path, info_name) diff --git a/ACL_PyTorch/contrib/cv/image_process/DnCNN/DnCNN_710/requirements.txt b/ACL_PyTorch/contrib/cv/image_process/DnCNN/DnCNN_710/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..965bdf8bdd73d4bbaab48e2fcc37f0cab79f100d --- /dev/null +++ b/ACL_PyTorch/contrib/cv/image_process/DnCNN/DnCNN_710/requirements.txt @@ -0,0 +1,6 @@ +torch==1.8.0 +torchvision==0.9.0 +onnx==1.9.0 +numpy==1.20.2 +opencv-python==4.5.2.52 +scikit-image==0.16.2 \ No newline at end of file diff --git a/ACL_PyTorch/contrib/cv/image_process/DnCNN/READEME.md b/ACL_PyTorch/contrib/cv/image_process/DnCNN/READEME.md index 592f7f9936ff3c5bcb93ad47f60fcb136bf32842..35ee9aa42b639f6d9bb55976fe5e294f92692d83 100644 --- a/ACL_PyTorch/contrib/cv/image_process/DnCNN/READEME.md +++ b/ACL_PyTorch/contrib/cv/image_process/DnCNN/READEME.md @@ -108,10 +108,22 @@ python3.7 DnCNN_pth2onnx.py net.pth DnCNN-S-15.onnx ``` source env.sh ``` -2.使用atc将onnx模型转换为om模型文件 +2.增加benchmark.{arch}可执行权限。 +``` +chmod u+x benchmark.x86_64 +``` +3.使用atc将onnx模型转换为om模型文件 ``` atc --framework=5 --model=./DnCNN-S-15.onnx --input_format=NCHW --input_shape="actual_input_1:1,1,481,481" --output=DnCNN-S-15_bs1 --log=debug --soc_version=Ascend310 ``` +(710_bs1) +``` +atc --framework=5 --model=./DnCNN-S-15.onnx --input_format=NCHW --input_shape="actual_input_1:1,1,481,481" --output=DnCNN-S-15_bs1 --log=debug --soc_version=Ascend710 +``` +(710_bs16) +``` +atc --framework=5 --model=./DnCNN-S-15.onnx --input_format=NCHW --input_shape="actual_input_1:16,1,481,481" --output=DnCNN-S-15_bs16 --log=debug --soc_version=Ascend710 +``` ## 4 数据集预处理 @@ -155,9 +167,15 @@ benchmark工具为华为自研的模型推理工具,支持多种模型的离 source env.sh ``` 2.执行离线推理 +(bs1) +``` +./benchmark.x86_64 -model_type=vision -om_path=DnCNN-S-15_bs1.om -device_id=0 -batch_size=1 -input_text_path=DnCNN_bin.info -input_width=481 -input_height=481 -useDvpp=false -output_binary=true ``` -./benchmark.x86_64 -model_type=vision -om_path=DnCNN-S-15.om -device_id=0 -batch_size=1 -input_text_path=DnCNN_bin.info -input_width=481 -input_height=481 -useDvpp=false -output_binary=true +(bs16) ``` +./benchmark.x86_64 -model_type=vision -om_path=DnCNN-S-15_bs16.om -device_id=0 -batch_size=16 -input_text_path=DnCNN_bin.info -input_width=481 -input_height=481 -useDvpp=false -output_binary=true +``` + 输出结果默认保存在当前目录result/dumpOutput_deviceX(X为对应的device_id),每个输入对应的输出对应一个_X.bin文件。 ## 6 精度对比 diff --git a/ACL_PyTorch/contrib/cv/image_process/DnCNN/postprocess.py b/ACL_PyTorch/contrib/cv/image_process/DnCNN/postprocess.py index 0333302e810f7ded587c87489659fdf8a766484e..7945fd671a25c6e800b5ef1b4631722d09bce708 100644 --- a/ACL_PyTorch/contrib/cv/image_process/DnCNN/postprocess.py +++ b/ACL_PyTorch/contrib/cv/image_process/DnCNN/postprocess.py @@ -20,7 +20,7 @@ import cv2 import torch import torch.nn as nn import struct -from skimage.measure.simple_metrics import compare_psnr +from skimage.metrics import peak_signal_noise_ratio as compare_psnr def batch_PSNR(img, imclean, data_range):