From 0a99725e0a9510957d0b51b6e947ca92b36d3296 Mon Sep 17 00:00:00 2001 From: mhh001 Date: Fri, 6 Jun 2025 15:15:23 +0800 Subject: [PATCH 1/3] [refactor] moe tp extend ep patch refactor --- mindspeed_llm/features_manager/__init__.py | 2 + .../common/moe_tp_extend_ep.py | 66 +++++++++++++++++++ mindspeed_llm/tasks/megatron_adaptor.py | 5 -- mindspeed_llm/training/arguments.py | 2 - 4 files changed, 68 insertions(+), 7 deletions(-) create mode 100644 mindspeed_llm/features_manager/common/moe_tp_extend_ep.py diff --git a/mindspeed_llm/features_manager/__init__.py b/mindspeed_llm/features_manager/__init__.py index 7a4a7bafb..ba6a128cf 100644 --- a/mindspeed_llm/features_manager/__init__.py +++ b/mindspeed_llm/features_manager/__init__.py @@ -9,6 +9,7 @@ from mindspeed_llm.features_manager.models.mamba import MambaModel from mindspeed_llm.features_manager.communication.coc import AscendCocFeature from mindspeed_llm.features_manager.common.data import DataFeature from mindspeed_llm.features_manager.common.moe_router import MOERouter +from mindspeed_llm.features_manager.common.moe_tp_extend_ep import MoeTpExtendEp FEATURES_LIST = [ # MindSpeed Legacy Features @@ -25,6 +26,7 @@ FEATURES_LIST = [ LanguageModelEmbeddingFeature(), MambaModel(), MOERouter(), + MoeTpExtendEp(), AscendCocFeature() # MindSpeed-LLM Legacy Features diff --git a/mindspeed_llm/features_manager/common/moe_tp_extend_ep.py b/mindspeed_llm/features_manager/common/moe_tp_extend_ep.py new file mode 100644 index 000000000..9a6e79cdc --- /dev/null +++ b/mindspeed_llm/features_manager/common/moe_tp_extend_ep.py @@ -0,0 +1,66 @@ +from mindspeed.features_manager.feature import MindSpeedFeature + + +class MoeTpExtendEp(MindSpeedFeature): + def __init__(self): + super(MoeTpExtendEp, self).__init__(feature_name="moe_tp_extend_ep", optimization_level=0) + + def register_args(self, parser): + group = parser.add_argument_group(title=self.feature_name) + group.add_argument("--moe-tp-extend-ep", action='store_true', + help="use tp group to extend experts parallism instead of sharding weight tensor of experts in tp group") + + + + def validate_args(self, args): + self._validate_moe_args(args) + self._validate_group_limited_greedy(args) + self._validate_aux_loss_free(args) + + def _validate_moe_args(self, args): + from mindspeed_llm.training.utils import print_rank0_by_args + if args.moe_expert_capacity_factor is not None: + if args.moe_token_dispatcher_type != "alltoall": + raise ValueError(f'moe_expert_capacity_factor only works with alltoall token dispatcher') + if args.moe_expert_capacity_factor < 0: + args.moe_expert_capacity_factor = None + print_rank0_by_args( + f'When moe_expert_capacity_factor < 0, no token would be drop, so moe_expert_capacity_factor should be set to false.') + if args.moe_router_load_balancing_type not in ["aux_loss", "none"]: + raise ValueError(f'moe_expert_capacity_factor only works with aux_loss or none load balancing') + if args.moe_expert_capacity_factor is None and args.moe_pad_expert_input_to_capacity: + raise ValueError(f'moe_expert_capacity_factor must be set to use moe_pad_expert_input_to_capacity') + if args.shared_expert_gate_output_dimension != 1 and args.shared_expert_gate_output_dimension != args.hidden_size: + raise AssertionError('shared expert gate output dimension can only be configured with 1 or hidden_size') + if hasattr(args, + 'use_fused_moe_token_permute_and_unpermute') and args.use_fused_moe_token_permute_and_unpermute: + raise AssertionError( + 'moe_expert_capacity_factor mode does not support use_fused_moe_token_permute_and_unpermute') + + def _validate_group_limited_greedy(self, args): + if args.moe_router_load_balancing_type == "group_limited_greedy": + if args.topk_group is None: + raise AssertionError('The parameter topk-group should be set when use group_limited_greedy.') + elif args.routed_scaling_factor is None: + raise AssertionError( + 'The parameter routed_scaling_factor should be set when use multi_head_latent_attention.') + elif args.topk_group >= args.expert_model_parallel_size: + raise AssertionError('The topk group ({}) should be less than n-group(EP)({}).'.format(args.topk_group, + args.expert_model_parallel_size)) + + def _validate_aux_loss_free(self, args): + if args.moe_router_enable_expert_bias and args.moe_router_score_function != "sigmoid": + raise ValueError( + "Expert bias for aux-loss-free routing only supports sigmoid score function." + "Please set --moe-router-score-function sigmoid for sigmoid score function." + ) + + def register_patches(self, patch_manager, args): + + # For moe tp extend ep ckpt + if args.moe_tp_extend_ep: + from mindspeed.core.transformer.moe.moe_layer import base_moe_init_wrapper + patch_manager.register_patch('megatron.core.transformer.moe.moe_layer.BaseMoELayer.__init__', + base_moe_init_wrapper) + + diff --git a/mindspeed_llm/tasks/megatron_adaptor.py b/mindspeed_llm/tasks/megatron_adaptor.py index 5a3c51448..d534ed446 100644 --- a/mindspeed_llm/tasks/megatron_adaptor.py +++ b/mindspeed_llm/tasks/megatron_adaptor.py @@ -396,11 +396,6 @@ class CoreAdaptation(MegatronAdaptationABC): dualpipe_register_patches(MegatronAdaptation) args = MegatronAdaptation.get_args() - # For moe tp extend ep ckpt - if args.moe_tp_extend_ep: - from mindspeed.core.transformer.moe.moe_layer import base_moe_init_wrapper - MegatronAdaptation.register('megatron.core.transformer.moe.moe_layer.BaseMoELayer.__init__', - base_moe_init_wrapper) if args.moe_permutation_async_comm: if args.moe_token_dispatcher_type == 'allgather': diff --git a/mindspeed_llm/training/arguments.py b/mindspeed_llm/training/arguments.py index c7538a9df..7801cf823 100644 --- a/mindspeed_llm/training/arguments.py +++ b/mindspeed_llm/training/arguments.py @@ -407,8 +407,6 @@ def _add_moe_args(parser): help='moe_alltoall_overlap_comm') group.add_argument("--cla-share-factor", type=int, default=1, help="Cross-Layer Attention share kv between cla-share-factor layers") - group.add_argument("--moe-tp-extend-ep", action='store_true', - help="use tp group to extend experts parallism instead of sharding weight tensor of experts in tp group") group.add_argument("--moe-zero-memory", type=str, default='disable', choices=['disable', 'level0', 'level1'], help="Save activation memory in moe layer.") -- Gitee From eb9b8fe15c137b8b195b484a766a075ef97b7f95 Mon Sep 17 00:00:00 2001 From: mhh001 Date: Fri, 6 Jun 2025 15:25:23 +0800 Subject: [PATCH 2/3] [refactor] moe tp extend ep patch refactor --- .../common/moe_tp_extend_ep.py | 50 +------------------ 1 file changed, 1 insertion(+), 49 deletions(-) diff --git a/mindspeed_llm/features_manager/common/moe_tp_extend_ep.py b/mindspeed_llm/features_manager/common/moe_tp_extend_ep.py index 9a6e79cdc..ca83b9ed8 100644 --- a/mindspeed_llm/features_manager/common/moe_tp_extend_ep.py +++ b/mindspeed_llm/features_manager/common/moe_tp_extend_ep.py @@ -10,57 +10,9 @@ class MoeTpExtendEp(MindSpeedFeature): group.add_argument("--moe-tp-extend-ep", action='store_true', help="use tp group to extend experts parallism instead of sharding weight tensor of experts in tp group") - - - def validate_args(self, args): - self._validate_moe_args(args) - self._validate_group_limited_greedy(args) - self._validate_aux_loss_free(args) - - def _validate_moe_args(self, args): - from mindspeed_llm.training.utils import print_rank0_by_args - if args.moe_expert_capacity_factor is not None: - if args.moe_token_dispatcher_type != "alltoall": - raise ValueError(f'moe_expert_capacity_factor only works with alltoall token dispatcher') - if args.moe_expert_capacity_factor < 0: - args.moe_expert_capacity_factor = None - print_rank0_by_args( - f'When moe_expert_capacity_factor < 0, no token would be drop, so moe_expert_capacity_factor should be set to false.') - if args.moe_router_load_balancing_type not in ["aux_loss", "none"]: - raise ValueError(f'moe_expert_capacity_factor only works with aux_loss or none load balancing') - if args.moe_expert_capacity_factor is None and args.moe_pad_expert_input_to_capacity: - raise ValueError(f'moe_expert_capacity_factor must be set to use moe_pad_expert_input_to_capacity') - if args.shared_expert_gate_output_dimension != 1 and args.shared_expert_gate_output_dimension != args.hidden_size: - raise AssertionError('shared expert gate output dimension can only be configured with 1 or hidden_size') - if hasattr(args, - 'use_fused_moe_token_permute_and_unpermute') and args.use_fused_moe_token_permute_and_unpermute: - raise AssertionError( - 'moe_expert_capacity_factor mode does not support use_fused_moe_token_permute_and_unpermute') - - def _validate_group_limited_greedy(self, args): - if args.moe_router_load_balancing_type == "group_limited_greedy": - if args.topk_group is None: - raise AssertionError('The parameter topk-group should be set when use group_limited_greedy.') - elif args.routed_scaling_factor is None: - raise AssertionError( - 'The parameter routed_scaling_factor should be set when use multi_head_latent_attention.') - elif args.topk_group >= args.expert_model_parallel_size: - raise AssertionError('The topk group ({}) should be less than n-group(EP)({}).'.format(args.topk_group, - args.expert_model_parallel_size)) - - def _validate_aux_loss_free(self, args): - if args.moe_router_enable_expert_bias and args.moe_router_score_function != "sigmoid": - raise ValueError( - "Expert bias for aux-loss-free routing only supports sigmoid score function." - "Please set --moe-router-score-function sigmoid for sigmoid score function." - ) - def register_patches(self, patch_manager, args): - # For moe tp extend ep ckpt if args.moe_tp_extend_ep: from mindspeed.core.transformer.moe.moe_layer import base_moe_init_wrapper patch_manager.register_patch('megatron.core.transformer.moe.moe_layer.BaseMoELayer.__init__', - base_moe_init_wrapper) - - + base_moe_init_wrapper) -- Gitee From 963c969a4734568005308f35b8189d70ed77c06e Mon Sep 17 00:00:00 2001 From: mhh001 Date: Mon, 9 Jun 2025 14:24:37 +0800 Subject: [PATCH 3/3] [refactor] moe tp extend ep patch refactor --- mindspeed_llm/features_manager/__init__.py | 8 ++++---- mindspeed_llm/features_manager/common/moe_router.py | 4 ++-- mindspeed_llm/features_manager/common/moe_tp_extend_ep.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mindspeed_llm/features_manager/__init__.py b/mindspeed_llm/features_manager/__init__.py index ba6a128cf..9a4d08b27 100644 --- a/mindspeed_llm/features_manager/__init__.py +++ b/mindspeed_llm/features_manager/__init__.py @@ -8,8 +8,8 @@ from mindspeed_llm.features_manager.models.mamba import MambaModel from mindspeed_llm.features_manager.communication.coc import AscendCocFeature from mindspeed_llm.features_manager.common.data import DataFeature -from mindspeed_llm.features_manager.common.moe_router import MOERouter -from mindspeed_llm.features_manager.common.moe_tp_extend_ep import MoeTpExtendEp +from mindspeed_llm.features_manager.common.moe_router import MoERouter +from mindspeed_llm.features_manager.common.moe_tp_extend_ep import MoETpExtendEp FEATURES_LIST = [ # MindSpeed Legacy Features @@ -25,8 +25,8 @@ FEATURES_LIST = [ RotaryPositionEmbeddingFeature(), LanguageModelEmbeddingFeature(), MambaModel(), - MOERouter(), - MoeTpExtendEp(), + MoERouter(), + MoETpExtendEp(), AscendCocFeature() # MindSpeed-LLM Legacy Features diff --git a/mindspeed_llm/features_manager/common/moe_router.py b/mindspeed_llm/features_manager/common/moe_router.py index 156f6934c..74aef5bc5 100644 --- a/mindspeed_llm/features_manager/common/moe_router.py +++ b/mindspeed_llm/features_manager/common/moe_router.py @@ -1,9 +1,9 @@ from mindspeed.features_manager.feature import MindSpeedFeature -class MOERouter(MindSpeedFeature): +class MoERouter(MindSpeedFeature): def __init__(self): - super(MOERouter, self).__init__(feature_name="moe_router", optimization_level=0) + super(MoERouter, self).__init__(feature_name="moe_router", optimization_level=0) def register_args(self, parser): group = parser.add_argument_group(title=self.feature_name) diff --git a/mindspeed_llm/features_manager/common/moe_tp_extend_ep.py b/mindspeed_llm/features_manager/common/moe_tp_extend_ep.py index ca83b9ed8..d8c089d40 100644 --- a/mindspeed_llm/features_manager/common/moe_tp_extend_ep.py +++ b/mindspeed_llm/features_manager/common/moe_tp_extend_ep.py @@ -1,9 +1,9 @@ from mindspeed.features_manager.feature import MindSpeedFeature -class MoeTpExtendEp(MindSpeedFeature): +class MoETpExtendEp(MindSpeedFeature): def __init__(self): - super(MoeTpExtendEp, self).__init__(feature_name="moe_tp_extend_ep", optimization_level=0) + super(MoETpExtendEp, self).__init__(feature_name="moe_tp_extend_ep", optimization_level=0) def register_args(self, parser): group = parser.add_argument_group(title=self.feature_name) -- Gitee