diff --git a/mindspeed_llm/mindspore/core/transformer/moe/moe_utils.py b/mindspeed_llm/mindspore/core/transformer/moe/moe_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..68feb0fadcecdafe4117d21b1d884d02afbb34d0 --- /dev/null +++ b/mindspeed_llm/mindspore/core/transformer/moe/moe_utils.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved. +import torch +from megatron.core.transformer.moe.moe_utils import unpermute_with_padded_tokens + + +def unpermute( + permuted_tokens: torch.Tensor, + sorted_indices: torch.Tensor, + probs: torch.Tensor = None, + padded_mode: bool = False, + restore_shape: torch.Size = None, +): + if padded_mode: + return unpermute_with_padded_tokens( + permuted_tokens, sorted_indices, probs, restore_shape=restore_shape + ) + + assert sorted_indices.numel() == permuted_tokens.size(0) + if probs is not None: + # Unpermute and merge the tokens with their probabilities + num_unpermuted_tokens = probs.numel() + topk = probs.size(1) + else: + # Unpermute the tokens without merge + num_unpermuted_tokens = permuted_tokens.size(0) + topk = 1 + + unpermuted_tokens = torch.zeros( + [num_unpermuted_tokens, permuted_tokens.shape[-1]], + dtype=permuted_tokens.dtype, + device=permuted_tokens.device, + ) + unpermuted_tokens.index_add_(0, sorted_indices, permuted_tokens) + unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1)) + if probs is not None: + unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1) + unpermuted_tokens = unpermuted_tokens.sum(dim=1) + + return unpermuted_tokens \ No newline at end of file diff --git a/mindspeed_llm/mindspore/mindspore_adaptor.py b/mindspeed_llm/mindspore/mindspore_adaptor.py index f0791fc7064905c8ec51b60c10b1abf66846351f..b08b3cf3de7d17f9d93f16ede18ec3611b131571 100644 --- a/mindspeed_llm/mindspore/mindspore_adaptor.py +++ b/mindspeed_llm/mindspore/mindspore_adaptor.py @@ -55,6 +55,9 @@ class MindSporeAdaptation(MegatronAdaptationABC): groupedmlp_init_wrapper) MindSporeAdaptation.register('megatron.core.transformer.moe.moe_layer.MoELayer.forward', moe_layer_forward) + from .core.transformer.moe.moe_utils import unpermute + MindSporeAdaptation.register('mindspeed.core.transformer.moe.moe_utils.unpermute', unpermute) + if args.moe_permutation_async_comm: if args.moe_token_dispatcher_type == 'alltoall': if args.moe_alltoall_overlap_comm: