1 Star 0 Fork 0

huwei/tableImageParser_tx

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train.py 9.64 KB
一键复制 编辑 原始数据 按行查看 历史
tommyMessi 提交于 2020-07-14 23:21 +08:00 . update code
import time
import numpy as np
import tensorflow as tf
from tensorflow.contrib import slim
import cv2
tf.app.flags.DEFINE_integer('input_size', 512, '')
tf.app.flags.DEFINE_integer('batch_size_per_gpu', 3, '')
tf.app.flags.DEFINE_integer('num_readers', 16, '')
tf.app.flags.DEFINE_float('learning_rate', 0.0001, '')
tf.app.flags.DEFINE_integer('max_steps', 100000, '')
tf.app.flags.DEFINE_float('moving_average_decay', 0.997, '')
tf.app.flags.DEFINE_string('gpu_list', '1', '')
tf.app.flags.DEFINE_string('checkpoint_path', './model/', '')
tf.app.flags.DEFINE_boolean('restore', False, 'whether to resotre from checkpoint')
tf.app.flags.DEFINE_integer('save_checkpoint_steps', 100, '')
tf.app.flags.DEFINE_integer('save_summary_steps', 100, '')
tf.app.flags.DEFINE_string('pretrained_model_path', None, '')
import model as model
import dataf
FLAGS = tf.app.flags.FLAGS
gpus = list(range(len(FLAGS.gpu_list.split(','))))
def tower_loss(images, score_maps_nrow, score_maps_ncol, score_maps_row,
score_maps_col, training_masks, reuse_variables=None):
# Build inference graph
with tf.variable_scope(tf.get_variable_scope(), reuse=reuse_variables):
f_score_nrow, f_score_ncol, \
f_score_row, f_score_col = model.model(images, is_training=True)
model_loss = model.loss(score_maps_nrow, f_score_nrow,
score_maps_ncol, f_score_ncol,
score_maps_row, f_score_row,
score_maps_col, f_score_col,
training_masks)
total_loss = tf.add_n([model_loss] + tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
# add summary
if reuse_variables is None:
tf.summary.image('input', images)
tf.summary.image('score_map', score_maps_nrow)
tf.summary.image('score_map_pred', f_score_nrow * 255)
tf.summary.image('score_map', score_maps_ncol)
tf.summary.image('score_map_pred', f_score_ncol * 255)
tf.summary.image('score_map', score_maps_row)
tf.summary.image('score_map_pred', f_score_row * 255)
tf.summary.image('score_map', score_maps_col)
tf.summary.image('score_map_pred', f_score_col * 255)
tf.summary.image('training_masks', training_masks)
tf.summary.scalar('model_loss', model_loss)
tf.summary.scalar('total_loss', total_loss)
return total_loss, model_loss
def average_gradients(tower_grads):
average_grads = []
for grad_and_vars in zip(*tower_grads):
grads = []
for g, _ in grad_and_vars:
expanded_g = tf.expand_dims(g, 0)
grads.append(expanded_g)
grad = tf.concat(grads, 0)
grad = tf.reduce_mean(grad, 0)
v = grad_and_vars[0][1]
grad_and_var = (grad, v)
average_grads.append(grad_and_var)
return average_grads
def main(argv=None):
import os
os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
if not tf.gfile.Exists(FLAGS.checkpoint_path):
tf.gfile.MkDir(FLAGS.checkpoint_path)
else:
if not FLAGS.restore:
tf.gfile.DeleteRecursively(FLAGS.checkpoint_path)
tf.gfile.MkDir(FLAGS.checkpoint_path)
input_images = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_images')
input_score_maps_nrow = tf.placeholder(tf.float32, shape=[None, None, None, 1], name='input_score_maps_nrow')
input_score_maps_ncol = tf.placeholder(tf.float32, shape=[None, None, None, 1], name='input_score_maps_ncol')
input_score_maps_row = tf.placeholder(tf.float32, shape=[None, None, None, 1], name='input_score_maps_row')
input_score_maps_col = tf.placeholder(tf.float32, shape=[None, None, None, 1], name='input_score_maps_col')
input_training_masks = tf.placeholder(tf.float32, shape=[None, None, None, 1], name='input_training_masks')
global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, decay_steps=10000, decay_rate=0.94, staircase=True)
# add summary
tf.summary.scalar('learning_rate', learning_rate)
opt = tf.train.AdamOptimizer(learning_rate)
# opt = tf.train.MomentumOptimizer(learning_rate, 0.9)
# split
input_images_split = tf.split(input_images, len(gpus))
input_score_maps_split_nrow = tf.split(input_score_maps_nrow, len(gpus))
input_score_maps_split_ncol = tf.split(input_score_maps_ncol, len(gpus))
input_score_maps_split_row = tf.split(input_score_maps_row, len(gpus))
input_score_maps_split_col = tf.split(input_score_maps_col, len(gpus))
input_training_masks_split = tf.split(input_training_masks, len(gpus))
tower_grads = []
reuse_variables = None
for i, gpu_id in enumerate(gpus):
with tf.device('/gpu:%d' % gpu_id):
with tf.name_scope('model_%d' % gpu_id) as scope:
iis = input_images_split[i]
isms_nrow = input_score_maps_split_nrow[i]
isms_ncol = input_score_maps_split_ncol[i]
isms_row = input_score_maps_split_row[i]
isms_col = input_score_maps_split_col[i]
itms = input_training_masks_split[i]
# total_loss, model_loss = tower_loss(iis, isms, igms, itms, reuse_variables)
total_loss, model_loss = tower_loss(iis, isms_nrow,
isms_ncol, isms_row,
isms_col, itms, reuse_variables)
batch_norm_updates_op = tf.group(*tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope))
reuse_variables = True
grads = opt.compute_gradients(total_loss)
tower_grads.append(grads)
grads = average_gradients(tower_grads)
apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
summary_op = tf.summary.merge_all()
# save moving average
variable_averages = tf.train.ExponentialMovingAverage(
FLAGS.moving_average_decay, global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
# batch norm updates
with tf.control_dependencies([variables_averages_op, apply_gradient_op, batch_norm_updates_op]):
train_op = tf.no_op(name='train_op')
saver = tf.train.Saver(tf.global_variables())
summary_writer = tf.summary.FileWriter(FLAGS.checkpoint_path, tf.get_default_graph())
init = tf.global_variables_initializer()
if FLAGS.pretrained_model_path is not None:
variable_restore_op = slim.assign_from_checkpoint_fn(FLAGS.pretrained_model_path, slim.get_trainable_variables(),
ignore_missing_vars=True)
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
if FLAGS.restore:
print('continue training from previous checkpoint')
ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
saver.restore(sess, ckpt)
else:
sess.run(init)
if FLAGS.pretrained_model_path is not None:
variable_restore_op(sess)
data_generator = dataf.get_batch(num_workers=FLAGS.num_readers,
input_size=FLAGS.input_size,
batch_size=FLAGS.batch_size_per_gpu * len(gpus))
start = time.time()
for step in range(FLAGS.max_steps):
data = next(data_generator)
ml, tl, _ = sess.run([model_loss, total_loss, train_op], feed_dict={input_images: data[0],
input_score_maps_nrow: data[2],
input_score_maps_ncol: data[3],
input_score_maps_row: data[4],
input_score_maps_col: data[5],
input_training_masks: data[6]})
if np.isnan(tl):
print('Loss diverged, stop training')
break
if step % 10 == 0:
avg_time_per_step = (time.time() - start)/10
avg_examples_per_second = (10 * FLAGS.batch_size_per_gpu * len(gpus))/(time.time() - start)
start = time.time()
print('Step {:06d}, model loss {:.4f}, total loss {:.4f}, {:.2f} seconds/step, {:.2f} examples/second'.format(
step, ml, tl, avg_time_per_step, avg_examples_per_second))
if step % FLAGS.save_checkpoint_steps == 0:
saver.save(sess, FLAGS.checkpoint_path + 'model.ckpt', global_step=global_step)
if step % FLAGS.save_summary_steps == 0:
_, tl, summary_str = sess.run([train_op, total_loss, summary_op], feed_dict={input_images: data[0],
input_score_maps_nrow: data[2],
input_score_maps_ncol: data[3],
input_score_maps_row: data[4],
input_score_maps_col: data[5],
input_training_masks: data[6]})
summary_writer.add_summary(summary_str, global_step=step)
if __name__ == '__main__':
tf.app.run()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/debug-huwei/tableImageParser_tx.git
git@gitee.com:debug-huwei/tableImageParser_tx.git
debug-huwei
tableImageParser_tx
tableImageParser_tx
master

搜索帮助