From 234a9a63e76a4635de7011cbcba612be27c00a52 Mon Sep 17 00:00:00 2001 From: little_nik Date: Thu, 19 Jun 2025 15:29:36 +0800 Subject: [PATCH] support ring attention parallel for context parallel --- mindspeed_llm/mindspore/mindspore_adaptor.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mindspeed_llm/mindspore/mindspore_adaptor.py b/mindspeed_llm/mindspore/mindspore_adaptor.py index a140ee171..0e468d925 100644 --- a/mindspeed_llm/mindspore/mindspore_adaptor.py +++ b/mindspeed_llm/mindspore/mindspore_adaptor.py @@ -344,6 +344,18 @@ class MindSporeAdaptation(MegatronAdaptationABC): from mindspeed_llm.mindspore.core.transformer.dot_product_attention import flash_attention_forward MindSporeAdaptation.register('mindspeed_llm.core.transformer.dot_product_attention.flash_attention_forward', flash_attention_forward) + from mindspeed.mindspore.core.context_parallel.utils import general_out_update + from mindspeed.mindspore.ops.fusion_attention_v2 import npu_fusion_attention, npu_fusion_attention_grad + MindSporeAdaptation.register('mindspeed.core.context_parallel.utils.general_out_update', general_out_update) + MindSporeAdaptation.register('mindspeed.core.context_parallel.ring_context_parallel.general_out_update', general_out_update) + MindSporeAdaptation.register('mindspeed.ops.fusion_attention_v2.npu_fusion_attention', npu_fusion_attention) + MindSporeAdaptation.register('mindspeed.ops.fusion_attention_v2.npu_fusion_attention_grad', npu_fusion_attention_grad) + try: + from mindspeed.mindspore.ops.npu_ring_attention_update import npu_ring_attention_update + MindSporeAdaptation.register('mindspeed.ops.npu_ring_attention_update.npu_ring_attention_update', npu_ring_attention_update) + except ImportError: + print("WARNING: npu_ring_attention_update not supported by mindspore now! Skip to import it") + @staticmethod def reparse_args(): """ -- Gitee