代码拉取完成,页面将自动刷新
同步操作将从 liq159159/Easy_Lstm_Cnn 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import os
import tensorflow as tf
from Parameters import Parameters as pm
from data_processing import read_category, get_wordid, get_word2vec, process, batch_iter, seq_length
from Lstm_Cnn import Lstm_CNN
def train():
tensorboard_dir = './tensorboard/Lstm_CNN'
save_dir = './checkpoints/Lstm_CNN'
if not os.path.exists(tensorboard_dir):
os.makedirs(tensorboard_dir)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_path = os.path.join(save_dir, 'best_validation')
tf.summary.scalar('loss', model.loss)
tf.summary.scalar('accuracy', model.accuracy)
merged_summary = tf.summary.merge_all()
writer = tf.summary.FileWriter(tensorboard_dir)
saver = tf.train.Saver()
session = tf.Session()
session.run(tf.global_variables_initializer())
writer.add_graph(session.graph)
x_train, y_train = process(pm.train_filename, wordid, cat_to_id, max_length=300)
x_test, y_test = process(pm.test_filename, wordid, cat_to_id, max_length=300)
for epoch in range(pm.num_epochs):
print('Epoch:', epoch+1)
num_batchs = int((len(x_train) - 1) / pm.batch_size) + 1
batch_train = batch_iter(x_train, y_train, batch_size=pm.batch_size)
for x_batch, y_batch in batch_train:
real_seq_len = seq_length(x_batch)
feed_dict = model.feed_data(x_batch, y_batch, real_seq_len, pm.keep_prob)
_, global_step, _summary, train_loss, train_accuracy = session.run([model.optimizer, model.global_step, merged_summary,
model.loss, model.accuracy], feed_dict=feed_dict)
if global_step % 100 == 0:
test_loss, test_accuracy = model.test(session, x_test, y_test)
print('global_step:', global_step, 'train_loss:', train_loss, 'train_accuracy:', train_accuracy,
'test_loss:', test_loss, 'test_accuracy:', test_accuracy)
if global_step % num_batchs == 0:
print('Saving Model...')
saver.save(session, save_path, global_step=global_step)
pm.learning_rate *= pm.lr_decay
if __name__ == '__main__':
pm = pm
filenames = [pm.train_filename, pm.test_filename, pm.val_filename]
categories, cat_to_id = read_category()
wordid = get_wordid(pm.vocab_filename)
pm.vocab_size = len(wordid)
pm.pre_trianing = get_word2vec(pm.vector_word_npz)
model = Lstm_CNN()
train()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。