From c48e66e8bfa73b1f2f992ac705c27d5db2532626 Mon Sep 17 00:00:00 2001 From: jzh Date: Thu, 25 Sep 2025 19:34:32 +0800 Subject: [PATCH 1/2] [pytorch][bugfix] --- .../baichuan2/evaluate_baichuan2_13b_ptd.sh | 1 - .../baichuan2/generate_baichuan2_13b_ptd.sh | 1 - .../context_parallel/dot_product_attention.py | 2 +- .../custom_dot_product_attention.py | 2 +- .../flash_attention/alibi_feature.py | 19 +++++++++++++++++-- 5 files changed, 19 insertions(+), 6 deletions(-) diff --git a/examples/mcore/baichuan2/evaluate_baichuan2_13b_ptd.sh b/examples/mcore/baichuan2/evaluate_baichuan2_13b_ptd.sh index b1038672f..c09429251 100644 --- a/examples/mcore/baichuan2/evaluate_baichuan2_13b_ptd.sh +++ b/examples/mcore/baichuan2/evaluate_baichuan2_13b_ptd.sh @@ -32,7 +32,6 @@ torchrun $DISTRIBUTED_ARGS evaluation.py \ --num-attention-heads 40 \ --disable-bias-linear \ --swiglu \ - --use-flash-attn \ --position-embedding-type alibi \ --square-alibi-mask \ --fill-neg-inf \ diff --git a/examples/mcore/baichuan2/generate_baichuan2_13b_ptd.sh b/examples/mcore/baichuan2/generate_baichuan2_13b_ptd.sh index 5620208f6..b474774b1 100644 --- a/examples/mcore/baichuan2/generate_baichuan2_13b_ptd.sh +++ b/examples/mcore/baichuan2/generate_baichuan2_13b_ptd.sh @@ -30,7 +30,6 @@ torchrun $DISTRIBUTED_ARGS inference.py \ --num-attention-heads 40 \ --max-position-embeddings 2048 \ --position-embedding-type alibi \ - --use-flash-attn \ --square-alibi-mask \ --fill-neg-inf \ --swiglu \ diff --git a/mindspeed_llm/core/context_parallel/dot_product_attention.py b/mindspeed_llm/core/context_parallel/dot_product_attention.py index 5868c1e0b..ba29608a6 100644 --- a/mindspeed_llm/core/context_parallel/dot_product_attention.py +++ b/mindspeed_llm/core/context_parallel/dot_product_attention.py @@ -306,7 +306,7 @@ class CPDotProductAttentionImpl: if not args.mla_fa_divide_qk: if actual_seq_len is not None and len(actual_seq_len) > ACTUAL_SEQ_LEN_THRESHOLD: - actual_seq_len = recompute_valid_actual_seq_len(actual_seq_len) + actual_seq_len = recompute_valid_actual_seq_len(actual_seq_len, args.micro_batch_size) if len(actual_seq_len) > ACTUAL_SEQ_LEN_THRESHOLD: logger.warning( f"FlashAttention received unexpectedly long 'actual_seq_len' (length={len(actual_seq_len)}, threshold={ACTUAL_SEQ_LEN_THRESHOLD}). " diff --git a/mindspeed_llm/core/transformer/custom_dot_product_attention.py b/mindspeed_llm/core/transformer/custom_dot_product_attention.py index 3e7d09899..71cbc8d95 100644 --- a/mindspeed_llm/core/transformer/custom_dot_product_attention.py +++ b/mindspeed_llm/core/transformer/custom_dot_product_attention.py @@ -333,7 +333,7 @@ class CustomDotProductAttentionImpl: if not args.mla_fa_divide_qk: # Standard FA path if actual_seq_len is not None and len(actual_seq_len) > ACTUAL_SEQ_LEN_THRESHOLD: - actual_seq_len = recompute_valid_actual_seq_len(actual_seq_len) + actual_seq_len = recompute_valid_actual_seq_len(actual_seq_len, args.micro_batch_size) if len(actual_seq_len) > ACTUAL_SEQ_LEN_THRESHOLD: logger.warning( f"FlashAttention received unexpectedly long 'actual_seq_len' (length={len(actual_seq_len)}, threshold={ACTUAL_SEQ_LEN_THRESHOLD}). " diff --git a/mindspeed_llm/features_manager/transformer/flash_attention/alibi_feature.py b/mindspeed_llm/features_manager/transformer/flash_attention/alibi_feature.py index fc154683a..91d1a127f 100644 --- a/mindspeed_llm/features_manager/transformer/flash_attention/alibi_feature.py +++ b/mindspeed_llm/features_manager/transformer/flash_attention/alibi_feature.py @@ -29,6 +29,21 @@ class AlibiFeature(MindSpeedFeature): help='alibi pse type, support for 0,2') def validate_args(self, args): - # alibi only support by FA + # alibi cp or fa is support + # else alibi without fa need patch below + return + + def register_patches(self, patch_manager, args): if getattr(args, "position_embedding_type", None) == "alibi" and not getattr(args, "use_flash_attn", False): - raise AssertionError("`--position-embedding-type alibi` requires `--use-flash-attn` to be enabled.") \ No newline at end of file + from mindspeed_llm.core.transformer.dot_product_attention import dot_product_attention_init, dot_product_attention_forward_wrapper + + patch_manager.register_patch('megatron.core.transformer.dot_product_attention.DotProductAttention.__init__', + dot_product_attention_init) + patch_manager.register_patch('megatron.core.transformer.dot_product_attention.DotProductAttention.forward', + dot_product_attention_forward_wrapper) + patch_manager.register_patch( + 'megatron.core.transformer.custom_layers.transformer_engine.TEDotProductAttention.__init__', + dot_product_attention_init) + patch_manager.register_patch( + 'megatron.core.transformer.custom_layers.transformer_engine.TEDotProductAttention.forward', + dot_product_attention_forward_wrapper) -- Gitee From 3b9374ead14a7b23b625bbb7d224ecf67c120c93 Mon Sep 17 00:00:00 2001 From: jzh Date: Thu, 25 Sep 2025 20:22:30 +0800 Subject: [PATCH 2/2] [pytorch][bugfix] model update --- mindspeed_llm/core/transformer/dot_product_attention.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mindspeed_llm/core/transformer/dot_product_attention.py b/mindspeed_llm/core/transformer/dot_product_attention.py index e319591c9..250d856a7 100644 --- a/mindspeed_llm/core/transformer/dot_product_attention.py +++ b/mindspeed_llm/core/transformer/dot_product_attention.py @@ -171,6 +171,11 @@ def dot_product_attention_init( self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size) coeff = None + if softmax_scale is None: + self.softmax_scale = 1.0 / math.sqrt(self.hidden_size_per_attention_head) + else: + self.softmax_scale = softmax_scale + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.config.apply_query_key_layer_scaling: coeff = self.layer_number -- Gitee