diff --git a/mindspeed_llm/tasks/megatron_adaptor.py b/mindspeed_llm/tasks/megatron_adaptor.py index 7767e65c1f598233d916e0126185a1e0ea35f846..fceed228af2be86fd99762c3b4361cc2d2ad6f0e 100644 --- a/mindspeed_llm/tasks/megatron_adaptor.py +++ b/mindspeed_llm/tasks/megatron_adaptor.py @@ -428,7 +428,7 @@ class CoreAdaptation(MegatronAdaptationABC): MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute', permute) MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.unpermute', unpermute) - if args.moe_tp_extend_ep: + if args.moe_tp_extend_ep or args.moe_hierarchical_alltoallv: from mindspeed.core.transformer.moe.token_dispatcher import ( preprocess_tp_extend_ep, alltoall_token_unpermutation_tp_extend_ep, alltoall_token_permutation_tp_extend_ep diff --git a/mindspeed_llm/training/arguments.py b/mindspeed_llm/training/arguments.py index cbc6b5beded097a19fd6108936821de8c6398bad..adc36b2bc3509fcfa540b0c5bed0d69183a23940 100644 --- a/mindspeed_llm/training/arguments.py +++ b/mindspeed_llm/training/arguments.py @@ -17,6 +17,7 @@ import os import argparse from pathlib import Path from functools import wraps +import torch from mindspeed_llm.training.utils import print_rank0_by_args cur_file_dir = Path(__file__).absolute().parent @@ -153,6 +154,8 @@ def _add_deepseek_moe_args(parser): 'global batch, where the bias is increased for the experts with less assigned tokens and ' 'decreased for the experts with more assigned tokens. ' 'The default value 1e-3 is same as that used in DeepSeekV3.') + group.add_argument('--moe-hierarchical-alltoallv', action='store_true', + help='Reduce communication cost between nodes') return parser @@ -233,7 +236,7 @@ def _add_coc_args(parser): group.add_argument('--disable-gloo-group', action='store_true', help='Replace the communication method of the DP group in the distributed optimizer from gloo to hccl.') group.add_argument('--hccl-slice-size', type=int, default=10 * 1024 * 1024, - help='data slice size on each dp rank in distributed optimizer') + help='data slice size on each dp rank in distributed optimizer') return parser @@ -969,7 +972,7 @@ def _add_dataset_args(parser): default=[], help='Additional keys need to be add from dataset.' ) - + return parser @@ -1026,7 +1029,7 @@ def _validate_recompute_args(args): raise AssertionError('uniform recomputation is not compatible with activation function recomputation.') if args.recompute_granularity == "selective": raise AssertionError('--recompute-activation-function is not compatible with selective recomputation.') - + if args.recompute_norm: if args.recompute_method == "uniform": raise AssertionError('uniform recomputation is not compatible with norm recomputation.') @@ -1034,7 +1037,7 @@ def _validate_recompute_args(args): raise AssertionError('--recompute-norm is not compatible with selective recomputation') if not args.use_mcore_models: raise AssertionError('--recompute-norm is only supported with mcore models') - + if args.swap_attention and args.swap_modules is None: if args.use_mcore_models: args.swap_modules = "input_layernorm,self_attention,pre_cross_attn_layernorm" @@ -1119,6 +1122,14 @@ def _validate_moe_args(args): raise AssertionError('`--moe-zero-memory` does not support full recomputation for now.') if args.shared_expert_gate and args.gradient_accumulation_fusion: raise AssertionError('args.shared_expert_gate does not support gradient_accumulation_fusion.') + if args.moe_hierarchical_alltoallv: + tp = args.tensor_model_parallel_size + ep = args.expert_model_parallel_size + if ((not args.moe_alltoall_overlap_comm) or (not args.moe_tp_extend_ep) or tp <= 1 or tp > torch.npu.device_count() or + ep * tp <= torch.npu.device_count() or args.world_size <= torch.npu.device_count()): + raise AssertionError( + '`--moe-hierarchical-alltoallv` must have `--moe-alltoall-overlap-comm` on and ' + '`--moe-tp-extend-ep` on and 1 < tp <= torch.npu.device_count() and cross-device communication') def _validate_mla(args): @@ -1326,7 +1337,6 @@ def _add_dummy_args(args): args.attention_mask_type = args.cp_attention_mask_type args.hccl_group_buffer_adaptive = False args.moe_bmm_mc2 = False - args.moe_hierarchical_alltoallv = False args.moe_experts_pipeline_degree = 0 args.context_parallel_kv_cache_policy = None args.context_parallel_cache_interval = 0