diff --git a/TensorFlow/contrib/cv/deblur_gan/DeblurGAN_ID0207_for_TensorFlow/main.py b/TensorFlow/contrib/cv/deblur_gan/DeblurGAN_ID0207_for_TensorFlow/main.py index 6b55a29e8e42827b49e0a82f39ae08ef3f5e00e3..4adcb4f39bb1b4f7e3fc3f23899fb68ca14da11b 100644 --- a/TensorFlow/contrib/cv/deblur_gan/DeblurGAN_ID0207_for_TensorFlow/main.py +++ b/TensorFlow/contrib/cv/deblur_gan/DeblurGAN_ID0207_for_TensorFlow/main.py @@ -112,6 +112,9 @@ sess = tf.Session(config=config) sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(max_to_keep=10) +if not os.path.exists(args.result_path): + os.makedirs(args.result_path) + if args.mode == 'train': train(args, model, sess, saver) diff --git a/TensorFlow/contrib/cv/deblur_gan/DeblurGAN_ID0207_for_TensorFlow/mode.py b/TensorFlow/contrib/cv/deblur_gan/DeblurGAN_ID0207_for_TensorFlow/mode.py index fcf2a8236ec257f04b1d50aa3e9182496e1e6af7..43d8a47f596b3c51c07927cca17f08735dae4e1e 100644 --- a/TensorFlow/contrib/cv/deblur_gan/DeblurGAN_ID0207_for_TensorFlow/mode.py +++ b/TensorFlow/contrib/cv/deblur_gan/DeblurGAN_ID0207_for_TensorFlow/mode.py @@ -143,9 +143,10 @@ def train(args, model, sess, saver): def test(args, model, sess, saver, file, step=-1, loading=False): if loading: - saver.restore(sess, args.pre_trained_model) + latest_checkpoint = tf.train.latest_checkpoint(args.pre_trained_model) + saver.restore(sess, latest_checkpoint) print("saved model is loaded for test!") - print("model path is %s" % args.pre_trained_model) + #print("model path is: ", lastest_checkpoint) blur_img_name = sorted(os.listdir(args.test_Blur_path)) sharp_img_name = sorted(os.listdir(args.test_Sharp_path)) diff --git a/TensorFlow/contrib/cv/deblur_gan/DeblurGAN_ID0207_for_TensorFlow/test/train_full_1p.sh b/TensorFlow/contrib/cv/deblur_gan/DeblurGAN_ID0207_for_TensorFlow/test/train_full_1p.sh index efa122787d0cffd631faf603cf346f5a7aa1a373..d65d3555df6c0ac301b3edf0a451302ab4b2fb37 100644 --- a/TensorFlow/contrib/cv/deblur_gan/DeblurGAN_ID0207_for_TensorFlow/test/train_full_1p.sh +++ b/TensorFlow/contrib/cv/deblur_gan/DeblurGAN_ID0207_for_TensorFlow/test/train_full_1p.sh @@ -22,7 +22,7 @@ export ASCEND_GLOBAL_LOG_LEVEL=3 #网络名称,同目录名称 Network="DeblurGAN_ID0207_for_TensorFlow" #训练epoch -train_epochs=10 +train_epochs=300 #训练batch_size batch_size=1 #训练step @@ -137,6 +137,9 @@ do --max_epoch=300 \ --vgg_path=${ckpt_path}/vgg19.npy \ --model_path=${cur_path}/test/output/$ASCEND_DEVICE_ID/ckpt \ + --save_test_result=True \ + --in_memory=True \ + --result_path=./result \ --mode=train > ${cur_path}test/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 python3 main.py \