diff --git a/TensorFlow/built-in/cv/detection/SRNTT-l2_ID0272_for_TensorFlow/SRNTT/model.py b/TensorFlow/built-in/cv/detection/SRNTT-l2_ID0272_for_TensorFlow/SRNTT/model.py index cf6c5e66f90408fc12402d4a98f776310e13c0c8..4cfacceff6d4cb7e0f7ed7b74d5475c74e22be9e 100644 --- a/TensorFlow/built-in/cv/detection/SRNTT-l2_ID0272_for_TensorFlow/SRNTT/model.py +++ b/TensorFlow/built-in/cv/detection/SRNTT-l2_ID0272_for_TensorFlow/SRNTT/model.py @@ -46,6 +46,13 @@ from scipy.io import savemat # set logging level for TensorFlow environ['TF_CPP_MIN_LOG_LEVEL'] = '3' os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' +RANK_SIZE = int(os.environ['RANK_SIZE']) +RANK_ID = int(os.environ['RANK_ID']) + +MY_DEVICE_ID = int(os.environ['DEVICE_ID']) +MY_ASCEND_DEVICE_ID = int(os.environ['ASCEND_DEVICE_ID']) +print("MY DEVICE ID:",MY_DEVICE_ID) +print("MY ASCEND DEVICE ID:",MY_ASCEND_DEVICE_ID) # set logging logging.basicConfig( @@ -55,7 +62,8 @@ logging.basicConfig( ) console = logging.StreamHandler() console.setLevel(logging.INFO) -formatter = logging.Formatter('%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s') +formatter = logging.Formatter( + '%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s') console.setFormatter(formatter) logging.getLogger('').addHandler(console) @@ -70,18 +78,20 @@ SRNTT_MODEL_NAMES = { 'weighted': 'srntt_weighted.npz' } + def npu_tf_optimizer(opt): - npu_opt = NPUDistributedOptimizer(opt) - #loss scale + # loss scale loss_scale_manager = ExponentialUpdateLossScaleManager(init_loss_scale=2 ** 32, incr_every_n_steps=1000, decr_every_n_nan_or_inf=2, decr_ratio=0.5) - if int(os.getenv('RANK_SIZE'))==1: - npu_opt = NPULossScaleOptimizer(npu_opt,loss_scale_manager) + if RANK_SIZE == 1: + npu_opt = NPULossScaleOptimizer(opt, loss_scale_manager) else: - npu_opt = NPULossScaleOptimizer(npu_opt, loss_scale_manager,is_distributed=True) + npu_opt = npu_distributed_optimizer_wrapper(opt) + npu_opt = NPULossScaleOptimizer(npu_opt, loss_scale_manager, is_distributed=True) return npu_opt + class SRNTT(object): MAX_IMAGE_SIZE = 2046 ** 2 @@ -132,20 +142,25 @@ class SRNTT(object): padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c2/%s' % i) net_ = BatchNormLayer(layer=net_, is_train=is_train, gamma_init=g_init, name='n64s1/b2/%s' % i) - net_ = ElementwiseLayer(layer=[net, net_], combine_fn=tf.add, name='b_residual_add/%s' % i) + net_ = ElementwiseLayer( + layer=[net, net_], combine_fn=tf.add, name='b_residual_add/%s' % i) net = net_ net = Conv2d(net=net, n_filter=64, filter_size=(3, 3), strides=(1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c/m') - net = BatchNormLayer(layer=net, is_train=is_train, gamma_init=g_init, name='n64s1/b/m') - content_feature = ElementwiseLayer(layer=[net, temp], combine_fn=tf.add, name='add3') + net = BatchNormLayer(layer=net, is_train=is_train, + gamma_init=g_init, name='n64s1/b/m') + content_feature = ElementwiseLayer( + layer=[net, temp], combine_fn=tf.add, name='add3') # upscaling (4x) for texture extractor net = Conv2d(net=content_feature, n_filter=256, filter_size=(3, 3), strides=(1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/1') - net = SubpixelConv2d(net=net, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/1') + net = SubpixelConv2d( + net=net, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/1') net = Conv2d(net=net, n_filter=256, filter_size=(3, 3), strides=(1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/2') - net = SubpixelConv2d(net=net, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/2') + net = SubpixelConv2d( + net=net, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/2') # output value range is [-1, 1] net_upscale = Conv2d(net=net, n_filter=3, filter_size=(1, 1), strides=(1, 1), act=tf.nn.tanh, @@ -161,15 +176,20 @@ class SRNTT(object): assert isinstance(maps, (list, tuple)) # fusion content and texture maps at the smallest scale # print('\tfusion content and texture maps at SMALL scale') - map_in = InputLayer(inputs=content_feature.outputs, name='content_feature_maps') + map_in = InputLayer(inputs=content_feature.outputs, + name='content_feature_maps') if weights is not None and concat: - self.a1 = tf.get_variable(dtype=tf.float32, name='small/a', initializer=1.) - self.b1 = tf.get_variable(dtype=tf.float32, name='small/b', initializer=0.) + self.a1 = tf.get_variable( + dtype=tf.float32, name='small/a', initializer=1.) + self.b1 = tf.get_variable( + dtype=tf.float32, name='small/b', initializer=0.) map_ref = maps[0] * tf.nn.sigmoid(self.a1 * weights + self.b1) else: map_ref = maps[0] - map_ref = InputLayer(inputs=map_ref, name='reference_feature_maps1') - net = ConcatLayer(layer=[map_in, map_ref], concat_dim=-1, name='concatenation1') + map_ref = InputLayer( + inputs=map_ref, name='reference_feature_maps1') + net = ConcatLayer(layer=[map_in, map_ref], + concat_dim=-1, name='concatenation1') net = Conv2d(net=net, n_filter=64, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', W_init=w_init, name='small/conv1') for i in range(self.num_res_blocks): # residual blocks @@ -181,29 +201,37 @@ class SRNTT(object): padding='SAME', W_init=w_init, b_init=b_init, name='small/resblock_%d/conv2' % i) net_ = BatchNormLayer(layer=net_, is_train=is_train, gamma_init=g_init, name='small/resblock_%d/bn2' % i) - net_ = ElementwiseLayer(layer=[net, net_], combine_fn=tf.add, name='small/resblock_%d/add' % i) + net_ = ElementwiseLayer( + layer=[net, net_], combine_fn=tf.add, name='small/resblock_%d/add' % i) net = net_ net = Conv2d(net=net, n_filter=64, filter_size=(3, 3), strides=(1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='small/conv2') - net = BatchNormLayer(layer=net, is_train=is_train, gamma_init=g_init, name='small/bn2') - net = ElementwiseLayer(layer=[net, map_in], combine_fn=tf.add, name='small/add2') + net = BatchNormLayer(layer=net, is_train=is_train, + gamma_init=g_init, name='small/bn2') + net = ElementwiseLayer( + layer=[net, map_in], combine_fn=tf.add, name='small/add2') # upscaling (2x) net = Conv2d(net=net, n_filter=256, filter_size=(3, 3), strides=(1, 1), act=None, padding='SAME', W_init=w_init, name='small/conv3') - net = SubpixelConv2d(net=net, scale=2, n_out_channel=None, act=tf.nn.relu, name='small/subpixel') + net = SubpixelConv2d( + net=net, scale=2, n_out_channel=None, act=tf.nn.relu, name='small/subpixel') # fusion content and texture maps at the medium scale # print('\tfusion content and texture maps at MEDIUM scale') map_in = net if weights is not None and concat: - self.a2 = tf.get_variable(dtype=tf.float32, name='medium/a', initializer=1.) - self.b2 = tf.get_variable(dtype=tf.float32, name='medium/b', initializer=0.) + self.a2 = tf.get_variable( + dtype=tf.float32, name='medium/a', initializer=1.) + self.b2 = tf.get_variable( + dtype=tf.float32, name='medium/b', initializer=0.) map_ref = maps[1] * tf.nn.sigmoid(self.a2 * tf.image.resize_bicubic( weights, [weights.get_shape()[1] * 2, weights.get_shape()[2] * 2]) + self.b2) else: map_ref = maps[1] - map_ref = InputLayer(inputs=map_ref, name='reference_feature_maps2') - net = ConcatLayer(layer=[map_in, map_ref], concat_dim=-1, name='concatenation2') + map_ref = InputLayer( + inputs=map_ref, name='reference_feature_maps2') + net = ConcatLayer(layer=[map_in, map_ref], + concat_dim=-1, name='concatenation2') net = Conv2d(net=net, n_filter=64, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', W_init=w_init, name='medium/conv1') for i in range(int(self.num_res_blocks / 2)): # residual blocks @@ -215,29 +243,37 @@ class SRNTT(object): padding='SAME', W_init=w_init, b_init=b_init, name='medium/resblock_%d/conv2' % i) net_ = BatchNormLayer(layer=net_, is_train=is_train, gamma_init=g_init, name='medium/resblock_%d/bn2' % i) - net_ = ElementwiseLayer(layer=[net, net_], combine_fn=tf.add, name='medium/resblock_%d/add' % i) + net_ = ElementwiseLayer( + layer=[net, net_], combine_fn=tf.add, name='medium/resblock_%d/add' % i) net = net_ net = Conv2d(net=net, n_filter=64, filter_size=(3, 3), strides=(1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='medium/conv2') - net = BatchNormLayer(layer=net, is_train=is_train, gamma_init=g_init, name='medium/bn2') - net = ElementwiseLayer(layer=[net, map_in], combine_fn=tf.add, name='medium/add2') + net = BatchNormLayer(layer=net, is_train=is_train, + gamma_init=g_init, name='medium/bn2') + net = ElementwiseLayer( + layer=[net, map_in], combine_fn=tf.add, name='medium/add2') # upscaling (2x) net = Conv2d(net=net, n_filter=256, filter_size=(3, 3), strides=(1, 1), act=None, padding='SAME', W_init=w_init, name='medium/conv3') - net = SubpixelConv2d(net=net, scale=2, n_out_channel=None, act=tf.nn.relu, name='medium/subpixel') + net = SubpixelConv2d( + net=net, scale=2, n_out_channel=None, act=tf.nn.relu, name='medium/subpixel') # fusion content and texture maps at the large scale # print('\tfusion content and texture maps at LARGE scale') map_in = net if weights is not None and concat: - self.a3 = tf.get_variable(dtype=tf.float32, name='large/a', initializer=1.) - self.b3 = tf.get_variable(dtype=tf.float32, name='large/b', initializer=0.) + self.a3 = tf.get_variable( + dtype=tf.float32, name='large/a', initializer=1.) + self.b3 = tf.get_variable( + dtype=tf.float32, name='large/b', initializer=0.) map_ref = maps[2] * tf.nn.sigmoid(self.a3 * tf.image.resize_bicubic( weights, [weights.get_shape()[1] * 4, weights.get_shape()[2] * 4]) + self.b3) else: map_ref = maps[2] - map_ref = InputLayer(inputs=map_ref, name='reference_feature_maps3') - net = ConcatLayer(layer=[map_in, map_ref], concat_dim=-1, name='concatenation3') + map_ref = InputLayer( + inputs=map_ref, name='reference_feature_maps3') + net = ConcatLayer(layer=[map_in, map_ref], + concat_dim=-1, name='concatenation3') net = Conv2d(net=net, n_filter=64, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', W_init=w_init, name='large/conv1') for i in range(int(self.num_res_blocks / 4)): # residual blocks @@ -249,12 +285,15 @@ class SRNTT(object): padding='SAME', W_init=w_init, b_init=b_init, name='large/resblock_%d/conv2' % i) net_ = BatchNormLayer(layer=net_, is_train=is_train, gamma_init=g_init, name='large/resblock_%d/bn2' % i) - net_ = ElementwiseLayer(layer=[net, net_], combine_fn=tf.add, name='large/resblock_%d/add' % i) + net_ = ElementwiseLayer( + layer=[net, net_], combine_fn=tf.add, name='large/resblock_%d/add' % i) net = net_ net = Conv2d(net=net, n_filter=64, filter_size=(3, 3), strides=(1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='large/conv2') - net = BatchNormLayer(layer=net, is_train=is_train, gamma_init=g_init, name='large/bn2') - net = ElementwiseLayer(layer=[net, map_in], combine_fn=tf.add, name='large/add2') + net = BatchNormLayer(layer=net, is_train=is_train, + gamma_init=g_init, name='large/bn2') + net = ElementwiseLayer( + layer=[net, map_in], combine_fn=tf.add, name='large/add2') net = Conv2d(net=net, n_filter=32, filter_size=(3, 3), strides=(1, 1), act=None, padding='SAME', W_init=w_init, name='large/conv3') # net = BatchNormLayer(layer=net, is_train=is_train, gamma_init=g_init, name='large/bn2') @@ -269,7 +308,7 @@ class SRNTT(object): w_init = tf.random_normal_initializer(stddev=0.02) b_init = None g_init = tf.random_normal_initializer(1., 0.02) - lrelu = lambda x: act.lrelu(x, 0.2) + def lrelu(x): return act.lrelu(x, 0.2) df_dim = 32 with tf.variable_scope('discriminator', reuse=reuse): layers.set_name_reuse(reuse) @@ -278,7 +317,8 @@ class SRNTT(object): n_channels = df_dim * 2 ** i net = Conv2d(net=net, n_filter=n_channels, filter_size=(3, 3), strides=(1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n%ds1/c' % n_channels) - net = BatchNormLayer(layer=net, act=lrelu, is_train=is_train, gamma_init=g_init, name='n%ds1/b' % n_channels) + net = BatchNormLayer(layer=net, act=lrelu, is_train=is_train, + gamma_init=g_init, name='n%ds1/b' % n_channels) net = Conv2d(net=net, n_filter=n_channels, filter_size=(3, 3), strides=(2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n%ds2/c' % n_channels) net = BatchNormLayer(layer=net, act=lrelu, is_train=is_train, gamma_init=g_init, @@ -292,7 +332,8 @@ class SRNTT(object): return net, logits def tf_gram_matrix(self, x): - x = tf.reshape(x, tf.stack([-1, tf.reduce_prod(x.get_shape()[1:-1]), x.get_shape()[-1]])) + x = tf.reshape(x, tf.stack( + [-1, tf.reduce_prod(x.get_shape()[1:-1]), x.get_shape()[-1]])) return tf.matmul(x, x, transpose_a=True) def eta(self, time_per_iter, n_iter_remain, current_eta=None, alpha=.8): @@ -335,9 +376,12 @@ class SRNTT(object): learning_rate=1e-4, beta1=0.9, use_pretrained_model=True, - use_init_model_only=False, # the init model is trained only with the reconstruction loss - weights=(1e-4, 1e-4, 1e-6, 1., 1.), # (perceptual loss, texture loss, adversarial loss, back projection loss, reconstruction_loss) - vgg_perceptual_loss_layer='relu5_1', # the layer name to compute perceptrual loss + # the init model is trained only with the reconstruction loss + use_init_model_only=False, + # (perceptual loss, texture loss, adversarial loss, back projection loss, reconstruction_loss) + weights=(1e-4, 1e-4, 1e-6, 1., 1.), + # the layer name to compute perceptrual loss + vgg_perceptual_loss_layer='relu5_1', is_WGAN_GP=True, is_L1_loss=True, param_WGAN_GP=10, @@ -353,7 +397,8 @@ class SRNTT(object): if self.save_dir is None: self.save_dir = 'default_save_dir' if not use_pretrained_model and exists(join(self.save_dir, MODEL_FOLDER)): - logging.warning('The existing model dir %s is removed!' % join(self.save_dir, MODEL_FOLDER)) + logging.warning('The existing model dir %s is removed!' % + join(self.save_dir, MODEL_FOLDER)) rmtree(join(self.save_dir, MODEL_FOLDER)) # create save folders @@ -374,60 +419,71 @@ class SRNTT(object): # ******************************************************************************** logging.info('Building graph ...') # input LR images, range [-1, 1] - self.input = tf.placeholder(dtype=tf.float32, shape=[batch_size, input_size, input_size, 3]) + self.input = tf.placeholder(dtype=tf.float32, shape=[ + batch_size, input_size, input_size, 3]) # original images, range [-1, 1] - self.ground_truth = tf.placeholder(dtype=tf.float32, shape=[batch_size, input_size * 4, input_size * 4, 3]) + self.ground_truth = tf.placeholder( + dtype=tf.float32, shape=[batch_size, input_size * 4, input_size * 4, 3]) # texture feature maps, range [0, ?] self.maps = tuple([tf.placeholder(dtype=tf.float32, shape=[batch_size, m.shape[0], m.shape[1], m.shape[2]]) - for m in np.load(files_map[0], allow_pickle=True)['target_map']]) - + for m in np.load(files_map[0], allow_pickle=True)['target_map']]) # weight maps - self.weights = tf.placeholder(dtype=tf.float32, shape=[batch_size, input_size, input_size]) - + self.weights = tf.placeholder(dtype=tf.float32, shape=[ + batch_size, input_size, input_size]) # reference images, ranges[-1, 1] - self.ref = tf.placeholder(dtype=tf.float32, shape=[batch_size, input_size, input_size, 3]) + self.ref = tf.placeholder(dtype=tf.float32, shape=[ + batch_size, input_size, input_size, 3]) # SRNTT network if use_weight_map: - self.net_upscale, self.net_srntt = self.model(self.input, self.maps, weights=tf.expand_dims(self.weights, axis=-1)) + self.net_upscale, self.net_srntt = self.model( + self.input, self.maps, weights=tf.expand_dims(self.weights, axis=-1)) else: - self.net_upscale, self.net_srntt = self.model(self.input, self.maps) + self.net_upscale, self.net_srntt = self.model( + self.input, self.maps) # VGG19 network, input range [0, 255] - self.net_vgg_sr = VGG19((self.net_srntt.outputs + 1) * 127.5, model_path=self.vgg19_model_path) - self.net_vgg_hr = VGG19((self.ground_truth + 1) * 127.5, model_path=self.vgg19_model_path) - + self.net_vgg_sr = VGG19( + (self.net_srntt.outputs + 1) * 127.5, model_path=self.vgg19_model_path) + self.net_vgg_hr = VGG19((self.ground_truth + 1) + * 127.5, model_path=self.vgg19_model_path) + # discriminator network self.net_d, d_real_logits = self.discriminator(self.ground_truth) - _, d_fake_logits = self.discriminator(self.net_srntt.outputs, reuse=True) + _, d_fake_logits = self.discriminator( + self.net_srntt.outputs, reuse=True) # ******************************************************************************** # *** objectives # ******************************************************************************** # reconstruction loss if is_L1_loss: - loss_reconst = tf.reduce_mean(tf.abs(self.net_srntt.outputs - self.ground_truth)) + loss_reconst = tf.reduce_mean( + tf.abs(self.net_srntt.outputs - self.ground_truth)) else: - loss_reconst = cost.mean_squared_error(self.net_srntt.outputs, self.ground_truth, is_mean=True) + loss_reconst = cost.mean_squared_error( + self.net_srntt.outputs, self.ground_truth, is_mean=True) # perceptual loss loss_percep = cost.mean_squared_error( - self.net_vgg_sr.layers[vgg_perceptual_loss_layer], + self.net_vgg_sr.layers[vgg_perceptual_loss_layer], self.net_vgg_hr.layers[vgg_perceptual_loss_layer], is_mean=True) try: available_layers = ['relu2_1', 'relu3_1', 'relu4_1', 'relu5_1'] - available_layers = available_layers[:available_layers.index(vgg_perceptual_loss_layer)] + available_layers = available_layers[:available_layers.index( + vgg_perceptual_loss_layer)] loss_percep_lower_layers = [cost.mean_squared_error( self.net_vgg_sr.layers[l], self.net_vgg_hr.layers[l], is_mean=True) for l in available_layers] if use_lower_layers_in_per_loss: - loss_percep = tf.reduce_mean([loss_percep] + loss_percep_lower_layers) + loss_percep = tf.reduce_mean( + [loss_percep] + loss_percep_lower_layers) except Exception: logging.warning('Failed to use lower layers in perceptual loss!') @@ -436,8 +492,10 @@ class SRNTT(object): self.a1, self.a2, self.a3 = -20., -20, -20 self.b1, self.b2, self.b3 = .65, .65, .65 loss_texture = tf.reduce_mean(tf.squared_difference( - self.tf_gram_matrix(self.maps[0] * tf.nn.sigmoid(tf.expand_dims(self.weights, axis=-1) * self.a1 + self.b1)), - self.tf_gram_matrix(self.net_vgg_sr.layers['relu3_1'] * tf.nn.sigmoid(tf.expand_dims(self.weights, axis=-1) * self.a1 + self.b1)) + self.tf_gram_matrix( + self.maps[0] * tf.nn.sigmoid(tf.expand_dims(self.weights, axis=-1) * self.a1 + self.b1)), + self.tf_gram_matrix(self.net_vgg_sr.layers['relu3_1'] * tf.nn.sigmoid( + tf.expand_dims(self.weights, axis=-1) * self.a1 + self.b1)) ) / 4. / (input_size * input_size * 256) ** 2) + tf.reduce_mean(tf.squared_difference( self.tf_gram_matrix( self.maps[1] * tf.nn.sigmoid(tf.image.resize_bicubic(tf.expand_dims(self.weights, axis=-1), [input_size * 2] * 2) * self.a2 + self.b2)), @@ -446,7 +504,8 @@ class SRNTT(object): ) / 4. / (input_size * input_size * 512) ** 2) + tf.reduce_mean(tf.squared_difference( self.tf_gram_matrix( self.maps[2] * tf.nn.sigmoid(tf.image.resize_bicubic(tf.expand_dims(self.weights, axis=-1), [input_size * 4] * 2) * self.a3 + self.b3)), - self.tf_gram_matrix(self.net_vgg_sr.layers['relu1_1'] * tf.nn.sigmoid(tf.image.resize_bicubic(tf.expand_dims(self.weights, axis=-1), [input_size * 4] * 2) * self.a3 + self.b3)) + self.tf_gram_matrix(self.net_vgg_sr.layers['relu1_1'] * tf.nn.sigmoid(tf.image.resize_bicubic( + tf.expand_dims(self.weights, axis=-1), [input_size * 4] * 2) * self.a3 + self.b3)) ) / 4. / (input_size * input_size * 1024) ** 2) loss_texture /= 3. else: @@ -465,31 +524,39 @@ class SRNTT(object): # adversarial loss if is_WGAN_GP: # WGAN losses - loss_d = tf.reduce_mean(d_fake_logits) - tf.reduce_mean(d_real_logits) + loss_d = tf.reduce_mean(d_fake_logits) - \ + tf.reduce_mean(d_real_logits) loss_g = -tf.reduce_mean(d_fake_logits) # GP: gradient penalty - alpha = tf.random_uniform(shape=[batch_size, 1, 1, 1], minval=0., maxval=1.) - interpolates = alpha * self.ground_truth + ((1 - alpha) * self.net_srntt.outputs) + alpha = tf.random_uniform( + shape=[batch_size, 1, 1, 1], minval=0., maxval=1.) + interpolates = alpha * self.ground_truth + \ + ((1 - alpha) * self.net_srntt.outputs) _, disc_interpolates = self.discriminator(interpolates, reuse=True) gradients = tf.gradients(disc_interpolates, [interpolates])[0] - slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=-1)) + slopes = tf.sqrt(tf.reduce_sum( + tf.square(gradients), reduction_indices=-1)) gradient_penalty = tf.reduce_mean((slopes - 1) ** 2) loss_d += param_WGAN_GP * gradient_penalty else: - loss_g = cost.sigmoid_cross_entropy(d_fake_logits, tf.ones_like(d_fake_logits)) - loss_d_fake = cost.sigmoid_cross_entropy(d_fake_logits, tf.zeros_like(d_fake_logits)) - loss_d_real = cost.sigmoid_cross_entropy(d_real_logits, tf.ones_like(d_real_logits)) + loss_g = cost.sigmoid_cross_entropy( + d_fake_logits, tf.ones_like(d_fake_logits)) + loss_d_fake = cost.sigmoid_cross_entropy( + d_fake_logits, tf.zeros_like(d_fake_logits)) + loss_d_real = cost.sigmoid_cross_entropy( + d_real_logits, tf.ones_like(d_real_logits)) loss_d = loss_d_fake + loss_d_real # back projection loss - loss_bp = back_projection_loss(tf_input=self.input, tf_output=self.net_srntt.outputs) - + loss_bp = back_projection_loss( + tf_input=self.input, tf_output=self.net_srntt.outputs) + # total loss loss_init = weights[4] * loss_reconst + weights[3] * loss_bp loss = weights[4] * loss_reconst + weights[3] * loss_bp + \ - weights[2] * loss_g + \ - weights[1] * loss_texture + \ - weights[0] * loss_percep + weights[2] * loss_g + \ + weights[1] * loss_texture + \ + weights[0] * loss_percep # ******************************************************************************** # *** optimizers @@ -501,7 +568,12 @@ class SRNTT(object): # learning rate decay global_step = tf.Variable(0, trainable=False, name='global_step') - num_batches = int(num_files / batch_size) + + if RANK_SIZE > 1: + num_batches = int(num_files / (batch_size * RANK_SIZE)) + else: + num_batches = int(num_files / batch_size) + print(num_batches) decayed_learning_rate = tf.train.exponential_decay( learning_rate=learning_rate, global_step=global_step, @@ -527,22 +599,28 @@ class SRNTT(object): samples_ref = [imresize(imread(files_ref[i], mode='RGB'), (input_size * 4, input_size * 4), interp='bicubic') for i in idx] samples_input = [imresize(img, (input_size, input_size), interp='bicubic').astype(np.float32) / 127.5 - 1 - for img in samples_in] - samples_texture_map_tmp = [np.load(files_map[i], allow_pickle=True)['target_map'] for i in idx] - samples_texture_map = [[] for _ in range(len(samples_texture_map_tmp[0]))] + for img in samples_in] + samples_texture_map_tmp = [np.load(files_map[i], allow_pickle=True)[ + 'target_map'] for i in idx] + samples_texture_map = [[] + for _ in range(len(samples_texture_map_tmp[0]))] for s in samples_texture_map_tmp: for i, item in enumerate(samples_texture_map): item.append(s[i]) samples_texture_map = [np.array(b) for b in samples_texture_map] if use_weight_map: - samples_weight_map = [np.pad(np.load(files_map[i], allow_pickle=True)['weights'], ((1, 1), (1, 1)), 'edge') for i in idx] + samples_weight_map = [np.pad(np.load(files_map[i], allow_pickle=True)[ + 'weights'], ((1, 1), (1, 1)), 'edge') for i in idx] else: - samples_weight_map = np.zeros(shape=(batch_size, input_size, input_size)) + samples_weight_map = np.zeros( + shape=(batch_size, input_size, input_size)) frame_size = int(np.sqrt(batch_size)) - vis.save_images(np.array(samples_in), [frame_size, frame_size], join(self.save_dir, SAMPLE_FOLDER, 'HR.png')) + vis.save_images(np.array(samples_in), [frame_size, frame_size], join( + self.save_dir, SAMPLE_FOLDER, 'HR.png')) vis.save_images(np.round((np.array(samples_input) + 1) * 127.5).astype(np.uint8), [frame_size, frame_size], - join(self.save_dir, SAMPLE_FOLDER, 'LR.png')) - vis.save_images(np.array(samples_ref), [frame_size, frame_size], join(self.save_dir, SAMPLE_FOLDER, 'Ref.png')) + join(self.save_dir, SAMPLE_FOLDER, 'LR.png')) + vis.save_images(np.array(samples_ref), [frame_size, frame_size], join( + self.save_dir, SAMPLE_FOLDER, 'Ref.png')) # ******************************************************************************** # *** load models and training @@ -554,12 +632,20 @@ class SRNTT(object): custom_op.parameter_map["use_off_line"].b = True config.graph_options.rewrite_options.remapping = RewriterConfig.OFF # off remap config.graph_options.rewrite_options.memory_optimization = RewriterConfig.OFF + # Allreduce + if RANK_SIZE > 1: + custom_op.parameter_map["hcom_parallel"].b = True with tf.Session(config=npu_config_proto(config_proto=config)) as sess: logging.info('Loading models ...') tf.global_variables_initializer().run() - + # BroadCast + if RANK_SIZE > 1: + input = tf.trainable_variables() + bcast_global_variables_op = hccl_ops.broadcast(input, 0) + sess.run(bcast_global_variables_op) # load pre-trained upscaling. - model_path = join(self.srntt_model_path, SRNTT_MODEL_NAMES['content_extractor']) + model_path = join(self.srntt_model_path, + SRNTT_MODEL_NAMES['content_extractor']) if files.load_and_assign_npz( sess=sess, name=model_path, @@ -567,13 +653,15 @@ class SRNTT(object): logging.error('FAILED load %s' % model_path) exit(0) vis.save_images( - np.round((self.net_upscale.outputs.eval({self.input: samples_input}) + 1) * 127.5).astype(np.uint8), + np.round((self.net_upscale.outputs.eval( + {self.input: samples_input}) + 1) * 127.5).astype(np.uint8), [frame_size, frame_size], join(self.save_dir, SAMPLE_FOLDER, 'Upscale.png')) # load the specific texture transfer model, specified by save_dir is_load_success = False if use_init_model_only: - model_path = join(self.save_dir, MODEL_FOLDER, SRNTT_MODEL_NAMES['init']) + model_path = join(self.save_dir, MODEL_FOLDER, + SRNTT_MODEL_NAMES['init']) if files.load_and_assign_npz( sess=sess, name=model_path, @@ -584,7 +672,8 @@ class SRNTT(object): else: logging.warning('FAILED load %s' % model_path) elif use_pretrained_model: - model_path = join(self.save_dir, MODEL_FOLDER, SRNTT_MODEL_NAMES['conditional_texture_transfer']) + model_path = join(self.save_dir, MODEL_FOLDER, + SRNTT_MODEL_NAMES['conditional_texture_transfer']) if files.load_and_assign_npz( sess=sess, name=model_path, @@ -595,7 +684,8 @@ class SRNTT(object): else: logging.warning('FAILED load %s' % model_path) - model_path = join(self.save_dir, MODEL_FOLDER, SRNTT_MODEL_NAMES['discriminator']) + model_path = join(self.save_dir, MODEL_FOLDER, + SRNTT_MODEL_NAMES['discriminator']) if files.load_and_assign_npz( sess=sess, name=model_path, @@ -608,7 +698,8 @@ class SRNTT(object): if not is_load_success: use_weight_map = False if use_init_model_only: - model_path = join(self.srntt_model_path, SRNTT_MODEL_NAMES['init']) + model_path = join(self.srntt_model_path, + SRNTT_MODEL_NAMES['init']) if files.load_and_assign_npz( sess=sess, name=model_path, @@ -619,7 +710,8 @@ class SRNTT(object): logging.error('FAILED load %s' % model_path) exit(0) elif use_pretrained_model: - model_path = join(self.srntt_model_path, SRNTT_MODEL_NAMES['conditional_texture_transfer']) + model_path = join( + self.srntt_model_path, SRNTT_MODEL_NAMES['conditional_texture_transfer']) if files.load_and_assign_npz( sess=sess, name=model_path, @@ -636,15 +728,27 @@ class SRNTT(object): # pre-train with only reconstruction loss current_eta = None idx = np.arange(num_files) + print("num_init_epochs:",num_init_epochs) for epoch in xrange(num_init_epochs): + if RANK_SIZE > 1: + np.random.seed(2) np.random.shuffle(idx) for n_batch in xrange(num_batches): step_time = time.time() - sub_idx = idx[n_batch * batch_size:n_batch * batch_size + batch_size] - batch_imgs = [imread(files_input[i], mode='RGB') for i in sub_idx] - batch_truth = [img.astype(np.float32) / 127.5 - 1 for img in batch_imgs] - batch_input = [imresize(img, .25, interp='bicubic').astype(np.float32)/127.5-1 for img in batch_imgs] - batch_maps_tmp = [np.load(files_map[i], allow_pickle=True)['target_map'] for i in sub_idx] + if RANK_SIZE > 1: + start_idx = n_batch * batch_size * RANK_SIZE + RANK_ID * batch_size + sub_idx = idx[start_idx:start_idx + batch_size] + else: + sub_idx = idx[n_batch * batch_size:n_batch * + batch_size + batch_size] + batch_imgs = [imread(files_input[i], mode='RGB') + for i in sub_idx] + batch_truth = [img.astype( + np.float32) / 127.5 - 1 for img in batch_imgs] + batch_input = [imresize(img, .25, interp='bicubic').astype( + np.float32)/127.5-1 for img in batch_imgs] + batch_maps_tmp = [np.load(files_map[i], allow_pickle=True)[ + 'target_map'] for i in sub_idx] batch_maps = [[] for _ in range(len(batch_maps_tmp[0]))] for s in batch_maps_tmp: for i, item in enumerate(batch_maps): @@ -656,7 +760,8 @@ class SRNTT(object): for i in sub_idx] else: - batch_weights = np.zeros(shape=(batch_size, input_size, input_size)) + batch_weights = np.zeros( + shape=(batch_size, input_size, input_size)) # train with reference _, l_reconst, l_bp, map_hr_3, map_hr_2, map_hr_1 = sess.run( fetches=[optimizer_init, loss_reconst, loss_bp, @@ -684,8 +789,10 @@ class SRNTT(object): # print time_per_iter = time.time() - step_time - n_iter_remain = (num_init_epochs - epoch - 1) * num_batches + num_batches - n_batch - eta_str, eta_ = self.eta(time_per_iter, n_iter_remain, current_eta) + n_iter_remain = (num_init_epochs - epoch - 1) * \ + num_batches + num_batches - n_batch + eta_str, eta_ = self.eta( + time_per_iter, n_iter_remain, current_eta) current_eta = eta_ logging.info('Pre-train: Epoch [%02d/%02d] Batch [%03d/%03d]\tETA: %s\n' '\tperf: %.4f\n' @@ -705,20 +812,33 @@ class SRNTT(object): # save model for each epoch files.save_npz( save_list=self.net_srntt.all_params, - name=join(self.save_dir, MODEL_FOLDER, SRNTT_MODEL_NAMES['init']), + name=join(self.save_dir, MODEL_FOLDER, + SRNTT_MODEL_NAMES['init']), sess=sess) + print("num_epochs:",num_epochs) # train with all losses current_eta = None for epoch in xrange(num_epochs): + if RANK_SIZE > 1: + np.random.seed(2) np.random.shuffle(idx) for n_batch in xrange(num_batches): step_time = time.time() - sub_idx = idx[n_batch * batch_size:n_batch * batch_size + batch_size] - batch_imgs = [imread(files_input[i], mode='RGB') for i in sub_idx] - batch_truth = [img.astype(np.float32) / 127.5 - 1 for img in batch_imgs] - batch_input = [imresize(img, .25, interp='bicubic').astype(np.float32)/127.5-1 for img in batch_imgs] - batch_maps_tmp = [np.load(files_map[i], allow_pickle=True)['target_map'] for i in sub_idx] + if RANK_SIZE > 1: + start_idx = n_batch * batch_size * RANK_SIZE + RANK_ID * batch_size + sub_idx = idx[start_idx:start_idx + batch_size] + else: + sub_idx = idx[n_batch * batch_size:n_batch * + batch_size + batch_size] + batch_imgs = [imread(files_input[i], mode='RGB') + for i in sub_idx] + batch_truth = [img.astype( + np.float32) / 127.5 - 1 for img in batch_imgs] + batch_input = [imresize(img, .25, interp='bicubic').astype( + np.float32)/127.5-1 for img in batch_imgs] + batch_maps_tmp = [np.load(files_map[i], allow_pickle=True)[ + 'target_map'] for i in sub_idx] batch_maps = [[] for _ in range(len(batch_maps_tmp[0]))] for s in batch_maps_tmp: for i, item in enumerate(batch_maps): @@ -728,7 +848,8 @@ class SRNTT(object): batch_weights = [np.pad(np.load(files_map[i], allow_pickle=True)['weights'], ((1, 1), (1, 1)), 'edge') for i in sub_idx] else: - batch_weights = np.zeros(shape=(batch_size, input_size, input_size)) + batch_weights = np.zeros( + shape=(batch_size, input_size, input_size)) # train with reference for _ in xrange(2): @@ -757,7 +878,8 @@ class SRNTT(object): # train with truth _, _, l_rec, l_per, l_tex, l_adv, l_dis, l_bp = sess.run( - fetches=[optimizer, optimizer_d, loss_reconst, loss_percep, loss_texture, loss_g, loss_d, loss_bp], + fetches=[optimizer, optimizer_d, loss_reconst, + loss_percep, loss_texture, loss_g, loss_d, loss_bp], feed_dict={ self.input: batch_input, self.maps: [map_hr_3, map_hr_2, map_hr_1], @@ -768,8 +890,10 @@ class SRNTT(object): # print time_per_iter = time.time() - step_time - n_iter_remain = (num_epochs - epoch - 1) * num_batches + num_batches - n_batch - eta_str, eta_ = self.eta(time_per_iter, n_iter_remain, current_eta) + n_iter_remain = (num_epochs - epoch - 1) * \ + num_batches + num_batches - n_batch + eta_str, eta_ = self.eta( + time_per_iter, n_iter_remain, current_eta) current_eta = eta_ logging.info('Epoch [%02d/%02d] Batch [%03d/%03d]\tETA: %s\n' '\tperf = %.4f\n' @@ -793,11 +917,13 @@ class SRNTT(object): # save models for each epoch files.save_npz( save_list=self.net_srntt.all_params, - name=join(self.save_dir, MODEL_FOLDER, SRNTT_MODEL_NAMES['conditional_texture_transfer']), + name=join(self.save_dir, MODEL_FOLDER, + SRNTT_MODEL_NAMES['conditional_texture_transfer']), sess=sess) files.save_npz( save_list=self.net_d.all_params, - name=join(self.save_dir, MODEL_FOLDER, SRNTT_MODEL_NAMES['discriminator']), + name=join(self.save_dir, MODEL_FOLDER, + SRNTT_MODEL_NAMES['discriminator']), sess=sess) def test( @@ -805,7 +931,8 @@ class SRNTT(object): input_dir, # original image ref_dir=None, # reference images use_pretrained_model=True, - use_init_model_only=False, # the init model is trained only with the reconstruction loss + # the init model is trained only with the reconstruction loss + use_init_model_only=False, use_weight_map=False, result_dir=None, ref_scale=1.0, @@ -860,10 +987,12 @@ class SRNTT(object): stride = 100 for ind_row in range(0, h - (patch_size - stride), stride): for ind_col in range(0, w - (patch_size - stride), stride): - patch = img_input[ind_row:ind_row + patch_size, ind_col:ind_col + patch_size, :] + patch = img_input[ind_row:ind_row + + patch_size, ind_col:ind_col + patch_size, :] if patch.shape != (patch_size, patch_size, 3): patch = np.pad(patch, - ((0, patch_size - patch.shape[0]), (0, patch_size - patch.shape[1]), (0, 0)), + ((0, patch_size - patch.shape[0]), (0, + patch_size - patch.shape[1]), (0, 0)), 'reflect') patches.append(patch) grids.append((ind_row * 4, ind_col * 4, patch_size * 4)) @@ -889,9 +1018,11 @@ class SRNTT(object): exit(0) if ref_scale <= 0: # keep the same scale as HR image - img_ref = [imresize(img, (h * 4, w * 4), interp='bicubic') for img in img_ref] + img_ref = [imresize(img, (h * 4, w * 4), interp='bicubic') + for img in img_ref] elif ref_scale != 1: - img_ref = [imresize(img, float(ref_scale), interp='bicubic') for img in img_ref] + img_ref = [imresize(img, float(ref_scale), interp='bicubic') + for img in img_ref] for i in xrange(len(img_ref)): h2, w2, _ = img_ref[i].shape @@ -914,11 +1045,12 @@ class SRNTT(object): self.is_model_built = True logging.info('Building graph ...') # input image, range [-1, 1] - self.input_srntt = tf.placeholder(shape=[1, None, None, 3], dtype=tf.float32) + self.input_srntt = tf.placeholder( + shape=[1, None, None, 3], dtype=tf.float32) # reference images, range [0, 255] - self.input_vgg19 = tf.placeholder(shape=[1, None, None, 3], dtype=tf.float32) - + self.input_vgg19 = tf.placeholder( + shape=[1, None, None, 3], dtype=tf.float32) # swapped feature map and weights self.maps = ( @@ -943,7 +1075,8 @@ class SRNTT(object): self.net_upscale, self.net_srntt = self.model( self.input_srntt, self.maps, weights=tf.expand_dims(self.weights, axis=-1), is_train=False) else: - self.net_upscale, self.net_srntt = self.model(self.input_srntt, self.maps, is_train=False) + self.net_upscale, self.net_srntt = self.model( + self.input_srntt, self.maps, is_train=False) # VGG19 network, input range [0, 255] logging.info('Build VGG19 model') @@ -958,7 +1091,8 @@ class SRNTT(object): # ******************************************************************************** config = tf.ConfigProto() config.gpu_options.allow_growth = False - self.sess = tf.Session(config=npu_config_proto(config_proto=config)) + self.sess = tf.Session( + config=npu_config_proto(config_proto=config)) # instant of Swap() logging.info('Initialize the swapper') @@ -968,7 +1102,8 @@ class SRNTT(object): self.sess.run(tf.global_variables_initializer()) # load pre-trained content extractor, including upscaling. - model_path = join(self.srntt_model_path, SRNTT_MODEL_NAMES['content_extractor']) + model_path = join(self.srntt_model_path, + SRNTT_MODEL_NAMES['content_extractor']) if files.load_and_assign_npz( sess=self.sess, name=model_path, @@ -979,7 +1114,8 @@ class SRNTT(object): # load the specific conditional texture transfer model, specified by save_dir if self.save_dir is None: if use_init_model_only: - model_path = join(self.srntt_model_path, SRNTT_MODEL_NAMES['init']) + model_path = join(self.srntt_model_path, + SRNTT_MODEL_NAMES['init']) if files.load_and_assign_npz( sess=self.sess, name=model_path, @@ -989,7 +1125,8 @@ class SRNTT(object): logging.error('FAILED load %s' % model_path) exit(0) else: - model_path = join(self.srntt_model_path, SRNTT_MODEL_NAMES['conditional_texture_transfer']) + model_path = join( + self.srntt_model_path, SRNTT_MODEL_NAMES['conditional_texture_transfer']) if files.load_and_assign_npz( sess=self.sess, name=model_path, @@ -1000,7 +1137,8 @@ class SRNTT(object): exit(0) else: if use_init_model_only: - model_path = join(self.save_dir, MODEL_FOLDER, SRNTT_MODEL_NAMES['init']) + model_path = join( + self.save_dir, MODEL_FOLDER, SRNTT_MODEL_NAMES['init']) if files.load_and_assign_npz( sess=self.sess, name=model_path, @@ -1102,16 +1240,17 @@ class SRNTT(object): logging.info('Time elapsed: PM: %.3f sec, SR: %.3f sec' % ((time_step_1 - t_start), (time_step_2 - time_step_1))) - imsave(join(result_dir, 'tmp', 'srntt_%05d.png' % idx), np.round((out_srntt.squeeze() + 1) * 127.5).astype(np.uint8)) imsave(join(result_dir, 'tmp', 'upscale_%05d.png' % idx), np.round((out_upscale.squeeze() + 1) * 127.5).astype(np.uint8)) - logging.info('Saved to %s' % join(result_dir, 'tmp', 'srntt_%05d.png' % idx)) + logging.info('Saved to %s' % + join(result_dir, 'tmp', 'srntt_%05d.png' % idx)) t_end = time.time() logging.info('Reconstruct SR image') out_srntt_files = sorted(glob(join(result_dir, 'tmp', 'srntt_*.png'))) - out_upscale_files = sorted(glob(join(result_dir, 'tmp', 'upscale_*.png'))) + out_upscale_files = sorted( + glob(join(result_dir, 'tmp', 'upscale_*.png'))) if grids is not None: patch_size = grids[0, 2] @@ -1121,16 +1260,16 @@ class SRNTT(object): counter = np.zeros_like(out_srntt_large, dtype=np.float32) for idx in xrange(len(grids)): out_upscale_large[ - grids[idx, 0]:grids[idx, 0] + patch_size, - grids[idx, 1]:grids[idx, 1] + patch_size, :] += imread(out_upscale_files[idx], mode='RGB').astype(np.float32) + grids[idx, 0]:grids[idx, 0] + patch_size, + grids[idx, 1]:grids[idx, 1] + patch_size, :] += imread(out_upscale_files[idx], mode='RGB').astype(np.float32) out_srntt_large[ - grids[idx, 0]:grids[idx, 0] + patch_size, - grids[idx, 1]:grids[idx, 1] + patch_size, :] += imread(out_srntt_files[idx], mode='RGB').astype(np.float32) + grids[idx, 0]:grids[idx, 0] + patch_size, + grids[idx, 1]:grids[idx, 1] + patch_size, :] += imread(out_srntt_files[idx], mode='RGB').astype(np.float32) counter[ - grids[idx, 0]:grids[idx, 0] + patch_size, - grids[idx, 1]:grids[idx, 1] + patch_size, :] += 1 + grids[idx, 0]:grids[idx, 0] + patch_size, + grids[idx, 1]:grids[idx, 1] + patch_size, :] += 1 out_upscale_large /= counter out_srntt_large /= counter @@ -1140,10 +1279,10 @@ class SRNTT(object): out_upscale = imread(out_upscale_files[0], mode='RGB') out_srntt = imread(out_srntt_files[0], mode='RGB') - # log run time with open(join(result_dir, 'run_time.txt'), 'w') as f: - line = '%02d min %02d sec\n' % ((t_end - t_start) // 60, (t_end - t_start) % 60) + line = '%02d min %02d sec\n' % ( + (t_end - t_start) // 60, (t_end - t_start) % 60) f.write(line) f.close() @@ -1158,10 +1297,13 @@ class SRNTT(object): for idx, ref in enumerate(img_ref): imsave(join(result_dir, 'Ref_%02d.png' % idx), ref) # save bicubic - imsave(join(result_dir, 'Bicubic.png'), imresize(img_input_copy, 4., interp='bicubic')) + imsave(join(result_dir, 'Bicubic.png'), imresize( + img_input_copy, 4., interp='bicubic')) # save SR images - imsave(join(result_dir, 'Upscale.png'), np.array(out_upscale).squeeze().round().clip(0, 255).astype(np.uint8)) - imsave(join(result_dir, 'SRNTT.png'), np.array(out_srntt).squeeze().round().clip(0, 255).astype(np.uint8)) + imsave(join(result_dir, 'Upscale.png'), np.array( + out_upscale).squeeze().round().clip(0, 255).astype(np.uint8)) + imsave(join(result_dir, 'SRNTT.png'), np.array( + out_srntt).squeeze().round().clip(0, 255).astype(np.uint8)) logging.info('Saved results to folder %s' % result_dir) return np.array(out_srntt).squeeze().round().clip(0, 255).astype(np.uint8) @@ -1171,7 +1313,8 @@ class SRNTT(object): input_dir, # original image ref_dir=None, # reference images use_pretrained_model=True, - use_init_model_only=False, # the init model is trained only with the reconstruction loss + # the init model is trained only with the reconstruction loss + use_init_model_only=False, use_weight_map=False, result_dir=None, ref_scale=1.0, @@ -1213,10 +1356,12 @@ class SRNTT(object): stride = 100 for ind_row in range(0, h - (patch_size - stride), stride): for ind_col in range(0, w - (patch_size - stride), stride): - patch = img_input[ind_row:ind_row + patch_size, ind_col:ind_col + patch_size, :] + patch = img_input[ind_row:ind_row + + patch_size, ind_col:ind_col + patch_size, :] if patch.shape != (patch_size, patch_size, 3): patch = np.pad(patch, - ((0, patch_size - patch.shape[0]), (0, patch_size - patch.shape[1]), (0, 0)), + ((0, patch_size - patch.shape[0]), (0, + patch_size - patch.shape[1]), (0, 0)), 'reflect') patches.append(patch) grids.append((ind_row * 4, ind_col * 4, patch_size * 4)) @@ -1248,9 +1393,11 @@ class SRNTT(object): exit(0) if ref_scale <= 0: # keep the same scale as HR image - img_ref = [imresize(img, (h * 4, w * 4), interp='bicubic') for img in img_ref] + img_ref = [imresize(img, (h * 4, w * 4), interp='bicubic') + for img in img_ref] elif ref_scale != 1: - img_ref = [imresize(img, float(ref_scale), interp='bicubic') for img in img_ref] + img_ref = [imresize(img, float(ref_scale), interp='bicubic') + for img in img_ref] for i in xrange(len(img_ref)): h2, w2, _ = img_ref[i].shape @@ -1275,10 +1422,12 @@ class SRNTT(object): self.is_model_built = True logging.info('Building graph ...') # input image, range [-1, 1] - self.input_srntt = tf.placeholder(shape=[1, None, None, 3], dtype=tf.float32) + self.input_srntt = tf.placeholder( + shape=[1, None, None, 3], dtype=tf.float32) # reference images, range [0, 255] - self.input_vgg19 = tf.placeholder(shape=[1, None, None, 3], dtype=tf.float32) + self.input_vgg19 = tf.placeholder( + shape=[1, None, None, 3], dtype=tf.float32) # swapped feature map and weights self.maps = ( @@ -1303,7 +1452,8 @@ class SRNTT(object): self.net_upscale, self.net_srntt = self.model( self.input_srntt, self.maps, weights=tf.expand_dims(self.weights, axis=-1), is_train=False) else: - self.net_upscale, self.net_srntt = self.model(self.input_srntt, self.maps, is_train=False) + self.net_upscale, self.net_srntt = self.model( + self.input_srntt, self.maps, is_train=False) # VGG19 network, input range [0, 255] logging.info('Build VGG19 model') @@ -1318,7 +1468,8 @@ class SRNTT(object): # ******************************************************************************** config = tf.ConfigProto() config.gpu_options.allow_growth = False - self.sess = tf.Session(config=npu_config_proto(config_proto=config)) + self.sess = tf.Session( + config=npu_config_proto(config_proto=config)) # instant of Swap() logging.info('Initialize the swapper') @@ -1328,7 +1479,8 @@ class SRNTT(object): self.sess.run(tf.global_variables_initializer()) # load pre-trained content extractor, including upscaling. - model_path = join(self.srntt_model_path, SRNTT_MODEL_NAMES['content_extractor']) + model_path = join(self.srntt_model_path, + SRNTT_MODEL_NAMES['content_extractor']) if files.load_and_assign_npz( sess=self.sess, name=model_path, @@ -1339,7 +1491,8 @@ class SRNTT(object): # load the specific conditional texture transfer model, specified by save_dir if self.save_dir is None: if use_init_model_only: - model_path = join(self.srntt_model_path, SRNTT_MODEL_NAMES['init']) + model_path = join(self.srntt_model_path, + SRNTT_MODEL_NAMES['init']) if files.load_and_assign_npz( sess=self.sess, name=model_path, @@ -1349,7 +1502,8 @@ class SRNTT(object): logging.error('FAILED load %s' % model_path) exit(0) else: - model_path = join(self.srntt_model_path, SRNTT_MODEL_NAMES['conditional_texture_transfer']) + model_path = join( + self.srntt_model_path, SRNTT_MODEL_NAMES['conditional_texture_transfer']) if files.load_and_assign_npz( sess=self.sess, name=model_path, @@ -1360,7 +1514,8 @@ class SRNTT(object): exit(0) else: if use_init_model_only: - model_path = join(self.save_dir, MODEL_FOLDER, SRNTT_MODEL_NAMES['init']) + model_path = join( + self.save_dir, MODEL_FOLDER, SRNTT_MODEL_NAMES['init']) if files.load_and_assign_npz( sess=self.sess, name=model_path, @@ -1408,7 +1563,8 @@ class SRNTT(object): if is_ref: for i in img_ref: img_ref_downscale = imresize(i, .25, interp='bicubic') - img_ref_upscale = self.net_upscale.outputs.eval({self.input_srntt: [img_ref_downscale / 127.5 - 1]}, session=self.sess) + img_ref_upscale = self.net_upscale.outputs.eval( + {self.input_srntt: [img_ref_downscale / 127.5 - 1]}, session=self.sess) img_ref_upscale = (img_ref_upscale + 1) * 127.5 map_ref_sr.append( self.net_vgg19.get_layer_output( @@ -1431,18 +1587,22 @@ class SRNTT(object): if 'Urban' in input_dir: img_input_upscale = imread( - join('../EDSR-PyTorch/test_Urban100_MDSR', split(input_dir)[-1], 'SRNTT.png'), + join('../EDSR-PyTorch/test_Urban100_MDSR', + split(input_dir)[-1], 'SRNTT.png'), mode='RGB').astype(np.float32) elif 'CUFED5' in input_dir and False: img_input_upscale = imread( - join('../EDSR-PyTorch/test_CUFED5_MDSR', split(input_dir)[-1], 'SRNTT.png'), + join('../EDSR-PyTorch/test_CUFED5_MDSR', + split(input_dir)[-1], 'SRNTT.png'), mode='RGB').astype(np.float32) elif 'Sun80' in input_dir or 'sun80' in input_dir: img_input_upscale = imread( - join('../EDSR-PyTorch/test_Sun100_MDSR', split(input_dir)[-1].split('.')[0], 'SRNTT.png'), + join('../EDSR-PyTorch/test_Sun100_MDSR', + split(input_dir)[-1].split('.')[0], 'SRNTT.png'), mode='RGB').astype(np.float32) else: - img_input_upscale = self.net_upscale.outputs.eval({self.input_srntt: [patch / 127.5 - 1]}, session=self.sess) + img_input_upscale = self.net_upscale.outputs.eval( + {self.input_srntt: [patch / 127.5 - 1]}, session=self.sess) img_input_upscale = (img_input_upscale + 1) * 127.5 if is_ref: @@ -1502,10 +1662,12 @@ class SRNTT(object): np.round((out_srntt.squeeze() + 1) * 127.5).astype(np.uint8)) imsave(join(result_dir, 'tmp', 'upscale_%05d.png' % idx), np.round((out_upscale.squeeze() + 1) * 127.5).astype(np.uint8)) - logging.info('Saved to %s' % join(result_dir, 'tmp', 'srntt_%05d.png' % idx)) + logging.info('Saved to %s' % + join(result_dir, 'tmp', 'srntt_%05d.png' % idx)) logging.info('Reconstruct SR image') out_srntt_files = sorted(glob(join(result_dir, 'tmp', 'srntt_*.png'))) - out_upscale_files = sorted(glob(join(result_dir, 'tmp', 'upscale_*.png'))) + out_upscale_files = sorted( + glob(join(result_dir, 'tmp', 'upscale_*.png'))) if grids is not None: patch_size = grids[0, 2] @@ -1515,16 +1677,16 @@ class SRNTT(object): counter = np.zeros_like(out_srntt_large, dtype=np.float32) for idx in xrange(len(grids)): out_upscale_large[ - grids[idx, 0]:grids[idx, 0] + patch_size, - grids[idx, 1]:grids[idx, 1] + patch_size, :] += imread(out_upscale_files[idx], mode='RGB').astype(np.float32) + grids[idx, 0]:grids[idx, 0] + patch_size, + grids[idx, 1]:grids[idx, 1] + patch_size, :] += imread(out_upscale_files[idx], mode='RGB').astype(np.float32) out_srntt_large[ - grids[idx, 0]:grids[idx, 0] + patch_size, - grids[idx, 1]:grids[idx, 1] + patch_size, :] += imread(out_srntt_files[idx], mode='RGB').astype(np.float32) + grids[idx, 0]:grids[idx, 0] + patch_size, + grids[idx, 1]:grids[idx, 1] + patch_size, :] += imread(out_srntt_files[idx], mode='RGB').astype(np.float32) counter[ - grids[idx, 0]:grids[idx, 0] + patch_size, - grids[idx, 1]:grids[idx, 1] + patch_size, :] += 1 + grids[idx, 0]:grids[idx, 0] + patch_size, + grids[idx, 1]:grids[idx, 1] + patch_size, :] += 1 out_upscale_large /= counter out_srntt_large /= counter @@ -1538,7 +1700,8 @@ class SRNTT(object): # log run time with open(join(result_dir, 'run_time.txt'), 'w') as f: - line = '%02d min %02d sec\n' % ((t_end - t_start) // 60, (t_end - t_start) % 60) + line = '%02d min %02d sec\n' % ( + (t_end - t_start) // 60, (t_end - t_start) % 60) f.write(line) f.close() @@ -1553,11 +1716,13 @@ class SRNTT(object): for idx, ref in enumerate(img_ref): imsave(join(result_dir, 'Ref_%02d.png' % idx), ref) # save bicubic - imsave(join(result_dir, 'Bicubic.png'), imresize(img_input_copy, 4., interp='bicubic')) + imsave(join(result_dir, 'Bicubic.png'), imresize( + img_input_copy, 4., interp='bicubic')) # save SR images - imsave(join(result_dir, 'Upscale.png'), np.array(out_upscale).squeeze().round().clip(0, 255).astype(np.uint8)) - imsave(join(result_dir, 'SRNTT.png'), np.array(out_srntt).squeeze().round().clip(0, 255).astype(np.uint8)) + imsave(join(result_dir, 'Upscale.png'), np.array( + out_upscale).squeeze().round().clip(0, 255).astype(np.uint8)) + imsave(join(result_dir, 'SRNTT.png'), np.array( + out_srntt).squeeze().round().clip(0, 255).astype(np.uint8)) logging.info('Saved results to folder %s' % result_dir) return np.array(out_srntt).squeeze().round().clip(0, 255).astype(np.uint8) - diff --git a/TensorFlow/built-in/cv/detection/SRNTT-l2_ID0272_for_TensorFlow/configs/.keep b/TensorFlow/built-in/cv/detection/SRNTT-l2_ID0272_for_TensorFlow/configs/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TensorFlow/built-in/cv/detection/SRNTT-l2_ID0272_for_TensorFlow/configs/rank_table_8p.json b/TensorFlow/built-in/cv/detection/SRNTT-l2_ID0272_for_TensorFlow/configs/rank_table_8p.json new file mode 100644 index 0000000000000000000000000000000000000000..cd9041f3efa3eb1a9e1959ac758b60e2313778a0 --- /dev/null +++ b/TensorFlow/built-in/cv/detection/SRNTT-l2_ID0272_for_TensorFlow/configs/rank_table_8p.json @@ -0,0 +1,52 @@ +{ + "server_count":"1", + "server_list":[ + { + "server_id":"10.147.179.27", + "device":[ + { + "device_id":"0", + "device_ip":"192.168.100.100", + "rank_id":"0" + }, + { + "device_id":"1", + "device_ip":"192.168.101.100", + "rank_id":"1" + }, + { + "device_id":"2", + "device_ip":"192.168.102.100", + "rank_id":"2" + }, + { + "device_id":"3", + "device_ip":"192.168.103.100", + "rank_id":"3" + }, + { + "device_id":"4", + "device_ip":"192.168.100.101", + "rank_id":"4" + }, + { + "device_id":"5", + "device_ip":"192.168.101.101", + "rank_id":"5" + }, + { + "device_id":"6", + "device_ip":"192.168.102.101", + "rank_id":"6" + }, + { + "device_id":"7", + "device_ip":"192.168.103.101", + "rank_id":"7" + } + ] + } + ], + "status":"completed", + "version":"1.0" +} \ No newline at end of file diff --git a/TensorFlow/built-in/cv/detection/SRNTT-l2_ID0272_for_TensorFlow/test/train_full_8p.sh b/TensorFlow/built-in/cv/detection/SRNTT-l2_ID0272_for_TensorFlow/test/train_full_8p.sh new file mode 100644 index 0000000000000000000000000000000000000000..2df9b65b313751b671b1d89bf74c5efa3724d09e --- /dev/null +++ b/TensorFlow/built-in/cv/detection/SRNTT-l2_ID0272_for_TensorFlow/test/train_full_8p.sh @@ -0,0 +1,129 @@ +#!/bin/bash +cur_path=`pwd`/../ + +RANK_ID_START=0 +#基础参数,需要模型审视修改 +batch_size=9 +#网络名称,同目录名称 +Network="SRNTT-l2_ID0272_for_TensorFlow" +#Device数量 +export JOB_ID=10001 +export RANK_SIZE=8 +export RANK_TABLE_FILE=${cur_path}/configs/rank_table_8p.json +#训练epoch,可选 +train_epochs=30 +#训练step +train_steps=50000 +#学习率 +learning_rate=1e-4 + +#参数配置 +data_path="./" + +if [[ $1 == --help || $1 == --h ]];then + echo "usage:./train_full_1p.sh" + exit 1 +fi + +for para in $* +do + if [[ $para == --data_path* ]];then + data_path=`echo ${para#*=}` + fi +done + +if [[ $data_path == "" ]];then + echo "[Error] para \"data_path \" must be config" + exit 1 +fi +##############执行训练########## +cd $cur_path + +#参数修改 +sed -i "s|num_batches = int(num_files / batch_size)|num_batches = 10|g" ./SRNTT/model.py +cp -r ${data_path}/SRNTT/ ./SRNTT/models/ +cp ${data_path}/imagenet-vgg-verydeep-19.mat ./SRNTT/models/VGG19/ + +start=$(date +%s) +for((RANK_ID=$RANK_ID_START;RANK_ID<$((RANK_SIZE+RANK_ID_START));RANK_ID++)); +do + #设置环境变量,不需要修改 + echo "Device ID: $RANK_ID" + export RANK_ID=$RANK_ID + export ASCEND_DEVICE_ID=$RANK_ID + export DEVICE_ID=$RANK_ID + + #创建DeviceID输出目录,不需要修改 + if [ -d $cur_path/test/output ];then + rm -rf $cur_path/test/output/${ASCEND_DEVICE_ID} + mkdir -p $cur_path/test/output/$ASCEND_DEVICE_ID/ckpt + else + mkdir -p $cur_path/test/output/$ASCEND_DEVICE_ID/ckpt + fi + + nohup python3 main.py \ + --is_train True \ + --input_dir ${data_path}/data/train/CUFED/input \ + --ref_dir ${data_path}/data/train/CUFED/ref \ + --map_dir ${data_path}/data/train/CUFED/map_321 \ + --use_pretrained_model False \ + --num_init_epochs 2 \ + --num_epochs 2 \ + --save_dir demo_training_srntt > $cur_path/test/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log 2>&1 & +done +wait + +end=$(date +%s) +e2e_time=$(( $end - $start )) + +#参数回改 +sed -i "s|num_batches = 10|num_batches = int(num_files / batch_size)|g" ./SRNTT/model.py + +#echo "Final Performance ms/step : $average_perf" +echo "Final Training Duration sec : $e2e_time" + +#结果打印,不需要修改 +echo "------------------ Final result ------------------" +#输出性能FPS,需要模型审视修改 +TrainingTime=`grep "perf:" $cur_path/test/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log|awk 'END {print $2}'` +wait +FPS=`awk 'BEGIN{printf "%.2f\n",'${batch_size}'/'${TrainingTime}'}'` +#打印,不需要修改 +echo "Final Performance images/sec : $FPS" + +#输出训练精度,需要模型审视修改 +#train_accuracy=`grep "train_acc " $cur_path/test/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log|awk 'END {print $12}'|sed 's/,//g'` +#打印,不需要修改 +#echo "Final Train Accuracy : ${train_accuracy}" + + +#性能看护结果汇总 +#训练用例信息,不需要修改 +BatchSize=${batch_size} +DeviceType=`uname -m` +CaseName=${Network}_bs${BatchSize}_${RANK_SIZE}'p'_'acc' + +##获取性能数据,不需要修改 +#吞吐量 +ActualFPS=${FPS} +#单迭代训练时长 +TrainingTime=`awk 'BEGIN{printf "%.2f\n",'${BatchSize}'/'${FPS}'}'` + +#从train_$ASCEND_DEVICE_ID.log提取Loss到train_${CaseName}_loss.txt中,需要根据模型审视 +grep "loss:l_rec" $cur_path/test/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log|awk '{print $3}' > $cur_path/test/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt + +#最后一个迭代loss值,不需要修改 +ActualLoss=`awk 'END {print $1}' $cur_path/test/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt` + +#关键信息打印到${CaseName}.log中,不需要修改 +echo "Network = ${Network}" > $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "RankSize = ${RANK_SIZE}" >> $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "BatchSize = ${BatchSize}" >> $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "DeviceType = ${DeviceType}" >> $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "CaseName = ${CaseName}" >> $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualFPS = ${ActualFPS}" >> $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainingTime = ${TrainingTime}" >> $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainAccuracy = None" >> $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualLoss = ${ActualLoss}" >> $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "E2ETrainingTime = ${e2e_time}" >> $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log +