From 93e80e6aeb0357c41c0d42eb251ce9a62083f77e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=8E=E6=9C=A8=E6=9E=97?= <762129126@qq.com> Date: Fri, 14 Jul 2023 02:09:34 +0000 Subject: [PATCH] =?UTF-8?q?=E3=80=90ADD=E3=80=91=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=E6=80=A7=E8=83=BD=E6=89=93=E7=82=B9=E6=97=A5=E5=BF=97=E3=80=82?= =?UTF-8?q?=E3=80=90UPDATE=E3=80=91=E4=BF=AE=E6=94=B9=E4=BF=9D=E5=AD=98che?= =?UTF-8?q?ckpoint=E9=A2=91=E7=8E=87=EF=BC=9A=E4=BB=8E1000=E4=B8=AAstep?= =?UTF-8?q?=E4=BF=9D=E5=AD=98=E4=B8=80=E6=AC=A1=E4=BF=AE=E6=94=B9=E4=B8=BA?= =?UTF-8?q?100000=EF=BC=8C=E5=87=8F=E5=B0=91=E5=AF=B9=E7=AB=AF=E5=88=B0?= =?UTF-8?q?=E7=AB=AF=E6=80=A7=E8=83=BD=E7=9A=84=E5=BD=B1=E5=93=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../run_squad.py | 81 ++++++++++++++++++- 1 file changed, 78 insertions(+), 3 deletions(-) diff --git a/TensorFlow/built-in/nlp/BertGoogle_Series_for_TensorFlow/run_squad.py b/TensorFlow/built-in/nlp/BertGoogle_Series_for_TensorFlow/run_squad.py index 8c70b73c9..af51680a7 100644 --- a/TensorFlow/built-in/nlp/BertGoogle_Series_for_TensorFlow/run_squad.py +++ b/TensorFlow/built-in/nlp/BertGoogle_Series_for_TensorFlow/run_squad.py @@ -41,6 +41,7 @@ import modeling import optimization import tokenization import six +import time import tensorflow as tf from npu_bridge.estimator.npu.npu_config import NPURunConfig from npu_bridge.estimator import npu_ops @@ -50,6 +51,80 @@ flags = tf.flags FLAGS = flags.FLAGS +class _TrainHook(tf.train.SessionRunHook): + """Logs loss and runtime.""" + def __init__(self, num_train_steps): + self.num_train_steps = num_train_steps / FLAGS.iterations_per_loop + + def after_create_session(self, session, coord): + self.init_time = time.time() + self.hist_time= 0 + self.hist_fps = 0 + self.hist_samples = 0 + self.epoch1_time= 0 + self.epoch_fps = 0 + self.epoch1_samples = 0 + self.epoch2_time= 0 + self.epoch2_samples = 0 + + def begin(self): + self._step = 0 + self._start_time = time.time() + + def before_run(self, run_context): + self._step += 1 + self._start_time = time.time() + + def after_run(self, run_context, run_values): + duration = time.time() - self._start_time + examples_per_sec = FLAGS.train_batch_size * rank_size * FLAGS.iterations_per_loop / duration + if self._step <= self.num_train_steps / 2: + self.epoch1_time += duration + self.epoch_time = self.epoch1_time + self.epoch1_samples += FLAGS.train_batch_size * FLAGS.iterations_per_loop + self.epoch_samples = self.epoch1_samples + else: + self.epoch2_time += duration + self.epoch_time = self.epoch2_time + self.epoch2_samples += FLAGS.train_batch_size * FLAGS.iterations_per_loop + self.epoch_samples = self.epoch2_samples + epoch = self._step / (self.num_train_steps / 2 ) + 1 + self.epoch_fps = self.epoch_samples / self.epoch_time * rank_size + self.hist_samples += FLAGS.train_batch_size * FLAGS.iterations_per_loop + self.hist_time += duration + self.hist_fps = self.hist_samples * rank_size / self.hist_time + print ('epoch:%d, step:%d, examples/sec:%.1f, time:%.3f, epoch_time:%.3f, epoch_fps:%.1f, hist_samples:%.1f, hist_time:%.3f, hist_fps:%.1f' % (epoch, self._step * FLAGS.iterations_per_loop, + examples_per_sec, duration, self.epoch_time, self.epoch_fps, self.hist_samples, self.hist_time, self.hist_fps)) + +class _EvalHook(tf.train.SessionRunHook): + """Logs loss and runtime.""" + def __init__(self, samples_num): + self.samples_num = samples_num + + def after_create_session(self, session, coord): + self.init_time = time.time() + self.hist_time= 0 + self.hist_fps = 0 + self.hist_samples = 0 + + + def begin(self): + self._step = -1 + self._start_time = time.time() + + def before_run(self, run_context): + self._step += 1 + self._start_time = time.time() + + def after_run(self, run_context, run_values): + duration = time.time() - self._start_time + examples_per_sec = FLAGS.predict_batch_size / duration + self.hist_samples += FLAGS.predict_batch_size + self.hist_time += duration + self.hist_fps = self.hist_samples / self.hist_time + print ('predict: step:%d, examples/sec:%.1f, time:%.3f, hist_samples:%.1f, hist_time:%.3f, hist_fps:%.1f' % (self._step, + examples_per_sec, duration, self.hist_samples, self.hist_time, self.hist_fps)) + rank_size = int(os.getenv("RANK_SIZE")) rank_id = int(os.getenv("RANK_ID")) @@ -118,7 +193,7 @@ flags.DEFINE_float( "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10% of training.") -flags.DEFINE_integer("save_checkpoints_steps", 1000, +flags.DEFINE_integer("save_checkpoints_steps", 100000, "How often to save the model checkpoint.") flags.DEFINE_integer("num_train_steps", 0, @@ -1234,7 +1309,7 @@ def main(_): seq_length=FLAGS.max_seq_length, is_training=True, drop_remainder=True) - estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) + estimator.train(input_fn=train_input_fn, max_steps=num_train_steps, hooks=[_TrainHook(num_train_steps)]) else: if FLAGS.do_train: train_examples = read_squad_examples( @@ -1342,7 +1417,7 @@ def main(_): # steps. all_results = [] for result in estimator.predict( - predict_input_fn, yield_single_examples=True): + predict_input_fn, yield_single_examples=True, hooks=[_EvalHook(len(eval_examples))]): if len(all_results) % 1000 == 0: tf.logging.info("Processing example: %d" % (len(all_results))) unique_id = int(result["unique_ids"]) -- Gitee