From 33a59a4f69f7c329b1be66ddb3cbd34b144d0ca0 Mon Sep 17 00:00:00 2001 From: lijing Date: Mon, 11 Apr 2022 19:45:56 +0800 Subject: [PATCH] add support for eval --- .../nlp/Bert_Chinese_for_PyTorch/README.cn.md | 38 ++++++++++++++++++- .../nlp/Bert_Chinese_for_PyTorch/run_mlm.py | 4 +- .../Bert_Chinese_for_PyTorch/run_mlm_cn.sh | 4 +- .../Bert_Chinese_for_PyTorch/run_mlm_cn_8p.sh | 4 +- .../transformers/src/transformers/trainer.py | 6 +-- 5 files changed, 49 insertions(+), 7 deletions(-) diff --git a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/README.cn.md b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/README.cn.md index 95e75224bf..24a950440b 100644 --- a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/README.cn.md +++ b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/README.cn.md @@ -30,7 +30,17 @@ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/bert-base-chinese 下载后会在当前目录生成bert-base-chinese子目录 -### 4.训练 +### 4.下载精度评估处理脚本 + +下载命令 + +``` +curl https://raw.githubusercontent.com/huggingface/datasets/master/metrics/accuracy/accuracy.py -k -o accuracy.py +``` + +默认会下载accuracy.py到当前目录。如果将其下载到其他目录,请配置参数**--eval_metric_path**为accuracy.py的实际路径。 + +### 5.训练 修改run_mlm_cn.sh和run_mlm_cn_8p.sh中**--train_file**参数为使用的中文文本数据的实际路径,然后执行训练 @@ -46,3 +56,29 @@ bash run_mlm_cn.sh bash run_mlm_cn_8p.sh ``` +### 附录:单机8卡训练脚本参数说明 + +``` +python3 -m torch.distributed.launch --nproc_per_node 8 run_mlm.py \ + --model_type bert \ # 模型类型 + --config_name bert-base-chinese/config.json \ # 模型配置文件 + --tokenizer_name bert-base-chinese \ # 分词文件路径 + --train_file ./train_huawei.txt \ # 数据集路径(会被自动分割为train和val两部分) + --eval_metric_path ./accuracy.py \ # 精度评估处理脚本路径 + --line_by_line \ # 是否将数据中一行视为一句话 + --pad_to_max_length \ # 是否对数据做padding处理 + --remove_unused_columns false \ # 是否移除不可用的字段 + --save_steps 5000 \ # 保存的step间隔 + --overwrite_output_dir \ # 是否进行覆盖输出 + --per_device_train_batch_size 32 \ # 每个卡的train的batch_size + --per_device_eval_batch_size 32 \ # 每个卡的evaluate的batch_size + --do_train \ # 是否进行train + --do_eval \ # 是否进行evaluate + --fp16 \ # 是否使用混合精度 + --fp16_opt_level O2 \ # 混合精度level + --loss_scale 8192 \ # loss scale值 + --use_combine_grad \ # 是否开启tensor叠加优化 + --optim adamw_apex_fused_npu \ # 优化器 + --output_dir ./output # 输出保存路径 +``` + diff --git a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm.py b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm.py index 12d0b3668b..3f9ed43626 100644 --- a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm.py +++ b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm.py @@ -187,6 +187,7 @@ class DataTrainingArguments: "value if set." }, ) + eval_metric_path: Optional[str] = field(default='accuracy.py', metadata={"help": "path to the metric processing script."}) def __post_init__(self): if self.dataset_name is None and self.train_file is None and self.validation_file is None: @@ -486,9 +487,10 @@ def main(): # Depending on the model and config, logits may contain extra tensors, # like past_key_values, but logits always come first logits = logits[0] + logits = logits.view(labels.shape[0], labels.shape[1], -1) # on npu, logits are 2-dim, need to be reshaped before op argmax return logits.argmax(dim=-1) - metric = load_metric("accuracy") + metric = load_metric(data_args.eval_metric_path) def compute_metrics(eval_preds): preds, labels = eval_preds diff --git a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_cn.sh b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_cn.sh index d933f30ad3..182f6db929 100644 --- a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_cn.sh +++ b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_cn.sh @@ -1,17 +1,19 @@ source env.sh -export PYTHONPATH=${pwd}/transformers/src:$PYTHONPATH python3 run_mlm.py \ --model_type bert \ --config_name bert-base-chinese/config.json \ --tokenizer_name bert-base-chinese \ --train_file ./train_huawei.txt \ + --eval_metric_path ./accuracy.py \ --line_by_line \ --pad_to_max_length \ + --remove_unused_columns false \ --save_steps 5000 \ --overwrite_output_dir \ --per_device_train_batch_size 32 \ --per_device_eval_batch_size 32 \ --do_train \ + --do_eval \ --fp16 \ --fp16_opt_level O2 \ --loss_scale 8192 \ diff --git a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_cn_8p.sh b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_cn_8p.sh index 22b073e1df..1b4381d15b 100644 --- a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_cn_8p.sh +++ b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_cn_8p.sh @@ -1,17 +1,19 @@ source env.sh -export PYTHONPATH=${pwd}/transformers/src:$PYTHONPATH python3 -m torch.distributed.launch --nproc_per_node 8 run_mlm.py \ --model_type bert \ --config_name bert-base-chinese/config.json \ --tokenizer_name bert-base-chinese \ --train_file ./train_huawei.txt \ + --eval_metric_path ./accuracy.py \ --line_by_line \ --pad_to_max_length \ + --remove_unused_columns false \ --save_steps 5000 \ --overwrite_output_dir \ --per_device_train_batch_size 32 \ --per_device_eval_batch_size 32 \ --do_train \ + --do_eval \ --fp16 \ --fp16_opt_level O2 \ --loss_scale 8192 \ diff --git a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/transformers/src/transformers/trainer.py b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/transformers/src/transformers/trainer.py index a8f7a1e5cc..640bb2dbcf 100644 --- a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/transformers/src/transformers/trainer.py +++ b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/transformers/src/transformers/trainer.py @@ -669,8 +669,8 @@ class Trainer: raise ValueError("Trainer: training requires a train_dataset.") train_dataset = self.train_dataset - # if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): - # train_dataset = self._remove_unused_columns(train_dataset, description="training") + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") if isinstance(train_dataset, torch.utils.data.IterableDataset): if self.args.world_size > 1: @@ -2475,9 +2475,9 @@ class Trainer: labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) if logits is not None: logits = self._pad_across_processes(logits) - logits = self._nested_gather(logits) if self.preprocess_logits_for_metrics is not None: logits = self.preprocess_logits_for_metrics(logits, labels) + logits = self._nested_gather(logits) preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) -- Gitee