From d4961397bdfb5a63f1e618d5e7b83027226f968b Mon Sep 17 00:00:00 2001 From: xinyuan Date: Thu, 7 Aug 2025 20:25:30 +0800 Subject: [PATCH] alltoalloverlap0.12 --- .../glm45-moe/pretrain_glm45_moe_106b_4k_A3_ms.sh | 2 +- mindspeed_llm/mindspore/mindspore_adaptor.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/mindspore/glm45-moe/pretrain_glm45_moe_106b_4k_A3_ms.sh b/examples/mindspore/glm45-moe/pretrain_glm45_moe_106b_4k_A3_ms.sh index 2ff137c89..4809c7412 100644 --- a/examples/mindspore/glm45-moe/pretrain_glm45_moe_106b_4k_A3_ms.sh +++ b/examples/mindspore/glm45-moe/pretrain_glm45_moe_106b_4k_A3_ms.sh @@ -29,6 +29,7 @@ DISTRIBUTED_ARGS=" MOE_ARGS=" --moe-grouped-gemm \ + --moe-alltoall-overlap-comm \ --moe-permutation-async-comm \ --moe-token-dispatcher-type alltoall_seq \ --first-k-dense-replace 1 \ @@ -63,7 +64,6 @@ GPT_ARGS=" --kv-channels 128 \ --use-fused-rmsnorm \ --use-fused-swiglu \ - --overlap-grad-reduce \ --use-distributed-optimizer \ --num-layers 48 \ --hidden-size 4096 \ diff --git a/mindspeed_llm/mindspore/mindspore_adaptor.py b/mindspeed_llm/mindspore/mindspore_adaptor.py index 2835535cb..6b3fbbda5 100644 --- a/mindspeed_llm/mindspore/mindspore_adaptor.py +++ b/mindspeed_llm/mindspore/mindspore_adaptor.py @@ -49,17 +49,17 @@ class MindSporeAdaptation(MegatronAdaptationABC): MindSporeAdaptation.register('megatron.core.transformer.moe.moe_layer.MoELayer.forward', moe_layer_forward) if args.moe_permutation_async_comm: - if args.moe_token_dispatcher_type == 'alltoall': + if args.moe_token_dispatcher_type == 'alltoall_seq': if args.moe_alltoall_overlap_comm: from mindspeed.mindspore.core.transformer.moe.legacy_a2a_token_dispatcher import alltoall_token_permutation_new, \ alltoall_token_unpermutation_new from mindspeed.mindspore.core.transformer.moe.experts import group_mlp_forward - + MindSporeAdaptation.register('megatron.core.transformer.moe.experts.GroupedMLP.forward', group_mlp_forward) MindSporeAdaptation.register( - 'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.token_permutation', + 'megatron.core.transformer.moe.legacy_a2a_token_dispatcher.MoEAlltoAllSEQTokenDispatcher.token_permutation', alltoall_token_permutation_new) MindSporeAdaptation.register( - 'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.token_unpermutation', + 'megatron.core.transformer.moe.legacy_a2a_token_dispatcher.MoEAlltoAllSEQTokenDispatcher.token_unpermutation', alltoall_token_unpermutation_new) if hasattr(args, 'use_fused_moe_token_permute_and_unpermute') and args.use_fused_moe_token_permute_and_unpermute and not args.moe_expert_capacity_factor: @@ -160,6 +160,9 @@ class MindSporeAdaptation(MegatronAdaptationABC): apply_seq_aux_loss) MindSporeAdaptation.register('megatron.core.transformer.moe.router.TopKRouter.gating', topk_router_gating_func) + from mindspeed.mindspore.core.transformer.moe.comm_utils import async_all_to_all + MindSporeAdaptation.register('mindspeed.core.transformer.moe.comm_utils.async_all_to_all', + async_all_to_all) if args.moe_fb_overlap: from mindspeed_llm.mindspore.tasks.models.transformer.multi_head_latent_attention import mla_forward MindSporeAdaptation.register('mindspeed_llm.tasks.models.transformer.multi_head_latent_attention.MultiHeadLatentAttention.forward', @@ -264,9 +267,6 @@ class MindSporeAdaptation(MegatronAdaptationABC): overlap_matmul) - from mindspeed.mindspore.core.transformer.moe.comm_utils import async_all_to_all - MindSporeAdaptation.register('mindspeed.core.transformer.moe.comm_utils.async_all_to_all', - async_all_to_all) from mindspeed.mindspore.core.pipeline_parallel.fb_overlap.modules.token_dispatcher import alltoall_token_perm1, overlap_stream MindSporeAdaptation.register('mindspeed.core.pipeline_parallel.fb_overlap.modules.token_dispatcher.alltoall_token_perm1', -- Gitee