diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/README.md b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/README.md index 76f04b28458ed9e29ce749f13c29ff7170572590..0c4b6f638440a1000678260e9373aa653be8d953 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/README.md +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/README.md @@ -39,12 +39,28 @@ train_data | -- valid.en_XX-ro_RO.ro_RO.idx | -- valid.en_XX-ro_RO.en_XX.bin | -- valid.en_XX-ro_RO.en_XX.idx + | -- en_de + | -- preprocess.log + | -- dict.en_XX.txt + | -- dict.de_DE.txt + | -- test.en_XX-de_DE.de_DE.bin + | -- test.en_XX-de_DE.de_DE.idx + | -- test.en_XX-de_DE.en_XX.bin + | -- test.en_XX-de_DE.en_XX.idx + | -- train.en_XX-de_DE.de_DE.bin + | -- train.en_XX-de_DE.de_DE.idx + | -- train.en_XX-de_DE.en_XX.bin + | -- train.en_XX-de_DE.en_XX.idx + | -- valid.en_XX-de_DE.de_DE.bin + | -- valid.en_XX-de_DE.de_DE.idx + | -- valid.en_XX-de_DE.en_XX.bin + | -- valid.en_XX-de_DE.en_XX.idx ``` ## 方法二. 下载数据集并自行处理 ### 1. 分词处理 -1. 下载en_ro数据集并放于工程根目录下 +1. 下载数据集并放于工程根目录下,以en_ro数据集为例 2. 下载并安装SPM [here](https://github.com/google/sentencepiece) ```bash SPM=/path/to/sentencepiece/build/src/spm_encode @@ -95,7 +111,8 @@ fairseq-preprocess \ # 在数据集上进行fine-tune ```bash -1. 修改run_8p.sh中PRETRAIN为模型的路径,DATA_PATH为数据集的路径 +1. 修改run_8p.sh中PRETRAIN为模型的路径,DATA_PATH为数据集的路径 + [若需要训练en_de数据集,则需要将run_8p.sh中dropout的参数设置为0.1,total-num-update与max-update设置为300000] 2. 执行 bash run_8p.sh ``` # 在数据集上进行评估 @@ -111,8 +128,12 @@ pip install sacrebleu==1.5.1 2.执行评估脚本 ```bash +验证en_ro精度 1. 修改generate_on_en_ro.sh中DATA_PATH为数据集的路径,BPE_PATH为sentence.bpe.model的路径,SCRIPTS为mosesdecoder/scripts的路径,WMT16_SCRIPTS为wmt16-scripts的路径 -2. 执行 bash generate_on_en_ro.sh checkpoints/checkpoint_best.pt +2. 执行 bash generate_on_en_ro.sh checkpoints/checkpoint_best.pt 验证en_ro的训练精度 +验证en_de精度 +1. 修改generate_on_en_de.sh中DATA_PATH为数据集的路径,BPE_PATH为sentence.bpe.model的路径,DETOKENIZER为mosesdecoder/scripts/tokenizer/detokenizer.perl的路径 +2. 执行 bash generate_on_en_de.sh checkpoints/checkpoint_best.pt 验证en_de的训练精度 ``` # Docker容器训练 diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/data/data_utils.py b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/data/data_utils.py index 74a560678676f63be7caa157fcc9ce836e5d31ab..e386fe86a35f6ed9fd395699516a4ea178a93159 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/data/data_utils.py +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/data/data_utils.py @@ -43,19 +43,11 @@ def collate_tokens( ): """Convert a list of 1d tensors into a padded 2d tensor.""" size = max(max(v.size(0) for v in values), max(ov.size(0) for ov in other_values)) - if size <= 16: - size = 16 - elif size <= 32: - size = 32 - elif size <= 64: - size = 64 - elif size <= 128: - size = 128 - elif size <= 256: - size = 256 - else: - size = 512 - + buckets = [16, 32, 64, 128, 256, 512, 1024] + for buck in buckets: + if size <= buck: + size = buck + break size = size if pad_to_length is None else max(size, pad_to_length) if pad_to_multiple != 1 and size % pad_to_multiple != 0: size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) @@ -352,17 +344,20 @@ def batch_by_size( def batch_by_size_fast_fix(indices, num_tokens_fn,max_tokens,max_sentences,bsz_mult): indices_view = indices - fix_shape_dict = {16:[],32:[],64:[],128:[],256:[],512:[]} + buckets = [16, 32, 64, 128, 256, 512, 1024] + fix_shape_dict = {} + for buck in buckets: + fix_shape_dict[buck] = [] batch_by_size_list = [] for i in range(len(indices_view)): idx = indices_view[i] fix_shape_dict[num_tokens_fn(idx)].append(idx) - for key_len in fix_shape_dict.keys(): + for idx, key_len in enumerate(buckets): max_batch = max_tokens // key_len division_len = max_batch * (len(fix_shape_dict[key_len]) // max_batch) tail_len = len(fix_shape_dict[key_len]) - division_len - if tail_len > 0 and key_len != 512: - fix_shape_dict[key_len * 2] = fix_shape_dict[key_len][division_len:] + fix_shape_dict[key_len * 2] + if tail_len > 0 and key_len != 1024: + fix_shape_dict[buckets[idx + 1]] = fix_shape_dict[key_len][division_len:] + fix_shape_dict[buckets[idx + 1]] if division_len == 0: pass else: diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/data/language_pair_dataset.py b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/data/language_pair_dataset.py index 730739a48b892a125d473eea127031b020df0874..1286b7944736a68336fbdeac0d7a1a998304f0fe 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/data/language_pair_dataset.py +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/data/language_pair_dataset.py @@ -404,22 +404,14 @@ class LanguagePairDataset(FairseqDataset): def num_tokens(self, index): """Return the number of tokens in a sample. This value is used to enforce ``--max-tokens`` during batching.""" + buckets = [16, 32, 64, 128, 256, 512, 1024] src = max( self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0, ) - if src <= 16: - return 16 - elif src <= 32: - return 32 - elif src <= 64: - return 64 - elif src <= 128: - return 128 - elif src <= 256: - return 256 - else: - return 512 + for buck in buckets: + if src <= buck: + return buck def size(self, index): """Return an example's size as a float or tuple. This value is used when diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/models/transformer.py b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/models/transformer.py index 2d898f5c73aaffbe5f999355bf9cd84f532f2a1c..5bf2c44f0c99f73f336b510167f4becb4e109d48 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/models/transformer.py +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/models/transformer.py @@ -419,11 +419,14 @@ class TransformerEncoder(FairseqEncoder): # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) - encoder_padding_mask = (encoder_padding_mask.to(torch.float16) * 65504).unsqueeze(1).unsqueeze(2) + encoder_padding_mask = (encoder_padding_mask.to(torch.float16) * -65504).unsqueeze(1).unsqueeze(2) encoder_padding_mask = encoder_padding_mask.repeat(1,self.encoder_attention_heads, tgt_len, 1).clone().npu_format_cast(29) encoder_states = [] if return_all_hiddens else None - + if len(x.shape) == 3: + x = x.view(-1, x.shape[2]).clone().npu_format_cast(29) + else: + x = x.npu_format_cast(29) # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask, bsz, tgt_len, s_len) @@ -790,28 +793,41 @@ class TransformerDecoder(FairseqIncrementalDecoder): self_attn_padding_mask: Optional[Tensor] = None if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any(): self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) - self_attn_padding_mask = (self_attn_padding_mask.to(torch.float16) * 65504).unsqueeze(1).unsqueeze(2) + self_attn_padding_mask = (self_attn_padding_mask.to(torch.float16) * -65504).unsqueeze(1).unsqueeze(2) self_attn_padding_mask = self_attn_padding_mask.repeat(1, self.decoder_attention_heads, tgt_len, 1).clone().npu_format_cast( 29) + if encoder_out is not None: encoder_padding_mask = encoder_out.encoder_padding_mask.unsqueeze(1).unsqueeze(2)\ .repeat(1, self.decoder_attention_heads, tgt_len, 1) + if len(encoder_out.encoder_out.shape) == 3: + encoder_out_ = encoder_out.encoder_out.view(-1, encoder_out.encoder_out.shape[2]).clone().npu_format_cast(29) + else: + encoder_out_ = encoder_out.encoder_out.npu_format_cast(29) + if len(x.shape) == 3: + x = x.view(-1, x.shape[2]).clone().npu_format_cast(29) + else: + x = x.npu_format_cast(29) # decoder layers attn: Optional[Tensor] = None inner_states: List[Optional[Tensor]] = [x] for idx, layer in enumerate(self.layers): if incremental_state is None and not full_context_alignment: self_attn_mask = self.buffered_future_mask(x, tgt_len) - else: - self_attn_mask = None + if self_attn_padding_mask is not None: + self_attn_padding_mask = self_attn_padding_mask + self_attn_mask + else: + self_attn_padding_mask = self_attn_mask.unsqueeze(0).unsqueeze(1).repeat(bsz, self.decoder_attention_heads, + 1, 1).clone().npu_format_cast( + 29) x, layer_attn, _ = layer( x, bsz, tgt_len, s_len, - encoder_out.encoder_out if encoder_out is not None else None, + encoder_out_ if encoder_out is not None else None, encoder_padding_mask if encoder_out is not None else None, incremental_state, - self_attn_mask=self_attn_mask, + self_attn_mask=None, self_attn_padding_mask=self_attn_padding_mask, need_attn=bool((idx == alignment_layer)), need_head_weights=bool((idx == alignment_layer)), diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/__init__.py b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/__init__.py index e2326ac6e36d9264e74f64e580834685ed99b546..4b7fcb95e1873b58edca5853fa3e058fde1eb1b2 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/__init__.py +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/__init__.py @@ -13,7 +13,7 @@ from .cross_entropy import cross_entropy from .downsampled_multihead_attention import DownsampledMultiHeadAttention from .dynamic_convolution import DynamicConv, DynamicConv1dTBC from .dynamic_crf_layer import DynamicCRF -from .fairseq_dropout import FairseqDropout +from .fairseq_dropout import FairseqDropout, NpuFairseqDropout from .fp32_group_norm import Fp32GroupNorm from .gelu import gelu, gelu_accurate from .grad_multiply import GradMultiply @@ -48,6 +48,7 @@ __all__ = [ "DynamicConv", "DynamicCRF", "FairseqDropout", + "NpuFairseqDropout", "Fp32GroupNorm", "Fp32LayerNorm", "gelu", diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/fairseq_dropout.py b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/fairseq_dropout.py index f070a804e6c1e00b6c0db315b944305c2c41d807..c82f175c41acc889dd31fb03495d4987181c2f6e 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/fairseq_dropout.py +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/fairseq_dropout.py @@ -6,12 +6,19 @@ import logging from typing import List, Optional +import torch import torch.nn as nn import torch.nn.functional as F logger = logging.getLogger(__name__) +def get_dropout_class(): + try: + from torch import npu_dropout_do_mask + return NpuFairseqDropout + except: + return FairseqDropout class FairseqDropout(nn.Module): def __init__(self, p, module_name=None): @@ -49,3 +56,78 @@ class FairseqDropout(nn.Module): self.apply_during_inference = True else: logger.info("Disabling dropout for module: {}".format(name)) + +class DropOutTask: + def __init__(self, shape, dtype, device, p): + self.shape = shape + self.dtype = dtype + self.device = device + self.p = p + self.request_count = 0 + self.mask_queue = [] + +class NpuFairseqDropout(torch.nn.Dropout): + task_dict = {} + dropout_stream = None + + def __init__(self, p, module_name=None): + super().__init__(p) + self.module_name = module_name + + def forward(self, x): + if isinstance(x, torch.Tensor): + shape = x.shape + dtype = x.dtype + device = x.device + do_mask_flag = True + return_obj = x + elif isinstance(x, list): + shape, dtype, device = x + do_mask_flag = False + return_obj = None + else: + raise RuntimeError("input type error!") + + if self.p == 0: + return return_obj + key = (shape, dtype, device, self.p) + if key not in NpuFairseqDropout.task_dict: + dropout_task = DropOutTask(shape, dtype, device, self.p) + dropout_task.request_count += 1 + NpuFairseqDropout.task_dict[key] = dropout_task + return return_obj + elif not NpuFairseqDropout.task_dict[key].mask_queue: + NpuFairseqDropout.task_dict[key].request_count += 1 + return return_obj + else: + mask, event = NpuFairseqDropout.task_dict[key].mask_queue.pop(0) + if do_mask_flag: + return torch.npu_dropout_do_mask(x, mask, self.p)[0] + else: + return mask + + @classmethod + def enable_dropout_ensemble(cls, model): + if cls.dropout_stream is None: + cls.dropout_stream = torch.npu.Stream() + + def wait_stream_hook_func(): + def hook_function(module, inputs): + torch.npu.current_stream().wait_stream(cls.dropout_stream) + return hook_function + model.register_forward_pre_hook(wait_stream_hook_func()) + + def mask_gen_hook_func(): + def hook_function(module, inputs, outputs): + with torch.npu.stream(cls.dropout_stream): + with torch.no_grad(): + for _, task in cls.task_dict.items(): + if len(task.mask_queue) < task.request_count: + for j in range(task.request_count - len(task.mask_queue)): + mask = torch.npu_dropout_gen_mask(task.shape, p=task.p, dtype=task.dtype, + device=task.device) + event = None + task.mask_queue.append((mask, event)) + return hook_function + + model.register_forward_hook(mask_gen_hook_func()) diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/gelu.py b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/gelu.py index 7c56bf32d6eed7fc052b6a04c11b4f7a663a688c..a60c15a5a1110dad9dd3f01302d38ca4459038cc 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/gelu.py +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/gelu.py @@ -22,4 +22,4 @@ def gelu_accurate(x): def gelu(x: torch.Tensor) -> torch.Tensor: - return torch.nn.functional.gelu(x).type_as(x) + return torch.fast_gelu(x).type_as(x) diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/multihead_attention.py b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/multihead_attention.py index 6d984e8c88ca7102d44eedf58a02851c9ae9b71a..5d061ce0ca3cec518aa4cb1c209e8c5aca308706 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/multihead_attention.py +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/multihead_attention.py @@ -10,11 +10,13 @@ import torch import torch.nn.functional as F from fairseq import utils from fairseq.incremental_decoding_utils import with_incremental_state -from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.fairseq_dropout import NpuFairseqDropout, get_dropout_class from fairseq.modules.quant_noise import quant_noise from torch import Tensor, nn from torch.nn import Parameter +dropout_class = get_dropout_class() + class NpuLinear(nn.Linear): def forward(self, input): @@ -29,6 +31,17 @@ class NpuLinear(nn.Linear): else: raise RuntimeError('not support this dim') +class MHAConfig: + use_fussion_mha = False + + @classmethod + def set_fussion(cls): + try: + from torch import npu_multi_head_attention + cls.use_fussion_mha = True + except: + cls.use_fussion_mha = False + class MatmulApply(torch.autograd.Function): @staticmethod def forward(ctx, self, mat2): @@ -74,9 +87,12 @@ class MultiheadAttention(nn.Module): self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads - self.dropout_module = FairseqDropout( + self.dropout_module = dropout_class( dropout, module_name=self.__class__.__name__ ) + self.dropout_prob = dropout + + self.use_dropout_optim = (dropout_class is NpuFairseqDropout) self.head_dim = embed_dim // num_heads assert ( @@ -143,6 +159,7 @@ class MultiheadAttention(nn.Module): nn.init.xavier_normal_(self.bias_k) if self.bias_v is not None: nn.init.xavier_normal_(self.bias_v) + def transpose_for_scores(self, x): new_x_shape = (self.batch_size, self.squence_length) + (self.num_attention_heads, self.attention_head_size) return x.npu_confusion_transpose((0, 2, 1, 3), new_x_shape, False) @@ -177,6 +194,15 @@ class MultiheadAttention(nn.Module): weights for each head. Implies *need_weights*. Default: return the average attention weights over all heads. """ + if MHAConfig.use_fussion_mha: + attn = self.multi_attn(query, key, value, key_padding_mask, bsz, tgt_len) + return attn, None + else: + return self.ori_attn(query, key, value, bsz, tgt_len, key_padding_mask, incremental_state, + need_weights, static_kv, attn_mask, before_softmax, need_head_weights) + + def ori_attn(self, query, key, value, bsz, tgt_len, key_padding_mask, incremental_state, + need_weights, static_kv, attn_mask, before_softmax, need_head_weights): if need_head_weights: need_weights = True @@ -184,41 +210,6 @@ class MultiheadAttention(nn.Module): assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len * bsz, embed_dim] - if ( - not self.onnx_trace - and not self.tpu # don't use PyTorch version on TPUs - and not self.npu - and incremental_state is None - and not static_kv - # A workaround for quantization to work. Otherwise JIT compilation - # treats bias in linear module as method. - and not torch.jit.is_scripting() - ): - assert key is not None and value is not None - return F.multi_head_attention_forward( - query, - key, - value, - self.embed_dim, - self.num_heads, - torch.empty([0]), - torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), - self.bias_k, - self.bias_v, - self.add_zero_attn, - self.dropout_module.p, - self.out_proj.weight, - self.out_proj.bias, - self.training or self.dropout_module.apply_during_inference, - key_padding_mask, - need_weights, - attn_mask, - use_separate_proj_weight=True, - q_proj_weight=self.q_proj.weight, - k_proj_weight=self.k_proj.weight, - v_proj_weight=self.v_proj.weight, - ) - if incremental_state is not None: saved_state = self._get_input_buffer(incremental_state) if saved_state is not None and "prev_key" in saved_state: @@ -359,7 +350,7 @@ class MultiheadAttention(nn.Module): if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights - key_padding_mask + attn_weights = attn_weights + key_padding_mask if before_softmax: return attn_weights, v @@ -393,6 +384,20 @@ class MultiheadAttention(nn.Module): return attn, attn_weights + def multi_attn(self, query, key, value, key_padding_mask, bsz, tgt_len): + src_len = key.size(0) // bsz + if self.use_dropout_optim: + dropout_mask = self.dropout_module([(bsz, self.num_heads, tgt_len, src_len), query.dtype, query.device]) + else: + dropout_mask = None + attn = torch.npu_multi_head_attention(query, key, value, self.q_proj.weight, + self.k_proj.weight, self.v_proj.weight, + key_padding_mask, self.out_proj.weight, + self.q_proj.bias, self.k_proj.bias, self.v_proj.bias, + self.out_proj.bias, dropout_mask, self.num_heads, + self.head_dim, src_len, tgt_len, self.dropout_prob, True) + return attn[0] + @staticmethod def _append_prev_key_padding_mask( key_padding_mask: Optional[Tensor], diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/transformer_layer.py b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/transformer_layer.py index 413d3cdd4a528928953a2790e06e3985b473092d..b889b7cfc448e5fade7735ac2a50c364d1345a88 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/transformer_layer.py +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/transformer_layer.py @@ -9,10 +9,12 @@ import torch import torch.nn as nn from fairseq import utils from fairseq.modules import LayerNorm, MultiheadAttention -from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.fairseq_dropout import get_dropout_class from fairseq.modules.quant_noise import quant_noise from torch import Tensor +dropout_class = get_dropout_class() + class NpuLinear(torch.nn.Linear): def forward(self, input): return torch.npu_linear(input, self.weight, self.bias) @@ -39,7 +41,7 @@ class TransformerEncoderLayer(nn.Module): self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8) self.self_attn = self.build_self_attention(self.embed_dim, args) self.self_attn_layer_norm = LayerNorm(self.embed_dim) - self.dropout_module = FairseqDropout( + self.dropout_module = dropout_class( args.dropout, module_name=self.__class__.__name__ ) self.activation_fn = utils.get_activation_fn( @@ -49,7 +51,7 @@ class TransformerEncoderLayer(nn.Module): if activation_dropout_p == 0: # for backwards compatibility with models that use args.relu_dropout activation_dropout_p = getattr(args, "relu_dropout", 0) - self.activation_dropout_module = FairseqDropout( + self.activation_dropout_module = dropout_class( float(activation_dropout_p), module_name=self.__class__.__name__ ) self.normalize_before = args.encoder_normalize_before @@ -129,17 +131,11 @@ class TransformerEncoderLayer(nn.Module): if attn_mask is not None: attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8) - if len(x.shape) == 3: - residual = x.view(-1, x.shape[2]).clone().npu_format_cast(29) - else: - residual = x.npu_format_cast(29) + residual = x if self.normalize_before: x = self.self_attn_layer_norm(x) - if len(x.shape) == 3: - x = x.view(-1,x.shape[2]).clone().npu_format_cast(29) - else: - x= x.npu_format_cast(29) + x, _ = self.self_attn( query=x, key=x, @@ -188,7 +184,7 @@ class TransformerDecoderLayer(nn.Module): ): super().__init__() self.embed_dim = args.decoder_embed_dim - self.dropout_module = FairseqDropout( + self.dropout_module = dropout_class( args.dropout, module_name=self.__class__.__name__ ) self.quant_noise = getattr(args, "quant_noise_pq", 0) @@ -212,7 +208,7 @@ class TransformerDecoderLayer(nn.Module): if activation_dropout_p == 0: # for backwards compatibility with models that use args.relu_dropout activation_dropout_p = getattr(args, "relu_dropout", 0) - self.activation_dropout_module = FairseqDropout( + self.activation_dropout_module = dropout_class( float(activation_dropout_p), module_name=self.__class__.__name__ ) self.normalize_before = args.decoder_normalize_before @@ -315,22 +311,10 @@ class TransformerDecoderLayer(nn.Module): if need_head_weights: need_attn = True - if len(x.shape) == 3: - residual = x.view(-1, x.shape[2]).clone().npu_format_cast(29) - else: - residual = x.npu_format_cast(29) - - if len(encoder_out.shape) == 3: - encoder_out = encoder_out.view(-1, encoder_out.shape[2]).clone().npu_format_cast(29) - else: - encoder_out = encoder_out.npu_format_cast(29) + residual = x if self.normalize_before: x = self.self_attn_layer_norm(x) - if len(x.shape) == 3: - x = x.view(-1,x.shape[2]).clone().npu_format_cast(29) - else: - x= x.npu_format_cast(29) if prev_self_attn_state is not None: prev_key, prev_value = prev_self_attn_state[:2] diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/optim/__init__.py b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/optim/__init__.py index 94eb2c7ee966756590a4294eab181fc87fac2fa1..ab45abfd1f19a7d2adb237e59815d9d5b37fae9c 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/optim/__init__.py +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/optim/__init__.py @@ -41,7 +41,6 @@ def build_optimizer( ): if all(isinstance(p, dict) for p in params): params = [t for p in params for t in p.values()] - params = list(filter(lambda p: p.requires_grad, params)) return _build_optimizer(optimizer_cfg, params, *extra_args, **extra_kwargs) diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/optim/fp16_optimizer.py b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/optim/fp16_optimizer.py index 385946b1c24cc678af04247146a2c840dfe4983f..dd2d7155d8f929aced9dd41510d8c7668baa2482 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/optim/fp16_optimizer.py +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/optim/fp16_optimizer.py @@ -10,7 +10,7 @@ import torch from fairseq import optim, utils from .dynamic_loss_scaler import DynamicLossScaler - +from apex.contrib.combine_tensors import combine_npu class _FP16OptimizerMixin(object): def __init__(self, *args, **kwargs): @@ -26,38 +26,26 @@ class _FP16OptimizerMixin(object): ) @classmethod + @torch.no_grad() def build_fp32_params(cls, args, params, flatten=True): # create FP32 copy of parameters and grads + cls.fp32_tmp_params = dict() if flatten: is_pipeline_parallel = getattr( args, "pipeline_model_parallel", False ) and getattr(args, "distributed_no_spawn", False) - total_param_size = sum(p.data.numel() for p in params) + devices = [torch.npu.current_device()] if is_pipeline_parallel: devices = list(set(args.pipeline_devices)) fp32_params = {} for device in devices: - if is_pipeline_parallel: - device_param_size = sum( - p.data.numel() for p in params if p.device.index == device - ) - device_params = [p for p in params if p.device.index == device] - else: - device_param_size = total_param_size - device_params = params - fp32_params[device] = ( - device_params[0].new(0).float().new(device_param_size) - ) - offset = 0 - for p in device_params: - numel = p.data.numel() - fp32_params[device][offset : offset + numel].copy_(p.data.view(-1)) - offset += numel - fp32_params[device] = torch.nn.Parameter(fp32_params[device]) - fp32_params[device].grad = fp32_params[device].data.new( - device_param_size - ) + cls.fp32_tmp_params[device] = [] + for idx, p in enumerate(params): + cls.fp32_tmp_params[device].append(p.data.float()) + fp32_params[device] = combine_npu(cls.fp32_tmp_params[device]) + + fp32_params[device].grad = torch.zeros_like(fp32_params[device].data) return fp32_params else: fp32_params = [] @@ -103,24 +91,31 @@ class _FP16OptimizerMixin(object): # copy FP16 grads to FP32 if self.has_flat_params: devices = list(self.fp32_params.keys()) - device_params_dict = defaultdict(list) - for p in self.fp16_params: - if p.requires_grad: - device_params_dict[p.device.index].append(p) - for device in devices: - device_params = device_params_dict[device] - offset = 0 - for p in device_params: - grad_data = ( - p.grad.data - if p.grad is not None - else p.data.new_zeros(p.data.shape) - ) - numel = grad_data.numel() - self.fp32_params[device].grad.data[ - offset : offset + numel - ].copy_(grad_data.view(-1)) - offset += numel + if not self.combine_grads_flag: + device_params_dict = defaultdict(list) + for p in self.fp16_params: + if p.requires_grad: + device_params_dict[p.device.index].append(p) + for device in devices: + device_params = device_params_dict[device] + fp16_grads_list = [] + fp16_params_list = [] + for p in device_params: + grad_data = ( + p.grad.data + if p.grad is not None + else p.data.new_zeros(p.data.shape) + ) + fp16_grads_list.append(grad_data) + fp16_params_list.append(p.data) + self.fp16_tmp_grads[device] = combine_npu(fp16_grads_list) + self.fp16_tmp_params[device] = combine_npu(fp16_params_list) + self.fp32_params[device].grad.data.copy_(self.fp16_tmp_grads[device]) + self.combine_grads_flag = True + else: + for device in devices: + self.fp32_params[device].grad.data.copy_(self.fp16_tmp_grads[device]) + else: for p, p32 in zip(self.fp16_params, self.fp32_params): if not p.requires_grad: @@ -136,20 +131,17 @@ class _FP16OptimizerMixin(object): # copy FP32 params back into FP16 model if self.has_flat_params: devices = list(self.fp32_params.keys()) - device_params_dict = defaultdict(list) - for p in self.fp16_params: - device_params_dict[p.device.index].append(p) - for device in devices: - device_params = device_params_dict[device] - offset = 0 - for p in device_params: - numel = p.data.numel() - p.data.copy_( - self.fp32_params[device] - .data[offset : offset + numel] - .view_as(p.data) - ) - offset += numel + if not self.combine_grads_flag: + device_params_dict = defaultdict(list) + for p in self.fp16_params: + device_params_dict[p.device.index].append(p) + for device in devices: + device_params = device_params_dict[device] + for idx, p in enumerate(device_params): + p.data.copy_(self.fp32_tmp_params[device][idx].data) + else: + for device in devices: + self.fp16_tmp_params[device].data.copy_(self.fp32_params[device]) else: for p, p32 in zip(self.fp16_params, self.fp32_params): if not p.requires_grad: @@ -175,10 +167,10 @@ class _FP16OptimizerMixin(object): ) if self.scaler is not None: - if grad_norm > max_norm > 0.0: - self._multiply_factor *= max_norm / grad_norm + if max_norm > 0.0: + if grad_norm > max_norm: + self._multiply_factor *= max_norm / grad_norm - self.scaler.check_overflow(grad_norm) elif max_norm > 0.0: clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1) self._multiply_factor *= clip_coef @@ -202,8 +194,14 @@ class _FP16OptimizerMixin(object): def zero_grad(self): """Clears the gradients of all optimized parameters.""" - for p in self.fp16_params: - p.grad = None + if self.combine_grads_flag: + devices = list(self.fp16_tmp_grads.keys()) + for device in devices: + self.fp16_tmp_grads[device].zero_() + else: + for p in self.fp16_params: + if p.grad is not None: + p.grad.zero_() if self.has_flat_params: if torch.is_tensor(self.fp32_params): self.fp32_params.grad.zero_() @@ -232,6 +230,9 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer): self.fp16_params = params self.fp32_optimizer = fp32_optimizer self.fp32_params = fp32_params + self.fp16_tmp_grads = dict() + self.fp16_tmp_params = dict() + self.combine_grads_flag = False if getattr(args, "fp16_scale_window", None) is None: if len(args.update_freq) > 1: diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/trainer.py b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/trainer.py index fe28b41f7c214b8d244d3371e65226fec8de64f1..25332f9b38e63ec6d02752759a207596755d3fce 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/trainer.py +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/trainer.py @@ -15,16 +15,48 @@ from itertools import chain from typing import Any, Dict, List import torch +import torch.distributed as dist from fairseq import checkpoint_utils, distributed_utils, models, optim, utils from fairseq.file_io import PathManager from fairseq.logging import meters, metrics from fairseq.nan_detector import NanDetector from fairseq.optim import lr_scheduler +from fairseq.modules.fairseq_dropout import NpuFairseqDropout, get_dropout_class logger = logging.getLogger(__name__) +class PreFetcher: + def __init__(self, loader, device): + self.stream = torch.npu.Stream() + self.iterable = iter(loader) + self.len = len(loader) + self.device = device + self.preload() + + def __iter__(self): + return self + + def __next__(self): + torch.npu.current_stream().wait_stream(self.stream) + data = self.next_data + if data == -1: + raise StopIteration + if data is not None: + self.preload() + return data + + def preload(self): + try: + self.next_data = next(self.iterable) + except StopIteration: + self.next_data = -1 + return + with torch.npu.stream(self.stream): + self.next_data = utils.move_to_cuda(self.next_data, self.device) + + class Trainer(object): """Main class for data parallel training. @@ -38,7 +70,8 @@ class Trainer(object): def __init__(self, args, task, model, criterion, quantizer=None): self.args = args self.task = task - + self.reduce_stream = torch.npu.Stream() + self.first_grad = None # catalog shared parameters shared_params = _catalog_shared_params(model) @@ -67,6 +100,9 @@ class Trainer(object): if not args.pipeline_model_parallel: self._criterion = self._criterion.to(device=self.device) self._model = self._model.to(device=self.device) + + self._model.encoder.embed_positions.weight.data = self._model.encoder.embed_positions.weight.data.npu_format_cast(29) + self._model.decoder.embed_positions.weight.data = self._model.decoder.embed_positions.weight.data.npu_format_cast(29) self.pipeline_model_parallel = args.pipeline_model_parallel self.last_device = None if self.npu and self.pipeline_model_parallel: @@ -119,6 +155,8 @@ class Trainer(object): self._start_time = time.time() self._previous_training_time = 0 self._cumulative_training_time = None + if get_dropout_class() is NpuFairseqDropout: + NpuFairseqDropout.enable_dropout_ensemble(self.model) def reinitialize(self): """Reinitialize the Trainer, typically after model params change.""" @@ -436,13 +474,13 @@ class Trainer(object): self.criterion.train() self.zero_grad() + prefetch_samples = PreFetcher(samples, torch.npu.current_device()) metrics.log_start_time("train_wall", priority=800, round=4) # forward and backward pass logging_outputs, sample_size, ooms = [], 0, 0 - for i, sample in enumerate(samples): - sample = self._prepare_sample(sample) - if sample is None: + for i, sample in enumerate(prefetch_samples): + if sample == {}: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) @@ -544,7 +582,10 @@ class Trainer(object): ) if hasattr(self.model, "all_reduce"): - self.model.all_reduce() + torch.npu.current_stream().wait_stream(self.reduce_stream) + if self.first_grad is not None: + self.first_grad.div_(8) + dist.all_reduce(self.first_grad) overflow = False try: diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq_cli/train.py b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq_cli/train.py index b7b7d7a26652299dc83fde98d2b7ad3f46e8f799..eb4e45c3a14b26a24b1cef6b355eebc25b56924e 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq_cli/train.py +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq_cli/train.py @@ -28,6 +28,7 @@ from fairseq.data import iterators from fairseq.logging import meters, metrics, progress_bar from fairseq.model_parallel.megatron_trainer import MegatronTrainer from fairseq.trainer import Trainer +from fairseq.modules.multihead_attention import MHAConfig logging.basicConfig( @@ -172,6 +173,66 @@ def should_stop_early(args, valid_loss): else: return False +def wrapper_model_all_reduce(model, fp_16_grads, reduce_stream): + from change_data_ptr import change_data_ptr + import torch.distributed as dist + total_param_size = 0 + for name, para in model.named_parameters(): + if name == "module.encoder.embed_tokens.weight": + continue + total_param_size += para.storage().size() + + target_para_size_list = [] + tmp_size = 0 + name_dict = dict() + name_order = 0 + for name, para in model.named_parameters(): + if name == "module.necoder.embed_tokens.weight": + target_para_size_list.append(para.storage().size()) + name_dict[name] =name_order + name_order += 1 + continue + tmp_size += para.storage().size() + name_dict[name] = name_order + if tmp_size > total_param_size // 8: + target_para_size_list.append(tmp_size) + tmp_size = 0 + name_order += 1 + target_para_size_list.append(tmp_size) + partial_combined_grad_list = [] + idx = 0 + for ss in target_para_size_list: + tmp_tensor = torch.zeros(ss).half().npu() + for device in fp_16_grads: + change_data_ptr(tmp_tensor, fp_16_grads[device], idx) + partial_combined_grad_list.append(tmp_tensor) + idx += ss + + target_para_size_list = [pp *2 for pp in target_para_size_list] + current_para_size_list = [0] *(len(target_para_size_list)) + ready_reduce_index = [] + + def hook_func(name, target_para_size_list, current_para_size_list, name_dict, reduce_stream, partial_combined_grad_list, ready_reduce_index): + def hook_function(grad): + if ready_reduce_index: + index = ready_reduce_index.pop() + current_para_size_list[index] = 0 + with torch.npu.stream(reduce_stream): + partial_combined_grad_list[index].div_(8) + dist.all_reduce(partial_combined_grad_list[index]) + + current_para_size_list[name_dict[name]] += grad.storage().size() + for i in range(len(current_para_size_list)): + if current_para_size_list[i] == target_para_size_list[i] and i != 0: + ready_reduce_index.append(i) + return + return hook_function + + for name, para in model.named_parameters(): + para.register_hook(hook_func(name, target_para_size_list, current_para_size_list, name_dict, reduce_stream, partial_combined_grad_list, ready_reduce_index)) + + return partial_combined_grad_list[0] + @metrics.aggregate("train") def train(args, trainer, task, epoch_itr): @@ -206,11 +267,16 @@ def train(args, trainer, task, epoch_itr): valid_subsets = args.valid_subset.split(",") should_stop = False num_updates = trainer.get_num_updates() + visited = False + MHAConfig.set_fussion() for i, samples in enumerate(progress): with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i ): log_output = trainer.train_step(samples) + if hasattr(trainer.model, "all_reduce") and (trainer.optimizer.fp16_tmp_grads is not None) and (not visited) and (epoch_itr.epoch <= 1): + trainer.first_grad = wrapper_model_all_reduce(trainer.model, trainer.optimizer.fp16_tmp_grads, trainer.reduce_stream) + visited = True if log_output is not None: # not OOM, overflow, ... # log mid-epoch stats diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/generate_on_en_de.sh b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/generate_on_en_de.sh new file mode 100644 index 0000000000000000000000000000000000000000..1a202ccee400a1cbac57920634cbebc51349d5fd --- /dev/null +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/generate_on_en_de.sh @@ -0,0 +1,23 @@ +source env.sh +DATA_PATH=path_of_data # fix it to your own train data path +BPE_PATH=/path/sentence.bpe.model # fix it to your own sentence.bpe.model path +langs=ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN +model_dir=$1 +DETOKENIZER=mosesdecoder/scripts/tokenizer/detokenizer.perl +HYP=hyp +REF=ref + +fairseq-generate $DATA_PATH \ + --fp16 --path $model_dir --max-tokens 4096 \ + --task translation_from_pretrained_bart \ + --gen-subset test \ + -t de_DE -s en_XX \ + --bpe 'sentencepiece' --sentencepiece-model $BPE_PATH \ + --scoring sacrebleu --remove-bpe 'sentencepiece' \ + --batch-size 32 --langs $langs > en_de +sed -i '$d' en_de +cat en_de | grep -P "^H" |sort -V |cut -f 3- > $HYP".txt" +cat en_de | grep -P "^T" |sort -V |cut -f 2- > $REF".txt" + +$DETOKENIZER -l de < $HYP".txt" >test.detok.hyp +sacrebleu -t wmt20 -l en-de -i test.detok.hyp -b \ No newline at end of file diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/run_8p.sh b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/run_8p.sh index 1d21b82a27d8e20d24b69a2ebd6552941a1a7701..e8761ac32a290e06f4ab149146685058add6ed4a 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/run_8p.sh +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/run_8p.sh @@ -22,7 +22,7 @@ do --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ --lr-scheduler polynomial_decay --lr 3e-05 --min-lr -1 --warmup-updates 2500 --total-num-update 40000 \ --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ - --max-tokens 512 --update-freq 2 \ + --max-tokens 1024 --update-freq 2 \ --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ --seed 222 --log-format simple --log-interval 2 \ --restore-file $PRETRAIN \ @@ -40,7 +40,7 @@ do --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ --lr-scheduler polynomial_decay --lr 3e-05 --min-lr -1 --warmup-updates 2500 --total-num-update 40000 \ --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ - --max-tokens 512 --update-freq 2 \ + --max-tokens 1024 --update-freq 2 \ --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ --seed 222 --log-format simple --log-interval 2 \ --restore-file $PRETRAIN \ diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/test/train_full_8p.sh b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/test/train_full_8p.sh index 9570c9014db6017820d0d7916fa785f2ce2298c6..27e798534c1c141d1a342e6f13ea71311f4f28de 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/test/train_full_8p.sh +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/test/train_full_8p.sh @@ -12,7 +12,7 @@ export RANK_SIZE=8 #网络名称,同目录名称 Network="mBART_for_PyTorch" #训练batch_size -token_size=512 +token_size=1024 #训练开始时间,不需要修改 start_time=$(date +%s) @@ -64,7 +64,7 @@ do --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ --lr-scheduler polynomial_decay --lr 3e-05 --min-lr -1 --warmup-updates 2500 --total-num-update 40000 \ --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ - --max-tokens 512 --update-freq 2 \ + --max-tokens 1024 --update-freq 2 \ --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ --seed 222 --log-format simple --log-interval 2 \ --restore-file $PRETRAIN \ @@ -82,7 +82,7 @@ do --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ --lr-scheduler polynomial_decay --lr 3e-05 --min-lr -1 --warmup-updates 2500 --total-num-update 40000 \ --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ - --max-tokens 512 --update-freq 2 \ + --max-tokens 1024 --update-freq 2 \ --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ --seed 222 --log-format simple --log-interval 2 \ --restore-file $PRETRAIN \ diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/test/train_performance_1p.sh b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/test/train_performance_1p.sh index c29e1c7ecc92834eb612682c7369137c6a6edd56..1c329db30fa6158e3ccc768126b1c5d08076a094 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/test/train_performance_1p.sh +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/test/train_performance_1p.sh @@ -1,127 +1,127 @@ -#!/bin/bash - -cur_path=`pwd`/../ -#失败用例打屏 -export ASCEND_SLOG_PRINT_TO_STDOUT=0 -export SCALAR_TO_HOST_MEM=1 - -export MKL_SERVICE_FORCE_INTEL=1 -export BMMV2_ENABLE=1 -langs=ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN -#基础参数,需要模型审视修改 -#Batch Size -batch_size=512 -#网络名称,同目录名称 -Network="mBART_ID2372_for_PyTorch" -#Device数量,单卡默认为1 -RankSize=1 -#训练epoch,可选 -train_epochs=1 -#训练step -train_steps= -#学习率 -learning_rate=3e-05 - -#参数配置 -data_path="" - -if [[ $1 == --help || $1 == --h ]];then - echo "usage:./train_performance_1p.sh " - exit 1 -fi - -for para in $* -do - if [[ $para == --data_path* ]];then - data_path=`echo ${para#*=}` - elif [[ $para == --conda_name* ]];then - conda_name=`echo ${para#*=}` - source set_conda.sh --conda_name=$conda_name - #export PATH=/usr/local/python3.7.5/bin:/home/anaconda3/bin:$PATH - #source activate py8 - source activate $conda_name - - fi -done - -if [[ $data_path == "" ]];then - echo "[Error] para \"data_path\" must be config" - exit 1 - -fi -sed -i "s|checkpoint_utils.save_checkpoint(|#checkpoint_utils.save_checkpoint(|g" $cur_path/fairseq_cli/train.py -##############执行训练########## -cd $cur_path -if [ -d $cur_path/test/output ];then - rm -rf $cur_path/test/output/* - mkdir -p $cur_path/test/output/$ASCEND_DEVICE_ID -else - mkdir -p $cur_path/test/output/$ASCEND_DEVICE_ID -fi -wait - - -pip3 install --editable ./ -start=$(date +%s) -python3 train.py $data_path/en_ro/ \ - --distributed-world-size 1 --npu --npu-id $ASCEND_DEVICE_ID --fp16 --encoder-normalize-before --decoder-normalize-before \ - --arch mbart_large --layernorm-embedding \ - --task translation_from_pretrained_bart \ - --source-lang en_XX --target-lang ro_RO \ - --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ - --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ - --lr-scheduler polynomial_decay --lr 3e-05 --min-lr -1 --warmup-updates 2500 --total-num-update 40000 \ - --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ - --max-tokens 512 --update-freq 2 \ - --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ - --seed 222 --log-format simple --log-interval 2 \ - --restore-file $data_path/mbart.cc25/model.pt \ - --reset-optimizer --reset-meters --reset-dataloader --reset-lr-scheduler \ - --langs $langs \ - --max-epoch $train_epochs \ - --max-update 200 \ - --ddp-backend no_c10d > $cur_path/test/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log 2>&1 & -wait -end=$(date +%s) -e2etime=$(( $end - $start )) - -sed -i "s|#checkpoint_utils.save_checkpoint(|checkpoint_utils.save_checkpoint(|g" $cur_path/fairseq_cli/train.py -#结果打印,不需要修改 -echo "------------------ Final result ------------------" -#输出性能FPS,需要模型审视修改 -TrainingTime=0 -FPS=`grep -rn train_inner $cur_path/test/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log | awk -F "wps=" '{print$2}' | awk -F "," '{print$1}' | tail -n+6 | awk '{sum+=$1} END {print"",sum/NR}' | sed s/[[:space:]]//g` -#打印,不需要修改 -echo "Final Performance images/sec : $FPS" -TrainingTime=`awk 'BEGIN{printf "%.2f\n",'${batch_size}'*1000/'${FPS}'}'` -#输出训练精度,需要模型审视修改 -#打印,不需要修改 -#echo "Final Train Accuracy : ${train_accuracy}" -echo "E2E Training Duration sec : $e2e_time" -#性能看护结果汇总 -#训练用例信息,不需要修改 -BatchSize=${batch_size} -DeviceType=`uname -m` -CaseName=${Network}_bs${BatchSize}_${RANK_SIZE}'p'_'perf' - -##获取性能数据,不需要修改 -#吞吐量 -ActualFPS=${FPS} - -#从train_$ASCEND_DEVICE_ID.log提取Loss到train_${CaseName}_loss.txt中,需要根据模型审视 -grep train_inner $cur_path/test/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log | awk -F "loss=" '{print$2}' | awk -F "," '{print$1}' > $cur_path/test/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt - -#最后一个迭代loss值,不需要修改 -ActualLoss=`awk 'END {print}' $cur_path/test/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt` - -#关键信息打印到${CaseName}.log中,不需要修改 -echo "Network = ${Network}" > $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log -echo "RankSize = ${RANK_SIZE}" >> $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log -echo "BatchSize = ${BatchSize}" >> $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log -echo "DeviceType = ${DeviceType}" >> $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log -echo "CaseName = ${CaseName}" >> $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log -echo "ActualFPS = ${ActualFPS}" >> $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log -echo "TrainingTime = ${TrainingTime}" >> $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log -echo "ActualLoss = ${ActualLoss}" >> $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log -echo "E2ETrainingTime = ${e2etime}" >> $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log - +#!/bin/bash + +cur_path=`pwd`/../ +#失败用例打屏 +export ASCEND_SLOG_PRINT_TO_STDOUT=0 +export SCALAR_TO_HOST_MEM=1 + +export MKL_SERVICE_FORCE_INTEL=1 +export BMMV2_ENABLE=1 +langs=ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN +#基础参数,需要模型审视修改 +#Batch Size +batch_size=1024 +#网络名称,同目录名称 +Network="mBART_ID2372_for_PyTorch" +#Device数量,单卡默认为1 +RankSize=1 +#训练epoch,可选 +train_epochs=1 +#训练step +train_steps= +#学习率 +learning_rate=3e-05 + +#参数配置 +data_path="" + +if [[ $1 == --help || $1 == --h ]];then + echo "usage:./train_performance_1p.sh " + exit 1 +fi + +for para in $* +do + if [[ $para == --data_path* ]];then + data_path=`echo ${para#*=}` + elif [[ $para == --conda_name* ]];then + conda_name=`echo ${para#*=}` + source set_conda.sh --conda_name=$conda_name + #export PATH=/usr/local/python3.7.5/bin:/home/anaconda3/bin:$PATH + #source activate py8 + source activate $conda_name + + fi +done + +if [[ $data_path == "" ]];then + echo "[Error] para \"data_path\" must be config" + exit 1 + +fi +sed -i "s|checkpoint_utils.save_checkpoint(|#checkpoint_utils.save_checkpoint(|g" $cur_path/fairseq_cli/train.py +##############执行训练########## +cd $cur_path +if [ -d $cur_path/test/output ];then + rm -rf $cur_path/test/output/* + mkdir -p $cur_path/test/output/$ASCEND_DEVICE_ID +else + mkdir -p $cur_path/test/output/$ASCEND_DEVICE_ID +fi +wait + + +pip3 install --editable ./ +start=$(date +%s) +python3 train.py $data_path/en_ro/ \ + --distributed-world-size 1 --npu --npu-id $ASCEND_DEVICE_ID --fp16 --encoder-normalize-before --decoder-normalize-before \ + --arch mbart_large --layernorm-embedding \ + --task translation_from_pretrained_bart \ + --source-lang en_XX --target-lang ro_RO \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ + --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ + --lr-scheduler polynomial_decay --lr 3e-05 --min-lr -1 --warmup-updates 2500 --total-num-update 40000 \ + --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ + --max-tokens 1024 --update-freq 2 \ + --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ + --seed 222 --log-format simple --log-interval 2 \ + --restore-file $data_path/mbart.cc25/model.pt \ + --reset-optimizer --reset-meters --reset-dataloader --reset-lr-scheduler \ + --langs $langs \ + --max-epoch $train_epochs \ + --max-update 200 \ + --ddp-backend no_c10d > $cur_path/test/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log 2>&1 & +wait +end=$(date +%s) +e2etime=$(( $end - $start )) + +sed -i "s|#checkpoint_utils.save_checkpoint(|checkpoint_utils.save_checkpoint(|g" $cur_path/fairseq_cli/train.py +#结果打印,不需要修改 +echo "------------------ Final result ------------------" +#输出性能FPS,需要模型审视修改 +TrainingTime=0 +FPS=`grep -rn train_inner $cur_path/test/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log | awk -F "wps=" '{print$2}' | awk -F "," '{print$1}' | tail -n+6 | awk '{sum+=$1} END {print"",sum/NR}' | sed s/[[:space:]]//g` +#打印,不需要修改 +echo "Final Performance images/sec : $FPS" +TrainingTime=`awk 'BEGIN{printf "%.2f\n",'${batch_size}'*1000/'${FPS}'}'` +#输出训练精度,需要模型审视修改 +#打印,不需要修改 +#echo "Final Train Accuracy : ${train_accuracy}" +echo "E2E Training Duration sec : $e2e_time" +#性能看护结果汇总 +#训练用例信息,不需要修改 +BatchSize=${batch_size} +DeviceType=`uname -m` +CaseName=${Network}_bs${BatchSize}_${RANK_SIZE}'p'_'perf' + +##获取性能数据,不需要修改 +#吞吐量 +ActualFPS=${FPS} + +#从train_$ASCEND_DEVICE_ID.log提取Loss到train_${CaseName}_loss.txt中,需要根据模型审视 +grep train_inner $cur_path/test/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log | awk -F "loss=" '{print$2}' | awk -F "," '{print$1}' > $cur_path/test/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt + +#最后一个迭代loss值,不需要修改 +ActualLoss=`awk 'END {print}' $cur_path/test/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt` + +#关键信息打印到${CaseName}.log中,不需要修改 +echo "Network = ${Network}" > $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "RankSize = ${RANK_SIZE}" >> $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "BatchSize = ${BatchSize}" >> $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "DeviceType = ${DeviceType}" >> $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "CaseName = ${CaseName}" >> $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualFPS = ${ActualFPS}" >> $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainingTime = ${TrainingTime}" >> $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualLoss = ${ActualLoss}" >> $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "E2ETrainingTime = ${e2etime}" >> $cur_path/test/output/$ASCEND_DEVICE_ID/${CaseName}.log + diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/test/train_performance_8p.sh b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/test/train_performance_8p.sh index 9dc7e1687a0553027f1908f4ccc2faa566bae9ed..ab1b36a9e03ccc957274b9c973c27affae8ff8ec 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/test/train_performance_8p.sh +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/test/train_performance_8p.sh @@ -16,7 +16,7 @@ train_epochs=1 #网络名称,同目录名称 Network="mBART_ID2372_for_PyTorch" #训练batch_size -token_size=512 +token_size=1024 #训练开始时间,不需要修改 start_time=$(date +%s) @@ -80,7 +80,7 @@ do --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ --lr-scheduler polynomial_decay --lr 3e-05 --min-lr -1 --warmup-updates 2500 --total-num-update 40000 \ --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ - --max-tokens 512 --update-freq 2 \ + --max-tokens 1024 --update-freq 2 \ --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ --seed 222 --log-format simple --log-interval 2 \ --restore-file $data_path/mbart.cc25/model.pt \ @@ -99,7 +99,7 @@ do --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ --lr-scheduler polynomial_decay --lr 3e-05 --min-lr -1 --warmup-updates 2500 --total-num-update 40000 \ --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ - --max-tokens 512 --update-freq 2 \ + --max-tokens 1024 --update-freq 2 \ --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ --seed 222 --log-format simple --log-interval 2 \ --restore-file $data_path/mbart.cc25/model.pt \