diff --git a/PyTorch/built-in/foundation/GPT-NeoX/megatron/model/transformer.py b/PyTorch/built-in/foundation/GPT-NeoX/megatron/model/transformer.py index 8392c92f2f97e0576a8abbdc4faf88274c9f071f..528eb76554d06351567c607eac6fb8980337e19c 100644 --- a/PyTorch/built-in/foundation/GPT-NeoX/megatron/model/transformer.py +++ b/PyTorch/built-in/foundation/GPT-NeoX/megatron/model/transformer.py @@ -40,6 +40,9 @@ from megatron.model.fused_bias_dropout import ( ) from megatron.model.utils import configure_sparse_attention +import torch_npu +from einops import rearrange + # flags required to enable jit fusion kernels torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_executor(False) @@ -175,6 +178,58 @@ class ParallelLinear(nn.Module): def forward(self, hidden_states): return self.final_linear(hidden_states) +class FlashSelfAttention(torch.nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + + def __init__(self, causal=True, seq_len=4096, softmax_scale=1.0, attention_dropout=0., device=None, dtype=None, layerout="SBH"): + super().__init__() + self.causal = causal + self.seq_len = seq_len + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + self.layerout = layerout + + def forward(self, q, k, v, n, attention_mask): + attention_mask_shape = attention_mask.shape + #if attention_mask_shape[0] == 1: + # attention_mask = attention_mask.view((attention_mask_shape[-2], attention_mask_shape[-1])) + # print(f'q : {q.shape} k: {k.shape} v: {v.shape}, mask {attention_mask.shape}') + + if self.causal: + output = torch_npu.npu_fusion_attention( + q, k, v, n, self.layerout, + pse=None, + padding_mask=None, + atten_mask=attention_mask, + scale=self.softmax_scale, + pre_tockens=self.seq_len, + next_tockens=0, + keep_prob=1 - self.dropout_p, + )[0] + else: + output = torch_npu.npu_fusion_attention( + q, k, v, n, self.layerout, + pse=None, + padding_mask=None, + atten_mask=attention_mask, + scale=self.softmax_scale, + pre_tockens=65536, + next_tockens=65536, + keep_prob=1 - self.dropout_p, + )[0] + return output + + + + class ParallelSelfAttention(nn.Module): """Parallel self-attention layer abstract class. @@ -269,7 +324,9 @@ class ParallelSelfAttention(nn.Module): self.rotary_emb = None self.attention_type = neox_args.attention_config[layer_number] + self.use_npu_flash_attn = False self.use_flash_attention = self.attention_type == "flash" + self.layerout = "BSND" self.sparse = self.attention_type not in ("global", "flash") self.sparse = self.attention_type != "global" and not self.use_flash_attention if self.sparse: @@ -288,7 +345,9 @@ class ParallelSelfAttention(nn.Module): from megatron.model.flash_attention import ( flash_attn_unpadded_qkvpacked_func, ) - + elif self.use_npu_flash_attn: + self.npu_flash_attention = FlashSelfAttention(causal=True, seq_len=2048, softmax_scale=(1.0 / self.norm_factor), + attention_dropout=0, layerout=self.layerout) else: self.scale_mask_softmax = FusedScaleMaskSoftmax( input_in_fp16=self.fp16, @@ -318,10 +377,25 @@ class ParallelSelfAttention(nn.Module): self.checkpoint_activations = neox_args.checkpoint_activations self.checkpoint_selective = neox_args.checkpoint_selective + print(f'trian: {self.use_triangle_attn} flash {self.use_npu_flash_attn}') def attention( self, query_layer, key_layer, value_layer, layer_past, attention_mask ): - if self.use_triangle_attn and layer_past is None and query_layer.size( + if self.use_npu_flash_attn: + # context layer shape: [b, np, sq, hn] + + output_size = (value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3)) + + sequence_len, bsz, head_num, head_dim = query_layer.shape + + query_layer, key_layer, value_layer = [rearrange(x, 's b n d -> b s n d').contiguous() for x in (query_layer, key_layer, value_layer)] + attn_output = self.npu_flash_attention(query_layer, key_layer, value_layer, head_num, attention_mask) + context_layer = attn_output.view(*output_size) + return context_layer + elif self.use_triangle_attn and layer_past is None and query_layer.size( 0) >= self.block_size * 2 and query_layer.size(0) % self.block_size == 0: # context layer shape: [b, np, sq, hn] output_size = (value_layer.size(1), diff --git a/PyTorch/built-in/foundation/GPT-NeoX/requirements/requirements.txt b/PyTorch/built-in/foundation/GPT-NeoX/requirements/requirements.txt index b1930922cfdb83eafc69fec232535dfc0b76e13e..9058a148c3d6f5d8a89193100249925b622e8a51 100644 --- a/PyTorch/built-in/foundation/GPT-NeoX/requirements/requirements.txt +++ b/PyTorch/built-in/foundation/GPT-NeoX/requirements/requirements.txt @@ -5,7 +5,7 @@ protobuf==3.20.3 pybind11>=2.6.2 regex sentencepiece -best_download +best_download==0.0.9 cloudpickle==2.2.1 decorator==5.1.1 psutil==5.9.5