From 915adbb1adb4cccc690e17f4b59c9175e45b1257 Mon Sep 17 00:00:00 2001 From: CCCP <648976749@qq.com> Date: Thu, 11 Aug 2022 01:09:01 +0800 Subject: [PATCH 1/4] prsr_pr --- .../prsr/PRSR_ID2111_for_TensorFlow/data.py | 76 ++++------- .../prsr/PRSR_ID2111_for_TensorFlow/solver.py | 126 +++++++++--------- .../prsr/PRSR_ID2111_for_TensorFlow/train.py | 89 +++++-------- .../prsr/PRSR_ID2111_for_TensorFlow/utils.py | 28 ---- 4 files changed, 116 insertions(+), 203 deletions(-) diff --git a/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/data.py b/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/data.py index 716c6436c..0d96d3784 100644 --- a/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/data.py +++ b/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/data.py @@ -1,31 +1,3 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -34,29 +6,29 @@ import tensorflow as tf class DataSet(object): def __init__(self, images_list_path, num_epoch, batch_size): - # filling the record_list - input_file = open(images_list_path, 'r') - self.record_list = [] - for line in input_file: - line = line.strip() - self.record_list.append(line) - filename_queue = tf.train.string_input_producer(self.record_list, num_epochs=None) - image_reader = tf.WholeFileReader() - _, image_file = image_reader.read(filename_queue) - image = tf.image.decode_jpeg(image_file, 3) - #preprocess - hr_image = tf.image.resize_images(image, [32, 32]) - lr_image = tf.image.resize_images(image, [8, 8]) - hr_image = tf.cast(hr_image, tf.float32) - lr_image = tf.cast(lr_image, tf.float32) - # - min_after_dequeue = 1000 - capacity = min_after_dequeue + 400 * batch_size - - # --------------------------------2021.11.3 整网数据比对前,去除训练脚本内部使用到的随机处理--------------------- - # self.hr_images, self.lr_images = tf.train.shuffle_batch([hr_image, lr_image], batch_size=batch_size, capacity=capacity, - # min_after_dequeue=min_after_dequeue) - self.hr_images, self.lr_images = tf.train.batch([hr_image, lr_image], batch_size = batch_size, capacity = capacity) - # --------------------------------2021.11.3 整网数据比对前,去除训练脚本内部使用到的随机处理--------------------- + def parse_example(example): + content = tf.read_file(example) + image = tf.image.decode_jpeg(content, channels = 3) + + hr_image = tf.image.resize_images(image, [32, 32]) + lr_image = tf.image.resize_images(image, [8, 8]) + hr_image = tf.cast(hr_image, tf.float32) + lr_image = tf.cast(lr_image, tf.float32) + + return hr_image, lr_image + + dataset = tf.data.TextLineDataset(images_list_path) + + num_example = 1000 + with open(images_list_path, 'r') as f: + num_example = len(list(f)) + + dataset = dataset.map(parse_example).shuffle(buffer_size = num_example).batch(batch_size, drop_remainder=True).repeat(num_epoch) + iterator = dataset.make_one_shot_iterator() + try: + self.hr_images, self.lr_images = iterator.get_next() + except tf.errors.OutOfRangeError: + iterator = dataset.make_one_shot_iterator() + self.hr_images, self.lr_images = iterator.get_next() \ No newline at end of file diff --git a/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/solver.py b/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/solver.py index 68bd96a7c..bb30f440c 100644 --- a/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/solver.py +++ b/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/solver.py @@ -1,37 +1,24 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from __future__ import absolute_import from __future__ import division from __future__ import print_function from npu_bridge.npu_init import * import tensorflow as tf + +# 引用precision_tool/tf_config.py +import precision_tool.tf_config as npu_tf_config + +# #------------------ Dump数据采集 ------------------------ +import argparse +# import precision_tool.tf_config as npu_tf_config +# import moxing as mox +import precision_tool.config as CONFIG +# #------------------ Dump数据采集 ------------------------ + +# #------------------ NPU 关闭融合规则 ---------------------- +# import precision_tool.tf_config as npu_tf_config +# #------------------ NPU 关闭融合规则 ---------------------- + import numpy as np from ops import * from data import * @@ -39,16 +26,6 @@ from net import * from utils import * import os import time - -# #------------------11.4 Dump数据采集------------------------ -# # 引用precision_tool/tf_config.py -# import argparse -# import precision_tool.tf_config as npu_tf_config -# import moxing as mox -# import precision_tool.config as CONFIG -# #------------------11.4 Dump数据采集------------------------ - - flags = tf.app.flags conf = flags.FLAGS @@ -62,7 +39,7 @@ class Solver(object): os.makedirs(self.train_dir) if not os.path.exists(self.samples_dir): os.makedirs(self.samples_dir) - # datasets params + # datasets params self.num_epoch = conf.num_epoch self.batch_size = conf.batch_size # optimizer parameter @@ -81,38 +58,48 @@ class Solver(object): learning_rate = tf.train.exponential_decay(self.learning_rate, self.global_step, 500000, 0.5, staircase = True) optimizer = tf.train.RMSPropOptimizer(learning_rate, decay = 0.95, momentum = 0.9, epsilon = 1e-8) + + # ------------------- NPU LossScale ------------------------- + self.loss_scale_manager = ExponentialUpdateLossScaleManager(init_loss_scale = 2 ** 32, incr_every_n_steps = 1000, + decr_every_n_nan_or_inf = 2, incr_ratio = 2, + decr_ratio = 0.8) + optimizer = NPULossScaleOptimizer(optimizer, self.loss_scale_manager) + # ------------------- NPU LossScale ------------------------- + self.train_op = optimizer.minimize(self.net.loss, global_step = self.global_step) + + def train(self): init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) summary_op = tf.summary.merge_all() saver = tf.train.Saver() - # Create a session for running operations in the Graph. - # config = tf.ConfigProto(allow_soft_placement=True) - # config.gpu_options.allow_growth = True - - # ----------------------------2021.9.11 NPU------------------------------------- - config = tf.ConfigProto() - # # ------------------11.4 Dump数据采集------------------------ - # config = npu_tf_config.session_dump_config(config, action = 'dump') - # # ------------------11.4 Dump数据采集------------------------ + # # ----------------- 溢出检测 -------------------------------- + # config = tf.ConfigProto() + # config = npu_tf_config.session_dump_config(config, action = 'overflow') + # # ------------------------------------------------------- + #----------------- NPU -------------------------------- + config = tf.ConfigProto() custom_op = config.graph_options.rewrite_options.custom_optimizers.add() custom_op.name = "NpuOptimizer" - # ----------------------2021.9.16 NPU------------- - custom_op.parameter_map["mix_compile_mode"].b = True - # ----------------------2021.9.16 NPU------------- + config.graph_options.rewrite_options.remapping = RewriterConfig.OFF # 必须显式关闭remap + #------------------------------------------------------- + + # ------------------ NPU 混合精度 ---------------------------- + custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision") + custom_op.parameter_map["modify_mixlist"].s = tf.compat.as_bytes("/home/ma-user/modelarts/user-job-dir/code/ops_info.json") + # ---------------------------------------------------------- - # # ----------------------2021.10.31 NPU------------- - # custom_op.parameter_map["use_off_line"].b = True - # custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision") - # # ----------------------2021.10.31 NPU------------- + # #---------------- NPU DUMP -------------------------- + # config = npu_tf_config.session_dump_config(config, action = 'dump') + # #----------------------------------------------------- - config.graph_options.rewrite_options.remapping = RewriterConfig.OFF # 必须显式关闭remap - config.graph_options.rewrite_options.memory_optimization = RewriterConfig.OFF - # ----------------------------2021.9.11 NPU------------------------------------- + # # ------------------ NPU 关闭融合规则 ---------------------- + # config = npu_tf_config.session_dump_config(config, action = 'fusion_off') + # # ------------------------------------------------------- sess = tf.Session(config = config) @@ -122,16 +109,22 @@ class Solver(object): summary_writer = tf.summary.FileWriter(self.train_dir, sess.graph) # Start input enqueue threads. coord = tf.train.Coordinator() - threads = tf.train.start_queue_runners(sess = sess, coord = coord) + threads = tf.train.start_queue_runners(sess=sess, coord=coord) iters = 0 try: - while not (coord.should_stop() | iters == 380): + # while not (coord.should_stop() | iters == 380): + while not (coord.should_stop()): # Run training steps or whatever t1 = time.time() _, loss = sess.run([self.train_op, self.net.loss], feed_dict = {self.net.train: True}) + # _, loss, scale_value = sess.run( + # [self.train_op, self.net.loss, self.loss_scale_manager.get_loss_scale()], + # feed_dict = {self.net.train: True}) t2 = time.time() - print('step %d, loss = %.2f %.1f examples/sec %.3f sec/batch' % ( + print('step %d, loss = %.2f (%.1f examples/sec; %.3f sec/batch)' % ( (iters, loss, self.batch_size / (t2 - t1), (t2 - t1)))) + # print('step %d, loss = %.2f (%.1f examples/sec; %.3f sec/batch), scale:%d' % ( + # (iters, loss, self.batch_size / (t2 - t1), (t2 - t1), scale_value))) iters += 1 if iters % 10 == 0: summary_str = sess.run(summary_op, feed_dict = {self.net.train: True}) @@ -139,16 +132,17 @@ class Solver(object): if iters % 1000 == 0: # self.sample(sess, mu=1.0, step=iters) self.sample(sess, mu = 1.1, step = iters) - # self.sample(sess, mu=100, step=iters) + # self.sample(sess, mu=100, step=iters) if iters % 10000 == 0: checkpoint_path = os.path.join(self.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step = iters) - # if iters % 370 == 0: + if iters % 370 == 0: + continue # parser = argparse.ArgumentParser() # parser.add_argument("--train_url", type = str, default = "./output") - # config2, unparsed = parser.parse_known_args() - #mox.file.copy_parallel(CONFIG.ROOT_DIR, config2.train_url) - #mox.file.copy_parallel('/home/ma-user/modelarts/user-job-dir/code', config2.train_url) + # configs = parser.parse_args() + # mox.file.copy_parallel(CONFIG.ROOT_DIR, configs.train_url) + # # mox.file.copy_parallel('/home/homema-user/modelarts/user-job-dir/code', configs.train_url) except tf.errors.OutOfRangeError: checkpoint_path = os.path.join(self.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path) diff --git a/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/train.py b/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/train.py index 9943defcd..ccaf56ea6 100644 --- a/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/train.py +++ b/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/train.py @@ -1,41 +1,14 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - +# -*- coding: utf-8 -*- import tensorflow as tf import sys import os import subprocess -#import precision_tool.config as CONFIG +import precision_tool.config as CONFIG # subprocess.call(["export LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libgomp.so.1:/$LD_PRELOAD"]) #-----------------------打印路径------------------- -print('当前路径为: {}'.format(os.path.abspath(__file__))) +# print('当前路径为: {}'.format(os.path.abspath(__file__))) # print('数据路径为: {}'.format(os.listdir('/home/ma-user/modelarts/inputs/data_url_0'))) #------------------------------------------------ @@ -43,41 +16,43 @@ sys.path.insert(0, './') flags = tf.app.flags -#-----------------------上载OBS文件至ModelArts------------ -import argparse -#import moxing as mox -# 解析输入参数data_url -parser = argparse.ArgumentParser() -parser.add_argument("--train_url", type=str, default="./output") -parser.add_argument("--data_url", type=str, default="/celeba-mytest/celebA_test") -config, unparsed = parser.parse_known_args() -# 在ModelArts容器创建数据存放目录 +# #-----------------------上载OBS文件至ModelArts------------ +# import argparse +# import moxing as mox +# # 解析输入参数data_url +# parser = argparse.ArgumentParser() +# parser.add_argument("--train_url", type=str, default="./output") +# parser.add_argument("--data_url", type=str, default="/celeba-mytest/celebA") +# # parser.add_argument("--data_url", type=str, default="/home/ma-user/modelarts/inputs/data_url_0") +# config, unparsed = parser.parse_known_args() +# # 在ModelArts容器创建数据存放目录 # os.makedirs("/cache/overflow_data") # data_dir = "/cache/dataset" # os.makedirs(data_dir) -# OBS数据拷贝到ModelArts容器内 +# # OBS数据拷贝到ModelArts容器内 # mox.file.copy_parallel(config.data_url, data_dir) # #------------------------------------------------------------ # print('数据路径2为: {}'.format(os.listdir('/cache/dataset'))) -# #---------------改写txt文件------------------- -# # options.dataset = '/root/data/celebA' -# # options.outfile = '/root/data/train.txt' -# # dataset = '/home/ma-user/modelarts/inputs/data_url_0' -# dataset = data_dir -# outfile = '/home/ma-user/modelarts/user-job-dir/code/train.txt' - -# f = open(outfile, 'w') -# dataset_basepath = dataset -# for p1 in os.listdir(dataset_basepath): -# image = os.path.abspath(dataset_basepath + '/' + p1) -# f.write(image + '\n') -# f.close() +#---------------改写txt文件------------------- +# options.dataset = '/root/data/celebA' +# options.outfile = '/root/data/train.txt' +# dataset = '/home/ma-user/modelarts/inputs/data_url_0' +data_dir = "/home/disk/celebA" +dataset = data_dir +outfile = './train.txt' + +f = open(outfile, 'w') +dataset_basepath = dataset +for p1 in os.listdir(dataset_basepath): + image = os.path.abspath(dataset_basepath + '/' + p1) + f.write(image + '\n') +f.close() #---------------------------------------------------------- #solver -flags.DEFINE_string("train_dir", "models", "trained model save path") +flags.DEFINE_string("train_dir", "./output", "trained model save path") flags.DEFINE_string("samples_dir", "samples", "sampled images save path") flags.DEFINE_string("imgs_list_path", "./train.txt", "images list file path") @@ -85,7 +60,7 @@ flags.DEFINE_boolean("use_gpu", True, "whether to use gpu for training") flags.DEFINE_integer("device_id", 0, "gpu device id") # flags.DEFINE_integer("num_epoch", 30, "train epoch num") -flags.DEFINE_integer("num_epoch", 1, "train epoch num") +flags.DEFINE_integer("num_epoch", 3000, "train epoch num") flags.DEFINE_integer("batch_size", 32, "batch_size") flags.DEFINE_float("learning_rate", 4e-4, "learning rate") @@ -97,7 +72,7 @@ def main(_): solver = Solver() solver.train() - # #-------------------------- 2021.10.18 NPU Modelarts文件传到OBS中------------------------- + # # #-------------------------- 2021.10.18 NPU Modelarts文件传到OBS中------------------------- # # 解析输入参数train_url # parser = argparse.ArgumentParser() # parser.add_argument("--train_url", type = str, default = "./output") @@ -107,7 +82,7 @@ def main(_): # os.makedirs(model_dir) # # 训练结束后,将ModelArts容器内的训练输出拷贝到OBS # mox.file.copy_parallel(model_dir, config.train_url) - # # -------------------------- 2021.10.18 NPU Modelarts文件传到OBS中------------------------- + # # # -------------------------- 2021.10.18 NPU Modelarts文件传到OBS中------------------------- if __name__ == '__main__': # EventLOG diff --git a/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/utils.py b/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/utils.py index c528eadc5..08a66d9f2 100644 --- a/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/utils.py +++ b/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/utils.py @@ -1,31 +1,3 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import numpy as np # from skimage.io import imsave # import scipy.misc -- Gitee From 8d7ac8a9067fbb41758d2b15046166352a0ed379 Mon Sep 17 00:00:00 2001 From: CCCP <648976749@qq.com> Date: Wed, 17 Aug 2022 23:49:32 +0800 Subject: [PATCH 2/4] prsr_pr --- .idea/.name | 2 +- .../contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/data.py | 3 ++- .../contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/solver.py | 5 +++-- .../contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/utils.py | 2 +- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/.idea/.name b/.idea/.name index 4c5b8f06e..eadf20fea 100644 --- a/.idea/.name +++ b/.idea/.name @@ -1 +1 @@ -train_pixel_link.py \ No newline at end of file +train.py \ No newline at end of file diff --git a/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/data.py b/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/data.py index 0d96d3784..ae3f5d072 100644 --- a/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/data.py +++ b/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/data.py @@ -23,7 +23,8 @@ class DataSet(object): with open(images_list_path, 'r') as f: num_example = len(list(f)) - dataset = dataset.map(parse_example).shuffle(buffer_size = num_example).batch(batch_size, drop_remainder=True).repeat(num_epoch) + dataset = dataset.map(parse_example).shuffle(buffer_size = num_example).\ + batch(batch_size, drop_remainder=True).repeat(num_epoch) iterator = dataset.make_one_shot_iterator() diff --git a/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/solver.py b/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/solver.py index bb30f440c..5356d262c 100644 --- a/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/solver.py +++ b/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/solver.py @@ -90,7 +90,8 @@ class Solver(object): # ------------------ NPU 混合精度 ---------------------------- custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision") - custom_op.parameter_map["modify_mixlist"].s = tf.compat.as_bytes("/home/ma-user/modelarts/user-job-dir/code/ops_info.json") + custom_op.parameter_map["modify_mixlist"].s = tf.compat.as_bytes\ + ("/home/ma-user/modelarts/user-job-dir/code/ops_info.json") # ---------------------------------------------------------- # #---------------- NPU DUMP -------------------------- @@ -155,7 +156,7 @@ class Solver(object): coord.join(threads) sess.close() - def sample(self, sess, mu = 1.1, step = None): + def sample(self, sess, mu=1.1, step=None): c_logits = self.net.conditioning_logits p_logits = self.net.prior_logits lr_imgs = self.dataset.lr_images diff --git a/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/utils.py b/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/utils.py index 08a66d9f2..397d856c7 100644 --- a/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/utils.py +++ b/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/utils.py @@ -19,7 +19,7 @@ def save_samples(np_imgs, img_path): merge_img = np.zeros((num * H, num * W, 3), dtype=np.uint8) for i in range(num): for j in range(num): - merge_img[i*H:(i+1)*H, j*W:(j+1)*W, :] = np_imgs[i*num+j,:,:,:] + merge_img[i*H:(i+1)*H, j*W:(j+1)*W, :] = np_imgs[i*num+j, :, :, :] # imsave(img_path, merge_img) # misc.imsave(img_path, merge_img) -- Gitee From 6d8dbc5662ddf76ded9c1d4631d9623903e94006 Mon Sep 17 00:00:00 2001 From: CCCP <648976749@qq.com> Date: Fri, 2 Sep 2022 01:04:36 +0800 Subject: [PATCH 3/4] prsr_pr --- .../prsr/PRSR_ID2111_for_TensorFlow/ops_info.json | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/ops_info.json diff --git a/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/ops_info.json b/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/ops_info.json new file mode 100644 index 000000000..e67141812 --- /dev/null +++ b/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/ops_info.json @@ -0,0 +1,15 @@ +{ + "black-list": { + "to-remove" : [ + ], + "to-add" : [ + "ReduceSumD" + ] + }, + "white-list" : { + "to-remove" :[ + ], + "to-add" : [ + ] + } +} \ No newline at end of file -- Gitee From b46738bad7e90b193fcd4ad2a83b8832d9bc28eb Mon Sep 17 00:00:00 2001 From: CCCP <648976749@qq.com> Date: Sat, 3 Sep 2022 23:13:05 +0800 Subject: [PATCH 4/4] add inference.py --- .../PRSR_ID2111_for_TensorFlow/inference.py | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/inference.py diff --git a/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/inference.py b/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/inference.py new file mode 100644 index 000000000..b399e3042 --- /dev/null +++ b/TensorFlow/contrib/cv/prsr/PRSR_ID2111_for_TensorFlow/inference.py @@ -0,0 +1,59 @@ +from net import * +from data import * +from ops import * +from utils import * +import tensorflow as tf +import numpy as np + +batch_size = 1 + +dataset = DataSet("./train.txt", 30, batch_size) +net = Net(dataset.hr_images, dataset.lr_images, 'prsr') + +init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) +saver = tf.train.Saver() + +# Create a session for running operations in the Graph. +config = tf.ConfigProto(allow_soft_placement = True) +config.gpu_options.allow_growth = True +sess = tf.Session(config = config) +sess.run(init_op) +saver.restore(sess, './output/model.ckpt-280000') + +coord = tf.train.Coordinator() +threads = tf.train.start_queue_runners(sess=sess, coord=coord) + +c_logits = net.conditioning_logits +p_logits = net.prior_logits +lr_imgs = dataset.lr_images +hr_imgs = dataset.hr_images +np_hr_imgs, np_lr_imgs = sess.run([hr_imgs, lr_imgs]) +gen_hr_imgs = np.zeros((batch_size, 32, 32, 3), dtype = np.float32) +np_c_logits = sess.run(c_logits, feed_dict = {lr_imgs: np_lr_imgs, net.train: False}) + +mu = 1.0 +for i in range(32): + for j in range(32): + for c in range(3): + np_p_logits = sess.run(p_logits, feed_dict = {hr_imgs: gen_hr_imgs}) + new_pixel = logits_2_pixel_value( + np_c_logits[:, i, j, c * 256:(c + 1) * 256] + np_p_logits[:, i, j, c * 256:(c + 1) * 256], + mu = mu) + gen_hr_imgs[:, i, j, c] = new_pixel +save_samples(gen_hr_imgs, './generate_imgs' + '.jpg') +save_samples(np_hr_imgs, './hr_imgs' + '.jpg') + +import cv2 +import numpy as np +import math + +def psnr1(img1, img2): + mse = np.mean((img1 / 255. - img2 / 255.) ** 2) + if mse < 1.0e-10: + return 100 + PIXEL_MAX = 1 + return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) + +img1 = cv2.imread("./generate_imgs.jpg") +img2 = cv2.imread("./hr_imgs.jpg") +print("PSNR is ", psnr1(img1, img2)) \ No newline at end of file -- Gitee