diff --git a/mindspeed_llm/mindspore/mindspore_adaptor.py b/mindspeed_llm/mindspore/mindspore_adaptor.py index d73080d6a75959fdf89b2aa79f8418f472d74bd2..8336c7e1c4e7d9de295fcff871e3c1999b9ca87e 100644 --- a/mindspeed_llm/mindspore/mindspore_adaptor.py +++ b/mindspeed_llm/mindspore/mindspore_adaptor.py @@ -298,3 +298,20 @@ class MindSporeAdaptation(MegatronAdaptationABC): from ..mindspore.core.transformer.module import set_is_first_microbatch MindSporeAdaptation.register('megatron.core.transformer.module.MegatronModule.set_is_first_microbatch', set_is_first_microbatch) + + + if args.moe_zerc: + from mindspeed.mindspore.core.transformer.moe.moe_zerc.fwdbwd import transformer_layer_forward_moe_backward_dense_overlaping_zerc, transformer_layer_forward_moe_backward_moe_overlaping_zerc + MindSporeAdaptation.register('mindspeed.core.pipeline_parallel.fb_overlap.overlap_funcs.fwdbwd.transformer_layer_forward_moe_backward_dense_overlaping', + transformer_layer_forward_moe_backward_dense_overlaping_zerc) + MindSporeAdaptation.register('mindspeed.core.pipeline_parallel.fb_overlap.overlap_funcs.fwdbwd.transformer_layer_forward_moe_backward_moe_overlaping', + transformer_layer_forward_moe_backward_moe_overlaping_zerc) + from mindspeed.mindspore.core.transformer.moe.moe_zerc.token_dispatcher import zerc_alltoall_token_perm1, zerc_alltoall_token_perm2, zerc_alltoall_token_unperm1, zerc_alltoall_token_unperm2 + MindSporeAdaptation.register('mindspeed.core.pipeline_parallel.fb_overlap.modules.token_dispatcher.alltoall_token_perm1', + zerc_alltoall_token_perm1) + MindSporeAdaptation.register('mindspeed.core.pipeline_parallel.fb_overlap.modules.token_dispatcher.alltoall_token_perm2', + zerc_alltoall_token_perm2) + MindSporeAdaptation.register('mindspeed.core.pipeline_parallel.fb_overlap.modules.token_dispatcher.alltoall_token_unperm1', + zerc_alltoall_token_unperm1) + MindSporeAdaptation.register('mindspeed.core.pipeline_parallel.fb_overlap.modules.token_dispatcher.alltoall_token_unperm2', + zerc_alltoall_token_unperm2)