From 4b7bc4ad5d6b341bb1d6d0583c950249e6872551 Mon Sep 17 00:00:00 2001 From: zihao-intuition Date: Mon, 16 Jun 2025 15:15:40 +0800 Subject: [PATCH] fix index_copy_ fix code problem push push copyright --- .../core/transformer/moe/moe_utils.py | 40 +++++++++++++++++++ mindspeed_llm/mindspore/mindspore_adaptor.py | 3 ++ 2 files changed, 43 insertions(+) create mode 100644 mindspeed_llm/mindspore/core/transformer/moe/moe_utils.py diff --git a/mindspeed_llm/mindspore/core/transformer/moe/moe_utils.py b/mindspeed_llm/mindspore/core/transformer/moe/moe_utils.py new file mode 100644 index 000000000..68feb0fad --- /dev/null +++ b/mindspeed_llm/mindspore/core/transformer/moe/moe_utils.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved. +import torch +from megatron.core.transformer.moe.moe_utils import unpermute_with_padded_tokens + + +def unpermute( + permuted_tokens: torch.Tensor, + sorted_indices: torch.Tensor, + probs: torch.Tensor = None, + padded_mode: bool = False, + restore_shape: torch.Size = None, +): + if padded_mode: + return unpermute_with_padded_tokens( + permuted_tokens, sorted_indices, probs, restore_shape=restore_shape + ) + + assert sorted_indices.numel() == permuted_tokens.size(0) + if probs is not None: + # Unpermute and merge the tokens with their probabilities + num_unpermuted_tokens = probs.numel() + topk = probs.size(1) + else: + # Unpermute the tokens without merge + num_unpermuted_tokens = permuted_tokens.size(0) + topk = 1 + + unpermuted_tokens = torch.zeros( + [num_unpermuted_tokens, permuted_tokens.shape[-1]], + dtype=permuted_tokens.dtype, + device=permuted_tokens.device, + ) + unpermuted_tokens.index_add_(0, sorted_indices, permuted_tokens) + unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1)) + if probs is not None: + unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1) + unpermuted_tokens = unpermuted_tokens.sum(dim=1) + + return unpermuted_tokens \ No newline at end of file diff --git a/mindspeed_llm/mindspore/mindspore_adaptor.py b/mindspeed_llm/mindspore/mindspore_adaptor.py index f0791fc70..b08b3cf3d 100644 --- a/mindspeed_llm/mindspore/mindspore_adaptor.py +++ b/mindspeed_llm/mindspore/mindspore_adaptor.py @@ -55,6 +55,9 @@ class MindSporeAdaptation(MegatronAdaptationABC): groupedmlp_init_wrapper) MindSporeAdaptation.register('megatron.core.transformer.moe.moe_layer.MoELayer.forward', moe_layer_forward) + from .core.transformer.moe.moe_utils import unpermute + MindSporeAdaptation.register('mindspeed.core.transformer.moe.moe_utils.unpermute', unpermute) + if args.moe_permutation_async_comm: if args.moe_token_dispatcher_type == 'alltoall': if args.moe_alltoall_overlap_comm: -- Gitee