diff --git a/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/data_load.py b/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/data_load.py index 32bf11ca3198d89764c97bb27cf1c92f5cce3f17..866a10616ec7bfb16d3e4ca94ca6e46be9385688 100644 --- a/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/data_load.py +++ b/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/data_load.py @@ -385,7 +385,10 @@ def make_dataset(config, train_form): num_parallel_calls=14) ds = ds.prefetch(buffer_size=tf.contrib.data.AUTOTUNE).map(reshape_fn_with_mags) - ds = ds.batch(config.batch_size, drop_remainder=True) + if config.dynamic_bs: + ds = ds.batch(config.batch_size) + else: + ds = ds.batch(config.batch_size, drop_remainder=True) return ds, num_batch diff --git a/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/modules.py b/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/modules.py index b5fa2a726d2735aef14f3ce510f5c8d4d7f0b8b4..019e0c8a50ff02a98e0a41b5d0ab979d0950a638 100644 --- a/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/modules.py +++ b/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/modules.py @@ -38,6 +38,7 @@ import numpy as np from zoneout_LSTM import ZoneoutLSTMCell from tensorflow.contrib.rnn import LSTMStateTuple from npu_bridge.estimator import npu_ops +from npu_bridge.estimator.npu.npu_dynamic_rnn import DynamicRNN def embed(inputs, vocab_size, num_units, zero_pad=False, scope="embedding", reuse=None): @@ -179,7 +180,7 @@ def fullyconnected(inputs, is_training, layer_size, activation,scope='fc',reuse= return output -def bidirectional_LSTM(inputs, scope, training): +def bidirectional_LSTM_ori(inputs, scope, training): with tf.variable_scope(scope): outputs, (fw_state, bw_state) = tf.nn.bidirectional_dynamic_rnn( @@ -199,8 +200,25 @@ def bidirectional_LSTM(inputs, scope, training): return tf.concat(outputs, axis=2), final_state # Concat forward + backward outputs and final states +def bidirectional_LSTM(inputs, scope, training): -def unidirectional_LSTM(is_training, layers, size): + with tf.variable_scope(scope): + inputs_data = tf.transpose(inputs, perm=[1, 0, 2], name="transpose_inputdata") + fw_cell = DynamicRNN(hidden_size=hp.enc_units, forget_bias=1.0, dtype=tf.float32) + fw_y, fw_output_h, fw_output_c, i, j, f, o, tanhc = fw_cell(inputs_data) + bw_cell = DynamicRNN(hidden_size=hp.enc_units, forget_bias=1.0, dtype=tf.float32) + bw_y, bw_output_h, bw_output_c, i, j, f, o, tanhc = bw_cell(tf.reverse(inputs_data, axis=[0])) + output_rnn = tf.concat((fw_y, tf.reverse(bw_y, axis=[0])), axis=2) + output = tf.transpose(output_rnn, perm=[1, 0, 2], name="transpose_outputdata") + + encoder_final_state_c = tf.concat((fw_output_c, bw_output_c), 1) + encoder_final_state_h = tf.concat((fw_output_h, bw_output_h), 1) + final_state = LSTMStateTuple(c=encoder_final_state_c, h=encoder_final_state_h) + + return output, final_state # Concat forward + backward outputs and final states + + +def unidirectional_LSTM_ori(is_training, layers, size): # rnn_layers = [tf.nn.rnn_cell.LSTMCell(hp.enc_units) for i in range(layers)] @@ -209,4 +227,12 @@ def unidirectional_LSTM(is_training, layers, size): ext_proj=hp.n_mels) for i in range(layers)] stacked_LSTM_Cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers) - return stacked_LSTM_Cell \ No newline at end of file + return stacked_LSTM_Cell + + +def unidirectional_LSTM(is_training, layers, size): + + rnn_layers = [tf.nn.rnn_cell.LSTMCell(hp.enc_units) for i in range(layers)] + + stacked_LSTM_Cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers) + return stacked_LSTM_Cell diff --git a/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/networks.py b/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/networks.py index a279e1caac13b22177b15b64f8f20d26722caf0e..5602cacf442ad0e116a321914cc6d86c37d5fded 100644 --- a/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/networks.py +++ b/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/networks.py @@ -35,6 +35,7 @@ from __future__ import print_function from hyperparams import Hyperparams as hp from modules import * import tensorflow as tf +from npu_bridge.estimator.npu.npu_dynamic_rnn import DynamicRNN from rnn_wrappers import TacotronDecoderWrapper from attention_wrapper import AttentionWrapper, LocationBasedAttention, BahdanauAttention #from helpers import TacoTrainingHelper, TacoTestHelper @@ -64,8 +65,8 @@ def encoder(inputs, training=True, scope="encoder", reuse=None): return output -def decoder(mel_targets, encoder_output, scope="decoder", training=True, reuse=None): - batch_size = mel_targets.shape[0] +def decoder(mel_targets, encoder_output, batch_size, scope="decoder", training=True, reuse=None): + print("mel_targets shape:{}, encoder_output shape:{}".format(mel_targets ,encoder_output)) with tf.variable_scope(scope, reuse=reuse): decoder_cell = TacotronDecoderWrapper(unidirectional_LSTM(training, layers=hp.dec_LSTM_layers, size=hp.dec_LSTM_size), training) attention_decoder = AttentionWrapper( @@ -141,7 +142,45 @@ def decoder(mel_targets, encoder_output, scope="decoder", training=True, reuse=N return mel_logits, final_projection, done_output, final_decoder_state, concat_LSTM_att,step + def converter(inputs, training=True, scope="converter", reuse=None): + with tf.variable_scope(scope, reuse=reuse): + + with tf.variable_scope("converter_rnn"): + inputs_data = tf.transpose(inputs, perm=[1, 0, 2], name="converter_transpose_inputdata") + fw_cell = DynamicRNN(hidden_size=hp.n_mels, forget_bias=1.0, dtype=tf.float32) + # first rnn + output_fw, fw_output_h, fw_output_c, i, j, f, o, tanhc = fw_cell(inputs_data) + inputs_data = (inputs_data + output_fw) * tf.sqrt(0.5) + # second rnn + output_fw_sec, fw_output_h, fw_output_c, i, j, f, o, tanhc = fw_cell(inputs_data) + inputs_data = (inputs_data + output_fw_sec) * tf.sqrt(0.5) + + output_rnn = tf.transpose(inputs_data, perm=[1, 0, 2], name="converter_transpose_outputdata") + if hp.print_shapes: print(output_rnn) + + with tf.variable_scope("converter_conv"): + for i in range(hp.converter_layers): + outputs = conv_block(inputs, + size=hp.converter_filter_size, + rate=2 ** i, + padding="SAME", + training=training, + scope="converter_conv_{}".format(i)) + inputs = (inputs + outputs) * tf.sqrt(0.5) + output_conv = inputs + if hp.print_shapes: print(output_conv) + + inputs = (output_rnn + output_conv) * tf.sqrt(0.5) + if hp.print_shapes: print(inputs) + + with tf.variable_scope("mag_logits"): + mag_logits = fc_block(inputs, hp.n_fft // 2 + 1, training=training) + if hp.print_shapes: print(mag_logits) + + return mag_logits + +def converter_ori(inputs, training=True, scope="converter", reuse=None): with tf.variable_scope(scope, reuse=reuse): diff --git a/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/test/train_full_1p.sh b/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/test/train_full_1p.sh index 22f42099fce202df17b0a3e8b81c7e6f26537966..1112f2e8ecd8d13c8cf8d424f20e46f9bd07a939 100644 --- a/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/test/train_full_1p.sh +++ b/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/test/train_full_1p.sh @@ -52,6 +52,7 @@ if [[ $1 == --help || $1 == -h ]];then --modeldir model dir --save_interval save interval for ckpt --loss_scale enable loss scale ,default is False + --dynamic_bs dynamic batch size, default is False -h/--help show help message " exit 1 @@ -99,6 +100,8 @@ do save_interval=`echo ${para#*=}` elif [[ $para == --loss_scale* ]];then loss_scale=`echo ${para#*=}` + elif [[ $para == --dynamic_bs* ]];then + dynamic_bs=`echo ${para#*=}` fi done @@ -134,6 +137,7 @@ do --data_paths=${data_path} \ --epoch=${train_epochs} \ --batch_size=${batch_size} \ + --dynamic_bs=${dynamic_bs} \ --precision_mode=${precision_mode} \ --over_dump=${over_dump} \ --over_dump_path=${over_dump_path} \ diff --git a/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/test/train_performance_1p.sh b/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/test/train_performance_1p.sh index eed4f6f5c0f0bc5706cda8811a0e82dbde7db253..f9d8fff43a4305d5a9fed1765ec3bf2ff89f6cdb 100644 --- a/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/test/train_performance_1p.sh +++ b/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/test/train_performance_1p.sh @@ -53,6 +53,7 @@ if [[ $1 == --help || $1 == -h ]];then --modeldir model dir --save_interval save interval for ckpt --loss_scale enable loss scale ,default is False + --dynamic_bs dynamic batch size, default is False -h/--help show help message " exit 1 @@ -100,6 +101,8 @@ do save_interval=`echo ${para#*=}` elif [[ $para == --loss_scale* ]];then loss_scale=`echo ${para#*=}` + elif [[ $para == --dynamic_bs* ]];then + dynamic_bs=`echo ${para#*=}` fi done @@ -136,6 +139,7 @@ do --epoch=${train_epochs} \ --num_iterations=${train_steps} \ --batch_size=${batch_size} \ + --dynamic_bs=${dynamic_bs} \ --precision_mode=${precision_mode} \ --over_dump=${over_dump} \ --over_dump_path=${over_dump_path} \ diff --git a/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/test/train_performance_bs128_1p.sh b/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/test/train_performance_bs128_1p.sh new file mode 100644 index 0000000000000000000000000000000000000000..18db79c590538759e0ed24f7dad2aa084fed47e2 --- /dev/null +++ b/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/test/train_performance_bs128_1p.sh @@ -0,0 +1,203 @@ +#!/bin/bash +#当前路径,不需要修改 +cur_path=`pwd` + +#集合通信参数,不需要修改 + +export RANK_SIZE=1 +export JOB_ID=10087 + +RANK_ID_START=0 + + +# 数据集路径,保持为空,不需要修改 +data_path="" + + +#基础参数,需要模型审视修改 +#网络名称,同目录名称 +Network="Tacotron2-v1-Encoder_ID1997_for_TensorFlow" +#训练epoch +train_epochs=2 +#训练step +train_steps=100 +#训练batch_size +batch_size=128 +#学习率 +learning_rate=1e-3 + + +#维测参数,precision_mode需要模型审视修改 +precision_mode="allow_mix_precision" +#维持参数,以下不需要修改 +over_dump=False +data_dump_flag=False +data_dump_step="10" +profiling=False +autotune=False + +# 帮助信息,不需要修改 +if [[ $1 == --help || $1 == -h ]];then + echo "usage:./train_performance_1p.sh " + echo " " + echo "parameter explain: + --precision_mode precision mode(allow_fp32_to_fp16/force_fp16/must_keep_origin_dtype/allow_mix_precision) + --over_dump if or not over detection, default is False + --data_dump_flag data dump flag, default is False + --data_dump_step data dump step, default is 10 + --profiling if or not profiling for performance debug, default is False + --data_path source data of training + --max_step # of step for training + --learning_rate learning rate + --batch batch size + --modeldir model dir + --save_interval save interval for ckpt + --loss_scale enable loss scale ,default is False + --dynamic_bs dynamic batch size, default is False + -h/--help show help message + " + exit 1 +fi + +#参数校验,不需要修改 +for para in $* +do + if [[ $para == --precision_mode* ]];then + precision_mode=`echo ${para#*=}` + elif [[ $para == --over_dump* ]];then + over_dump=`echo ${para#*=}` + over_dump_path=${cur_path}/output/overflow_dump + mkdir -p ${over_dump_path} + elif [[ $para == --data_dump_flag* ]];then + data_dump_flag=`echo ${para#*=}` + data_dump_path=${cur_path}/output/data_dump + mkdir -p ${data_dump_path} + elif [[ $para == --data_dump_step* ]];then + data_dump_step=`echo ${para#*=}` + elif [[ $para == --profiling* ]];then + profiling=`echo ${para#*=}` + profiling_dump_path=${cur_path}/output/profiling + mkdir -p ${profiling_dump_path} + elif [[ $para == --autotune* ]];then + autotune=`echo ${para#*=}` + mv $install_path/fwkacllib/data/rl/Ascend910/custom $install_path/fwkacllib/data/rl/Ascend910/custom_bak + mv $install_path/fwkacllib/data/tiling/Ascend910/custom $install_path/fwkacllib/data/tiling/Ascend910/custom_bak + autotune_dump_path=${cur_path}/output/autotune_dump + mkdir -p ${autotune_dump_path}/GA + mkdir -p ${autotune_dump_path}/rl + cp -rf $install_path/fwkacllib/data/tiling/Ascend910/custom ${autotune_dump_path}/GA/ + cp -rf $install_path/fwkacllib/data/rl/Ascend910/custom ${autotune_dump_path}/RL/ + elif [[ $para == --data_path* ]];then + data_path=`echo ${para#*=}` + elif [[ $para == --max_step* ]];then + train_steps=`echo ${para#*=}` + elif [[ $para == --learning_rate* ]];then + learning_rate=`echo ${para#*=}` + elif [[ $para == --batch* ]];then + batch_size=`echo ${para#*=}` + elif [[ $para == --modeldir* ]];then + modeldir=`echo ${para#*=}` + elif [[ $para == --save_interval* ]];then + save_interval=`echo ${para#*=}` + elif [[ $para == --loss_scale* ]];then + loss_scale=`echo ${para#*=}` + elif [[ $para == --dynamic_bs* ]];then + dynamic_bs=`echo ${para#*=}` + fi +done + +#校验是否传入data_path,不需要修改 +if [[ $data_path == "" ]];then + echo "[Error] para \"data_path\" must be config" + exit 1 +fi + +#训练开始时间,不需要修改 +start_time=$(date +%s) +cd $cur_path/../ +#进入训练脚本目录,需要模型审视修改 +for((RANK_ID=$RANK_ID_START;RANK_ID<$((RANK_SIZE+RANK_ID_START));RANK_ID++)); +do + #设置环境变量,不需要修改 + echo "Device ID: $ASCEND_DEVICE_ID" + export RANK_ID=$RANK_ID + + #创建DeviceID输出目录,不需要修改 + if [ -d ${cur_path}/output/${ASCEND_DEVICE_ID} ];then + rm -rf ${cur_path}/output/${ASCEND_DEVICE_ID} + mkdir -p ${cur_path}/output/$ASCEND_DEVICE_ID/ckpt + else + mkdir -p ${cur_path}/output/$ASCEND_DEVICE_ID/ckpt + fi + + #执行训练脚本,以下传参不需要修改,其他需要模型审视修改 + #--data_path, --model_dir, --precision_mode, --precision_mode, --over_dump, --over_dump_path,--data_dump_flag,--data_dump_step,--data_dump_path,--profiling,--profiling_dump_path,--autotune + nohup python3 train.py \ + --log_dir=${cur_path}/output \ + --log_name="test" \ + --data_paths=${data_path} \ + --epoch=${train_epochs} \ + --num_iterations=${train_steps} \ + --batch_size=${batch_size} \ + --dynamic_bs=${dynamic_bs} \ + --precision_mode=${precision_mode} \ + --over_dump=${over_dump} \ + --over_dump_path=${over_dump_path} \ + --data_dump_flag=${data_dump_flag} \ + --data_dump_step=${data_dump_step} \ + --data_dump_path=${data_dump_path} \ + --profiling=${profiling} \ + --profiling_dump_path=${profiling_dump_path} \ + --deltree=True > ${cur_path}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & +if [ $? -ne 0 ];then + exit 1 +fi +done +wait + +#训练结束时间,不需要修改 +end_time=$(date +%s) +e2e_time=$(( $end_time - $start_time )) + +#结果打印,不需要修改 +echo "------------------ Final result ------------------" +#输出性能FPS,需要模型审视修改 +FPS_tmp=`grep "perf = " $cur_path/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log | tail -1 | awk -F "=" '{print $3}' | sed -e 's/^[ ]*//g' | sed -e 's/[ ]*$//g'` +FPS=`awk 'BEGIN {printf "%.2f\n", '${batch_size}'/'${FPS_tmp}'}'` + +TrainingTime=`awk 'BEGIN {printf "%.2f\n", '1000'*'${batch_size}'/'${FPS}'}'` +#打印,不需要修改 +echo "Final Performance images/sec : $FPS" + +#输出训练精度,需要模型审视修改 +#train_accuracy=`grep "Accuracy:" $cur_path/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log | awk -F " " '{print $2}'` +#打印,不需要修改 +#echo "Final Train Accuracy : ${train_accuracy}" +echo "E2E Training Duration sec : $e2e_time" + +#性能看护结果汇总 +#训练用例信息,不需要修改 +BatchSize=${batch_size} +DeviceType=`uname -m` +CaseName=${Network}_bs${BatchSize}_${RANK_SIZE}'p'_'perf' + +##获取性能数据,不需要修改 +#吞吐量 +ActualFPS=${FPS} + +#从train_$ASCEND_DEVICE_ID.log提取Loss到train_${CaseName}_loss.txt中,需要根据模型审视 +grep 'Loss =' $cur_path/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log|awk '{print $6}' > $cur_path/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt + +#最后一个迭代loss值,不需要修改 +ActualLoss=`awk 'END {print}' $cur_path/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt` + +#关键信息打印到${CaseName}.log中,不需要修改 +echo "Network = ${Network}" > $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "RankSize = ${RANK_SIZE}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "BatchSize = ${BatchSize}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "DeviceType = ${DeviceType}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "CaseName = ${CaseName}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualFPS = ${ActualFPS}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainingTime = ${TrainingTime}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualLoss = ${ActualLoss}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "E2ETrainingTime = ${e2e_time}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log \ No newline at end of file diff --git a/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/train.py b/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/train.py index 554426f5c962f98826b79e12ee1147bb98190230..d65b4a85de1cb73817d00554e4f9879a213038ba 100644 --- a/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/train.py +++ b/TensorFlow/built-in/audio/Tacotron2-v1-Encoder_ID1997_for_TensorFlow/train.py @@ -75,7 +75,8 @@ class Graph(): with tf.variable_scope('encoder'): self.encoded = encoder(self.x, training=training) with tf.variable_scope('decoder'): - (self.mel_output, self.bef_mel_output, self.done_output, self.decoder_state, self.LTSM, self.step) = decoder(self.y1, self.encoded, training=training) + (self.mel_output, self.bef_mel_output, self.done_output, self.decoder_state, self.LTSM, self.step) = \ + decoder(self.y1, self.encoded, config.batch_size, training=training) self.cell_state = self.decoder_state.cell_state self.mel_output = tf.nn.sigmoid(self.mel_output) if (train_form == 'Both'): @@ -121,7 +122,8 @@ class Graph(): # incr_every_n_steps=1000, # decr_every_n_nan_or_inf=2, # decr_ratio=0.5) - loss_scale_manager = FixedLossScaleManager(loss_scale=1024) + # overflow also update for dynamic_bs + loss_scale_manager = FixedLossScaleManager(loss_scale=1024, enable_overflow_check=False) self.optimizer = NPULossScaleOptimizer(opt_tmp, loss_scale_manager) self.gvs = self.optimizer.compute_gradients(self.loss) @@ -183,6 +185,7 @@ def main(): parser.add_argument('--data_dump_step', type=str, default='10') parser.add_argument('--over_dump_path', type=str, default='/home/data') parser.add_argument('--profiling_dump_path', type=str, default='/home/data') + parser.add_argument('--dynamic_bs', default=False) config = parser.parse_args() config.log_dir = ((config.log_dir + '/') + config.log_name) @@ -200,7 +203,16 @@ def main(): custom_op = session_config.graph_options.rewrite_options.custom_optimizers.add() custom_op.name = "NpuOptimizer" custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes(config.precision_mode) - custom_op.parameter_map["enable_data_pre_proc"].b = True + if config.dynamic_bs: + custom_op.parameter_map["dynamic_input"].b = True + custom_op.parameter_map["dynamic_graph_execute_mode"].s = tf.compat.as_bytes("dynamic_execute") + custom_op.parameter_map["dynamic_inputs_shape_range"].s = tf.compat.as_bytes("data:[1~200,200],[1~200,810,80],[1~200,810,1025]") + # add for bs=128/160 + custom_op.parameter_map["graph_memory_max_size"].s = tf.compat.as_bytes(str(30 * 1024 * 1024 * 1024)) + custom_op.parameter_map["variable_memory_max_size"].s = tf.compat.as_bytes(str(1 * 1024 * 1024 * 1024)) + else: + custom_op.parameter_map["enable_data_pre_proc"].b = True + if config.data_dump_flag == 'True': custom_op.parameter_map["enable_dump"].b = True custom_op.parameter_map["dump_path"].s = tf.compat.as_bytes(config.data_dump_path) @@ -259,12 +271,12 @@ def main(): if (hp.test_only == 0): with tf.Session(config=npu_config) as sess: if config.load_path: - infolog.log(('Resuming from checkpoint: %s ' % tf.train.latest_checkpoint(config.log_dir)), slack=True) - tf.train.Saver().restore(sess, tf.train.latest_checkpoint(config.log_dir)) + infolog.log(('Resuming from checkpoint: %s ' % tf.train.latest_checkpoint(config.log_dir)), slack=True) + tf.train.Saver().restore(sess, tf.train.latest_checkpoint(config.log_dir)) else: infolog.log('Starting new training', slack=True) - summary_writer = tf.summary.FileWriter(config.log_dir, sess.graph) - print("========After summary_writer") + # summary_writer = tf.summary.FileWriter(config.log_dir, sess.graph) + # print("========After summary_writer") sess.run(iterator.initializer) sess.run(tf.global_variables_initializer()) for epoch in range(1, config.epoch): @@ -277,8 +289,8 @@ def main(): (gs, merged, loss, loss1, loss1b, loss2, loss3, _) = sess.run(fetches=fetch) loss_one = [loss, loss1, loss1b, loss2, loss3] else: - fetch = [g.global_step, g.merged, g.loss, g.loss1, g.loss1b, g.loss3, g.train_op] - (gs, merged, loss, loss1, loss1b, loss3, _) = sess.run(fetches=fetch) + fetch = [g.global_step, g.loss, g.loss1, g.loss1b, g.loss3, g.train_op] + (gs, loss, loss1, loss1b, loss3, _) = sess.run(fetches=fetch) loss_one = [loss, loss1, loss1b, loss3, 0] elif (hp.train_form == 'Encoder'): if hp.include_dones: @@ -286,8 +298,8 @@ def main(): (gs, merged, loss, loss1, loss1b, loss2, _) = sess.run(fetches=fetch) loss_one = [loss, loss1, loss1b, loss2, 0] else: - fetch = [g.global_step, g.merged, g.loss, g.loss1, g.loss1b, g.train_op] - (gs, merged, loss, loss1, loss1b, _) = sess.run(fetches=fetch) + fetch = [g.global_step, g.loss, g.loss1, g.loss1b, g.train_op] + (gs, loss, loss1, loss1b, _) = sess.run(fetches=fetch) loss_one = [loss, loss1, loss1b, 0, 0] else: print("========before sess.run") @@ -315,9 +327,9 @@ def main(): pass print('###############################################################################') print("========After for num_batch") - if ((epoch % config.summary_interval) == 0): + if not config.dynamic_bs and ((epoch % config.summary_interval) == 0): infolog.log('Saving summary') - summary_writer.add_summary(merged, gs) + # summary_writer.add_summary(merged, gs) if (hp.train_form == 'Both'): if hp.include_dones: (origx, Kmel_out, Ky1, Kdone_out, Ky2, Kmag_out, Ky3) = sess.run([g.origx, g.mel_output, g.y1, g.done_output, g.y2, g.mag_output, g.y3])