diff --git a/mindspeed/core/transformer/flash_attention/__init__.py b/mindspeed/core/transformer/flash_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mindspeed/core/transformer/flash_attention/alibi/__init__.py b/mindspeed/core/transformer/flash_attention/alibi/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mindspeed/core/transformer/flash_attention/alibi/adaptor.py b/mindspeed/core/transformer/flash_attention/alibi/adaptor.py new file mode 100644 index 0000000000000000000000000000000000000000..24be3b407089c1c425f694ab47f18179df1ba29f --- /dev/null +++ b/mindspeed/core/transformer/flash_attention/alibi/adaptor.py @@ -0,0 +1,40 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +from typing import Optional + +from torch import Tensor +from megatron.core.transformer.dot_product_attention import DotProductAttention as MegatronDotProductAttention +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.packed_seq_params import PackedSeqParams + +from mindspeed.core.transformer.flash_attention.alibi.dot_product_attention import DotProductAttentionImpl + + +class MindSpeedDotProductAttention(DotProductAttentionImpl, MegatronDotProductAttention): + + def __init__( + self, + config: TransformerConfig, + layer_number: int, + attn_mask_type: AttnMaskType, + attention_type: str, + attention_dropout: float = None, + softmax_scale: float = None, + cp_comm_type: str = None, + ): + MegatronDotProductAttention.__init__( + self, + config, + layer_number, + attn_mask_type, + attention_type, + attention_dropout, + softmax_scale, + cp_comm_type + ) + + # add pse + DotProductAttentionImpl.__init__(self) + diff --git a/mindspeed/core/transformer/flash_attention/alibi/alibi.py b/mindspeed/core/transformer/flash_attention/alibi/alibi.py new file mode 100644 index 0000000000000000000000000000000000000000..df138999bc22f3c821d8892df9ffba1612805183 --- /dev/null +++ b/mindspeed/core/transformer/flash_attention/alibi/alibi.py @@ -0,0 +1,125 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import threading + +import torch + +from megatron.core import parallel_state +from mindspeed.core.transformer.flash_attention.alibi.alibi_utils import get_slopes + + +class Alibi: + _instance = None + alibi = None + matmul_result = None + output_size = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + +class AlibiForFusionAttnSingleton: + _alibi_tensor_args = None + _alibi_tensor = None + + _alibi_slopes_headnum = None + _alibi_slopes = None + + @classmethod + def get_alibi_tensor_for_fusion_attn( + cls, + max_seq_len, + num_attention_heads, + dtype, + neg_diagonal_opposite=False, + last_k=1024 + ): + if ( + cls._alibi_tensor is None or + cls._alibi_tensor_args != ( + max_seq_len, num_attention_heads, + neg_diagonal_opposite, last_k + ) + ): + if last_k > max_seq_len: + last_k = max_seq_len + + tp_world_size = ( + parallel_state.get_tensor_model_parallel_world_size() + ) + current_head_num = num_attention_heads // tp_world_size + slopes = ( + AlibiForFusionAttnSingleton.get_alibi_slopes_for_fusion_attn( + num_attention_heads + ) + ) + + position_point = torch.arange(max_seq_len) - max_seq_len + 1 + diag = torch.diag( + torch.diag(position_point) + ).unsqueeze(0).unsqueeze(0) + + position_point = ( + position_point.unsqueeze(0).unsqueeze(0).expand( + current_head_num, last_k, -1 + ) + ) + position_point = ( + position_point - + diag.transpose(-1, -2)[:, -last_k:, :].expand( + current_head_num, last_k, max_seq_len + ) + ) + + alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point.npu() + + if not neg_diagonal_opposite: + alibi = -torch.abs(alibi) + + alibi = alibi.unsqueeze(0) + alibi = torch.Tensor(alibi).npu() + if dtype == torch.float16: + alibi = alibi.to(torch.float16) + elif dtype == torch.bfloat16: + alibi = alibi.to(torch.bfloat16) + + cls._alibi_tensor = alibi + cls._alibi_tensor_args = ( + max_seq_len, num_attention_heads, + neg_diagonal_opposite, last_k + ) + + return cls._alibi_tensor + + @classmethod + def get_alibi_slopes_for_fusion_attn(cls, n): + if ( + cls._alibi_slopes is None or + cls._alibi_slopes_headnum != n + ): + slopes = get_slopes(n) + + tp_world_size = ( + parallel_state.get_tensor_model_parallel_world_size() + ) + tp_index = parallel_state.get_tensor_model_parallel_rank() + + current_head_num = n // tp_world_size + slopes = torch.Tensor( + slopes[ + tp_index * current_head_num: + tp_index * current_head_num + current_head_num + ] + ).npu() + + cls._alibi_slopes = slopes + cls._alibi_slopes_headnum = n + return cls._alibi_slopes + + return cls._alibi_slopes \ No newline at end of file diff --git a/mindspeed/core/transformer/flash_attention/alibi/alibi_utils.py b/mindspeed/core/transformer/flash_attention/alibi/alibi_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..31d879a2acdb5d8ce76b4773eb7e07468dfec083 --- /dev/null +++ b/mindspeed/core/transformer/flash_attention/alibi/alibi_utils.py @@ -0,0 +1,92 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import math + +import torch +from megatron.core import parallel_state + + +def get_slopes(n): + """ + Generate ALiBi slopes for n attention heads. + The slopes are computed based on the number of heads and follow a power-of-2 pattern. + + Args: + n (int): Number of attention heads. + + Returns: + List[float]: A list of slopes for each attention head. + """ + + def get_slopes_power_of_2(n): + start = (2 ** (-2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio ** i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][:n - closest_power_of_2] + ) + + +def _get_inverted_mask(attention_mask, alibi): + inverted_mask = attention_mask.to(alibi.dtype) + inverted_mask = inverted_mask.masked_fill( + inverted_mask.to(torch.bool), float("-inf") + ) + return inverted_mask.to(alibi.device) + alibi.unsqueeze(0) + + +def _build_alibi_tensor( + max_seq_len, + num_attention_heads, + square_alibi_mask, + fill_neg_inf +): + def _fill_with_neg_inf(t): + """FP16-compatible function that fills a tensor with -inf.""" + return t.float().fill_(float("-inf")).type_as(t) + + def _buffered_future_mask(maxpos, alibi, attn_heads): + _future_mask = torch.triu( + _fill_with_neg_inf(torch.zeros([maxpos, maxpos])), + 1 + ) + _future_mask = _future_mask.unsqueeze(0) + alibi + return _future_mask[:attn_heads, :maxpos, :maxpos] + + slopes = torch.Tensor(get_slopes(num_attention_heads)) + if square_alibi_mask: + position_point = torch.arange(max_seq_len) - max_seq_len + 1 + position_point = ( + position_point.unsqueeze(0).unsqueeze(0).expand( + num_attention_heads, max_seq_len, -1 + ) + ) + diag = torch.diag(position_point[0]) + position_point = ( + position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2) + ) + alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point + else: + alibi = ( + slopes.unsqueeze(1).unsqueeze(1) * + torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0).expand( + num_attention_heads, -1, -1 + ) + ) + + # Select the part of the tensor that corresponds to our tensor parallel index. + tp_world_size = parallel_state.get_tensor_model_parallel_world_size() + tp_index = parallel_state.get_tensor_model_parallel_rank() + alibi = alibi.reshape((tp_world_size, -1, *alibi.shape[1:]))[tp_index] + + if fill_neg_inf: + return _buffered_future_mask(max_seq_len, alibi, num_attention_heads) + + return alibi diff --git a/mindspeed/core/transformer/flash_attention/alibi/dot_product_attention.py b/mindspeed/core/transformer/flash_attention/alibi/dot_product_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..4d930c27d31c31f6ae50678ac6d5379e98711841 --- /dev/null +++ b/mindspeed/core/transformer/flash_attention/alibi/dot_product_attention.py @@ -0,0 +1,119 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import math +from typing import Optional + +from torch import Tensor + +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.packed_seq_params import PackedSeqParams + +from mindspeed.ops.fusion_attention_v2 import npu_fusion_attention +from mindspeed.core.transformer.flash_attention.alibi.alibi import AlibiForFusionAttnSingleton + +try: + from einops import rearrange +except ImportError: + rearrange = None + + +class DotProductAttentionImpl(): + """ + Implementation of dot product attention with ALiBi support. + """ + + def __init__(self): + # add pse + self.pse = None + self.pse_type = self.config.alibi_fusion_attn_type + + if self.pse_type is None: + self.pse_type = 1 # not use pse + elif self.pse_type == 0: + alibi = ( + AlibiForFusionAttnSingleton.get_alibi_tensor_for_fusion_attn( + self.config.seq_length, + self.config.num_attention_heads, + self.config.params_dtype, + self.config.alibi_diagonal_opposite, + 1024 + ) + ) + self.pse = alibi + elif self.pse_type == 2 or self.pse_type == 3: + self.pse = ( + AlibiForFusionAttnSingleton.get_alibi_slopes_for_fusion_attn( + self.config.num_attention_heads + ) + ) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attention_mask: Tensor, + attn_mask_type: AttnMaskType = None, + attention_bias: Tensor = None, + packed_seq_params: Optional[PackedSeqParams] = None, + ): + assert attention_bias is None, \ + "Attention bias is not supported for DotProductAttention." + + if packed_seq_params is None: + seq_length, bsz, n_head, head_dim = ( + query.shape[0], query.shape[1], query.shape[2], query.shape[3] + ) + else: + seq_length, n_head, head_dim = ( + query.shape[0], query.shape[1], query.shape[2] + ) + + sparse_mode = self.config.sparse_mode + if attn_mask_type == AttnMaskType.no_mask: + sparse_mode = 0 # default mask + + scale = ( + 1.0 / math.sqrt(self.hidden_size_per_attention_head) + if self.scale_mask_softmax.scale is None + else self.softmax_scale + ) + + if packed_seq_params is not None: # TND + actual_seq_qlen = packed_seq_params.cu_seqlens_q.tolist() + actual_seq_kvlen = packed_seq_params.cu_seqlens_kv.tolist() + query, key, value = ( + [rearrange(x, 's b h d -> (b s) h d') for x in [query, key, value]] + ) + shape_order = 'TND' + else: # SBH + actual_seq_qlen = None + actual_seq_kvlen = None + query, key, value = ( + [rearrange(x, 's b h d -> s b (h d)') for x in [query, key, value]] + ) + shape_order = 'SBH' + + output = npu_fusion_attention( + query, key, value, n_head, shape_order, + pse=self.pse, + padding_mask=None, + atten_mask=attention_mask, + scale=scale, + pse_type=self.pse_type, + pre_tokens=self.config.pre_tockens, + next_tokens=self.config.next_tockens, + keep_prob=1 - self.attention_dropout.p, + inner_precise=0, + sparse_mode=sparse_mode, + actual_seq_qlen=actual_seq_qlen, + actual_seq_kvlen=actual_seq_kvlen + )[0] + + if packed_seq_params is not None: + output = ( + rearrange(output, '(b s) h d -> s b (h d)', s=seq_length, b=bsz) + ) + + return output \ No newline at end of file diff --git a/mindspeed/core/transformer/flash_attention/generate_mask/__init__.py b/mindspeed/core/transformer/flash_attention/generate_mask/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mindspeed/core/transformer/flash_attention/generate_mask/adaptor.py b/mindspeed/core/transformer/flash_attention/generate_mask/adaptor.py new file mode 100644 index 0000000000000000000000000000000000000000..f27dae85ebdafd28076036a016652063e9d84f1e --- /dev/null +++ b/mindspeed/core/transformer/flash_attention/generate_mask/adaptor.py @@ -0,0 +1,29 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +from functools import wraps + +from megatron.core.transformer.enums import AttnMaskType +from mindspeed.core.transformer.flash_attention.generate_mask.generate_mask import get_attention_mask + + +def dot_product_attention_forward_wrapper(fn): + @wraps(fn) + def wrapper( + self, query, key, value, + attention_mask, + attn_mask_type, + attention_bias, + packed_seq_params + ): + if ( + attention_mask is None and + self.attn_mask_type == AttnMaskType.causal + ) and not getattr(self.config, 'is_llava', False): + self.config.sparse_mode = 2 + attention_mask = get_attention_mask(self.config) + return fn( + self, query, key, value, + attention_mask, attn_mask_type, attention_bias, packed_seq_params + ) + return wrapper \ No newline at end of file diff --git a/mindspeed/core/transformer/flash_attention/generate_mask/generate_mask.py b/mindspeed/core/transformer/flash_attention/generate_mask/generate_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..b4bad6d0f6d31d8b87e9f4332c4219d7a9701cce --- /dev/null +++ b/mindspeed/core/transformer/flash_attention/generate_mask/generate_mask.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import warnings + +import torch + +_GLOBAL_ATTN_MASK = None + + +def set_attention_mask(attn_mask): + global _GLOBAL_ATTN_MASK + _GLOBAL_ATTN_MASK = attn_mask + + +def generate_attention_mask(args, compress, device): + global _GLOBAL_ATTN_MASK + if not args.use_flash_attn: + warnings.warn("Flash Attention is highly recommended") + _GLOBAL_ATTN_MASK = ( + torch.tril( + torch.ones( + [args.micro_batch_size, 1, args.seq_length, args.seq_length], + dtype=bool, + device=device + ), + diagonal=-(args.pre_tockens + 1) + ) + + torch.triu( + torch.ones( + [args.micro_batch_size, 1, args.seq_length, args.seq_length], + dtype=bool, + device=device + ), + diagonal=args.next_tockens + 1 + ) + ) + return + + if compress: + seq_len = 2048 + else: + seq_len = args.seq_length + + _GLOBAL_ATTN_MASK = torch.triu( + torch.ones( + (seq_len, seq_len), + device=device, + dtype=torch.bool + ), + diagonal=1 + ) + + +def get_attention_mask(args): + global _GLOBAL_ATTN_MASK + if _GLOBAL_ATTN_MASK is not None: + return _GLOBAL_ATTN_MASK + + device = 'npu' + compress = True + + generate_attention_mask(args, compress, device) + + return _GLOBAL_ATTN_MASK \ No newline at end of file diff --git a/mindspeed/features_manager/__init__.py b/mindspeed/features_manager/__init__.py index f31b29cb08ed31a4d97eff513360cbf523c63c3f..4fffefd3a7eb8ac701ff732397ac969ddbb57682 100644 --- a/mindspeed/features_manager/__init__.py +++ b/mindspeed/features_manager/__init__.py @@ -52,6 +52,10 @@ from mindspeed.features_manager.optimizer.virtual_optimizer import ( ) from mindspeed.features_manager.pipeline_parallel.noop_layers import NoopLayersFeature +from mindspeed.features_manager.transformer.flash_attention.fusion_attention_v2_feature import FusionAttentionV2Feature +from mindspeed.features_manager.transformer.flash_attention.alibi_feature import AlibiFeature +from mindspeed.features_manager.transformer.flash_attention.generate_mask_feature import GenerateMaskFeature + FEATURES_LIST = [ # Functional features ProfilerDefaultFeature(), @@ -85,6 +89,12 @@ FEATURES_LIST_V2 = ( UnalignedLinearFeature(), # llava-multimodal LlavaModel(), + + # Transformer flash attention features + FusionAttentionV2Feature(), + AlibiFeature(), + GenerateMaskFeature(), + # MoeExperts use gemm MoEGmmFeature(), # MoeTp2EpFeature diff --git a/mindspeed/features_manager/transformer/__init__.py b/mindspeed/features_manager/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mindspeed/features_manager/transformer/flash_attention/__init__.py b/mindspeed/features_manager/transformer/flash_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mindspeed/features_manager/transformer/flash_attention/alibi_feature.py b/mindspeed/features_manager/transformer/flash_attention/alibi_feature.py new file mode 100644 index 0000000000000000000000000000000000000000..a190780fcf4edc38ff1d5ee30bbcee35e86536c4 --- /dev/null +++ b/mindspeed/features_manager/transformer/flash_attention/alibi_feature.py @@ -0,0 +1,97 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +from logging import getLogger +from argparse import ArgumentParser + +from mindspeed.features_manager.feature import MindSpeedFeature + + +class AlibiFeature(MindSpeedFeature): + """ + Attention positional embeding. + To enable this feature, the reference is as follows . + + Usage: + "--position-embedding-type alibi" + "--alibi-fusion-attn-type [2, 3]" + "[--alibi-diagonal-opposite]" + """ + + def __init__(self): + super().__init__( + 'position-embedding-type', + optimization_level=2 + ) + + def is_need_apply(self, args): + pse = getattr(args, self.feature_name, None) + need_apply = False + if pse == 'alibi': + need_apply = True + return ( + self.optimization_level <= args.optimization_level and + need_apply + ) or self.default_patches + + def register_args(self, parser: ArgumentParser): + self.add_parser_argument_choices_value( + parser, + "--position-embedding-type", + 'alibi' + ) + + group = parser.add_argument_group(title='alibi') + group.add_argument( + '--square-alibi-mask', + action='store_true', + default=False, + help='attention mask of alibi is squared' + ) + group.add_argument( + '--fill-neg-inf', + action='store_true', + default=False, + help='fill alibi with negative inf' + ) + + group.add_argument( + '--alibi-fusion-attn-type', + type=int, + help='alibi pse type, support for 0,2,3' + ) + group.add_argument( + '--alibi-diagonal-opposite', + action='store_true', + default=False, + help='make alibi diagonal opposite' + ) + + def validate_args(self, args): + if args.alibi_fusion_attn_type is not None: + if args.alibi_fusion_attn_type not in [0, 2, 3]: + raise AssertionError( + '--alibi-fusion-attn-type only support for `0, 2, 3`' + ) + if args.alibi_fusion_attn_type == 0: + raise AssertionError( + 'fa v2 only support compress model currently. ' + 'please use 2 or 3' + ) + # alibi is only support FA2 + if args.alibi_fusion_attn_type in [2, 3]: + args.use_fusion_attn_v2 = True + + if args.use_fusion_attn_v2: + args.use_flash_attn = True + print( + '[WARNING] \"use_fusion_attn_v2\" is not recommended.' \ + 'This feature is not officially released.' + ) + + def register_patches(self, patch_manager, args): + from mindspeed.core.transformer.flash_attention.alibi.adaptor import MindSpeedDotProductAttention + patch_manager.register_patch( + 'megatron.core.transformer.dot_product_attention.DotProductAttention', + MindSpeedDotProductAttention + ) \ No newline at end of file diff --git a/mindspeed/features_manager/transformer/flash_attention/fusion_attention_v2_feature.py b/mindspeed/features_manager/transformer/flash_attention/fusion_attention_v2_feature.py new file mode 100644 index 0000000000000000000000000000000000000000..c274b7b35d1c9dedcca8a3bb78e75b0a88c56549 --- /dev/null +++ b/mindspeed/features_manager/transformer/flash_attention/fusion_attention_v2_feature.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +from argparse import ArgumentParser + +from mindspeed.features_manager.feature import MindSpeedFeature + + +class FusionAttentionV2Feature(MindSpeedFeature): + ''' + fusion attention v2 is a expand to fusion attention v1 + and only support for alibi positional embeding currently. + Close by default. + ''' + + def __init__(self): + super().__init__( + 'use-fusion-attn-v2', + optimization_level=2 + ) + + def register_args(self, parser: ArgumentParser): + group = parser.add_argument_group(title='fusion attention v2') + group.add_argument( + '--use-fusion-attn-v2', + action='store_true', + default=False, + help='enalbe fusion attention v2' + ) + group.add_argument( + '--pre-tockens', + type=int, + default=65536, + help='pre-tockens is used by Flash attention' + ) + group.add_argument( + '--next-tockens', + type=int, + default=0, + help='next-tockens is used by Flash attention' + ) + group.add_argument( + '--sparse-mode', + type=int, + default=0, + choices=[0, 1, 2, 3, 4, 5, 6, 7, 8], + help='mask type for fusion attention' + ) + + def validate_args(self, args): + if args.use_fusion_attn_v2: + args.use_flash_attn = True diff --git a/mindspeed/features_manager/transformer/flash_attention/generate_mask_feature.py b/mindspeed/features_manager/transformer/flash_attention/generate_mask_feature.py new file mode 100644 index 0000000000000000000000000000000000000000..16705975899d797ee6ebf952fcbc226f49cc61b8 --- /dev/null +++ b/mindspeed/features_manager/transformer/flash_attention/generate_mask_feature.py @@ -0,0 +1,35 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +from typing import Any + +from mindspeed.features_manager.feature import MindSpeedFeature + + +class GenerateMaskFeature(MindSpeedFeature): + + def __init__(self): + super().__init__( + 'no-create-attention-mask-in-dataloader', + optimization_level=2 + ) + + def is_need_apply(self, args: Any) -> bool: + """Check the feature is need to apply.""" + need_apply = True + + # can't find feature name, need to enable + if getattr(args, self.feature_name, None): + need_apply = False + + return ( + self.optimization_level <= args.optimization_level and + need_apply + ) or self.default_patches + + def register_patches(self, patch_manager, args): + from mindspeed.core.transformer.flash_attention.generate_mask.adaptor import dot_product_attention_forward_wrapper + patch_manager.register_patch( + 'megatron.core.transformer.dot_product_attention.DotProductAttention.forward', + dot_product_attention_forward_wrapper + ) diff --git a/requirements.txt b/requirements.txt index 1044aea477e7d37a270eeb6617fee247eeb0ee73..b7e8f1c8bb7b8134a3f536f68a09a23d9362649d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,7 @@ einops scipy sentencepiece pytest +pytest-mock tokenizers<=0.20.3 transformers>=4.43.2 gpytorch diff --git a/tests_extend_v2/unit_tests/features/flash_attention/test_alibi.py b/tests_extend_v2/unit_tests/features/flash_attention/test_alibi.py new file mode 100644 index 0000000000000000000000000000000000000000..dd002ad8fdffd19c7f05ff33fa8d41807acfe56c --- /dev/null +++ b/tests_extend_v2/unit_tests/features/flash_attention/test_alibi.py @@ -0,0 +1,72 @@ +import pytest +import torch +import torch_npu + +from megatron.training.global_vars import set_args +from megatron.training.arguments import parse_args +from megatron.core.transformer.transformer_config import TransformerConfig + +from mindspeed import megatron_adaptor_v2 +from mindspeed.core.transformer.flash_attention.alibi.adaptor import MindSpeedDotProductAttention + +DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] + + +def run_fusion_attn_with_pse_alibi(bs, seq_len, dtype): + from megatron.core.transformer.enums import AttnMaskType + + args = parse_args(None, True) + set_args(args) + + config = TransformerConfig( + num_layers=2, + hidden_size=32, + num_attention_heads=4, + attention_dropout=0.0, + params_dtype=dtype + ) + + # extra arguments mindspeed needed + config.use_flash_attn = True + config.use_fusion_attn_v2 = True + config.alibi_fusion_attn_type = 2 + config.sparse_mode = 2 + config.seq_length = seq_len + config.alibi_diagonal_opposite = False + + attn = MindSpeedDotProductAttention( + config=config, + layer_number=1, + attn_mask_type=AttnMaskType.causal, + attention_type='self' + ) + + # attn.pse should exist and not be None + assert attn.pse is not None + + b, n, s, d = bs, 4, seq_len, 8 + + q = torch.randn(s, b, n, d, dtype=dtype, device='npu', requires_grad=True) + k = torch.randn(s, b, n, d, dtype=dtype, device='npu', requires_grad=True) + v = torch.randn(s, b, n, d, dtype=dtype, device='npu', requires_grad=True) + + # global attn mask will be generated at DotProductAttention forward wrapper + out = attn(q, k, v, None, None, None, None) + assert isinstance(out, torch.Tensor) + + +class TestAlibi(): + + @pytest.mark.skipif(DEVICE_NAME != 'Ascend910B', reason='device type is not supported, skip this UT!') + def test_alibi(self, mocker): + mock_world_size = mocker.patch( + "megatron.core.parallel_state.get_tensor_model_parallel_world_size", + return_value=1 + ) + mock_rank = mocker.patch( + "megatron.core.parallel_state.get_tensor_model_parallel_rank", + return_value=0 + ) + run_fusion_attn_with_pse_alibi(2, 256, torch.bfloat16) + mock_world_size.assert_called() + mock_rank.assert_called_once()