diff --git a/mindspeed_llm/mindspore/mindspore_adaptor.py b/mindspeed_llm/mindspore/mindspore_adaptor.py index cd3992a2b5dc57d05bf72382a0ae204708ce49bc..d0d2c5015b3b455fd09ee7265c9e6de125237a85 100644 --- a/mindspeed_llm/mindspore/mindspore_adaptor.py +++ b/mindspeed_llm/mindspore/mindspore_adaptor.py @@ -159,6 +159,12 @@ class MindSporeAdaptation(MegatronAdaptationABC): apply_seq_aux_loss) MindSporeAdaptation.register('megatron.core.transformer.moe.router.TopKRouter.gating', topk_router_gating_func) + from mindspeed.mindspore.core.transformer.transformer import core_mlp_forward_wrapper + MindSporeAdaptation.register('megatron.core.transformer.mlp.MLP.forward', + core_mlp_forward_wrapper) + + + if args.moe_fb_overlap: from mindspeed_llm.mindspore.tasks.models.transformer.multi_head_latent_attention import mla_forward MindSporeAdaptation.register('mindspeed_llm.tasks.models.transformer.multi_head_latent_attention.MultiHeadLatentAttention.forward', @@ -267,10 +273,6 @@ class MindSporeAdaptation(MegatronAdaptationABC): MindSporeAdaptation.register('mindspeed.core.transformer.moe.comm_utils.async_all_to_all', async_all_to_all) - from mindspeed.mindspore.core.transformer.transformer import core_mlp_forward_wrapper - MindSporeAdaptation.register('megatron.core.transformer.mlp.MLP.forward', - core_mlp_forward_wrapper) - from mindspeed.mindspore.core.pipeline_parallel.fb_overlap.modules.token_dispatcher import alltoall_token_perm1, overlap_stream MindSporeAdaptation.register('mindspeed.core.pipeline_parallel.fb_overlap.modules.token_dispatcher.alltoall_token_perm1', alltoall_token_perm1) @@ -331,7 +333,10 @@ class MindSporeAdaptation(MegatronAdaptationABC): if args.gemm_gradient_accumulation_fusion: from mindspeed.mindspore.ops.npu_groupmatmul_add import npu_groupmatmul_add_fp32 - MindSporeAdaptation.register('mindspeed.ops.npu_groupmatmul_add', npu_groupmatmul_add_fp32) + MindSporeAdaptation.register('mindspeed.ops.npu_groupmatmul_add.npu_groupmatmul_add_fp32', npu_groupmatmul_add_fp32) + from mindspeed.mindspore.ops.npu_matmul_add import npu_matmul_add_fp32 + MindSporeAdaptation.register('fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32', npu_matmul_add_fp32) + MindSporeAdaptation.register('mindspeed.ops.npu_matmul_add.npu_matmul_add_fp32', npu_matmul_add_fp32) if args.use_moba_attn: from mindspeed_llm.mindspore.core.transformer.dot_product_attention import flash_attention_forward