diff --git a/mindformers/pynative/distributed/expert_parallel.py b/mindformers/pynative/distributed/expert_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..7bb47760b241fa63075c44d72715a1f105694dc1 --- /dev/null +++ b/mindformers/pynative/distributed/expert_parallel.py @@ -0,0 +1,254 @@ +# Copyright 2026 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# implementation of Tensor Parallel for the GroupedMLP in MoE +import numpy as np + +from mindspore import mint, nn, ops, Tensor +from mindspore.common import dtype as mstype + +from hyper_parallel import DTensor +from hyper_parallel.core.shard.api import shard +from hyper_parallel.core.placement_types import Shard +from hyper_parallel.core.device_mesh import DeviceMesh + +from mindformers.pynative.distributed.style import ParallelStyle +from mindformers.pynative.transformers.moe.experts import GroupedMLP + + +class AlltoAll(nn.Cell): + """AlltoAll operation wrapper.""" + + def __init__(self, split_count, split_dim, concat_dim, group=None): + super().__init__() + self.group_is_none = group is None + if not self.group_is_none: + self.ops = ops.AlltoAll(split_count, split_dim, concat_dim, group) + + def construct(self, input_x): + if self.group_is_none: + return input_x + input_x = self.ops(input_x) + return input_x + + +class AlltoAllV(nn.Cell): + """AlltoAllV operation wrapper.""" + + def __init__(self, group=None, block_size=1): + super().__init__() + self.group_is_none = group is None + if not self.group_is_none: + self.ops = ops.AlltoAllV(group=group, block_size=block_size) + + def construct(self, input_x, send_numel_list, recv_numel_list): + if self.group_is_none: + return input_x + tensor = self.ops(input_x, send_numel_list, recv_numel_list) + return tensor + + +class ExpertParallel(ParallelStyle): + def __init__(self): + super().__init__() + # self.input_splits = None + # self.output_splits = None + # self.input_shape = None + # self.permuted_indices = None + self.ctx = None + + self.cast = ops.cast + self.shape = ops.shape + self.reshape = mint.reshape + self.transpose = mint.transpose + self.sort = mint.sort + self.fmod = mint.fmod + self.index_select = mint.index_select + # self.one_hot = mint.nn.functional.one_hot + self.sum = mint.sum + self.cumsum = mint.cumsum + self.mul = mint.mul + self.d2h = ops.MoveTo() + + # performing all-to-all dispatch on the input + def _token_dispatch(self, cell, inputs, device_mesh): + tokens, probs, topk_indices, num_tokens_per_expert = inputs + ep_degree = device_mesh.mesh_shape[0] + num_experts = self.shape(num_tokens_per_expert)[0] + + tokens_shape = self.shape(tokens) + tokens = self.reshape(tokens, (-1, tokens_shape[-1])) + topk_indices_shape = self.shape(topk_indices) + topk_indices = self.transpose(topk_indices, 1, 0) # (B*S, k) --> (k, B*S) + topk_indices = self.reshape(topk_indices, (-1,)) # (k, B*S) --> (k*T,) + + sorted_topk_indices, token_indices_experts_sorted = self.sort(self.cast(topk_indices, mstype.float32), dim=-1) + + _, unsort_token_indices_experts = self.sort(self.cast(token_indices_experts_sorted, mstype.float32), dim=-1) + unsort_token_indices_experts = self.reshape(unsort_token_indices_experts, + (topk_indices_shape[1], topk_indices_shape[0])) + unsort_token_indices_experts = self.transpose(unsort_token_indices_experts, 1, 0) # (k, B*S) --> (B*S, k) + + inter_map = self.fmod(token_indices_experts_sorted, topk_indices_shape[0]) + index = self.reshape(inter_map, (-1,)) + + routed_input = self.index_select(tokens, 0, index) + routed_input = self.reshape(routed_input, (tokens_shape[0], -1, tokens_shape[-1])) + + # tokens_per_expert = self.sum(self.one_hot(self.cast(topk_indices, mstype.int32), num_experts), dim=0) + # tokens_per_expert = self.cast(tokens_per_expert, mstype.float32) + + # generate the input splits and output splits for all-to-all + num_tokens_per_expert_group = AlltoAll( + split_count=ep_degree, + split_dim=-1, + concat_dim=-2, + group=device_mesh.get_group() + )(num_tokens_per_expert) + + num_tokens_per_expert_reshaped = self.reshape(num_tokens_per_expert, (ep_degree, -1)) + input_splits = self.cast(self.sum(num_tokens_per_expert_reshaped, dim=-1, keepdim=False)) + num_tokens_per_expert_group_reshaped = self.reshape(num_tokens_per_expert_group, (ep_degree, -1)) + output_splits = self.cast(self.sum(num_tokens_per_expert_group_reshaped, dim=-1, keepdim=False)) + num_tokens_per_expert = self.cumsum(self.sum(num_tokens_per_expert_group_reshaped, dim=-2, keepdim=False), 0) + num_tokens_per_expert = self.cast(num_tokens_per_expert, mstype.int64) + + input_splits = self.d2h(input_splits, "CPU", True) + output_splits = self.d2h(output_splits, "CPU", True) + + # perform expert parallel AlltoAll communication + original_shape = routed_input.shape + global_input_tokens = AlltoAllV(group=device_mesh.get_group(), block_size=cell.hidden_size)( + self.reshape(routed_input, (-1,)), input_splits, output_splits + ) + global_input_tokens = self.reshape(global_input_tokens, (1, -1, cell.hidden_size)) + routing_map = self.reshape(self.cast(sorted_topk_indices, mstype.float32), (-1,)) + routing_map = AlltoAllV(group=device_mesh.get_group(), block_size=1)( + routing_map, input_splits, output_splits + ) + routing_map = self.reshape(routing_map, (1, -1)) + + # sort tokens by local expert + _, sorted_map = self.sort(routing_map) + _, unsorted_map = self.sort(self.cast(sorted_map, mstype.float32)) + index = self.reshape(sorted_map, (self.shape(sorted_map)[0] * self.shape(sorted_map)[1])) + global_input_tokens_shape = self.shape(global_input_tokens) + global_input_tokens = self.reshape(global_input_tokens, (-1, global_input_tokens_shape[-1])) + global_input_tokens = self.index_select(global_input_tokens, 0, index) + global_input_tokens = self.reshape(global_input_tokens, + (global_input_tokens_shape[0], -1, global_input_tokens_shape[-1])) + + self.ctx = ( + probs, unsorted_map, unsort_token_indices_experts, + input_splits, output_splits, original_shape + ) + + return global_input_tokens, probs, topk_indices, num_tokens_per_expert + + @staticmethod + def _partition_fn(cell, device_mesh): + # shard on the expert dimension + sharding_plan = { + "parameter": {"weight1": (Shard(0),), + "weight2": (Shard(0),)}, + } + model = shard(cell, device_mesh, sharding_plan) + # for name, param in cell.named_parameters(recurse=False): + # dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])) + # cell.register_parameter(name, dist_param) + + # performing all-to-all combine on the output + def _token_combine(self, cell, routed_output, device_mesh): + probs, unsorted_map, unsort_token_indices_experts, \ + input_splits, output_splits, original_shape = self.ctx + routed_output = self.reshape(routed_output, (1, -1, cell.hidden_size)) + + # unsort tokens by local expert + index = self.reshape(unsorted_map, (-1,)) + routed_output_shape = routed_output.shape + routed_output = self.reshape(-1, routed_output_shape[-1]) + routed_output = self.index_select(routed_output, 0, index) + routed_output = self.reshape(routed_output, (routed_output_shape[0], -1, routed_output_shape[-1])) + + # perform expert parallel AlltoAll communication + permutated_local_input_tokens = AlltoAllV(group=device_mesh.get_group(), block_size=cell.hidden_size)( + self.reshape(routed_output, (-1)), output_splits, input_splits + ) + permutated_local_input_tokens = self.reshape(permutated_local_input_tokens, original_shape) + + # AlltoAll output to output + index = self.reshape(unsort_token_indices_experts, (-1,)) + permutated_local_input_tokens = self.reshape(permutated_local_input_tokens, (-1, self.shape(permutated_local_input_tokens)[-1])) + routed_output = self.index_select(permutated_local_input_tokens, 0, index) + unsort_token_indices_experts_shape = self.shape(unsort_token_indices_experts) + routed_output = self.reshape(routed_output, + (unsort_token_indices_experts_shape[0], unsort_token_indices_experts_shape[1], -1)) + probs = self.reshape(probs, (self.shape(probs)[0], self.shape(probs)[1], 1)) + routed_output = self.mul(routed_output, self.cast(probs, routed_output.dtype)) + routed_output = ops.ReduceSum(keep_dims=False)(routed_output, 2) + + return routed_output + + def _apply(self, module: nn.Cell, device_mesh: DeviceMesh) -> nn.Cell: + # only supports GroupedMLP + if not isinstance(module, GroupedMLP): + raise NotImplemented + + self._partition_fn(module, device_mesh=device_mesh) + + module.register_forward_pre_hook( + lambda cell, inputs: self._token_dispatch(cell, inputs, device_mesh) + ) + + module.register_forward_hook( + lambda cell, inputs: self._token_combine(cell, inputs, device_mesh) + ) + + return module + # will be realized in the future + """ + return distribute_module( + module, + device_mesh, + partition_fn=ExpertParallel._partition_fn, + input_fn=self._token_dispatch, + output_fn=self._token_combine, + ) + """ + + +# TODO +class DeredundancyExpertParallel(ExpertParallel): + def __init__(self, nums_per_device: int = 8): + self.nums_per_device = nums_per_device + super().__init__() + + def _token_dispatch(self, cell, inputs, device_mesh): + pass + + def _token_combine(self, cell, routed_output, device_mesh): + pass + +""" +grouped_mlp = GroupedMLP() +expert_parallel = ExpertParallel() + +device_mesh = init_device_mesh( + mesh_shape=(2, 2), + alias_name=("dp", "ep") +) + +ep_mesh = device_mesh["ep"] +expert_parallel._apply(grouped_mlp, ep_mesh) +""" diff --git a/mindformers/pynative/distributed/style.py b/mindformers/pynative/distributed/style.py new file mode 100644 index 0000000000000000000000000000000000000000..53e414346c375e95aa6f2dcc0e7d3e26e19abbb1 --- /dev/null +++ b/mindformers/pynative/distributed/style.py @@ -0,0 +1,24 @@ +# Copyright 2026 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from abc import ABC, abstractmethod + +from mindspore import mint, nn, ops, Tensor +from hyper_parallel.core.device_mesh import DeviceMesh + + +class ParallelStyle(ABC): + + @abstractmethod + def _apply(self, module: nn.Cell, device_mesh: DeviceMesh) -> nn.Cell: ... diff --git a/mindformers/pynative/transformers/moe/experts.py b/mindformers/pynative/transformers/moe/experts.py index 48e883659f763da3137f1935368e90b7d92f5040..664c8581f8faabc5fd92889e8ba7f8da1cd9f886 100644 --- a/mindformers/pynative/transformers/moe/experts.py +++ b/mindformers/pynative/transformers/moe/experts.py @@ -19,6 +19,9 @@ import mindspore as ms from mindspore import mint, nn, ops from mindspore.common.parameter import Parameter from mindspore.ops.auto_generate import GroupedMatmul + +from hyper_parallel import DTensor + from mindformers.tools.logger import logger from mindformers.pynative.layers.activation import get_activation from mindformers.parallel_core.transformer_config import TransformerConfig @@ -87,20 +90,15 @@ class GroupedMLP(nn.Cell): self.cumsum = mint.cumsum self.bmm = mint.bmm - def permute(self, tokens, top_scores, selected_experts_indices): + def permute(self, tokens, top_scores, selected_experts_indices, num_tokens_per_expert): """ Reorders token indices to match the order of experts for MoE routing. """ _, _, dim = tokens.shape tokens = self.reshape(tokens, (-1, dim)) - # group tokens together by expert indices from 0 to num_experts and pass that to experts forward - num_tokens_per_expert = self.histc( - selected_experts_indices, - bins=self.num_local_experts, - min=0, - max=self.num_local_experts, - ) + num_tokens_per_expert = self.cast(num_tokens_per_expert, selected_experts_indices.dtype) + num_tokens_per_expert = self.cumsum(num_tokens_per_expert, dim=0, dtype=ms.int64) # Reorder the token indices to match the order of the experts # token_indices_experts_sorted shape (bs*slen*top_k,) @@ -110,7 +108,7 @@ class GroupedMLP(nn.Cell): # shape (bs*slen*top_k, dim) routed_input = tokens[token_indices_experts_sorted // self.top_k] - + routed_input = self.reshape(routed_input, (-1, self.hidden_size)) return num_tokens_per_expert, token_indices_experts_sorted, top_scores_experts_sorted, routed_input def unpermute(self, routed_output, token_indices_experts_sorted, shape): @@ -128,20 +126,27 @@ class GroupedMLP(nn.Cell): out_experts = routed_output_unsorted.sum(dim=1) return out_experts - def construct(self, tokens, probs, topk_indices): + def construct(self, tokens, probs, topk_indices, num_tokens_per_expert): """Construct function of GroupedMLP.""" - tokens_per_expert, token_indices_experts_sorted, permuted_probs, permuted_local_hidden_states = self.permute( - tokens, probs, topk_indices) + need_dispatch = not isinstance(self.weight1, DTensor) or "ep" not in self.weight1.device_mesh.mesh_dim_names + + if need_dispatch: + tokens_per_expert, token_indices_experts_sorted, permuted_probs, permuted_local_hidden_states = self.permute( + tokens, probs, topk_indices, num_tokens_per_expert) + else: + tokens_per_expert, permuted_probs, permuted_local_hidden_states = num_tokens_per_expert, probs, tokens + token_indices_experts_sorted = None - permuted_local_hidden_states = self.reshape(permuted_local_hidden_states, (-1, self.hidden_size)) if self.config.moe_apply_probs_on_input: permuted_local_hidden_states = self.mul(self.unsqueeze(permuted_probs, -1), permuted_local_hidden_states) permuted_probs = self.ones_like(permuted_probs) experts_output = self.experts_forward(permuted_local_hidden_states, tokens_per_expert, permuted_probs) - output = self.unpermute(experts_output, token_indices_experts_sorted, tokens.shape) - return output + if need_dispatch: + experts_output = self.unpermute(experts_output, token_indices_experts_sorted, tokens.shape) + + return experts_output def experts_forward(self, permuted_local_hidden_states, tokens_per_expert, permuted_probs): """Forward step of GroupedMLP.""" @@ -150,7 +155,7 @@ class GroupedMLP(nn.Cell): w2 = self.cast(self.weight2, original_dtype) w1 = self.reshape(w1, (-1, self.hidden_size, self.moe_ffn_hidden_size)) w2 = self.reshape(w2, (-1, self.config.moe_ffn_hidden_size, self.hidden_size)) - tokens_per_expert = self.cumsum(tokens_per_expert, dim=0, dtype=ms.int64) + fc1_output = GroupedMatmul(split_item=3, group_type=0)( [permuted_local_hidden_states], [w1], None, None, None, None, None, tokens_per_expert)[0] diff --git a/mindformers/pynative/transformers/moe/moe_layer.py b/mindformers/pynative/transformers/moe/moe_layer.py index 8b6071c10e18f93aa4b76db42f8f3568dc4c52cf..121f3edd372494e12b85db717d827bf07d0018f1 100644 --- a/mindformers/pynative/transformers/moe/moe_layer.py +++ b/mindformers/pynative/transformers/moe/moe_layer.py @@ -90,7 +90,7 @@ class MoELayer(nn.Cell): self.tokens_per_expert.add_(num_tokens_per_expert) - routed_output = self.experts(hidden_states, top_scores, selected_experts_indices) + routed_output = self.experts(hidden_states, top_scores, selected_experts_indices, num_tokens_per_expert) shared_output = None if self.shared_experts is not None: