From b44dea8243b274fc7781476be0059d0f5d6a3ae9 Mon Sep 17 00:00:00 2001 From: HanhuiChen Date: Tue, 5 Aug 2025 20:14:31 +0800 Subject: [PATCH 1/9] refactor dot product attention --- .../custom_dot_product_attention.py | 439 ++++++++++++++++++ .../infer_dot_product_attention.py | 439 ++++++++++++++++++ mindspeed_llm/features_manager/__init__.py | 2 + 3 files changed, 880 insertions(+) create mode 100644 mindspeed_llm/core/transformer/custom_dot_product_attention.py create mode 100644 mindspeed_llm/core/transformer/infer_dot_product_attention.py diff --git a/mindspeed_llm/core/transformer/custom_dot_product_attention.py b/mindspeed_llm/core/transformer/custom_dot_product_attention.py new file mode 100644 index 000000000..653e8fcf1 --- /dev/null +++ b/mindspeed_llm/core/transformer/custom_dot_product_attention.py @@ -0,0 +1,439 @@ +# coding=utf-8 +# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. + +import logging +import math +from functools import wraps +from typing import Union, List + +import torch +import torch_npu +from torch import Tensor +from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax +from megatron.training import get_args +from megatron.core import mpu, parallel_state, tensor_parallel +from megatron.core.transformer import TransformerConfig +from megatron.core.transformer.dot_product_attention import DotProductAttention +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.utils import attention_mask_func +from megatron.core.utils import divide +from megatron.core.packed_seq_params import PackedSeqParams +from mindspeed.core.context_parallel.ulysses_context_parallel.ulysses_context_parallel import ulyssesattn_context_parallel +from mindspeed.core.parallel_state import (get_context_parallel_group_for_hybrid_ring, + get_context_parallel_for_hybrid_ring_world_size, + get_context_parallel_for_hybrid_ring_rank, + get_context_parallel_for_hybrid_ring_global_ranks, + get_ring_ranks_for_intra_window, + get_ring_ranks_for_inter_window_kv, + get_ring_ranks_for_inter_window_dkv, + get_ring_group_for_intra_window, + get_ring_group_for_intra_window_send_recv_overlap) +from mindspeed.core.context_parallel.adaptive_context_parallel.adaptive_context_parallel import adaptive_attn_context_parallel +from mindspeed.core.context_parallel.utils import get_scheduling_info +from mindspeed.model.transformer import get_attention_mask +from mindspeed.core.context_parallel.ring_context_parallel.context_parallel_kv_cache import get_cache_policy +from mindspeed.utils import get_actual_seq_len, compute_qkv_index, get_position_ids + +from mindspeed_llm.tasks.models.common.alibi import Alibi +from mindspeed_llm.core.context_parallel.ring_context_parallel import ringattn_context_parallel +from mindspeed_llm.training.utils import recompute_valid_actual_seq_len + +logger = logging.getLogger(__name__) + +try: + from einops import rearrange +except ImportError: + rearrange = None + +ACTUAL_SEQ_LEN_THRESHOLD = 2048 + +def dot_product_attention_init( + self, + config: TransformerConfig, + layer_number: int, + attn_mask_type: AttnMaskType, + attention_type: str, + attention_dropout: float = None, + softmax_scale: float = None, + cp_comm_type: str = None, +): + cp_size = config.context_parallel_size + config.context_parallel_size = 1 + + super(DotProductAttention, self).__init__(config=config) + assert ( + self.config.context_parallel_size == 1 + ), "Context parallelism is only supported by TEDotProductAttention!" + + assert ( + self.config.window_size is None + ), "Sliding Window Attention is only supported by TEDotProductAttention!" + + self.layer_number = max(1, layer_number) + self.attn_mask_type = attn_mask_type + self.attention_type = attention_type # unused for now + + projection_size = self.config.kv_channels * self.config.num_attention_heads + args = get_args() + # Per attention head and per partition values. + world_size = args.tp_x if args.tp_2d else parallel_state.get_tensor_model_parallel_world_size() + self.hidden_size_per_partition = divide(projection_size, world_size) + self.hidden_size_per_attention_head = divide(projection_size, config.num_attention_heads) + self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size) + self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size) + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.config.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + + self.scale_mask_softmax = FusedScaleMaskSoftmax( + input_in_fp16=self.config.fp16, + input_in_bf16=self.config.bf16, + attn_mask_type=self.attn_mask_type, + scaled_masked_softmax_fusion=self.config.masked_softmax_fusion, + mask_func=attention_mask_func, + softmax_in_fp32=self.config.attention_softmax_in_fp32, + scale=coeff, + ) + + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.attention_dropout = torch.nn.Dropout( + self.config.attention_dropout if attention_dropout is None else attention_dropout + ) + + config.context_parallel_size = cp_size + self.pse = None + self.pse_type = None + self.attn_logit_softcapping = args.attn_logit_softcapping + self.square_alibi_mask = args.square_alibi_mask + self.fill_neg_inf = args.fill_neg_inf + self.beta = 1.0 + self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling + if self.apply_query_key_layer_scaling: + self.beta = 1.0 / self.layer_number + + if args.position_embedding_type == 'alibi': + get_alibi(self, args.seq_length) + self.alibi_output_size = None + else: + self.alibi = None + + if args.query_pre_attn_scalar: + self.norm_factor = args.query_pre_attn_scalar ** 0.5 + self.scale_mask_softmax.scale = 1.0 + self.softmax_scale = 1.0 / self.norm_factor + + +def get_alibi(self, seq_length): + args = get_args() + self.alibi = Alibi() + alibi = self.alibi._build_alibi_tensor(seq_length, + args.num_attention_heads, + args.square_alibi_mask, + args.fill_neg_inf, + ).to(torch.cuda.current_device()) + if args.params_dtype == torch.float16: + alibi = alibi.to(torch.float16) + elif args.params_dtype == torch.bfloat16: + alibi = alibi.to(torch.bfloat16) + self.alibi.alibi = alibi + +def dot_product_attention_forward_wrapper(fn): + @wraps(fn) + def wrapper(self, query, key, value, attention_mask, attn_mask_type, attention_bias=None, packed_seq_params=None): + if attention_mask is None: + attention_mask = get_attention_mask() + + args = get_args() + if args.use_flash_attn and args.tp_2d: + from mindspeed.core.transformer.dot_product_attention import dot_product_attention_forward + return dot_product_attention_forward(self, query, key, value, attention_mask, attn_mask_type, attention_bias, packed_seq_params) + if self.config.context_parallel_size > 1 and args.context_parallel_algo == "ulysses_cp_algo" and args.context_parallel_kv_cache_policy: + return do_ulyssesattn_context_parallel(self, query, key, value, attention_mask, attn_mask_type, packed_seq_params) + # =================================== + # Raw attention scores. [b, n/p, s, s] + # =================================== + + # expand the key and value [sk, b, ng, hn] -> [sk, b, np, hn] + # This is a noop for normal attention where ng == np. When using group query attention this + # creates a view that has the keys and values virtually repeated along their dimension to + # match the number of queries. + + heads_per_gqa_group = self.num_attention_heads_per_partition // self.num_query_groups_per_partition + if not args.use_flash_attn: + if heads_per_gqa_group > 1: + key = key.repeat_interleave(heads_per_gqa_group, dim=2) + value = value.repeat_interleave(heads_per_gqa_group, dim=2) + else: + # Do repeat KV to support PFA + should_kv_repeat_before_pfa = hasattr(args, 'use_kv_cache') and args.use_kv_cache + if heads_per_gqa_group > 1 and should_kv_repeat_before_pfa: + key = key.repeat_interleave(heads_per_gqa_group, dim=2) + value = value.repeat_interleave(heads_per_gqa_group, dim=2) + + return flash_attention_forward(self, query, key, value, attention_mask, attn_mask_type, + packed_seq_params) + + output_size = ( + query.size(1), + query.size(2), + query.size(0), + key.size(0), + ) + + # [sq, b, np, hn] -> [sq, b * np, hn] + # This will be a simple view when doing normal attention, but in group query attention + # the key and value tensors are repeated to match the queries so you can't use simple strides + # to extract the queries. + query = query.reshape(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key = key.view(output_size[3], output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + if self.alibi is None: + matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor( + (output_size[0] * output_size[1], output_size[2], output_size[3]), query.dtype, "mpu", + ) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query.transpose(0, 1), # [b * np, sq, hn] + key.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) + else: + if self.alibi.alibi_pse is None or self.alibi.output_size != output_size: + self.alibi.output_size = output_size + self.alibi.get_alibi_pse(attention_mask, output_size[0], output_size[2], output_size[3]) + + q_trans = query.transpose(0, 1).contiguous() + k_trans = key.transpose(0, 1).transpose(1, 2).contiguous() + matmul_result = self.beta * self.alibi.alibi_pse + torch.bmm(q_trans, k_trans) * (1.0 / self.norm_factor) + + if self.attn_logit_softcapping is not None: + matmul_result = matmul_result / self.attn_logit_softcapping + matmul_result = torch.tanh(matmul_result) + matmul_result = matmul_result * self.attn_logit_softcapping + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + if self.square_alibi_mask: + attention_scores = torch.max( + attention_scores, torch.tensor(torch.finfo(attention_scores.dtype).min) + ) + attention_probs = torch.nn.functional.softmax(attention_scores, -1) + else: + attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + + if not self.config.sequence_parallel: + with tensor_parallel.get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(attention_probs) + else: + attention_probs = self.attention_dropout(attention_probs) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = ( + value.size(1), + value.size(2), + query.size(0), + value.size(3), + ) + + # change view [sk, b * np, hn] + value = value.view(value.size(0), output_size[0] * output_size[1], -1) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + + # matmul: [b * np, sq, hn] + context = torch.bmm(attention_probs, value.transpose(0, 1)) + + # change view [b, np, sq, hn] + context = context.view(*output_size) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context = context.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_shape = context.size()[:-2] + (self.hidden_size_per_partition,) + context = context.view(*new_context_shape) + + return context + + return wrapper + + +def flash_attention_forward( + self, + query: Union[Tensor, List[Tensor]], + key: Union[Tensor, List[Tensor]], + value: Tensor, + attention_mask, + attn_mask_type, + packed_seq_params, +): + query_rope, key_rope = None, None + if isinstance(query, List): + query, query_rope = query[0], query[1] + if isinstance(key, List): + key, key_rope = key[0], key[1] + + args = get_args() + + seq_length, batch_size, n_head, head_dim = query.shape[0], query.shape[1], query.shape[2], query.shape[3] + scale = 1.0 / math.sqrt(self.hidden_size_per_attention_head) \ + if self.scale_mask_softmax.scale is None else self.softmax_scale + actual_seq_len = get_actual_seq_len() + if actual_seq_len is not None and args.mtp_num_layers: + actual_seq_len = actual_seq_len[self.mtp_idx] + + if args.context_parallel_size > 1 and args.context_parallel_algo in ['megatron_cp_algo', 'hybrid_cp_algo', + 'adaptive_cp_algo', 'hybrid_adaptive_cp_algo']: + query, key, value = [rearrange(x, 's b h d -> s b (h d)') for x in [query, key, value]] + return do_ring_context_parallel( + query, key, value, head_num=n_head, softmax_scale=scale, attn_mask=attention_mask, pse=self.pse, + pse_type=self.pse_type, packed_seq_params=packed_seq_params) + + if args.shape_order == "TND": # varlen FA + if args.mla_fa_divide_qk: + query, key, value = [rearrange(x, 's b h d -> (b s) h d') for x in [query, key, value]] + if query_rope is not None and key_rope is not None: + query_rope, key_rope = [rearrange(x, 's b h d -> (b s) h d') for x in [query_rope, key_rope]] + else: + query, key, value = [rearrange(x, 's b h d -> (s b) h d') for x in [query, key, value]] + args.sparse_mode = 4 + elif args.shape_order == "BNSD": + query, key, value = [rearrange(x, 's b h d -> b h s d') for x in [query, key, value]] + else: + query, key, value = [rearrange(x, 's b h d -> s b (h d)') for x in [query, key, value]] + args.shape_order = "SBH" + + if self.hidden_size_per_attention_head == 0: + raise AssertionError("self.hidden_size_per_attention_head should not be ZERO.") + if not hasattr(self, 'attention_mask') or \ + self.attention_mask is None or \ + self.attention_mask.shape[0] != seq_length: + if self.alibi is not None: + self.attention_mask = torch.triu( + torch.ones(seq_length, seq_length), + 1).bool().npu() + else: + self.attention_mask = attention_mask + + use_sliding_windows = args.sliding_window is not None and seq_length > args.sliding_window + + if use_sliding_windows: + args.pre_tockens = args.sliding_window + args.sparse_mode = 4 + + pse = None + size_record = key.shape + if self.alibi is not None and (self.alibi.output_size != size_record) and pse is None: + if args.shape_order != 'SBH': + raise ValueError( + 'FlashAttention with Alibi requires for SBH shape_order, but is {}.'.format(args.shape_order)) + + self.alibi.output_size = size_record + self.alibi.get_alibi_pse(self.attention_mask, batch_size, query.shape[0], key.shape[0]) + + if self.alibi and pse is None: + pse = self.alibi.alibi_pse.reshape( + batch_size, n_head, self.alibi.alibi_pse.size(1), -1) + if hasattr(args, 'use_kv_cache') and args.use_kv_cache: + pse = pse * self.beta + else: + pse = pse * self.beta * self.norm_factor + args.pre_tockens = seq_length + args.sparse_mode = 0 + + if hasattr(args, 'use_kv_cache') and args.use_kv_cache: + query, key, value = [rearrange(x, 's b h -> b s h') for x in [query, key, value]] + if query.shape[1] == 1 and query.shape[1] != key.shape[1]: + output = torch_npu.npu_incre_flash_attention( + query, key, value, + num_heads=n_head, + input_layout="BSH", + pse_shift=pse, + padding_mask=None, + scale_value=scale + ) + else: + output = torch_npu.npu_prompt_flash_attention( + query, key, value, + num_heads=n_head, + input_layout="BSH", + pse_shift=pse, + sparse_mode=args.sparse_mode, + padding_mask=None, + atten_mask=self.attention_mask, + scale_value=scale, + pre_tokens=args.pre_tockens, + next_tokens=args.next_tockens + ) + output = output.transpose(0, 1) + else: + if not args.mla_fa_divide_qk: + if actual_seq_len is not None and len(actual_seq_len) > ACTUAL_SEQ_LEN_THRESHOLD: + logger.warning("flash-attention get a long actual_seq_len, maybe create a coredump!") + actual_seq_len = recompute_valid_actual_seq_len(get_position_ids(), actual_seq_len) + + output = torch_npu.npu_fusion_attention( + query, key, value, n_head, args.shape_order, + pse=pse, + padding_mask=None, + atten_mask=self.attention_mask, + actual_seq_qlen=actual_seq_len, + actual_seq_kvlen=actual_seq_len, + scale=scale, + pre_tockens=args.pre_tockens, + next_tockens=args.next_tockens, + keep_prob=1 - self.attention_dropout.p, + inner_precise=0, + sparse_mode=args.sparse_mode + )[0] + else: + output = torch_npu.npu_fusion_attention_v2( + query, key, value, n_head, args.shape_order, + pse=pse, + padding_mask=None, + atten_mask=self.attention_mask, + query_rope=query_rope, + key_rope=key_rope, + actual_seq_qlen=actual_seq_len, + actual_seq_kvlen=actual_seq_len, + scale=scale, + pre_tokens=args.pre_tockens, + next_tokens=args.next_tockens, + keep_prob=1 - self.attention_dropout.p, + inner_precise=0, + sparse_mode=args.sparse_mode + )[0] + + if args.shape_order == "TND": # varlen FA + output = rearrange(output, '(s b) h d -> s b (h d)', s=seq_length) + elif args.shape_order == "BNSD": + output = rearrange(output, 'b h s d -> s b (h d)') + + return output diff --git a/mindspeed_llm/core/transformer/infer_dot_product_attention.py b/mindspeed_llm/core/transformer/infer_dot_product_attention.py new file mode 100644 index 000000000..653e8fcf1 --- /dev/null +++ b/mindspeed_llm/core/transformer/infer_dot_product_attention.py @@ -0,0 +1,439 @@ +# coding=utf-8 +# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. + +import logging +import math +from functools import wraps +from typing import Union, List + +import torch +import torch_npu +from torch import Tensor +from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax +from megatron.training import get_args +from megatron.core import mpu, parallel_state, tensor_parallel +from megatron.core.transformer import TransformerConfig +from megatron.core.transformer.dot_product_attention import DotProductAttention +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.utils import attention_mask_func +from megatron.core.utils import divide +from megatron.core.packed_seq_params import PackedSeqParams +from mindspeed.core.context_parallel.ulysses_context_parallel.ulysses_context_parallel import ulyssesattn_context_parallel +from mindspeed.core.parallel_state import (get_context_parallel_group_for_hybrid_ring, + get_context_parallel_for_hybrid_ring_world_size, + get_context_parallel_for_hybrid_ring_rank, + get_context_parallel_for_hybrid_ring_global_ranks, + get_ring_ranks_for_intra_window, + get_ring_ranks_for_inter_window_kv, + get_ring_ranks_for_inter_window_dkv, + get_ring_group_for_intra_window, + get_ring_group_for_intra_window_send_recv_overlap) +from mindspeed.core.context_parallel.adaptive_context_parallel.adaptive_context_parallel import adaptive_attn_context_parallel +from mindspeed.core.context_parallel.utils import get_scheduling_info +from mindspeed.model.transformer import get_attention_mask +from mindspeed.core.context_parallel.ring_context_parallel.context_parallel_kv_cache import get_cache_policy +from mindspeed.utils import get_actual_seq_len, compute_qkv_index, get_position_ids + +from mindspeed_llm.tasks.models.common.alibi import Alibi +from mindspeed_llm.core.context_parallel.ring_context_parallel import ringattn_context_parallel +from mindspeed_llm.training.utils import recompute_valid_actual_seq_len + +logger = logging.getLogger(__name__) + +try: + from einops import rearrange +except ImportError: + rearrange = None + +ACTUAL_SEQ_LEN_THRESHOLD = 2048 + +def dot_product_attention_init( + self, + config: TransformerConfig, + layer_number: int, + attn_mask_type: AttnMaskType, + attention_type: str, + attention_dropout: float = None, + softmax_scale: float = None, + cp_comm_type: str = None, +): + cp_size = config.context_parallel_size + config.context_parallel_size = 1 + + super(DotProductAttention, self).__init__(config=config) + assert ( + self.config.context_parallel_size == 1 + ), "Context parallelism is only supported by TEDotProductAttention!" + + assert ( + self.config.window_size is None + ), "Sliding Window Attention is only supported by TEDotProductAttention!" + + self.layer_number = max(1, layer_number) + self.attn_mask_type = attn_mask_type + self.attention_type = attention_type # unused for now + + projection_size = self.config.kv_channels * self.config.num_attention_heads + args = get_args() + # Per attention head and per partition values. + world_size = args.tp_x if args.tp_2d else parallel_state.get_tensor_model_parallel_world_size() + self.hidden_size_per_partition = divide(projection_size, world_size) + self.hidden_size_per_attention_head = divide(projection_size, config.num_attention_heads) + self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size) + self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size) + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.config.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + + self.scale_mask_softmax = FusedScaleMaskSoftmax( + input_in_fp16=self.config.fp16, + input_in_bf16=self.config.bf16, + attn_mask_type=self.attn_mask_type, + scaled_masked_softmax_fusion=self.config.masked_softmax_fusion, + mask_func=attention_mask_func, + softmax_in_fp32=self.config.attention_softmax_in_fp32, + scale=coeff, + ) + + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.attention_dropout = torch.nn.Dropout( + self.config.attention_dropout if attention_dropout is None else attention_dropout + ) + + config.context_parallel_size = cp_size + self.pse = None + self.pse_type = None + self.attn_logit_softcapping = args.attn_logit_softcapping + self.square_alibi_mask = args.square_alibi_mask + self.fill_neg_inf = args.fill_neg_inf + self.beta = 1.0 + self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling + if self.apply_query_key_layer_scaling: + self.beta = 1.0 / self.layer_number + + if args.position_embedding_type == 'alibi': + get_alibi(self, args.seq_length) + self.alibi_output_size = None + else: + self.alibi = None + + if args.query_pre_attn_scalar: + self.norm_factor = args.query_pre_attn_scalar ** 0.5 + self.scale_mask_softmax.scale = 1.0 + self.softmax_scale = 1.0 / self.norm_factor + + +def get_alibi(self, seq_length): + args = get_args() + self.alibi = Alibi() + alibi = self.alibi._build_alibi_tensor(seq_length, + args.num_attention_heads, + args.square_alibi_mask, + args.fill_neg_inf, + ).to(torch.cuda.current_device()) + if args.params_dtype == torch.float16: + alibi = alibi.to(torch.float16) + elif args.params_dtype == torch.bfloat16: + alibi = alibi.to(torch.bfloat16) + self.alibi.alibi = alibi + +def dot_product_attention_forward_wrapper(fn): + @wraps(fn) + def wrapper(self, query, key, value, attention_mask, attn_mask_type, attention_bias=None, packed_seq_params=None): + if attention_mask is None: + attention_mask = get_attention_mask() + + args = get_args() + if args.use_flash_attn and args.tp_2d: + from mindspeed.core.transformer.dot_product_attention import dot_product_attention_forward + return dot_product_attention_forward(self, query, key, value, attention_mask, attn_mask_type, attention_bias, packed_seq_params) + if self.config.context_parallel_size > 1 and args.context_parallel_algo == "ulysses_cp_algo" and args.context_parallel_kv_cache_policy: + return do_ulyssesattn_context_parallel(self, query, key, value, attention_mask, attn_mask_type, packed_seq_params) + # =================================== + # Raw attention scores. [b, n/p, s, s] + # =================================== + + # expand the key and value [sk, b, ng, hn] -> [sk, b, np, hn] + # This is a noop for normal attention where ng == np. When using group query attention this + # creates a view that has the keys and values virtually repeated along their dimension to + # match the number of queries. + + heads_per_gqa_group = self.num_attention_heads_per_partition // self.num_query_groups_per_partition + if not args.use_flash_attn: + if heads_per_gqa_group > 1: + key = key.repeat_interleave(heads_per_gqa_group, dim=2) + value = value.repeat_interleave(heads_per_gqa_group, dim=2) + else: + # Do repeat KV to support PFA + should_kv_repeat_before_pfa = hasattr(args, 'use_kv_cache') and args.use_kv_cache + if heads_per_gqa_group > 1 and should_kv_repeat_before_pfa: + key = key.repeat_interleave(heads_per_gqa_group, dim=2) + value = value.repeat_interleave(heads_per_gqa_group, dim=2) + + return flash_attention_forward(self, query, key, value, attention_mask, attn_mask_type, + packed_seq_params) + + output_size = ( + query.size(1), + query.size(2), + query.size(0), + key.size(0), + ) + + # [sq, b, np, hn] -> [sq, b * np, hn] + # This will be a simple view when doing normal attention, but in group query attention + # the key and value tensors are repeated to match the queries so you can't use simple strides + # to extract the queries. + query = query.reshape(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key = key.view(output_size[3], output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + if self.alibi is None: + matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor( + (output_size[0] * output_size[1], output_size[2], output_size[3]), query.dtype, "mpu", + ) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query.transpose(0, 1), # [b * np, sq, hn] + key.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) + else: + if self.alibi.alibi_pse is None or self.alibi.output_size != output_size: + self.alibi.output_size = output_size + self.alibi.get_alibi_pse(attention_mask, output_size[0], output_size[2], output_size[3]) + + q_trans = query.transpose(0, 1).contiguous() + k_trans = key.transpose(0, 1).transpose(1, 2).contiguous() + matmul_result = self.beta * self.alibi.alibi_pse + torch.bmm(q_trans, k_trans) * (1.0 / self.norm_factor) + + if self.attn_logit_softcapping is not None: + matmul_result = matmul_result / self.attn_logit_softcapping + matmul_result = torch.tanh(matmul_result) + matmul_result = matmul_result * self.attn_logit_softcapping + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + if self.square_alibi_mask: + attention_scores = torch.max( + attention_scores, torch.tensor(torch.finfo(attention_scores.dtype).min) + ) + attention_probs = torch.nn.functional.softmax(attention_scores, -1) + else: + attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + + if not self.config.sequence_parallel: + with tensor_parallel.get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(attention_probs) + else: + attention_probs = self.attention_dropout(attention_probs) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = ( + value.size(1), + value.size(2), + query.size(0), + value.size(3), + ) + + # change view [sk, b * np, hn] + value = value.view(value.size(0), output_size[0] * output_size[1], -1) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + + # matmul: [b * np, sq, hn] + context = torch.bmm(attention_probs, value.transpose(0, 1)) + + # change view [b, np, sq, hn] + context = context.view(*output_size) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context = context.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_shape = context.size()[:-2] + (self.hidden_size_per_partition,) + context = context.view(*new_context_shape) + + return context + + return wrapper + + +def flash_attention_forward( + self, + query: Union[Tensor, List[Tensor]], + key: Union[Tensor, List[Tensor]], + value: Tensor, + attention_mask, + attn_mask_type, + packed_seq_params, +): + query_rope, key_rope = None, None + if isinstance(query, List): + query, query_rope = query[0], query[1] + if isinstance(key, List): + key, key_rope = key[0], key[1] + + args = get_args() + + seq_length, batch_size, n_head, head_dim = query.shape[0], query.shape[1], query.shape[2], query.shape[3] + scale = 1.0 / math.sqrt(self.hidden_size_per_attention_head) \ + if self.scale_mask_softmax.scale is None else self.softmax_scale + actual_seq_len = get_actual_seq_len() + if actual_seq_len is not None and args.mtp_num_layers: + actual_seq_len = actual_seq_len[self.mtp_idx] + + if args.context_parallel_size > 1 and args.context_parallel_algo in ['megatron_cp_algo', 'hybrid_cp_algo', + 'adaptive_cp_algo', 'hybrid_adaptive_cp_algo']: + query, key, value = [rearrange(x, 's b h d -> s b (h d)') for x in [query, key, value]] + return do_ring_context_parallel( + query, key, value, head_num=n_head, softmax_scale=scale, attn_mask=attention_mask, pse=self.pse, + pse_type=self.pse_type, packed_seq_params=packed_seq_params) + + if args.shape_order == "TND": # varlen FA + if args.mla_fa_divide_qk: + query, key, value = [rearrange(x, 's b h d -> (b s) h d') for x in [query, key, value]] + if query_rope is not None and key_rope is not None: + query_rope, key_rope = [rearrange(x, 's b h d -> (b s) h d') for x in [query_rope, key_rope]] + else: + query, key, value = [rearrange(x, 's b h d -> (s b) h d') for x in [query, key, value]] + args.sparse_mode = 4 + elif args.shape_order == "BNSD": + query, key, value = [rearrange(x, 's b h d -> b h s d') for x in [query, key, value]] + else: + query, key, value = [rearrange(x, 's b h d -> s b (h d)') for x in [query, key, value]] + args.shape_order = "SBH" + + if self.hidden_size_per_attention_head == 0: + raise AssertionError("self.hidden_size_per_attention_head should not be ZERO.") + if not hasattr(self, 'attention_mask') or \ + self.attention_mask is None or \ + self.attention_mask.shape[0] != seq_length: + if self.alibi is not None: + self.attention_mask = torch.triu( + torch.ones(seq_length, seq_length), + 1).bool().npu() + else: + self.attention_mask = attention_mask + + use_sliding_windows = args.sliding_window is not None and seq_length > args.sliding_window + + if use_sliding_windows: + args.pre_tockens = args.sliding_window + args.sparse_mode = 4 + + pse = None + size_record = key.shape + if self.alibi is not None and (self.alibi.output_size != size_record) and pse is None: + if args.shape_order != 'SBH': + raise ValueError( + 'FlashAttention with Alibi requires for SBH shape_order, but is {}.'.format(args.shape_order)) + + self.alibi.output_size = size_record + self.alibi.get_alibi_pse(self.attention_mask, batch_size, query.shape[0], key.shape[0]) + + if self.alibi and pse is None: + pse = self.alibi.alibi_pse.reshape( + batch_size, n_head, self.alibi.alibi_pse.size(1), -1) + if hasattr(args, 'use_kv_cache') and args.use_kv_cache: + pse = pse * self.beta + else: + pse = pse * self.beta * self.norm_factor + args.pre_tockens = seq_length + args.sparse_mode = 0 + + if hasattr(args, 'use_kv_cache') and args.use_kv_cache: + query, key, value = [rearrange(x, 's b h -> b s h') for x in [query, key, value]] + if query.shape[1] == 1 and query.shape[1] != key.shape[1]: + output = torch_npu.npu_incre_flash_attention( + query, key, value, + num_heads=n_head, + input_layout="BSH", + pse_shift=pse, + padding_mask=None, + scale_value=scale + ) + else: + output = torch_npu.npu_prompt_flash_attention( + query, key, value, + num_heads=n_head, + input_layout="BSH", + pse_shift=pse, + sparse_mode=args.sparse_mode, + padding_mask=None, + atten_mask=self.attention_mask, + scale_value=scale, + pre_tokens=args.pre_tockens, + next_tokens=args.next_tockens + ) + output = output.transpose(0, 1) + else: + if not args.mla_fa_divide_qk: + if actual_seq_len is not None and len(actual_seq_len) > ACTUAL_SEQ_LEN_THRESHOLD: + logger.warning("flash-attention get a long actual_seq_len, maybe create a coredump!") + actual_seq_len = recompute_valid_actual_seq_len(get_position_ids(), actual_seq_len) + + output = torch_npu.npu_fusion_attention( + query, key, value, n_head, args.shape_order, + pse=pse, + padding_mask=None, + atten_mask=self.attention_mask, + actual_seq_qlen=actual_seq_len, + actual_seq_kvlen=actual_seq_len, + scale=scale, + pre_tockens=args.pre_tockens, + next_tockens=args.next_tockens, + keep_prob=1 - self.attention_dropout.p, + inner_precise=0, + sparse_mode=args.sparse_mode + )[0] + else: + output = torch_npu.npu_fusion_attention_v2( + query, key, value, n_head, args.shape_order, + pse=pse, + padding_mask=None, + atten_mask=self.attention_mask, + query_rope=query_rope, + key_rope=key_rope, + actual_seq_qlen=actual_seq_len, + actual_seq_kvlen=actual_seq_len, + scale=scale, + pre_tokens=args.pre_tockens, + next_tokens=args.next_tockens, + keep_prob=1 - self.attention_dropout.p, + inner_precise=0, + sparse_mode=args.sparse_mode + )[0] + + if args.shape_order == "TND": # varlen FA + output = rearrange(output, '(s b) h d -> s b (h d)', s=seq_length) + elif args.shape_order == "BNSD": + output = rearrange(output, 'b h s d -> s b (h d)') + + return output diff --git a/mindspeed_llm/features_manager/__init__.py b/mindspeed_llm/features_manager/__init__.py index 4e9e578e4..e0914d6d4 100644 --- a/mindspeed_llm/features_manager/__init__.py +++ b/mindspeed_llm/features_manager/__init__.py @@ -22,6 +22,7 @@ from mindspeed.features_manager import ( UnalignedLinearFeature, UnalignedPipelineFeature, VirtualOptimizerFeature, + AlibiFeature ) from mindspeed.features_manager.feature import MindSpeedFeature from mindspeed.features_manager.features_manager import MindSpeedFeaturesManager @@ -131,6 +132,7 @@ def add_transformer_features(features_list: List[MindSpeedFeature]): MultiTokenPredictionFeature(), # LLM feature TransformerBlockFeature(), + AlibiFeature() ]) -- Gitee From 8c1c0733485898b5264e4e8a7f34a7d9169bf01f Mon Sep 17 00:00:00 2001 From: HanhuiChen Date: Tue, 5 Aug 2025 20:35:50 +0800 Subject: [PATCH 2/9] refactor dot product attention --- mindspeed_llm/features_manager/__init__.py | 5 +- .../flash_attention/alibi_feature.py | 66 +++++++++++++++++++ 2 files changed, 69 insertions(+), 2 deletions(-) create mode 100644 mindspeed_llm/features_manager/transformer/flash_attention/alibi_feature.py diff --git a/mindspeed_llm/features_manager/__init__.py b/mindspeed_llm/features_manager/__init__.py index e0914d6d4..f50f1b698 100644 --- a/mindspeed_llm/features_manager/__init__.py +++ b/mindspeed_llm/features_manager/__init__.py @@ -22,7 +22,6 @@ from mindspeed.features_manager import ( UnalignedLinearFeature, UnalignedPipelineFeature, VirtualOptimizerFeature, - AlibiFeature ) from mindspeed.features_manager.feature import MindSpeedFeature from mindspeed.features_manager.features_manager import MindSpeedFeaturesManager @@ -49,6 +48,7 @@ from mindspeed_llm.features_manager.pipeline_parallel.dualpipev_feature import D from mindspeed_llm.features_manager.pipeline_parallel.noop_layers import NoopLayersFeature from mindspeed_llm.features_manager.tokenizer.build_tokenizer import BuildTokenizerFeature from mindspeed_llm.features_manager.transformer.flash_attention.fusion_attention_feature import FusionAttentionFeature +from mindspeed_llm.features_manager.transformer.flash_attention.alibi_feature import AlibiFeature from mindspeed_llm.features_manager.transformer.mtp import MultiTokenPredictionFeature from mindspeed_llm.features_manager.transformer.multi_latent_attention.mla_feature import MLAFeature from mindspeed_llm.features_manager.transformer.transformer_block import TransformerBlockFeature @@ -132,7 +132,8 @@ def add_transformer_features(features_list: List[MindSpeedFeature]): MultiTokenPredictionFeature(), # LLM feature TransformerBlockFeature(), - AlibiFeature() + # LLM feature + AlibiFeature(), ]) diff --git a/mindspeed_llm/features_manager/transformer/flash_attention/alibi_feature.py b/mindspeed_llm/features_manager/transformer/flash_attention/alibi_feature.py new file mode 100644 index 000000000..c62570b93 --- /dev/null +++ b/mindspeed_llm/features_manager/transformer/flash_attention/alibi_feature.py @@ -0,0 +1,66 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +from logging import getLogger +from argparse import ArgumentParser + +from mindspeed.features_manager.feature import MindSpeedFeature + + +class AlibiFeature(MindSpeedFeature): + """ + Attention positional embedding. + To enable this feature, the reference is as follows . + + Usage: + "--position-embedding-type alibi" + "--alibi-fusion-attn-type 0 or 2" + "[--alibi-diagonal-opposite]" + """ + + def __init__(self): + super().__init__( + 'position-embedding-type', + optimization_level=2 + ) + + def is_need_apply(self, args): + pse = getattr(args, self.feature_name, None) + need_apply = False + if pse == 'alibi': + need_apply = True + return ( + self.optimization_level <= args.optimization_level and + need_apply + ) or self.default_patches + + def register_args(self, parser: ArgumentParser): + self.add_parser_argument_choices_value( + parser, + "--position-embedding-type", + 'alibi' + ) + + group = parser.add_argument_group(title='alibi') + group.add_argument( + '--square-alibi-mask', + action='store_true', + default=False, + help='attention mask of alibi is squared' + ) + group.add_argument( + '--fill-neg-inf', + action='store_true', + default=False, + help='fill alibi with negative inf' + ) + + + def register_patches(self, patch_manager, args): + if int(getattr(args, 'context_parallel_size', 1)) == 1: + # from mindspeed.core.transformer.flash_attention.alibi.adaptor import MindSpeedDotProductAttention + from mindspeed_llm.core.transformer.custom_dot_product_attention import DotProductAttention + patch_manager.register_patch( + 'megatron.core.transformer.dot_product_attention.DotProductAttention', + DotProductAttention + ) \ No newline at end of file -- Gitee From ec8b2be217a10a0ea09b16cbef50cdf795984887 Mon Sep 17 00:00:00 2001 From: HanhuiChen Date: Tue, 12 Aug 2025 17:11:05 +0800 Subject: [PATCH 3/9] refactor dot product attention for non-cp --- .../custom_dot_product_attention.py | 631 +++++++----------- .../infer_dot_product_attention.py | 439 ------------ .../flash_attention/alibi_feature.py | 13 - .../fusion_attention_feature.py | 26 +- 4 files changed, 244 insertions(+), 865 deletions(-) delete mode 100644 mindspeed_llm/core/transformer/infer_dot_product_attention.py diff --git a/mindspeed_llm/core/transformer/custom_dot_product_attention.py b/mindspeed_llm/core/transformer/custom_dot_product_attention.py index 653e8fcf1..7aca7a36a 100644 --- a/mindspeed_llm/core/transformer/custom_dot_product_attention.py +++ b/mindspeed_llm/core/transformer/custom_dot_product_attention.py @@ -3,39 +3,20 @@ import logging import math -from functools import wraps -from typing import Union, List +from typing import List import torch import torch_npu -from torch import Tensor from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax from megatron.training import get_args -from megatron.core import mpu, parallel_state, tensor_parallel -from megatron.core.transformer import TransformerConfig +from megatron.core import mpu, parallel_state from megatron.core.transformer.dot_product_attention import DotProductAttention -from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.utils import attention_mask_func from megatron.core.utils import divide -from megatron.core.packed_seq_params import PackedSeqParams -from mindspeed.core.context_parallel.ulysses_context_parallel.ulysses_context_parallel import ulyssesattn_context_parallel -from mindspeed.core.parallel_state import (get_context_parallel_group_for_hybrid_ring, - get_context_parallel_for_hybrid_ring_world_size, - get_context_parallel_for_hybrid_ring_rank, - get_context_parallel_for_hybrid_ring_global_ranks, - get_ring_ranks_for_intra_window, - get_ring_ranks_for_inter_window_kv, - get_ring_ranks_for_inter_window_dkv, - get_ring_group_for_intra_window, - get_ring_group_for_intra_window_send_recv_overlap) -from mindspeed.core.context_parallel.adaptive_context_parallel.adaptive_context_parallel import adaptive_attn_context_parallel -from mindspeed.core.context_parallel.utils import get_scheduling_info from mindspeed.model.transformer import get_attention_mask -from mindspeed.core.context_parallel.ring_context_parallel.context_parallel_kv_cache import get_cache_policy -from mindspeed.utils import get_actual_seq_len, compute_qkv_index, get_position_ids +from mindspeed.utils import get_actual_seq_len, get_position_ids from mindspeed_llm.tasks.models.common.alibi import Alibi -from mindspeed_llm.core.context_parallel.ring_context_parallel import ringattn_context_parallel from mindspeed_llm.training.utils import recompute_valid_actual_seq_len logger = logging.getLogger(__name__) @@ -47,393 +28,251 @@ except ImportError: ACTUAL_SEQ_LEN_THRESHOLD = 2048 -def dot_product_attention_init( - self, - config: TransformerConfig, - layer_number: int, - attn_mask_type: AttnMaskType, - attention_type: str, - attention_dropout: float = None, - softmax_scale: float = None, - cp_comm_type: str = None, -): - cp_size = config.context_parallel_size - config.context_parallel_size = 1 - - super(DotProductAttention, self).__init__(config=config) - assert ( - self.config.context_parallel_size == 1 - ), "Context parallelism is only supported by TEDotProductAttention!" - - assert ( - self.config.window_size is None - ), "Sliding Window Attention is only supported by TEDotProductAttention!" - - self.layer_number = max(1, layer_number) - self.attn_mask_type = attn_mask_type - self.attention_type = attention_type # unused for now - - projection_size = self.config.kv_channels * self.config.num_attention_heads - args = get_args() - # Per attention head and per partition values. - world_size = args.tp_x if args.tp_2d else parallel_state.get_tensor_model_parallel_world_size() - self.hidden_size_per_partition = divide(projection_size, world_size) - self.hidden_size_per_attention_head = divide(projection_size, config.num_attention_heads) - self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size) - self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size) - - coeff = None - self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - if self.config.apply_query_key_layer_scaling: - coeff = self.layer_number - self.norm_factor *= coeff - - self.scale_mask_softmax = FusedScaleMaskSoftmax( - input_in_fp16=self.config.fp16, - input_in_bf16=self.config.bf16, - attn_mask_type=self.attn_mask_type, - scaled_masked_softmax_fusion=self.config.masked_softmax_fusion, - mask_func=attention_mask_func, - softmax_in_fp32=self.config.attention_softmax_in_fp32, - scale=coeff, - ) - - # Dropout. Note that for a single iteration, this layer will generate - # different outputs on different number of parallel partitions but - # on average it should not be partition dependent. - self.attention_dropout = torch.nn.Dropout( - self.config.attention_dropout if attention_dropout is None else attention_dropout - ) - - config.context_parallel_size = cp_size - self.pse = None - self.pse_type = None - self.attn_logit_softcapping = args.attn_logit_softcapping - self.square_alibi_mask = args.square_alibi_mask - self.fill_neg_inf = args.fill_neg_inf - self.beta = 1.0 - self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling - if self.apply_query_key_layer_scaling: - self.beta = 1.0 / self.layer_number - - if args.position_embedding_type == 'alibi': - get_alibi(self, args.seq_length) - self.alibi_output_size = None - else: - self.alibi = None - - if args.query_pre_attn_scalar: - self.norm_factor = args.query_pre_attn_scalar ** 0.5 - self.scale_mask_softmax.scale = 1.0 - self.softmax_scale = 1.0 / self.norm_factor - - -def get_alibi(self, seq_length): - args = get_args() - self.alibi = Alibi() - alibi = self.alibi._build_alibi_tensor(seq_length, - args.num_attention_heads, - args.square_alibi_mask, - args.fill_neg_inf, - ).to(torch.cuda.current_device()) - if args.params_dtype == torch.float16: - alibi = alibi.to(torch.float16) - elif args.params_dtype == torch.bfloat16: - alibi = alibi.to(torch.bfloat16) - self.alibi.alibi = alibi - -def dot_product_attention_forward_wrapper(fn): - @wraps(fn) - def wrapper(self, query, key, value, attention_mask, attn_mask_type, attention_bias=None, packed_seq_params=None): - if attention_mask is None: - attention_mask = get_attention_mask() +class CustomDotProductAttentionImpl: + """ + Implementation of dot product attention with non-cp support. + """ + + def __init__(self, + config, + layer_number, + attn_mask_type, + attention_type, + attention_dropout: float = None, + softmax_scale: float = None, + cp_comm_type: str = None + ): + super().__init__(config, layer_number, attn_mask_type, attention_type, attention_dropout, softmax_scale, cp_comm_type) args = get_args() - if args.use_flash_attn and args.tp_2d: - from mindspeed.core.transformer.dot_product_attention import dot_product_attention_forward - return dot_product_attention_forward(self, query, key, value, attention_mask, attn_mask_type, attention_bias, packed_seq_params) - if self.config.context_parallel_size > 1 and args.context_parallel_algo == "ulysses_cp_algo" and args.context_parallel_kv_cache_policy: - return do_ulyssesattn_context_parallel(self, query, key, value, attention_mask, attn_mask_type, packed_seq_params) - # =================================== - # Raw attention scores. [b, n/p, s, s] - # =================================== - - # expand the key and value [sk, b, ng, hn] -> [sk, b, np, hn] - # This is a noop for normal attention where ng == np. When using group query attention this - # creates a view that has the keys and values virtually repeated along their dimension to - # match the number of queries. - heads_per_gqa_group = self.num_attention_heads_per_partition // self.num_query_groups_per_partition - if not args.use_flash_attn: - if heads_per_gqa_group > 1: - key = key.repeat_interleave(heads_per_gqa_group, dim=2) - value = value.repeat_interleave(heads_per_gqa_group, dim=2) - else: - # Do repeat KV to support PFA - should_kv_repeat_before_pfa = hasattr(args, 'use_kv_cache') and args.use_kv_cache - if heads_per_gqa_group > 1 and should_kv_repeat_before_pfa: - key = key.repeat_interleave(heads_per_gqa_group, dim=2) - value = value.repeat_interleave(heads_per_gqa_group, dim=2) - - return flash_attention_forward(self, query, key, value, attention_mask, attn_mask_type, - packed_seq_params) - - output_size = ( - query.size(1), - query.size(2), - query.size(0), - key.size(0), + assert getattr(config, 'context_parallel_size', 1) == 1, "CustomDotProductAttention only supported by non-cp (context_parallel_size == 1)" + assert bool(getattr(args, 'use_flash_attn', False)) == True, "CustomDotProductAttention only supported by FlashAttention (args.use_flash_attn == True)" + + self.config = config + self.layer_number = max(1, layer_number) + self.attn_mask_type = attn_mask_type # unused for now + self.attention_type = attention_type + + projection_size = self.config.kv_channels * self.config.num_attention_heads + world_size = args.tp_x if args.tp_2d else parallel_state.get_tensor_model_parallel_world_size() + self.hidden_size_per_partition = divide(projection_size, world_size) + self.hidden_size_per_attention_head = divide(projection_size, config.num_attention_heads) + self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size) + self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size) + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.config.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + + self.scale_mask_softmax = FusedScaleMaskSoftmax( + input_in_fp16=self.config.fp16, + input_in_bf16=self.config.bf16, + attn_mask_type=self.attn_mask_type, + scaled_masked_softmax_fusion=self.config.masked_softmax_fusion, + mask_func=attention_mask_func, + softmax_in_fp32=self.config.attention_softmax_in_fp32, + scale=coeff, ) - # [sq, b, np, hn] -> [sq, b * np, hn] - # This will be a simple view when doing normal attention, but in group query attention - # the key and value tensors are repeated to match the queries so you can't use simple strides - # to extract the queries. - query = query.reshape(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] - key = key.view(output_size[3], output_size[0] * output_size[1], -1) - - # preallocting input tensor: [b * np, sq, sk] - if self.alibi is None: - matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor( - (output_size[0] * output_size[1], output_size[2], output_size[3]), query.dtype, "mpu", - ) - - # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( - matmul_input_buffer, - query.transpose(0, 1), # [b * np, sq, hn] - key.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=(1.0 / self.norm_factor), - ) - else: - if self.alibi.alibi_pse is None or self.alibi.output_size != output_size: - self.alibi.output_size = output_size - self.alibi.get_alibi_pse(attention_mask, output_size[0], output_size[2], output_size[3]) - - q_trans = query.transpose(0, 1).contiguous() - k_trans = key.transpose(0, 1).transpose(1, 2).contiguous() - matmul_result = self.beta * self.alibi.alibi_pse + torch.bmm(q_trans, k_trans) * (1.0 / self.norm_factor) - - if self.attn_logit_softcapping is not None: - matmul_result = matmul_result / self.attn_logit_softcapping - matmul_result = torch.tanh(matmul_result) - matmul_result = matmul_result * self.attn_logit_softcapping - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - - # =========================== - # Attention probs and dropout - # =========================== - - # attention scores and attention mask [b, np, sq, sk] - if self.square_alibi_mask: - attention_scores = torch.max( - attention_scores, torch.tensor(torch.finfo(attention_scores.dtype).min) - ) - attention_probs = torch.nn.functional.softmax(attention_scores, -1) - else: - attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - - if not self.config.sequence_parallel: - with tensor_parallel.get_cuda_rng_tracker().fork(): - attention_probs = self.attention_dropout(attention_probs) - else: - attention_probs = self.attention_dropout(attention_probs) - - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = ( - value.size(1), - value.size(2), - query.size(0), - value.size(3), + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.attention_dropout = torch.nn.Dropout( + self.config.attention_dropout if attention_dropout is None else attention_dropout ) - # change view [sk, b * np, hn] - value = value.view(value.size(0), output_size[0] * output_size[1], -1) - - # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) - - # matmul: [b * np, sq, hn] - context = torch.bmm(attention_probs, value.transpose(0, 1)) - - # change view [b, np, sq, hn] - context = context.view(*output_size) - - # [b, np, sq, hn] --> [sq, b, np, hn] - context = context.permute(2, 0, 1, 3).contiguous() - - # [sq, b, np, hn] --> [sq, b, hp] - new_context_shape = context.size()[:-2] + (self.hidden_size_per_partition,) - context = context.view(*new_context_shape) - - return context - - return wrapper + self.pse = None + self.pse_type = None + self.attn_logit_softcapping = args.attn_logit_softcapping + self.square_alibi_mask = args.square_alibi_mask + self.fill_neg_inf = args.fill_neg_inf + self.beta = 1.0 + self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling + if self.apply_query_key_layer_scaling: + self.beta = 1.0 / self.layer_number + + if args.position_embedding_type == 'alibi': + self.alibi = Alibi() + alibi = self.alibi._build_alibi_tensor(args.seq_length, + args.num_attention_heads, + args.square_alibi_mask, + args.fill_neg_inf, + ).to(torch.cuda.current_device()) + if args.params_dtype == torch.float16: + alibi = alibi.to(torch.float16) + elif args.params_dtype == torch.bfloat16: + alibi = alibi.to(torch.bfloat16) + self.alibi.alibi = alibi + self.alibi_output_size = None + else: + self.alibi = None + if args.query_pre_attn_scalar: + self.norm_factor = args.query_pre_attn_scalar ** 0.5 + self.scale_mask_softmax.scale = 1.0 + self.softmax_scale = 1.0 / self.norm_factor -def flash_attention_forward( + def forward( self, - query: Union[Tensor, List[Tensor]], - key: Union[Tensor, List[Tensor]], - value: Tensor, + query, + key, + value, attention_mask, - attn_mask_type, - packed_seq_params, -): - query_rope, key_rope = None, None - if isinstance(query, List): - query, query_rope = query[0], query[1] - if isinstance(key, List): - key, key_rope = key[0], key[1] - - args = get_args() - - seq_length, batch_size, n_head, head_dim = query.shape[0], query.shape[1], query.shape[2], query.shape[3] - scale = 1.0 / math.sqrt(self.hidden_size_per_attention_head) \ - if self.scale_mask_softmax.scale is None else self.softmax_scale - actual_seq_len = get_actual_seq_len() - if actual_seq_len is not None and args.mtp_num_layers: - actual_seq_len = actual_seq_len[self.mtp_idx] - - if args.context_parallel_size > 1 and args.context_parallel_algo in ['megatron_cp_algo', 'hybrid_cp_algo', - 'adaptive_cp_algo', 'hybrid_adaptive_cp_algo']: - query, key, value = [rearrange(x, 's b h d -> s b (h d)') for x in [query, key, value]] - return do_ring_context_parallel( - query, key, value, head_num=n_head, softmax_scale=scale, attn_mask=attention_mask, pse=self.pse, - pse_type=self.pse_type, packed_seq_params=packed_seq_params) - - if args.shape_order == "TND": # varlen FA - if args.mla_fa_divide_qk: - query, key, value = [rearrange(x, 's b h d -> (b s) h d') for x in [query, key, value]] - if query_rope is not None and key_rope is not None: - query_rope, key_rope = [rearrange(x, 's b h d -> (b s) h d') for x in [query_rope, key_rope]] - else: - query, key, value = [rearrange(x, 's b h d -> (s b) h d') for x in [query, key, value]] - args.sparse_mode = 4 - elif args.shape_order == "BNSD": - query, key, value = [rearrange(x, 's b h d -> b h s d') for x in [query, key, value]] - else: - query, key, value = [rearrange(x, 's b h d -> s b (h d)') for x in [query, key, value]] - args.shape_order = "SBH" - - if self.hidden_size_per_attention_head == 0: - raise AssertionError("self.hidden_size_per_attention_head should not be ZERO.") - if not hasattr(self, 'attention_mask') or \ - self.attention_mask is None or \ - self.attention_mask.shape[0] != seq_length: - if self.alibi is not None: - self.attention_mask = torch.triu( - torch.ones(seq_length, seq_length), - 1).bool().npu() - else: - self.attention_mask = attention_mask - - use_sliding_windows = args.sliding_window is not None and seq_length > args.sliding_window - - if use_sliding_windows: - args.pre_tockens = args.sliding_window - args.sparse_mode = 4 + attn_mask_type=None, + attention_bias=None, + packed_seq_params=None, + ): + if attention_mask is None: + attention_mask = get_attention_mask() - pse = None - size_record = key.shape - if self.alibi is not None and (self.alibi.output_size != size_record) and pse is None: - if args.shape_order != 'SBH': - raise ValueError( - 'FlashAttention with Alibi requires for SBH shape_order, but is {}.'.format(args.shape_order)) + query_rope, key_rope = None, None + if isinstance(query, List): + query, query_rope = query[0], query[1] + if isinstance(key, List): + key, key_rope = key[0], key[1] - self.alibi.output_size = size_record - self.alibi.get_alibi_pse(self.attention_mask, batch_size, query.shape[0], key.shape[0]) + args = get_args() + heads_per_gqa_group = self.num_attention_heads_per_partition // self.num_query_groups_per_partition + should_kv_repeat_before_pfa = hasattr(args, 'use_kv_cache') and args.use_kv_cache + if heads_per_gqa_group > 1 and should_kv_repeat_before_pfa: + key = key.repeat_interleave(heads_per_gqa_group, dim=2) + value = value.repeat_interleave(heads_per_gqa_group, dim=2) + + seq_length, batch_size, n_head, head_dim = query.shape[0], query.shape[1], query.shape[2], query.shape[3] + scale = (1.0 / math.sqrt(self.hidden_size_per_attention_head)) if self.scale_mask_softmax.scale is None \ + else self.softmax_scale + + actual_seq_len = get_actual_seq_len() + if actual_seq_len is not None and args.mtp_num_layers: + actual_seq_len = actual_seq_len[self.mtp_idx] + + if args.shape_order == "TND": # varlen FA + if args.mla_fa_divide_qk: + query, key, value = [rearrange(x, 's b h d -> (b s) h d') for x in [query, key, value]] + if query_rope is not None and key_rope is not None: + query_rope, key_rope = [rearrange(x, 's b h d -> (b s) h d') for x in [query_rope, key_rope]] + else: + query, key, value = [rearrange(x, 's b h d -> (s b) h d') for x in [query, key, value]] + args.sparse_mode = 4 + elif args.shape_order == "BNSD": + query, key, value = [rearrange(x, 's b h d -> b h s d') for x in [query, key, value]] + else: + query, key, value = [rearrange(x, 's b h d -> s b (h d)') for x in [query, key, value]] + args.shape_order = "SBH" + + if self.hidden_size_per_attention_head == 0: + raise AssertionError("self.hidden_size_per_attention_head should not be ZERO.") + if not hasattr(self, 'attention_mask') or \ + self.attention_mask is None or \ + self.attention_mask.shape[0] != seq_length: + if self.alibi is not None: + self.attention_mask = torch.triu( + torch.ones(seq_length, seq_length), + 1).bool().npu() + else: + self.attention_mask = attention_mask + + use_sliding_windows = args.sliding_window is not None and seq_length > args.sliding_window + + if use_sliding_windows: + args.pre_tockens = args.sliding_window + args.sparse_mode = 4 + + pse = None + size_record = key.shape + if self.alibi is not None and (self.alibi.output_size != size_record) and pse is None: + if args.shape_order != 'SBH': + raise ValueError( + 'FlashAttention with Alibi requires for SBH shape_order, but is {}.'.format(args.shape_order)) + + self.alibi.output_size = size_record + self.alibi.get_alibi_pse(self.attention_mask, batch_size, query.shape[0], key.shape[0]) + + if self.alibi and pse is None: + pse = self.alibi.alibi_pse.reshape( + batch_size, n_head, self.alibi.alibi_pse.size(1), -1) + if hasattr(args, 'use_kv_cache') and args.use_kv_cache: + pse = pse * self.beta + else: + pse = pse * self.beta * self.norm_factor + args.pre_tockens = seq_length + args.sparse_mode = 0 - if self.alibi and pse is None: - pse = self.alibi.alibi_pse.reshape( - batch_size, n_head, self.alibi.alibi_pse.size(1), -1) if hasattr(args, 'use_kv_cache') and args.use_kv_cache: - pse = pse * self.beta - else: - pse = pse * self.beta * self.norm_factor - args.pre_tockens = seq_length - args.sparse_mode = 0 - - if hasattr(args, 'use_kv_cache') and args.use_kv_cache: - query, key, value = [rearrange(x, 's b h -> b s h') for x in [query, key, value]] - if query.shape[1] == 1 and query.shape[1] != key.shape[1]: - output = torch_npu.npu_incre_flash_attention( - query, key, value, - num_heads=n_head, - input_layout="BSH", - pse_shift=pse, - padding_mask=None, - scale_value=scale - ) - else: - output = torch_npu.npu_prompt_flash_attention( - query, key, value, - num_heads=n_head, - input_layout="BSH", - pse_shift=pse, - sparse_mode=args.sparse_mode, - padding_mask=None, - atten_mask=self.attention_mask, - scale_value=scale, - pre_tokens=args.pre_tockens, - next_tokens=args.next_tockens - ) - output = output.transpose(0, 1) - else: - if not args.mla_fa_divide_qk: - if actual_seq_len is not None and len(actual_seq_len) > ACTUAL_SEQ_LEN_THRESHOLD: - logger.warning("flash-attention get a long actual_seq_len, maybe create a coredump!") - actual_seq_len = recompute_valid_actual_seq_len(get_position_ids(), actual_seq_len) - - output = torch_npu.npu_fusion_attention( - query, key, value, n_head, args.shape_order, - pse=pse, - padding_mask=None, - atten_mask=self.attention_mask, - actual_seq_qlen=actual_seq_len, - actual_seq_kvlen=actual_seq_len, - scale=scale, - pre_tockens=args.pre_tockens, - next_tockens=args.next_tockens, - keep_prob=1 - self.attention_dropout.p, - inner_precise=0, - sparse_mode=args.sparse_mode - )[0] + query, key, value = [rearrange(x, 's b h -> b s h') for x in [query, key, value]] + if query.shape[1] == 1 and query.shape[1] != key.shape[1]: + output = torch_npu.npu_incre_flash_attention( + query, key, value, + num_heads=n_head, + input_layout="BSH", + pse_shift=pse, + padding_mask=None, + scale_value=scale + ) + else: + output = torch_npu.npu_prompt_flash_attention( + query, key, value, + num_heads=n_head, + input_layout="BSH", + pse_shift=pse, + sparse_mode=args.sparse_mode, + padding_mask=None, + atten_mask=self.attention_mask, + scale_value=scale, + pre_tokens=args.pre_tockens, + next_tokens=args.next_tockens + ) + output = output.transpose(0, 1) else: - output = torch_npu.npu_fusion_attention_v2( - query, key, value, n_head, args.shape_order, - pse=pse, - padding_mask=None, - atten_mask=self.attention_mask, - query_rope=query_rope, - key_rope=key_rope, - actual_seq_qlen=actual_seq_len, - actual_seq_kvlen=actual_seq_len, - scale=scale, - pre_tokens=args.pre_tockens, - next_tokens=args.next_tockens, - keep_prob=1 - self.attention_dropout.p, - inner_precise=0, - sparse_mode=args.sparse_mode - )[0] - - if args.shape_order == "TND": # varlen FA - output = rearrange(output, '(s b) h d -> s b (h d)', s=seq_length) - elif args.shape_order == "BNSD": - output = rearrange(output, 'b h s d -> s b (h d)') - - return output + if not args.mla_fa_divide_qk: + if actual_seq_len is not None and len(actual_seq_len) > ACTUAL_SEQ_LEN_THRESHOLD: + logger.warning("flash-attention get a long actual_seq_len, maybe create a coredump!") + actual_seq_len = recompute_valid_actual_seq_len(get_position_ids(), actual_seq_len) + + output = torch_npu.npu_fusion_attention( + query, key, value, n_head, args.shape_order, + pse=pse, + padding_mask=None, + atten_mask=self.attention_mask, + actual_seq_qlen=actual_seq_len, + actual_seq_kvlen=actual_seq_len, + scale=scale, + pre_tockens=args.pre_tockens, + next_tockens=args.next_tockens, + keep_prob=1 - self.attention_dropout.p, + inner_precise=0, + sparse_mode=args.sparse_mode + )[0] + else: + output = torch_npu.npu_fusion_attention_v2( + query, key, value, n_head, args.shape_order, + pse=pse, + padding_mask=None, + atten_mask=self.attention_mask, + query_rope=query_rope, + key_rope=key_rope, + actual_seq_qlen=actual_seq_len, + actual_seq_kvlen=actual_seq_len, + scale=scale, + pre_tokens=args.pre_tockens, + next_tokens=args.next_tockens, + keep_prob=1 - self.attention_dropout.p, + inner_precise=0, + sparse_mode=args.sparse_mode + )[0] + + if args.shape_order == "TND": # varlen FA + output = rearrange(output, '(s b) h d -> s b (h d)', s=seq_length) + elif args.shape_order == "BNSD": + output = rearrange(output, 'b h s d -> s b (h d)') + + return output + + +class CustomDotProductAttention(CustomDotProductAttentionImpl, DotProductAttention): + + def __init__(self, *args, **kwargs): + CustomDotProductAttentionImpl.__init__(self, *args, **kwargs) \ No newline at end of file diff --git a/mindspeed_llm/core/transformer/infer_dot_product_attention.py b/mindspeed_llm/core/transformer/infer_dot_product_attention.py deleted file mode 100644 index 653e8fcf1..000000000 --- a/mindspeed_llm/core/transformer/infer_dot_product_attention.py +++ /dev/null @@ -1,439 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. - -import logging -import math -from functools import wraps -from typing import Union, List - -import torch -import torch_npu -from torch import Tensor -from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax -from megatron.training import get_args -from megatron.core import mpu, parallel_state, tensor_parallel -from megatron.core.transformer import TransformerConfig -from megatron.core.transformer.dot_product_attention import DotProductAttention -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.utils import attention_mask_func -from megatron.core.utils import divide -from megatron.core.packed_seq_params import PackedSeqParams -from mindspeed.core.context_parallel.ulysses_context_parallel.ulysses_context_parallel import ulyssesattn_context_parallel -from mindspeed.core.parallel_state import (get_context_parallel_group_for_hybrid_ring, - get_context_parallel_for_hybrid_ring_world_size, - get_context_parallel_for_hybrid_ring_rank, - get_context_parallel_for_hybrid_ring_global_ranks, - get_ring_ranks_for_intra_window, - get_ring_ranks_for_inter_window_kv, - get_ring_ranks_for_inter_window_dkv, - get_ring_group_for_intra_window, - get_ring_group_for_intra_window_send_recv_overlap) -from mindspeed.core.context_parallel.adaptive_context_parallel.adaptive_context_parallel import adaptive_attn_context_parallel -from mindspeed.core.context_parallel.utils import get_scheduling_info -from mindspeed.model.transformer import get_attention_mask -from mindspeed.core.context_parallel.ring_context_parallel.context_parallel_kv_cache import get_cache_policy -from mindspeed.utils import get_actual_seq_len, compute_qkv_index, get_position_ids - -from mindspeed_llm.tasks.models.common.alibi import Alibi -from mindspeed_llm.core.context_parallel.ring_context_parallel import ringattn_context_parallel -from mindspeed_llm.training.utils import recompute_valid_actual_seq_len - -logger = logging.getLogger(__name__) - -try: - from einops import rearrange -except ImportError: - rearrange = None - -ACTUAL_SEQ_LEN_THRESHOLD = 2048 - -def dot_product_attention_init( - self, - config: TransformerConfig, - layer_number: int, - attn_mask_type: AttnMaskType, - attention_type: str, - attention_dropout: float = None, - softmax_scale: float = None, - cp_comm_type: str = None, -): - cp_size = config.context_parallel_size - config.context_parallel_size = 1 - - super(DotProductAttention, self).__init__(config=config) - assert ( - self.config.context_parallel_size == 1 - ), "Context parallelism is only supported by TEDotProductAttention!" - - assert ( - self.config.window_size is None - ), "Sliding Window Attention is only supported by TEDotProductAttention!" - - self.layer_number = max(1, layer_number) - self.attn_mask_type = attn_mask_type - self.attention_type = attention_type # unused for now - - projection_size = self.config.kv_channels * self.config.num_attention_heads - args = get_args() - # Per attention head and per partition values. - world_size = args.tp_x if args.tp_2d else parallel_state.get_tensor_model_parallel_world_size() - self.hidden_size_per_partition = divide(projection_size, world_size) - self.hidden_size_per_attention_head = divide(projection_size, config.num_attention_heads) - self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size) - self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size) - - coeff = None - self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - if self.config.apply_query_key_layer_scaling: - coeff = self.layer_number - self.norm_factor *= coeff - - self.scale_mask_softmax = FusedScaleMaskSoftmax( - input_in_fp16=self.config.fp16, - input_in_bf16=self.config.bf16, - attn_mask_type=self.attn_mask_type, - scaled_masked_softmax_fusion=self.config.masked_softmax_fusion, - mask_func=attention_mask_func, - softmax_in_fp32=self.config.attention_softmax_in_fp32, - scale=coeff, - ) - - # Dropout. Note that for a single iteration, this layer will generate - # different outputs on different number of parallel partitions but - # on average it should not be partition dependent. - self.attention_dropout = torch.nn.Dropout( - self.config.attention_dropout if attention_dropout is None else attention_dropout - ) - - config.context_parallel_size = cp_size - self.pse = None - self.pse_type = None - self.attn_logit_softcapping = args.attn_logit_softcapping - self.square_alibi_mask = args.square_alibi_mask - self.fill_neg_inf = args.fill_neg_inf - self.beta = 1.0 - self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling - if self.apply_query_key_layer_scaling: - self.beta = 1.0 / self.layer_number - - if args.position_embedding_type == 'alibi': - get_alibi(self, args.seq_length) - self.alibi_output_size = None - else: - self.alibi = None - - if args.query_pre_attn_scalar: - self.norm_factor = args.query_pre_attn_scalar ** 0.5 - self.scale_mask_softmax.scale = 1.0 - self.softmax_scale = 1.0 / self.norm_factor - - -def get_alibi(self, seq_length): - args = get_args() - self.alibi = Alibi() - alibi = self.alibi._build_alibi_tensor(seq_length, - args.num_attention_heads, - args.square_alibi_mask, - args.fill_neg_inf, - ).to(torch.cuda.current_device()) - if args.params_dtype == torch.float16: - alibi = alibi.to(torch.float16) - elif args.params_dtype == torch.bfloat16: - alibi = alibi.to(torch.bfloat16) - self.alibi.alibi = alibi - -def dot_product_attention_forward_wrapper(fn): - @wraps(fn) - def wrapper(self, query, key, value, attention_mask, attn_mask_type, attention_bias=None, packed_seq_params=None): - if attention_mask is None: - attention_mask = get_attention_mask() - - args = get_args() - if args.use_flash_attn and args.tp_2d: - from mindspeed.core.transformer.dot_product_attention import dot_product_attention_forward - return dot_product_attention_forward(self, query, key, value, attention_mask, attn_mask_type, attention_bias, packed_seq_params) - if self.config.context_parallel_size > 1 and args.context_parallel_algo == "ulysses_cp_algo" and args.context_parallel_kv_cache_policy: - return do_ulyssesattn_context_parallel(self, query, key, value, attention_mask, attn_mask_type, packed_seq_params) - # =================================== - # Raw attention scores. [b, n/p, s, s] - # =================================== - - # expand the key and value [sk, b, ng, hn] -> [sk, b, np, hn] - # This is a noop for normal attention where ng == np. When using group query attention this - # creates a view that has the keys and values virtually repeated along their dimension to - # match the number of queries. - - heads_per_gqa_group = self.num_attention_heads_per_partition // self.num_query_groups_per_partition - if not args.use_flash_attn: - if heads_per_gqa_group > 1: - key = key.repeat_interleave(heads_per_gqa_group, dim=2) - value = value.repeat_interleave(heads_per_gqa_group, dim=2) - else: - # Do repeat KV to support PFA - should_kv_repeat_before_pfa = hasattr(args, 'use_kv_cache') and args.use_kv_cache - if heads_per_gqa_group > 1 and should_kv_repeat_before_pfa: - key = key.repeat_interleave(heads_per_gqa_group, dim=2) - value = value.repeat_interleave(heads_per_gqa_group, dim=2) - - return flash_attention_forward(self, query, key, value, attention_mask, attn_mask_type, - packed_seq_params) - - output_size = ( - query.size(1), - query.size(2), - query.size(0), - key.size(0), - ) - - # [sq, b, np, hn] -> [sq, b * np, hn] - # This will be a simple view when doing normal attention, but in group query attention - # the key and value tensors are repeated to match the queries so you can't use simple strides - # to extract the queries. - query = query.reshape(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] - key = key.view(output_size[3], output_size[0] * output_size[1], -1) - - # preallocting input tensor: [b * np, sq, sk] - if self.alibi is None: - matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor( - (output_size[0] * output_size[1], output_size[2], output_size[3]), query.dtype, "mpu", - ) - - # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( - matmul_input_buffer, - query.transpose(0, 1), # [b * np, sq, hn] - key.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=(1.0 / self.norm_factor), - ) - else: - if self.alibi.alibi_pse is None or self.alibi.output_size != output_size: - self.alibi.output_size = output_size - self.alibi.get_alibi_pse(attention_mask, output_size[0], output_size[2], output_size[3]) - - q_trans = query.transpose(0, 1).contiguous() - k_trans = key.transpose(0, 1).transpose(1, 2).contiguous() - matmul_result = self.beta * self.alibi.alibi_pse + torch.bmm(q_trans, k_trans) * (1.0 / self.norm_factor) - - if self.attn_logit_softcapping is not None: - matmul_result = matmul_result / self.attn_logit_softcapping - matmul_result = torch.tanh(matmul_result) - matmul_result = matmul_result * self.attn_logit_softcapping - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - - # =========================== - # Attention probs and dropout - # =========================== - - # attention scores and attention mask [b, np, sq, sk] - if self.square_alibi_mask: - attention_scores = torch.max( - attention_scores, torch.tensor(torch.finfo(attention_scores.dtype).min) - ) - attention_probs = torch.nn.functional.softmax(attention_scores, -1) - else: - attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - - if not self.config.sequence_parallel: - with tensor_parallel.get_cuda_rng_tracker().fork(): - attention_probs = self.attention_dropout(attention_probs) - else: - attention_probs = self.attention_dropout(attention_probs) - - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = ( - value.size(1), - value.size(2), - query.size(0), - value.size(3), - ) - - # change view [sk, b * np, hn] - value = value.view(value.size(0), output_size[0] * output_size[1], -1) - - # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) - - # matmul: [b * np, sq, hn] - context = torch.bmm(attention_probs, value.transpose(0, 1)) - - # change view [b, np, sq, hn] - context = context.view(*output_size) - - # [b, np, sq, hn] --> [sq, b, np, hn] - context = context.permute(2, 0, 1, 3).contiguous() - - # [sq, b, np, hn] --> [sq, b, hp] - new_context_shape = context.size()[:-2] + (self.hidden_size_per_partition,) - context = context.view(*new_context_shape) - - return context - - return wrapper - - -def flash_attention_forward( - self, - query: Union[Tensor, List[Tensor]], - key: Union[Tensor, List[Tensor]], - value: Tensor, - attention_mask, - attn_mask_type, - packed_seq_params, -): - query_rope, key_rope = None, None - if isinstance(query, List): - query, query_rope = query[0], query[1] - if isinstance(key, List): - key, key_rope = key[0], key[1] - - args = get_args() - - seq_length, batch_size, n_head, head_dim = query.shape[0], query.shape[1], query.shape[2], query.shape[3] - scale = 1.0 / math.sqrt(self.hidden_size_per_attention_head) \ - if self.scale_mask_softmax.scale is None else self.softmax_scale - actual_seq_len = get_actual_seq_len() - if actual_seq_len is not None and args.mtp_num_layers: - actual_seq_len = actual_seq_len[self.mtp_idx] - - if args.context_parallel_size > 1 and args.context_parallel_algo in ['megatron_cp_algo', 'hybrid_cp_algo', - 'adaptive_cp_algo', 'hybrid_adaptive_cp_algo']: - query, key, value = [rearrange(x, 's b h d -> s b (h d)') for x in [query, key, value]] - return do_ring_context_parallel( - query, key, value, head_num=n_head, softmax_scale=scale, attn_mask=attention_mask, pse=self.pse, - pse_type=self.pse_type, packed_seq_params=packed_seq_params) - - if args.shape_order == "TND": # varlen FA - if args.mla_fa_divide_qk: - query, key, value = [rearrange(x, 's b h d -> (b s) h d') for x in [query, key, value]] - if query_rope is not None and key_rope is not None: - query_rope, key_rope = [rearrange(x, 's b h d -> (b s) h d') for x in [query_rope, key_rope]] - else: - query, key, value = [rearrange(x, 's b h d -> (s b) h d') for x in [query, key, value]] - args.sparse_mode = 4 - elif args.shape_order == "BNSD": - query, key, value = [rearrange(x, 's b h d -> b h s d') for x in [query, key, value]] - else: - query, key, value = [rearrange(x, 's b h d -> s b (h d)') for x in [query, key, value]] - args.shape_order = "SBH" - - if self.hidden_size_per_attention_head == 0: - raise AssertionError("self.hidden_size_per_attention_head should not be ZERO.") - if not hasattr(self, 'attention_mask') or \ - self.attention_mask is None or \ - self.attention_mask.shape[0] != seq_length: - if self.alibi is not None: - self.attention_mask = torch.triu( - torch.ones(seq_length, seq_length), - 1).bool().npu() - else: - self.attention_mask = attention_mask - - use_sliding_windows = args.sliding_window is not None and seq_length > args.sliding_window - - if use_sliding_windows: - args.pre_tockens = args.sliding_window - args.sparse_mode = 4 - - pse = None - size_record = key.shape - if self.alibi is not None and (self.alibi.output_size != size_record) and pse is None: - if args.shape_order != 'SBH': - raise ValueError( - 'FlashAttention with Alibi requires for SBH shape_order, but is {}.'.format(args.shape_order)) - - self.alibi.output_size = size_record - self.alibi.get_alibi_pse(self.attention_mask, batch_size, query.shape[0], key.shape[0]) - - if self.alibi and pse is None: - pse = self.alibi.alibi_pse.reshape( - batch_size, n_head, self.alibi.alibi_pse.size(1), -1) - if hasattr(args, 'use_kv_cache') and args.use_kv_cache: - pse = pse * self.beta - else: - pse = pse * self.beta * self.norm_factor - args.pre_tockens = seq_length - args.sparse_mode = 0 - - if hasattr(args, 'use_kv_cache') and args.use_kv_cache: - query, key, value = [rearrange(x, 's b h -> b s h') for x in [query, key, value]] - if query.shape[1] == 1 and query.shape[1] != key.shape[1]: - output = torch_npu.npu_incre_flash_attention( - query, key, value, - num_heads=n_head, - input_layout="BSH", - pse_shift=pse, - padding_mask=None, - scale_value=scale - ) - else: - output = torch_npu.npu_prompt_flash_attention( - query, key, value, - num_heads=n_head, - input_layout="BSH", - pse_shift=pse, - sparse_mode=args.sparse_mode, - padding_mask=None, - atten_mask=self.attention_mask, - scale_value=scale, - pre_tokens=args.pre_tockens, - next_tokens=args.next_tockens - ) - output = output.transpose(0, 1) - else: - if not args.mla_fa_divide_qk: - if actual_seq_len is not None and len(actual_seq_len) > ACTUAL_SEQ_LEN_THRESHOLD: - logger.warning("flash-attention get a long actual_seq_len, maybe create a coredump!") - actual_seq_len = recompute_valid_actual_seq_len(get_position_ids(), actual_seq_len) - - output = torch_npu.npu_fusion_attention( - query, key, value, n_head, args.shape_order, - pse=pse, - padding_mask=None, - atten_mask=self.attention_mask, - actual_seq_qlen=actual_seq_len, - actual_seq_kvlen=actual_seq_len, - scale=scale, - pre_tockens=args.pre_tockens, - next_tockens=args.next_tockens, - keep_prob=1 - self.attention_dropout.p, - inner_precise=0, - sparse_mode=args.sparse_mode - )[0] - else: - output = torch_npu.npu_fusion_attention_v2( - query, key, value, n_head, args.shape_order, - pse=pse, - padding_mask=None, - atten_mask=self.attention_mask, - query_rope=query_rope, - key_rope=key_rope, - actual_seq_qlen=actual_seq_len, - actual_seq_kvlen=actual_seq_len, - scale=scale, - pre_tokens=args.pre_tockens, - next_tokens=args.next_tockens, - keep_prob=1 - self.attention_dropout.p, - inner_precise=0, - sparse_mode=args.sparse_mode - )[0] - - if args.shape_order == "TND": # varlen FA - output = rearrange(output, '(s b) h d -> s b (h d)', s=seq_length) - elif args.shape_order == "BNSD": - output = rearrange(output, 'b h s d -> s b (h d)') - - return output diff --git a/mindspeed_llm/features_manager/transformer/flash_attention/alibi_feature.py b/mindspeed_llm/features_manager/transformer/flash_attention/alibi_feature.py index c62570b93..800cf9971 100644 --- a/mindspeed_llm/features_manager/transformer/flash_attention/alibi_feature.py +++ b/mindspeed_llm/features_manager/transformer/flash_attention/alibi_feature.py @@ -1,7 +1,6 @@ # Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -from logging import getLogger from argparse import ArgumentParser from mindspeed.features_manager.feature import MindSpeedFeature @@ -14,8 +13,6 @@ class AlibiFeature(MindSpeedFeature): Usage: "--position-embedding-type alibi" - "--alibi-fusion-attn-type 0 or 2" - "[--alibi-diagonal-opposite]" """ def __init__(self): @@ -53,14 +50,4 @@ class AlibiFeature(MindSpeedFeature): action='store_true', default=False, help='fill alibi with negative inf' - ) - - - def register_patches(self, patch_manager, args): - if int(getattr(args, 'context_parallel_size', 1)) == 1: - # from mindspeed.core.transformer.flash_attention.alibi.adaptor import MindSpeedDotProductAttention - from mindspeed_llm.core.transformer.custom_dot_product_attention import DotProductAttention - patch_manager.register_patch( - 'megatron.core.transformer.dot_product_attention.DotProductAttention', - DotProductAttention ) \ No newline at end of file diff --git a/mindspeed_llm/features_manager/transformer/flash_attention/fusion_attention_feature.py b/mindspeed_llm/features_manager/transformer/flash_attention/fusion_attention_feature.py index ce33e3b9d..4f28449eb 100644 --- a/mindspeed_llm/features_manager/transformer/flash_attention/fusion_attention_feature.py +++ b/mindspeed_llm/features_manager/transformer/flash_attention/fusion_attention_feature.py @@ -22,21 +22,13 @@ class FusionAttentionFeature(MindSpeedFusionAttentionFeature): def register_patches(self, pm, args): from mindspeed.core.transformer.attention import attention_init - from mindspeed_llm.core.transformer.dot_product_attention import dot_product_attention_init, \ - dot_product_attention_forward_wrapper, ulysses_context_parallel_forward_wrapper - from mindspeed_llm.core.models.gpt.gpt_model import GPTModel - + from mindspeed_llm.core.transformer.custom_dot_product_attention import CustomDotProductAttention + # Attention - pm.register_patch('megatron.core.transformer.attention.Attention.__init__', - attention_init) - pm.register_patch('megatron.core.transformer.dot_product_attention.DotProductAttention.__init__', - dot_product_attention_init) - pm.register_patch('megatron.core.transformer.dot_product_attention.DotProductAttention.forward', - dot_product_attention_forward_wrapper) - pm.register_patch('megatron.core.transformer.custom_layers.transformer_engine.TEDotProductAttention.__init__', - dot_product_attention_init) - pm.register_patch('megatron.core.transformer.custom_layers.transformer_engine.TEDotProductAttention.forward', - dot_product_attention_forward_wrapper) - # For GQA in ulysses and hybrid - pm.register_patch('mindspeed.core.context_parallel.ulysses_context_parallel.ulysses_context_parallel.UlyssesContextAttention.forward', - ulysses_context_parallel_forward_wrapper) \ No newline at end of file + if int(getattr(args, 'context_parallel_size', 1)) < 2 and bool(getattr(args, 'use_flash_attn', False)): + pm.register_patch('megatron.core.transformer.attention.Attention.__init__', + attention_init) + pm.register_patch('megatron.core.transformer.dot_product_attention.DotProductAttention', + CustomDotProductAttention) + pm.register_patch('megatron.core.transformer.custom_layers.transformer_engine.TEDotProductAttention', + CustomDotProductAttention) \ No newline at end of file -- Gitee From cbbf7f31c1920ca214ecf8915c3d16622e546fd7 Mon Sep 17 00:00:00 2001 From: HanhuiChen Date: Tue, 12 Aug 2025 18:35:52 +0800 Subject: [PATCH 4/9] refactor dot product attention for non-cp --- mindspeed_llm/training/arguments.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mindspeed_llm/training/arguments.py b/mindspeed_llm/training/arguments.py index 19a055555..4999df89b 100644 --- a/mindspeed_llm/training/arguments.py +++ b/mindspeed_llm/training/arguments.py @@ -1303,8 +1303,6 @@ def _add_dummy_args_v2(args): args.tp_2d = False args.tp_x = 1 args.attn_logit_softcapping = False - args.square_alibi_mask = False - args.fill_neg_inf = False args.query_pre_attn_scalar = 0.0 args.add_output_layer_bias = False args.is_pairwise_dataset = False -- Gitee From a64507338d1ffb2d19a3356d141fad035cb4b203 Mon Sep 17 00:00:00 2001 From: HanhuiChen Date: Wed, 13 Aug 2025 11:48:46 +0800 Subject: [PATCH 5/9] refactor non-cp dot product atten --- .../custom_dot_product_attention.py | 183 ++++++++++++++---- .../flash_attention/alibi_feature.py | 41 ++-- 2 files changed, 161 insertions(+), 63 deletions(-) diff --git a/mindspeed_llm/core/transformer/custom_dot_product_attention.py b/mindspeed_llm/core/transformer/custom_dot_product_attention.py index 7aca7a36a..d8019b2f8 100644 --- a/mindspeed_llm/core/transformer/custom_dot_product_attention.py +++ b/mindspeed_llm/core/transformer/custom_dot_product_attention.py @@ -7,9 +7,9 @@ from typing import List import torch import torch_npu -from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax from megatron.training import get_args from megatron.core import mpu, parallel_state +from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.transformer.utils import attention_mask_func from megatron.core.utils import divide @@ -31,7 +31,8 @@ ACTUAL_SEQ_LEN_THRESHOLD = 2048 class CustomDotProductAttentionImpl: """ - Implementation of dot product attention with non-cp support. + Implementation of dot product attention with non-CP (no context-parallel) support. + This module assumes FlashAttention kernels are available and enforces the constraint. """ def __init__(self, @@ -41,32 +42,59 @@ class CustomDotProductAttentionImpl: attention_type, attention_dropout: float = None, softmax_scale: float = None, - cp_comm_type: str = None - ): + cp_comm_type: str = None): + """ + Args: + config: TransformerConfig-like object containing model hyperparameters. + layer_number (int): 1-based index of the transformer layer (used for scaling). + attn_mask_type: Type of attention mask (causal/bidirectional). Currently unused here. + attention_type: Attention impl selector (e.g., self/cross); passed through for compatibility. + attention_dropout (float, optional): Attention dropout probability. If None, read from config. + softmax_scale (float, optional): External softmax scaling factor; if None, computed internally. + cp_comm_type (str, optional): Context-parallel comm type (unused because CP is disabled). + """ + # --------------------------------------------------------------------- + # Preconditions: Only non-CP and FlashAttention are supported + # --------------------------------------------------------------------- super().__init__(config, layer_number, attn_mask_type, attention_type, attention_dropout, softmax_scale, cp_comm_type) args = get_args() - assert getattr(config, 'context_parallel_size', 1) == 1, "CustomDotProductAttention only supported by non-cp (context_parallel_size == 1)" - assert bool(getattr(args, 'use_flash_attn', False)) == True, "CustomDotProductAttention only supported by FlashAttention (args.use_flash_attn == True)" + assert getattr(config, 'context_parallel_size', 1) == 1, \ + "CustomDotProductAttention only supported by non-CP (context_parallel_size == 1)" + assert bool(getattr(args, 'use_flash_attn', False)) is True, \ + "CustomDotProductAttention only supported by FlashAttention (args.use_flash_attn == True)" + # --------------------------------------------------------------------- + # Basic attributes and tensor-parallel partition shapes + # --------------------------------------------------------------------- self.config = config self.layer_number = max(1, layer_number) - self.attn_mask_type = attn_mask_type # unused for now + self.attn_mask_type = attn_mask_type # unused for now self.attention_type = attention_type + # Projection size = H * Dh (heads × head_dim) projection_size = self.config.kv_channels * self.config.num_attention_heads + + # Determine model-parallel world size (2D-TP or standard TP) world_size = args.tp_x if args.tp_2d else parallel_state.get_tensor_model_parallel_world_size() + + # Partitioned hidden and heads (per TP shard) self.hidden_size_per_partition = divide(projection_size, world_size) self.hidden_size_per_attention_head = divide(projection_size, config.num_attention_heads) self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size) self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size) + # --------------------------------------------------------------------- + # Scaling strategy (Megatron-style query-key layer scaling) + # - norm_factor = sqrt(Dh) * layer_number (if apply_query_key_layer_scaling) + # --------------------------------------------------------------------- coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.config.apply_query_key_layer_scaling: coeff = self.layer_number self.norm_factor *= coeff + # Fused scale+mask+softmax for pre-FA paths (kept for parity / mask handling) self.scale_mask_softmax = FusedScaleMaskSoftmax( input_in_fp16=self.config.fp16, input_in_bf16=self.config.bf16, @@ -77,44 +105,65 @@ class CustomDotProductAttentionImpl: scale=coeff, ) - # Dropout. Note that for a single iteration, this layer will generate - # different outputs on different number of parallel partitions but - # on average it should not be partition dependent. + # --------------------------------------------------------------------- + # Dropout layer (kept to pass keep_prob to FA kernels) + # - Single-iteration outputs may differ across partitions, but expectation matches + # --------------------------------------------------------------------- self.attention_dropout = torch.nn.Dropout( self.config.attention_dropout if attention_dropout is None else attention_dropout ) + # --------------------------------------------------------------------- + # Positional bias / soft-capping / ALiBi options + # --------------------------------------------------------------------- self.pse = None self.pse_type = None self.attn_logit_softcapping = args.attn_logit_softcapping self.square_alibi_mask = args.square_alibi_mask self.fill_neg_inf = args.fill_neg_inf + + # Beta is used to down-scale PSE when KV-cache is active (per-layer scaling) self.beta = 1.0 self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling if self.apply_query_key_layer_scaling: self.beta = 1.0 / self.layer_number + # --------------------------------------------------------------------- + # ALiBi positional bias precomputation (if enabled via args.position_embedding_type == 'alibi') + # - Prebuild and cast once per dtype and device + # --------------------------------------------------------------------- if args.position_embedding_type == 'alibi': self.alibi = Alibi() - alibi = self.alibi._build_alibi_tensor(args.seq_length, - args.num_attention_heads, - args.square_alibi_mask, - args.fill_neg_inf, - ).to(torch.cuda.current_device()) + alibi = self.alibi._build_alibi_tensor( + args.seq_length, + args.num_attention_heads, + args.square_alibi_mask, + args.fill_neg_inf, + ).to(torch.cuda.current_device()) + if args.params_dtype == torch.float16: alibi = alibi.to(torch.float16) elif args.params_dtype == torch.bfloat16: alibi = alibi.to(torch.bfloat16) + self.alibi.alibi = alibi self.alibi_output_size = None else: self.alibi = None + # --------------------------------------------------------------------- + # Optional: query pre-attention scaling override + # - When enabled, override scale used by softmax to 1/sqrt(query_pre_attn_scalar) + # --------------------------------------------------------------------- if args.query_pre_attn_scalar: self.norm_factor = args.query_pre_attn_scalar ** 0.5 self.scale_mask_softmax.scale = 1.0 self.softmax_scale = 1.0 / self.norm_factor + # Final scale used by FA kernels (fallback to 1/sqrt(Dh) if not overridden) + self.scale = 1.0 / math.sqrt(self.hidden_size_per_attention_head) \ + if self.scale_mask_softmax.scale is None else self.softmax_scale + def forward( self, query, @@ -125,9 +174,28 @@ class CustomDotProductAttentionImpl: attention_bias=None, packed_seq_params=None, ): + """ + Args: + query: Tensor of shape [S, B, H, Dh] (default SBHD) before layout transforms. + key: Tensor with same logical layout as query. + value: Tensor with same logical layout as query. + attention_mask: Precomputed mask (e.g., causal) or None to fetch global mask. + attn_mask_type: Optional mask type override (unused here; parity with base API). + attention_bias: Optional additive attention bias (unused here; PSE used for ALiBi). + packed_seq_params: Optional varlen pack info for FA (handled via shape_order logic). + + Returns: + output: Tensor of shape [S, B, H * Dh] (SBH merged heads×dim at the end). + """ + # --------------------------------------------------------------------- + # 0) Guard: ensure we have a valid attention mask + # --------------------------------------------------------------------- if attention_mask is None: attention_mask = get_attention_mask() + # --------------------------------------------------------------------- + # 1) Unpack optional rope-carrying lists (query/key may be [tensor, rope]) + # --------------------------------------------------------------------- query_rope, key_rope = None, None if isinstance(query, List): query, query_rope = query[0], query[1] @@ -135,21 +203,36 @@ class CustomDotProductAttentionImpl: key, key_rope = key[0], key[1] args = get_args() + + # --------------------------------------------------------------------- + # 2) GQA group expansion when using KV cache + # - If heads_per_group > 1 and KV cache is enabled, repeat KV across heads in group + # --------------------------------------------------------------------- heads_per_gqa_group = self.num_attention_heads_per_partition // self.num_query_groups_per_partition should_kv_repeat_before_pfa = hasattr(args, 'use_kv_cache') and args.use_kv_cache if heads_per_gqa_group > 1 and should_kv_repeat_before_pfa: key = key.repeat_interleave(heads_per_gqa_group, dim=2) value = value.repeat_interleave(heads_per_gqa_group, dim=2) + # Shapes prior to any layout rearrangement seq_length, batch_size, n_head, head_dim = query.shape[0], query.shape[1], query.shape[2], query.shape[3] - scale = (1.0 / math.sqrt(self.hidden_size_per_attention_head)) if self.scale_mask_softmax.scale is None \ - else self.softmax_scale + # --------------------------------------------------------------------- + # 3) Variable-length (packed) sequence handling + # - actual_seq_len may be per-token; trim / recompute if too long (core-dump risk) + # --------------------------------------------------------------------- actual_seq_len = get_actual_seq_len() if actual_seq_len is not None and args.mtp_num_layers: actual_seq_len = actual_seq_len[self.mtp_idx] - if args.shape_order == "TND": # varlen FA + # --------------------------------------------------------------------- + # 4) Layout transforms for FA kernels + # shape_order: + # - "TND": treat (T,N,D) with heads factored outside; kernel expects packed batch-major + # - "BNSD": [B, H, S, Dh] + # - default -> "SBH": [S, B, H*Dh] (Megatron classic) + # --------------------------------------------------------------------- + if args.shape_order == "TND": # varlen FA path if args.mla_fa_divide_qk: query, key, value = [rearrange(x, 's b h d -> (b s) h d') for x in [query, key, value]] if query_rope is not None and key_rope is not None: @@ -163,56 +246,77 @@ class CustomDotProductAttentionImpl: query, key, value = [rearrange(x, 's b h d -> s b (h d)') for x in [query, key, value]] args.shape_order = "SBH" + # --------------------------------------------------------------------- + # 5) Prepare / cache the attention mask (and causal mask for ALiBi) + # --------------------------------------------------------------------- if self.hidden_size_per_attention_head == 0: raise AssertionError("self.hidden_size_per_attention_head should not be ZERO.") - if not hasattr(self, 'attention_mask') or \ - self.attention_mask is None or \ - self.attention_mask.shape[0] != seq_length: + + if (not hasattr(self, 'attention_mask')) or (self.attention_mask is None) or (self.attention_mask.shape[0] != seq_length): if self.alibi is not None: - self.attention_mask = torch.triu( - torch.ones(seq_length, seq_length), - 1).bool().npu() + # Strict causal upper-triangular mask for ALiBi + self.attention_mask = torch.triu(torch.ones(seq_length, seq_length), 1).bool().npu() else: + # Use provided (or global) attention mask as-is self.attention_mask = attention_mask + # --------------------------------------------------------------------- + # 6) Sliding-window attention (Long context sparsity) + # - When window is smaller than sequence, switch to sparse mode + # --------------------------------------------------------------------- use_sliding_windows = args.sliding_window is not None and seq_length > args.sliding_window - if use_sliding_windows: args.pre_tockens = args.sliding_window args.sparse_mode = 4 + # --------------------------------------------------------------------- + # 7) Build/reshape ALiBi PSE if needed (enforce SBH layout for FA+ALiBi) + # - PSE is scaled by beta and optionally by norm_factor (no KV cache) + # --------------------------------------------------------------------- pse = None size_record = key.shape if self.alibi is not None and (self.alibi.output_size != size_record) and pse is None: if args.shape_order != 'SBH': - raise ValueError( - 'FlashAttention with Alibi requires for SBH shape_order, but is {}.'.format(args.shape_order)) - + raise ValueError(f'FlashAttention with ALiBi requires SBH shape_order, but got {args.shape_order}.') self.alibi.output_size = size_record self.alibi.get_alibi_pse(self.attention_mask, batch_size, query.shape[0], key.shape[0]) if self.alibi and pse is None: - pse = self.alibi.alibi_pse.reshape( - batch_size, n_head, self.alibi.alibi_pse.size(1), -1) + pse = self.alibi.alibi_pse.reshape(batch_size, n_head, self.alibi.alibi_pse.size(1), -1) if hasattr(args, 'use_kv_cache') and args.use_kv_cache: pse = pse * self.beta else: pse = pse * self.beta * self.norm_factor + # With dense ALiBi PSE we disable sparsity args.pre_tockens = seq_length args.sparse_mode = 0 + # --------------------------------------------------------------------- + # 8) Execute FlashAttention kernels on Ascend NPU (torch_npu) + # Two paths: + # a) KV cache enabled, only supports infernce mode: + # - npu_incre_flash_attention for single-token decode (BSH, step by step) + # - npu_prompt_flash_attention for prompt / extended decode + # b) No KV cache: + # - npu_fusion_attention (standard FA) + # - npu_fusion_attention_v2 (FA supports mla with seperate q and k) + # --------------------------------------------------------------------- if hasattr(args, 'use_kv_cache') and args.use_kv_cache: + # Kernels below expect [B, S, H*Dh] for BSH layout query, key, value = [rearrange(x, 's b h -> b s h') for x in [query, key, value]] + if query.shape[1] == 1 and query.shape[1] != key.shape[1]: + # Incremental decode kernel: append a single step using cached K/V output = torch_npu.npu_incre_flash_attention( query, key, value, num_heads=n_head, input_layout="BSH", pse_shift=pse, padding_mask=None, - scale_value=scale + scale_value=self.scale ) else: + # Prompt + decode kernel: extend using both prompt and cached segments output = torch_npu.npu_prompt_flash_attention( query, key, value, num_heads=n_head, @@ -221,15 +325,18 @@ class CustomDotProductAttentionImpl: sparse_mode=args.sparse_mode, padding_mask=None, atten_mask=self.attention_mask, - scale_value=scale, + scale_value=self.scale, pre_tokens=args.pre_tockens, next_tokens=args.next_tockens ) + # Restore to [S, B, H*Dh]-compatible first dimension for later reshapes output = output.transpose(0, 1) else: + # No KV cache: fused attention over full sequences if not args.mla_fa_divide_qk: + # Standard FA path if actual_seq_len is not None and len(actual_seq_len) > ACTUAL_SEQ_LEN_THRESHOLD: - logger.warning("flash-attention get a long actual_seq_len, maybe create a coredump!") + logger.warning("flash-attention got a long actual_seq_len; recomputing to avoid potential coredump.") actual_seq_len = recompute_valid_actual_seq_len(get_position_ids(), actual_seq_len) output = torch_npu.npu_fusion_attention( @@ -239,7 +346,7 @@ class CustomDotProductAttentionImpl: atten_mask=self.attention_mask, actual_seq_qlen=actual_seq_len, actual_seq_kvlen=actual_seq_len, - scale=scale, + scale=self.scale, pre_tockens=args.pre_tockens, next_tockens=args.next_tockens, keep_prob=1 - self.attention_dropout.p, @@ -247,6 +354,7 @@ class CustomDotProductAttentionImpl: sparse_mode=args.sparse_mode )[0] else: + # FA v2 with separate Q/K RoPE inputs output = torch_npu.npu_fusion_attention_v2( query, key, value, n_head, args.shape_order, pse=pse, @@ -256,7 +364,7 @@ class CustomDotProductAttentionImpl: key_rope=key_rope, actual_seq_qlen=actual_seq_len, actual_seq_kvlen=actual_seq_len, - scale=scale, + scale=self.scale, pre_tokens=args.pre_tockens, next_tokens=args.next_tockens, keep_prob=1 - self.attention_dropout.p, @@ -264,7 +372,10 @@ class CustomDotProductAttentionImpl: sparse_mode=args.sparse_mode )[0] - if args.shape_order == "TND": # varlen FA + # --------------------------------------------------------------------- + # 9) Restore to canonical [S, B, H*Dh] layout expected by upper layers + # --------------------------------------------------------------------- + if args.shape_order == "TND": # varlen FA output = rearrange(output, '(s b) h d -> s b (h d)', s=seq_length) elif args.shape_order == "BNSD": output = rearrange(output, 'b h s d -> s b (h d)') diff --git a/mindspeed_llm/features_manager/transformer/flash_attention/alibi_feature.py b/mindspeed_llm/features_manager/transformer/flash_attention/alibi_feature.py index 800cf9971..af74eee9f 100644 --- a/mindspeed_llm/features_manager/transformer/flash_attention/alibi_feature.py +++ b/mindspeed_llm/features_manager/transformer/flash_attention/alibi_feature.py @@ -16,38 +16,25 @@ class AlibiFeature(MindSpeedFeature): """ def __init__(self): - super().__init__( - 'position-embedding-type', - optimization_level=2 - ) + super().__init__('position-embedding-type', optimization_level=2) def is_need_apply(self, args): pse = getattr(args, self.feature_name, None) need_apply = False if pse == 'alibi': need_apply = True - return ( - self.optimization_level <= args.optimization_level and - need_apply - ) or self.default_patches + + return (self.optimization_level <= args.optimization_level and need_apply) or self.default_patches def register_args(self, parser: ArgumentParser): - self.add_parser_argument_choices_value( - parser, - "--position-embedding-type", - 'alibi' - ) - - group = parser.add_argument_group(title='alibi') - group.add_argument( - '--square-alibi-mask', - action='store_true', - default=False, - help='attention mask of alibi is squared' - ) - group.add_argument( - '--fill-neg-inf', - action='store_true', - default=False, - help='fill alibi with negative inf' - ) \ No newline at end of file + group = parser.add_argument_group(title=self.feature_name) + self.add_parser_argument_choices_value(parser, "--position-embedding-type", 'alibi') + group.add_argument('--square-alibi-mask', action='store_true', default=False, + help='attention mask of alibi is squared') + group.add_argument('--fill-neg-inf', action='store_true', default=False, + help='fill alibi with negative inf') + + def validate_args(self, args): + # alibi only support by FA + if getattr(args, "position_embedding_type", None) == "alibi" and not getattr(args, "use_flash_attn", False): + raise AssertionError("`--position-embedding-type alibi` requires `--use-flash-attn` to be enabled.") \ No newline at end of file -- Gitee From 716dc9d908fcd629f961ace7b003cf4bcaf8d64d Mon Sep 17 00:00:00 2001 From: HanhuiChen Date: Wed, 13 Aug 2025 11:50:40 +0800 Subject: [PATCH 6/9] refactor non-cp dot product atten --- .../core/transformer/custom_dot_product_attention.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mindspeed_llm/core/transformer/custom_dot_product_attention.py b/mindspeed_llm/core/transformer/custom_dot_product_attention.py index d8019b2f8..9789fe468 100644 --- a/mindspeed_llm/core/transformer/custom_dot_product_attention.py +++ b/mindspeed_llm/core/transformer/custom_dot_product_attention.py @@ -384,6 +384,14 @@ class CustomDotProductAttentionImpl: class CustomDotProductAttention(CustomDotProductAttentionImpl, DotProductAttention): + """ + Dot product attention class combining: + - CustomDotProductAttentionImpl: Non-CP + FlashAttention optimized implementation + - DotProductAttention: Base attention interface for compatibility with Megatron-LM + + Inherits both to allow seamless replacement in the transformer stack + while keeping the optimized forward logic from CustomDotProductAttentionImpl. + """ def __init__(self, *args, **kwargs): CustomDotProductAttentionImpl.__init__(self, *args, **kwargs) \ No newline at end of file -- Gitee From e421efc3a3d2a4bf078ab2f56a9d612e7be925d8 Mon Sep 17 00:00:00 2001 From: HanhuiChen Date: Wed, 13 Aug 2025 16:39:33 +0800 Subject: [PATCH 7/9] refactor non-cp dot product atten --- .../custom_dot_product_attention.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/mindspeed_llm/core/transformer/custom_dot_product_attention.py b/mindspeed_llm/core/transformer/custom_dot_product_attention.py index 9789fe468..032925e8e 100644 --- a/mindspeed_llm/core/transformer/custom_dot_product_attention.py +++ b/mindspeed_llm/core/transformer/custom_dot_product_attention.py @@ -59,10 +59,12 @@ class CustomDotProductAttentionImpl: super().__init__(config, layer_number, attn_mask_type, attention_type, attention_dropout, softmax_scale, cp_comm_type) args = get_args() - assert getattr(config, 'context_parallel_size', 1) == 1, \ - "CustomDotProductAttention only supported by non-CP (context_parallel_size == 1)" - assert bool(getattr(args, 'use_flash_attn', False)) is True, \ - "CustomDotProductAttention only supported by FlashAttention (args.use_flash_attn == True)" + if getattr(config, 'context_parallel_size', 1) != 1: + raise AssertionError("CustomDotProductAttention only supported by non-CP (context_parallel_size == 1)") + + if not bool(getattr(args, 'use_flash_attn', False)): + raise AssertionError("CustomDotProductAttention only supported by FlashAttention (args.use_flash_attn == True)") + # --------------------------------------------------------------------- # Basic attributes and tensor-parallel partition shapes @@ -72,7 +74,6 @@ class CustomDotProductAttentionImpl: self.attn_mask_type = attn_mask_type # unused for now self.attention_type = attention_type - # Projection size = H * Dh (heads × head_dim) projection_size = self.config.kv_channels * self.config.num_attention_heads # Determine model-parallel world size (2D-TP or standard TP) @@ -86,7 +87,6 @@ class CustomDotProductAttentionImpl: # --------------------------------------------------------------------- # Scaling strategy (Megatron-style query-key layer scaling) - # - norm_factor = sqrt(Dh) * layer_number (if apply_query_key_layer_scaling) # --------------------------------------------------------------------- coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) @@ -160,7 +160,6 @@ class CustomDotProductAttentionImpl: self.scale_mask_softmax.scale = 1.0 self.softmax_scale = 1.0 / self.norm_factor - # Final scale used by FA kernels (fallback to 1/sqrt(Dh) if not overridden) self.scale = 1.0 / math.sqrt(self.hidden_size_per_attention_head) \ if self.scale_mask_softmax.scale is None else self.softmax_scale @@ -302,7 +301,6 @@ class CustomDotProductAttentionImpl: # - npu_fusion_attention_v2 (FA supports mla with seperate q and k) # --------------------------------------------------------------------- if hasattr(args, 'use_kv_cache') and args.use_kv_cache: - # Kernels below expect [B, S, H*Dh] for BSH layout query, key, value = [rearrange(x, 's b h -> b s h') for x in [query, key, value]] if query.shape[1] == 1 and query.shape[1] != key.shape[1]: @@ -329,7 +327,6 @@ class CustomDotProductAttentionImpl: pre_tokens=args.pre_tockens, next_tokens=args.next_tockens ) - # Restore to [S, B, H*Dh]-compatible first dimension for later reshapes output = output.transpose(0, 1) else: # No KV cache: fused attention over full sequences @@ -388,9 +385,6 @@ class CustomDotProductAttention(CustomDotProductAttentionImpl, DotProductAttenti Dot product attention class combining: - CustomDotProductAttentionImpl: Non-CP + FlashAttention optimized implementation - DotProductAttention: Base attention interface for compatibility with Megatron-LM - - Inherits both to allow seamless replacement in the transformer stack - while keeping the optimized forward logic from CustomDotProductAttentionImpl. """ def __init__(self, *args, **kwargs): -- Gitee From 843cd49c7002aa9556f2be6e38507e450bac1e6f Mon Sep 17 00:00:00 2001 From: HanhuiChen Date: Thu, 14 Aug 2025 01:16:47 +0000 Subject: [PATCH 8/9] update mindspeed_llm/training/arguments.py. Signed-off-by: HanhuiChen --- mindspeed_llm/training/arguments.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mindspeed_llm/training/arguments.py b/mindspeed_llm/training/arguments.py index 5a3953698..ac61212f1 100644 --- a/mindspeed_llm/training/arguments.py +++ b/mindspeed_llm/training/arguments.py @@ -1286,8 +1286,6 @@ def _add_dummy_args_v2(args): args.tp_2d = False args.tp_x = 1 args.tp_y = 1 - args.attn_logit_softcapping = False - args.query_pre_attn_scalar = 0.0 args.add_output_layer_bias = False args.is_pairwise_dataset = False args.enable_share_memory = False -- Gitee From 956959ab9c7a581bf68b793d9c4cef215fc72618 Mon Sep 17 00:00:00 2001 From: HanhuiChen Date: Thu, 14 Aug 2025 16:07:12 +0800 Subject: [PATCH 9/9] refactor non-cp dot product attention --- .../transformer/flash_attention/alibi_feature.py | 8 -------- .../flash_attention/fusion_attention_feature.py | 2 +- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/mindspeed_llm/features_manager/transformer/flash_attention/alibi_feature.py b/mindspeed_llm/features_manager/transformer/flash_attention/alibi_feature.py index af74eee9f..d417935b7 100644 --- a/mindspeed_llm/features_manager/transformer/flash_attention/alibi_feature.py +++ b/mindspeed_llm/features_manager/transformer/flash_attention/alibi_feature.py @@ -18,14 +18,6 @@ class AlibiFeature(MindSpeedFeature): def __init__(self): super().__init__('position-embedding-type', optimization_level=2) - def is_need_apply(self, args): - pse = getattr(args, self.feature_name, None) - need_apply = False - if pse == 'alibi': - need_apply = True - - return (self.optimization_level <= args.optimization_level and need_apply) or self.default_patches - def register_args(self, parser: ArgumentParser): group = parser.add_argument_group(title=self.feature_name) self.add_parser_argument_choices_value(parser, "--position-embedding-type", 'alibi') diff --git a/mindspeed_llm/features_manager/transformer/flash_attention/fusion_attention_feature.py b/mindspeed_llm/features_manager/transformer/flash_attention/fusion_attention_feature.py index 628cee62d..8626f0608 100644 --- a/mindspeed_llm/features_manager/transformer/flash_attention/fusion_attention_feature.py +++ b/mindspeed_llm/features_manager/transformer/flash_attention/fusion_attention_feature.py @@ -27,7 +27,7 @@ class FusionAttentionFeature(MindSpeedFusionAttentionFeature): from mindspeed_llm.core.transformer.custom_dot_product_attention import CustomDotProductAttention # Attention - if int(getattr(args, 'context_parallel_size', 1)) < 2 and bool(getattr(args, 'use_flash_attn', False)): + if int(getattr(args, 'context_parallel_size', 1)) < 2: pm.register_patch('megatron.core.transformer.attention.Attention.__init__', attention_init) pm.register_patch('megatron.core.transformer.dot_product_attention.DotProductAttention', -- Gitee