From bb1d18bea28696b1e935b48f21b981f7f414c511 Mon Sep 17 00:00:00 2001 From: yanzhixiao Date: Thu, 7 Aug 2025 15:31:21 +0800 Subject: [PATCH] refactor context parallel --- .../core/context_parallel/adaptor.py | 37 ++ .../context_parallel/dot_product_attention.py | 335 ++++++++++++++++++ .../core/context_parallel/get_batch_utils.py | 68 ++++ mindspeed_llm/features_manager/__init__.py | 14 + .../context_parallel/__init__.py | 0 .../context_parallel_feature.py | 52 +++ .../mamba_context_parallel.py | 29 ++ .../ulysses_context_parallel.py | 28 ++ .../megatron_basic/model_basic.py | 4 - .../megatron_basic/training_basic.py | 2 - .../tokenizer/build_tokenizer.py | 5 + .../fusion_attention_feature.py | 30 +- 12 files changed, 583 insertions(+), 21 deletions(-) create mode 100644 mindspeed_llm/core/context_parallel/adaptor.py create mode 100644 mindspeed_llm/core/context_parallel/dot_product_attention.py create mode 100644 mindspeed_llm/core/context_parallel/get_batch_utils.py create mode 100644 mindspeed_llm/features_manager/context_parallel/__init__.py create mode 100644 mindspeed_llm/features_manager/context_parallel/context_parallel_feature.py create mode 100644 mindspeed_llm/features_manager/context_parallel/mamba_context_parallel.py create mode 100644 mindspeed_llm/features_manager/context_parallel/ulysses_context_parallel.py diff --git a/mindspeed_llm/core/context_parallel/adaptor.py b/mindspeed_llm/core/context_parallel/adaptor.py new file mode 100644 index 0000000000..3ffefb03b3 --- /dev/null +++ b/mindspeed_llm/core/context_parallel/adaptor.py @@ -0,0 +1,37 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +from functools import wraps +from mindspeed.core.context_parallel import mpu +from mindspeed.core.context_parallel.model_parallel_utils import get_context_parallel_group_for_hybrid_ulysses +from mindspeed.core.context_parallel.ulysses_context_parallel.ulysses_context_parallel import UlyssesContextAttention +from mindspeed.core.context_parallel import DotProductAttention as MegatronDotProductAttention + +from mindspeed_llm.core.context_parallel.dot_product_attention import CPDotProductAttentionImpl + + +class CPDotProductAttention(CPDotProductAttentionImpl, MegatronDotProductAttention): + + def __init__(self, *args, **kwargs): + CPDotProductAttentionImpl.__init__(self, *args, **kwargs) + + +def attention_init_wrapper(fn): + @wraps(fn) + def wrapper( + self, + config, + submodules, + layer_number, + attn_mask_type, + attention_type, + cp_comm_type: str = None,): + fn(self, config, submodules, layer_number, attn_mask_type, attention_type, cp_comm_type) + + if config.context_parallel_size > 1 and config.context_parallel_algo in ['ulysses_cp_algo', 'hybrid_cp_algo']: + if config.context_parallel_algo in ['hybrid_cp_algo']: + ulysses_group = get_context_parallel_group_for_hybrid_ulysses() + else: + ulysses_group = mpu.get_context_parallel_group() + self.core_attention = UlyssesContextAttention(self.core_attention, ulysses_group) + + return wrapper \ No newline at end of file diff --git a/mindspeed_llm/core/context_parallel/dot_product_attention.py b/mindspeed_llm/core/context_parallel/dot_product_attention.py new file mode 100644 index 0000000000..3bbb234baf --- /dev/null +++ b/mindspeed_llm/core/context_parallel/dot_product_attention.py @@ -0,0 +1,335 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +from typing import Union, List +import math +import torch +import torch_npu +from torch import Tensor + +from megatron.core.transformer.utils import attention_mask_func +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.training import get_args +from megatron.core import mpu, parallel_state, tensor_parallel +from mindspeed.core.transformer.attention import self_attention_init_wrapper +from mindspeed.utils import get_actual_seq_len, compute_qkv_index, get_position_ids +from mindspeed.core.context_parallel import AttnMaskType +from mindspeed.core.context_parallel import FusedScaleMaskSoftmax +from mindspeed.core.context_parallel import divide +from mindspeed.ops.fusion_attention_v2 import npu_fusion_attention +from mindspeed.model.transformer import get_attention_mask +from mindspeed.core.tensor_parallel_y_union_cp import TensorParallelYUnionCP +from mindspeed.core.context_parallel.ring_context_parallel.context_parallel_kv_cache import get_cache_policy +from mindspeed.core.context_parallel.ulysses_context_parallel.ulysses_context_parallel import ulyssesattn_context_parallel +from mindspeed.core.context_parallel.ring_context_parallel.ring_context_parallel import ringattn_context_parallel +from mindspeed.core.context_parallel.utils import get_scheduling_info +from mindspeed.core.context_parallel.model_parallel_utils import (get_context_parallel_group_for_hybrid_ring, + get_context_parallel_for_hybrid_ring_world_size, + get_context_parallel_for_hybrid_ring_rank, + get_context_parallel_for_hybrid_ring_global_ranks, + get_ring_ranks_for_intra_window, + get_ring_ranks_for_inter_window_kv, + get_ring_ranks_for_inter_window_dkv, + get_ring_group_for_intra_window, + get_ring_group_for_intra_window_send_recv_overlap) + + +try: + from einops import rearrange +except ImportError: + rearrange = None + +ACTUAL_SEQ_LEN_THRESHOLD = 2048 + + +def do_ulyssesattn_context_parallel_with_kv_cache_policy(self, + query: Tensor, + key: Tensor, + value: Tensor, + attention_mask, + packed_seq_params): + args = get_args() + + self.ulysses_comm_para['cache_policy'] = get_cache_policy( + self.layer_number, args.context_parallel_kv_cache_policy, args.context_parallel_cache_interval + ) + self.ulysses_comm_para['use_ulysses_allgather_kv'] = args.use_ulysses_allgather_kv + attn_para = dict() + attn_para['packed_seq_params'] = packed_seq_params + attn_para['attention_mask'] = attention_mask + attn_para['scale'] = self.scale + attn_para['pre_tokens'] = args.pre_tockens + attn_para['next_tokens'] = args.next_tockens + attn_para['keep_prob'] = 1 - self.attention_dropout.p + attn_para['sparse_mode'] = self.sparse_mode + output = ulyssesattn_context_parallel(query, key, value, attn_para, self.ulysses_comm_para) + + return output + + +def do_ring_context_parallel(self, + query: Tensor, + key: Tensor, + value: Tensor, + head_num, + attention_mask, + dropout_p=0, + packed_seq_params=None, + actual_seq_len=None): + args = get_args() + + if args.shape_order == "TND": + packed_seq_params = PackedSeqParams( + cu_seqlens_q=torch.tensor(actual_seq_len, dtype=torch.int64, device=torch.cuda.current_device()), + cu_seqlens_kv=torch.tensor(actual_seq_len, dtype=torch.int64, device=torch.cuda.current_device()) + ) + + q_index, kv_index = compute_qkv_index( + torch.tensor(actual_seq_len, dtype=torch.int64, device=torch.cuda.current_device()).clone().tolist()) + packed_seq_params.q_index = q_index + packed_seq_params.kv_index = kv_index + packed_seq_params.position_ids = get_position_ids() + + in_hybrid_mode = get_context_parallel_group_for_hybrid_ring(check_initialized=False) is not None + if in_hybrid_mode: + cp_group = get_context_parallel_group_for_hybrid_ring() + cp_size = get_context_parallel_for_hybrid_ring_world_size() + rank = get_context_parallel_for_hybrid_ring_rank() + cp_global_ranks = get_context_parallel_for_hybrid_ring_global_ranks() + else: + cp_group = mpu.get_context_parallel_group() + cp_size = mpu.get_context_parallel_world_size() + rank = mpu.get_context_parallel_rank() + cp_global_ranks = mpu.get_context_parallel_global_ranks() + + cp_para = dict() + + cp_para['causal'] = args.attention_mask_type == 'causal' + cp_para['cp_group'] = cp_group + cp_para['cp_size'] = cp_size + cp_para['rank'] = rank + if args.context_parallel_algo in ['megatron_cp_algo', 'hybrid_cp_algo']: + cp_para['cp_global_ranks'] = cp_global_ranks + cp_para['cp_group_for_send_recv_overlap'] = mpu.get_context_parallel_group_for_send_recv_overlap() \ + if args.use_cp_send_recv_overlap else None + cp_para['pse'] = self.pse + cp_para['pse_type'] = self.pse_type + + cp_para['cp_inner_ranks'] = get_ring_ranks_for_intra_window() + cp_para['cp_outer_ranks'] = get_ring_ranks_for_inter_window_kv() + cp_para['cp_dkv_outer_ranks'] = get_ring_ranks_for_inter_window_dkv() + cp_para['cp_group_for_intra_window'] = get_ring_group_for_intra_window() + cp_para['cp_group_for_intra_window_send_recv_overlap'] = get_ring_group_for_intra_window_send_recv_overlap() + cp_para['cache_policy'] = get_cache_policy( + self.layer_number, args.context_parallel_kv_cache_policy, args.context_parallel_cache_interval + ) + + output = ringattn_context_parallel(query, key, value, head_num, cp_para, self.scale, attention_mask, dropout_p, + packed_seq_params) + + return output + + +class CPDotProductAttentionImpl: + """ + Implementation of dot product attention with cp support. + """ + + def __init__(self, + config, + layer_number, + attn_mask_type, + attention_type, + attention_dropout: float = None, + softmax_scale: float = None, + cp_comm_type: str = None): + cp_size = config.context_parallel_size + config.context_parallel_size = 1 + self.config = config + super().__init__(config, layer_number, attn_mask_type, attention_type, attention_dropout, softmax_scale, cp_comm_type) + if self.config.context_parallel_size != 1: + raise AssertionError("Context parallelism is only supported by TEDotProductAttention!") + + if self.config.window_size is not None: + raise AssertionError("Sliding Window Attention is only supported by TEDotProductAttention!") + + self.layer_number = max(1, layer_number) + self.attn_mask_type = attn_mask_type + self.attention_type = attention_type # unused for now + + projection_size = self.config.kv_channels * self.config.num_attention_heads + # Per attention head and per partition values. + world_size = parallel_state.get_tensor_model_parallel_world_size() + self.hidden_size_per_partition = divide(projection_size, world_size) + self.hidden_size_per_attention_head = divide(projection_size, config.num_attention_heads) + self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size) + self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size) + + coeff = None + if softmax_scale is None: + self.softmax_scale = 1.0 / math.sqrt(self.hidden_size_per_attention_head) + else: + self.softmax_scale = softmax_scale + + if self.config.apply_query_key_layer_scaling: + coeff = self.layer_number + self.softmax_scale /= coeff + + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.config.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + + self.scale_mask_softmax = FusedScaleMaskSoftmax( + input_in_fp16=self.config.fp16, + input_in_bf16=self.config.bf16, + attn_mask_type=self.attn_mask_type, + scaled_masked_softmax_fusion=self.config.masked_softmax_fusion, + mask_func=attention_mask_func, + softmax_in_fp32=self.config.attention_softmax_in_fp32, + scale=coeff, + ) + + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.attention_dropout = torch.nn.Dropout( + self.config.attention_dropout if attention_dropout is None else attention_dropout + ) + + config.context_parallel_size = cp_size + + # add pse + self.pse = None + self.pse_type = None + self.attn_logit_softcapping = self.config.attn_logit_softcapping + self.square_alibi_mask = self.config.square_alibi_mask + self.fill_neg_inf = self.config.fill_neg_inf + self.beta = 1.0 + self.apply_query_key_layer_scaling = self.config.apply_query_key_layer_scaling + + if self.apply_query_key_layer_scaling: + self.beta = 1.0 / self.layer_number + + if self.config.position_embedding_type == 'alibi': + get_alibi(self, args.seq_length) + self.alibi_output_size = None + else: + self.alibi = None + + if self.config.query_pre_attn_scalar: + self.norm_factor = self.config.query_pre_attn_scalar ** 0.5 + self.scale_mask_softmax.scale = 1.0 + self.softmax_scale = 1.0 / self.norm_factor + + self.scale = 1.0 / math.sqrt( + self.hidden_size_per_attention_head) if self.scale_mask_softmax.scale is None else self.softmax_scale + + def forward( + self, + query, + key, + value, + attention_mask, + attn_mask_type=None, + attention_bias=None, + packed_seq_params=None, + ): + if attention_mask is None: + attention_mask = get_attention_mask() + query_rope, key_rope = None, None + if isinstance(query, List): + query, query_rope = query[0], query[1] + if isinstance(key, List): + key, key_rope = key[0], key[1] + + args = get_args() + self.sparse_mode = args.sparse_mode + seq_length, batch_size, n_head, head_dim = query.shape[0], query.shape[1], query.shape[2], query.shape[3] + actual_seq_len = get_actual_seq_len() + if actual_seq_len is not None and args.mtp_num_layers: + actual_seq_len = actual_seq_len[self.mtp_idx] + + if attn_mask_type == AttnMaskType.no_mask: + self.sparse_mode = 0 # default mask + + # ulyssesattn_context_parallel_with_kv_cache_policy + if (self.config.context_parallel_size > 1 and self.config.context_parallel_algo == "ulysses_cp_algo" + and self.config.context_parallel_kv_cache_policy): + return do_ulyssesattn_context_parallel_with_kv_cache_policy(self, query, key, value, attention_mask=attention_mask, packed_seq_params=packed_seq_params) + + # ring_context_parallel + if self.config.context_parallel_size > 1 and self.config.context_parallel_algo in ['megatron_cp_algo', 'hybrid_cp_algo']: + query, key, value = [rearrange(x, 's b h d -> s b (h d)') for x in [query, key, value]] + return do_ring_context_parallel(self, query, key, value, head_num=n_head, attention_mask=attention_mask, packed_seq_params=packed_seq_params, actual_seq_len=actual_seq_len) + + # process shape order + if args.shape_order == "TND": # varlen FA + if args.mla_fa_divide_qk: + query, key, value = [rearrange(x, 's b h d -> (b s) h d') for x in [query, key, value]] + if query_rope is not None and key_rope is not None: + query_rope, key_rope = [rearrange(x, 's b h d -> (b s) h d') for x in [query_rope, key_rope]] + else: + query, key, value = [rearrange(x, 's b h d -> (s b) h d') for x in [query, key, value]] + self.sparse_mode = 4 + elif args.shape_order == "BNSD": + query, key, value = [rearrange(x, 's b h d -> b h s d') for x in [query, key, value]] + else: + query, key, value = [rearrange(x, 's b h d -> s b (h d)') for x in [query, key, value]] + args.shape_order = "SBH" + + if self.hidden_size_per_attention_head == 0: + raise AssertionError("self.hidden_size_per_attention_head should not be ZERO.") + if not hasattr(self, 'attention_mask') or \ + self.attention_mask is None or \ + self.attention_mask.shape[0] != seq_length: + if self.alibi is not None: + self.attention_mask = torch.triu( + torch.ones(seq_length, seq_length), 1).bool().npu() + else: + self.attention_mask = attention_mask + + + if not args.mla_fa_divide_qk: + if actual_seq_len is not None and len(actual_seq_len) > ACTUAL_SEQ_LEN_THRESHOLD: + logger.warning("flash-attention get a long actual_seq_len, maybe create a coredump!") + actual_seq_len = recompute_valid_actual_seq_len(get_position_ids(), actual_seq_len) + + output = torch_npu.npu_fusion_attention( + query, key, value, n_head, args.shape_order, + pse=self.pse, + padding_mask=None, + atten_mask=self.attention_mask, + actual_seq_qlen=actual_seq_len, + actual_seq_kvlen=actual_seq_len, + scale=self.scale, + pre_tockens=args.pre_tockens, + next_tockens=args.next_tockens, + keep_prob=1 - self.attention_dropout.p, + inner_precise=0, + sparse_mode=self.sparse_mode + )[0] + else: + output = torch_npu.npu_fusion_attention_v2( + query, key, value, n_head, args.shape_order, + pse=self.pse, + padding_mask=None, + atten_mask=self.attention_mask, + query_rope=query_rope, + key_rope=key_rope, + actual_seq_qlen=actual_seq_len, + actual_seq_kvlen=actual_seq_len, + scale=self.scale, + pre_tokens=args.pre_tockens, + next_tokens=args.next_tockens, + keep_prob=1 - self.attention_dropout.p, + inner_precise=0, + sparse_mode=self.sparse_mode + )[0] + + # post_process after FA + if args.shape_order == "TND": # varlen FA + output = rearrange(output, '(s b) h d -> s b (h d)', s=seq_length) + elif args.shape_order == "BNSD": + output = rearrange(output, 'b h s d -> s b (h d)') + + return output diff --git a/mindspeed_llm/core/context_parallel/get_batch_utils.py b/mindspeed_llm/core/context_parallel/get_batch_utils.py new file mode 100644 index 0000000000..cd3ebd4726 --- /dev/null +++ b/mindspeed_llm/core/context_parallel/get_batch_utils.py @@ -0,0 +1,68 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +from megatron.training import get_args +from megatron.core import mpu +from mindspeed.model.transformer import set_attention_mask +from mindspeed.core.context_parallel.get_batch_utils import (set_actual_seq_len, + _get_batch_on_this_cp_rank_in_megatron_cp, + _get_batch_on_this_cp_rank_in_ulysses_cp, + _get_batch_on_this_cp_rank_in_hybrid_cp_general, + _get_batch_on_this_cp_rank_in_hybrid_cp, + _get_batch_on_this_cp_rank_in_adaptive_cp, + _get_batch_on_this_cp_rank_in_hybrid_adaptive_cp, + broadcast_dynamic, _broadcast, get_ring_degree) + + +def get_batch_on_this_cp_rank(batch): + """ Slice batch input along sequence dimension into multiple chunks, + which are parallelized across GPUs in a context parallel group. + """ + + # With causal masking, each token only attends to its prior tokens. Simply split + # sequence into CP chunks can result in severe load imbalance. That's to say, chunks + # at the end of sequence have bigger workload than others. To address this issue, + # we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0 + # and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so + # that we can get balanced workload among GPUs in a context parallel group. + args = get_args() + + if not args.context_parallel_size > 1: + return batch + + if args.attention_mask_type == 'general' and batch.get("attention_mask", None) is not None: + set_attention_mask(batch['attention_mask'].squeeze()) + + cp_expanded_by_2d_tp = args.tp_y > 1 + if args.context_parallel_algo == 'megatron_cp_algo': + if args.attention_mask_type == 'general': + batch = _get_batch_on_this_cp_rank_in_megatron_cp_general(batch) + elif cp_expanded_by_2d_tp: + batch = _get_batch_on_this_tp_y_cp_rank_in_megatron_cp(batch) + else: + batch = _get_batch_on_this_cp_rank_in_megatron_cp(batch) + elif args.context_parallel_algo == 'ulysses_cp_algo' or args.context_parallel_algo == 'mamba_cp_algo': + batch = _get_batch_on_this_cp_rank_in_ulysses_cp(batch) + elif args.context_parallel_algo == 'hybrid_cp_algo': + if args.attention_mask_type == 'general': + batch = _get_batch_on_this_cp_rank_in_hybrid_cp_general(batch) + else: + batch = _get_batch_on_this_cp_rank_in_hybrid_cp(batch) + return batch + + +def _get_batch_on_this_cp_rank_in_megatron_cp_general(batch): + cp_rank = mpu.get_context_parallel_rank() + cp_size = mpu.get_context_parallel_world_size() + for key, val in batch.items(): + if key == 'attention_mask' and val is not None: + seq_dim = 2 if len(val.shape) == 4 else 0 + mask_row = val.chunk(cp_size, dim=seq_dim)[cp_rank].contiguous() + mask_list = [m.contiguous() for m in mask_row.chunk(cp_size, dim=seq_dim + 1)] + batch[key] = mask_list + continue + if val is not None: + seq_dim = 1 + val = val.chunk(cp_size, dim=seq_dim)[cp_rank].contiguous() + batch[key] = val + + return batch \ No newline at end of file diff --git a/mindspeed_llm/features_manager/__init__.py b/mindspeed_llm/features_manager/__init__.py index f738a70ed7..ff91842b4c 100644 --- a/mindspeed_llm/features_manager/__init__.py +++ b/mindspeed_llm/features_manager/__init__.py @@ -31,11 +31,15 @@ from mindspeed.features_manager import ( RecomputeMethodFeature, SmartSwapFeature, SwapAttentionFeature, + ContextParallelKvCacheFeature ) from mindspeed.features_manager.feature import MindSpeedFeature from mindspeed.features_manager.features_manager import MindSpeedFeaturesManager from mindspeed_llm.features_manager.affinity.affinity import AffinityFeature +from mindspeed_llm.features_manager.context_parallel.context_parallel_feature import ContextParallelFeature +from mindspeed_llm.features_manager.context_parallel.ulysses_context_parallel import UlyssesContextParallelFeature +from mindspeed_llm.features_manager.context_parallel.mamba_context_parallel import MambaContextParallelFeature from mindspeed_llm.features_manager.common.data import DataFeature from mindspeed_llm.features_manager.common.embedding import LanguageModelEmbeddingFeature from mindspeed_llm.features_manager.common.rotary import RotaryPositionEmbeddingFeature @@ -111,6 +115,15 @@ def add_affinity_features(features_list: List[MindSpeedFeature]): ]) +def add_context_parallel_features(features_list: List[MindSpeedFeature]): + features_list.extend([ + ContextParallelFeature(), + UlyssesContextParallelFeature(), + ContextParallelKvCacheFeature(), + MambaContextParallelFeature() + ]) + + def add_fusions_features(features_list: List[MindSpeedFeature]): features_list.extend([ FusedSwigluFeature(), @@ -223,6 +236,7 @@ def add_disable_gloo_group_feature(features_list: List[MindSpeedFeature]): def create_features_list(): features_list = [] add_megatron_basic_features(features_list) + add_context_parallel_features(features_list) add_llm_features(features_list) add_affinity_features(features_list) add_fusions_features(features_list) diff --git a/mindspeed_llm/features_manager/context_parallel/__init__.py b/mindspeed_llm/features_manager/context_parallel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/mindspeed_llm/features_manager/context_parallel/context_parallel_feature.py b/mindspeed_llm/features_manager/context_parallel/context_parallel_feature.py new file mode 100644 index 0000000000..42716c09c4 --- /dev/null +++ b/mindspeed_llm/features_manager/context_parallel/context_parallel_feature.py @@ -0,0 +1,52 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. + +from argparse import ArgumentParser +from mindspeed.features_manager.context_parallel.context_parallel_feature import ContextParallelFeature as MindspeedContextParallelFeature + + +class ContextParallelFeature(MindspeedContextParallelFeature): + + def __init__(self): + super().__init__() + + def register_args(self, parser: ArgumentParser): + group = parser.add_argument_group(title=self.feature_name) + group.add_argument('--context-parallel-algo', type=str, default='megatron_cp_algo', + choices=['megatron_cp_algo', 'hybrid_cp_algo'], + help='context parallel algorithm') + + # ring context parallel + group.add_argument('--cp-window-size', type=int, default=1) + group.add_argument('--attention-mask-type', type=str, default='causal', + choices=['causal', 'general'], help='context parallel attention mask type') + group.add_argument('--use-cp-send-recv-overlap', action='store_true', + help='use this flag to enable cp send-recv-overlap.') + group.add_argument("--use-fused-ring-attention-update", action='store_true', + help="Use fused ring attention update.") + group.add_argument("--megatron-cp-in-bnsd", action='store_true', + help="Megatron CP in bnsd.") + + + def register_patches(self, patch_manager, args): + if int(getattr(args, 'context_parallel_size', 1)) > 1: + from mindspeed.core.context_parallel.model_parallel_utils import initialize_model_parallel_cp_wrapper, \ + destroy_model_parallel_cp_wrapper, get_context_parallel_group_for_send_recv_overlap + from mindspeed.core.context_parallel.rotary_pos_embedding_utils import get_pos_emb_on_this_cp_rank + from mindspeed_llm.core.context_parallel.adaptor import CPDotProductAttention + from mindspeed_llm.core.context_parallel.adaptor import attention_init_wrapper + from mindspeed_llm.core.context_parallel.get_batch_utils import get_batch_on_this_cp_rank + patch_manager.register_patch('megatron.core.parallel_state.initialize_model_parallel', + initialize_model_parallel_cp_wrapper) + patch_manager.register_patch('megatron.core.parallel_state.destroy_model_parallel', + destroy_model_parallel_cp_wrapper) + patch_manager.register_patch('megatron.core.parallel_state.get_context_parallel_group_for_send_recv_overlap', + get_context_parallel_group_for_send_recv_overlap) + patch_manager.register_patch('megatron.core.models.common.embeddings.rotary_pos_embedding.get_pos_emb_on_this_cp_rank', + get_pos_emb_on_this_cp_rank) + patch_manager.register_patch('megatron.training.utils.get_batch_on_this_cp_rank', get_batch_on_this_cp_rank) + patch_manager.register_patch('megatron.core.transformer.attention.Attention.__init__', + attention_init_wrapper) + patch_manager.register_patch('megatron.core.transformer.dot_product_attention.DotProductAttention', + CPDotProductAttention) + patch_manager.register_patch('megatron.core.extensions.transformer_engine.TEDotProductAttention', + CPDotProductAttention) diff --git a/mindspeed_llm/features_manager/context_parallel/mamba_context_parallel.py b/mindspeed_llm/features_manager/context_parallel/mamba_context_parallel.py new file mode 100644 index 0000000000..e119be4919 --- /dev/null +++ b/mindspeed_llm/features_manager/context_parallel/mamba_context_parallel.py @@ -0,0 +1,29 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. + +from argparse import ArgumentParser + +from mindspeed.features_manager.feature import MindSpeedFeature + + +class MambaContextParallelFeature(MindSpeedFeature): + + def __init__(self): + super().__init__('context-parallel-size') + + def register_args(self, parser: ArgumentParser): + group = parser.add_argument_group(title=self.feature_name) + self.add_parser_argument_choices_value( + parser, "--context-parallel-algo", 'mamba_cp_algo' + ) + + + def validate_args(self, args): + # mamba context parallel + if args.context_parallel_size > 1 and args.context_parallel_algo == 'mamba_cp_algo': + if args.seq_length % args.context_parallel_size != 0: + raise AssertionError("sequence length must be divisible by context_parallel_size") + head, remainder = divmod(args.num_attention_heads, + args.context_parallel_size * args.tensor_model_parallel_size) + if not (head >= 1 and remainder == 0): + raise AssertionError("num_attention_heads must be divisible by context_parallel_size * tensor_model_parallel_size") + args.use_flash_attn = True \ No newline at end of file diff --git a/mindspeed_llm/features_manager/context_parallel/ulysses_context_parallel.py b/mindspeed_llm/features_manager/context_parallel/ulysses_context_parallel.py new file mode 100644 index 0000000000..d780500d5d --- /dev/null +++ b/mindspeed_llm/features_manager/context_parallel/ulysses_context_parallel.py @@ -0,0 +1,28 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. + +from argparse import ArgumentParser + +from mindspeed.features_manager.context_parallel.ulysses_context_parallel import UlyssesContextParallelFeature as MindspeedUlyssesContextParallel + + +class UlyssesContextParallelFeature(MindspeedUlyssesContextParallel): + + def __init__(self): + super().__init__() + + def register_args(self, parser: ArgumentParser): + super().register_args(parser) + group = parser.add_argument_group(title=self.feature_name) + group.add_argument('--kv-head-repeat-before-uly-alltoall', action='store_true', default=True, + help='use it to expand key and value for ulysses when GQA/MQA is used.') + + def validate_args(self, args): + super().validate_args(args) + if args.context_parallel_size <= 1: + if args.kv_head_repeat_before_uly_alltoall: + from mindspeed_llm.training.utils import print_rank0_by_args + args.kv_head_repeat_before_uly_alltoall = False + print_rank0_by_args(args, + f"When context_parallel is not activated, kv_head_repeat_before_uly_alltoall would be set to False for reducing memory usage.") + + diff --git a/mindspeed_llm/features_manager/megatron_basic/model_basic.py b/mindspeed_llm/features_manager/megatron_basic/model_basic.py index 452813c6be..10b38b5334 100644 --- a/mindspeed_llm/features_manager/megatron_basic/model_basic.py +++ b/mindspeed_llm/features_manager/megatron_basic/model_basic.py @@ -13,7 +13,6 @@ class ModelBasicFeature(MindSpeedFeature): def patch_model_patches(self, pm, args): from mindspeed_llm.training.tokenizer import build_tokenizer from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec - from mindspeed.core.context_parallel.rotary_pos_embedding_utils import get_pos_emb_on_this_cp_rank from mindspeed_llm.core.models.gpt.gpt_model import GPTModel from mindspeed_llm.training.utils import get_device_wrapper from mindspeed_llm.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec_wrapper @@ -21,9 +20,6 @@ class ModelBasicFeature(MindSpeedFeature): checkpoint_backward_wrapper pm.register_patch('megatron.training.global_vars.build_tokenizer', build_tokenizer) - # Embedding - pm.register_patch('megatron.core.models.common.embeddings.rotary_pos_embedding.get_pos_emb_on_this_cp_rank', - get_pos_emb_on_this_cp_rank) pm.register_patch('megatron.core.models.gpt.gpt_model.GPTModel', GPTModel) diff --git a/mindspeed_llm/features_manager/megatron_basic/training_basic.py b/mindspeed_llm/features_manager/megatron_basic/training_basic.py index 9a986c0e3c..1c03e2203c 100644 --- a/mindspeed_llm/features_manager/megatron_basic/training_basic.py +++ b/mindspeed_llm/features_manager/megatron_basic/training_basic.py @@ -37,8 +37,6 @@ class TrainingBasicFeature(MindSpeedFeature): group = parser.add_argument_group(title=self.feature_name) group.add_argument('--jit-compile', action='store_true', default=False, help='Setting jit compile mode to True') - group.add_argument('--attention-mask-type', type=str, default='causal', choices=['causal', 'general'], - help='context parallel attention mask type') group.add_argument('--load-checkpoint-loosely', action='store_true', default=False, help='Enable loading checkpoint not strictly.') diff --git a/mindspeed_llm/features_manager/tokenizer/build_tokenizer.py b/mindspeed_llm/features_manager/tokenizer/build_tokenizer.py index 658c5c201b..ec7e61daac 100644 --- a/mindspeed_llm/features_manager/tokenizer/build_tokenizer.py +++ b/mindspeed_llm/features_manager/tokenizer/build_tokenizer.py @@ -32,3 +32,8 @@ class BuildTokenizerFeature(MindSpeedBuildTokenizerFeature): help='Path to the json file of templates.') group.add_argument('--tokenizer-padding-side', type=str, default='right', help="tokenizer padding side") + + def register_patches(self, patch_manager, args): + if args.tokenizer_type == "PretrainedFromHF": + from mindspeed_llm.training.tokenizer import build_tokenizer + patch_manager.register_patch('megatron.training.tokenizer.tokenizer.build_tokenizer', build_tokenizer) \ No newline at end of file diff --git a/mindspeed_llm/features_manager/transformer/flash_attention/fusion_attention_feature.py b/mindspeed_llm/features_manager/transformer/flash_attention/fusion_attention_feature.py index ce33e3b9dc..205fda3ad1 100644 --- a/mindspeed_llm/features_manager/transformer/flash_attention/fusion_attention_feature.py +++ b/mindspeed_llm/features_manager/transformer/flash_attention/fusion_attention_feature.py @@ -25,18 +25,18 @@ class FusionAttentionFeature(MindSpeedFusionAttentionFeature): from mindspeed_llm.core.transformer.dot_product_attention import dot_product_attention_init, \ dot_product_attention_forward_wrapper, ulysses_context_parallel_forward_wrapper from mindspeed_llm.core.models.gpt.gpt_model import GPTModel - - # Attention - pm.register_patch('megatron.core.transformer.attention.Attention.__init__', - attention_init) - pm.register_patch('megatron.core.transformer.dot_product_attention.DotProductAttention.__init__', - dot_product_attention_init) - pm.register_patch('megatron.core.transformer.dot_product_attention.DotProductAttention.forward', - dot_product_attention_forward_wrapper) - pm.register_patch('megatron.core.transformer.custom_layers.transformer_engine.TEDotProductAttention.__init__', - dot_product_attention_init) - pm.register_patch('megatron.core.transformer.custom_layers.transformer_engine.TEDotProductAttention.forward', - dot_product_attention_forward_wrapper) - # For GQA in ulysses and hybrid - pm.register_patch('mindspeed.core.context_parallel.ulysses_context_parallel.ulysses_context_parallel.UlyssesContextAttention.forward', - ulysses_context_parallel_forward_wrapper) \ No newline at end of file + if int(getattr(args, 'context_parallel_size', 1)) < 2: + # Attention + pm.register_patch('megatron.core.transformer.attention.Attention.__init__', + attention_init) + pm.register_patch('megatron.core.transformer.dot_product_attention.DotProductAttention.__init__', + dot_product_attention_init) + pm.register_patch('megatron.core.transformer.dot_product_attention.DotProductAttention.forward', + dot_product_attention_forward_wrapper) + pm.register_patch('megatron.core.transformer.custom_layers.transformer_engine.TEDotProductAttention.__init__', + dot_product_attention_init) + pm.register_patch('megatron.core.transformer.custom_layers.transformer_engine.TEDotProductAttention.forward', + dot_product_attention_forward_wrapper) + # For GQA in ulysses and hybrid + pm.register_patch('mindspeed.core.context_parallel.ulysses_context_parallel.ulysses_context_parallel.UlyssesContextAttention.forward', + ulysses_context_parallel_forward_wrapper) \ No newline at end of file -- Gitee