diff --git a/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/.keep b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/.keep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/Basic_Model.py b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/Basic_Model.py
new file mode 100644
index 0000000000000000000000000000000000000000..700d4e99d9e43066e5531187ab5664a026686a2f
--- /dev/null
+++ b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/Basic_Model.py
@@ -0,0 +1,170 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from npu_bridge.npu_init import *
+import tensorflow as tf
+from tensorflow.contrib import layers
+import sys
+
+class basic_network(object):
+ def __init__(self, cfg):
+ self.training=True
+ self.cfg = cfg
+ self.params_count = 0 #the amount of parameters
+ def init_params(self, *args, **kwargs):
+ def _variable_on_cpu(w_shape, b_shape, weight_decay=0.99, use_bias=True, name="conv"):
+ with tf.device('/cpu:0'):
+ w = tf.Variable(tf.truncated_normal(w_shape, 0.0, 0.001), trainable=True, name="%s_w" % name)
+ tf.add_to_collection(name="weights_l2_loss", value=self.calc_l1_loss(w, weight_decay))
+ b = tf.Variable(tf.zeros(b_shape), trainable=use_bias, name="%s_b" % name)
+ return w, b
+ kernel_size = kwargs["kernel_size"]
+ in_channels = kwargs["in_channels"]
+ out_channels = kwargs["out_channels"]
+ # weight_decay = kwargs["weight_decay"]
+ w_shape = [kernel_size, kernel_size, in_channels, out_channels]
+ b_shape = [out_channels]
+ name = kwargs["name"]
+ self.params_count += kernel_size*kernel_size*in_channels*out_channels
+ self.params_count += out_channels
+ return _variable_on_cpu(w_shape, b_shape, use_bias=kwargs["use_bias"], name=name)
+
+ def calc_loss(self, *args, **kwargs):
+ loss_type = kwargs["loss_type"]
+ x = kwargs["x"]
+ y = kwargs["y"]
+ if loss_type == "L1":
+ return tf.reduce_sum(tf.abs(x-y), name="L1_loss")
+ elif loss_type == "L2":
+ return tf.nn.l2_loss((x-y), name="L2_loss")
+
+ def activation(self, *args, **kwargs):
+ act_type = kwargs["act_type"]
+ act_type = act_type.lower()
+ if act_type == "relu":
+ return tf.nn.relu(args[0])
+ elif act_type == "lrelu":
+ slope = kwargs["slope"]
+ y = slope*args[0]
+ return tf.maximum(args[0], y)
+ elif act_type == "prelu":
+ return tf.nn.leaky_relu(args[0], alpha=0.2)
+ elif act_type == "tanh":
+ return tf.nn.tanh(args[0])
+ else:
+ return args[0]
+
+ def calc_l2_loss(self, weight, weight_decay):
+ _, _, _, outchannel = weight.get_shape().as_list()
+ return (weight_decay) * tf.reduce_sum(tf.square(weight)) / outchannel
+
+ def calc_l1_loss(self, weight, weight_decay):
+ _, _, _, outchannel = weight.get_shape().as_list()
+ return (weight_decay)*tf.reduce_sum(tf.abs(weight)) / outchannel
+
+ def batch_norm(self, *args, **kwargs):
+ return tf.layers.batch_normalization(args[0], training=kwargs["training"])
+
+ def instance_norm(self, *args, **kwargs):
+ return layers.instance_norm(args[0], kwargs["name"])
+
+ def hard_sigmoid(self, x):
+ return tf.nn.relu6((x+3)/6)
+
+ def hard_swish(self, x):
+ return x * self.hard_sigmoid(x)
+
+ def global_average_pooling(self, x, name="GAP"):
+ return tf.reduce_mean(x, axis=[1, 2], keep_dims=True, name="Global_Average_Pooling_%s" % name)#不降维
+
+
+ def ConvBlock(self,x, in_channels, out_channels, kernel_size, stride=1, name="ConvBlock",
+ BN=True, use_bias=True, padding="VALID", act_type="relu", mode="CNA"):
+
+
+ assert (mode in ['CNA', 'NAC']), '[ERROR] Wrong mode in [%s]!' % sys.modules[__name__]#断言
+ weight, bias = self.init_params(kernel_size=kernel_size, in_channels=in_channels,
+ out_channels=out_channels, use_bias=use_bias, name=name)
+ if mode == "CNA":
+ x = tf.nn.conv2d(x, filter=weight, strides=[1, stride, stride, 1], padding=padding)
+ x = tf.nn.bias_add(x, bias)
+ if BN:
+ if self.cfg.BN_type == "BN":
+ x = self.batch_norm(x, training=self.cfg.istrain)
+ elif self.cfg.BN_type == "IN":
+ x = self.instance_norm(x, name="%s_IN"%name)
+ else:
+ raise NotImplementedError('[ERROR] BN type [%s] is not implemented!' % self.cfg.BN_type)
+ x = self.activation(x, act_type=act_type)
+ return x
+ elif mode=="NAC":
+ if BN:
+ if self.cfg.BN_type == "BN":
+ x = self.batch_norm(x, training=self.cfg.istrain)
+ elif self.cfg.BN_type == "IN":
+ x = self.instance_norm(x, name="%s_IN" % name)
+ else:
+ raise NotImplementedError('[ERROR] BN type [%s] is not implemented!' % self.cfg.BN_type)
+ x = self.activation(x, act_type=act_type)
+ x = tf.nn.conv2d(x, filter=weight, strides=[1, stride, stride, 1], padding=padding)
+ x = tf.nn.bias_add(x, bias)
+ return x
+
+ def DeConvBlock(self, x, in_channels, out_channels, kernel_size, stride=1, name="DeConvBlock",
+ BN=True, use_bias=True, padding="VALID", act_type="relu", mode="CNA"):
+ assert (mode in ['CNA', 'NAC']), '[ERROR] Wrong mode in [%s]!' % sys.modules[__name__]
+ b, h, w, c = x.get_shape().as_list()
+ out_shape = [b, h * self.cfg.scale, w * self.cfg.scale, out_channels]
+ weight, bias = self.init_params(kernel_size=kernel_size, in_channels=out_channels,
+ out_channels=in_channels, use_bias=use_bias, name=name)
+ if mode == "CNA":
+ x = tf.nn.conv2d_transpose(x, filter=weight, output_shape=out_shape,
+ strides=[1, stride, stride, 1], padding=padding)
+ x = tf.nn.bias_add(x, bias)
+ if BN:
+ if self.cfg.BN_type == "BN":
+ x = self.batch_norm(x, training=True)
+ elif self.cfg.BN_type == "IN":
+ x = self.instance_norm(x, name="%s_IN" % name)
+ else:
+ raise NotImplementedError('[ERROR] BN type [%s] is not implemented!' % self.cfg.BN_type)
+ x = self.activation(x, act_type=act_type)
+ return x
+ elif mode == "NAC":
+ if BN:
+ if self.cfg.BN_type == "BN":
+ x = self.batch_norm(x, training=True)
+ elif self.cfg.BN_type == "IN":
+ x = self.instance_norm(x, name="%s_IN" % name)
+ else:
+ raise NotImplementedError('[ERROR] BN type [%s] is not implemented!' % self.cfg.BN_type)
+ x = self.activation(x, act_type=act_type)
+ x = tf.nn.conv2d_transpose(x, filter=weight, output_shape=out_shape,
+ strides=[1, stride, stride, 1], padding=padding)
+ x = tf.nn.bias_add(x, bias)
+ return x
diff --git a/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/LICENSE b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..5ea8a5f7b6ae91ebb12b7f2fa71a5432bb89de63
--- /dev/null
+++ b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/LICENSE
@@ -0,0 +1,284 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+------------------
+Files: third_party/compute_library/...
+
+MIT License
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+------------------
+Files: ACKNOWLEDGEMENTS
+LICENSE
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+------------------
+Files: third_party/hexagon
+
+Copyright (c) 2016-2019, The Linux Foundation. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted (subject to the limitations in the
+disclaimer below) provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above
+ copyright notice, this list of conditions and the following
+ disclaimer in the documentation and/or other materials provided
+ with the distribution.
+
+ * Neither the name of The Linux Foundation nor the names of its
+ contributors may be used to endorse or promote products derived
+ from this software without specific prior written permission.
+
+NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE
+GRANTED BY THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT
+HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
+WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
+MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
+IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
+ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
+GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
+IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
+OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
+IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
diff --git a/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/PreProcess.py b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/PreProcess.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ed011cdbf6684e72110a05a5a221ee64ca3f528
--- /dev/null
+++ b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/PreProcess.py
@@ -0,0 +1,98 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from npu_bridge.npu_init import *
+import cv2
+import numpy as np
+import random
+from skimage import util
+
+
+def add_noise(img):
+ mode_types = ['gaussian', 'localvar', 'poisson', 'speckle'] # 'salt', 'pepper', 's&p' is too fake
+ inx = int(np.random.choice(np.arange(len(mode_types)), 1))
+ inx = 0
+ img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB)
+ mean = random.random() * 0.001 # + 0.001#random.random() generates number between 0 and 1
+ var = random.random() * 0.002 # + 0.01
+ noise_img = util.random_noise(img.copy(), mode=mode_types[inx],
+ mean=mean,
+ var=var)
+ return noise_img
+
+
+def augment_data(img_patch, flip, rot): # img_patchs : n,h,w,c
+ if flip==1:
+ img_patch = img_patch[:, ::-1, :] # hflip
+ elif flip==2:
+ img_patch = img_patch[::-1, :, :] # vflip
+ if rot==1:
+ img_patch = cv2.rotate(img_patch, cv2.ROTATE_90_CLOCKWISE)
+ elif rot==2:
+ img_patch = cv2.rotate(img_patch, cv2.ROTATE_90_COUNTERCLOCKWISE)
+ return img_patch
+
+def preprocess(imgs, cfg):
+ LR_patchs, HR_patchs = [], []
+ for img in imgs:
+
+ HR = cv2.imread(img.strip(), cv2.IMREAD_COLOR)
+ HR = (HR - 127.5) / 128
+ h, w, c = HR.shape
+
+ x_stride = w // (cfg.imagesize * cfg.scale)
+ y_stride = h // (cfg.imagesize * cfg.scale)
+
+ for x in range(x_stride):
+ for y in range(y_stride):
+ HR_patch = HR[y * cfg.imagesize * cfg.scale:(y + 1) * cfg.imagesize * cfg.scale,
+ x * cfg.imagesize * cfg.scale:(x + 1) * cfg.imagesize * cfg.scale, :]
+ # add noise && add blur
+ t = np.random.randint(0, 2, 1)
+ if t == 0:
+ LR_patch = cv2.resize(HR_patch, dsize=None, fx=1 / cfg.scale, fy=1 / cfg.scale,
+ interpolation=cv2.INTER_CUBIC)
+ LR_patch = np.clip(LR_patch, -1.0, 1.0)
+ #LR_patch = add_noise(LR_patch)
+ else:
+ #LR_patch = add_noise(HR_patch) # [-1, 1]
+ LR_patch = cv2.resize(HR_patch, dsize=None, fx=1 / cfg.scale,
+ fy=1 / cfg.scale, interpolation=cv2.INTER_LINEAR)
+ # data augment
+ if cfg.istrain:
+ rot = np.random.randint(0, 3, 1)
+ flip = np.random.randint(0, 3, 1)
+ LR_patch = augment_data(LR_patch, flip, rot)
+ HR_patch = augment_data(HR_patch, flip, rot)
+ LR_patchs.append(LR_patch)
+ HR_patchs.append(HR_patch)
+
+ return HR_patchs, LR_patchs
+
+
+
diff --git a/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/README.md b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..edf8fc968564ffc42699a6382b0938b36b51c4fd
--- /dev/null
+++ b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/README.md
@@ -0,0 +1,116 @@
+- [基本信息](#基本信息.md)
+- [概述](#概述.md)
+- [训练环境准备](#训练环境准备.md)
+- [快速上手](#快速上手.md)
+- [训练结果](#训练结果.md)
+- [高级参考](#高级参考.md)
+
基本信息
+
+**发布者(Publisher):Huawei**
+
+**应用领域(Application Domain):Computer Vision**
+
+**修改时间(Modified) :2022.11.6**
+
+**框架(Framework):TensorFlow 1.15.0**
+
+**描述(Description):基于TensorFlow框架对高清图片重建相应的超分辨率图片的训练代码**
+
+概述
+
+```
+SRFBN是采取反馈连接来提高重建超分辨率图片效果的网络模型
+```
+- 参考论文:
+
+ https://arxiv.org/pdf/1903.09814.pdf
+
+- 参考实现:
+
+ https://github.com/turboLIU/SRFBN-tensorflow/blob/master/train.py
+
+## 默认配置
+
+- 训练数据集预处理:
+
+ - 图像的输入尺寸为64*64
+- 测试数据集预处理:
+
+ - 图像的输入尺寸为64*64
+- 训练超参
+
+ - Batch size: 1
+ - Train epoch: 1000
+
+
+快速上手
+
+- 数据集准备
+1. 模型训练使用DIV2K数据集。
+
+## 模型训练
+
+- 单卡训练
+
+ 1. 配置训练参数。
+
+ 首先在脚本test/train_performance_1p.sh中,配置batch_size、epochs、data_path等参数,请用户根据实际路径配置data_path,或者在启动训练的命令行中以参数形式下发。
+
+ ```
+ batch_size=1
+ epochs=1000
+ data_path="../DIV2K/DIV2K_train_HR"
+ ```
+
+ 2. 启动训练。
+
+ 启动单卡训练 (脚本为SRFBN_for_TensorFlow/test/train_performance_1p.sh)
+
+ ```
+ bash train_performance_1p.sh --data_path=../DIV2K/DIV2K_train_HR
+ ```
+
+训练结果
+
+- 精度结果比对
+
+| 精度指标项 | GPU实测 | NPU实测 |
+| ---------- | ----------- | ----------- |
+| PSNR | 6.706763287 | 5.831956861 |
+
+- 性能结果比对
+
+| 性能指标项 | GPU实测 | NPU实测 |
+| ---------- | -------------- | -------------- |
+| FPS | 3.358950029841 | 4.976489075014 |
+
+
+高级参考
+
+## 脚本和示例代码
+
+```
+├── Basic_Model.py //基本模型代码
+├── README.md //代码说明文档
+├── config.py //模型配置代码
+├── PreProcess.py //数据预处理代码
+├── requirements.txt //训练python依赖列表
+├── SRFBN_model.py //SRFBN网络模型代码
+├── test.py //测试代码
+├── train.py //训练代码
+├── test
+│ ├──train_performance_1p.sh //单卡训练验证性能启动脚本
+│ ├──train_full_1p.sh //单卡全量训练启动脚本
+```
+
+## 脚本参数
+
+```
+--data_path 数据集路径,默认:path/data
+--batch_size 每个NPU的batch size,默认:1
+--epochs 训练epcoh数量,默认:1000
+```
+
+## 训练过程
+
+1. 通过“模型训练”中的训练指令启动单卡训练。
\ No newline at end of file
diff --git a/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/SRFBN_model.py b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/SRFBN_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..4afaf9bb579f3c760394d28d114522832aa99c29
--- /dev/null
+++ b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/SRFBN_model.py
@@ -0,0 +1,209 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from npu_bridge.npu_init import *
+import tensorflow as tf
+import os
+from Basic_Model import basic_network
+
+
+class SRFBN(basic_network):
+ def __init__(self, sess, cfg):
+ super(SRFBN, self).__init__(cfg)
+ self.sess = sess
+ imageshape = [cfg.batchsize, cfg.imagesize, cfg.imagesize, cfg.c_dim]
+ labelshape = [cfg.batchsize, cfg.imagesize * cfg.scale, cfg.imagesize * cfg.scale, cfg.c_dim]
+ self.imageplaceholder = tf.placeholder(dtype=tf.float32, shape=imageshape, name="image")
+ self.labelplaceholder = tf.placeholder(dtype=tf.float32, shape=labelshape, name="label")
+ self.last_hidden = None
+ self.should_reset = True
+ self.outs = []
+ #FB block that forwardpros feedback information
+ def FeedBackBlock(self, x, num_features, num_groups, act_type, name="FBB"):
+ if self.cfg.scale == 1:
+ stride = 1
+ padding = "SAME"
+ kernel_size = 5
+ if self.cfg.scale==2:
+ stride = 2
+ padding = "SAME"
+ kernel_size = 6
+ if self.cfg.scale == 3:
+ stride = 3
+ padding = "SAME"
+ kernel_size = 7
+ if self.cfg.scale == 4:
+ stride = 4
+ padding = "SAME"
+ kernel_size = 8
+ if self.should_reset:
+ self.last_hidden = x
+ self.should_reset = False
+ x = tf.concat([x, self.last_hidden], 3)
+ x = self.ConvBlock(x, 2*num_features, num_features, kernel_size=1, name="FeedBack_compress_in",
+ act_type=act_type)
+
+ lr_features = []
+ hr_features = []
+ lr_features.append(x)
+ for i in range(num_groups):
+ x = tf.concat(lr_features, 3)
+ if i > 0:
+ x = self.ConvBlock(x, num_features*(i+1), num_features, kernel_size=1,stride=1,
+ padding=padding, act_type=act_type, name="%s_%d"%(name, i))
+ x = self.DeConvBlock(x, num_features, num_features, kernel_size=kernel_size, stride=stride,
+ padding=padding, act_type=act_type, name="%s_%d"%(name, i))
+ hr_features.append(x)
+ x = tf.concat(hr_features, 3)
+ if i > 0:
+ x = self.ConvBlock(x, num_features*(i+1), num_features, kernel_size=1, stride=1,
+ padding=padding, act_type=act_type, name="%s_%d"%(name, i))
+ x = self.ConvBlock(x, num_features, num_features, kernel_size=kernel_size, stride=stride,
+ padding=padding, act_type=act_type, name="%s_%d"%(name, i))
+ lr_features.append(x)
+ del hr_features
+
+ x = tf.concat(lr_features[1:], 3)
+
+ x = self.ConvBlock(x, num_features*num_groups, num_features, kernel_size=1,
+ act_type=act_type, name="FeedBack_compress_out")
+
+ self.last_hidden = x
+
+ return x
+
+ def build(self):
+ if self.cfg.scale == 2:
+ stride = 2
+ padding = "SAME"
+ kernel_size = 6
+ if self.cfg.scale == 3:
+ stride = 3
+ padding = "SAME"
+ kernel_size = 7
+ if self.cfg.scale == 4:
+ stride = 4
+ padding = "SAME"
+ kernel_size = 8
+ # x = self.sub_mean(self.imageplaceholder) # normalize
+
+ _, height, width, _ = self.imageplaceholder.get_shape().as_list()
+
+ inter_size = tf.constant([height*self.cfg.scale, width*self.cfg.scale])
+ inter_res = tf.image.resize_images(self.imageplaceholder, inter_size)
+ # inter_res = self.imageplaceholder
+
+ x = self.ConvBlock(self.imageplaceholder, self.cfg.in_channels, 4 * self.cfg.num_features, kernel_size=3,
+ act_type=self.cfg.act_type, padding="SAME", name="conv_in")
+ x = self.ConvBlock(x, 4*self.cfg.num_features, self.cfg.num_features, kernel_size=1,
+ act_type=self.cfg.act_type, padding="SAME", name="feat_in")
+ # outs = []
+ for i in range(self.cfg.num_steps):
+ if i == 0:
+ self.should_reset=True
+ t = self.FeedBackBlock(x, self.cfg.num_features, self.cfg.num_groups, self.cfg.act_type, name="FBB_%d"%i)
+ t = self.DeConvBlock(t, self.cfg.num_features, self.cfg.num_features, kernel_size=kernel_size,
+ stride=stride, padding=padding, act_type="relu", name="out_%d"%i)
+ t = self.ConvBlock(t, self.cfg.num_features, self.cfg.out_channels, kernel_size=3, stride=1,
+ act_type="tanh", padding="SAME", name="conv_out")
+ t = inter_res + t
+ t = tf.clip_by_value(t, -1.0, 1.0)
+ # t = t + inter_res
+ # t = self.add_mean(t)
+ self.outs.append(t)
+
+ def train_step(self):
+ self.build()
+ print("This Net has Params num is %f MB" % (self.params_count * 4 / 1024 / 1024)) # float32
+ tf.summary.image("image/HR", self.labelplaceholder, max_outputs=1)
+ out = tf.add_n(self.outs)/self.cfg.num_steps
+
+ tf.summary.image("image/SR", out, max_outputs=1)
+ tf.summary.image("image/LR", self.imageplaceholder, max_outputs=1)
+
+ self.l2_regularization_loss = tf.reduce_sum(tf.get_collection("weights_l2_loss"))
+
+ self.losses = [self.calc_loss(x=x, y=self.labelplaceholder, loss_type=self.cfg.loss_type) for x in self.outs]
+ self.losses = tf.reduce_sum(self.losses)/len(self.losses)/self.cfg.batchsize + self.l2_regularization_loss
+
+ tf.summary.scalar('loss/total', self.losses)
+ tf.summary.scalar('loss/l2_loss', self.l2_regularization_loss)
+
+ self.merged_summary = tf.summary.merge_all()
+ self.saver = tf.train.Saver(max_to_keep=1)
+ #loading ckpt
+ def load(self):
+ model_name = "SRFBN.model"
+ model_dir = "%s_%s_%s_%s_c%d_x%s" % (
+ "SRFBN", self.cfg.num_features, self.cfg.num_steps, self.cfg.num_groups, self.cfg.c_dim, self.cfg.scale)
+ checkpoint_dir = os.path.join(self.cfg.checkpoint_dir, model_dir)
+ ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
+ if ckpt and ckpt.model_checkpoint_path:
+ ckpt_path = str(ckpt.model_checkpoint_path)
+ self.saver.restore(self.sess, os.path.join(os.getcwd(), ckpt_path))
+ step = int(os.path.basename(ckpt_path).split('-')[1])
+ print("\nCheckpoint Loading Success! %s\n" % ckpt_path)
+ else:
+ step = 0
+ print("\nCheckpoint Loading Failed! \n")
+
+ return step
+ #save model
+ def save(self, step):
+ model_name = "SRFBN.model"
+ model_dir = "%s_%s_%s_%s_c%d_x%s" % \
+ ("SRFBN", self.cfg.num_features, self.cfg.num_steps,
+ self.cfg.num_groups, self.cfg.c_dim, self.cfg.scale)
+ checkpoint_dir = os.path.join(self.cfg.checkpoint_dir, model_dir)
+
+ if not os.path.exists(checkpoint_dir):
+ os.makedirs(checkpoint_dir)
+
+ self.saver.save(self.sess,
+ os.path.join(checkpoint_dir, model_name),
+ global_step=step)
+ #test
+ def test(self, width, height):
+ self.cfg.batchsize = 1
+ testshape = [self.cfg.batchsize, height, width, self.cfg.c_dim]
+ labelshape = [self.cfg.batchsize, height*self.cfg.scale, width*self.cfg.scale, self.cfg.c_dim]
+ self.imageplaceholder = tf.placeholder(dtype=tf.float32, shape=testshape)
+ self.labelplaceholder = tf.placeholder(dtype=tf.float32, shape=labelshape)
+ self.build()
+ # self.outs = [self.add_mean(x) for x in self.outs]
+ out = tf.add_n(self.outs)/self.cfg.num_steps
+ # out = tf.concat(self.outs, -1)
+ return out
+
+
+if __name__ == '__main__':
+ from config import SRFBN_config as config
+ cfg = config()
+ sess = tf.Session(config=npu_config_proto())
+ net = SRFBN(sess, cfg)
+ train_step = net.train_step()
diff --git a/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/config.py b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..391bbadccb78021c08332117159329c5551b6325
--- /dev/null
+++ b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/config.py
@@ -0,0 +1,81 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from npu_bridge.npu_init import *
+import os
+
+class config:
+ def __init__(self):
+ self.batchsize = 1
+ self.Process_num = 3
+ self.maxsize = 200
+ self.ngpu = 1
+ self.imagesize = 64
+ self.scale = 3
+ self.epoch = 1000
+ #create ckpt,log,result dir
+ self.checkpoint_dir = "./model"
+ if not os.path.exists(self.checkpoint_dir):
+ os.mkdir(self.checkpoint_dir)
+ self.log_dir = "./log"
+ if not os.path.exists(self.log_dir):
+ os.mkdir(self.log_dir)
+ self.result = "./result"
+ if not os.path.exists(self.result):
+ os.mkdir(self.result)
+
+
+
+class SRFBN_config(config):
+ def __init__(self):
+ super(SRFBN_config, self).__init__()
+ self.istrain = True#is train or is test
+ self.istest = not self.istrain
+ self.c_dim = 3 #color channel can train one-channel pic or RGB pic
+ self.in_channels = 3
+ self.out_channels = 3
+ self.num_features = 32#base number of filter
+ self.num_steps = 4# timestep
+ self.num_groups = 6#the number of projection group of FBB feedbackblock
+ self.BN = True
+ if self.BN:
+ self.BN_type = "BN" # "BN" # or "IN"
+ self.act_type = "prelu" #activation function
+ self.loss_type = "L2"
+ self.lr_steps = [150, 300, 550, 750]#iteration
+ self.lr_gama = 1
+ self.learning_rate = 2e-7#learning rate
+ self.load_premodel = True
+ #create dir
+ self.srfbn_logdir = "%s/srfbn" % self.log_dir
+ if not os.path.exists(self.srfbn_logdir):
+ os.mkdir(self.srfbn_logdir)
+ self.srfbn_result = "%s/srfbn" % self.result
+ if not os.path.exists(self.srfbn_result):
+ os.mkdir(self.srfbn_result)
+
diff --git a/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/modelzoo_level.txt b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/modelzoo_level.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9981888d4330a88ecfc05366e577609f83017194
--- /dev/null
+++ b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/modelzoo_level.txt
@@ -0,0 +1,6 @@
+GPUStatus:OK
+NPUMigrationStatus:OK
+FuncStatus:OK
+PrecisionStatus:POK
+AutoTune:OK
+PerfStatus:PERFECT
\ No newline at end of file
diff --git a/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/requirements.txt b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e04023578103b7c569e1fd39a65bc81d44632a6b
--- /dev/null
+++ b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/requirements.txt
@@ -0,0 +1,4 @@
+tensorflow==1.15.0
+cv2
+numpy
+scikit-image
\ No newline at end of file
diff --git a/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/test.py b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..be8f5570f924ef2ef612b529b78e32cfafeb0977
--- /dev/null
+++ b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/test.py
@@ -0,0 +1,87 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from npu_bridge.npu_init import *
+import tensorflow as tf
+
+from SRFBN_model import SRFBN
+from PreProcess import *
+from skimage.metrics import peak_signal_noise_ratio as compare_psnr
+from skimage.metrics import _structural_similarity
+
+
+
+
+def test_SRFBN(image_lr,image_hr):
+
+ #image
+ height, width, _ = image_lr.shape
+ global load_flag
+ global srfbn
+ global out
+ if load_flag == 0:
+ srfbn = SRFBN(sess, cfg)
+ out = srfbn.test(width, height)
+ tf.global_variables_initializer().run(session=sess)
+ srfbn.saver = tf.train.Saver(max_to_keep=1)
+ srfbn.load()
+ srfbn.l2_regularization_loss = tf.reduce_sum(tf.get_collection("weights_l2_loss"))
+ srfbn.losses = [srfbn.calc_loss(x=x, y=srfbn.labelplaceholder, loss_type=srfbn.cfg.loss_type) for x in
+ srfbn.outs]
+ srfbn.losses = tf.reduce_sum(srfbn.losses) / len(srfbn.losses) / srfbn.cfg.batchsize + srfbn.l2_regularization_loss
+ load_flag += 1
+ #cv2.namedWindow("result", 0)
+
+ img_hr = image_hr.reshape([1,height*srfbn.cfg.scale,width*srfbn.cfg.scale,3])
+ img_lr = image_lr.reshape([1, height, width, 3])
+ output,err,l2_loss = sess.run([out,srfbn.losses,srfbn.l2_regularization_loss], feed_dict={srfbn.imageplaceholder: img_lr,srfbn.labelplaceholder:img_hr})
+ output = output[0] * 128 + 127.5
+ img_hr = img_hr.reshape([height*srfbn.cfg.scale,width*srfbn.cfg.scale,3])
+ PSNR = compare_psnr(output, img_hr, data_range=255)
+ ssim = _structural_similarity.structural_similarity(output, img_hr,win_size=11, data_range=255, multichannel=True)
+ print("loss:[%.8f], l2_loss:[%.8f], PSNR:[%.8f], SSIM:[%.8f]"%(err,l2_loss,PSNR,ssim))
+ #cv2.imshow("result", np.uint8(output))
+ #cv2.waitKey(0)
+
+
+
+if __name__ == '__main__':
+ sess = tf.Session(config=npu_config_proto())
+ from config import SRFBN_config
+ cfg = SRFBN_config()
+ cfg.istest = True
+ cfg.istrain = False
+ image = "/home/TestUser08/BUAA/Resolution_2K/DIV2K/DIV2K_valid_HR/0801.png"
+ batch_label,batch_lrimage = preprocess([image,],cfg)
+ batch_lrimage = np.array(batch_lrimage)
+ batch_label = np.array(batch_label)
+ load_flag = 0
+ for i in range(batch_label.shape[0]):
+ test_SRFBN(batch_lrimage[i],batch_label[i])
+ srfbn.sess.close()
+
diff --git a/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/test/.keep b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/test/.keep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/test/train_full_1p.sh b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/test/train_full_1p.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1242fb044895a27e42122ecc43b0f3950dd10619
--- /dev/null
+++ b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/test/train_full_1p.sh
@@ -0,0 +1,184 @@
+#!/bin/bash
+
+##########################################################
+#########第3行 至 100行,请一定不要、不要、不要修改##########
+#########第3行 至 100行,请一定不要、不要、不要修改##########
+#########第3行 至 100行,请一定不要、不要、不要修改##########
+##########################################################
+# shell脚本所在路径
+cur_path=`echo $(cd $(dirname $0);pwd)`
+
+# 判断当前shell是否是performance
+perf_flag=`echo $0 | grep performance | wc -l`
+
+# 当前执行网络的名称
+Network=`echo $(cd $(dirname $0);pwd) | awk -F"/" '{print $(NF-1)}'`
+
+export RANK_SIZE=1
+export RANK_ID=0
+export JOB_ID=10087
+
+# 路径参数初始化
+data_path=""
+output_path=""
+
+# 帮助信息,不需要修改
+if [[ $1 == --help || $1 == -h ]];then
+ echo "usage:./train_performance_1P.sh "
+ echo " "
+ echo "parameter explain:
+ --data_path # dataset of training
+ --output_path # output of training
+ --train_steps # max_step for training
+ --train_epochs # max_epoch for training
+ --batch_size # batch size
+ -h/--help show help message
+ "
+ exit 1
+fi
+
+# 参数校验,不需要修改
+for para in $*
+do
+ if [[ $para == --data_path* ]];then
+ data_path=`echo ${para#*=}`
+ elif [[ $para == --output_path* ]];then
+ output_path=`echo ${para#*=}`
+ elif [[ $para == --train_steps* ]];then
+ train_steps=`echo ${para#*=}`
+ elif [[ $para == --train_epochs* ]];then
+ train_epochs=`echo ${para#*=}`
+ elif [[ $para == --batch_size* ]];then
+ batch_size=`echo ${para#*=}`
+ fi
+done
+
+# 校验是否传入data_path,不需要修改
+if [[ $data_path == "" ]];then
+ echo "[Error] para \"data_path\" must be config"
+ exit 1
+fi
+
+# 校验是否传入output_path,不需要修改
+if [[ $output_path == "" ]];then
+ output_path="./output/${ASCEND_DEVICE_ID}"
+fi
+
+# 设置打屏日志文件名,请保留,文件名为${print_log}
+print_log="./test/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log"
+modelarts_flag=`cat /etc/passwd |grep ma-user`
+if [ x"${modelarts_flag}" != x ];
+then
+ echo "running with modelarts_flag..."
+ print_log_name=`ls /home/ma-user/modelarts/log/ | grep proc-rank`
+ print_log="/home/ma-user/modelarts/log/${print_log_name}"
+fi
+echo "### get your log here : ${print_log}"
+
+CaseName=""
+function get_casename()
+{
+ if [ x"${perf_flag}" = x1 ];
+ then
+ CaseName=${Network}_bs${batch_size}_${RANK_SIZE}'p'_'perf'
+ else
+ CaseName=${Network}_bs${batch_size}_${RANK_SIZE}'p'_'acc'
+ fi
+}
+
+# 跳转到code目录
+cd ${cur_path}/../
+rm -rf ./test/output/${ASCEND_DEVICE_ID}
+mkdir -p ./test/output/${ASCEND_DEVICE_ID}
+
+# 训练开始时间记录,不需要修改
+start_time=$(date +%s)
+##########################################################
+#########第3行 至 100行,请一定不要、不要、不要修改##########
+#########第3行 至 100行,请一定不要、不要、不要修改##########
+#########第3行 至 100行,请一定不要、不要、不要修改##########
+##########################################################
+
+#=========================================================
+#=========================================================
+#========训练执行命令,需要根据您的网络进行修改==============
+#=========================================================
+#=========================================================
+# 基础参数,需要模型审视修改
+# 您的训练数据集在${data_path}路径下,请直接使用这个变量获取
+# 您的训练输出目录在${output_path}路径下,请直接使用这个变量获取
+# 您的其他基础参数,可以自定义增加,但是batch_size请保留,并且设置正确的值
+batch_size=1
+
+if [ x"${modelarts_flag}" != x ];
+then
+ python3 ./train.py
+else
+ python3.7 ./train.py --data_path=${data_path} --output_path=${output_path} 1>${print_log} 2>&1
+fi
+
+# 性能相关数据计算
+StepTime=`((cat ${print_log} | grep "time" | head -n 1) && (cat ${print_log} | grep "time" | tail -n 1)) | awk -F ':' '{print $5 $6 }' | awk -F ',' '{print $1 $2}' | awk -F ' ' '{print $1;print $3}' | awk '{if (NR == 1){a=$1} else if (NR == 2){b=$1} else if (NR == 3){c=$1} else if (NR == 4){d=$1}} END {print (d-b)/(c-a)}'`
+FPS=`awk 'BEGIN{printf "%.2f\n", '${batch_size}'/'${StepTime}'}'`
+#PSNR值计算
+PSNR=`cat ${print_log} | grep "time" | tail -n 10 | awk -F ',' '{print $8}' | awk -F ':' '{sum+=$2} END {print sum/NR}'`
+# 提取所有loss打印信息
+grep "loss:" ${print_log} | awk -F "," '{print $6}' > ./test/output/${ASCEND_DEVICE_ID}/my_output_loss.txt
+
+
+###########################################################
+#########后面的所有内容请不要修改###########################
+#########后面的所有内容请不要修改###########################
+#########后面的所有内容请不要修改###########################
+###########################################################
+
+# 判断本次执行是否正确使用Ascend NPU
+tf_flag=`echo ${Network} | grep TensorFlow | wc -l`
+use_npu_flag=`grep "The model has been compiled on the Ascend AI processor" ${print_log} | wc -l`
+if [ x"${use_npu_flag}" == x0 -a x"${tf_flag}" == x1 ];
+then
+ echo "------------------ ERROR NOTICE START ------------------"
+ echo "ERROR, your task haven't used Ascend NPU, please check your npu Migration."
+ echo "------------------ ERROR NOTICE END------------------"
+else
+ echo "------------------ INFO NOTICE START------------------"
+ echo "INFO, your task have used Ascend NPU, please check your result."
+ echo "------------------ INFO NOTICE END------------------"
+fi
+
+# 获取最终的casename,请保留,case文件名为${CaseName}
+get_casename
+
+# 重命名loss文件
+if [ -f ./test/output/${ASCEND_DEVICE_ID}/my_output_loss.txt ];
+then
+ mv ./test/output/${ASCEND_DEVICE_ID}/my_output_loss.txt ./test/output/${ASCEND_DEVICE_ID}/${CaseName}_loss.txt
+fi
+
+# 训练端到端耗时
+end_time=$(date +%s)
+e2e_time=$(( $end_time - $start_time ))
+
+echo "------------------ Final result ------------------"
+# 输出性能FPS/单step耗时/端到端耗时
+echo "Final Performance images/sec : $FPS"
+echo "Final Performance sec/step : $StepTime"
+echo "E2E Training Duration sec : $e2e_time"
+echo "PSNR : $PSNR"
+# 输出训练精度
+#echo "Final Train Accuracy : ${train_accuracy}"
+
+# 最后一个迭代loss值,不需要修改
+ActualLoss=(`awk 'END {print $NF}' ./test/output/${ASCEND_DEVICE_ID}/my_output_loss.txt`)
+
+#关键信息打印到${CaseName}.log中,不需要修改
+echo "Network = ${Network}" > $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log
+echo "RankSize = ${RANK_SIZE}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log
+echo "BatchSize = ${batch_size}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log
+echo "DeviceType = `uname -m`" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log
+echo "CaseName = ${CaseName}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log
+echo "ActualFPS = ${FPS}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log
+echo "TrainingTime = ${StepTime}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log
+echo "ActualLoss = ${ActualLoss}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log
+echo "E2ETrainingTime = ${e2e_time}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log
+#echo "TrainAccuracy = ${train_accuracy}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log
diff --git a/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/test/train_performance_1p.sh b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/test/train_performance_1p.sh
new file mode 100644
index 0000000000000000000000000000000000000000..708a9d41cfb630d36853a63f0dbaf4aabd602fa7
--- /dev/null
+++ b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/test/train_performance_1p.sh
@@ -0,0 +1,152 @@
+#!/bin/bash
+
+##########################################################
+#########第3行 至 100行,请一定不要、不要、不要修改##########
+#########第3行 至 100行,请一定不要、不要、不要修改##########
+#########第3行 至 100行,请一定不要、不要、不要修改##########
+##########################################################
+# shell脚本所在路径
+cur_path=`echo $(cd $(dirname $0);pwd)`
+
+# 判断当前shell是否是performance
+perf_flag=`echo $0 | grep performance | wc -l`
+
+# 当前执行网络的名称
+Network=`echo $(cd $(dirname $0);pwd) | awk -F"/" '{print $(NF-1)}'`
+
+export RANK_SIZE=1
+export RANK_ID=0
+export JOB_ID=10087
+
+# 路径参数初始化
+data_path=""
+output_path=""
+
+# 帮助信息,不需要修改
+if [[ $1 == --help || $1 == -h ]];then
+ echo "usage:./train_performance_1P.sh "
+ echo " "
+ echo "parameter explain:
+ --data_path # dataset of training
+ --output_path # output of training
+ --train_steps # max_step for training
+ --train_epochs # max_epoch for training
+ --batch_size # batch size
+ -h/--help show help message
+ "
+ exit 1
+fi
+
+# 参数校验,不需要修改
+for para in $*
+do
+ if [[ $para == --data_path* ]];then
+ data_path=`echo ${para#*=}`
+ elif [[ $para == --output_path* ]];then
+ output_path=`echo ${para#*=}`
+ elif [[ $para == --train_steps* ]];then
+ train_steps=`echo ${para#*=}`
+ elif [[ $para == --train_epochs* ]];then
+ train_epochs=`echo ${para#*=}`
+ elif [[ $para == --batch_size* ]];then
+ batch_size=`echo ${para#*=}`
+ fi
+done
+
+# 校验是否传入data_path,不需要修改
+if [[ $data_path == "" ]];then
+ echo "[Error] para \"data_path\" must be config"
+ exit 1
+fi
+
+# 校验是否传入output_path,不需要修改
+if [[ $output_path == "" ]];then
+ output_path="./output/${ASCEND_DEVICE_ID}"
+fi
+
+CaseName=""
+function get_casename()
+{
+ if [ x"${perf_flag}" = x1 ];
+ then
+ CaseName=${Network}_bs${batch_size}_${RANK_SIZE}'p'_'perf'
+ else
+ CaseName=${Network}_bs${batch_size}_${RANK_SIZE}'p'_'acc'
+ fi
+}
+
+# 跳转到code目录
+cd ${cur_path}/../
+rm -rf ./test/output/${ASCEND_DEVICE_ID}
+mkdir -p ./test/output/${ASCEND_DEVICE_ID}
+
+# 训练开始时间记录,不需要修改
+start_time=$(date +%s)
+##########################################################
+#########第3行 至 90行,请一定不要、不要、不要修改##########
+#########第3行 至 90行,请一定不要、不要、不要修改##########
+#########第3行 至 90行,请一定不要、不要、不要修改##########
+##########################################################
+
+#=========================================================
+#=========================================================
+#========训练执行命令,需要根据您的网络进行修改==============
+#=========================================================
+#=========================================================
+# 基础参数,需要模型审视修改
+# 您的训练数据集在${data_path}路径下,请直接使用这个变量获取
+# 您的训练输出目录在${output_path}路径下,请直接使用这个变量获取
+# 您的其他基础参数,可以自定义增加,但是batch_size请保留,并且设置正确的值
+train_epochs=1
+batch_size=1
+print_log="./test/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log"
+python3.7 ./train.py --data_path=${data_path} --output_path=${output_path} 1>${print_log} 2>&1
+StepTime=`((cat ${print_log} | grep "time" | head -n 1) && (cat ${print_log} | grep "time" | tail -n 1)) | awk -F ':' '{print $5 $6 }' | awk -F ',' '{print $1 $2}' | awk -F ' ' '{print $1;print $3}' | awk '{if (NR == 1){a=$1} else if (NR == 2){b=$1} else if (NR == 3){c=$1} else if (NR == 4){d=$1}} END {print (d-b)/(c-a)}'`
+FPS=`awk 'BEGIN{printf "%.2f\n", '${batch_size}'/'${StepTime}'}'`
+#PSNR值计算
+PSNR=`cat ${print_log} | grep "time" | tail -n 10 | awk -F ',' '{print $8}' | awk -F ':' '{sum+=$2} END {print sum/NR}'`
+# 提取所有loss打印信息
+grep "loss:" ${print_log} | awk -F "," '{print $6}' > ./test/output/${ASCEND_DEVICE_ID}/my_output_loss.txt
+
+
+###########################################################
+#########后面的所有内容请不要修改###########################
+#########后面的所有内容请不要修改###########################
+#########后面的所有内容请不要修改###########################
+###########################################################
+
+# 获取最终的casename,请保留,case文件名为${CaseName}
+get_casename
+
+# 重命名loss文件
+if [ -f ./test/output/${ASCEND_DEVICE_ID}/my_output_loss.txt ];
+then
+ mv ./test/output/${ASCEND_DEVICE_ID}/my_output_loss.txt ./test/output/${ASCEND_DEVICE_ID}/${CaseName}_loss.txt
+fi
+
+# 训练端到端耗时
+end_time=$(date +%s)
+e2e_time=$(( $end_time - $start_time ))
+
+echo "------------------ Final result ------------------"
+# 输出性能FPS/单step耗时/端到端耗时
+echo "Final Performance images/sec : $FPS"
+echo "Final Performance sec/step : $StepTime"
+echo "E2E Training Duration sec : $e2e_time"
+echo "PSNR : $PSNR"
+# 输出训练精度
+#echo "Final Train Accuracy : ${train_accuracy}"
+
+# 最后一个迭代loss值,不需要修改
+ActualLoss=(`awk 'END {print $NF}' ./test/output/${ASCEND_DEVICE_ID}/${CaseName}_loss.txt`)
+
+#关键信息打印到${CaseName}.log中,不需要修改
+echo "Network = ${Network}" > $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log
+echo "RankSize = ${RANK_SIZE}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log
+echo "BatchSize = ${batch_size}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log
+echo "DeviceType = `uname -m`" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log
+echo "CaseName = ${CaseName}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log
+echo "ActualFPS = ${FPS}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log
+echo "TrainingTime = ${StepTime}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log
+echo "ActualLoss = ${ActualLoss}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log
+echo "E2ETrainingTime = ${e2e_time}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log
diff --git a/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/train.py b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..46f8808e3325498fc0cb737f8c36a8d63fd66595
--- /dev/null
+++ b/TensorFlow/contrib/cv/SRFBN_for_TensorFlow/train.py
@@ -0,0 +1,150 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from npu_bridge.npu_init import *
+from SRFBN_model import SRFBN
+import tensorflow as tf
+import time
+from PreProcess import *
+from skimage.metrics import peak_signal_noise_ratio as comparepsnr
+from skimage.metrics import _structural_similarity
+
+def train_SRFBN(dataset, sess, cfg):
+ # start put data in queue
+ with tf.device('/cpu:0'):
+ step = tf.Variable(0, trainable=False)
+ srfbn = SRFBN(sess=sess, cfg=cfg)
+ srfbn.train_step()
+ out = tf.add_n(srfbn.outs) / srfbn.cfg.num_steps
+ ## build Optimizer
+ #make lr_rate different in different stages
+ boundaries = [len(dataset)*epoch//cfg.batchsize for epoch in cfg.lr_steps]
+ values = [cfg.learning_rate*(cfg.lr_gama**i) for i in range(len(cfg.lr_steps)+1)]
+ lr = tf.train.piecewise_constant(step, boundaries, values)
+ update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
+ optimizer = tf.train.AdamOptimizer(learning_rate=lr)
+ with tf.control_dependencies(update_ops):
+ gs_vs = optimizer.compute_gradients(srfbn.losses)
+ with tf.device('/cpu:0'):
+ train_op = optimizer.apply_gradients(grads_and_vars=gs_vs, global_step=step)
+
+ tf.global_variables_initializer().run(session=sess)
+
+ summary_writer = tf.summary.FileWriter(cfg.srfbn_logdir, srfbn.sess.graph)
+ #load model
+ if srfbn.cfg.load_premodel:
+ counter = srfbn.load()
+ else:
+ counter = 0
+ time_ = time.time()
+ print("\nNow Start Training...\n")
+ global_step = 0
+ for ep in range(cfg.epoch):
+
+ #pick pic randomly
+ pic_idx = np.random.permutation(len(dataset))
+ picid = 0
+ #load five pics one time
+ for i in range(0,len(dataset),5):
+ index = []
+ for j in range(5):
+ index.append(pic_idx[i+j])
+ imgnames = []
+ for pic in index:
+ imgnames.append(dataset[pic])
+ picid += 5
+ print(imgnames)
+ batch_labels, batch_images = preprocess(imgnames, cfg)
+ patch_idx = list(range(len(batch_labels)))
+ #make the number of pic chunk divided by batchsize
+ if len(patch_idx) % cfg.batchsize != 0:
+ patch_idx.extend(list(np.random.choice(patch_idx,
+ cfg.batchsize * ((len(patch_idx) // cfg.batchsize)+1) - len(patch_idx))))
+
+
+ patch_idx = np.random.permutation(patch_idx)
+
+
+ iterations = len(patch_idx) // cfg.batchsize
+
+
+ for it in range(iterations):
+
+ idx = list(patch_idx[it * cfg.batchsize: (it+1)* cfg.batchsize])
+
+
+ patch_labels = np.array(batch_labels)[idx]
+
+ patch_images = np.array(batch_images)[idx]
+
+
+ output,_, loss,l2_loss,= srfbn.sess.run([out,train_op, srfbn.losses,srfbn.l2_regularization_loss],
+ feed_dict={srfbn.imageplaceholder: patch_images,
+ srfbn.labelplaceholder: patch_labels})
+ output = output[0] * 128 + 127.5
+ img_hr = patch_labels.reshape([srfbn.cfg.imagesize * srfbn.cfg.scale, srfbn.cfg.imagesize * srfbn.cfg.scale, 3])
+ PSNR = comparepsnr(output, img_hr, data_range=255)
+ ssim = _structural_similarity.structural_similarity(output, img_hr, win_size=11, data_range=255,
+ multichannel=True)
+
+ if it % 10 == 0:
+ print("Epoch:%2d, pic:%d, step:%2d, global_step:%d, time :%4.4f, loss:%.8f, l2_loss:%.8f, PSNR:%.8f, SSIM:%.8f" % (
+ (ep + 1),picid, it,global_step,time.time() - time_, loss,l2_loss,PSNR,ssim))
+ if it % 100 == 0:
+ srfbn.save(counter)
+ summary_str = srfbn.sess.run(srfbn.merged_summary,
+ feed_dict={srfbn.imageplaceholder: patch_images,
+ srfbn.labelplaceholder: patch_labels})
+ summary_writer.add_summary(summary_str, counter)
+
+ global_step += 1
+ counter += 1
+
+#train
+def train(*args, **kwargs):
+ data_dir = kwargs["data_dir"]
+ imgs = [os.path.join(data_dir,data) for data in os.listdir(data_dir)]
+
+
+
+ sess = tf.compat.v1.Session(config=npu_config_proto())
+
+ ## build NetWork
+ from config import SRFBN_config
+ cfg = SRFBN_config()
+ datasetet = imgs
+ train_SRFBN(datasetet, sess, cfg)
+
+
+
+if __name__ == '__main__':
+ import os
+ data_dir = "/home/TestUser08/BUAA/output_npu_20221021153629/SRFBN-tensorflow_npu_20221021153629/Resolution_2K/DIV2K/DIV2K_train_HR"
+ train(data_dir=data_dir)
+