From 25ba6f5c29df25a4352b831699d6d32ea75db5d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=93=E4=BD=B3?= Date: Mon, 9 Jun 2025 16:15:15 +0800 Subject: [PATCH] feat: balanced moe --- .../pipeline_parallel/dualpipe/MTP_overlap.py | 14 ++++++++- .../transformer/moe/balanced_moe/adaptor.py | 29 +++++++++++++++++++ mindspeed_llm/tasks/megatron_adaptor.py | 3 ++ mindspeed_llm/training/arguments.py | 9 ++++++ 4 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 mindspeed_llm/core/transformer/moe/balanced_moe/adaptor.py diff --git a/mindspeed_llm/core/pipeline_parallel/dualpipe/MTP_overlap.py b/mindspeed_llm/core/pipeline_parallel/dualpipe/MTP_overlap.py index 2b6e21fb0..37f856aea 100644 --- a/mindspeed_llm/core/pipeline_parallel/dualpipe/MTP_overlap.py +++ b/mindspeed_llm/core/pipeline_parallel/dualpipe/MTP_overlap.py @@ -45,7 +45,19 @@ class TransformerMTPoverlap(torch.autograd.Function): inference_params=None, packed_seq_params=None, ): with torch.enable_grad(): - output, context_out, graph = transformer_layer_forward_moe(layer, + if hasattr(layer.mlp, 'hot_experts'): + from mindspeed.core.transformer.moe.balanced_moe.overlap_funcs.fwd import \ + transformer_layer_forward_balanced_moe + layer_forward_func = transformer_layer_forward_balanced_moe + elif hasattr(layer.mlp, 'experts'): + layer_forward_func = transformer_layer_forward_moe + else: + raise AttributeError( + "Layer's mlp must have either 'hot_experts' or 'experts' attribute. " + f"Actual attributes: {dir(layer.mlp)}" + ) + + output, context_out, graph = layer_forward_func(layer, hidden_states, attention_mask, context, diff --git a/mindspeed_llm/core/transformer/moe/balanced_moe/adaptor.py b/mindspeed_llm/core/transformer/moe/balanced_moe/adaptor.py new file mode 100644 index 000000000..4a5926581 --- /dev/null +++ b/mindspeed_llm/core/transformer/moe/balanced_moe/adaptor.py @@ -0,0 +1,29 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + + +def balanced_moe_register_patches(MegatronAdaptation): + args = MegatronAdaptation.get_args() + from mindspeed.core.transformer.moe.balanced_moe.transformer_config import CustomTransformerConfig + from mindspeed.core.transformer.moe.balanced_moe.moe_layer import BalancedMoELayer + MegatronAdaptation.register('megatron.core.transformer.moe.moe_layer.MoELayer', BalancedMoELayer) + MegatronAdaptation.register('megatron.core.transformer.transformer_config.TransformerConfig', + CustomTransformerConfig) + if args.moe_permutation_async_comm and args.moe_tp_extend_ep: + from mindspeed.core.transformer.moe.balanced_moe.token_dispatcher import MoEBalancedAlltoAllTokenDispatcher + MegatronAdaptation.register( + 'mindspeed.core.transformer.moe.balanced_moe.token_dispatcher.MoEBalancedAlltoAllTokenDispatcher.preprocess', + MoEBalancedAlltoAllTokenDispatcher.preprocess_async_comm_tp_extend_ep) + MegatronAdaptation.register( + 'mindspeed.core.transformer.moe.balanced_moe.token_dispatcher.MoEBalancedAlltoAllTokenDispatcher.token_permutation', + MoEBalancedAlltoAllTokenDispatcher.token_permutation_async_comm_tp_extend_ep) + MegatronAdaptation.register( + 'mindspeed.core.transformer.moe.balanced_moe.token_dispatcher.MoEBalancedAlltoAllTokenDispatcher.token_unpermutation', + MoEBalancedAlltoAllTokenDispatcher.token_unpermutation_tp_extend_ep) + elif args.moe_permutation_async_comm: + from mindspeed.core.transformer.moe.balanced_moe.token_dispatcher import MoEBalancedAlltoAllTokenDispatcher + MegatronAdaptation.register( + 'mindspeed.core.transformer.moe.balanced_moe.token_dispatcher.MoEBalancedAlltoAllTokenDispatcher.preprocess', + MoEBalancedAlltoAllTokenDispatcher.preprocess_async_comm) + MegatronAdaptation.register( + 'mindspeed.core.transformer.moe.balanced_moe.token_dispatcher.MoEBalancedAlltoAllTokenDispatcher.token_permutation', + MoEBalancedAlltoAllTokenDispatcher.token_permutation_async_comm) diff --git a/mindspeed_llm/tasks/megatron_adaptor.py b/mindspeed_llm/tasks/megatron_adaptor.py index d6bf26fa5..af51f643a 100644 --- a/mindspeed_llm/tasks/megatron_adaptor.py +++ b/mindspeed_llm/tasks/megatron_adaptor.py @@ -540,6 +540,9 @@ class CoreAdaptation(MegatronAdaptationABC): if not args.moe_alltoall_overlap_comm and not args.moe_allgather_overlap_comm and not args.moe_fb_overlap: MegatronAdaptation.register('megatron.core.transformer.moe.experts.GroupedMLP.forward', groupedmlp_forward) + if args.balanced_moe_experts: + from mindspeed_llm.core.transformer.moe.balanced_moe.adaptor import balanced_moe_register_patches + balanced_moe_register_patches(MegatronAdaptation) def patch_pipeline_parallel(self): from ..core.pipeline_parallel.p2p_communication import _batched_p2p_ops diff --git a/mindspeed_llm/training/arguments.py b/mindspeed_llm/training/arguments.py index 61b88be8d..2ba767202 100644 --- a/mindspeed_llm/training/arguments.py +++ b/mindspeed_llm/training/arguments.py @@ -416,6 +416,15 @@ def _add_moe_args(parser): 'in each pp stage.') group.add_argument('--moe-allgather-overlap-comm', action='store_true', default=False, help='moe_allgather_overlap_comm') + # balanced moe arguments + group.add_argument("--balanced-moe-experts", action='store_true', default=False, + help='Enable balanced MoE ExpertsBalance workload across EPs by duplicating experts.') + group.add_argument('--balanced-moe-hot-expert-num', type=int, default=3, + help='The number of duplicated hot experts to balance MoE workloads.') + group.add_argument('--trans-hot-expert', type=str, default="broadcast", + choices=['broadcast', 'alltoall'], help='trans hot expert algo') + group.add_argument('--trans-hot-expert-group-num', type=int, default=0, + help='trans hot expert group num') return parser -- Gitee