From 1f0266430e8f90596ae1ae2bd63365e88d33a57c Mon Sep 17 00:00:00 2001 From: lijing Date: Tue, 19 Apr 2022 10:49:44 +0800 Subject: [PATCH 1/5] =?UTF-8?q?=E4=BF=AE=E6=94=B9readme=EF=BC=8C=E4=BF=AE?= =?UTF-8?q?=E5=A4=8Dbertpool=E7=BB=B4=E5=BA=A6=E4=B8=8D=E4=B8=80=E8=87=B4?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../nlp/Bert_Chinese_for_PyTorch/README.cn.md | 14 ++++++++++++++ .../nlp/Bert_Chinese_for_PyTorch/requirements.txt | 1 + .../src/transformers/models/bert/modeling_bert.py | 2 +- 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/README.cn.md b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/README.cn.md index b39fda0a8f..722b82f649 100644 --- a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/README.cn.md +++ b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/README.cn.md @@ -104,3 +104,17 @@ python3 -m torch.distributed.launch --nproc_per_node 8 run_mlm.py \ --output_dir ./output # 输出保存路径 ``` +### Q&A + +1. Q:第一次运行报类似"xxx **socket timeout** xxx"的错误该怎么办? + + A:第一次运行tokenizer会对单词进行预处理,根据您的数据集大小,耗时不同,若时间过长,可能导致HCCL通信超时。此时可以通过设置以下环境变量,设置较大的超时时间阈值(单位秒,默认为600秒): + + ``` + export HCCL_CONNECT_TIMEOUT=3600 + ``` + + + + + diff --git a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/requirements.txt b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/requirements.txt index bc4f8bc394..b264df8cf0 100644 --- a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/requirements.txt +++ b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/requirements.txt @@ -4,3 +4,4 @@ tokenizers sentencepiece != 0.1.92 protobuf wikiextractor +sklearn diff --git a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/transformers/src/transformers/models/bert/modeling_bert.py b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/transformers/src/transformers/models/bert/modeling_bert.py index b341e36238..39bb0946ea 100644 --- a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/transformers/src/transformers/models/bert/modeling_bert.py +++ b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/transformers/src/transformers/models/bert/modeling_bert.py @@ -1075,7 +1075,7 @@ class BertModel(BertPreTrainedModel): return_dict=return_dict, ) sequence_output = encoder_outputs[0] - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + pooled_output = self.pooler(sequence_output.view(bs, from_seq_len, -1)) if self.pooler is not None else None if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] -- Gitee From 99f45fff0e9123a7ae4c9c4fe0740d875fd57d6a Mon Sep 17 00:00:00 2001 From: lijing Date: Tue, 19 Apr 2022 13:12:09 +0800 Subject: [PATCH 2/5] =?UTF-8?q?eval=E6=B7=B7=E5=90=88=E7=B2=BE=E5=BA=A6?= =?UTF-8?q?=E9=80=BB=E8=BE=91=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../transformers/src/transformers/trainer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/transformers/src/transformers/trainer.py b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/transformers/src/transformers/trainer.py index 640bb2dbcf..2ade6810fc 100644 --- a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/transformers/src/transformers/trainer.py +++ b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/transformers/src/transformers/trainer.py @@ -1051,8 +1051,11 @@ class Trainer: return model # Mixed precision training with apex (torch < 1.6) - if self.use_apex and training: - model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level, loss_scale=self.args.loss_scale, combine_grad=self.args.use_combine_grad) + if self.use_apex: + if training: + model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level, loss_scale=self.args.loss_scale, combine_grad=self.args.use_combine_grad) + elif self.optimizer is None: + model = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level, loss_scale=self.args.loss_scale, combine_grad=self.args.use_combine_grad) # Multi-gpu training (should be after apex fp16 initialization) if self.args.n_gpu > 1: -- Gitee From 83e2d5e34a37922c645f08a40187d47892d45648 Mon Sep 17 00:00:00 2001 From: lijing Date: Tue, 19 Apr 2022 18:42:11 +0800 Subject: [PATCH 3/5] =?UTF-8?q?eval=E9=98=B6=E6=AE=B5=E5=9B=BA=E5=AE=9Asha?= =?UTF-8?q?pe=EF=BC=8Cevaluate=E6=97=B6=E9=BB=98=E8=AE=A4=E6=AF=8F?= =?UTF-8?q?=E9=9A=94100step=E5=B0=86device=E7=BB=93=E6=9E=9C=E8=BF=9B?= =?UTF-8?q?=E8=A1=8C=E6=B1=87=E8=81=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../run_mlm_bertbase_1p.sh | 1 + .../run_mlm_bertbase_8p.sh | 1 + .../run_mlm_bertlarge_1p.sh | 1 + .../run_mlm_bertlarge_8p.sh | 1 + .../transformers/src/transformers/trainer.py | 26 ++++++++++++------- 5 files changed, 20 insertions(+), 10 deletions(-) diff --git a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_bertbase_1p.sh b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_bertbase_1p.sh index fdce09aca0..e4ef7e87d0 100644 --- a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_bertbase_1p.sh +++ b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_bertbase_1p.sh @@ -16,6 +16,7 @@ python3 run_mlm.py \ --per_device_eval_batch_size 32 \ --do_train \ --do_eval \ + --eval_accumulation_steps 100 \ --fp16 \ --fp16_opt_level O2 \ --loss_scale 8192 \ diff --git a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_bertbase_8p.sh b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_bertbase_8p.sh index 75ab7a2188..99d52812ee 100644 --- a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_bertbase_8p.sh +++ b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_bertbase_8p.sh @@ -16,6 +16,7 @@ python3 -m torch.distributed.launch --nproc_per_node 8 run_mlm.py \ --per_device_eval_batch_size 32 \ --do_train \ --do_eval \ + --eval_accumulation_steps 100 \ --fp16 \ --fp16_opt_level O2 \ --loss_scale 8192 \ diff --git a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_bertlarge_1p.sh b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_bertlarge_1p.sh index 0e237b1bbb..61c70f6389 100644 --- a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_bertlarge_1p.sh +++ b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_bertlarge_1p.sh @@ -16,6 +16,7 @@ python3 run_mlm.py \ --per_device_eval_batch_size 16 \ --do_train \ --do_eval \ + --eval_accumulation_steps 100 \ --fp16 \ --fp16_opt_level O2 \ --loss_scale 8192 \ diff --git a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_bertlarge_8p.sh b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_bertlarge_8p.sh index c95acb947e..afa52ae7de 100644 --- a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_bertlarge_8p.sh +++ b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_bertlarge_8p.sh @@ -16,6 +16,7 @@ python3 -m torch.distributed.launch --nproc_per_node 8 run_mlm.py \ --per_device_eval_batch_size 16 \ --do_train \ --do_eval \ + --eval_accumulation_steps 100 \ --fp16 \ --fp16_opt_level O2 \ --loss_scale 8192 \ diff --git a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/transformers/src/transformers/trainer.py b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/transformers/src/transformers/trainer.py index 2ade6810fc..4d704a9cfb 100644 --- a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/transformers/src/transformers/trainer.py +++ b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/transformers/src/transformers/trainer.py @@ -2442,9 +2442,9 @@ class Trainer: # Initialize containers # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps) - losses_host = None - preds_host = None - labels_host = None + losses_host = [] + preds_host = [] + labels_host = [] # losses/preds/labels on CPU (final containers) all_losses = None all_preds = None @@ -2471,35 +2471,38 @@ class Trainer: # Update containers on host if loss is not None: losses = self._nested_gather(loss.repeat(batch_size)) - losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) + losses_host.append(losses) if labels is not None: labels = self._pad_across_processes(labels) labels = self._nested_gather(labels) - labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) + labels_host.append(labels) if logits is not None: logits = self._pad_across_processes(logits) if self.preprocess_logits_for_metrics is not None: logits = self.preprocess_logits_for_metrics(logits, labels) logits = self._nested_gather(logits) - preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) + preds_host.append(logits) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: - if losses_host is not None: + if losses_host: + losses_host = torch.cat(losses_host, dim=0) losses = nested_numpify(losses_host) all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) - if preds_host is not None: + if preds_host: + preds_host = torch.cat(preds_host, dim=0) logits = nested_numpify(preds_host) all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) - if labels_host is not None: + if labels_host: + labels_host = torch.cat(labels_host, dim=0) labels = nested_numpify(labels_host) all_labels = ( labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) ) # Set back to None to begin a new accumulation - losses_host, preds_host, labels_host = None, None, None + losses_host, preds_host, labels_host = [], [], [] if args.past_index and hasattr(self, "_past"): # Clean the state at the end of the evaluation loop @@ -2507,12 +2510,15 @@ class Trainer: # Gather all remaining tensors and put them back on the CPU if losses_host is not None: + losses_host = torch.cat(losses_host, dim=0) losses = nested_numpify(losses_host) all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) if preds_host is not None: + preds_host = torch.cat(preds_host, dim=0) logits = nested_numpify(preds_host) all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) if labels_host is not None: + labels_host = torch.cat(labels_host, dim=0) labels = nested_numpify(labels_host) all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) -- Gitee From 7bcebd0f0c099cefb02f7311b64b960130015da6 Mon Sep 17 00:00:00 2001 From: lijing Date: Tue, 19 Apr 2022 18:45:36 +0800 Subject: [PATCH 4/5] =?UTF-8?q?eval=E9=98=B6=E6=AE=B5=E5=9B=BA=E5=AE=9Asha?= =?UTF-8?q?pe=EF=BC=8Cevaluate=E6=97=B6=E9=BB=98=E8=AE=A4=E6=AF=8F?= =?UTF-8?q?=E9=9A=94100step=E5=B0=86device=E7=BB=93=E6=9E=9C=E8=BF=9B?= =?UTF-8?q?=E8=A1=8C=E6=B1=87=E8=81=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../transformers/src/transformers/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/transformers/src/transformers/trainer.py b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/transformers/src/transformers/trainer.py index 4d704a9cfb..14d6c1904c 100644 --- a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/transformers/src/transformers/trainer.py +++ b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/transformers/src/transformers/trainer.py @@ -2509,15 +2509,15 @@ class Trainer: delattr(self, "_past") # Gather all remaining tensors and put them back on the CPU - if losses_host is not None: + if losses_host: losses_host = torch.cat(losses_host, dim=0) losses = nested_numpify(losses_host) all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) - if preds_host is not None: + if preds_host: preds_host = torch.cat(preds_host, dim=0) logits = nested_numpify(preds_host) all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) - if labels_host is not None: + if labels_host: labels_host = torch.cat(labels_host, dim=0) labels = nested_numpify(labels_host) all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) -- Gitee From 7ad99a579bf00b526f0d31672b3a36e365580b29 Mon Sep 17 00:00:00 2001 From: lijing Date: Sun, 24 Apr 2022 14:06:30 +0800 Subject: [PATCH 5/5] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=88=86=E5=B8=83?= =?UTF-8?q?=E5=BC=8F=E8=AE=A1=E7=AE=97=E8=B6=85=E6=97=B6=E5=8F=82=E6=95=B0?= =?UTF-8?q?=E5=92=8C=E8=AF=B4=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../nlp/Bert_Chinese_for_PyTorch/README.cn.md | 14 +++++++++++--- .../built-in/nlp/Bert_Chinese_for_PyTorch/env.sh | 1 + .../run_mlm_bertbase_8p.sh | 1 + .../run_mlm_bertlarge_8p.sh | 1 + .../transformers/src/transformers/training_args.py | 10 +++++++++- 5 files changed, 23 insertions(+), 4 deletions(-) diff --git a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/README.cn.md b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/README.cn.md index 722b82f649..3aeaaf928a 100644 --- a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/README.cn.md +++ b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/README.cn.md @@ -108,12 +108,20 @@ python3 -m torch.distributed.launch --nproc_per_node 8 run_mlm.py \ 1. Q:第一次运行报类似"xxx **socket timeout** xxx"的错误该怎么办? - A:第一次运行tokenizer会对单词进行预处理,根据您的数据集大小,耗时不同,若时间过长,可能导致HCCL通信超时。此时可以通过设置以下环境变量,设置较大的超时时间阈值(单位秒,默认为600秒): + A:第一次运行tokenizer会对单词进行预处理,根据您的数据集大小,耗时不同,若时间过长,可能导致等待超时。此时可以通过设置较大的超时时间阈值尝试解决: + (1)设置pytorch框架内置超时时间,修改脚本中的distributed_process_group_timeout(单位秒)为更大的值,例如设置为7200: + ``` - export HCCL_CONNECT_TIMEOUT=3600 +--distributed_process_group_timeout 7200 ``` - + + (2)设置HCCL的建链时间为更大的值,修改env.sh中环境变量HCCL_CONNECT_TIMEOUT(单位秒)的值: + + ``` + export HCCL_CONNECT_TIMEOUT=7200 + ``` + diff --git a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/env.sh b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/env.sh index 8ba618f127..d72034c89b 100644 --- a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/env.sh +++ b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/env.sh @@ -44,6 +44,7 @@ export DYNAMIC_OP="ADD#MUL" #HCCL白名单开关,1-关闭/0-开启 export HCCL_WHITELIST_DISABLE=1 export HCCL_IF_IP=$(hostname -I |awk '{print $1}') +export HCCL_CONNECT_TIMEOUT=5400 #设置device侧日志登记为error ${install_path}/driver/tools/msnpureport -g error -d 0 diff --git a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_bertbase_8p.sh b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_bertbase_8p.sh index 99d52812ee..66737acdf3 100644 --- a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_bertbase_8p.sh +++ b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_bertbase_8p.sh @@ -22,4 +22,5 @@ python3 -m torch.distributed.launch --nproc_per_node 8 run_mlm.py \ --loss_scale 8192 \ --use_combine_grad \ --optim adamw_apex_fused_npu \ + --distributed_process_group_timeout 5400 \ --output_dir ./output \ No newline at end of file diff --git a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_bertlarge_8p.sh b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_bertlarge_8p.sh index afa52ae7de..6d81435b1e 100644 --- a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_bertlarge_8p.sh +++ b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/run_mlm_bertlarge_8p.sh @@ -22,4 +22,5 @@ python3 -m torch.distributed.launch --nproc_per_node 8 run_mlm.py \ --loss_scale 8192 \ --use_combine_grad \ --optim adamw_apex_fused_npu \ + --distributed_process_group_timeout 5400 \ --output_dir ./output \ No newline at end of file diff --git a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/transformers/src/transformers/training_args.py b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/transformers/src/transformers/training_args.py index dbad5e03e0..09dee52967 100644 --- a/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/transformers/src/transformers/training_args.py +++ b/PyTorch/built-in/nlp/Bert_Chinese_for_PyTorch/transformers/src/transformers/training_args.py @@ -16,6 +16,7 @@ import contextlib import json import math import os +import datetime import warnings from dataclasses import asdict, dataclass, field from enum import Enum @@ -710,6 +711,13 @@ class TrainingArguments: "`DistributedDataParallel`." }, ) + distributed_process_group_timeout: Optional[int] = field( + default=1800, + metadata={ + "help": "Timeout(seconds) for operations executed against the process group, the value of the flag `timeout` passed to " + "`init_process_group`." + }, + ) dataloader_pin_memory: bool = field( default=True, metadata={"help": "Whether or not to pin memory for DataLoader."} ) @@ -1076,7 +1084,7 @@ class TrainingArguments: else: # Here, we'll use torch.distributed. # Initializes the distributed backend which will take care of synchronizing nodes/GPUs - torch.distributed.init_process_group(backend="hccl") + torch.distributed.init_process_group(backend="hccl", timeout=datetime.timedelta(seconds=self.distributed_process_group_timeout)) device = torch.device("npu", self.local_rank) self._n_gpu = 1 -- Gitee