diff --git a/PyTorch/built-in/foundation/GPT-NeoX/configs/20B.yml b/PyTorch/built-in/foundation/GPT-NeoX/configs/20B.yml index 6595919cb099ca52a81261db2c17e1b882f912b4..7fcbdab89a64abed1b3aa0b80a36360164460eac 100644 --- a/PyTorch/built-in/foundation/GPT-NeoX/configs/20B.yml +++ b/PyTorch/built-in/foundation/GPT-NeoX/configs/20B.yml @@ -13,8 +13,8 @@ # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages # across the node boundaries ) - "pipe-parallel-size": 4, - "model-parallel-size": 2, + "pipe-parallel-size": 2, + "model-parallel-size": 4, # model settings "num-layers": 44, @@ -28,13 +28,15 @@ "no-weight-tying": true, "gpt_j_residual": true, "output_layer_parallelism": "column", - "scaled-upper-triang-masked-softmax-fusion": true, - "bias-gelu-fusion": true, + "scaled-upper-triang-masked-softmax-fusion": false, + "bias-gelu-fusion": false, # init methods "init_method": "small_init", "output_layer_init_method": "wang_init", + "scaled_masked_softmax_fusion":true, + # optimizer settings "optimizer": { "type": "Adam", @@ -82,7 +84,7 @@ "enabled": true, "loss_scale": 0, "loss_scale_window": 1000, - "initial_scale_power": 12, + "initial_scale_power": 32, "hysteresis": 2, "min_loss_scale": 1 }, @@ -99,12 +101,12 @@ "eval-iters": 10, # logging - "log-interval": 2, - "steps_per_print": 2, + "log-interval": 1, + "steps_per_print": 1, "wall_clock_breakdown": false, ### NEW DATA: #### - "tokenizer_type": "HFTokenizer", +# "tokenizer_type": "HFTokenizer", "tensorboard-dir": "./tensorboard", "log-dir": "./logs", diff --git a/PyTorch/built-in/foundation/GPT-NeoX/megatron/fused_kernels/__init__.py b/PyTorch/built-in/foundation/GPT-NeoX/megatron/fused_kernels/__init__.py index 75dd19f39a3178d5cff1502322f652e690cd87c0..faf28406966eb3e71e0167efd83334d4cef161d1 100644 --- a/PyTorch/built-in/foundation/GPT-NeoX/megatron/fused_kernels/__init__.py +++ b/PyTorch/built-in/foundation/GPT-NeoX/megatron/fused_kernels/__init__.py @@ -31,8 +31,8 @@ os.environ["TORCH_CUDA_ARCH_LIST"] = "" def load_fused_kernels(): try: - import scaled_upper_triang_masked_softmax_cuda - import scaled_masked_softmax_cuda + import torch_npu + from torch_npu import npu_scaled_masked_softmax except (ImportError, ModuleNotFoundError): print("\n") print("=" * 100) diff --git a/PyTorch/built-in/foundation/GPT-NeoX/megatron/model/fused_softmax.py b/PyTorch/built-in/foundation/GPT-NeoX/megatron/model/fused_softmax.py index 78f2992adb7b1c51f3531897dd569dd10cd8b0d1..ff4b6b33cd62070b467da576f19c350ac627caee 100644 --- a/PyTorch/built-in/foundation/GPT-NeoX/megatron/model/fused_softmax.py +++ b/PyTorch/built-in/foundation/GPT-NeoX/megatron/model/fused_softmax.py @@ -133,39 +133,26 @@ class FusedScaleMaskSoftmax(nn.Module): self.softmax_in_fp32 = softmax_in_fp32 self.scale = scale - assert ( - self.scale is None or softmax_in_fp32 - ), "softmax should be in fp32 when scaled" - def forward(self, input, mask): # [b, np, sq, sk] assert input.dim() == 4 if self.is_kernel_available(mask, *input.size()): - return self.forward_fused_softmax(input, mask) + import torch_npu + return torch_npu.npu_scaled_masked_softmax(input, mask, self.scale, False) else: return self.forward_torch_softmax(input, mask) + def is_kernel_available(self, mask, b, np, sq, sk): - attn_batches = b * np - - if ( - self.fusion # user wants to fuse - and self.input_in_float16 # input must be fp16 - and mask is not None # mask tensor must not be None - and 16 < sk <= 2048 # sk must be 16 ~ 2048 - and sq % 4 == 0 # sq must be divisor of 4 - and attn_batches % 4 == 0 # np * b must be divisor of 4 - ): - if 0 <= sk <= 2048: - batch_per_block = self.get_batch_per_block(sq, sk, b, np) - - if self.upper_triang_mask_fusion: - if attn_batches % batch_per_block == 0: - return True - else: - if sq % batch_per_block == 0: - return True - return False + return ( + self.fusion # user want to fuse + and self.input_in_float16 # input must be fp16 + and 32 < sk <= 2048 # sk must be 32 ~ 2048 + and sq % 16 == 0 # sq must be divisor of 16 + and sk % 16 == 0 # sk must be divisor of 16 + ) + + def forward_fused_softmax(self, input, mask): b, np, sq, sk = input.size() diff --git a/PyTorch/built-in/foundation/GPT-NeoX/megatron/model/positional_embeddings.py b/PyTorch/built-in/foundation/GPT-NeoX/megatron/model/positional_embeddings.py index ff53e9fbf04c615b5c4caa0163a3ce99bbfd5f92..8620800d36281eac1af825b4ae740192a533faa0 100644 --- a/PyTorch/built-in/foundation/GPT-NeoX/megatron/model/positional_embeddings.py +++ b/PyTorch/built-in/foundation/GPT-NeoX/megatron/model/positional_embeddings.py @@ -66,7 +66,7 @@ class RotaryEmbedding(torch.nn.Module): # rotary pos emb helpers: def rotate_half(x): - x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + x1, x2 = torch.chunk(x, 2, -1) return torch.cat( (-x2, x1), dim=x1.ndim - 1 ) # dim=-1 triggers a bug in earlier torch versions diff --git a/PyTorch/built-in/foundation/GPT-NeoX/megatron/model/transformer.py b/PyTorch/built-in/foundation/GPT-NeoX/megatron/model/transformer.py index e917396a26e2d21e70d3ac540d9b056e01fad33e..d5d428faa1459378878f7fe08629f815f4b66594 100644 --- a/PyTorch/built-in/foundation/GPT-NeoX/megatron/model/transformer.py +++ b/PyTorch/built-in/foundation/GPT-NeoX/megatron/model/transformer.py @@ -224,8 +224,7 @@ class ParallelSelfAttention(nn.Module): gather_output=False, init_method=init_method, ) - - coeff = None + coeff = 1 self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = max(1, self.layer_number) @@ -332,14 +331,6 @@ class ParallelSelfAttention(nn.Module): ) key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - # preallocating result tensor: [b * np, sq, sk] - matmul_result = torch.empty( - output_size[0] * output_size[1], - output_size[2], - output_size[3], - dtype=query_layer.dtype, - device=torch.cuda.current_device(), - ) ''' ### Raw attention scores. [b * np, sq, sk] # orig @@ -544,6 +535,11 @@ class ParallelSelfAttention(nn.Module): # full rotary query_rot, key_rot = query_layer, key_layer + query_rot = query_rot.contiguous() + query_pass = query_pass.contiguous() + key_rot = key_rot.contiguous() + key_pass = key_pass.contiguous() + apply_rotary_fn = ( apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb ) @@ -717,7 +713,7 @@ class ParallelTransformerLayer(nn.Module): with torch.enable_grad(): attention_output = bias_dropout_fn( attention_output, - bias=attention_bias.expand_as(attention_output), + bias=attention_bias, residual=None, prob=self.hidden_dropout, ) @@ -727,7 +723,7 @@ class ParallelTransformerLayer(nn.Module): with torch.enable_grad(): output = bias_dropout_fn( mlp_output, - bias=mlp_bias.expand_as(mlp_output), + bias=mlp_bias, residual=attention_output, prob=self.hidden_dropout, ) diff --git a/PyTorch/built-in/foundation/GPT-NeoX/megatron/mpu/cross_entropy.py b/PyTorch/built-in/foundation/GPT-NeoX/megatron/mpu/cross_entropy.py index d28e0cc6baf83dabc22cbc1b9a844fbdcfa79be5..a13d7601b131a5ebcdd40f63f4a7f76bccc29c4d 100644 --- a/PyTorch/built-in/foundation/GPT-NeoX/megatron/mpu/cross_entropy.py +++ b/PyTorch/built-in/foundation/GPT-NeoX/megatron/mpu/cross_entropy.py @@ -50,7 +50,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): # Create a mask of valid vocab ids (1 means it needs to be masked). target_mask = (target < vocab_start_index) | (target >= vocab_end_index) masked_target = target.clone() - vocab_start_index - masked_target[target_mask] = 0 + masked_target *= ~target_mask # Get predicted-logits = logits[target]. # For Simplicity, we convert logits to a 2-D tensor with size @@ -60,10 +60,10 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): arange_1d = torch.arange( start=0, end=logits_2d.size()[0], device=logits_2d.device ) - predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d.long()] predicted_logits_1d = predicted_logits_1d.clone().contiguous() predicted_logits = predicted_logits_1d.view_as(target) - predicted_logits[target_mask] = 0.0 + predicted_logits *= ~target_mask # All reduce is needed to get the chunks from other GPUs. torch.distributed.all_reduce( predicted_logits, diff --git a/PyTorch/built-in/foundation/GPT-NeoX/megatron/mpu/layers.py b/PyTorch/built-in/foundation/GPT-NeoX/megatron/mpu/layers.py index 7408f1f490fd16d2ef4e0e0bc0066816ff887098..2c48ff6918ce54dd71be45d4a9850ccc1e78e5ef 100644 --- a/PyTorch/built-in/foundation/GPT-NeoX/megatron/mpu/layers.py +++ b/PyTorch/built-in/foundation/GPT-NeoX/megatron/mpu/layers.py @@ -186,7 +186,7 @@ class VocabParallelEmbedding(torch.nn.Module): ) # Mask the input. masked_input = input_.clone() - self.vocab_start_index - masked_input[input_mask] = 0 + masked_input *= ~input_mask else: masked_input = input_ # Get the embeddings. @@ -201,7 +201,7 @@ class VocabParallelEmbedding(torch.nn.Module): ) # Mask the output embedding. if self.model_parallel_size > 1: - output_parallel[input_mask, :] = 0.0 + output_parallel *= ~input_mask[..., None] # Reduce across all the model parallel GPUs. output = reduce_from_model_parallel_region(output_parallel) return output diff --git a/PyTorch/built-in/foundation/GPT-NeoX/tests/env_npu.sh b/PyTorch/built-in/foundation/GPT-NeoX/tests/env_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..6bfb85d9fd5fa375c3557b9a6c70b0451b58726c --- /dev/null +++ b/PyTorch/built-in/foundation/GPT-NeoX/tests/env_npu.sh @@ -0,0 +1,70 @@ +#!/bin/bash +CANN_INSTALL_PATH_CONF='/etc/Ascend/ascend_cann_install.info' + +if [ -f $CANN_INSTALL_PATH_CONF ]; then + CANN_INSTALL_PATH=$(cat $CANN_INSTALL_PATH_CONF | grep Install_Path | cut -d "=" -f 2) +else + CANN_INSTALL_PATH="/usr/local/Ascend" +fi + +if [ -d ${CANN_INSTALL_PATH}/ascend-toolkit/latest ]; then + source ${CANN_INSTALL_PATH}/ascend-toolkit/set_env.sh +else + source ${CANN_INSTALL_PATH}/nnae/set_env.sh +fi + +#设置device侧日志登记为error +msnpureport -g error -d 0 +msnpureport -g error -d 1 +msnpureport -g error -d 2 +msnpureport -g error -d 3 +msnpureport -g error -d 4 +msnpureport -g error -d 5 +msnpureport -g error -d 6 +msnpureport -g error -d 7 + +#关闭Device侧Event日志 +msnpureport -e disable + +#将Host日志输出到串口,0-关闭/1-开启 +export ASCEND_SLOG_PRINT_TO_STDOUT=0 +#设置默认日志级别,0-debug/1-info/2-warning/3-error +export ASCEND_GLOBAL_LOG_LEVEL=3 +#设置Event日志开启标志,0-关闭/1-开启 +export ASCEND_GLOBAL_EVENT_ENABLE=0 +#设置是否开启taskque,0-关闭/1-开启 +export TASK_QUEUE_ENABLE=1 +#设置是否开启PTCopy,0-关闭/1-开启 +export PTCOPY_ENABLE=1 +#设置是否开启2个非连续combined标志,0-关闭/1-开启 +export COMBINED_ENABLE=0 +#设置特殊场景是否需要重新编译,不需要修改 +export DYNAMIC_OP="ADD#MUL" +#HCCL白名单开关,1-关闭/0-开启 +export HCCL_WHITELIST_DISABLE=1 +#设置HCCL超时时间 +export HCCL_CONNECT_TIMEOUT=1200 +ulimit -SHn 512000 + +path_lib=$(python3.7 -c """ +import sys +import re +result='' +for index in range(len(sys.path)): + match_sit = re.search('-packages', sys.path[index]) + if match_sit is not None: + match_lib = re.search('lib', sys.path[index]) + + if match_lib is not None: + end=match_lib.span()[1] + result += sys.path[index][0:end] + ':' + + result+=sys.path[index] + '/torch/lib:' +print(result)""" +) + +echo ${path_lib} + +export LD_LIBRARY_PATH=/usr/local/python3.7.5/lib/:${path_lib}:$LD_LIBRARY_PATH +export HCCL_WHITELIST_DISABLE=1 +#export HCCL_IF_IP=$(hostname -I |awk '{print $1}') diff --git a/PyTorch/built-in/foundation/GPT-NeoX/train.py b/PyTorch/built-in/foundation/GPT-NeoX/train.py index 9731b77deb05efecab5d66a938078ed4c44afa04..d8340ae5fdcae632e104952746690c2cee5874c9 100644 --- a/PyTorch/built-in/foundation/GPT-NeoX/train.py +++ b/PyTorch/built-in/foundation/GPT-NeoX/train.py @@ -27,7 +27,7 @@ import os if __name__ == "__main__": - torch.npu.set_compile_mode(jit_compile=False) + torch.npu.set_compile_mode(jit_compile=True) option = {"NPU_FUZZY_COMPILE_BLACKLIST": "Tril,SoftmaxV2,LayerNormGrad", "MM_BMM_ND_ENABLE": 'enable'} torch.npu.set_option(option) neox_args = NeoXArgs.consume_neox_args()