diff --git a/TensorFlow/built-in/nlp/BertGoogle_Series_for_TensorFlow/run_squad.py b/TensorFlow/built-in/nlp/BertGoogle_Series_for_TensorFlow/run_squad.py index 7a056a808b5b7ae8fa435450cf1b9d164979de79..8c70b73c91bc29a7547406b6c652ca7d416f1a72 100644 --- a/TensorFlow/built-in/nlp/BertGoogle_Series_for_TensorFlow/run_squad.py +++ b/TensorFlow/built-in/nlp/BertGoogle_Series_for_TensorFlow/run_squad.py @@ -176,6 +176,8 @@ flags.DEFINE_float( "null_score_diff_threshold", 0.0, "If null_score - best_non_null is greater than the threshold predict null.") +flags.DEFINE_bool("read_tf_record", False, "if or not read tf record.") + class SquadExample(object): """A single training/test example for simple sequence classification. @@ -743,9 +745,10 @@ def input_fn_builder(input_file, seq_length, is_training, drop_remainder=True): # For eval, we want no shuffling and parallel reading doesn't matter. d = tf.data.TFRecordDataset(input_file) if is_training: - d = d.repeat() + # d = d.repeat() if rank_size > 1: d = d.shard(rank_size, rank_id) + d = d.repeat() d = d.shuffle(buffer_size=100) d = d.apply( tf.contrib.data.map_and_batch( @@ -1198,71 +1201,106 @@ def main(_): train_examples = None num_train_steps = FLAGS.num_train_steps num_warmup_steps = None - if FLAGS.do_train: - train_examples = read_squad_examples( - input_file=FLAGS.train_file, is_training=True) - - if num_train_steps == 0: - num_train_steps = int( - len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) - - #print("lenof train_examples = %s , num_train_epochs = %s" %(len(train_examples), FLAGS.num_train_epochs)) - #num_train_steps = int(num_train_steps / rank_size) - num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) - - # Pre-shuffle the input to avoid having to make a very large shuffle - # buffer in in the `input_fn`. - rng = random.Random(12345) - rng.shuffle(train_examples) - - model_fn = model_fn_builder( - bert_config=bert_config, - init_checkpoint=FLAGS.init_checkpoint, - learning_rate=FLAGS.learning_rate, - num_train_steps=num_train_steps, - num_warmup_steps=num_warmup_steps, - use_tpu=FLAGS.use_tpu, - use_one_hot_embeddings=FLAGS.use_tpu) - - # If TPU is not available, this will fall back to normal Estimator on CPU - # or GPU. - estimator = NPUEstimator( - model_fn=model_fn, - config=run_config, - model_dir=FLAGS.output_dir, - params={"batch_size": FLAGS.train_batch_size, "predict_batch_size": FLAGS.predict_batch_size}) - #train_batch_size=FLAGS.train_batch_size, - #predict_batch_size=FLAGS.predict_batch_size) + if FLAGS.read_tf_record: + if FLAGS.do_train: + num_train_steps = int(87599 / FLAGS.train_batch_size * FLAGS.num_train_epochs) + num_train_steps = int(num_train_steps / rank_size) + num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) - if FLAGS.do_train: - # We write to a temporary file to avoid storing very large constant tensors - # in memory. - train_writer = FeatureWriter( - filename=os.path.join(FLAGS.output_dir, "train.tf_record"), - is_training=True) - convert_examples_to_features( - examples=train_examples, - tokenizer=tokenizer, - max_seq_length=FLAGS.max_seq_length, - doc_stride=FLAGS.doc_stride, - max_query_length=FLAGS.max_query_length, - is_training=True, - output_fn=train_writer.process_feature) - train_writer.close() - - tf.logging.info("***** Running training *****") - tf.logging.info(" Num orig examples = %d", len(train_examples)) - tf.logging.info(" Num split examples = %d", train_writer.num_features) - tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) - tf.logging.info(" Num steps = %d", num_train_steps) - del train_examples - - train_input_fn = input_fn_builder( - input_file=train_writer.filename, - seq_length=FLAGS.max_seq_length, - is_training=True, - drop_remainder=True) - estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) + model_fn = model_fn_builder( + bert_config=bert_config, + init_checkpoint=FLAGS.init_checkpoint, + learning_rate=FLAGS.learning_rate, + num_train_steps=num_train_steps, + num_warmup_steps=num_warmup_steps, + use_tpu=FLAGS.use_tpu, + use_one_hot_embeddings=FLAGS.use_tpu) + + # If TPU is not available, this will fall back to normal Estimator on CPU + # or GPU. + estimator = NPUEstimator( + model_fn=model_fn, + config=run_config, + model_dir=FLAGS.output_dir, + params={"batch_size": FLAGS.train_batch_size, "predict_batch_size": FLAGS.predict_batch_size}) + + if FLAGS.do_train: + tf.logging.info("***** Running training *****") + tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) + tf.logging.info(" Num steps = %d", num_train_steps) + + train_input_fn = input_fn_builder( + input_file=FLAGS.train_file, + seq_length=FLAGS.max_seq_length, + is_training=True, + drop_remainder=True) + estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) + else: + if FLAGS.do_train: + train_examples = read_squad_examples( + input_file=FLAGS.train_file, is_training=True) + + if num_train_steps == 0: + num_train_steps = int( + len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) + + #print("lenof train_examples = %s , num_train_epochs = %s" %(len(train_examples), FLAGS.num_train_epochs)) + #num_train_steps = int(num_train_steps / rank_size) + num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) + + # Pre-shuffle the input to avoid having to make a very large shuffle + # buffer in in the `input_fn`. + rng = random.Random(12345) + rng.shuffle(train_examples) + + model_fn = model_fn_builder( + bert_config=bert_config, + init_checkpoint=FLAGS.init_checkpoint, + learning_rate=FLAGS.learning_rate, + num_train_steps=num_train_steps, + num_warmup_steps=num_warmup_steps, + use_tpu=FLAGS.use_tpu, + use_one_hot_embeddings=FLAGS.use_tpu) + + # If TPU is not available, this will fall back to normal Estimator on CPU + # or GPU. + estimator = NPUEstimator( + model_fn=model_fn, + config=run_config, + model_dir=FLAGS.output_dir, + params={"batch_size": FLAGS.train_batch_size, "predict_batch_size": FLAGS.predict_batch_size}) + #train_batch_size=FLAGS.train_batch_size, + #predict_batch_size=FLAGS.predict_batch_size) + + if FLAGS.do_train: + # We write to a temporary file to avoid storing very large constant tensors + # in memory. + train_writer = FeatureWriter( + filename=os.path.join(FLAGS.output_dir, "train.tf_record"), + is_training=True) + convert_examples_to_features( + examples=train_examples, + tokenizer=tokenizer, + max_seq_length=FLAGS.max_seq_length, + doc_stride=FLAGS.doc_stride, + max_query_length=FLAGS.max_query_length, + is_training=True, + output_fn=train_writer.process_feature) + train_writer.close() + + tf.logging.info("***** Running training *****") + tf.logging.info(" Num orig examples = %d", len(train_examples)) + tf.logging.info(" Num split examples = %d", train_writer.num_features) + tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) + tf.logging.info(" Num steps = %d", num_train_steps) + del train_examples + + train_input_fn = input_fn_builder( + input_file=train_writer.filename, + seq_length=FLAGS.max_seq_length, + is_training=True, + drop_remainder=True) + estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) if FLAGS.do_predict: eval_examples = read_squad_examples( diff --git a/TensorFlow/built-in/nlp/BertGoogle_Series_for_TensorFlow/test/train_ID0495_Bert-Squad_full_1p.sh b/TensorFlow/built-in/nlp/BertGoogle_Series_for_TensorFlow/test/train_ID0495_Bert-Squad_full_1p.sh index 771488ea15bc6699199b550c8a6f682f8c4fe7c8..4c5aca93b9d8023c66da8d3dcd670ad13cb6a2e6 100644 --- a/TensorFlow/built-in/nlp/BertGoogle_Series_for_TensorFlow/test/train_ID0495_Bert-Squad_full_1p.sh +++ b/TensorFlow/built-in/nlp/BertGoogle_Series_for_TensorFlow/test/train_ID0495_Bert-Squad_full_1p.sh @@ -75,7 +75,7 @@ fi vocab_file=${data_path}/model/vocab.txt bert_config_file=${data_path}/model/bert_config.json init_checkpoint=${data_path}/model/bert_model.ckpt -train_file=${data_path}/dataset/train-v1.1.json +train_file=${data_path}/dataset/train.tf_record predict_file=${data_path}/dataset/dev-v1.1.json #训练开始时间,不需要修改 @@ -100,7 +100,8 @@ do nohup python3.7 ${parent_path}/run_squad.py \ --vocab_file=$vocab_file \ --bert_config_file=$bert_config_file \ - --init_checkpoint=$init_checkpoint \ + --init_checkpoint=$init_checkpoint \\ + --read_tf_record=True \ --train_file=$train_file \ --do_predict=True \ --do_train=True \ diff --git a/TensorFlow/built-in/nlp/BertGoogle_Series_for_TensorFlow/test/train_ID0495_Bert-Squad_full_8p.sh b/TensorFlow/built-in/nlp/BertGoogle_Series_for_TensorFlow/test/train_ID0495_Bert-Squad_full_8p.sh index a12eeef6d5f5d1c2f593e65235c7de3686d35d10..be1bbf0f19a81fc376439f223d6b86e51d0c598d 100644 --- a/TensorFlow/built-in/nlp/BertGoogle_Series_for_TensorFlow/test/train_ID0495_Bert-Squad_full_8p.sh +++ b/TensorFlow/built-in/nlp/BertGoogle_Series_for_TensorFlow/test/train_ID0495_Bert-Squad_full_8p.sh @@ -16,7 +16,7 @@ data_path="" #基础参数 需要模型审视修改 #网络名称,同目录名称 Network="Bertsquad_ID0495_for_TensorFlow" -batch_size=32 +batch_size=16 epoch=2 #维持参数,不需要修改 @@ -76,7 +76,7 @@ fi vocab_file=${data_path}/model/vocab.txt bert_config_file=${data_path}/model/bert_config.json init_checkpoint=${data_path}/model/bert_model.ckpt -train_file=${data_path}/dataset/train-v1.1.json +train_file=${data_path}/dataset/train.tf_record predict_file=${data_path}/dataset/dev-v1.1.json #训练开始时间,不需要修改 @@ -103,13 +103,14 @@ do --vocab_file=$vocab_file \ --bert_config_file=$bert_config_file \ --init_checkpoint=$init_checkpoint \ + --read_tf_record=True \ --train_file=$train_file \ --do_predict=True \ --do_train=True \ --predict_file=$predict_file \ --train_batch_size=${batch_size} \ --num_train_epochs=${epoch} \ - --learning_rate=3e-5 \ + --learning_rate=5e-5 \ --max_seq_length=384 \ --doc_stride=128 \ --output_dir=${cur_path}/output/${ASCEND_DEVICE_ID}/ckpt > ${cur_path}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 &