From 1914a50db135c056d7f74309d0ca7d105380cf2b Mon Sep 17 00:00:00 2001 From: pqhgitee Date: Wed, 12 Mar 2025 22:45:21 +0800 Subject: [PATCH 1/3] =?UTF-8?q?feat:=E6=94=AF=E6=8C=81=E5=88=86=E5=B1=82al?= =?UTF-8?q?ltoallv=E7=89=B9=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindspeed_llm/tasks/megatron_adaptor.py | 2 +- mindspeed_llm/training/arguments.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/mindspeed_llm/tasks/megatron_adaptor.py b/mindspeed_llm/tasks/megatron_adaptor.py index 7767e65c1..fceed228a 100644 --- a/mindspeed_llm/tasks/megatron_adaptor.py +++ b/mindspeed_llm/tasks/megatron_adaptor.py @@ -428,7 +428,7 @@ class CoreAdaptation(MegatronAdaptationABC): MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute', permute) MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.unpermute', unpermute) - if args.moe_tp_extend_ep: + if args.moe_tp_extend_ep or args.moe_hierarchical_alltoallv: from mindspeed.core.transformer.moe.token_dispatcher import ( preprocess_tp_extend_ep, alltoall_token_unpermutation_tp_extend_ep, alltoall_token_permutation_tp_extend_ep diff --git a/mindspeed_llm/training/arguments.py b/mindspeed_llm/training/arguments.py index cbc6b5bed..9b521d5b9 100644 --- a/mindspeed_llm/training/arguments.py +++ b/mindspeed_llm/training/arguments.py @@ -15,6 +15,7 @@ import os import argparse +import torch from pathlib import Path from functools import wraps from mindspeed_llm.training.utils import print_rank0_by_args @@ -153,6 +154,8 @@ def _add_deepseek_moe_args(parser): 'global batch, where the bias is increased for the experts with less assigned tokens and ' 'decreased for the experts with more assigned tokens. ' 'The default value 1e-3 is same as that used in DeepSeekV3.') + group.add_argument('--moe-hierarchical-alltoallv', action='store_true', + help='Reduce communication cost between nodes') return parser @@ -1119,6 +1122,14 @@ def _validate_moe_args(args): raise AssertionError('`--moe-zero-memory` does not support full recomputation for now.') if args.shared_expert_gate and args.gradient_accumulation_fusion: raise AssertionError('args.shared_expert_gate does not support gradient_accumulation_fusion.') + if args.moe_hierarchical_alltoallv: + tp = args.tensor_model_parallel_size + ep = args.expert_model_parallel_size + if ((not args.moe_alltoall_overlap_comm) or (not args.moe_tp_extend_ep) or tp <= 1 or tp > torch.npu.device_count() or + ep * tp <= torch.npu.device_count() or args.world_size <= torch.npu.device_count()): + raise AssertionError( + '`--moe-hierarchical-alltoallv` must have `--moe-alltoall-overlap-comm` on and ' + '`--moe-tp-extend-ep` on and 1 < tp <= torch.npu.device_count() and cross-device communication') def _validate_mla(args): @@ -1326,7 +1337,6 @@ def _add_dummy_args(args): args.attention_mask_type = args.cp_attention_mask_type args.hccl_group_buffer_adaptive = False args.moe_bmm_mc2 = False - args.moe_hierarchical_alltoallv = False args.moe_experts_pipeline_degree = 0 args.context_parallel_kv_cache_policy = None args.context_parallel_cache_interval = 0 -- Gitee From c76ed508c7383a8443a598de7abb9f7bd28a081f Mon Sep 17 00:00:00 2001 From: pqhgitee Date: Thu, 13 Mar 2025 10:55:48 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E4=BB=A3=E7=A0=81=E6=A0=BC=E5=BC=8F?= =?UTF-8?q?=E8=A7=84=E8=8C=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindspeed_llm/training/arguments.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mindspeed_llm/training/arguments.py b/mindspeed_llm/training/arguments.py index 9b521d5b9..decbeb5a3 100644 --- a/mindspeed_llm/training/arguments.py +++ b/mindspeed_llm/training/arguments.py @@ -16,8 +16,10 @@ import os import argparse import torch + from pathlib import Path from functools import wraps + from mindspeed_llm.training.utils import print_rank0_by_args cur_file_dir = Path(__file__).absolute().parent @@ -236,7 +238,7 @@ def _add_coc_args(parser): group.add_argument('--disable-gloo-group', action='store_true', help='Replace the communication method of the DP group in the distributed optimizer from gloo to hccl.') group.add_argument('--hccl-slice-size', type=int, default=10 * 1024 * 1024, - help='data slice size on each dp rank in distributed optimizer') + help='data slice size on each dp rank in distributed optimizer') return parser @@ -972,7 +974,7 @@ def _add_dataset_args(parser): default=[], help='Additional keys need to be add from dataset.' ) - + return parser @@ -1029,7 +1031,7 @@ def _validate_recompute_args(args): raise AssertionError('uniform recomputation is not compatible with activation function recomputation.') if args.recompute_granularity == "selective": raise AssertionError('--recompute-activation-function is not compatible with selective recomputation.') - + if args.recompute_norm: if args.recompute_method == "uniform": raise AssertionError('uniform recomputation is not compatible with norm recomputation.') @@ -1037,7 +1039,7 @@ def _validate_recompute_args(args): raise AssertionError('--recompute-norm is not compatible with selective recomputation') if not args.use_mcore_models: raise AssertionError('--recompute-norm is only supported with mcore models') - + if args.swap_attention and args.swap_modules is None: if args.use_mcore_models: args.swap_modules = "input_layernorm,self_attention,pre_cross_attn_layernorm" -- Gitee From b8ce7404ce1ca4eed2d2931459801f02fae52ec4 Mon Sep 17 00:00:00 2001 From: pqhgitee Date: Thu, 13 Mar 2025 14:20:43 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E4=BB=A3=E7=A0=81=E6=A0=BC=E5=BC=8F?= =?UTF-8?q?=E8=A7=84=E8=8C=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindspeed_llm/training/arguments.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mindspeed_llm/training/arguments.py b/mindspeed_llm/training/arguments.py index decbeb5a3..adc36b2bc 100644 --- a/mindspeed_llm/training/arguments.py +++ b/mindspeed_llm/training/arguments.py @@ -15,11 +15,9 @@ import os import argparse -import torch - from pathlib import Path from functools import wraps - +import torch from mindspeed_llm.training.utils import print_rank0_by_args cur_file_dir = Path(__file__).absolute().parent -- Gitee