From bce2d7d5a62483937749eb88c6f65980406f2075 Mon Sep 17 00:00:00 2001 From: Traly <445326569@qq.com> Date: Thu, 12 Jan 2023 22:08:35 +0800 Subject: [PATCH] GhostSR model commit update research/cv/GhostSR/GhostSR/EDSR_mindspore/common_gumbel_softmax_ms.py. Signed-off-by: lose4578 <445326569@qq.com> update research/cv/GhostSR/GhostSR/EDSR_mindspore/common_gumbel_softmax_ms.py. Signed-off-by: lose4578 <445326569@qq.com> update research/cv/GhostSR/GhostSR/EDSR_mindspore/common_gumbel_softmax_ms.py. Signed-off-by: lose4578 <445326569@qq.com> --- research/cv/GhostSR/DIV2K_config.yaml | 129 +++++++ .../common_gumbel_softmax_ms.py | 231 ++++++++++++ .../cv/GhostSR/GhostSR/EDSR_mindspore/edsr.py | 166 +++++++++ .../GhostSR/unsupported_model/PixelShuffle.py | 131 +++++++ .../GhostSR/unsupported_model/__init__.py | 0 research/cv/GhostSR/README_CN.md | 229 ++++++++++++ .../cv/GhostSR/ascend310_infer/CMakeLists.txt | 14 + research/cv/GhostSR/ascend310_infer/build.sh | 26 ++ .../cv/GhostSR/ascend310_infer/inc/utils.h | 33 ++ .../cv/GhostSR/ascend310_infer/src/main.cc | 141 ++++++++ .../cv/GhostSR/ascend310_infer/src/utils.cc | 144 ++++++++ research/cv/GhostSR/benchmark_config.yaml | 71 ++++ research/cv/GhostSR/eval.py | 207 +++++++++++ research/cv/GhostSR/eval_onnx.py | 111 ++++++ research/cv/GhostSR/export.py | 65 ++++ .../infer/convert/aipp_edsr_opencv.cfg | 5 + .../cv/GhostSR/infer/convert/convert_om.sh | 28 ++ .../GhostSR/infer/data/config/edsr.pipeline | 28 ++ .../cv/GhostSR/infer/docker_start_infer.sh | 49 +++ .../cv/GhostSR/infer/mxbase/CMakeLists.txt | 55 +++ .../infer/mxbase/EdsrSuperresolution.cpp | 200 +++++++++++ .../infer/mxbase/EdsrSuperresolution.h | 54 +++ research/cv/GhostSR/infer/mxbase/build.sh | 46 +++ research/cv/GhostSR/infer/mxbase/main.cpp | 46 +++ research/cv/GhostSR/infer/sdk/eval.py | 72 ++++ research/cv/GhostSR/infer/sdk/main.py | 45 +++ research/cv/GhostSR/infer/sdk/run.sh | 76 ++++ .../cv/GhostSR/infer/sdk/sr_infer_wrapper.py | 127 +++++++ research/cv/GhostSR/mindspore_hub_conf.py | 26 ++ research/cv/GhostSR/model_utils/__init__.py | 0 research/cv/GhostSR/model_utils/config.py | 136 +++++++ .../cv/GhostSR/model_utils/device_adapter.py | 27 ++ .../cv/GhostSR/model_utils/local_adapter.py | 37 ++ .../cv/GhostSR/model_utils/moxing_adapter.py | 124 +++++++ research/cv/GhostSR/modelarts/train_start.py | 130 +++++++ research/cv/GhostSR/postprocess.py | 193 ++++++++++ research/cv/GhostSR/preprocess.py | 100 ++++++ research/cv/GhostSR/requirements.txt | 7 + research/cv/GhostSR/scripts/run_eval.sh | 55 +++ research/cv/GhostSR/scripts/run_eval_onnx.sh | 30 ++ research/cv/GhostSR/scripts/run_infer_310.sh | 135 +++++++ research/cv/GhostSR/scripts/run_train.sh | 55 +++ research/cv/GhostSR/src/__init__.py | 0 research/cv/GhostSR/src/dataset.py | 333 +++++++++++++++++ research/cv/GhostSR/src/edsr.py | 204 +++++++++++ research/cv/GhostSR/src/metric.py | 338 ++++++++++++++++++ research/cv/GhostSR/src/utils.py | 212 +++++++++++ research/cv/GhostSR/train.py | 155 ++++++++ 48 files changed, 4826 insertions(+) create mode 100644 research/cv/GhostSR/DIV2K_config.yaml create mode 100644 research/cv/GhostSR/GhostSR/EDSR_mindspore/common_gumbel_softmax_ms.py create mode 100644 research/cv/GhostSR/GhostSR/EDSR_mindspore/edsr.py create mode 100644 research/cv/GhostSR/GhostSR/unsupported_model/PixelShuffle.py create mode 100644 research/cv/GhostSR/GhostSR/unsupported_model/__init__.py create mode 100644 research/cv/GhostSR/README_CN.md create mode 100644 research/cv/GhostSR/ascend310_infer/CMakeLists.txt create mode 100644 research/cv/GhostSR/ascend310_infer/build.sh create mode 100644 research/cv/GhostSR/ascend310_infer/inc/utils.h create mode 100644 research/cv/GhostSR/ascend310_infer/src/main.cc create mode 100644 research/cv/GhostSR/ascend310_infer/src/utils.cc create mode 100644 research/cv/GhostSR/benchmark_config.yaml create mode 100644 research/cv/GhostSR/eval.py create mode 100644 research/cv/GhostSR/eval_onnx.py create mode 100644 research/cv/GhostSR/export.py create mode 100644 research/cv/GhostSR/infer/convert/aipp_edsr_opencv.cfg create mode 100644 research/cv/GhostSR/infer/convert/convert_om.sh create mode 100644 research/cv/GhostSR/infer/data/config/edsr.pipeline create mode 100644 research/cv/GhostSR/infer/docker_start_infer.sh create mode 100644 research/cv/GhostSR/infer/mxbase/CMakeLists.txt create mode 100644 research/cv/GhostSR/infer/mxbase/EdsrSuperresolution.cpp create mode 100644 research/cv/GhostSR/infer/mxbase/EdsrSuperresolution.h create mode 100644 research/cv/GhostSR/infer/mxbase/build.sh create mode 100644 research/cv/GhostSR/infer/mxbase/main.cpp create mode 100644 research/cv/GhostSR/infer/sdk/eval.py create mode 100644 research/cv/GhostSR/infer/sdk/main.py create mode 100644 research/cv/GhostSR/infer/sdk/run.sh create mode 100644 research/cv/GhostSR/infer/sdk/sr_infer_wrapper.py create mode 100644 research/cv/GhostSR/mindspore_hub_conf.py create mode 100644 research/cv/GhostSR/model_utils/__init__.py create mode 100644 research/cv/GhostSR/model_utils/config.py create mode 100644 research/cv/GhostSR/model_utils/device_adapter.py create mode 100644 research/cv/GhostSR/model_utils/local_adapter.py create mode 100644 research/cv/GhostSR/model_utils/moxing_adapter.py create mode 100644 research/cv/GhostSR/modelarts/train_start.py create mode 100644 research/cv/GhostSR/postprocess.py create mode 100644 research/cv/GhostSR/preprocess.py create mode 100644 research/cv/GhostSR/requirements.txt create mode 100644 research/cv/GhostSR/scripts/run_eval.sh create mode 100644 research/cv/GhostSR/scripts/run_eval_onnx.sh create mode 100644 research/cv/GhostSR/scripts/run_infer_310.sh create mode 100644 research/cv/GhostSR/scripts/run_train.sh create mode 100644 research/cv/GhostSR/src/__init__.py create mode 100644 research/cv/GhostSR/src/dataset.py create mode 100644 research/cv/GhostSR/src/edsr.py create mode 100644 research/cv/GhostSR/src/metric.py create mode 100644 research/cv/GhostSR/src/utils.py create mode 100644 research/cv/GhostSR/train.py diff --git a/research/cv/GhostSR/DIV2K_config.yaml b/research/cv/GhostSR/DIV2K_config.yaml new file mode 100644 index 000000000..5cbfc9534 --- /dev/null +++ b/research/cv/GhostSR/DIV2K_config.yaml @@ -0,0 +1,129 @@ +# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) +enable_modelarts: False +# Url for modelarts +data_url: "" +train_url: "" +checkpoint_url: "" +# Path for local +data_path: "/cache/data/DIV2K" +output_path: "./output" +device_target: "GPU" + +network: "EDSR4GhostSRMs" # EDSR_mindspore / EDSR4GhostSRMs + +# ============================================================================== +# train options +amp_level: "O3" +loss_scale: 1000.0 # for ['O2', 'O3', 'auto'] +keep_checkpoint_max: 60 +save_epoch_frq: 100 +ckpt_save_dir: "./ckpt/" +epoch_size: 6000 + +# eval options +eval_epoch_frq: 20 +self_ensemble: False +save_sr: True +eval_type: "" + +# Adam opt options +opt_type: Adam +weight_decay: 0.0 + +# learning rate options +learning_rate: 0.0001 +milestones: [ 4000 ] +gamma: 0.5 + +# dataset options +dataset_name: "DIV2K" +lr_type: "bicubic" +batch_size: 2 +patch_size: 192 +scale: 2 +dataset_sink_mode: True +need_unzip_in_modelarts: False +need_unzip_files: + - "DIV2K_train_HR.zip" + - "DIV2K_train_LR_bicubic_X2.zip" + - "DIV2K_train_LR_bicubic_X3.zip" + - "DIV2K_train_LR_bicubic_X4.zip" + - "DIV2K_train_LR_unknown_X2.zip" + - "DIV2K_train_LR_unknown_X3.zip" + - "DIV2K_train_LR_unknown_X4.zip" + - "DIV2K_valid_HR.zip" + - "DIV2K_valid_LR_bicubic_X2.zip" + - "DIV2K_valid_LR_bicubic_X3.zip" + - "DIV2K_valid_LR_bicubic_X4.zip" + - "DIV2K_valid_LR_unknown_X2.zip" + - "DIV2K_valid_LR_unknown_X3.zip" + - "DIV2K_valid_LR_unknown_X4.zip" + +# net options +pre_trained: "" +rgb_range: 255 +rgb_mean: [ 0.4488, 0.4371, 0.4040 ] +rgb_std: [ 1.0, 1.0, 1.0 ] +n_colors: 3 +n_feats: 256 +kernel_size: 3 +n_resblocks: 32 +res_scale: 0.1 + + +--- +# helper + +enable_modelarts: "set True if run in modelarts, default: False" +# Url for modelarts +data_url: "modelarts data path" +train_url: "modelarts code path" +checkpoint_url: "modelarts checkpoint save path" +# Path for local +data_path: "local data path, data will be download from 'data_url', default: /cache/data" +output_path: "local output path, checkpoint will be upload to 'checkpoint_url', default: /cache/train" +device_target: "choice from ['Ascend'], default: Ascend" + +# ============================================================================== +# train options +amp_level: "choice from ['O0', 'O2', 'O3', 'auto'], default: O3" +loss_scale: "loss scale will be used except 'O0', default: 1000.0" +keep_checkpoint_max: "max number of checkpoints to be saved, defalue: 60" +save_epoch_frq: "frequency to save checkpoint, defalue: 100" +ckpt_save_dir: "the relative path to save checkpoint, root path is 'output_path', defalue: ./ckpt/" +epoch_size: "the number of training epochs, defalue: 6000" + +# eval options +eval_epoch_frq: "frequency to evaluate model, defalue: 20" +self_ensemble: "set True if wanna do self-ensemble while evaluating, defalue: True" +save_sr: "set True if wanna save sr and hr image while evaluating, defalue: True" + +# opt options +opt_type: "optimizer type, choice from ['Adam'], defalue: Adam" +weight_decay: "weight_decay for optimizer, defalue: 0.0" + +# learning rate options +learning_rate: "learning rate, defalue: 0.0001" +milestones: "the key epoch to do a gamma decay, defalue: [4000]" +gamma: "gamma decay rate, defalue: 0.5" + +# dataset options +dataset_name: "dataset name, defalue: DIV2K" +lr_type: "lr image degeneration type, choice from ['bicubic', 'unknown'], defalue: bicubic" +batch_size: "batch size for training; total batch size = 16 is recommended, defalue: 2" +patch_size: "cut hr images into patch size for training, lr images auto-adjust by 'scale', defalue: 192" +scale: "scale for super resolution reconstruction, choice from [2,3,4], defalue: 4" +dataset_sink_mode: "set True if wanna using dataset sink mode, defalue: True" +need_unzip_in_modelarts: "set True if wanna unzip data after download data from s3, defalue: False" +need_unzip_files: "list of zip files to unzip, only work while 'need_unzip_in_modelarts'=True" + +# net options +pre_trained: "load pre-trained model, x2/x3/x4 models can be loaded for each other, choice from [[S3_ABS_PATH], [RELATIVE_PATH below 'output_path'], [LOCAL_ABS_PATH], ''], defalue: ''" +rgb_range: "pix value range, defalue: 255" +rgb_mean: "rgb mean, defalue: [0.4488, 0.4371, 0.4040]" +rgb_std: "rgb standard deviation, defalue: [1.0, 1.0, 1.0]" +n_colors: "the number of RGB image channels, defalue: 3" +n_feats: "the number of output features for each Conv2d, defalue: 256" +kernel_size: "kernel size for Conv2d, defalue: 3" +n_resblocks: "the number of resblocks, defalue: 32" +res_scale: "zoom scale of res branch, defalue: 0.1" diff --git a/research/cv/GhostSR/GhostSR/EDSR_mindspore/common_gumbel_softmax_ms.py b/research/cv/GhostSR/GhostSR/EDSR_mindspore/common_gumbel_softmax_ms.py new file mode 100644 index 000000000..06154d5e2 --- /dev/null +++ b/research/cv/GhostSR/GhostSR/EDSR_mindspore/common_gumbel_softmax_ms.py @@ -0,0 +1,231 @@ +# 2022.12.27-Changed for EDSR-PyTorch +# Huawei Technologies Co., Ltd. +# Copyright 2022 Huawei Technologies Co., Ltd. +# Copyright 2018 sanghyun-son (https://github.com/sanghyun-son/EDSR-PyTorch). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +import numpy as np + +import mindspore as ms +from mindspore import nn +import mindspore.common.initializer as init +from GhostSR.unsupported_model.PixelShuffle import PixelShuffle + + + +def exponential_decay(t, _init=10, m=200, finish=1e-2): + alpha = np.log(_init / finish) / m + l = - np.log(_init) / alpha + decay = np.exp(-alpha * (t + l)) + return decay + + +def sample_gumbel(size, eps=1e-20): + # print('size.dtype: ', size[0].dtype) + uniform_real = ms.ops.UniformReal() + U = uniform_real(size) + return -ms.ops.log(-ms.ops.log(U + eps) + eps) + + +def gumbel_softmax(weights, epoch): + noise_temp = 0.97 ** (epoch - 1) + noise = sample_gumbel(weights.shape) * noise_temp + y = weights + noise + y_abs = y.abs().view(1, -1) + y_hard = ms.ops.zeros_like(y_abs) + y_hard[0, ms.ops.Argmax()(y_abs)] = 1 + y_hard = y_hard.view(weights.shape) + # ret = (y_hard - weights).detach() + weights + ret = ms.ops.stop_gradient(y_hard - weights) + weights + return ret + + +def hard_softmax(weights): + y_abs = weights.abs().view(1, -1) + y_hard = ms.ops.ZerosLike()(y_abs) + y_hard[0, ms.ops.Argmax()(y_abs)] = 1 + y_hard = y_hard.view(weights.shape) + return y_hard + + +# 1*1*3*3 shift +class ShiftConvGeneral(nn.Cell): + def __init__(self, act_channel, in_channels=1, out_channels=1, kernel_size=3, stride=1, + padding=1, groups=1, + bias=False): + super(ShiftConvGeneral, self).__init__() + self.stride = stride + self.padding = padding + self.bias = bias + self.epoch = 1 + self.act_channel = act_channel + # self.w_out_channels = in_channels // groups + self.kernel_size = kernel_size + self.weight = ms.Parameter( + ms.Tensor(shape=(out_channels, in_channels // groups, kernel_size, kernel_size), + dtype=ms.float16, + init=init.HeUniform(negative_slope=math.sqrt(5))), requires_grad=True) + if bias: + self.b = ms.Parameter(ms.ops.Zeros(act_channel), requires_grad=True) + # self.reset_parameters() + + def reset_parameters(self): + init.HeUniform(self.weight, a=math.sqrt(5)) + + def construct(self, x): + assert x.shape[1] == self.act_channel + if self.training: + w = gumbel_softmax(self.weight, self.epoch) + else: + w = hard_softmax(self.weight) + w = w.astype(x.dtype) + w = ms.numpy.tile(w, (x.shape[1], 1, 1, 1)) + out = ms.ops.Conv2D(self.act_channel, self.kernel_size, stride=self.stride, + pad=self.padding, dilation=1, + pad_mode='pad', group=x.shape[1])(x, w) + if self.bias: + out += self.b.unsqueeze(0).unsqueeze(2).unsqueeze(3) + return out + + +# 统一方向shift, lamda!=0.5 +class GhostModule(nn.Cell): + def __init__(self, inp, oup, kernel_size, dir_num, ratio=0.5, stride=1, bias=True): + super(GhostModule, self).__init__() + self.oup = oup + init_channels = math.ceil(oup * ratio) + new_channels = oup - init_channels + + self.primary_conv = nn.Conv2d(inp, init_channels, kernel_size, stride, + pad_mode='pad', padding=kernel_size // 2, has_bias=bias) + self.cheap_conv = ShiftConvGeneral(new_channels, 1, 1, kernel_size=3, stride=1, padding=1, + groups=1, bias=False) + self.concat = ms.ops.Concat(axis=1) + + self.init_channels = init_channels + self.new_channels = new_channels + + def construct(self, x): + if self.init_channels > self.new_channels: + x1 = self.primary_conv(x) + x2 = self.cheap_conv(x1[:, :self.new_channels, :, :]) + elif self.init_channels == self.new_channels: + x1 = self.primary_conv(x) + x2 = self.cheap_conv(x1) + # elif self.init_channels < self.new_channels: + else: + x1 = self.primary_conv(x) + x1 = x1.repeat(1, 3, 1, 1) + x2 = self.cheap_conv(x1[:, :self.new_channels, :, :]) + out = self.concat([x1, x2]) + return out[:, :self.oup, :, :] + + +def default_conv(in_channels, out_channels, kernel_size, bias=True): + return nn.Conv2d( + in_channels, out_channels, kernel_size, + pad_mode='pad', padding=(kernel_size // 2), has_bias=bias) + + +def default_ghost(in_channels, out_channels, kernel_size, dir_num, bias=True): + return GhostModule( + in_channels, out_channels, kernel_size, dir_num, bias=bias) + + +class MeanShift(nn.Conv2d): + def __init__( + self, rgb_range, + rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): + super(MeanShift, self).__init__(3, 3, kernel_size=1, pad_mode='valid', has_bias=True) + std = ms.Tensor(rgb_std) + self.weight.set_data(ms.ops.eye(3, 3, ms.float32).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)) + + self.bias.set_data(sign * rgb_range * ms.Tensor(rgb_mean) / std) + for p in self.get_parameters(): + p.requires_grad = False + + +class GhostResBlock(nn.Cell): + def __init__(self, conv, n_feats, kernel_size, dir_num=1, bias=True, bn=False, act=nn.ReLU(), + res_scale=1): + super(GhostResBlock, self).__init__() + m = [] + for i in range(2): + m.append(conv(n_feats, n_feats, kernel_size, dir_num=dir_num, bias=bias)) + if bn: + m.append(nn.BatchNorm2d(n_feats, momentum=0.9)) + if i == 0: + m.append(act) + + self.body = nn.SequentialCell(m) + self.mul = ms.ops.Mul() + self.res_scale = res_scale + + def construct(self, x): + res = self.mul(self.body(x), self.res_scale) + # res = self.body(x) * self.res_scale + res += x + return res + + +class ConvResBlock(nn.Cell): + def __init__(self, conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(), res_scale=1): + super(ConvResBlock, self).__init__() + m = [] + for i in range(2): + m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) + if bn: + m.append(nn.BatchNorm2d(n_feats, momentum=0.9)) + if i == 0: + m.append(act) + + self.body = nn.SequentialCell(m) + self.res_scale = ms.Tensor(res_scale, dtype=ms.int32) + + def construct(self, x): + res = self.body(x).mul(self.res_scale) + res += x + return res + + +class Upsampler(nn.SequentialCell): + def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): + + m = [] + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(conv(n_feats, 4 * n_feats, 3, bias)) + m.append(PixelShuffle(2)) + if bn: + m.append(nn.BatchNorm2d(n_feats, momentum=0.9)) + if act == 'relu': + m.append(nn.ReLU()) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + + elif scale == 3: + m.append(conv(n_feats, 9 * n_feats, 3, bias)) + m.append(PixelShuffle(3)) + if bn: + m.append(nn.BatchNorm2d(n_feats, momentum=0.9)) + if act == 'relu': + m.append(nn.ReLU()) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + else: + raise NotImplementedError + + super(Upsampler, self).__init__(m) diff --git a/research/cv/GhostSR/GhostSR/EDSR_mindspore/edsr.py b/research/cv/GhostSR/GhostSR/EDSR_mindspore/edsr.py new file mode 100644 index 000000000..d82d0a5b3 --- /dev/null +++ b/research/cv/GhostSR/GhostSR/EDSR_mindspore/edsr.py @@ -0,0 +1,166 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# This file refers to https://github.com/sanghyun-son/EDSR-PyTorch + +import numpy as np +import mindspore +import mindspore.common.initializer as init +from mindspore import nn, Parameter + +import common_gumbel_softmax_ms as mycommon + +convghost = 'ghost' +default_conv = mycommon.default_conv +ghost_conv = mycommon.default_ghost + + +def _weights_init(m): + # print(classname) + if isinstance(m, (nn.Dense, nn.Conv2d)): + # init.kaiming_normal(m.weight) + init.Zero(m.weight) + init.Zero(m.bias) + + +class RgbNormal(nn.Conv2d): + """ + "MeanShift" in EDSR paper pytorch-code: + https://github.com/sanghyun-son/EDSR-PyTorch/blob/master/src/model/common.py + + it is not unreasonable in the case below + if std != 1 and sign = -1: y = x * rgb_std - rgb_range * rgb_mean + if std != 1 and sign = 1: y = x * rgb_std + rgb_range * rgb_mean + they are not inverse operation for each other! + + so use "RgbNormal" instead, it runs as below: + if inverse = False: y = (x / rgb_range - mean) / std + if inverse = True : x = (y * std + mean) * rgb_range + """ + + def __init__(self, rgb_range, rgb_mean, rgb_std, inverse=False): + super().__init__(3, 3, kernel_size=1, pad_mode='valid', has_bias=True) + self.rgb_range = rgb_range + self.rgb_mean = rgb_mean + self.rgb_std = rgb_std + self.inverse = inverse + std = np.array(self.rgb_std, dtype=np.float32) + mean = np.array(self.rgb_mean, dtype=np.float32) + if not inverse: + # y: (x / rgb_range - mean) / std <=> x * (1.0 / rgb_range / std) + (-mean) / std + weight = (1.0 / self.rgb_range / std).reshape((1, -1, 1, 1)) + bias = (-mean / std) + else: + # x: (y * std + mean) * rgb_range <=> y * (std * rgb_range) + mean * rgb_range + weight = (self.rgb_range * std).reshape((1, -1, 1, 1)) + bias = (mean * rgb_range) + + weight = np.tile(weight, (3, 1, 1, 1)) + # bias = np.tile(bias, (3, 1, 1, 1)) + + self.weight = Parameter(weight, requires_grad=False).astype('float16') + self.bias = Parameter(bias, requires_grad=False).astype('float16') + + +class EDSR4GhostSRMs(nn.Cell): + """ + EDSR for GhostSR version + """ + def __init__(self, scale=2): + super().__init__() + n_resblocks = 32 + n_feats = 256 + kernel_size = 3 + act = nn.ReLU() + + self.sub_mean = mycommon.MeanShift(255) + self.add_mean = mycommon.MeanShift(255, sign=1) + + # self.sub_mean = RgbNormal(rgb_range, rgb_mean, rgb_std, inverse=False) + # self.add_mean = RgbNormal(rgb_range, rgb_mean, rgb_std, inverse=True) + + # define head module + m_head = [default_conv(3, n_feats, kernel_size)] + + # define body module + if convghost == 'ghost': + m_body = [ + mycommon.GhostResBlock( + ghost_conv, n_feats, kernel_size, dir_num=1, act=act, res_scale=0.1 + ) for _ in range(n_resblocks) + ] + m_body.append(default_conv(n_feats, n_feats, kernel_size)) + + elif convghost == 'conv': + m_body = [ + mycommon.ConvResBlock( + default_conv, n_feats, kernel_size, act=act, res_scale=0.1 + ) for _ in range(n_resblocks) + ] + m_body.append(default_conv(n_feats, n_feats, kernel_size)) + + # define tail module + m_tail = [ + mycommon.Upsampler(default_conv, scale, n_feats, act=False), + default_conv(n_feats, 3, kernel_size) + ] + + self.head = nn.SequentialCell(m_head) + self.body = nn.SequentialCell(m_body) + self.tail = nn.SequentialCell(m_tail) + + def construct(self, x): + """ + construct + """ + x = self.sub_mean(x) + x = self.head(x) + + res = self.body(x) + res += x + + x = self.tail(res) + x = self.add_mean(x) + + return x + + def load_pre_trained_param_dict(self, new_param_dict, strict=True): + """ + load pre_trained param dict from edsr_x2 + """ + own_param = self.parameters_dict() + for name, new_param in new_param_dict.items(): + if len(name) >= 4 and name[:4] == "net.": + name = name[4:] + if name in own_param: + if isinstance(new_param, Parameter): + param = own_param[name] + if tuple(param.data.shape) == tuple(new_param.data.shape): + param.set_data(type(param.data)(new_param.data)) + elif name.find('tail') == -1: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_param[name].shape, new_param.shape)) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in parameters_dict()' + .format(name)) + + +if __name__ == '__main__': + model = EDSR4GhostSRMs().cuda() + model.load_state_dict(mindspore.load_checkpoint('./model_best.pt'), strict=True) + print('success') diff --git a/research/cv/GhostSR/GhostSR/unsupported_model/PixelShuffle.py b/research/cv/GhostSR/GhostSR/unsupported_model/PixelShuffle.py new file mode 100644 index 000000000..9e85fd437 --- /dev/null +++ b/research/cv/GhostSR/GhostSR/unsupported_model/PixelShuffle.py @@ -0,0 +1,131 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore.ops as ops +from mindspore import nn +from mindspore._checkparam import Validator as validator +from mindspore.ops.primitive import constexpr + + +@constexpr +def _check_positive_int(arg_value, arg_name=None, prim_name=None): + validator.check_positive_int(arg_value, arg_name=arg_name, prim_name=prim_name) + + +def pixel_shuffle(x, upscale_factor): + r""" + pixel_shuffle operatrion. + + Applies a pixel_shuffle operation over an input signal composed of several input planes. This is useful for + implementiong efficient sub-pixel convolution with a stride of :math:`1/r`. For more details, refer to + `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network + `_ . + + Typically, the `x` is of shape :math:`(*, C \times r^2, H, W)` , and the output is of shape + :math:`(*, C, H \times r, W \times r)`, where `r` is an upscale factor and `*` is zero or more batch dimensions. + + Args: + x (Tensor): Tensor of shape :math:`(*, C \times r^2, H, W)` . The dimension of `x` is larger than 2, and the + length of third to last dimension can be divisible by `upscale_factor` squared. + upscale_factor (int): factor to increase spatial resolution by, and is a positive integer. + + Returns: + - **output** (Tensor) - Tensor of shape :math:`(*, C, H \times r, W \times r)` . + + Raises: + ValueError: If `upscale_factor` is not a positive integer. + ValueError: If the length of third to last dimension is not divisible by `upscale_factor` squared. + TypeError: If the dimension of `x` is less than 3. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> input_x = np.arange(3 * 2 * 9 * 4 * 4).reshape((3, 2, 9, 4, 4)) + >>> input_x = mindspore.Tensor(input_x, mindspore.dtype.int32) + >>> output = ops.pixel_shuffle(input_x, 3) + >>> print(output.shape) + (3, 2, 1, 12, 12) + """ + _check_positive_int(upscale_factor, "upscale_factor") + idx = x.shape + length = len(idx) + if length < 3: + raise TypeError( + f"For pixel_shuffle, the dimension of `x` should be larger than 2, but got {length}.") + pre = idx[:-3] + c, h, w = idx[-3:] + if c % upscale_factor ** 2 != 0: + raise ValueError( + "For 'pixel_shuffle', the length of third to last dimension is not divisible" + "by `upscale_factor` squared.") + c = c // upscale_factor ** 2 + input_perm = (pre + (c, upscale_factor, upscale_factor, h, w)) + reshape = ops.Reshape() + x = reshape(x, input_perm) + input_perm = [i for i in range(length - 2)] + input_perm = input_perm + [length, length - 2, length + 1, length - 1] + input_perm = tuple(input_perm) + transpose = ops.Transpose() + x = transpose(x, input_perm) + x = reshape(x, (pre + (c, upscale_factor * h, upscale_factor * w))) + return x + + +class PixelShuffle(nn.Cell): + r""" + PixelShuffle operatrion. + + Applies a pixelshuffle operation over an input signal composed of several input planes. This is useful for + implementiong efficient sub-pixel convolution with a stride of :math:`1/r`. For more details, refer to + `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network + `_ . + + Typically, the input is of shape :math:`(*, C \times r^2, H, W)` , and the output is of shape + :math:`(*, C, H \times r, W \times r)`, where r is an upscale factor and * is zero or more batch dimensions. + + Args: + upscale_factor (int): factor to increase spatial resolution by, and is a positive integer. + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(*, C \times r^2, H, W)` . The dimension of `x` is larger than 2, and + the length of third to last dimension can be divisible by `upscale_factor` squared. + + Outputs: + - **output** (Tensor) - Tensor of shape :math:`(*, C, H \times r, W \times r)` . + + Raises: + ValueError: If `upscale_factor` is not a positive integer. + ValueError: If the length of third to last dimension of `x` is not divisible by `upscale_factor` squared. + TypeError: If the dimension of `x` is less than 3. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> input_x = np.arange(3 * 2 * 9 * 4 * 4).reshape((3, 2, 9, 4, 4)) + >>> input_x = mindspore.Tensor(input_x, mindspore.dtype.int32) + >>> pixel_shuffle = nn.PixelShuffle(3) + >>> output = pixel_shuffle(input_x) + >>> print(output.shape) + (3, 2, 1, 12, 12) + """ + + def __init__(self, upscale_factor): + super(PixelShuffle, self).__init__() + self.upscale_factor = upscale_factor + + def construct(self, x): + return pixel_shuffle(x, self.upscale_factor) diff --git a/research/cv/GhostSR/GhostSR/unsupported_model/__init__.py b/research/cv/GhostSR/GhostSR/unsupported_model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/research/cv/GhostSR/README_CN.md b/research/cv/GhostSR/README_CN.md new file mode 100644 index 000000000..a563933d4 --- /dev/null +++ b/research/cv/GhostSR/README_CN.md @@ -0,0 +1,229 @@ +# 目录 + + + +- [目录](#目录) +- [GhostSR 描述](#GhostSR 描述) +- [环境配置/推理/导出](#环境配置/推理/导出) +- [数据集](#数据集) +- [快速入门](#快速入门) +- [脚本说明](#脚本说明) + - [脚本及样例代码](#脚本及样例代码) + - [脚本参数](#脚本参数) +- [模型评估](#模型评估) + - [评估性能](#评估性能) + - [DIV2K上的评估2倍超分辨率重建的EDSR](#DIV2K上的评估2倍超分辨率重建的GhostSR_EDSR) +- [随机情况说明](#随机情况说明) +- [ModelZoo主页](#modelzoo主页) + + + +# GhostSR 描述 + +GhostSR 是2022年提出的轻量级单图超分辨重建网络。它通过引入shift operation 来生成 ghost +features,大幅减少参数量、flops和推理延迟的同时几乎性能无损。 + +论文:[GhostSR: Learning Ghost Features for Efficient Image Super-Resolution](https://arxiv.org/abs/2101.08525) + +# 环境配置/推理/导出 + +本代码修改自 [EDSR(MindSpore)](https://gitee.com/mindspore/models/tree/master/official/cv/EDSR), +环境配置/推理/导出等操作可参考EDSR + +# 数据集 + +使用的数据集:[DIV2K]() + +- 数据集大小:7.11G,共1000组(HR,LRx2,LRx3,LRx4)有效彩色图像 + - 训练集:6.01G,共800组图像 + - 验证集:783.68M,共100组图像 + - 测试集:349.53M,共100组图像(无HR图) +- 数据格式:PNG图片文件文件 + - 注:数据将在src/dataset.py中处理。 +- 数据目录树:官网下载数据后,解压压缩包,训练和验证所需的数据目录结构如下: + +```shell +├─DIV2K_train_HR +│ ├─0001.png +│ ├─... +│ └─0800.png +├─DIV2K_train_LR_bicubic +│ ├─X2 +│ │ ├─0001x2.png +│ │ ├─... +│ │ └─0800x2.png +│ ├─X3 +│ │ ├─0001x3.png +│ │ ├─... +│ │ └─0800x3.png +│ └─X4 +│ ├─0001x4.png +│ ├─... +│ └─0800x4.png +├─DIV2K_valid_LR_bicubic +│ ├─0801.png +│ ├─... +│ └─0900.png +└─DIV2K_valid_LR_bicubic + ├─X2 + │ ├─0801x2.png + │ ├─... + │ └─0900x2.png + ├─X3 + │ ├─0801x3.png + │ ├─... + │ └─0900x3.png + └─X4 + ├─0801x4.png + ├─... + └─0900x4.png +``` + +# 快速入门 + +通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估。对于分布式训练,需要提前创建JSON格式的hccl配置文件。请遵循以下链接中的说明: + + +- GPU环境运行单卡评估DIV2K + + ```python + # 运行评估示例(EDSR_mindspore(x2) in the paper) + python eval.py --config_path DIV2K_config.yaml --scale 2 --data_path [DIV2K path] --output_path [path to save sr] --pre_trained ./ckpt/EDSR_GhostSR_x2.ckpt > train.log 2>&1 & + ``` + +- GPU环境运行单卡评估benchmark + + ```python + # 运行评估示例(EDSR_mindspore(x2) in the paper) + python eval.py --config_path benchmark_config.yaml --scale 2 --data_path [benchmark path] --output_path [path to save sr] --pre_trained ./ckpt/EDSR_GhostSR_x2.ckpt > train.log 2>&1 & + ``` + +# 脚本说明 + +## 脚本及样例代码 + +```text +├── model_zoo + ├── README.md // 所有模型相关说明 + ├── EDSR + ├── README_CN.md // EDSR说明 + ├── model_utils // 上云的工具脚本 + ├── DIV2K_config.yaml // EDSR参数 + ├── ckpt + │ └── EDSR_GhostSR_x2.ckpt // EDSR_GhostSR 2倍超分辨率模型权重 + ├── GhostSR // GhostSR 网络架构 + │ ├── EDSR_mindspore // EDSR_GhostSR 网络架构 + │ └── unsupported_model // mindspore 中未原生支持的算子 + ├── scripts + │ ├── run_train.sh // 分布式到Ascend的shell脚本 + │ ├── run_eval.sh // Ascend评估的shell脚本 + │ ├── run_infer_310.sh // Ascend-310推理shell脚本 + │ └── run_eval_onnx.sh // 用于ONNX评估的shell脚本 + ├── src + │ ├── dataset.py // 创建数据集 + │ ├── edsr.py // edsr网络架构 + │ ├── config.py // 参数配置 + │ ├── metric.py // 评估指标 + │ ├── utils.py // train.py/eval.py公用的代码段 + ├── train.py // 训练脚本 + ├── eval.py // 评估脚本 + ├── eval_onnx.py // ONNX评估脚本 + ├── export.py // 将checkpoint文件导出到onnx/air/mindir + ├── preprocess.py // Ascend-310推理的数据预处理脚本 + ├── ascend310_infer + │ ├── src // 实现Ascend-310推理源代码 + │ ├── inc // 实现Ascend-310推理源代码 + │ ├── build.sh // 构建Ascend-310推理程序的shell脚本 + │ ├── CMakeLists.txt // 构建Ascend-310推理程序的CMakeLists + ├── postprocess.py // Ascend-310推理的数据后处理脚本 +``` + +## 脚本参数 + +在DIV2K_config.yaml中可以同时配置训练参数和评估参数。benchmark_config.yaml中的同名参数是一样的定义。 + +- 可以使用以下语句可以打印配置说明 + + ```python + python train.py --config_path DIV2K_config.yaml --help + ``` + +- 可以直接查看DIV2K_config.yaml内的配置说明,说明如下 + + ```yaml + enable_modelarts: "在云道运行则需要配置为True, default: False" + + data_url: "云道数据路径" + train_url: "云道代码路径" + checkpoint_url: "云道保存的路径" + + data_path: "运行机器的数据路径,由脚本从云道数据路径下载,default: /cache/data" + output_path: "运行机器的输出路径,由脚本从本地上传至checkpoint_url,default: /cache/train" + device_target: "可选['Ascend'],default: Ascend" + + amp_level: "可选['O0', 'O2', 'O3', 'auto'],default: O3" + loss_scale: "除了O0外,其他混合精度时会做loss放缩,default: 1000.0" + keep_checkpoint_max: "最多保存多少个ckpt, defalue: 60" + save_epoch_frq: "每隔多少epoch保存ckpt一次, defalue: 100" + ckpt_save_dir: "保存的本地相对路径,根目录是output_path, defalue: ./ckpt/" + epoch_size: "训练多少个epoch, defalue: 6000" + + eval_epoch_frq: "训练时每隔多少epoch执行一次验证,defalue: 20" + self_ensemble: "验证时执行self_ensemble,仅在eval.py中使用, defalue: True" + save_sr: "验证时保存sr和hr图片,仅在eval.py中使用, defalue: True" + + opt_type: "优化器类型,可选['Adam'],defalue: Adam" + weight_decay: "优化器权重衰减参数,defalue: 0.0" + + learning_rate: "学习率,defalue: 0.0001" + milestones: "学习率衰减的epoch节点列表,defalue: [4000]" + gamma: "学习率衰减率,defalue: 0.5" + + dataset_name: "数据集名称,defalue: DIV2K" + lr_type: "lr图的退化方式,可选['bicubic', 'unknown'],defalue: bicubic" + batch_size: "为了保证效果,建议8卡用2,单卡用16,defalue: 2" + patch_size: "训练时候的裁剪HR图大小,LR图会依据scale调整裁剪大小,defalue: 192" + scale: "模型的超分辨重建的尺度,可选[2,3,4], defalue: 4" + dataset_sink_mode: "训练使用数据下沉模式,defalue: True" + need_unzip_in_modelarts: "从s3下载数据后加压数据,defalue: False" + need_unzip_files: "需要解压的数据列表, need_unzip_in_modelarts=True时才起作用" + + pre_trained: "加载预训练模型,x2/x3/x4倍可以相互加载,可选[[s3绝对地址], [output_path下相对地址], [本地机器绝对地址], ''],defalue: ''" + rgb_range: "图片像素的范围,defalue: 255" + rgb_mean: "图片RGB均值,defalue: [0.4488, 0.4371, 0.4040]" + rgb_std: "图片RGB方差,defalue: [1.0, 1.0, 1.0]" + n_colors: "RGB图片3通道,defalue: 3" + n_feats: "每个卷积层的输出特征数量,defalue: 256" + kernel_size: "卷积核大小,defalue: 3" + n_resblocks: "resblocks数量,defalue: 32" + res_scale: "res的分支的系数,defalue: 0.1" + ``` + +# 模型评估 + +## 性能 + +### DIV2K上的评估2倍/3倍/4倍超分辨率重建的EDSR + +| 参数 | Ascend | +|--------------|---| +| 模型版本 | EDSR-GhostSR(x2) | +| MindSpore版本 | 1.9.0 | +| 数据集 | DIV2K, 100张图像 | +| self_ensemble | True | +| batch_size | 1 | +| 输出 | 超分辨率重建RGB图 | +| Set5 psnr | 38.101 db | +| Set14 psnr | 33.856 db | +| B100 psnr | 32.288 db | +| Urban100 psnr | 32.793 db | +| DIV2K psnr | 34.8748 db | +| 推理模型 | 83.3 MB (.ckpt文件) | + +# 随机情况说明 + +在train.py,eval.py中,我们设置了mindspore.common.set_seed(2021)种子。 + +# ModelZoo主页 + +请浏览官网[主页](https://gitee.com/mindspore/models)。 diff --git a/research/cv/GhostSR/ascend310_infer/CMakeLists.txt b/research/cv/GhostSR/ascend310_infer/CMakeLists.txt new file mode 100644 index 000000000..ee3c85447 --- /dev/null +++ b/research/cv/GhostSR/ascend310_infer/CMakeLists.txt @@ -0,0 +1,14 @@ +cmake_minimum_required(VERSION 3.14.1) +project(Ascend310Infer) +add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined") +set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/) +option(MINDSPORE_PATH "mindspore install path" "") +include_directories(${MINDSPORE_PATH}) +include_directories(${MINDSPORE_PATH}/include) +include_directories(${PROJECT_SRC_ROOT}) +find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib) +file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*) + +add_executable(main src/main.cc src/utils.cc) +target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags) diff --git a/research/cv/GhostSR/ascend310_infer/build.sh b/research/cv/GhostSR/ascend310_infer/build.sh new file mode 100644 index 000000000..a19f43626 --- /dev/null +++ b/research/cv/GhostSR/ascend310_infer/build.sh @@ -0,0 +1,26 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ ! -d out ]; then + mkdir out +fi +cd out || exit +MINDSPORE_PATH="`pip show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`" +if [[ ! $MINDSPORE_PATH ]];then + MINDSPORE_PATH="`pip show mindspore | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`" +fi +cmake .. -DMINDSPORE_PATH=$MINDSPORE_PATH +make diff --git a/research/cv/GhostSR/ascend310_infer/inc/utils.h b/research/cv/GhostSR/ascend310_infer/inc/utils.h new file mode 100644 index 000000000..b536d9048 --- /dev/null +++ b/research/cv/GhostSR/ascend310_infer/inc/utils.h @@ -0,0 +1,33 @@ +/* + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_INFERENCE_UTILS_H_ +#define MINDSPORE_INFERENCE_UTILS_H_ + +#include +#include +#include +#include +#include +#include "include/api/types.h" + +std::vector GetAllFiles(std::string_view dirName); +DIR *OpenDir(std::string_view dirName); +std::string RealPath(std::string_view path); +mindspore::MSTensor ReadFileToTensor(const std::string &file); +int WriteResult(const std::string& imageFile, const std::vector &outputs, + const std::string &homePath); +#endif diff --git a/research/cv/GhostSR/ascend310_infer/src/main.cc b/research/cv/GhostSR/ascend310_infer/src/main.cc new file mode 100644 index 000000000..cb27a2ac2 --- /dev/null +++ b/research/cv/GhostSR/ascend310_infer/src/main.cc @@ -0,0 +1,141 @@ +/* + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "include/api/model.h" +#include "include/api/context.h" +#include "include/api/types.h" +#include "include/api/serialization.h" +#include "include/dataset/vision_ascend.h" +#include "include/dataset/execute.h" +#include "include/dataset/vision.h" +#include "inc/utils.h" + +using mindspore::Context; +using mindspore::Serialization; +using mindspore::Model; +using mindspore::Status; +using mindspore::ModelType; +using mindspore::GraphCell; +using mindspore::kSuccess; +using mindspore::MSTensor; +using mindspore::dataset::Execute; +using mindspore::dataset::MapTargetDevice; +using mindspore::dataset::TensorTransform; +using mindspore::dataset::vision::Resize; +using mindspore::dataset::vision::HWC2CHW; +using mindspore::dataset::vision::Normalize; +using mindspore::dataset::vision::Decode; +using mindspore::dataset::vision::CenterCrop; + +DEFINE_string(mindir_path, "", "mindir path"); +DEFINE_string(dataset_path, ".", "dataset path"); +DEFINE_string(save_dir, "", "save dir"); +DEFINE_int32(device_id, 0, "device id"); + +int main(int argc, char **argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + if (RealPath(FLAGS_mindir_path).empty()) { + std::cout << "Invalid mindir" << std::endl; + return 1; + } + + DIR *dir = OpenDir(FLAGS_save_dir); + if (dir == nullptr) { + return 1; + } + + auto context = std::make_shared(); + auto ascend310 = std::make_shared(); + ascend310->SetDeviceID(FLAGS_device_id); + ascend310->SetBufferOptimizeMode("off_optimize"); + context->MutableDeviceInfo().push_back(ascend310); + mindspore::Graph graph; + Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph); + + Model model; + Status ret = model.Build(GraphCell(graph), context); + if (ret != kSuccess) { + std::cout << "ERROR: Build failed." << std::endl; + return 1; + } + + auto decode = Decode(); + auto normalize = Normalize({0.0, 0.0, 0.0}, {1.0, 1.0, 1.0}); + auto hwc2chw = HWC2CHW(); + Execute transform({decode, normalize, hwc2chw}); + + auto all_files = GetAllFiles(FLAGS_dataset_path); + std::map costTime_map; + size_t size = all_files.size(); + + for (size_t i = 0; i < size; ++i) { + struct timeval start = {0}; + struct timeval end = {0}; + double startTimeMs = 0.0; + double endTimeMs = 0.0; + std::vector inputs; + std::vector outputs; + std::cout << "Start predict input files:" << all_files[i] << std::endl; + auto img = MSTensor(); + auto image = ReadFileToTensor(all_files[i]); + transform(image, &img); + std::vector model_inputs = model.GetInputs(); + inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(), + img.Data().get(), img.DataSize()); + gettimeofday(&start, nullptr); + ret = model.Predict(inputs, &outputs); + gettimeofday(&end, nullptr); + if (ret != kSuccess) { + std::cout << "Predict " << all_files[i] << " failed." << std::endl; + return 1; + } + startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000; + endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000; + costTime_map.insert(std::pair(startTimeMs, endTimeMs)); + WriteResult(all_files[i], outputs, FLAGS_save_dir); + } + double average = 0.0; + int inferCount = 0; + + for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) { + double diff = 0.0; + diff = iter->second - iter->first; + average += diff; + inferCount++; + } + + average = average / inferCount; + std::stringstream timeCost; + timeCost << "NN inference cost average time: " << average << " ms of infer_count " << inferCount << std::endl; + std::cout << "NN inference cost average time: " << average << "ms of infer_count " << inferCount << std::endl; + std::string fileName = FLAGS_save_dir + std::string("/test_perform_static.txt"); + std::ofstream fileStream(fileName.c_str(), std::ios::trunc); + fileStream << timeCost.str(); + fileStream.close(); + costTime_map.clear(); + return 0; +} diff --git a/research/cv/GhostSR/ascend310_infer/src/utils.cc b/research/cv/GhostSR/ascend310_infer/src/utils.cc new file mode 100644 index 000000000..1df8bdf7a --- /dev/null +++ b/research/cv/GhostSR/ascend310_infer/src/utils.cc @@ -0,0 +1,144 @@ +/* + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "inc/utils.h" + +#include +#include +#include + +using mindspore::MSTensor; +using mindspore::DataType; + +std::vector GetAllFiles(std::string_view dirName) { + struct dirent *filename; + DIR *dir = OpenDir(dirName); + if (dir == nullptr) { + return {}; + } + std::vector dirs; + std::vector files; + while ((filename = readdir(dir)) != nullptr) { + std::string dName = std::string(filename->d_name); + if (dName == "." || dName == "..") { + continue; + } else if (filename->d_type == DT_DIR) { + dirs.emplace_back(std::string(dirName) + "/" + filename->d_name); + } else if (filename->d_type == DT_REG) { + files.emplace_back(std::string(dirName) + "/" + filename->d_name); + } else { + continue; + } + } + + for (auto d : dirs) { + dir = OpenDir(d); + while ((filename = readdir(dir)) != nullptr) { + std::string dName = std::string(filename->d_name); + if (dName == "." || dName == ".." || filename->d_type != DT_REG) { + continue; + } + files.emplace_back(std::string(d) + "/" + filename->d_name); + } + } + std::sort(files.begin(), files.end()); + for (auto &f : files) { + std::cout << "image file: " << f << std::endl; + } + return files; +} + +int WriteResult(const std::string& imageFile, const std::vector &outputs, const std::string &homePath) { + for (size_t i = 0; i < outputs.size(); ++i) { + std::shared_ptr netOutput; + netOutput = outputs[i].Data(); + size_t outputSize = outputs[i].DataSize(); + int pos = imageFile.rfind('/'); + std::string fileName(imageFile, pos + 1); + fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin"); + std::string outFileName = homePath + "/" + fileName; + FILE * outputFile = fopen(outFileName.c_str(), "wb"); + fwrite(netOutput.get(), outputSize, sizeof(char), outputFile); + fclose(outputFile); + outputFile = nullptr; + } + return 0; +} + +mindspore::MSTensor ReadFileToTensor(const std::string &file) { + if (file.empty()) { + std::cout << "Pointer file is nullptr" << std::endl; + return mindspore::MSTensor(); + } + + std::ifstream ifs(file); + if (!ifs.good()) { + std::cout << "File: " << file << " is not exist" << std::endl; + return mindspore::MSTensor(); + } + + if (!ifs.is_open()) { + std::cout << "File: " << file << "open failed" << std::endl; + return mindspore::MSTensor(); + } + + ifs.seekg(0, std::ios::end); + size_t size = ifs.tellg(); + mindspore::MSTensor buffer( + file, mindspore::DataType::kNumberTypeUInt8, {static_cast(size)}, nullptr, size); + + ifs.seekg(0, std::ios::beg); + ifs.read(reinterpret_cast(buffer.MutableData()), size); + ifs.close(); + + return buffer; +} + +DIR *OpenDir(std::string_view dirName) { + if (dirName.empty()) { + std::cout << " dirName is null ! " << std::endl; + return nullptr; + } + std::string realPath = RealPath(dirName); + struct stat s; + lstat(realPath.c_str(), &s); + if (!S_ISDIR(s.st_mode)) { + std::cout << "dirName is not a valid directory !" << std::endl; + return nullptr; + } + DIR *dir = opendir(realPath.c_str()); + if (dir == nullptr) { + std::cout << "Can not open dir " << dirName << std::endl; + return nullptr; + } + std::cout << "Successfully opened the dir " << dirName << std::endl; + return dir; +} + +std::string RealPath(std::string_view path) { + char realPathMem[PATH_MAX] = {0}; + char *realPathRet = nullptr; + realPathRet = realpath(path.data(), realPathMem); + + if (realPathRet == nullptr) { + std::cout << "File: " << path << " is not exist."; + return ""; + } + + std::string realPath(realPathMem); + std::cout << path << " realpath is: " << realPath << std::endl; + return realPath; +} diff --git a/research/cv/GhostSR/benchmark_config.yaml b/research/cv/GhostSR/benchmark_config.yaml new file mode 100644 index 000000000..f20ac1358 --- /dev/null +++ b/research/cv/GhostSR/benchmark_config.yaml @@ -0,0 +1,71 @@ +# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) +enable_modelarts: False +# Url for modelarts +data_url: "" +train_url: "" +checkpoint_url: "" +# Path for local +data_path: "/cache/data" +output_path: "./output" +device_target: "GPU" + +network: "EDSR4GhostSRMs" # EDSR_mindspore / EDSR_GhostSR_ms + +# ============================================================================== +ckpt_save_dir: "./ckpt/" + +self_ensemble: True +save_sr: True + +# dataset options +dataset_name: "benchmark" +scale: 2 +need_unzip_in_modelarts: False + +# net options +pre_trained: "" +rgb_range: 255 +rgb_mean: [ 0.4488, 0.4371, 0.4040 ] +rgb_std: [ 1.0, 1.0, 1.0 ] +n_colors: 3 +n_feats: 256 +kernel_size: 3 +n_resblocks: 32 +res_scale: 0.1 + +--- +# helper + +enable_modelarts: "set True if run in modelarts, default: False" +# Url for modelarts +data_url: "modelarts data path" +train_url: "modelarts code path" +checkpoint_url: "modelarts checkpoint save path" +# Path for local +data_path: "local data path, data will be download from 'data_url', default: /cache/data" +output_path: "local output path, checkpoint will be upload to 'checkpoint_url', default: /cache/train" +device_target: "choice from ['Ascend'], default: Ascend" + +# ============================================================================== +# train options +ckpt_save_dir: "the relative path to save checkpoint, root path is 'output_path', defalue: ./ckpt/" + +self_ensemble: "set True if wanna do self-ensemble while evaluating, defalue: True" +save_sr: "set True if wanna save sr and hr image while evaluating, defalue: True" + +# dataset options +dataset_name: "dataset name, defalue: DIV2K" +scale: "scale for super resolution reconstruction, choice from [2,3,4], defalue: 4" +need_unzip_in_modelarts: "set True if wanna unzip data after download data from s3, defalue: False" +need_unzip_files: "list of zip files to unzip, only work while 'need_unzip_in_modelarts'=True" + +# net options +pre_trained: "load pre-trained model, x2/x3/x4 models can be loaded for each other, choice from [[S3_ABS_PATH], [RELATIVE_PATH below 'output_path'], [LOCAL_ABS_PATH], ''], defalue: ''" +rgb_range: "pix value range, defalue: 255" +rgb_mean: "rgb mean, defalue: [0.4488, 0.4371, 0.4040]" +rgb_std: "rgb standard deviation, defalue: [1.0, 1.0, 1.0]" +n_colors: "the number of RGB image channels, defalue: 3" +n_feats: "the number of output features for each Conv2d, defalue: 256" +kernel_size: "kernel size for Conv2d, defalue: 3" +n_resblocks: "the number of resblocks, defalue: 32" +res_scale: "zoom scale of res branch, defalue: 0.1" diff --git a/research/cv/GhostSR/eval.py b/research/cv/GhostSR/eval.py new file mode 100644 index 000000000..c7022893f --- /dev/null +++ b/research/cv/GhostSR/eval.py @@ -0,0 +1,207 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +#################evaluate EDSR_mindspore example on DIV2K######################## +""" +import math +import os + +import numpy as np +from mindspore import Tensor, ops +from mindspore import dataset as ds +from mindspore.common import set_seed + +from model_utils.config import config +from model_utils.moxing_adapter import moxing_wrapper, get_rank_id +from src.dataset import get_rank_info, LrHrImages, hwc2chw, uint8_to_float32 +from src.metric import SelfEnsembleWrapperNumpy, PSNR, SaveSrHr +from src.utils import init_env, init_dataset, init_net, modelarts_pre_process, do_eval + +set_seed(2021) + + +class HrCutter: + """ + cut hr into correct shape, for evaluating benchmark + """ + + def __init__(self, lr_scale): + self.lr_scale = lr_scale + + def __call__(self, lr, hr): + lrh, lrw, _ = lr.shape + hrh, hrw, _ = hr.shape + h, w = lrh * self.lr_scale, lrw * self.lr_scale + if hrh != h or hrw != w: + hr = hr[0:h, 0:w, :] + return lr, hr + + +class RepeatDataSet: + """ + Repeat DataSet so that it can dist evaluate Set5 + """ + + def __init__(self, dataset, repeat): + self.dataset = dataset + self.repeat = repeat + + def __getitem__(self, idx): + return self.dataset[idx % len(self.dataset)] + + def __len__(self): + return len(self.dataset) * self.repeat + + +def create_dataset_benchmark(dataset_path, scale): + """ + create a train or eval benchmark dataset + Args: + dataset_path(string): the path of dataset. + scale(int): lr scale, read data ordered by it, choices=(2,3,4) + Returns: + multi_datasets + """ + lr_scale = scale + + multi_datasets = {} + for dataset_name in ["Set5", "Set14", "BSDS100", "Urban100"]: + # get HR_PATH/*.png + dir_hr = os.path.join(dataset_path, dataset_name, "HR") + hr_pattern = os.path.join(dir_hr, "*.png") + + # get LR + column_names = [f"lrx{lr_scale}", "hr"] + dir_lr = os.path.join(dataset_path, dataset_name, "LR_bicubic", f"X{lr_scale}") + lr_pattern = os.path.join(dir_lr, f"*x{lr_scale}.png") + lrs_pattern = [lr_pattern] + + device_num, rank_id = get_rank_info() + + # make dataset + dataset = LrHrImages(lr_pattern=lrs_pattern, hr_pattern=hr_pattern) + if len(dataset) < device_num: + dataset = RepeatDataSet(dataset, repeat=device_num // len(dataset) + 1) + + # make mindspore dataset + if device_num == 1 or device_num is None: + generator_dataset = ds.GeneratorDataset(dataset, column_names=column_names, + num_parallel_workers=3, + shuffle=False) + else: + sampler = ds.DistributedSampler(num_shards=device_num, shard_id=rank_id, shuffle=False, + offset=0) + generator_dataset = ds.GeneratorDataset(dataset, column_names=column_names, + num_parallel_workers=3, + sampler=sampler) + + # define map operations + transform_img = [ + HrCutter(lr_scale), + hwc2chw, + uint8_to_float32, + ] + + # pre-process hr lr + generator_dataset = generator_dataset.map(input_columns=column_names, + output_columns=column_names, + operations=transform_img) + + # apply batch operations + generator_dataset = generator_dataset.batch(1, drop_remainder=False) + + multi_datasets[dataset_name] = generator_dataset + return multi_datasets + + +class BenchmarkPSNR(PSNR): + """ + eval psnr for Benchmark + """ + + def __init__(self, rgb_range, shave, channels_scale): + super(BenchmarkPSNR, self).__init__(rgb_range=rgb_range, shave=shave) + self.channels_scale = channels_scale + self.c_scale = Tensor( + np.array(self.channels_scale, dtype=np.float32).reshape((1, -1, 1, 1))) + self.sum = ops.ReduceSum(keep_dims=True) + + def update(self, *inputs): + if len(inputs) != 2: + raise ValueError('PSNR need 2 inputs (sr, hr), but got {}'.format(len(inputs))) + sr, hr = inputs + sr = self.quantize(sr) + diff = (sr - hr) / self.rgb_range + diff = diff * self.c_scale + valid = self.sum(diff, 1) + if self.shave is not None and self.shave != 0: + valid = valid[..., self.shave:(-self.shave), self.shave:(-self.shave)] + mse_list = (valid ** 2).mean(axis=(1, 2, 3)) + mse_list = self._convert_data(mse_list).tolist() + psnr_list = [float(1e32) if mse == 0 else (- 10.0 * math.log10(mse)) for mse in mse_list] + self._accumulate(psnr_list) + + +@moxing_wrapper(pre_process=modelarts_pre_process) +def run_eval(): + """ + run eval + """ + print(config, flush=True) + cfg = config + + init_env(cfg) + net = init_net(cfg) + eval_net = SelfEnsembleWrapperNumpy(net) if cfg.self_ensemble else net + + if cfg.dataset_name == "DIV2K": + cfg.batch_size = 1 + cfg.patch_size = -1 + ds_val = init_dataset(cfg, "valid") + metrics = { + "psnr": PSNR(rgb_range=cfg.rgb_range, shave=6 + cfg.scale), + } + if config.save_sr: + save_img_dir = os.path.join(cfg.output_path, "HrSr") + os.makedirs(save_img_dir, exist_ok=True) + metrics["num_sr"] = SaveSrHr(save_img_dir) + do_eval(eval_net, ds_val, metrics) + print("eval success", flush=True) + + elif cfg.dataset_name == "benchmark": + multi_datasets = create_dataset_benchmark(dataset_path=cfg.data_path, scale=cfg.scale) + result = {} + for dname, ds_val in multi_datasets.items(): + dpnsr = f"{dname}_psnr" + gray_coeffs = [65.738, 129.057, 25.064] + channels_scale = [x / 256.0 for x in gray_coeffs] + metrics = { + dpnsr: BenchmarkPSNR(rgb_range=cfg.rgb_range, shave=cfg.scale, + channels_scale=channels_scale) + } + if config.save_sr: + save_img_dir = os.path.join(cfg.output_path, "HrSr", dname) + os.makedirs(save_img_dir, exist_ok=True) + metrics["num_sr"] = SaveSrHr(save_img_dir) + result[dpnsr] = do_eval(eval_net, ds_val, metrics)[dpnsr] + if get_rank_id() == 0: + print(result, flush=True) + print("eval success", flush=True) + else: + raise RuntimeError("Unsupported dataset.") + + +if __name__ == '__main__': + run_eval() diff --git a/research/cv/GhostSR/eval_onnx.py b/research/cv/GhostSR/eval_onnx.py new file mode 100644 index 000000000..44ba320bc --- /dev/null +++ b/research/cv/GhostSR/eval_onnx.py @@ -0,0 +1,111 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +#################evaluate EDSR_mindspore example on DIV2K######################## +""" +import os +import time + +import mindspore +from mindspore import Tensor +from mindspore.common import set_seed + +import onnxruntime as ort +from model_utils.config import config +from src.metric import PSNR, SaveSrHr +from src.utils import init_env, init_dataset + +set_seed(2021) + + +def create_session(checkpoint_path, target_device): + """Create ONNX runtime session""" + if target_device == 'GPU': + providers = ['CUDAExecutionProvider'] + elif target_device in ('CPU', 'Ascend'): + providers = ['CPUExecutionProvider'] + else: + raise ValueError( + f"Unsupported target device '{target_device}'. Expected one of: 'CPU', 'GPU', 'Ascend'") + session = ort.InferenceSession(checkpoint_path, providers=providers) + input_names = [x.name for x in session.get_inputs()] + return session, input_names + + +def unpadding(img, target_shape): + h, w = target_shape[2], target_shape[3] + _, _, img_h, img_w = img.shape + if img_h > h: + img = img[:, :, :h, :] + if img_w > w: + img = img[:, :, :, :w] + return img + + +def do_eval(session, input_names, ds_val, metrics, cur_epoch=None): + """ + do eval for psnr and save hr, sr + """ + total_step = ds_val.get_dataset_size() + setw = len(str(total_step)) + begin = time.time() + step_begin = time.time() + rank_id = 0 + for i, (lr, hr) in enumerate(ds_val): + input_data = [lr.asnumpy()] + sr = session.run(None, dict(zip(input_names, input_data))) + sr = Tensor(unpadding(sr[0], hr.shape), mindspore.float32) + _ = [m.update(sr, hr) for m in metrics.values()] + result = {k: m.eval(sync=False) for k, m in metrics.items()} + result["time"] = time.time() - step_begin + step_begin = time.time() + print(f"[{i + 1:>{setw}}/{total_step:>{setw}}] rank = {rank_id} result = {result}", + flush=True) + result = {k: m.eval(sync=True) for k, m in metrics.items()} + result["time"] = time.time() - begin + print(f"evaluation result = {result}", flush=True) + return result + + +def run_eval(): + """ + run eval + """ + print(config, flush=True) + cfg = config + cfg.lr_type = "bicubic_AUG_self_ensemble" + + init_env(cfg) + session, input_names = create_session(cfg.pre_trained, 'GPU') + + if cfg.dataset_name == "DIV2K": + cfg.batch_size = 1 + cfg.patch_size = -1 + ds_val = init_dataset(cfg, "valid") + metrics = { + "psnr": PSNR(rgb_range=cfg.rgb_range, shave=6 + cfg.scale), + } + if config.save_sr: + save_img_dir = os.path.join(cfg.output_path, "HrSr") + os.makedirs(save_img_dir, exist_ok=True) + metrics["num_sr"] = SaveSrHr(save_img_dir) + do_eval(session, input_names, ds_val, metrics) + print("eval success", flush=True) + else: + raise RuntimeError("Unsupported dataset.") + + +if __name__ == '__main__': + run_eval() diff --git a/research/cv/GhostSR/export.py b/research/cv/GhostSR/export.py new file mode 100644 index 000000000..1eddada7d --- /dev/null +++ b/research/cv/GhostSR/export.py @@ -0,0 +1,65 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +##############export checkpoint file into air, mindir models################# +python export.py +""" +import os +import numpy as np + +import mindspore as ms +from mindspore import Tensor, export, context + +from model_utils.config import config +from model_utils.device_adapter import get_device_id +from model_utils.moxing_adapter import moxing_wrapper +from src.utils import init_net + +context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) +if config.device_target == "Ascend": + context.set_context(device_id=get_device_id()) + +MAX_HR_SIZE = 2040 + + +@moxing_wrapper() +def run_export(): + """ + run export + """ + print(config) + + cfg = config + if cfg.pre_trained is None: + raise RuntimeError('config.pre_trained is None.') + + net = init_net(cfg) + max_lr_size = MAX_HR_SIZE // cfg.scale + input_arr = Tensor(np.ones([1, cfg.n_colors, max_lr_size, max_lr_size]), ms.float32) + file_name = os.path.splitext(os.path.basename(cfg.pre_trained))[0] + file_name = file_name + f"_InputSize{max_lr_size}" + file_path = os.path.join(cfg.output_path, file_name) + file_format = 'MINDIR' + + num_params = sum([param.size for param in net.parameters_dict().values()]) + export(net, input_arr, file_name=file_path, file_format=file_format) + print(f"export success", flush=True) + print( + f"{cfg.pre_trained} -> {file_path}.{file_format.lower()}, net parameters = {num_params / 1000000:>0.4}M", + flush=True) + + +if __name__ == '__main__': + run_export() diff --git a/research/cv/GhostSR/infer/convert/aipp_edsr_opencv.cfg b/research/cv/GhostSR/infer/convert/aipp_edsr_opencv.cfg new file mode 100644 index 000000000..931dbfeda --- /dev/null +++ b/research/cv/GhostSR/infer/convert/aipp_edsr_opencv.cfg @@ -0,0 +1,5 @@ +aipp_op { +aipp_mode:static +input_format:RGB888_U8 +} + diff --git a/research/cv/GhostSR/infer/convert/convert_om.sh b/research/cv/GhostSR/infer/convert/convert_om.sh new file mode 100644 index 000000000..a57ec0d1b --- /dev/null +++ b/research/cv/GhostSR/infer/convert/convert_om.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +model_path=$1 +aipp_cfg_path=$2 +output_model_name=$3 + +atc \ +--model=$model_path \ +--input_format=NCHW \ +--framework=1 \ +--output=$output_model_name \ +--log=error \ +--soc_version=Ascend310 \ +--insert_op_conf=$aipp_cfg_path diff --git a/research/cv/GhostSR/infer/data/config/edsr.pipeline b/research/cv/GhostSR/infer/data/config/edsr.pipeline new file mode 100644 index 000000000..ea3c2a13d --- /dev/null +++ b/research/cv/GhostSR/infer/data/config/edsr.pipeline @@ -0,0 +1,28 @@ +{ + "edsr_superResolution": { + "stream_config": { + "deviceId": "0" + }, + "appsrc0": { + "props": { + "blocksize": "409600" + }, + "factory": "appsrc", + "next": "mxpi_tensorinfer0" + }, + "mxpi_tensorinfer0": { + "props": { + "dataSource": "appsrc0", + "modelPath": "../model/edsr.om" + }, + "factory": "mxpi_tensorinfer", + "next": "appsink0" + }, + "appsink0": { + "props": { + "blocksize": "409600" + }, + "factory": "appsink" + } + } + } \ No newline at end of file diff --git a/research/cv/GhostSR/infer/docker_start_infer.sh b/research/cv/GhostSR/infer/docker_start_infer.sh new file mode 100644 index 000000000..d8cf64915 --- /dev/null +++ b/research/cv/GhostSR/infer/docker_start_infer.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +#coding = utf-8 +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +docker_image=$1 +data_dir=$2 + +function show_help() { + echo "Usage: docker_start.sh docker_image data_dir" +} + +function param_check() { + if [ -z "${docker_image}" ]; then + echo "please input docker_image" + show_help + exit 1 + fi + + if [ -z "${data_dir}" ]; then + echo "please input data_dir" + show_help + exit 1 + fi +} + +param_check + +docker run -it \ + --device=/dev/davinci0 \ + --device=/dev/davinci_manager \ + --device=/dev/devmm_svm \ + --device=/dev/hisi_hdc \ + -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \ + -v ${data_dir}:${data_dir} \ + ${docker_image} \ + /bin/bash diff --git a/research/cv/GhostSR/infer/mxbase/CMakeLists.txt b/research/cv/GhostSR/infer/mxbase/CMakeLists.txt new file mode 100644 index 000000000..1538ccd9d --- /dev/null +++ b/research/cv/GhostSR/infer/mxbase/CMakeLists.txt @@ -0,0 +1,55 @@ +cmake_minimum_required(VERSION 3.14.0) +project(edsr) + +set(TARGET edsr) + +add_definitions(-DENABLE_DVPP_INTERFACE) +#add_compile_options(-std=c++11 -fPIE -fstack-protector-all -fPIC -Wall) +add_link_options(-Wl,-z,relro,-z,now,-z,noexecstack -s -pie) + +add_compile_options(-std=c++11 -fPIE -fstack-protector-all -fPIC -Wall +-Dgoogle=mindxsdk_private -D_GLIBCXX_USE_CXX11_ABI=0) + + +#Check environment variable +if(NOT DEFINED ENV{ASCEND_HOME}) + message(FATAL_ERROR "please define environment variable:ASCEND_HOME") +endif() +if(NOT DEFINED ENV{ASCEND_VERSION}) + message(WARNING "please define environment variable:ASCEND_VERSION") +endif() +if(NOT DEFINED ENV{ARCH_PATTERN}) + message(WARNING "please define environment variable:ARCH_PATTERN") +endif() +set(ACL_INC_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/include) +set(ACL_LIB_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/lib64) + +set(MXBASE_ROOT_DIR $ENV{MX_SDK_HOME}) +set(MXBASE_INC ${MXBASE_ROOT_DIR}/include) +set(MXBASE_LIB_DIR ${MXBASE_ROOT_DIR}/lib) +set(MXBASE_POST_LIB_DIR ${MXBASE_ROOT_DIR}/lib/modelpostprocessors) +set(MXBASE_POST_PROCESS_DIR ${MXBASE_ROOT_DIR}/postprocess/include) + + +if(DEFINED ENV{MXSDK_OPENSOURCE_DIR}) + set(OPENSOURCE_DIR $ENV{MXSDK_OPENSOURCE_DIR}) +else() + set(OPENSOURCE_DIR ${MXBASE_ROOT_DIR}/opensource) +endif() + +include_directories(${ACL_INC_DIR}) +include_directories(${OPENSOURCE_DIR}/include) +include_directories(${OPENSOURCE_DIR}/include/opencv4) + +include_directories(${MXBASE_INC}) +include_directories(${MXBASE_POST_PROCESS_DIR}) + +link_directories(${ACL_LIB_DIR}) +link_directories(${OPENSOURCE_DIR}/lib) +link_directories(${MXBASE_LIB_DIR}) +link_directories(${MXBASE_POST_LIB_DIR}) + +add_executable(${TARGET} main.cpp EdsrSuperresolution.cpp) +target_link_libraries(${TARGET} glog cpprest mxbase opencv_world) + +install(TARGETS ${TARGET} RUNTIME DESTINATION ${PROJECT_SOURCE_DIR}/) diff --git a/research/cv/GhostSR/infer/mxbase/EdsrSuperresolution.cpp b/research/cv/GhostSR/infer/mxbase/EdsrSuperresolution.cpp new file mode 100644 index 000000000..1016326d4 --- /dev/null +++ b/research/cv/GhostSR/infer/mxbase/EdsrSuperresolution.cpp @@ -0,0 +1,200 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "EdsrSuperresolution.h" + +#include +#include +#include + +#include "MxBase/DeviceManager/DeviceManager.h" +#include "MxBase/DvppWrapper/DvppWrapper.h" +#include "MxBase/Log/Log.h" + + +namespace localParameter { + const uint32_t VECTOR_FIRST_INDEX = 0; + const uint32_t VECTOR_SECOND_INDEX = 1; + const uint32_t VECTOR_THIRD_INDEX = 2; + const uint32_t VECTOR_FOURTH_INDEX = 3; + const uint32_t VECTOR_FIFTH_INDEX = 4; +} + +APP_ERROR EdsrSuperresolution::Init(const InitParam &initParam) { + deviceId_ = initParam.deviceId; + APP_ERROR ret = MxBase::DeviceManager::GetInstance()->InitDevices(); + if (ret != APP_ERR_OK) { + LogError << "Init devices failed, ret=" << ret << "."; + return ret; + } + + ret = MxBase::TensorContext::GetInstance()->SetContext(initParam.deviceId); + if (ret != APP_ERR_OK) { + LogError << "Set context failed, ret=" << ret << "."; + return ret; + } + + model_ = std::make_shared(); + ret = model_->Init(initParam.modelPath, modelDesc_); + if (ret != APP_ERR_OK) { + LogError << "ModelInferenceProcessor init failed, ret=" << ret << "."; + return ret; + } + uint32_t outputModelHeight = modelDesc_.outputTensors[0].tensorDims[localParameter::VECTOR_THIRD_INDEX]; + uint32_t inputModelHeight = modelDesc_.inputTensors[0].tensorDims[localParameter::VECTOR_SECOND_INDEX]; + uint32_t inputModelWidth = modelDesc_.inputTensors[0].tensorDims[localParameter::VECTOR_THIRD_INDEX]; + + scale_ = outputModelHeight/inputModelHeight; + maxEdge_ = inputModelWidth > inputModelHeight ? inputModelWidth:inputModelHeight; + return APP_ERR_OK; +} + +APP_ERROR EdsrSuperresolution::DeInit() { + model_->DeInit(); + MxBase::DeviceManager::GetInstance()->DestroyDevices(); + return APP_ERR_OK; +} + +APP_ERROR EdsrSuperresolution::ReadImage(const std::string &imgPath, cv::Mat *imageMat) { + *imageMat = cv::imread(imgPath, cv::IMREAD_COLOR); + imageWidth_ = imageMat->cols; + imageHeight_ = imageMat->rows; + return APP_ERR_OK; +} + +APP_ERROR EdsrSuperresolution::PaddingImage(cv::Mat *imageSrc, cv::Mat *imageDst, const uint32_t &targetLength) { + uint32_t padding_h = targetLength - imageHeight_; + uint32_t padding_w = targetLength - imageWidth_; + cv::copyMakeBorder(*imageSrc, *imageDst, 0, padding_h, 0, padding_w, cv::BORDER_CONSTANT, 0); + return APP_ERR_OK; +} + + +APP_ERROR EdsrSuperresolution::CVMatToTensorBase(const cv::Mat &imageMat, MxBase::TensorBase *tensorBase) { + const uint32_t dataSize = imageMat.cols * imageMat.rows * MxBase::YUV444_RGB_WIDTH_NU; + + MxBase::MemoryData memoryDataDst(dataSize, MxBase::MemoryData::MEMORY_DEVICE, deviceId_); + + MxBase::MemoryData memoryDataSrc(imageMat.data, dataSize, MxBase::MemoryData::MEMORY_HOST_MALLOC); + + APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc); + if (ret != APP_ERR_OK) { + LogError << GetError(ret) << "Memory malloc failed."; + return ret; + } + + std::vector shape = {imageMat.rows * MxBase::YUV444_RGB_WIDTH_NU, static_cast(imageMat.cols)}; + *tensorBase = MxBase::TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_UINT8); + return APP_ERR_OK; +} + +APP_ERROR EdsrSuperresolution::Inference(std::vector *inputs, + std::vector *outputs) { + auto dtypes = model_->GetOutputDataType(); + for (size_t i = 0; i < modelDesc_.outputTensors.size(); ++i) { + std::vector shape = {}; + for (size_t j = 0; j < modelDesc_.outputTensors[i].tensorDims.size(); ++j) { + shape.push_back((uint32_t)modelDesc_.outputTensors[i].tensorDims[j]); + } + MxBase::TensorBase tensor(shape, dtypes[i], MxBase::MemoryData::MemoryType::MEMORY_DEVICE, deviceId_); + APP_ERROR ret = MxBase::TensorBase::TensorBaseMalloc(tensor); + if (ret != APP_ERR_OK) { + LogError << "TensorBaseMalloc failed, ret=" << ret << "."; + return ret; + } + outputs->push_back(tensor); + } + MxBase::DynamicInfo dynamicInfo = {}; + dynamicInfo.dynamicType = MxBase::DynamicType::STATIC_BATCH; + APP_ERROR ret = model_->ModelInference(*inputs, *outputs, dynamicInfo); + if (ret != APP_ERR_OK) { + LogError << "ModelInference failed, ret=" << ret << "."; + return ret; + } + return APP_ERR_OK; +} + + +APP_ERROR EdsrSuperresolution::PostProcess(std::vector *inputs, cv::Mat *imageMat) { + MxBase::TensorBase tensor = *inputs->begin(); + int ret = tensor.ToHost(); + if (ret != APP_ERR_OK) { + LogError << GetError(ret) << "Tensor deploy to host failed."; + return ret; + } + uint32_t outputModelChannel = tensor.GetShape()[localParameter::VECTOR_SECOND_INDEX]; + uint32_t outputModelHeight = tensor.GetShape()[localParameter::VECTOR_THIRD_INDEX]; + uint32_t outputModelWidth = tensor.GetShape()[localParameter::VECTOR_FOURTH_INDEX]; + LogInfo << "Channel:" << outputModelChannel << " Height:" << outputModelHeight << " Width:" << outputModelWidth; + + uint32_t finalHeight = imageHeight_ * scale_; + uint32_t finalWidth = imageWidth_ * scale_; + cv::Mat output(finalHeight, finalWidth, CV_32FC3); + + auto data = reinterpret_cast(tensor.GetBuffer()); + + for (size_t c = 0; c < outputModelChannel; ++c) { + for (size_t x = 0; x < finalHeight; ++x) { + for (size_t y = 0; y < finalWidth; ++y) { + output.at(x, y)[c] = data[0][c][x][y]; + } + } + } + + *imageMat = output; + return APP_ERR_OK; +} + +APP_ERROR EdsrSuperresolution::Process(const std::string &imgPath) { + cv::Mat imageMat; + APP_ERROR ret = ReadImage(imgPath, &imageMat); + if (ret != APP_ERR_OK) { + LogError << "ReadImage failed, ret=" << ret << "."; + return ret; + } + + PaddingImage(&imageMat, &imageMat, maxEdge_); + MxBase::TensorBase tensorBase; + ret = CVMatToTensorBase(imageMat, &tensorBase); + if (ret != APP_ERR_OK) { + LogError << "CVMatToTensorBase failed, ret=" << ret << "."; + return ret; + } + + + std::vector inputs = {}; + std::vector outputs = {}; + inputs.push_back(tensorBase); + ret = Inference(&inputs, &outputs); + + if (ret != APP_ERR_OK) { + LogError << "Inference failed, ret=" << ret << "."; + return ret; + } + + cv::Mat output; + ret = PostProcess(&outputs, &output); + if (ret != APP_ERR_OK) { + LogError << "PostProcess failed, ret=" << ret << "."; + return ret; + } + + std::string resultPath = imgPath; + size_t pos = resultPath.find_last_of("."); + resultPath.replace(resultPath.begin() + pos, resultPath.end(), "_infer.png"); + cv::imwrite(resultPath, output); + return APP_ERR_OK; +} diff --git a/research/cv/GhostSR/infer/mxbase/EdsrSuperresolution.h b/research/cv/GhostSR/infer/mxbase/EdsrSuperresolution.h new file mode 100644 index 000000000..36b1ab9cf --- /dev/null +++ b/research/cv/GhostSR/infer/mxbase/EdsrSuperresolution.h @@ -0,0 +1,54 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef EDSR_SUPERRESULOTION_H +#define EDSR_SUPERRESULOTION_H + +#include +#include +#include +#include +#include "MxBase/ModelInfer/ModelInferenceProcessor.h" +#include "MxBase/PostProcessBases/PostProcessDataType.h" +#include "MxBase/Tensor/TensorContext/TensorContext.h" + +struct InitParam { + uint32_t deviceId; + std::string modelPath; +}; + +class EdsrSuperresolution { + public: + APP_ERROR Init(const InitParam &initParam); + APP_ERROR DeInit(); + APP_ERROR ReadImage(const std::string &imgPath, cv::Mat *imageMat); + APP_ERROR CVMatToTensorBase(const cv::Mat &imageMat, MxBase::TensorBase *tensorBase); + APP_ERROR Inference(std::vector *inputs, std::vector *outputs); + APP_ERROR Process(const std::string &imgPath); + APP_ERROR PostProcess(std::vector *inputs, cv::Mat *imageMat); + APP_ERROR PaddingImage(cv::Mat *imageSrc, cv::Mat *imageDst, const uint32_t &targetLength); + + private: + std::shared_ptr model_; + MxBase::ModelDesc modelDesc_; + uint32_t deviceId_ = 0; + uint32_t scale_ = 0; + uint32_t imageWidth_ = 0; + uint32_t imageHeight_ = 0; + uint32_t maxEdge_ = 0; +}; + +#endif diff --git a/research/cv/GhostSR/infer/mxbase/build.sh b/research/cv/GhostSR/infer/mxbase/build.sh new file mode 100644 index 000000000..566c6461e --- /dev/null +++ b/research/cv/GhostSR/infer/mxbase/build.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + + +# env + +mkdir -p build +cd build || exit + +function make_plugin() { + if ! cmake ..; + then + echo "cmake failed." + return 1 + fi + + if ! (make); + then + echo "make failed." + return 1 + fi + + return 0 +} + +if make_plugin; +then + echo "INFO: Build successfully." +else + echo "ERROR: Build failed." +fi + +cd - || exit diff --git a/research/cv/GhostSR/infer/mxbase/main.cpp b/research/cv/GhostSR/infer/mxbase/main.cpp new file mode 100644 index 000000000..cb3401c0d --- /dev/null +++ b/research/cv/GhostSR/infer/mxbase/main.cpp @@ -0,0 +1,46 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "EdsrSuperresolution.h" +#include "MxBase/Log/Log.h" + + +// infer an image +int main(int argc, char *argv[]) { + if (argc <= 1) { + LogWarn << "Please input image path, such as './test.png'"; + return APP_ERR_OK; + } + InitParam initParam = {}; + initParam.deviceId = 0; + initParam.modelPath = "../model/edsr.om"; + EdsrSuperresolution esdrSR; + APP_ERROR ret = esdrSR.Init(initParam); + if (ret != APP_ERR_OK) { + LogError << "EdsrSuperresolution init failed, ret=" << ret << "."; + return ret; + } + std::string imgPath = argv[1]; + ret = esdrSR.Process(imgPath); + if (ret != APP_ERR_OK) { + LogError << "EdsrSuperresolution process failed, ret=" << ret << "."; + esdrSR.DeInit(); + return ret; + } + + esdrSR.DeInit(); + return APP_ERR_OK; +} diff --git a/research/cv/GhostSR/infer/sdk/eval.py b/research/cv/GhostSR/infer/sdk/eval.py new file mode 100644 index 000000000..dc2433be1 --- /dev/null +++ b/research/cv/GhostSR/infer/sdk/eval.py @@ -0,0 +1,72 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""eval for sdk infer""" +import argparse +import math +import os + +import cv2 +import numpy as np + + +def parser_args(): + """parse arguments""" + parser = argparse.ArgumentParser() + parser.add_argument("--label_dir", type=str, default="../data/DIV2K/label/", + help="path of label images directory") + parser.add_argument("--infer_dir", type=str, default=" ../data/sdk_out", + help="path of infer images directory") + parser.add_argument("--scale", type=int, default=2) + return parser.parse_args() + + +def calc_psnr(sr, hr, scale, rgb_range): + """calculate psnr""" + hr = np.float32(hr) + sr = np.float32(sr) + diff = (sr - hr) / rgb_range + gray_coeffs = np.array([65.738, 129.057, 25.064]).reshape((1, 3, 1, 1)) / 256 + diff = np.multiply(diff, gray_coeffs).sum(1) + if hr.size == 1: + return 0 + if scale != 1: + shave = scale + else: + shave = scale + 6 + if scale == 1: + valid = diff + else: + valid = diff[..., shave:-shave, shave:-shave] + mse = np.mean(pow(valid, 2)) + return -10 * math.log10(mse) + + +if __name__ == '__main__': + args = parser_args() + infer_path_list = os.listdir(args.infer_dir) + total_num = len(infer_path_list) + mean_psnr = 0.0 + for infer_p in infer_path_list: + infer_path = os.path.join(args.infer_dir, infer_p) + label_path = os.path.join(args.label_dir, infer_p.replace('_infer', '')) + infer_img = cv2.imread(infer_path) + h, w = infer_img.shape[:2] + label_img = cv2.imread(label_path)[0:h, 0:w] + infer_img = np.expand_dims(infer_img, 0).transpose((0, 3, 1, 2)) + label_img = np.expand_dims(label_img, 0).transpose((0, 3, 1, 2)) + psnr = calc_psnr(infer_img, label_img, args.scale, 255.0) + mean_psnr += psnr / total_num + print("current psnr: ", psnr) + print('Mean psnr of %s images is %.4f' % (total_num, mean_psnr)) diff --git a/research/cv/GhostSR/infer/sdk/main.py b/research/cv/GhostSR/infer/sdk/main.py new file mode 100644 index 000000000..63d09c41a --- /dev/null +++ b/research/cv/GhostSR/infer/sdk/main.py @@ -0,0 +1,45 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""run sdk infer""" +import argparse +import os + +from sr_infer_wrapper import SRInferWrapper + + +def parser_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type=str, default="../data/DIV2K/input/", + help="path of input images directory") + parser.add_argument("--pipeline_path", type=str, default="../data/config/edsr.pipeline", + help="path of pipeline file") + parser.add_argument("--output_dir", type=str, default="../data/sdk_out/", + help="path of output images directory") + return parser.parse_args() + + +if __name__ == '__main__': + args = parser_args() + sr_infer = SRInferWrapper() + sr_infer.load_pipeline(args.pipeline_path) + path_list = os.listdir(args.input_dir) + path_list.sort() + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + for img_path in path_list: + print(img_path) + res = sr_infer.do_infer(os.path.join(args.input_dir, img_path)) + res.save(os.path.join(args.output_dir, img_path.replace('x2', '_infer'))) diff --git a/research/cv/GhostSR/infer/sdk/run.sh b/research/cv/GhostSR/infer/sdk/run.sh new file mode 100644 index 000000000..f1b9ff91f --- /dev/null +++ b/research/cv/GhostSR/infer/sdk/run.sh @@ -0,0 +1,76 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# The number of parameters must be 2. +if [ $# -ne 3 ] +then + echo "Wrong parameter format." + echo "Usage:" + echo " bash $0 [INPUT_PATH] [PIPELINE_PATH] [OUTPUT_PATH]" + echo "Example: " + echo " bash run.sh ../data/DIV2K/input/ ../data/config/edsr.pipeline ../data/sdk_out/" + + exit 1 +fi + +# The path of a folder containing eval images. +input_dir=$1 +# The path of pipeline file. +pipeline_path=$2 +# The path of a folder used to store all results. +output_dir=$3 + + +if [ ! -d $input_dir ] +then + echo "Please input the correct directory containing images." + exit +fi + +if [ ! -d $output_dir ] +then + mkdir -p $output_dir +fi + +set -e + +CUR_PATH=$(cd "$(dirname "$0")" || { warn "Failed to check path/to/run.sh" ; exit ; } ; pwd) +echo "enter $CUR_PATH" + +# Simple log helper functions +info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; } +warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; } + +export LD_LIBRARY_PATH=${MX_SDK_HOME}/lib:${MX_SDK_HOME}/opensource/lib:${MX_SDK_HOME}/opensource/lib64:/usr/local/Ascend/${ASCEND_VERSION}/latest/acllib/lib64:${LD_LIBRARY_PATH} + +#to set PYTHONPATH, import the StreamManagerApi.py +export PYTHONPATH=$PYTHONPATH:${MX_SDK_HOME}/python + +if [ ! "${MX_SDK_HOME}" ] +then +export GST_PLUGIN_SCANNER=${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner +fi + +if [ ! "${MX_SDK_HOME}" ] +then +export GST_PLUGIN_PATH=${MX_SDK_HOME}/opensource/lib/gstreamer-1.0:${MX_SDK_HOME}/lib/plugins +fi + +python3 main.py --input_dir=$input_dir \ + --pipeline_path=$pipeline_path \ + --output_dir=$output_dir \ + +exit 0 \ No newline at end of file diff --git a/research/cv/GhostSR/infer/sdk/sr_infer_wrapper.py b/research/cv/GhostSR/infer/sdk/sr_infer_wrapper.py new file mode 100644 index 000000000..d71e36e60 --- /dev/null +++ b/research/cv/GhostSR/infer/sdk/sr_infer_wrapper.py @@ -0,0 +1,127 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""super resolution infer wrapper""" +import json + +import MxpiDataType_pb2 as MxpiDataType +import cv2 +import numpy as np +from PIL import Image +from StreamManagerApi import StreamManagerApi, StringVector, InProtobufVector, MxProtobufIn + +DEFAULT_IMAGE_WIDTH = 1020 +DEFAULT_IMAGE_HEIGHT = 1020 +CHANNELS = 3 +SCALE = 2 + + +def padding(img, target_shape): + h, w = target_shape[0], target_shape[1] + img_h, img_w, _ = img.shape + dh, dw = h - img_h, w - img_w + if dh < 0 or dw < 0: + raise RuntimeError(f"target_shape is bigger than img.shape, {target_shape} > {img.shape}") + if dh != 0 or dw != 0: + img = np.pad(img, ((0, int(dh)), (0, int(dw)), (0, 0)), "reflect") + return img + + +def unpadding(img, target_shape): + h, w = target_shape[0], target_shape[1] + img_h, img_w, _ = img.shape + if img_h > h: + img = img[:h, :, :] + if img_w > w: + img = img[:, :w, :] + return img + + +class SRInferWrapper: + """super resolution infer wrapper""" + + def __init__(self): + self.stream_name = None + self.streamManagerApi = StreamManagerApi() + # init stream manager + if self.streamManagerApi.InitManager() != 0: + raise RuntimeError("Failed to init stream manager.") + + def load_pipeline(self, pipeline_path): + # create streams by pipeline config file + with open(pipeline_path, 'r') as f: + pipeline = json.load(f) + self.stream_name = list(pipeline.keys())[0].encode() + pipelineStr = json.dumps(pipeline).encode() + if self.streamManagerApi.CreateMultipleStreams(pipelineStr) != 0: + raise RuntimeError("Failed to create stream.") + + def do_infer(self, image_path): + """do infer process""" + # construct the input of the stream + image = cv2.imread(image_path) + ori_h, ori_w, _ = image.shape + image = padding(image, (DEFAULT_IMAGE_HEIGHT, DEFAULT_IMAGE_WIDTH)) + tensor_pkg_list = MxpiDataType.MxpiTensorPackageList() + tensor_pkg = tensor_pkg_list.tensorPackageVec.add() + tensor_vec = tensor_pkg.tensorVec.add() + tensor_vec.deviceId = 0 + tensor_vec.memType = 0 + + for dim in [1, *image.shape]: + tensor_vec.tensorShape.append(dim) + + input_data = image.tobytes() + tensor_vec.dataStr = input_data + tensor_vec.tensorDataSize = len(input_data) + + protobuf_vec = InProtobufVector() + protobuf = MxProtobufIn() + protobuf.key = b'appsrc0' + protobuf.type = b'MxTools.MxpiTensorPackageList' + protobuf.protobuf = tensor_pkg_list.SerializeToString() + protobuf_vec.push_back(protobuf) + + unique_id = self.streamManagerApi.SendProtobuf( + self.stream_name, 0, protobuf_vec) + if unique_id < 0: + raise RuntimeError("Failed to send data to stream.") + + # get plugin output data + key = b"mxpi_tensorinfer0" + keyVec = StringVector() + keyVec.push_back(key) + inferResult = self.streamManagerApi.GetProtobuf(self.stream_name, 0, keyVec) + if inferResult.size() == 0: + raise RuntimeError("inferResult is null") + if inferResult[0].errorCode != 0: + raise RuntimeError("GetProtobuf error. errorCode=%d, errorMsg=%s" % ( + inferResult[0].errorCode, inferResult[0].messageName.decode())) + + # get the infer result + inferList0 = MxpiDataType.MxpiTensorPackageList() + inferList0.ParseFromString(inferResult[0].messageBuf) + inferVisionData = inferList0.tensorPackageVec[0].tensorVec[0].dataStr + + # converting the byte data into 32 bit float array + output_img_data = np.frombuffer(inferVisionData, dtype=np.float32) + output_img_data = np.clip(output_img_data, 0, 255) + output_img_data = np.round(output_img_data).astype(np.uint8) + output_img_data = np.reshape(output_img_data, ( + CHANNELS, SCALE * DEFAULT_IMAGE_HEIGHT, SCALE * DEFAULT_IMAGE_WIDTH)) + output_img_data = output_img_data.transpose((1, 2, 0)) + output_img_data = unpadding(output_img_data, (SCALE * ori_h, SCALE * ori_w)) + result = Image.fromarray(output_img_data[..., ::-1]) + + return result diff --git a/research/cv/GhostSR/mindspore_hub_conf.py b/research/cv/GhostSR/mindspore_hub_conf.py new file mode 100644 index 000000000..684cc34ea --- /dev/null +++ b/research/cv/GhostSR/mindspore_hub_conf.py @@ -0,0 +1,26 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""hub config.""" +from src.edsr import EDSR + + +def edsr(*args, **kwargs): + return EDSR(*args, **kwargs) + + +def create_network(name, *args, **kwargs): + if name == "edsr": + return edsr(*args, **kwargs) + raise NotImplementedError(f"{name} is not implemented in the repo") diff --git a/research/cv/GhostSR/model_utils/__init__.py b/research/cv/GhostSR/model_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/research/cv/GhostSR/model_utils/config.py b/research/cv/GhostSR/model_utils/config.py new file mode 100644 index 000000000..a19f176b8 --- /dev/null +++ b/research/cv/GhostSR/model_utils/config.py @@ -0,0 +1,136 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Parse arguments""" + +import argparse +import ast +import os +from pprint import pformat + +import yaml + + +class Config: + """ + Configuration namespace. Convert dictionary to members. + """ + + def __init__(self, cfg_dict): + for k, v in cfg_dict.items(): + if isinstance(v, (list, tuple)): + setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v]) + else: + setattr(self, k, Config(v) if isinstance(v, dict) else v) + + def __str__(self): + return pformat(self.__dict__) + + def __repr__(self): + return self.__str__() + + +def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"): + """ + Parse command line arguments to the configuration according to the default yaml. + + Args: + parser: Parent parser. + cfg: Base configuration. + helper: Helper description. + cfg_path: Path to the default yaml config. + """ + parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]", + parents=[parser]) + helper = {} if helper is None else helper + choices = {} if choices is None else choices + for item in cfg: + if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict): + help_description = helper[item] if item in helper else "Please reference to {}".format( + cfg_path) + choice = choices[item] if item in choices else None + if isinstance(cfg[item], bool): + parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], + choices=choice, + help=help_description) + else: + parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], + choices=choice, + help=help_description) + args = parser.parse_args() + return args + + +def parse_yaml(yaml_path): + """ + Parse the yaml config file. + + Args: + yaml_path: Path to the yaml config. + """ + with open(yaml_path, 'r') as fin: + try: + cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader) + cfgs = [x for x in cfgs] + if len(cfgs) == 1: + cfg_helper = {} + cfg = cfgs[0] + cfg_choices = {} + elif len(cfgs) == 2: + cfg, cfg_helper = cfgs + cfg_choices = {} + elif len(cfgs) == 3: + cfg, cfg_helper, cfg_choices = cfgs + else: + raise ValueError( + "At most 3 docs (config, description for help, choices) are supported in config yaml") + except: + raise ValueError("Failed to parse yaml") + return cfg, cfg_helper, cfg_choices + + +def merge(args, cfg): + """ + Merge the base config from yaml file and command line arguments. + + Args: + args: Command line arguments. + cfg: Base configuration. + """ + args_var = vars(args) + for item in args_var: + cfg[item] = args_var[item] + return cfg + + +def get_config(): + """ + Get Config according to the yaml file and cli arguments. + """ + parser = argparse.ArgumentParser(description="default name", add_help=False) + current_dir = os.path.dirname(os.path.abspath(__file__)) + parser.add_argument("--config_path", type=str, + default=os.path.join(current_dir, "../default_config.yaml"), + help="Config file path") + path_args, _ = parser.parse_known_args() + default, helper, choices = parse_yaml(path_args.config_path) + args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, + cfg_path=path_args.config_path) + final_config = merge(args, default) + final_config = Config(final_config) + return final_config + + +config = get_config() diff --git a/research/cv/GhostSR/model_utils/device_adapter.py b/research/cv/GhostSR/model_utils/device_adapter.py new file mode 100644 index 000000000..7c5d7f837 --- /dev/null +++ b/research/cv/GhostSR/model_utils/device_adapter.py @@ -0,0 +1,27 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Device adapter for ModelArts""" + +from .config import config + +if config.enable_modelarts: + from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id +else: + from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id + +__all__ = [ + "get_device_id", "get_device_num", "get_rank_id", "get_job_id" +] diff --git a/research/cv/GhostSR/model_utils/local_adapter.py b/research/cv/GhostSR/model_utils/local_adapter.py new file mode 100644 index 000000000..8a1b1fa1f --- /dev/null +++ b/research/cv/GhostSR/model_utils/local_adapter.py @@ -0,0 +1,37 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Local adapter""" + +import os + + +def get_device_id(): + device_id = os.getenv('DEVICE_ID', '0') + return int(device_id) + + +def get_device_num(): + device_num = os.getenv('RANK_SIZE', '1') + return int(device_num) + + +def get_rank_id(): + global_rank_id = os.getenv('RANK_ID', '0') + return int(global_rank_id) + + +def get_job_id(): + return "Local Job" diff --git a/research/cv/GhostSR/model_utils/moxing_adapter.py b/research/cv/GhostSR/model_utils/moxing_adapter.py new file mode 100644 index 000000000..b120426e0 --- /dev/null +++ b/research/cv/GhostSR/model_utils/moxing_adapter.py @@ -0,0 +1,124 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Moxing adapter for ModelArts""" + +import functools +import os + +from mindspore import context + +from .config import config + +_global_sync_count = 0 + + +def get_device_id(): + device_id = os.getenv('DEVICE_ID', '0') + return int(device_id) + + +def get_device_num(): + device_num = os.getenv('RANK_SIZE', '1') + return int(device_num) + + +def get_rank_id(): + global_rank_id = os.getenv('RANK_ID', '0') + return int(global_rank_id) + + +def get_job_id(): + job_id = os.getenv('JOB_ID') + job_id = job_id if job_id != "" else "default" + return job_id + + +def sync_data(from_path, to_path): + """ + Download data from remote obs to local directory if the first url is remote url and the second one is local path + Upload data from local directory to remote obs in contrast. + """ + import moxing as mox + import time + global _global_sync_count + sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count) + _global_sync_count += 1 + + # Each server contains 8 devices as most. + if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): + print("from path: ", from_path) + print("to path: ", to_path) + mox.file.copy_parallel(from_path, to_path) + print("===finish data synchronization===") + try: + os.mknod(sync_lock) + except IOError: + pass + print("===save flag===") + + while True: + if os.path.exists(sync_lock): + break + time.sleep(1) + + print("Finish sync data from {} to {}.".format(from_path, to_path)) + + +def moxing_wrapper(pre_process=None, post_process=None): + """ + Moxing wrapper to download dataset and upload outputs. + """ + + def wrapper(run_func): + @functools.wraps(run_func) + def wrapped_func(*args, **kwargs): + # Download data from data_url + if config.enable_modelarts: + if config.data_url: + sync_data(config.data_url, config.data_path) + print("Dataset downloaded: ", os.listdir(config.data_path)) + if config.checkpoint_url: + sync_data(config.checkpoint_url, config.load_path) + print("Preload downloaded: ", os.listdir(config.load_path)) + if config.train_url: + sync_data(config.train_url, config.output_path) + print("Workspace downloaded: ", os.listdir(config.output_path)) + + context.set_context( + save_graphs_path=os.path.join(config.output_path, str(get_rank_id()))) + config.device_num = get_device_num() + config.device_id = get_device_id() + if not os.path.exists(config.output_path): + os.makedirs(config.output_path) + + if pre_process: + pre_process() + + # Run the main function + run_func(*args, **kwargs) + + # Upload data to train_url + if config.enable_modelarts: + if post_process: + post_process() + + if config.train_url: + print("Start to copy output directory") + sync_data(config.output_path, config.train_url) + + return wrapped_func + + return wrapper diff --git a/research/cv/GhostSR/modelarts/train_start.py b/research/cv/GhostSR/modelarts/train_start.py new file mode 100644 index 000000000..b5e7c17d8 --- /dev/null +++ b/research/cv/GhostSR/modelarts/train_start.py @@ -0,0 +1,130 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train edsr om modelarts""" +import argparse +import os +import subprocess + +import moxing as mox + +_CACHE_DATA_URL = "/cache/data_url" +_CACHE_TRAIN_URL = "/cache/train_url" + + +def _parse_args(): + """parse arguments""" + parser = argparse.ArgumentParser(description='train and export edsr on modelarts') + # train output path + parser.add_argument('--train_url', type=str, default='', + help='where training log and ckpts saved') + # dataset dir + parser.add_argument('--data_url', type=str, default='', + help='where training log and ckpts saved') + # train config + parser.add_argument('--data_train', type=str, default='DIV2K', help='train dataset name') + parser.add_argument('--epochs', type=int, default=1, help='number of epochs to train') + parser.add_argument('--batch_size', type=int, default=16, help='input batch size for training') + parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') + parser.add_argument('--init_loss_scale', type=float, default=65536., help='scaling factor') + parser.add_argument('--loss_scale', type=float, default=1024.0, help='loss_scale') + parser.add_argument('--scale', type=str, default='2', help='super resolution scale') + parser.add_argument('--ckpt_save_path', type=str, default='ckpt', help='path to save ckpt') + parser.add_argument('--ckpt_save_interval', type=int, default=10, + help='save ckpt frequency, unit is epoch') + parser.add_argument('--ckpt_save_max', type=int, default=5, help='max number of saved ckpt') + parser.add_argument('--task_id', type=int, default=0) + # export config + parser.add_argument("--export_batch_size", type=int, default=1, help="batch size") + parser.add_argument("--export_file_name", type=str, default="edsr", help="output file name.") + parser.add_argument("--export_file_format", type=str, default="AIR", + choices=['MINDIR', 'AIR', 'ONNX'], help="file format") + args, _ = parser.parse_known_args() + + return args + + +def _train(args, data_url): + """use train.py""" + pwd = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + train_file = os.path.join(pwd, "train.py") + + cmd = ["python", train_file, + f"--dir_data={os.path.abspath(data_url)}", + f"--data_train={args.data_train}", + f"--epochs={args.epochs}", + f"--batch_size={args.batch_size}", + f"--lr={args.lr}", + f"--init_loss_scale={args.init_loss_scale}", + f"--loss_scale={args.loss_scale}", + f"--scale={args.scale}", + f"--task_id={args.task_id}", + f"--ckpt_save_path={os.path.join(_CACHE_TRAIN_URL, args.ckpt_save_path)}", + f"--ckpt_save_interval={args.ckpt_save_interval}", + f"--ckpt_save_max={args.ckpt_save_max}"] + + print(' '.join(cmd)) + process = subprocess.Popen(cmd, shell=False) + return process.wait() + + +def _get_last_ckpt(ckpt_dir): + """get the last ckpt path""" + file_dict = {} + lists = os.listdir(ckpt_dir) + if not lists: + print("No ckpt file found.") + return None + for i in lists: + ctime = os.stat(os.path.join(ckpt_dir, i)).st_ctime + file_dict[ctime] = i + max_ctime = max(file_dict.keys()) + ckpt_file = os.path.join(ckpt_dir, file_dict[max_ctime]) + + return ckpt_file + + +def _export_air(args, ckpt_dir): + """export""" + ckpt_file = _get_last_ckpt(ckpt_dir) + if not ckpt_file: + return + pwd = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + export_file = os.path.join(pwd, "export.py") + cmd = ["python", export_file, + f"--batch_size={args.export_batch_size}", + f"--ckpt_path={ckpt_file}", + f"--file_name={os.path.join(_CACHE_TRAIN_URL, args.export_file_name)}", + f"--file_format={args.export_file_format}"] + print(f"Start exporting, cmd = {' '.join(cmd)}.") + process = subprocess.Popen(cmd, shell=False) + process.wait() + + +def main(): + args = _parse_args() + + os.makedirs(_CACHE_TRAIN_URL, exist_ok=True) + os.makedirs(_CACHE_DATA_URL, exist_ok=True) + + mox.file.copy_parallel(args.data_url, _CACHE_DATA_URL) + data_url = _CACHE_DATA_URL + + _train(args, data_url) + _export_air(args, os.path.join(_CACHE_TRAIN_URL, args.ckpt_save_path)) + mox.file.copy_parallel(_CACHE_TRAIN_URL, args.train_url) + + +if __name__ == '__main__': + main() diff --git a/research/cv/GhostSR/postprocess.py b/research/cv/GhostSR/postprocess.py new file mode 100644 index 000000000..d5389ef09 --- /dev/null +++ b/research/cv/GhostSR/postprocess.py @@ -0,0 +1,193 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +'''post process for 310 inference''' +import math +import os + +import numpy as np +from PIL import Image +from mindspore import Tensor + +from model_utils.config import config +from model_utils.moxing_adapter import moxing_wrapper +from src.dataset import FolderImagePair, AUG_DICT +from src.metric import PSNR +from src.utils import init_env, modelarts_pre_process + + +def read_bin(bin_path): + img = np.fromfile(bin_path, dtype=np.float32) + num_pix = img.size + img_shape = int(math.sqrt(num_pix // 3)) + if 1 * 3 * img_shape * img_shape != num_pix: + raise RuntimeError(f'bin file error, it not output from edsr network, {bin_path}') + img = img.reshape(1, 3, img_shape, img_shape) + return img + + +def read_bin_as_hwc(bin_path): + nchw_img = read_bin(bin_path) + chw_img = np.squeeze(nchw_img) + hwc_img = chw_img.transpose(1, 2, 0) + return hwc_img + + +def unpadding(img, target_shape): + h, w = target_shape[0], target_shape[1] + img_h, img_w, _ = img.shape + if img_h > h: + img = img[:h, :, :] + if img_w > w: + img = img[:, :w, :] + return img + + +def img_to_tensor(img): + img = np.array([img.transpose(2, 0, 1)], np.float32) + img = Tensor(img) + return img + + +def float_to_uint8(img): + clip_img = np.clip(img, 0, 255) + round_img = np.round(clip_img) + uint8_img = round_img.astype(np.uint8) + return uint8_img + + +def bin_to_png(cfg): + """ + bin from ascend310_infer outputs will be covert to png + """ + dataset_path = cfg.data_path + dataset_type = "valid" + aug_keys = list(AUG_DICT.keys()) + lr_scale = cfg.scale + + if cfg.self_ensemble: + dir_sr_bin = os.path.join(dataset_path, f"DIV2K_{dataset_type}_SR_bin", f"X{lr_scale}") + save_sr_se_dir = os.path.join(dataset_path, f"DIV2K_{dataset_type}_SR_self_ensemble", + f"X{lr_scale}") + if os.path.isdir(dir_sr_bin): + os.makedirs(save_sr_se_dir, exist_ok=True) + bin_patterns = [os.path.join(dir_sr_bin, f"*x{lr_scale}_{a_key}_0.bin") for a_key in + aug_keys] + dataset = FolderImagePair(bin_patterns, reader=read_bin_as_hwc) + for i in range(len(dataset)): + img_key = dataset.get_key(i) + sr_se_path = os.path.join(save_sr_se_dir, f"{img_key}x{lr_scale}.png") + if os.path.isfile(sr_se_path): + continue + data = dataset[i] + img_key, sr_8 = data[0], data[1:] + sr = np.zeros_like(sr_8[0], dtype=np.float64) + for img, a_key in zip(sr_8, aug_keys): + aug = AUG_DICT[a_key] + for a in reversed(aug): + img = a(img) + sr += img + sr /= len(sr_8) + sr = float_to_uint8(sr) + Image.fromarray(sr).save(sr_se_path) + print(f"merge sr bin save to {sr_se_path}") + return + + if not cfg.self_ensemble: + dir_sr_bin = os.path.join(dataset_path, f"DIV2K_{dataset_type}_SR_bin", f"X{lr_scale}") + save_sr_dir = os.path.join(dataset_path, f"DIV2K_{dataset_type}_SR", f"X{lr_scale}") + if os.path.isdir(dir_sr_bin): + os.makedirs(save_sr_dir, exist_ok=True) + bin_patterns = [os.path.join(dir_sr_bin, f"*x{lr_scale}_0_0.bin")] + dataset = FolderImagePair(bin_patterns, reader=read_bin_as_hwc) + for i in range(len(dataset)): + img_key = dataset.get_key(i) + sr_path = os.path.join(save_sr_dir, f"{img_key}x{lr_scale}.png") + if os.path.isfile(sr_path): + continue + img_key, sr = dataset[i] + sr = float_to_uint8(sr) + Image.fromarray(sr).save(sr_path) + print(f"merge sr bin save to {sr_path}") + return + + +def get_hr_sr_dataset(cfg): + """ + make hr sr dataset + """ + dataset_path = cfg.data_path + dataset_type = "valid" + lr_scale = cfg.scale + + dir_patterns = [] + + # get HR_PATH/*.png + dir_hr = os.path.join(dataset_path, f"DIV2K_{dataset_type}_HR") + hr_pattern = os.path.join(dir_hr, "*.png") + dir_patterns.append(hr_pattern) + + # get LR_PATH/X2/*x2.png, LR_PATH/X3/*x3.png, LR_PATH/X4/*x4.png + se = "_self_ensemble" if cfg.self_ensemble else "" + + dir_sr = os.path.join(dataset_path, f"DIV2K_{dataset_type}_SR" + se, f"X{lr_scale}") + if not os.path.isdir(dir_sr): + raise RuntimeError(f'{dir_sr} is not a dir for saving sr') + sr_pattern = os.path.join(dir_sr, f"*x{lr_scale}.png") + dir_patterns.append(sr_pattern) + + # make dataset + dataset = FolderImagePair(dir_patterns) + return dataset + + +@moxing_wrapper(pre_process=modelarts_pre_process) +def run_post_process(): + """ + run post process + """ + print(config, flush=True) + cfg = config + lr_scale = cfg.scale + + init_env(cfg) + + print("begin to run bin_to_png...") + bin_to_png(cfg) + print("bin_to_png finish") + + dataset = get_hr_sr_dataset(cfg) + + metrics = { + "psnr": PSNR(rgb_range=cfg.rgb_range, shave=6 + lr_scale), + } + + total_step = len(dataset) + setw = len(str(total_step)) + for i in range(len(dataset)): + _, hr, sr = dataset[i] + sr = unpadding(sr, hr.shape) + sr = img_to_tensor(sr) + hr = img_to_tensor(hr) + _ = [m.update(sr, hr) for m in metrics.values()] + result = {k: m.eval(sync=False) for k, m in metrics.items()} + print(f"[{i + 1:>{setw}}/{total_step:>{setw}}] result = {result}", flush=True) + result = {k: m.eval(sync=False) for k, m in metrics.items()} + print(f"evaluation result = {result}", flush=True) + + print("post_process success", flush=True) + + +if __name__ == "__main__": + run_post_process() diff --git a/research/cv/GhostSR/preprocess.py b/research/cv/GhostSR/preprocess.py new file mode 100644 index 000000000..cc66d5db8 --- /dev/null +++ b/research/cv/GhostSR/preprocess.py @@ -0,0 +1,100 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +'''pre process for 310 inference''' +import os + +import numpy as np +from PIL import Image + +from model_utils.config import config +from model_utils.moxing_adapter import moxing_wrapper +from src.dataset import FolderImagePair, AUG_DICT +from src.utils import modelarts_pre_process + +MAX_HR_SIZE = 2040 + + +def padding(img, target_shape): + h, w = target_shape[0], target_shape[1] + img_h, img_w, _ = img.shape + dh, dw = h - img_h, w - img_w + if dh < 0 or dw < 0: + raise RuntimeError(f"target_shape is bigger than img.shape, {target_shape} > {img.shape}") + if dh != 0 or dw != 0: + img = np.pad(img, ((0, dh), (0, dw), (0, 0)), "constant") + return img + + +def get_lr_dataset(cfg): + """ + get lr dataset + """ + dataset_path = cfg.data_path + lr_scale = cfg.scale + lr_type = cfg.lr_type + dataset_type = "valid" + self_ensemble = "_self_ensemble" if cfg.self_ensemble else "" + + # get LR_PATH/X2/*x2.png, LR_PATH/X3/*x3.png, LR_PATH/X4/*x4.png + lrs_pattern = [] + dir_lr = os.path.join(dataset_path, f"DIV2K_{dataset_type}_LR_{lr_type}", f"X{lr_scale}") + lr_pattern = os.path.join(dir_lr, f"*x{lr_scale}.png") + lrs_pattern.append(lr_pattern) + save_dir = os.path.join(dataset_path, f"DIV2K_{dataset_type}_LR_{lr_type}_AUG{self_ensemble}", + f"X{lr_scale}") + os.makedirs(save_dir, exist_ok=True) + save_format = os.path.join(save_dir, "{}" + f"x{lr_scale}" + "_{}.png") + + # make dataset + dataset = FolderImagePair(lrs_pattern) + + return dataset, save_format + + +@moxing_wrapper(pre_process=modelarts_pre_process) +def run_pre_process(): + """ + run pre process + """ + print(config) + cfg = config + + aug_dict = AUG_DICT + if not cfg.self_ensemble: + aug_dict = {"0": AUG_DICT["0"]} + + dataset, save_format = get_lr_dataset(cfg) + for i in range(len(dataset)): + img_key = dataset.get_key(i) + org_img = None + for a_key, aug in aug_dict.items(): + save_path = save_format.format(img_key, a_key) + if os.path.isfile(save_path): + continue + if org_img is None: + _, lr = dataset[i] + target_shape = [MAX_HR_SIZE // cfg.scale, MAX_HR_SIZE // cfg.scale] + org_img = padding(lr, target_shape) + img = org_img.copy() + for a in aug: + img = a(img) + Image.fromarray(img).save(save_path) + print(f"[{i + 1}/{len(dataset)}]\tsave {save_path}\tshape = {img.shape}", flush=True) + + print("pre_process success", flush=True) + + +if __name__ == "__main__": + run_pre_process() diff --git a/research/cv/GhostSR/requirements.txt b/research/cv/GhostSR/requirements.txt new file mode 100644 index 000000000..ec5f17ff9 --- /dev/null +++ b/research/cv/GhostSR/requirements.txt @@ -0,0 +1,7 @@ +# onnxruntime-gpu +pillow~=9.3.0 +numpy~=1.19.5 +pyyaml~=5.1 +matplotlib~=3.5.1 +torch~=1.5.1 +mmcv~=1.7.0 \ No newline at end of file diff --git a/research/cv/GhostSR/scripts/run_eval.sh b/research/cv/GhostSR/scripts/run_eval.sh new file mode 100644 index 000000000..9d6e209cd --- /dev/null +++ b/research/cv/GhostSR/scripts/run_eval.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# -lt 1 ] +then + echo "Usage: sh scripts/run_eval.sh [RANK_TABLE_FILE] --opt1 opt1_value --opt2 opt2_value ..." +exit 1 +fi + +if [ ! -f $1 ] +then + echo "error: RANK_TABLE_FILE=$1 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=8 +export RANK_SIZE=8 +PATH1=$(realpath $1) +export RANK_TABLE_FILE=$PATH1 +echo "RANK_TABLE_FILE=${PATH1}" + +export PYTHONPATH=$PWD:$PYTHONPATH +export SERVER_ID=0 +rank_start=$((DEVICE_NUM * SERVER_ID)) +for((i=0; i<${DEVICE_NUM}; i++)) +do + export DEVICE_ID=$i + export RANK_ID=$((rank_start + i)) + rm -rf ./eval_parallel$i + mkdir ./eval_parallel$i + cp -r ./src ./eval_parallel$i + cp -r ./model_utils ./eval_parallel$i + cp -r ./*.yaml ./eval_parallel$i + cp ./eval.py ./eval_parallel$i + echo "start evaluation for rank $RANK_ID, device $DEVICE_ID" + cd ./eval_parallel$i ||exit + env > env.log + export args=${*:2} + python eval.py $args > eval.log 2>&1 & + cd .. +done diff --git a/research/cv/GhostSR/scripts/run_eval_onnx.sh b/research/cv/GhostSR/scripts/run_eval_onnx.sh new file mode 100644 index 000000000..da58aac3b --- /dev/null +++ b/research/cv/GhostSR/scripts/run_eval_onnx.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash run.sh DEVICE_ID CKPT_PATH" +echo "For example: bash scripts/run_eval_onnx.sh ./DIV2K_config.yaml 2 DIV2K path output_path pre_trained_model_path ONNX" +echo "It is better to use the absolute path." +echo "==============================================================================================================" + +if [ $# != 6 ] +then + echo "Usage: bash scripts/run_eval_onnx.sh [config_path] [scale] [data_path] [output_path] [pre_trained_model_path] [eval_type]" +exit 1 +fi + +export args=${*:1} +python eval_onnx.py --config_path $1 --scale $2 --data_path $3 --output_path $4 --pre_trained $5 --eval_type $6 > eval_onnx.log 2>&1 & diff --git a/research/cv/GhostSR/scripts/run_infer_310.sh b/research/cv/GhostSR/scripts/run_infer_310.sh new file mode 100644 index 000000000..2a49eeeba --- /dev/null +++ b/research/cv/GhostSR/scripts/run_infer_310.sh @@ -0,0 +1,135 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [[ $# -lt 3 || $# -gt 5 ]]; then + echo "Usage: bash scripts/run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [SCALE] [LOG_FILE] [DEVICE_ID] + DEVICE_ID is optional, it can be set by environment variable device_id, default: 0" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +model=$(get_real_path $1) +data_path=$(get_real_path $2) +scale=$3 + +if [[ $scale -ne "2" && $scale -ne "3" && $scale -ne "4" ]]; then + echo "[SCALE] should be in [2,3,4]" +exit 1 +fi + +log_file="./run_infer.log" +if [ $# -gt 4 ]; then + log_file=$4 +fi +log_file=$(get_real_path $log_file) + + +device_id=0 +if [ $# == 5 ]; then + device_id=$5 +fi + +self_ensemble="True" + +echo "***************** param *****************" +echo "mindir name: "$model +echo "dataset path: "$data_path +echo "scale: "$scale +echo "log file: "$log_file +echo "device id: "$device_id +echo "self_ensemble: "$self_ensemble +echo "***************** param *****************" + +function compile_app() +{ + echo "begin to compile app..." + cd ./ascend310_infer || exit + bash build.sh >> $log_file 2>&1 + cd - + echo "finshi compile app" +} + +function preprocess() +{ + echo "begin to preprocess..." + export DEVICE_ID=$device_id + export RANK_SIZE=1 + python preprocess.py --data_path=$data_path --config_path=DIV2K_config.yaml --device_target=CPU --scale=$scale --self_ensemble=$self_ensemble >> $log_file 2>&1 + echo "finshi preprocess" +} + +function infer() +{ + echo "begin to infer..." + if [ $self_ensemble == "True" ]; then + read_data_path=$data_path"/DIV2K_valid_LR_bicubic_AUG_self_ensemble/X"$scale + else + read_data_path=$data_path"/DIV2K_valid_LR_bicubic_AUG/X"$scale + fi + save_data_path=$data_path"/DIV2K_valid_SR_bin/X"$scale + if [ -d $save_data_path ]; then + rm -rf $save_data_path + fi + mkdir -p $save_data_path + ./ascend310_infer/out/main --mindir_path=$model --dataset_path=$read_data_path --device_id=$device_id --save_dir=$save_data_path >> $log_file 2>&1 + echo "finshi infer" +} + +function postprocess() +{ + echo "begin to postprocess..." + export DEVICE_ID=$device_id + export RANK_SIZE=1 + python postprocess.py --data_path=$data_path --config_path=DIV2K_config.yaml --device_target=CPU --scale=$scale --self_ensemble=$self_ensemble >> $log_file 2>&1 + echo "finshi postprocess" +} + +echo "" > $log_file +echo "read the log command: " +echo " tail -f $log_file" + +compile_app +if [ $? -ne 0 ]; then + echo "compile app code failed, check $log_file" + exit 1 +fi + +preprocess +if [ $? -ne 0 ]; then + echo "preprocess code failed, check $log_file" + exit 1 +fi + +infer +if [ $? -ne 0 ]; then + echo " execute inference failed, check $log_file" + exit 1 +fi + +postprocess +if [ $? -ne 0 ]; then + echo "postprocess failed, check $log_file" + exit 1 +fi + +cat $log_file | tail -n 3 | head -n 1 diff --git a/research/cv/GhostSR/scripts/run_train.sh b/research/cv/GhostSR/scripts/run_train.sh new file mode 100644 index 000000000..5f03e2ee4 --- /dev/null +++ b/research/cv/GhostSR/scripts/run_train.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# -lt 1 ] +then + echo "Usage: sh scripts/run_train.sh [RANK_TABLE_FILE] --opt1 opt1_value --opt2 opt2_value ..." +exit 1 +fi + +if [ ! -f $1 ] +then + echo "error: RANK_TABLE_FILE=$1 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=8 +export RANK_SIZE=8 +PATH1=$(realpath $1) +export RANK_TABLE_FILE=$PATH1 +echo "RANK_TABLE_FILE=${PATH1}" + +export PYTHONPATH=$PWD:$PYTHONPATH +export SERVER_ID=0 +rank_start=$((DEVICE_NUM * SERVER_ID)) +for((i=0; i<${DEVICE_NUM}; i++)) +do + export DEVICE_ID=$i + export RANK_ID=$((rank_start + i)) + rm -rf ./train_parallel$i + mkdir ./train_parallel$i + cp -r ./src ./train_parallel$i + cp -r ./model_utils ./train_parallel$i + cp -r ./*.yaml ./train_parallel$i + cp ./train.py ./train_parallel$i + echo "start training for rank $RANK_ID, device $DEVICE_ID" + cd ./train_parallel$i ||exit + env > env.log + export args=${*:2} + python train.py $args > train.log 2>&1 & + cd .. +done diff --git a/research/cv/GhostSR/src/__init__.py b/research/cv/GhostSR/src/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/research/cv/GhostSR/src/dataset.py b/research/cv/GhostSR/src/dataset.py new file mode 100644 index 000000000..4c3143487 --- /dev/null +++ b/research/cv/GhostSR/src/dataset.py @@ -0,0 +1,333 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Data operations, will be used in train.py and eval.py +""" +import glob +import os +import random +import re +from functools import reduce + +import mindspore.dataset as ds +import numpy as np +from PIL import Image + + +def get_rank_info(): + """ + get rank size and rank id + """ + from model_utils.moxing_adapter import get_rank_id, get_device_num + return get_device_num(), get_rank_id() + + +class FolderImagePair: + """ + get image pair + dir_patterns(list): a list of image path patterns. such as ["/LR/*.jpg", "/HR/*.png"...] + the file key is matched chars from * and ? + reader(object/func): a method to read image by path. + """ + + def __init__(self, dir_patterns, reader=None): + self.dir_patterns = dir_patterns + self.reader = reader + self.pair_keys, self.image_pairs = self.scan_pair(self.dir_patterns) + + @staticmethod + def scan_pair(dir_patterns): + """ + scan pair + """ + images = [] + for _dir in dir_patterns: + imgs = glob.glob(_dir) + _dir = os.path.basename(_dir) + pat = _dir.replace("*", "(.*)").replace("?", "(.?)") + pat = re.compile(pat, re.I | re.M) + keys = [re.findall(pat, os.path.basename(p))[0] for p in imgs] + images.append({k: v for k, v in zip(keys, imgs)}) + same_keys = reduce(lambda x, y: set(x) & set(y), images) + same_keys = sorted(same_keys) + image_pairs = [[d[k] for d in images] for k in same_keys] + same_keys = [x if isinstance(x, str) else "_".join(x) for x in same_keys] + return same_keys, image_pairs + + def get_key(self, idx): + return self.pair_keys[idx] + + def __getitem__(self, idx): + if self.reader is None: + images = [Image.open(p) for p in self.image_pairs[idx]] + images = [img.convert('RGB') for img in images] + images = [np.array(img) for img in images] + else: + images = [self.reader(p) for p in self.image_pairs[idx]] + pair_key = self.pair_keys[idx] + return (pair_key, *images) + + def __len__(self): + return len(self.pair_keys) + + +class LrHrImages(FolderImagePair): + """ + make LrHrImages dataset + """ + + def __init__(self, lr_pattern, hr_pattern, reader=None): + self.hr_pattern = hr_pattern + self.lr_pattern = lr_pattern + self.dir_patterns = [] + if isinstance(self.lr_pattern, str): + self.is_multi_lr = False + self.dir_patterns.append(self.lr_pattern) + elif len(lr_pattern) == 1: + self.is_multi_lr = False + self.dir_patterns.append(self.lr_pattern[0]) + else: + self.is_multi_lr = True + self.dir_patterns.extend(self.lr_pattern) + self.dir_patterns.append(self.hr_pattern) + super(LrHrImages, self).__init__(self.dir_patterns, reader=reader) + + def __getitem__(self, idx): + _, *images = super(LrHrImages, self).__getitem__(idx) + return tuple(images) + + +class _BasePatchCutter: + """ + cut patch from images + patch_size(int): patch size, input images should be bigger than patch_size. + lr_scale(int/list): lr scales for input images. Choice from [1,2,3,4, or their combination] + """ + + def __init__(self, patch_size, lr_scale): + self.patch_size = patch_size + self.multi_lr_scale = lr_scale + if isinstance(lr_scale, int): + self.multi_lr_scale = [lr_scale] + else: + self.multi_lr_scale = [*lr_scale] + self.max_lr_scale_idx = self.multi_lr_scale.index(max(self.multi_lr_scale)) + self.max_lr_scale = self.multi_lr_scale[self.max_lr_scale_idx] + + def get_tx_ty(self, target_height, target_weight, target_patch_size): + raise NotImplementedError() + + def __call__(self, *images): + target_img = images[self.max_lr_scale_idx] + + tp = self.patch_size // self.max_lr_scale + th, tw, _ = target_img.shape + + tx, ty = self.get_tx_ty(th, tw, tp) + + patch_images = [] + for _, (img, lr_scale) in enumerate(zip(images, self.multi_lr_scale)): + x = tx * self.max_lr_scale // lr_scale + y = ty * self.max_lr_scale // lr_scale + p = tp * self.max_lr_scale // lr_scale + patch_images.append(img[y:(y + p), x:(x + p), :]) + return tuple(patch_images) + + +class RandomPatchCutter(_BasePatchCutter): + + def __init__(self, patch_size, lr_scale): + super(RandomPatchCutter, self).__init__(patch_size=patch_size, lr_scale=lr_scale) + + def get_tx_ty(self, target_height, target_weight, target_patch_size): + target_x = random.randrange(0, target_weight - target_patch_size + 1) + target_y = random.randrange(0, target_height - target_patch_size + 1) + return target_x, target_y + + +class CentrePatchCutter(_BasePatchCutter): + + def __init__(self, patch_size, lr_scale): + super(CentrePatchCutter, self).__init__(patch_size=patch_size, lr_scale=lr_scale) + + def get_tx_ty(self, target_height, target_weight, target_patch_size): + target_x = (target_weight - target_patch_size) // 2 + target_y = (target_height - target_patch_size) // 2 + return target_x, target_y + + +def hflip(img): + return img[:, ::-1, :] + + +def vflip(img): + return img[::-1, :, :] + + +def trnsp(img): + return img.transpose(1, 0, 2) + + +AUG_LIST = [ + [], + [trnsp], + [vflip], + [vflip, trnsp], + [hflip], + [hflip, trnsp], + [hflip, vflip], + [hflip, vflip, trnsp], +] + +AUG_DICT = { + "0": [], + "t": [trnsp], + "v": [vflip], + "vt": [vflip, trnsp], + "h": [hflip], + "ht": [hflip, trnsp], + "hv": [hflip, vflip], + "hvt": [hflip, vflip, trnsp], +} + + +def flip_and_rotate(*images): + aug = random.choice(AUG_LIST) + res = [] + for img in images: + for a in aug: + img = a(img) + res.append(img) + return tuple(res) + + +def hwc2chw(*images): + res = [i.transpose(2, 0, 1) for i in images] + return tuple(res) + + +def uint8_to_float32(*images): + res = [(i.astype(np.float32) if i.dtype == np.uint8 else i) for i in images] + return tuple(res) + + +def create_dataset_DIV2K(config, dataset_type="train", num_parallel_workers=10, shuffle=True): + """ + create a train or eval DIV2K dataset + Args: + config(dict): + dataset_path(string): the path of dataset. + scale(int/list): lr scale, read data ordered by it, choices=(2,3,4,[2],[3],[4],[2,3],[2,4],[3,4],[2,3,4]) + lr_type(string): lr images type, choices=("bicubic", "unknown"), Default "bicubic" + batch_size(int): the batch size of dataset. (train prarm), Default 1 + patch_size(int): train data size. (train param), Default -1 + epoch_size(int): times to repeat dataset for dataset_sink_mode, Default None + dataset_type(string): choices=("train", "valid", "test"), Default "train" + num_parallel_workers(int): num-workers to read data, Default 10 + shuffle(bool): shuffle dataset. Default: True + Returns: + dataset + """ + dataset_path = config["dataset_path"] + lr_scale = config["scale"] + if config["eval_type"] == "ONNX": + lr_type = config.get("lr_type", "bicubic_AUG_self_ensemble") + else: + lr_type = config.get("lr_type", "bicubic") + batch_size = config.get("batch_size", 1) + patch_size = config.get("patch_size", -1) + epoch_size = config.get("epoch_size", None) + + # for multi lr scale, such as [2,3,4] + if isinstance(lr_scale, int): + multi_lr_scale = [lr_scale] + else: + multi_lr_scale = lr_scale + + # get HR_PATH/*.png + dir_hr = os.path.join(dataset_path, f"DIV2K_{dataset_type}_HR") + hr_pattern = os.path.join(dir_hr, "*.png") + + # get LR_PATH/X2/*x2.png, LR_PATH/X3/*x3.png, LR_PATH/X4/*x4.png + column_names = [] + lrs_pattern = [] + for lr_scale in multi_lr_scale: + dir_lr = os.path.join(dataset_path, f"DIV2K_{dataset_type}_LR_{lr_type}", f"X{lr_scale}") + if config["eval_type"] == "ONNX": + lr_pattern = os.path.join(dir_lr, f"*x{lr_scale}_0.png") + else: + lr_pattern = os.path.join(dir_lr, f"*x{lr_scale}.png") + # if dataset_type == "train": + # lr_pattern = os.path.join(dir_lr, f"*x{lr_scale}.png") + # else: + # lr_pattern = os.path.join(dir_lr, f"*.png") + lrs_pattern.append(lr_pattern) + column_names.append(f"lrx{lr_scale}") + column_names.append("hr") # ["lrx2","lrx3","lrx4",..., "hr"] + + # make dataset + dataset = LrHrImages(lr_pattern=lrs_pattern, hr_pattern=hr_pattern) + + # make mindspore dataset + device_num, rank_id = get_rank_info() + if device_num == 1 or device_num is None: + generator_dataset = ds.GeneratorDataset(dataset, column_names=column_names, + num_parallel_workers=num_parallel_workers, + shuffle=shuffle and dataset_type == "train") + elif dataset_type == "train": + generator_dataset = ds.GeneratorDataset(dataset, column_names=column_names, + num_parallel_workers=num_parallel_workers, + shuffle=shuffle and dataset_type == "train", + num_shards=device_num, shard_id=rank_id) + else: + sampler = ds.DistributedSampler(num_shards=device_num, shard_id=rank_id, shuffle=False, + offset=0) + generator_dataset = ds.GeneratorDataset(dataset, column_names=column_names, + num_parallel_workers=num_parallel_workers, + sampler=sampler) + + # define map operations + if dataset_type == "train": + transform_img = [ + RandomPatchCutter(patch_size, multi_lr_scale + [1]), + flip_and_rotate, + hwc2chw, + uint8_to_float32, + ] + elif patch_size > 0: + transform_img = [ + CentrePatchCutter(patch_size, multi_lr_scale + [1]), + hwc2chw, + uint8_to_float32, + ] + else: + transform_img = [ + hwc2chw, + uint8_to_float32, + ] + + # pre-process hr lr + generator_dataset = generator_dataset.map(input_columns=column_names, + output_columns=column_names, + operations=transform_img) + + # apply batch operations + generator_dataset = generator_dataset.batch(batch_size, drop_remainder=False) + + # apply repeat operations + if dataset_type == "train" and epoch_size is not None and epoch_size != 1: + generator_dataset = generator_dataset.repeat(epoch_size) + + return generator_dataset diff --git a/research/cv/GhostSR/src/edsr.py b/research/cv/GhostSR/src/edsr.py new file mode 100644 index 000000000..97695de53 --- /dev/null +++ b/research/cv/GhostSR/src/edsr.py @@ -0,0 +1,204 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""EDSR_mindspore""" +import numpy as np +from mindspore import Parameter +from mindspore import nn, ops +from mindspore.common.initializer import TruncatedNormal + + +class RgbNormal(nn.Cell): + """ + "MeanShift" in EDSR_mindspore paper pytorch-code: + https://github.com/sanghyun-son/EDSR-PyTorch/blob/master/src/model/common.py + + it is not unreasonable in the case below + if std != 1 and sign = -1: y = x * rgb_std - rgb_range * rgb_mean + if std != 1 and sign = 1: y = x * rgb_std + rgb_range * rgb_mean + they are not inverse operation for each other! + + so use "RgbNormal" instead, it runs as below: + if inverse = False: y = (x / rgb_range - mean) / std + if inverse = True : x = (y * std + mean) * rgb_range + """ + + def __init__(self, rgb_range, rgb_mean, rgb_std, inverse=False): + super(RgbNormal, self).__init__() + self.rgb_range = rgb_range + self.rgb_mean = rgb_mean + self.rgb_std = rgb_std + self.inverse = inverse + std = np.array(self.rgb_std, dtype=np.float32) + mean = np.array(self.rgb_mean, dtype=np.float32) + if not inverse: + # y: (x / rgb_range - mean) / std <=> x * (1.0 / rgb_range / std) + (-mean) / std + weight = (1.0 / self.rgb_range / std).reshape((1, -1, 1, 1)) + bias = (-mean / std).reshape((1, -1, 1, 1)) + else: + # x: (y * std + mean) * rgb_range <=> y * (std * rgb_range) + mean * rgb_range + weight = (self.rgb_range * std).reshape((1, -1, 1, 1)) + bias = (mean * rgb_range).reshape((1, -1, 1, 1)) + self.weight = Parameter(weight, requires_grad=False) + self.bias = Parameter(bias, requires_grad=False) + + def construct(self, x): + return x * self.weight + self.bias + + def extend_repr(self): + s = 'rgb_range={}, rgb_mean={}, rgb_std={}, inverse = {}' \ + .format(self.rgb_range, self.rgb_mean, self.rgb_std, self.inverse) + return s + + +def make_conv2d(in_channels, out_channels, kernel_size, has_bias=True): + return nn.Conv2d( + in_channels, out_channels, kernel_size, + pad_mode="same", has_bias=has_bias, weight_init=TruncatedNormal(0.02)) + + +class ResBlock(nn.Cell): + """ + Resnet Block + """ + + def __init__( + self, in_channels, out_channels, kernel_size=1, has_bias=True, res_scale=1): + super(ResBlock, self).__init__() + self.conv1 = make_conv2d(in_channels, in_channels, kernel_size, has_bias) + self.relu = nn.ReLU() + self.conv2 = make_conv2d(in_channels, out_channels, kernel_size, has_bias) + self.res_scale = res_scale + + def construct(self, x): + res = self.conv1(x) + res = self.relu(res) + res = self.conv2(res) + res = res * self.res_scale + x = x + res + return x + + +class PixelShuffle(nn.Cell): + """ + PixelShuffle using ops.DepthToSpace + """ + + def __init__(self, upscale_factor): + super(PixelShuffle, self).__init__() + self.upscale_factor = upscale_factor + self.upper = ops.DepthToSpace(self.upscale_factor) + + def construct(self, x): + return self.upper(x) + + def extend_repr(self): + return 'upscale_factor={}'.format(self.upscale_factor) + + +def UpsamplerBlockList(upscale_factor, n_feats, has_bias=True): + """ + make Upsampler Block List + """ + if upscale_factor == 1: + return [] + allow_sub_upscale_factor = [2, 3, None] + for sub in allow_sub_upscale_factor: + if sub is None: + raise NotImplementedError( + f"Only support \"scales\" that can be divisibled by {allow_sub_upscale_factor[:-1]}") + if upscale_factor % sub == 0: + break + sub_block_list = [ + make_conv2d(n_feats, sub * sub * n_feats, 3, has_bias), + PixelShuffle(sub), + ] + return sub_block_list + UpsamplerBlockList(upscale_factor // sub, n_feats, has_bias) + + +class Upsampler(nn.Cell): + + def __init__(self, scale, n_feats, has_bias=True): + super(Upsampler, self).__init__() + up = UpsamplerBlockList(scale, n_feats, has_bias) + self.up = nn.SequentialCell(*up) + + def construct(self, x): + x = self.up(x) + return x + + +class EDSR(nn.Cell): + """ + EDSR_mindspore network + """ + + def __init__(self, scale, n_feats, kernel_size, n_resblocks, + n_colors=3, + res_scale=0.1, + rgb_range=255, + rgb_mean=(0.0, 0.0, 0.0), + rgb_std=(1.0, 1.0, 1.0)): + super(EDSR, self).__init__() + + self.norm = RgbNormal(rgb_range, rgb_mean, rgb_std, inverse=False) + self.de_norm = RgbNormal(rgb_range, rgb_mean, rgb_std, inverse=True) + + m_head = [make_conv2d(n_colors, n_feats, kernel_size)] + + m_body = [ + ResBlock(n_feats, n_feats, kernel_size, res_scale=res_scale) + for _ in range(n_resblocks) + ] + m_body.append(make_conv2d(n_feats, n_feats, kernel_size)) + + m_tail = [ + Upsampler(scale, n_feats), + make_conv2d(n_feats, n_colors, kernel_size) + ] + + self.head = nn.SequentialCell(m_head) + self.body = nn.SequentialCell(m_body) + self.tail = nn.SequentialCell(m_tail) + + def construct(self, x): + x = self.norm(x) + x = self.head(x) + x = x + self.body(x) + x = self.tail(x) + x = self.de_norm(x) + return x + + def load_pre_trained_param_dict(self, new_param_dict, strict=True): + """ + load pre_trained param dict from edsr_x2 + """ + own_param = self.parameters_dict() + for name, new_param in new_param_dict.items(): + if len(name) >= 4 and name[:4] == "net.": + name = name[4:] + if name in own_param: + if isinstance(new_param, Parameter): + param = own_param[name] + if tuple(param.data.shape) == tuple(new_param.data.shape): + param.set_data(type(param.data)(new_param.data)) + elif name.find('tail') == -1: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_param[name].shape, new_param.shape)) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in parameters_dict()' + .format(name)) diff --git a/research/cv/GhostSR/src/metric.py b/research/cv/GhostSR/src/metric.py new file mode 100644 index 000000000..014e830da --- /dev/null +++ b/research/cv/GhostSR/src/metric.py @@ -0,0 +1,338 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Metric for evaluation.""" +import math +import os + +import numpy as np +from PIL import Image +from mindspore import dtype as mstype +from mindspore import nn, Tensor, ops +from mindspore.ops.operations.comm_ops import ReduceOp + +try: + from model_utils.device_adapter import get_rank_id, get_device_num +except ImportError: + get_rank_id = None + get_device_num = None +finally: + pass + + +class SelfEnsembleWrapperNumpy: + """ + SelfEnsembleWrapperNumpy using numpy + """ + + def __init__(self, net): + super(SelfEnsembleWrapperNumpy, self).__init__() + self.net = net + + def hflip(self, x): + return x[:, :, :, ::-1] + + def vflip(self, x): + return x[:, :, ::-1, :] + + def trnsps(self, x): + return x.transpose(0, 1, 3, 2) + + def aug_x8(self, x): + """ + do x8 augments for input image + """ + # hflip + hx = self.hflip(x) + # vflip + vx = self.vflip(x) + vhx = self.vflip(hx) + # trnsps + tx = self.trnsps(x) + thx = self.trnsps(hx) + tvx = self.trnsps(vx) + tvhx = self.trnsps(vhx) + return x, hx, vx, vhx, tx, thx, tvx, tvhx + + def aug_x8_reverse(self, x, hx, vx, vhx, tx, thx, tvx, tvhx): + """ + undo x8 augments for input images + """ + # trnsps + tvhx = self.trnsps(tvhx) + tvx = self.trnsps(tvx) + thx = self.trnsps(thx) + tx = self.trnsps(tx) + # vflip + tvhx = self.vflip(tvhx) + tvx = self.vflip(tvx) + vhx = self.vflip(vhx) + vx = self.vflip(vx) + # hflip + tvhx = self.hflip(tvhx) + thx = self.hflip(thx) + vhx = self.hflip(vhx) + hx = self.hflip(hx) + return x, hx, vx, vhx, tx, thx, tvx, tvhx + + def to_numpy(self, *inputs): + # if inputs: + # return None + if len(inputs) == 1: + return inputs[0].asnumpy() + return [x.asnumpy() for x in inputs] + + def to_tensor(self, *inputs): + # if inputs: + # return None + if len(inputs) == 1: + return Tensor(inputs[0]) + return [Tensor(x) for x in inputs] + + def set_train(self, mode=True): + self.net.set_train(mode) + return self + + def __call__(self, x): + x = self.to_numpy(x) + x0, x1, x2, x3, x4, x5, x6, x7 = self.aug_x8(x) + x0, x1, x2, x3, x4, x5, x6, x7 = self.to_tensor(x0, x1, x2, x3, x4, x5, x6, x7) + x0 = self.net(x0) + x1 = self.net(x1) + x2 = self.net(x2) + x3 = self.net(x3) + x4 = self.net(x4) + x5 = self.net(x5) + x6 = self.net(x6) + x7 = self.net(x7) + x0, x1, x2, x3, x4, x5, x6, x7 = self.to_numpy(x0, x1, x2, x3, x4, x5, x6, x7) + x0, x1, x2, x3, x4, x5, x6, x7 = self.aug_x8_reverse(x0, x1, x2, x3, x4, x5, x6, x7) + x0, x1, x2, x3, x4, x5, x6, x7 = self.to_tensor(x0, x1, x2, x3, x4, x5, x6, x7) + return (x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7) / 8 + + +class SelfEnsembleWrapper(nn.Cell): + """ + because of [::-1] operator error, use "SelfEnsembleWrapperNumpy" instead + """ + + def __init__(self, net): + super(SelfEnsembleWrapper, self).__init__() + self.net = net + + def hflip(self, x): + raise NotImplementedError( + "https://gitee.com/mindspore/mindspore/issues/I41ONQ?from=project-issue") + + def vflip(self, x): + raise NotImplementedError( + "https://gitee.com/mindspore/mindspore/issues/I41ONQ?from=project-issue") + + def trnsps(self, x): + return x.transpose(0, 1, 3, 2) + + def aug_x8(self, x): + """ + do x8 augments for input image + """ + # hflip + hx = self.hflip(x) + # vflip + vx = self.vflip(x) + vhx = self.vflip(hx) + # trnsps + tx = self.trnsps(x) + thx = self.trnsps(hx) + tvx = self.trnsps(vx) + tvhx = self.trnsps(vhx) + return x, hx, vx, vhx, tx, thx, tvx, tvhx + + def aug_x8_reverse(self, x, hx, vx, vhx, tx, thx, tvx, tvhx): + """ + undo x8 augments for input images + """ + # trnsps + tvhx = self.trnsps(tvhx) + tvx = self.trnsps(tvx) + thx = self.trnsps(thx) + tx = self.trnsps(tx) + # vflip + tvhx = self.vflip(tvhx) + tvx = self.vflip(tvx) + vhx = self.vflip(vhx) + vx = self.vflip(vx) + # hflip + tvhx = self.hflip(tvhx) + thx = self.hflip(thx) + vhx = self.hflip(vhx) + hx = self.hflip(hx) + return x, hx, vx, vhx, tx, thx, tvx, tvhx + + def construct(self, x): + """ + do x8 aug, run network, undo x8 aug, calculate mean for 8 output + """ + x0, x1, x2, x3, x4, x5, x6, x7 = self.aug_x8(x) + x0 = self.net(x0) + x1 = self.net(x1) + x2 = self.net(x2) + x3 = self.net(x3) + x4 = self.net(x4) + x5 = self.net(x5) + x6 = self.net(x6) + x7 = self.net(x7) + x0, x1, x2, x3, x4, x5, x6, x7 = self.aug_x8_reverse(x0, x1, x2, x3, x4, x5, x6, x7) + return (x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7) / 8 + + +class Quantizer(nn.Cell): + """ + clip by [0.0, 255.0], rount to int + """ + + def __init__(self, _min=0.0, _max=255.0): + super(Quantizer, self).__init__() + self._min = _min + self._max = _max + + def construct(self, x): + x = ops.clip_by_value(x, self._min, self._max) + x = x.astype("Int32") + return x + + +class TensorSyncer(nn.Cell): + """ + sync metric values from all mindspore-processes + """ + + def __init__(self, _type="sum"): + super(TensorSyncer, self).__init__() + self._type = _type.lower() + if self._type == "sum": + self.ops = ops.AllReduce(ReduceOp.SUM) + elif self._type == "gather": + self.ops = ops.AllGather() + else: + raise ValueError(f"TensorSyncer._type == {self._type} is not support") + + def construct(self, x): + return self.ops(x) + + +class _DistMetric(nn.Metric): + """ + gather data from all rank while eval(True) + _type(str): choice from ["avg", "sum"]. + """ + + def __init__(self, _type): + super(_DistMetric, self).__init__() + self._type = _type.lower() + self.all_reduce_sum = None + if get_device_num is not None and get_device_num() > 1: + self.all_reduce_sum = TensorSyncer(_type="sum") + self.clear() + self.sum = None + + def _accumulate(self, value): + if isinstance(value, (list, tuple)): + self._acc_value += sum(value) + self._count += len(value) + else: + self._acc_value += value + self._count += 1 + + def clear(self): + self._acc_value = 0.0 + self._count = 0 + + def eval(self, sync=True): + """ + sync: True, return metric value merged from all mindspore-processes + sync: False, return metric value in this single mindspore-processes + """ + if self._count == 0: + raise RuntimeError('self._count == 0') + if self.sum is not None and sync: + data = Tensor([self._acc_value, self._count], mstype.float32) + data = self.all_reduce_sum(data) + acc_value, count = self._convert_data(data).tolist() + else: + acc_value, count = self._acc_value, self._count + if self._type == "avg": + return acc_value / count + if self._type == "sum": + return acc_value + raise RuntimeError(f"_DistMetric._type={self._type} is not support") + + +class PSNR(_DistMetric): + """ + Define PSNR metric for SR network. + """ + + def __init__(self, rgb_range, shave): + super(PSNR, self).__init__(_type="avg") + self.shave = shave + self.rgb_range = rgb_range + self.quantize = Quantizer(0.0, 255.0) + + def update(self, *inputs): + """ + update psnr + """ + if len(inputs) != 2: + raise ValueError('PSNR need 2 inputs (sr, hr), but got {}'.format(len(inputs))) + sr, hr = inputs + sr = self.quantize(sr) + diff = (sr - hr) / self.rgb_range + valid = diff + if self.shave is not None and self.shave != 0: + valid = valid[..., int(self.shave):int(-self.shave), int(self.shave):int(-self.shave)] + mse_list = (valid ** 2).mean(axis=(1, 2, 3)) + mse_list = self._convert_data(mse_list).tolist() + psnr_list = [float(1e32) if mse == 0 else (- 10.0 * math.log10(mse)) for mse in mse_list] + self._accumulate(psnr_list) + + +class SaveSrHr(_DistMetric): + """ + help to save sr and hr + """ + + def __init__(self, save_dir): + super(SaveSrHr, self).__init__(_type="sum") + self.save_dir = save_dir + self.quantize = Quantizer(0.0, 255.0) + self.rank_id = 0 if get_rank_id is None else get_rank_id() + self.device_num = 1 if get_device_num is None else get_device_num() + + def update(self, *inputs): + """ + update images to save + """ + if len(inputs) != 2: + raise ValueError('SaveSrHr need 2 inputs (sr, hr), but got {}'.format(len(inputs))) + sr, hr = inputs + sr = self.quantize(sr) + sr = self._convert_data(sr).astype(np.uint8) + hr = self._convert_data(hr).astype(np.uint8) + for s, h in zip(sr.transpose(0, 2, 3, 1), hr.transpose(0, 2, 3, 1)): + idx = self._count * self.device_num + self.rank_id + sr_path = os.path.join(self.save_dir, f"{idx:0>4}_sr.png") + Image.fromarray(s).save(sr_path) + hr_path = os.path.join(self.save_dir, f"{idx:0>4}_hr.png") + Image.fromarray(h).save(hr_path) + self._accumulate(1) diff --git a/research/cv/GhostSR/src/utils.py b/research/cv/GhostSR/src/utils.py new file mode 100644 index 000000000..10cfe80f3 --- /dev/null +++ b/research/cv/GhostSR/src/utils.py @@ -0,0 +1,212 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +#################utils for train.py and eval.py######################## +""" +import os +import time +from matplotlib import pyplot as plt + +import mindspore as ms +from mindspore import context +from mindspore.communication.management import init +from mindspore.context import ParallelMode +from mindspore.train.serialization import load_checkpoint +from model_utils.config import config +from model_utils.device_adapter import get_device_id, get_rank_id, get_device_num + +from GhostSR.EDSR_mindspore.edsr import EDSR4GhostSRMs +from .dataset import create_dataset_DIV2K +from .edsr import EDSR + + +def plt_tensor_img(tensor): + img = tensor.asnumpy().astype('uint8').squeeze().transpose(1, 2, 0) + plt.imshow(img) + plt.show() + + +def init_env(cfg): + """ + init env for mindspore + """ + context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) + # context.set_context(mode=ms.PYNATIVE_MODE, device_target=cfg.device_target) + device_num = get_device_num() + if cfg.device_target == "Ascend": + context.set_context(device_id=get_device_id()) + if device_num > 1: + init() + context.reset_auto_parallel_context() + context.set_auto_parallel_context(device_num=device_num, + parallel_mode=ParallelMode.DATA_PARALLEL, + gradients_mean=True) + elif cfg.device_target == "GPU": + context.set_context(enable_graph_kernel=True) + if device_num > 1: + init() + context.reset_auto_parallel_context() + context.set_auto_parallel_context(device_num=device_num, + parallel_mode=ParallelMode.DATA_PARALLEL, + gradients_mean=True) + elif cfg.device_target == "CPU": + pass + else: + raise ValueError("Unsupported platform.") + + +def init_dataset(cfg, dataset_type="train"): + """ + init DIV2K dataset + """ + ds_cfg = { + "dataset_path": cfg.data_path, + "scale": cfg.scale, + "lr_type": cfg.lr_type, + "batch_size": cfg.batch_size, + "patch_size": cfg.patch_size, + "eval_type": cfg.eval_type, + } + if cfg.dataset_name == "DIV2K": + dataset = create_dataset_DIV2K(config=ds_cfg, + dataset_type=dataset_type, + num_parallel_workers=10, + shuffle=dataset_type == "Train") + else: + raise ValueError("Unsupported dataset.") + return dataset + + +def init_net(cfg): + """ + init edsr network + """ + if cfg.network == 'EDSR_mindspore': + net = EDSR(scale=cfg.scale, + n_feats=cfg.n_feats, + kernel_size=cfg.kernel_size, + n_resblocks=cfg.n_resblocks, + n_colors=cfg.n_colors, + res_scale=cfg.res_scale, + rgb_range=cfg.rgb_range, + rgb_mean=cfg.rgb_mean, + rgb_std=cfg.rgb_std) + elif cfg.network == 'EDSR4GhostSRMs': + net = EDSR4GhostSRMs(scale=cfg.scale) + + if cfg.pre_trained: + pre_trained_path = os.path.join(cfg.output_path, cfg.pre_trained) + if len(cfg.pre_trained) >= 5 and cfg.pre_trained[:5] == "s3://": + pre_trained_path = cfg.pre_trained + import moxing as mox + mox.file.shift("os", "mox") # then system can read file from s3:// + elif os.path.isfile(cfg.pre_trained): + pre_trained_path = cfg.pre_trained + elif os.path.isfile(pre_trained_path): + pass + else: + raise ValueError(f"pre_trained error: {cfg.pre_trained}") + print(f"loading pre_trained = {pre_trained_path}", flush=True) + param_dict = load_checkpoint(pre_trained_path) + param_not_load = ms.load_param_into_net(net, param_dict, strict_load=True) + print(f'param_not_load: {param_not_load}') + # net.load_pre_trained_param_dict(param_dict, strict=False) + return net + + +def modelarts_pre_process(): + '''modelarts pre process function.''' + + def unzip(zip_file, save_dir): + import zipfile + s_time = time.time() + zip_isexist = zipfile.is_zipfile(zip_file) + zip_name = os.path.basename(zip_file) + if zip_isexist: + fz = zipfile.ZipFile(zip_file, 'r') + data_num = len(fz.namelist()) + data_print = int(data_num / 4) if data_num > 4 else 1 + len_data_num = len(str(data_num)) + for i, _file in enumerate(fz.namelist()): + if i % data_print == 0: + print( + "[{1:>{0}}/{2:>{0}}] {3:>2}% const time: {4:0>2}:{5:0>2} unzipping {6}".format( + len_data_num, + i, + data_num, + int(i / data_num * 100), + int((time.time() - s_time) / 60), + int(int(time.time() - s_time) % 60), + zip_name, + flush=True)) + fz.extract(_file, save_dir) + print(" finish const time: {:0>2}:{:0>2} unzipping {}".format( + int((time.time() - s_time) / 60), + int(int(time.time() - s_time) % 60), + zip_name, + flush=True)) + else: + print("{} is not zip.".format(zip_name), flush=True) + + if config.enable_modelarts and config.need_unzip_in_modelarts: + sync_lock = "/tmp/unzip_sync.lock" + # Each server contains 8 devices as most. + if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): + for ufile in config.need_unzip_files: + zip_file = os.path.join(config.data_path, ufile) + save_dir = os.path.dirname(zip_file) + unzip(zip_file, save_dir) + print("===Finish extract data synchronization===", flush=True) + try: + os.mknod(sync_lock) + except IOError: + pass + + while True: + if os.path.exists(sync_lock): + break + time.sleep(1) + + print("Device: {}, Finish sync unzip data.".format(get_device_id()), flush=True) + + config.ckpt_save_dir = os.path.join(config.output_path, config.ckpt_save_dir) + + +def do_eval(eval_network, ds_val, metrics, cur_epoch=None): + """ + do eval for psnr and save hr, sr + """ + eval_network.set_train(False) + total_step = ds_val.get_dataset_size() + setw = len(str(total_step)) + begin = time.time() + step_begin = time.time() + rank_id = get_rank_id() + for i, (lr, hr) in enumerate(ds_val): + sr = eval_network(lr) + _ = [m.update(sr, hr) for m in metrics.values()] + result = {k: m.eval(sync=False) for k, m in metrics.items()} + result["time"] = time.time() - step_begin + step_begin = time.time() + print(f"[{i + 1:>{setw}}/{total_step:>{setw}}] rank = {rank_id} result = {result}", + flush=True) + result = {k: m.eval(sync=False) for k, m in metrics.items()} + result["time"] = time.time() - begin + if cur_epoch is not None: + result["epoch"] = cur_epoch + if rank_id == 0: + print(f"evaluation result = {result}", flush=True) + eval_network.set_train(True) + return result diff --git a/research/cv/GhostSR/train.py b/research/cv/GhostSR/train.py new file mode 100644 index 000000000..e2f1790e5 --- /dev/null +++ b/research/cv/GhostSR/train.py @@ -0,0 +1,155 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +#################train EDSR_mindspore example on DIV2K######################## +""" + +import numpy as np +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor, \ + Callback +from mindspore.train.model import Model +from mindspore.common import set_seed + +from src.metric import PSNR +from src.utils import init_env, init_dataset, init_net, modelarts_pre_process, do_eval +from model_utils.config import config +from model_utils.moxing_adapter import moxing_wrapper +from model_utils.device_adapter import get_rank_id, get_device_num + +set_seed(2021) + + +def lr_steps_edsr(lr, milestones, gamma, epoch_size, steps_per_epoch, last_epoch=None): + lr_each_step = [] + step_begin_epoch = [0] + milestones[:-1] + step_end_epoch = milestones[1:] + [epoch_size] + for begin, end in zip(step_begin_epoch, step_end_epoch): + lr_each_step += [lr] * (end - begin) * steps_per_epoch + lr *= gamma + if last_epoch is not None: + lr_each_step = lr_each_step[last_epoch * steps_per_epoch:] + return np.array(lr_each_step).astype(np.float32) + + +def init_opt(cfg, net): + """ + init opt to train edsr + """ + lr = lr_steps_edsr(lr=cfg.learning_rate, milestones=cfg.milestones, gamma=cfg.gamma, + epoch_size=cfg.epoch_size, steps_per_epoch=cfg.steps_per_epoch, + last_epoch=None) + loss_scale = 1.0 if cfg.amp_level == "O0" else cfg.loss_scale + if cfg.opt_type == "Adam": + opt = nn.Adam(params=filter(lambda x: x.requires_grad, net.get_parameters()), + learning_rate=Tensor(lr), + weight_decay=cfg.weight_decay, + loss_scale=loss_scale) + elif cfg.opt_type == "SGD": + opt = nn.SGD(params=filter(lambda x: x.requires_grad, net.get_parameters()), + learning_rate=Tensor(lr), + weight_decay=cfg.weight_decay, + momentum=cfg.momentum, + dampening=cfg.dampening if hasattr(cfg, "dampening") else 0.0, + nesterov=cfg.nesterov if hasattr(cfg, "nesterov") else False, + loss_scale=loss_scale) + else: + raise ValueError("Unsupported optimizer.") + return opt + + +class EvalCallBack(Callback): + """ + eval callback + """ + + def __init__(self, eval_network, ds_val, eval_epoch_frq, epoch_size, metrics, + result_evaluation=None): + self.eval_network = eval_network + self.ds_val = ds_val + self.eval_epoch_frq = eval_epoch_frq + self.epoch_size = epoch_size + self.result_evaluation = result_evaluation + self.metrics = metrics + self.best_result = None + self.eval_network.set_train(False) + + def epoch_end(self, run_context): + """ + do eval in epoch end + """ + cb_param = run_context.original_args() + cur_epoch = cb_param.cur_epoch_num + if cur_epoch % self.eval_epoch_frq == 0 or cur_epoch == self.epoch_size: + result = do_eval(self.eval_network, self.ds_val, self.metrics, cur_epoch=cur_epoch) + if self.best_result is None or self.best_result["psnr"] < result["psnr"]: + self.best_result = result + if get_rank_id() == 0: + print(f"best evaluation result = {self.best_result}", flush=True) + if isinstance(self.result_evaluation, dict): + for k, v in result.items(): + r_list = self.result_evaluation.get(k) + if r_list is None: + r_list = [] + self.result_evaluation[k] = r_list + r_list.append(v) + + +@moxing_wrapper(pre_process=modelarts_pre_process) +def run_train(): + """ + run train + """ + print(config, flush=True) + cfg = config + + init_env(cfg) + + ds_train = init_dataset(cfg, "train") + ds_val = init_dataset(cfg, "valid") + + net = init_net(cfg) + cfg.steps_per_epoch = ds_train.get_dataset_size() + opt = init_opt(cfg, net) + + loss = nn.L1Loss(reduction='mean') + + eval_net = net + + model = Model(net, loss_fn=loss, optimizer=opt, amp_level=cfg.amp_level) + + metrics = { + "psnr": PSNR(rgb_range=cfg.rgb_range, shave=True), + } + eval_cb = EvalCallBack(eval_net, ds_val, cfg.eval_epoch_frq, cfg.epoch_size, metrics=metrics) + + config_ck = CheckpointConfig(save_checkpoint_steps=cfg.steps_per_epoch * cfg.save_epoch_frq, + keep_checkpoint_max=cfg.keep_checkpoint_max) + time_cb = TimeMonitor() + ckpoint_cb = ModelCheckpoint(prefix=f"EDSR_x{cfg.scale}_" + cfg.dataset_name, + directory=cfg.ckpt_save_dir, + config=config_ck) + loss_cb = LossMonitor() + cbs = [time_cb, ckpoint_cb, loss_cb, eval_cb] + if get_device_num() > 1 and get_rank_id() != 0: + cbs = [time_cb, loss_cb, eval_cb] + + model.train(cfg.epoch_size, ds_train, dataset_sink_mode=cfg.dataset_sink_mode, callbacks=cbs) + print("train success", flush=True) + + +if __name__ == '__main__': + run_train() -- Gitee