From 14afcca39139f0ef4430d2c948ad5c2bd904db9c Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Wed, 18 Jun 2025 16:15:25 +0800 Subject: [PATCH 01/77] add qwen3 moe --- .../layers/fused_moe/__init__.py | 2 + .../layers/fused_moe/fused_moe.py | 0 .../model_executor/layers/linear.py | 83 ++- .../model_executor/models/qwen3_moe.py | 531 ++++++++++++++++++ vllm_mindspore/model_executor/models/utils.py | 33 +- 5 files changed, 647 insertions(+), 2 deletions(-) create mode 100644 vllm_mindspore/model_executor/layers/fused_moe/__init__.py create mode 100644 vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py create mode 100644 vllm_mindspore/model_executor/models/qwen3_moe.py diff --git a/vllm_mindspore/model_executor/layers/fused_moe/__init__.py b/vllm_mindspore/model_executor/layers/fused_moe/__init__.py new file mode 100644 index 00000000..a38a67cd --- /dev/null +++ b/vllm_mindspore/model_executor/layers/fused_moe/__init__.py @@ -0,0 +1,2 @@ +class FusedMoE: + ... \ No newline at end of file diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index e0851149..0dee09d6 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -16,7 +16,7 @@ # limitations under the License. # ============================================================================ -from typing import List, Optional +from typing import List, Optional, Union from abc import abstractmethod import numpy as np @@ -185,6 +185,87 @@ class LinearBase(ms.nn.Cell): return None +class ReplicatedLinear(LinearBase): + """Replicated linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + return_bias: If true, return bias together with outputs in forward pass. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): + super().__init__(input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix=prefix, + return_bias=return_bias) + + # All the linear layer supports quant method. + assert self.quant_method is not None + self.quant_method.create_weights(self, + self.input_size, [self.output_size], + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader) + + if bias: + self.bias = Parameter( + mint.empty(self.output_size, dtype=self.params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.bias = None + + def weight_loader(self, param: Parameter, loaded_weight: Tensor): + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param.size() == loaded_weight.size(), ( + f"Tried to load weights of size {loaded_weight.size()}" + f"to a parameter of size {param.size()}") + param.set_data(loaded_weight) + + def forward( + self, x: Tensor + ) -> Union[Tensor, tuple[Tensor, Optional[Parameter]]]: + bias = self.bias if not self.skip_bias_add else None + assert self.quant_method is not None + output = self.quant_method.apply(self, x, bias) + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias + + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + return s + + class ColumnParallelLinear(LinearBase): def __init__( self, diff --git a/vllm_mindspore/model_executor/models/qwen3_moe.py b/vllm_mindspore/model_executor/models/qwen3_moe.py new file mode 100644 index 00000000..27533115 --- /dev/null +++ b/vllm_mindspore/model_executor/models/qwen3_moe.py @@ -0,0 +1,531 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen3MoE model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Any, Optional, Union + +from mindspore import Tensor, nn +from transformers import PretrainedConfig +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.models.interfaces import SupportsPP +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from vllm_mindspore.attention import Attention +from vllm_mindspore.model_executor.layers.activation import SiluAndMul +from vllm_mindspore.model_executor.layers.fused_moe import FusedMoE +from vllm_mindspore.model_executor.layers.layernorm import RMSNorm +from vllm_mindspore.model_executor.layers.linear import ( + MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, + RowParallelLinear) +from vllm_mindspore.model_executor.layers.logits_processor import ( + LogitsProcessor) +from vllm_mindspore.model_executor.layers.rotary_embedding import get_rope +from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm_mindspore.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_mindspore.model_executor.models.utils import ( + extract_layer_index, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm_mindspore.model_executor.models.model_base import NativeModel + +logger = init_logger(__name__) + + +class Qwen3MoeMLP(nn.Cell): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Qwen3MoeSparseMoeBlock(nn.Cell): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + self.experts = FusedMoE(num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts") + + self.gate = ReplicatedLinear(config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + + def forward(self, hidden_states: Tensor) -> Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + final_hidden_states = final_hidden_states + if self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 + final_hidden_states) + + return final_hidden_states.view(orig_shape) + + +class Qwen3MoeAttention(nn.Cell): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + head_dim: Optional[int] = None, + rms_norm_eps: float = 1e-06, + qkv_bias: bool = False, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or (hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear(hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") + + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + + def forward( + self, + positions: Tensor, + hidden_states: Tensor, + ) -> Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + # Add qk-norm + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, + self.head_dim) + q_by_head = self.q_norm(q_by_head) + q = q_by_head.view(q.shape) + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, + self.head_dim) + k_by_head = self.k_norm(k_by_head) + k = k_by_head.view(k.shape) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Qwen3MoeDecoderLayer(nn.Cell): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = Qwen3MoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, 'attention_bias', False), + head_dim=getattr(config, 'head_dim', None), + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + # `mlp_only_layers` in the config. + layer_idx = extract_layer_index(prefix) + mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else + config.mlp_only_layers) + if (layer_idx not in mlp_only_layers) and ( + config.num_experts > 0 and + (layer_idx + 1) % config.decoder_sparse_step == 0): + self.mlp = Qwen3MoeSparseMoeBlock(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: Tensor, + hidden_states: Tensor, + residual: Optional[Tensor], + ) -> Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class Qwen3MoeModel(nn.Cell): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=f"{prefix}.embed_tokens") + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Qwen3MoeDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: Tensor) -> Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Tensor, + positions: Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[Tensor] = None, + ) -> Union[Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + logger.warning_once( + "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501 + name, + remapped_kv_scale_name, + ) + continue + else: + name = remapped_kv_scale_name + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Qwen3MoeForCausalLM(NativeModel, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Qwen3MoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: Tensor) -> Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: Tensor, + positions: Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[Tensor] = None, + ) -> Union[Tensor, IntermediateTensors]: + hidden_states = self.exec_model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + Tensor]]) -> set[str]: + params_dict = self.get_params_dict() + self.model.load_weights(weights, params_dict) diff --git a/vllm_mindspore/model_executor/models/utils.py b/vllm_mindspore/model_executor/models/utils.py index 493664cd..9ecebe2a 100644 --- a/vllm_mindspore/model_executor/models/utils.py +++ b/vllm_mindspore/model_executor/models/utils.py @@ -20,7 +20,7 @@ from dataclasses import dataclass, field from typing import Iterable, List, Mapping, Optional, Tuple, Union -import mindspore as ms +import mindspore as ms, nn from mindspore import mint, ops from vllm.sequence import IntermediateTensors @@ -261,3 +261,34 @@ def merge_multimodal_embeddings( (input_ids == placeholder_token_id), multimodal_embeddings, ) + + +_model_to_pp_missing_layer_names: dict[int, list[str]] = {} + + +def get_pp_missing_layer_names(model: nn.Cell) -> list[str]: + """Get the names of the missing layers in a pipeline parallel model.""" + model_id = id(model) + if model_id in _model_to_pp_missing_layer_names: + return _model_to_pp_missing_layer_names[model_id] + + missing_layer_names = [] + for name, cell in model.name_cells(): + if isinstance(cell, PPMissingLayer): + # NOTE: the trailing dot is used to match the prefix of the layer. + # without the dot, we could match a layer that is not missing, + # e.g., 'encoder.layer.1' would match 'encoder.layer.11' + missing_layer_names.append(name + '.') + _model_to_pp_missing_layer_names[model_id] = missing_layer_names + + return missing_layer_names + + +def is_pp_missing_parameter(name: str, model: nn.Cell) -> bool: + """Check if a parameter is missing in a pipeline parallel model.""" + if isinstance(model, PPMissingLayer): + return True + + return any( + name.startswith(missing_layer_name) + for missing_layer_name in get_pp_missing_layer_names(model)) -- Gitee From 6b50bfebd0b2901eda1cb9d77bf3a6944a7b8a0d Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Thu, 19 Jun 2025 16:25:01 +0800 Subject: [PATCH 02/77] add layers --- .../model_executor/layers/fused_moe/layer.py | 695 ++++++++++++++++++ 1 file changed, 695 insertions(+) create mode 100644 vllm_mindspore/model_executor/layers/fused_moe/layer.py diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py new file mode 100644 index 00000000..bb2c02d6 --- /dev/null +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -0,0 +1,695 @@ +# SPDX-License-Identifier: Apache-2.0 + +import importlib +from abc import abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Callable, Optional + +import torch +import torch.nn.functional as F +from torch.nn.parameter import UninitializedParameter + +import vllm.envs as envs +from vllm.config import ParallelConfig, get_current_vllm_config +from vllm.distributed import (get_dp_group, get_ep_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_moe_enabled) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.platforms.interface import CpuArchEnum +from vllm.utils import direct_register_custom_op +from vllm.model_executor.layers.fused_moe.layers import FusedMoEParallelConfig + + +from mindspore import nn + +class FusedMoE(nn.Cell): + """FusedMoE layer for MoE models. + + This layer contains both MergedColumnParallel weights (gate_up_proj / + w13) and RowParallelLinear weights (down_proj/ w2). + + Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We + copy that naming convention here and handle any remapping in the + load_weights function in each model implementation. + + Args: + num_experts: Number of experts in the model + top_k: Number of experts selected for each token + hidden_size: Input hidden state size of the transformer + intermediate_size: Intermediate size of the experts + params_dtype: Data type for the parameters. + reduce_results: Whether to all all_reduce on the output of the layer + renomalize: Whether to renormalize the logits in the fused_moe kernel + quant_config: Quantization configure. + """ + + def __init__( + self, + num_experts: int, # Global number of experts + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + ep_size: Optional[int] = None, + dp_size: Optional[int] = None, + prefix: str = "", + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ): + super().__init__() + + if params_dtype is None: + params_dtype = get_current_vllm_config().model_config.dtype + self.params_dtype = params_dtype + + vllm_config = get_current_vllm_config() + self.moe_parallel_config: FusedMoEParallelConfig = ( + FusedMoEParallelConfig.make( + tp_size_=(tp_size if tp_size is not None else + get_tensor_model_parallel_world_size()), + dp_size_=(dp_size if dp_size is not None else + get_dp_group().world_size), + vllm_parallel_config=vllm_config.parallel_config)) + + self.global_num_experts = num_experts + + # For smuggling this layer into the fused moe custom op + self.use_direct_call = self.dp_size == 1 + if not self.use_direct_call: + compilation_config = vllm_config.compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError("Duplicate layer name: {}".format(prefix)) + compilation_config.static_forward_context[prefix] = self + self.layer_name = prefix + + # Determine expert maps + if self.use_ep: + self.local_num_experts, self.expert_map = determine_expert_map( + ep_size=self.ep_size, + ep_rank=self.ep_rank, + global_num_experts=self.global_num_experts) + else: + self.local_num_experts, self.expert_map = (self.global_num_experts, + None) + + self.top_k = top_k + + assert intermediate_size % self.tp_size == 0 + self.hidden_size = hidden_size + self.intermediate_size_per_partition = intermediate_size // self.tp_size + self.reduce_results = reduce_results + self.renormalize = renormalize + self.use_grouped_topk = use_grouped_topk + if self.use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + self.num_expert_group = num_expert_group + self.topk_group = topk_group + self.custom_routing_function = custom_routing_function + self.scoring_func = scoring_func + self.e_score_correction_bias = e_score_correction_bias + self.apply_router_weight_on_input = apply_router_weight_on_input + self.activation = activation + + if self.scoring_func != "softmax" and not self.use_grouped_topk: + raise ValueError("Only softmax scoring function is supported for " + "non-grouped topk.") + if current_platform.is_hpu(): + from vllm_hpu_extension.ops import DynamicFusedMOE + self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) + + moe = MoEConfig( + num_experts=self.global_num_experts, + experts_per_token=top_k, + hidden_dim=hidden_size, + num_local_experts=self.local_num_experts, + moe_parallel_config=self.moe_parallel_config, + # TODO (bnell): this needs to be fixed for quantized types. + in_dtype=params_dtype, + max_num_tokens=MOE_DP_CHUNK_SIZE, + ) + self.moe_config = moe + self.quant_config = quant_config + + # Note: get_quant_method will look at the layer's local_num_experts + # for heuristic purposes, so it must be initialized first. + quant_method: Optional[QuantizeMethodBase] = None + + if quant_config is None: + quant_method = UnquantizedFusedMoEMethod(moe) + else: + quant_method = quant_config.get_quant_method(self, prefix) + + assert quant_method is not None + assert isinstance(quant_method, FusedMoEMethodBase) + self.quant_method = quant_method + + moe_quant_params = { + "num_experts": self.local_num_experts, + "hidden_size": hidden_size, + "intermediate_size_per_partition": + self.intermediate_size_per_partition, + "params_dtype": params_dtype, + "weight_loader": self.weight_loader, + } + # need full intermediate size pre-sharding for WNA16 act order + if (self.quant_method.__class__.__name__ + in ("GPTQMarlinMoEMethod", + "CompressedTensorsWNA16MarlinMoEMethod", + "CompressedTensorsWNA16MoEMethod")): + moe_quant_params["intermediate_size_full"] = intermediate_size + + self.quant_method.create_weights(layer=self, **moe_quant_params) + + @property + def tp_size(self): + return self.moe_parallel_config.tp_size + + @property + def dp_size(self): + return self.moe_parallel_config.dp_size + + @property + def ep_size(self): + return self.moe_parallel_config.ep_size + + @property + def tp_rank(self): + return self.moe_parallel_config.tp_rank + + @property + def dp_rank(self): + return self.moe_parallel_config.dp_rank + + @property + def ep_rank(self): + return self.moe_parallel_config.ep_rank + + @property + def use_ep(self): + return self.moe_parallel_config.use_ep + + @property + def use_pplx_kernels(self): + return self.moe_parallel_config.use_pplx_kernels + + def _load_per_tensor_weight_scale(self, shard_id: str, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + expert_id: int): + param_data = param.data + # for per tensor weight quantization + if shard_id in ("w1", "w3"): + # We have to keep the weight scales of w1 and w3 because + # we need to re-quantize w1/w3 weights after weight loading. + idx = 0 if shard_id == "w1" else 1 + param_data[expert_id][idx] = loaded_weight + # If we are in the row parallel case (down_proj) + elif shard_id == "w2": + param_data[expert_id] = loaded_weight + + def _load_model_weight_or_group_weight_scale(self, + shard_dim: int, + expert_data: torch.Tensor, + shard_id: str, + loaded_weight: torch.Tensor, + tp_rank: int, + load_full_w2: bool = False): + """ + Load grouped weight scales for group quantization or model weights + :param shard_dim: dimension to shard + :param expert_data: parameter for a particular expert + :param shard_id: either w1, w2, or w3 + :param loaded_weight: checkpoint weight to load into the param + :param tp_rank: tensor parallel rank + :param load_full_w2: whether or not the w2 loaded should be sharded. + """ + if shard_id == "w2": + # In the case where we have actorder/g_idx, we do not partition the + # w2 scales, as indicated by `load_full` argument, for all tp cases + self._load_w2(shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank, + load_full=load_full_w2) + elif shard_id in ("w1", "w3"): + self._load_w13(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + + def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, + shard_dim: int, shard_id: str, + loaded_weight: torch.Tensor, + tp_rank: int): + # for per channel weight quantization + if shard_id == "w2": + expert_data.copy_(loaded_weight) + elif shard_id in ("w1", "w3"): + self._load_w13(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + + def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, + shard_id: str, loaded_weight: torch.Tensor, tp_rank: int): + + # Index the loaded weight for tp sharding. + # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim + shard_size = expert_data.shape[shard_dim] // 2 + loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, + shard_size) + # Narrow parameter and load. + # w1, gate_proj: Load into first logical weight of w13. + if shard_id == "w1": + expert_data = expert_data.narrow(shard_dim, 0, shard_size) + # w3, up_proj: Load into second logical weight of w13. + else: + assert shard_id == "w3" + expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + expert_data.copy_(loaded_weight) + + def _load_w2(self, + expert_data: torch.Tensor, + shard_dim: int, + loaded_weight: torch.Tensor, + tp_rank: int, + load_full: bool = False): + + # Index the loaded weight for tp sharding. + # down_proj: "RowParallel" so tp sharding on input_dim + # Narrow parameter and load. + shard_size = expert_data.shape[shard_dim] + if not load_full: + loaded_weight = loaded_weight.narrow(shard_dim, + shard_size * tp_rank, + shard_size) + # w2, down_proj: Load into only logical weight of w2. + expert_data.copy_(loaded_weight) + + def _load_single_value(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, expert_id: int): + param_data = param.data + + # Input scales can be loaded directly and should be equal. + param_data[expert_id] = loaded_weight + + def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor, + shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int): + + if shard_id == "w2": + self._load_w2(shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + else: + assert shard_id in ("w1", "w3") + expert_data.copy_(loaded_weight) + + def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: + if self.expert_map is None: + return expert_id + return self.expert_map[expert_id].item() + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, weight_name: str, + shard_id: str, expert_id: int) -> None: + + expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) + if expert_id == -1: + return + quant_method_name = self.quant_method.__class__.__name__ + # compressed-tensors checkpoints with packed weights are stored flipped + # TODO (mgoin): check self.quant_method.quant_config.quant_format + # against known CompressionFormat enum values that have this quality + if self.quant_method.__class__.__name__ in ( + "CompressedTensorsWNA16MarlinMoEMethod", + "CompressedTensorsWNA16MoEMethod"): + loaded_weight = loaded_weight.t().contiguous() + + if shard_id not in ("w1", "w2", "w3"): + raise ValueError(f"shard_id must be ['w1','w2','w3'] but " + f"got {shard_id}.") + + WEIGHT_SCALE_SUPPORTED = [ + e.value for e in FusedMoeWeightScaleSupported + ] + # Fetch the dim to shard the parameter/loaded weight + # based on the shard id. This will be whatever + # dimension intermediate_size_per_partition is used. + SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} + + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.weight_type = loaded_weight.item() + param.data.copy_(loaded_weight) + return + + # is_transposed: if the dim to shard the weight + # should be flipped. Required by GPTQ, compressed-tensors + # should be whatever dimension intermediate_size_per_partition is + is_transposed = getattr(param, "is_transposed", False) + shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] + if is_transposed: + shard_dim = int(not shard_dim) + + full_load = len(loaded_weight.shape) == 3 + if full_load: + shard_dim += 1 + + # Materialize GGUF UninitializedParameter + if is_gguf_weight and isinstance(param, UninitializedParameter): + final_shape = list(loaded_weight.shape) + if shard_id in ["w1", "w3"]: + final_shape[1] *= 2 + final_shape[shard_dim] = final_shape[shard_dim] // self.tp_size + param.materialize(final_shape, dtype=loaded_weight.dtype) + + expert_data = param.data if full_load else param.data[expert_id] + # Case input scale: input_scale loading is only supported for fp8 + if "input_scale" in weight_name: + # this is needed for compressed-tensors only + loaded_weight = loaded_weight.to(param.data.device) + + if ("compressed" in quant_method_name.lower() + and param.data[expert_id] != 1 + and (param.data[expert_id] - loaded_weight).abs() > 1e-5): + raise ValueError( + "input_scales of w1 and w3 of a layer " + f"must be equal. But got {param.data[expert_id]} " + f"vs. {loaded_weight}") + + self._load_single_value(param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + return + + # Case g_idx + if "g_idx" in weight_name: + self._load_g_idx(shard_dim=0, + shard_id=shard_id, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank) + return + + if "ModelOpt" in quant_method_name: + if ('weight_scale_2' in weight_name + or 'input_scale' in weight_name): + self._load_per_tensor_weight_scale(shard_id=shard_id, + param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + elif "weight" in weight_name: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank) + return + + # Case weight scales, zero_points and offset + if ("scale" in weight_name or "zero" in weight_name + or "offset" in weight_name): + # load the weight scales and zp based on the quantization scheme + # supported weight scales/zp can be found in + # FusedMoeWeightScaleSupported + # TODO @dsikka: once hardened, refactor to use vLLM Parameters + # specific to each case + quant_method = getattr(param, "quant_method", None) + if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value: + self._load_per_channel_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank) + elif quant_method in [ + FusedMoeWeightScaleSupported.GROUP.value, + FusedMoeWeightScaleSupported.BLOCK.value, + ]: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank, + load_full_w2=getattr(param, "load_full_w2", False)) + elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: + self._load_per_tensor_weight_scale(shard_id=shard_id, + param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + else: + raise ValueError( + f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}") + return + + # Case weight_shape + if "weight_shape" in weight_name: + # only required by compressed-tensors + self._load_single_value(param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + return + + # Case model weights + if "weight" in weight_name: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank) + return + + @staticmethod + def select_experts(hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + indices_type: Optional[torch.dtype] = None): + from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk + + # DeekSeekv2 uses grouped_top_k + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + topk_weights, topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + if indices_type is not None: + topk_ids = topk_ids.to(dtype=indices_type) + elif custom_routing_function is None: + topk_weights, topk_ids, token_expert_indices = fused_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + indices_type=indices_type, + ) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) + if indices_type is not None: + topk_ids = topk_ids.to(dtype=indices_type) + + return topk_weights, topk_ids + + def must_reduce_shared_expert_outputs(self) -> bool: + """ + The shared_experts are typically computed using the RowParallelLinear + layer. The result of this function is typically used as + the reduce_results argument to the module. + When just tensor-parallel is used, it is not required to reduce + the shared_experts results immediately. Instead we reduce at the + once at the end of the MoE op. (Refer to DeepSeekV2MoE module) + With EP and the pplx kernels - this is no longer viable as all + GPU ranks in DP, produce the complete set of hidden_states. + Therefore it is required that we reduce the shared_experts output + early. + """ + return self.use_pplx_kernels + + def maybe_all_reduce_tensor_model_parallel( + self, final_hidden_states: torch.Tensor): + """ + The pplx combine kernel reduces across GPU ranks by default. + """ + if self.use_pplx_kernels: + return final_hidden_states + else: + return tensor_model_parallel_all_reduce(final_hidden_states) + + def forward(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + if self.use_direct_call: + return self.forward_impl(hidden_states, router_logits) + else: + return torch.ops.vllm.moe_forward(hidden_states, router_logits, + self.layer_name) + + def forward_impl_chunked(self, full_hidden_states: torch.Tensor, + full_router_logits: torch.Tensor): + + full_final_hidden_states = torch.empty_like(full_hidden_states) + + def process_chunk(chunk_start, chunk_end, skip_result_store=False): + hidden_states = full_hidden_states[chunk_start:chunk_end, :] + router_logits = full_router_logits[chunk_start:chunk_end, :] + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + activation=self.activation, + ) + + if not skip_result_store: + full_final_hidden_states[chunk_start:chunk_end, :].copy_( + final_hidden_states) + + ctx = get_forward_context() + max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu + moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE + + num_tokens = full_hidden_states.size(0) + for chunk_start_ in range(0, max_tokens_across_dp, + moe_dp_chunk_size_per_rank): + chunk_start = chunk_start_ + chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, + max_tokens_across_dp) + # clamp start and end + chunk_start = min(chunk_start, num_tokens - 1) + chunk_end = min(chunk_end, num_tokens) + + process_chunk(chunk_start, + chunk_end, + skip_result_store=chunk_start_ >= num_tokens) + + return full_final_hidden_states + + def forward_impl(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + assert self.quant_method is not None + if self.moe_parallel_config.use_pplx_kernels: + return self.forward_impl_chunked(hidden_states, router_logits) + + if self.dp_size > 1: + hidden_states, router_logits = get_ep_group().dispatch( + hidden_states, router_logits) + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + activation=self.activation, + apply_router_weight_on_input=self.apply_router_weight_on_input, + ) + + if self.dp_size > 1: + final_hidden_states = get_ep_group().combine(final_hidden_states) + + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + # Default set to False. (May have to add shared expert outputs.) + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states + + @classmethod + def make_expert_params_mapping( + cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int) -> list[tuple[str, str, int, str]]: + + return [ + # (param_name, weight_name, expert_id, shard_id) + ("experts.w13_" if weight_name + in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", + f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) + for expert_id in range(num_experts) for shard_id, weight_name in [ + ("w1", ckpt_gate_proj_name), + ("w2", ckpt_down_proj_name), + ("w3", ckpt_up_proj_name), + ] + ] + + def extra_repr(self) -> str: + + s = ( + f"global_num_experts={self.global_num_experts}, " + f"local_num_experts={self.local_num_experts}, " + f"top_k={self.top_k}, " + f"intermediate_size_per_partition={self.intermediate_size_per_partition}, " # noqa: E501 + f"tp_size={self.tp_size},\n" + f"ep_size={self.ep_size}, " + f"reduce_results={self.reduce_results}, " + f"renormalize={self.renormalize}, " + f"use_grouped_topk={self.use_grouped_topk}") + + if self.use_grouped_topk: + s += f", num_expert_group={self.num_expert_group}, topk_group={self.topk_group}" # noqa: E501 + + s += f", scoring_func='{self.scoring_func}', activation='{self.activation}'" # noqa: E501 + + return s -- Gitee From d8011fb81ef281363baeb143a86c651534bfd0d0 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Fri, 20 Jun 2025 11:41:48 +0800 Subject: [PATCH 03/77] update --- .../layers/fused_moe/fused_moe.py | 86 +++++ .../model_executor/layers/fused_moe/layer.py | 321 +++++++++++++----- 2 files changed, 314 insertions(+), 93 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py index e69de29b..b5c20ef2 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py @@ -0,0 +1,86 @@ +from typing import Optional + +from mindspore import Tensor +from mindspore.ops.auto_generate import FusedAddTopKDiv +import mindspore as ms +def fused_topk( + hidden_states: Tensor, + gating_output: Tensor, + topk: int, + renormalize: bool, + indices_type = None, +) -> tuple[Tensor, Tensor, Tensor]: + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + fused_add_topk_div = FusedAddTopKDiv() + e_score_correction_bias = 0 + num_expert_group = 0 + topk_group = 0 + scoring_type = 0 # softmax + group_max_topk = 2 + topk_weights, topk_ids = fused_add_topk_div( + gating_output, + e_score_correction_bias, + num_expert_group, + topk_group, + topk, + group_max_topk, + scoring_type, + renormalize) + if indices_type is not None: + topk_ids = topk_ids.to(indices_type) + return topk_weights, topk_ids + + +def grouped_topk( + hidden_states: Tensor, + gating_output: Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[Tensor] = None +) -> tuple[Tensor, Tensor]: + fused_add_topk_div = FusedAddTopKDiv() + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + scoring_type = 0 # sigmoid + group_max_topk = 2 + topk_weights, topk_ids = fused_add_topk_div( + gating_output, + e_score_correction_bias, + num_expert_group, + topk_group, + topk, + group_max_topk, + scoring_type, + renormalize) + + return topk_weights.to(ms.float32), topk_ids.to(ms.int32) + + +def fused_experts(hidden_states: Tensor, + w1: Tensor, + w2: Tensor, + topk_weights: Tensor, + topk_ids: Tensor, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[Tensor] = None, + w1_scale: Optional[Tensor] = None, + w2_scale: Optional[Tensor] = None, + w1_zp: Optional[Tensor] = None, + w2_zp: Optional[Tensor] = None, + a1_scale: Optional[Tensor] = None, + a2_scale: Optional[Tensor] = None, + block_shape: Optional[list[int]] = None, + allow_deep_gemm: bool = False) -> Tensor: + ... \ No newline at end of file diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index bb2c02d6..151946fc 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -7,7 +7,6 @@ from enum import Enum from typing import Callable, Optional import torch -import torch.nn.functional as F from torch.nn.parameter import UninitializedParameter import vllm.envs as envs @@ -29,8 +28,208 @@ from vllm.platforms.interface import CpuArchEnum from vllm.utils import direct_register_custom_op from vllm.model_executor.layers.fused_moe.layers import FusedMoEParallelConfig +from vllm.model_executor.layers.fused_moe.layers import (determine_expert_map, MoEConfig, + FusedMoeWeightScaleSupported, + FusedMoEMethodBase) + + +from vllm_mindspore.model_executor.layers.fused_moe.fused_moe import (fused_topk, + grouped_topk, + MOE_DP_CHUNK_SIZE, + fused_expert) + +from mindspore import nn, Tensor, Parameter + + + +class UnquantizedFusedMoEMethod(FusedMoEMethodBase, nn.Cell): + """MoE method without quantization.""" + + def __init__(self, moe: MoEConfig): + super().__init__() + self.fused_experts = fused_experts # type: ignore + self.moe = moe + + self.rocm_aiter_fused_experts = None # type: ignore + + def select_gemm_impl( + self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize]): + + assert self.fused_experts == fused_experts + + all2all_manager = get_ep_group().device_communicator.all2all_manager + assert all2all_manager is not None + + experts: Optional[FusedMoEPermuteExpertsUnpermute] = None + + if isinstance(prepare_finalize, + (BatchedPrepareAndFinalize, PplxPrepareAndFinalize)): + logger.debug("BatchedTritonExperts %s", self.moe) + experts = BatchedTritonExperts( + max_num_tokens=MOE_DP_CHUNK_SIZE, + world_size=all2all_manager.world_size, + # dp_size actually means tp_size, bug in pplx kernels + dp_size=all2all_manager.tp_group.world_size, + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=None, + ) + else: + logger.debug("TritonExperts %s", self.moe) + experts = TritonExperts( + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=None, + per_channel_quant=False, + ) + return experts + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: + # Pad the weight tensor. This is an optimization on ROCm platform, which + # can benefit from tensors located far enough from one another in memory + if (envs.VLLM_ROCM_MOE_PADDING and current_platform.is_rocm() + and weight.stride(-1) == 1 + and (weight.stride(-2) * weight.element_size()) % 512 == 0): + num_pad = 256 // weight.element_size() + weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] + torch.cuda.empty_cache() + return weight + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + super().process_weights_after_loading(layer) + + # Padding the weight for better performance on ROCm + layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data) + layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) + # Lazy import to avoid importing triton. + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + shuffle_weights) + + if self.rocm_aiter_moe_enabled: + shuffled_w13, shuffled_w2 = shuffle_weights( + layer.w13_weight.data, layer.w2_weight.data) + + layer.w13_weight.data = shuffled_w13 + layer.w2_weight.data = shuffled_w2 + + if current_platform.is_cpu(): + if current_platform.get_cpu_architecture() == CpuArchEnum.X86: + import intel_extension_for_pytorch as ipex + layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( + layer.w13_weight, + layer.w2_weight, + use_prepack=envs.VLLM_CPU_MOE_PREPACK, + ) + else: + raise NotImplementedError("CPU MOE only supports x86 arch.") + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + return self.forward_npu( + x=x, + layer=layer, + router_logits=router_logits, + top_k=top_k, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + num_expert_group=num_expert_group, + global_num_experts=global_num_experts, + expert_map=expert_map, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input) + + def forward_npu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + indices_type=torch.uint32 if self.moe.use_pplx_kernels else None) + + return self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) + -from mindspore import nn class FusedMoE(nn.Cell): """FusedMoE layer for MoE models. @@ -59,7 +258,7 @@ class FusedMoE(nn.Cell): top_k: int, hidden_size: int, intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, + params_dtype = None, reduce_results: bool = False, renormalize: bool = True, use_grouped_topk: bool = False, @@ -72,7 +271,7 @@ class FusedMoE(nn.Cell): prefix: str = "", custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: Optional[Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", ): @@ -133,9 +332,6 @@ class FusedMoE(nn.Cell): if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " "non-grouped topk.") - if current_platform.is_hpu(): - from vllm_hpu_extension.ops import DynamicFusedMOE - self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) moe = MoEConfig( num_experts=self.global_num_experts, @@ -213,8 +409,8 @@ class FusedMoE(nn.Cell): return self.moe_parallel_config.use_pplx_kernels def _load_per_tensor_weight_scale(self, shard_id: str, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, + param: Parameter, + loaded_weight: Tensor, expert_id: int): param_data = param.data # for per tensor weight quantization @@ -229,9 +425,9 @@ class FusedMoE(nn.Cell): def _load_model_weight_or_group_weight_scale(self, shard_dim: int, - expert_data: torch.Tensor, + expert_data: Tensor, shard_id: str, - loaded_weight: torch.Tensor, + loaded_weight: Tensor, tp_rank: int, load_full_w2: bool = False): """ @@ -258,9 +454,9 @@ class FusedMoE(nn.Cell): expert_data=expert_data, tp_rank=tp_rank) - def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, + def _load_per_channel_weight_scale(self, expert_data: Tensor, shard_dim: int, shard_id: str, - loaded_weight: torch.Tensor, + loaded_weight: Tensor, tp_rank: int): # for per channel weight quantization if shard_id == "w2": @@ -272,8 +468,8 @@ class FusedMoE(nn.Cell): expert_data=expert_data, tp_rank=tp_rank) - def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, - shard_id: str, loaded_weight: torch.Tensor, tp_rank: int): + def _load_w13(self, expert_data: Tensor, shard_dim: int, + shard_id: str, loaded_weight: Tensor, tp_rank: int): # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim @@ -291,9 +487,9 @@ class FusedMoE(nn.Cell): expert_data.copy_(loaded_weight) def _load_w2(self, - expert_data: torch.Tensor, + expert_data: Tensor, shard_dim: int, - loaded_weight: torch.Tensor, + loaded_weight: Tensor, tp_rank: int, load_full: bool = False): @@ -309,14 +505,14 @@ class FusedMoE(nn.Cell): expert_data.copy_(loaded_weight) def _load_single_value(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor, expert_id: int): + loaded_weight: Tensor, expert_id: int): param_data = param.data # Input scales can be loaded directly and should be equal. param_data[expert_id] = loaded_weight - def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor, - shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int): + def _load_g_idx(self, shard_id: str, expert_data: Tensor, + shard_dim: int, loaded_weight: Tensor, tp_rank: int): if shard_id == "w2": self._load_w2(shard_dim=shard_dim, @@ -333,7 +529,7 @@ class FusedMoE(nn.Cell): return self.expert_map[expert_id].item() def weight_loader(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor, weight_name: str, + loaded_weight: Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) @@ -487,8 +683,8 @@ class FusedMoE(nn.Cell): return @staticmethod - def select_experts(hidden_states: torch.Tensor, - router_logits: torch.Tensor, + def select_experts(hidden_states: Tensor, + router_logits: Tensor, top_k: int, use_grouped_topk: bool, renormalize: bool, @@ -496,9 +692,8 @@ class FusedMoE(nn.Cell): num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: Optional[Tensor] = None, indices_type: Optional[torch.dtype] = None): - from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk # DeekSeekv2 uses grouped_top_k if use_grouped_topk: @@ -516,7 +711,7 @@ class FusedMoE(nn.Cell): if indices_type is not None: topk_ids = topk_ids.to(dtype=indices_type) elif custom_routing_function is None: - topk_weights, topk_ids, token_expert_indices = fused_topk( + topk_weights, topk_ids = fused_topk( hidden_states=hidden_states, gating_output=router_logits, topk=top_k, @@ -550,79 +745,19 @@ class FusedMoE(nn.Cell): return self.use_pplx_kernels def maybe_all_reduce_tensor_model_parallel( - self, final_hidden_states: torch.Tensor): + self, final_hidden_states: Tensor): """ The pplx combine kernel reduces across GPU ranks by default. """ - if self.use_pplx_kernels: - return final_hidden_states - else: - return tensor_model_parallel_all_reduce(final_hidden_states) - - def forward(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): - if self.use_direct_call: - return self.forward_impl(hidden_states, router_logits) - else: - return torch.ops.vllm.moe_forward(hidden_states, router_logits, - self.layer_name) - - def forward_impl_chunked(self, full_hidden_states: torch.Tensor, - full_router_logits: torch.Tensor): - - full_final_hidden_states = torch.empty_like(full_hidden_states) - - def process_chunk(chunk_start, chunk_end, skip_result_store=False): - hidden_states = full_hidden_states[chunk_start:chunk_end, :] - router_logits = full_router_logits[chunk_start:chunk_end, :] - - # Matrix multiply. - final_hidden_states = self.quant_method.apply( - layer=self, - x=hidden_states, - router_logits=router_logits, - top_k=self.top_k, - renormalize=self.renormalize, - use_grouped_topk=self.use_grouped_topk, - global_num_experts=self.global_num_experts, - expert_map=self.expert_map, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - custom_routing_function=self.custom_routing_function, - scoring_func=self.scoring_func, - e_score_correction_bias=self.e_score_correction_bias, - activation=self.activation, - ) - - if not skip_result_store: - full_final_hidden_states[chunk_start:chunk_end, :].copy_( - final_hidden_states) - - ctx = get_forward_context() - max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu - moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE - - num_tokens = full_hidden_states.size(0) - for chunk_start_ in range(0, max_tokens_across_dp, - moe_dp_chunk_size_per_rank): - chunk_start = chunk_start_ - chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, - max_tokens_across_dp) - # clamp start and end - chunk_start = min(chunk_start, num_tokens - 1) - chunk_end = min(chunk_end, num_tokens) - - process_chunk(chunk_start, - chunk_end, - skip_result_store=chunk_start_ >= num_tokens) + return tensor_model_parallel_all_reduce(final_hidden_states) - return full_final_hidden_states + def construct(self, hidden_states: Tensor, + router_logits: Tensor): + return self.forward_impl(hidden_states, router_logits) - def forward_impl(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + def forward_impl(self, hidden_states: Tensor, + router_logits: Tensor): assert self.quant_method is not None - if self.moe_parallel_config.use_pplx_kernels: - return self.forward_impl_chunked(hidden_states, router_logits) if self.dp_size > 1: hidden_states, router_logits = get_ep_group().dispatch( -- Gitee From 60309c95fcc7240634476cce6399c56e6f2641e9 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Fri, 20 Jun 2025 14:56:20 +0800 Subject: [PATCH 04/77] update --- .../layers/fused_moe/fused_moe.py | 30 ++++++++----------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py index b5c20ef2..c6416b47 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py @@ -1,6 +1,6 @@ from typing import Optional -from mindspore import Tensor +from mindspore import Tensor, mint from mindspore.ops.auto_generate import FusedAddTopKDiv import mindspore as ms def fused_topk( @@ -9,27 +9,21 @@ def fused_topk( topk: int, renormalize: bool, indices_type = None, -) -> tuple[Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor]: assert hidden_states.shape[0] == gating_output.shape[0], ( "Number of tokens mismatch") - fused_add_topk_div = FusedAddTopKDiv() - e_score_correction_bias = 0 - num_expert_group = 0 - topk_group = 0 - scoring_type = 0 # softmax - group_max_topk = 2 - topk_weights, topk_ids = fused_add_topk_div( - gating_output, - e_score_correction_bias, - num_expert_group, - topk_group, - topk, - group_max_topk, - scoring_type, - renormalize) + score = mint.softmax(gating_output, dim=-1) + topk_weights, topk_ids = mint.topk( + score, + k=topk, + dim=-1 + ) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + if indices_type is not None: topk_ids = topk_ids.to(indices_type) - return topk_weights, topk_ids + return topk_weights.to(ms.float32), topk_ids.to(ms.int32) def grouped_topk( -- Gitee From 68b007b6d28463e64b6fa539919a97bb8aeddf1c Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Fri, 20 Jun 2025 15:30:58 +0800 Subject: [PATCH 05/77] update --- vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py index c6416b47..ec1642bd 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py @@ -3,6 +3,8 @@ from typing import Optional from mindspore import Tensor, mint from mindspore.ops.auto_generate import FusedAddTopKDiv import mindspore as ms + + def fused_topk( hidden_states: Tensor, gating_output: Tensor, @@ -40,14 +42,14 @@ def grouped_topk( assert hidden_states.shape[0] == gating_output.shape[0], ( "Number of tokens mismatch") scoring_type = 0 # sigmoid - group_max_topk = 2 + topk_in_group = 2 topk_weights, topk_ids = fused_add_topk_div( gating_output, e_score_correction_bias, num_expert_group, topk_group, topk, - group_max_topk, + topk_in_group, scoring_type, renormalize) -- Gitee From 52423267154cb3ca82627e8eaa9b92096a55eac7 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Mon, 23 Jun 2025 12:08:30 +0800 Subject: [PATCH 06/77] update --- .../device_communicators/__init__.py | 0 .../device_communicators/npu_communicator.py | 4 + .../layers/fused_moe/fused_moe.py | 46 ++- .../model_executor/layers/fused_moe/layer.py | 354 ++++-------------- vllm_mindspore/platforms/ascend.py | 2 +- 5 files changed, 129 insertions(+), 277 deletions(-) create mode 100644 vllm_mindspore/distributed/device_communicators/__init__.py create mode 100644 vllm_mindspore/distributed/device_communicators/npu_communicator.py diff --git a/vllm_mindspore/distributed/device_communicators/__init__.py b/vllm_mindspore/distributed/device_communicators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_mindspore/distributed/device_communicators/npu_communicator.py b/vllm_mindspore/distributed/device_communicators/npu_communicator.py new file mode 100644 index 00000000..3885baa9 --- /dev/null +++ b/vllm_mindspore/distributed/device_communicators/npu_communicator.py @@ -0,0 +1,4 @@ +from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator + +class NPUCommunicator(CudaCommunicator): + ... diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py index ec1642bd..169716d7 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py @@ -3,7 +3,7 @@ from typing import Optional from mindspore import Tensor, mint from mindspore.ops.auto_generate import FusedAddTopKDiv import mindspore as ms - +from vllm.distributed.parallel_state import get_tp_group def fused_topk( hidden_states: Tensor, @@ -79,4 +79,46 @@ def fused_experts(hidden_states: Tensor, a2_scale: Optional[Tensor] = None, block_shape: Optional[list[int]] = None, allow_deep_gemm: bool = False) -> Tensor: - ... \ No newline at end of file + use_ep = False + if expert_map is not None: + use_ep = True + + if use_ep: + _run_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, activation) + else: + _run_tp_moe(hidden_states, w1, w2, topk_ids, topk_weights, activation) + +def _run_activation(x, activation): + if activation == "silu": + return mint.silu(x) + elif activation == "gelu": + return mint.gelu(x) + else: + raise ValueError(f"Unsupported activation function: {activation}") + + +def _run_ep_moe(hidden_states, + w1, + w2, + group_list, + group_logits, + activation): + hidden_states = mint.group_matmul(hidden_states, w1, group_list) + hidden_states = _run_activation(hidden_states, activation) + hidden_states = mint.group_matmul(hidden_states, w2, group_list) + hidden_states = mint.mul(hidden_states, group_logits) + return hidden_states + + +def _run_tp_moe(hidden_states, + w1, + w2, + group_list, + group_logits, + activation): + hidden_states = mint.group_matmul(hidden_states, w1, group_list) + hidden_states = _run_activation(hidden_states, activation) + hidden_states = mint.group_matmul(hidden_states, w2, group_list) + hidden_states = mint.all_reduce(hidden_states, get_tp_group()) + hidden_states = mint.mul(hidden_states, group_logits) + return hidden_states diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 151946fc..5bc0aef7 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -36,9 +36,9 @@ from vllm.model_executor.layers.fused_moe.layers import (determine_expert_map, M from vllm_mindspore.model_executor.layers.fused_moe.fused_moe import (fused_topk, grouped_topk, MOE_DP_CHUNK_SIZE, - fused_expert) + fused_experts) -from mindspore import nn, Tensor, Parameter +from mindspore import nn, Tensor, Parameter, mint @@ -50,124 +50,47 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, nn.Cell): self.fused_experts = fused_experts # type: ignore self.moe = moe - self.rocm_aiter_fused_experts = None # type: ignore - - def select_gemm_impl( - self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize]): - - assert self.fused_experts == fused_experts - - all2all_manager = get_ep_group().device_communicator.all2all_manager - assert all2all_manager is not None - - experts: Optional[FusedMoEPermuteExpertsUnpermute] = None - - if isinstance(prepare_finalize, - (BatchedPrepareAndFinalize, PplxPrepareAndFinalize)): - logger.debug("BatchedTritonExperts %s", self.moe) - experts = BatchedTritonExperts( - max_num_tokens=MOE_DP_CHUNK_SIZE, - world_size=all2all_manager.world_size, - # dp_size actually means tp_size, bug in pplx kernels - dp_size=all2all_manager.tp_group.world_size, - use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - block_shape=None, - ) - else: - logger.debug("TritonExperts %s", self.moe) - experts = TritonExperts( - use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - block_shape=None, - per_channel_quant=False, - ) - return experts - - def create_weights(self, layer: torch.nn.Module, num_experts: int, + def create_weights(self, layer: nn.Cell, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): + params_dtype, **extra_weight_attrs): # Fused gate_up_proj (column parallel) - w13_weight = torch.nn.Parameter(torch.empty( + w13_weight = Parameter(mint.empty( num_experts, 2 * intermediate_size_per_partition, hidden_size, dtype=params_dtype), requires_grad=False) - layer.register_parameter("w13_weight", w13_weight) + layer.insert_param_to_cell("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) # down_proj (row parallel) - w2_weight = torch.nn.Parameter(torch.empty( + w2_weight = Parameter(mint.empty( num_experts, hidden_size, intermediate_size_per_partition, dtype=params_dtype), requires_grad=False) - layer.register_parameter("w2_weight", w2_weight) + layer.insert_param_to_cell("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) - def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: - # Pad the weight tensor. This is an optimization on ROCm platform, which - # can benefit from tensors located far enough from one another in memory - if (envs.VLLM_ROCM_MOE_PADDING and current_platform.is_rocm() - and weight.stride(-1) == 1 - and (weight.stride(-2) * weight.element_size()) % 512 == 0): - num_pad = 256 // weight.element_size() - weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] - torch.cuda.empty_cache() - return weight - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - super().process_weights_after_loading(layer) - - # Padding the weight for better performance on ROCm - layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data) - layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) - # Lazy import to avoid importing triton. - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - shuffle_weights) - - if self.rocm_aiter_moe_enabled: - shuffled_w13, shuffled_w2 = shuffle_weights( - layer.w13_weight.data, layer.w2_weight.data) - - layer.w13_weight.data = shuffled_w13 - layer.w2_weight.data = shuffled_w2 - - if current_platform.is_cpu(): - if current_platform.get_cpu_architecture() == CpuArchEnum.X86: - import intel_extension_for_pytorch as ipex - layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( - layer.w13_weight, - layer.w2_weight, - use_prepack=envs.VLLM_CPU_MOE_PREPACK, - ) - else: - raise NotImplementedError("CPU MOE only supports x86 arch.") - def apply( self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, + layer: nn.Cell, + x: Tensor, + router_logits: Tensor, top_k: int, renormalize: bool, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, + expert_map: Optional[Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: Optional[Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", - ) -> torch.Tensor: + ) -> Tensor: return self.forward_npu( x=x, layer=layer, @@ -187,22 +110,22 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, nn.Cell): def forward_npu( self, - layer: torch.nn.Module, - x: torch.Tensor, + layer: nn.Cell, + x: Tensor, use_grouped_topk: bool, top_k: int, - router_logits: torch.Tensor, + router_logits: Tensor, renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, + expert_map: Optional[Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: Optional[Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", - ) -> torch.Tensor: + ) -> Tensor: topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -408,126 +331,103 @@ class FusedMoE(nn.Cell): def use_pplx_kernels(self): return self.moe_parallel_config.use_pplx_kernels - def _load_per_tensor_weight_scale(self, shard_id: str, - param: Parameter, - loaded_weight: Tensor, - expert_id: int): - param_data = param.data - # for per tensor weight quantization - if shard_id in ("w1", "w3"): - # We have to keep the weight scales of w1 and w3 because - # we need to re-quantize w1/w3 weights after weight loading. - idx = 0 if shard_id == "w1" else 1 - param_data[expert_id][idx] = loaded_weight - # If we are in the row parallel case (down_proj) - elif shard_id == "w2": - param_data[expert_id] = loaded_weight - - def _load_model_weight_or_group_weight_scale(self, - shard_dim: int, - expert_data: Tensor, - shard_id: str, - loaded_weight: Tensor, - tp_rank: int, - load_full_w2: bool = False): - """ - Load grouped weight scales for group quantization or model weights - :param shard_dim: dimension to shard - :param expert_data: parameter for a particular expert - :param shard_id: either w1, w2, or w3 - :param loaded_weight: checkpoint weight to load into the param - :param tp_rank: tensor parallel rank - :param load_full_w2: whether or not the w2 loaded should be sharded. - """ - if shard_id == "w2": - # In the case where we have actorder/g_idx, we do not partition the - # w2 scales, as indicated by `load_full` argument, for all tp cases - self._load_w2(shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=tp_rank, - load_full=load_full_w2) - elif shard_id in ("w1", "w3"): - self._load_w13(shard_id=shard_id, - shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=tp_rank) - - def _load_per_channel_weight_scale(self, expert_data: Tensor, - shard_dim: int, shard_id: str, - loaded_weight: Tensor, - tp_rank: int): - # for per channel weight quantization - if shard_id == "w2": - expert_data.copy_(loaded_weight) - elif shard_id in ("w1", "w3"): - self._load_w13(shard_id=shard_id, - shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=tp_rank) - - def _load_w13(self, expert_data: Tensor, shard_dim: int, - shard_id: str, loaded_weight: Tensor, tp_rank: int): + def _load_w13(self, param: Parameter, shard_dim: int, + shard_id: str, loaded_weight: Tensor, expert_id: int, + tp_rank: int): # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim - shard_size = expert_data.shape[shard_dim] // 2 + shard_size = param.shape[shard_dim] // 2 loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, shard_size) # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. if shard_id == "w1": - expert_data = expert_data.narrow(shard_dim, 0, shard_size) + # expert_data = expert_data.narrow(shard_dim, 0, shard_size) + param[expert_id, ] = loaded_weight # w3, up_proj: Load into second logical weight of w13. else: assert shard_id == "w3" - expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) - expert_data.copy_(loaded_weight) + # expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + param[expert_id, ] = loaded_weight + # expert_data.set_data(loaded_weight) def _load_w2(self, - expert_data: Tensor, + param: Parameter, shard_dim: int, loaded_weight: Tensor, tp_rank: int, - load_full: bool = False): + expert_id: int, + load_full: bool = False) # Index the loaded weight for tp sharding. # down_proj: "RowParallel" so tp sharding on input_dim # Narrow parameter and load. - shard_size = expert_data.shape[shard_dim] + shard_size = param.shape[shard_dim] if not load_full: loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, shard_size) # w2, down_proj: Load into only logical weight of w2. - expert_data.copy_(loaded_weight) + param[expert_id] = loaded_weight - def _load_single_value(self, param: torch.nn.Parameter, + def _load_single_value(self, param: Parameter, loaded_weight: Tensor, expert_id: int): - param_data = param.data + param[expert_id] = loaded_weight - # Input scales can be loaded directly and should be equal. - param_data[expert_id] = loaded_weight - - def _load_g_idx(self, shard_id: str, expert_data: Tensor, - shard_dim: int, loaded_weight: Tensor, tp_rank: int): + def _load_g_idx(self, shard_id: str, param: Parameter, + shard_dim: int, loaded_weight: Tensor, tp_rank: int, + expert_id: int): if shard_id == "w2": self._load_w2(shard_dim=shard_dim, loaded_weight=loaded_weight, - expert_data=expert_data, + param=param, + expert_id=expert_id, tp_rank=tp_rank) else: assert shard_id in ("w1", "w3") - expert_data.copy_(loaded_weight) + param[expert_id] = loaded_weight def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: if self.expert_map is None: return expert_id return self.expert_map[expert_id].item() + def _load_model_weight_or_group_weight_scale(self, + shard_dim: int, + param: Parameter, + shard_id: str, + loaded_weight: Tensor, + tp_rank: int, + expert_id: int, + load_full_w2: bool = False): + """ + Load grouped weight scales for group quantization or model weights + :param shard_dim: dimension to shard + :param expert_data: parameter for a particular expert + :param shard_id: either w1, w2, or w3 + :param loaded_weight: checkpoint weight to load into the param + :param tp_rank: tensor parallel rank + :param load_full_w2: whether or not the w2 loaded should be sharded. + """ + if shard_id == "w2": + # In the case where we have actorder/g_idx, we do not partition the + # w2 scales, as indicated by `load_full` argument, for all tp cases + self._load_w2(shard_dim=shard_dim, + loaded_weight=loaded_weight, + param=param, + tp_rank=tp_rank, + expert_id=expert_id, + load_full=load_full_w2) + elif shard_id in ("w1", "w3"): + self._load_w13(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + param=param, + expert_id=expert_id, + tp_rank=tp_rank) + def weight_loader(self, param: torch.nn.Parameter, loaded_weight: Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: @@ -535,14 +435,6 @@ class FusedMoE(nn.Cell): expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) if expert_id == -1: return - quant_method_name = self.quant_method.__class__.__name__ - # compressed-tensors checkpoints with packed weights are stored flipped - # TODO (mgoin): check self.quant_method.quant_config.quant_format - # against known CompressionFormat enum values that have this quality - if self.quant_method.__class__.__name__ in ( - "CompressedTensorsWNA16MarlinMoEMethod", - "CompressedTensorsWNA16MoEMethod"): - loaded_weight = loaded_weight.t().contiguous() if shard_id not in ("w1", "w2", "w3"): raise ValueError(f"shard_id must be ['w1','w2','w3'] but " @@ -556,13 +448,6 @@ class FusedMoE(nn.Cell): # dimension intermediate_size_per_partition is used. SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} - is_gguf_weight = getattr(param, "is_gguf_weight", False) - is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) - if is_gguf_weight_type: - param.weight_type = loaded_weight.item() - param.data.copy_(loaded_weight) - return - # is_transposed: if the dim to shard the weight # should be flipped. Required by GPTQ, compressed-tensors # should be whatever dimension intermediate_size_per_partition is @@ -575,93 +460,14 @@ class FusedMoE(nn.Cell): if full_load: shard_dim += 1 - # Materialize GGUF UninitializedParameter - if is_gguf_weight and isinstance(param, UninitializedParameter): - final_shape = list(loaded_weight.shape) - if shard_id in ["w1", "w3"]: - final_shape[1] *= 2 - final_shape[shard_dim] = final_shape[shard_dim] // self.tp_size - param.materialize(final_shape, dtype=loaded_weight.dtype) - - expert_data = param.data if full_load else param.data[expert_id] - # Case input scale: input_scale loading is only supported for fp8 - if "input_scale" in weight_name: - # this is needed for compressed-tensors only - loaded_weight = loaded_weight.to(param.data.device) - - if ("compressed" in quant_method_name.lower() - and param.data[expert_id] != 1 - and (param.data[expert_id] - loaded_weight).abs() > 1e-5): - raise ValueError( - "input_scales of w1 and w3 of a layer " - f"must be equal. But got {param.data[expert_id]} " - f"vs. {loaded_weight}") - - self._load_single_value(param=param, - loaded_weight=loaded_weight, - expert_id=expert_id) - return - # Case g_idx if "g_idx" in weight_name: self._load_g_idx(shard_dim=0, shard_id=shard_id, loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=self.tp_rank) - return - - if "ModelOpt" in quant_method_name: - if ('weight_scale_2' in weight_name - or 'input_scale' in weight_name): - self._load_per_tensor_weight_scale(shard_id=shard_id, - param=param, - loaded_weight=loaded_weight, - expert_id=expert_id) - elif "weight" in weight_name: - self._load_model_weight_or_group_weight_scale( - shard_id=shard_id, - shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=self.tp_rank) - return - - # Case weight scales, zero_points and offset - if ("scale" in weight_name or "zero" in weight_name - or "offset" in weight_name): - # load the weight scales and zp based on the quantization scheme - # supported weight scales/zp can be found in - # FusedMoeWeightScaleSupported - # TODO @dsikka: once hardened, refactor to use vLLM Parameters - # specific to each case - quant_method = getattr(param, "quant_method", None) - if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value: - self._load_per_channel_weight_scale( - shard_id=shard_id, - shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=self.tp_rank) - elif quant_method in [ - FusedMoeWeightScaleSupported.GROUP.value, - FusedMoeWeightScaleSupported.BLOCK.value, - ]: - self._load_model_weight_or_group_weight_scale( - shard_id=shard_id, - shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=self.tp_rank, - load_full_w2=getattr(param, "load_full_w2", False)) - elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: - self._load_per_tensor_weight_scale(shard_id=shard_id, - param=param, - loaded_weight=loaded_weight, - expert_id=expert_id) - else: - raise ValueError( - f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}") + param=param, + tp_rank=self.tp_rank, + expert_id=expert_id) return # Case weight_shape @@ -678,7 +484,7 @@ class FusedMoE(nn.Cell): shard_id=shard_id, shard_dim=shard_dim, loaded_weight=loaded_weight, - expert_data=expert_data, + param=param, tp_rank=self.tp_rank) return diff --git a/vllm_mindspore/platforms/ascend.py b/vllm_mindspore/platforms/ascend.py index 43d5d177..7a31885c 100644 --- a/vllm_mindspore/platforms/ascend.py +++ b/vllm_mindspore/platforms/ascend.py @@ -136,7 +136,7 @@ class AscendPlatform(Platform): def get_device_communicator_cls(cls) -> str: """Get device specific communicator class for distributed communication.""" if envs.VLLM_USE_V1: - return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" + return "vllm_mindspore.distributed.device_communicators.npu_communicator.NPUCommunicator" return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" @classmethod -- Gitee From 72c90dc355544d975b70a838b1f129f556fc4765 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Mon, 23 Jun 2025 15:17:11 +0800 Subject: [PATCH 07/77] update load --- .../model_executor/layers/fused_moe/layer.py | 53 ++++++++++++++----- 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 5bc0aef7..f5ea5590 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -333,7 +333,7 @@ class FusedMoE(nn.Cell): def _load_w13(self, param: Parameter, shard_dim: int, shard_id: str, loaded_weight: Tensor, expert_id: int, - tp_rank: int): + tp_rank: int, load_full: bool = False): # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim @@ -342,15 +342,38 @@ class FusedMoE(nn.Cell): shard_size) # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. - if shard_id == "w1": - # expert_data = expert_data.narrow(shard_dim, 0, shard_size) - param[expert_id, ] = loaded_weight - # w3, up_proj: Load into second logical weight of w13. + if not load_full: + if shard_id == "w1": + if shard_dim == 1: + param[expert_id, :, 0:shard_size] = loaded_weight + else: + assert shard_dim == 0 + param[expert_id, 0:shard_size, :] = loaded_weight + # w3, up_proj: Load into second logical weight of w13. + else: + assert shard_id == "w3" + # expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + if shard_dim == 1: + param[expert_id, :, shard_size:shard_size*2] = loaded_weight + else: + assert shard_dim == 0 + param[expert_id, shard_size:shard_size*2, :] = loaded_weight else: - assert shard_id == "w3" - # expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) - param[expert_id, ] = loaded_weight - # expert_data.set_data(loaded_weight) + if shard_id == "w1": + if shard_dim == 2: + param[:, :, 0:shard_size] = loaded_weight + else: + assert shard_dim == 1 + param[:, 0:shard_size, :] = loaded_weight + # w3, up_proj: Load into second logical weight of w13. + else: + assert shard_id == "w3" + # expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + if shard_dim == 2: + param[:, :, shard_size:shard_size*2] = loaded_weight + else: + assert shard_dim == 1 + param[:, shard_size:shard_size*2, :] = loaded_weight def _load_w2(self, param: Parameter, @@ -358,7 +381,7 @@ class FusedMoE(nn.Cell): loaded_weight: Tensor, tp_rank: int, expert_id: int, - load_full: bool = False) + load_full: bool = False): # Index the loaded weight for tp sharding. # down_proj: "RowParallel" so tp sharding on input_dim @@ -368,8 +391,10 @@ class FusedMoE(nn.Cell): loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, shard_size) + param.set_data(loaded_weight) # w2, down_proj: Load into only logical weight of w2. - param[expert_id] = loaded_weight + else: + param[expert_id] = loaded_weight def _load_single_value(self, param: Parameter, loaded_weight: Tensor, expert_id: int): @@ -401,7 +426,8 @@ class FusedMoE(nn.Cell): loaded_weight: Tensor, tp_rank: int, expert_id: int, - load_full_w2: bool = False): + load_full_w2: bool = False, + load_full_w3: bool = False): """ Load grouped weight scales for group quantization or model weights :param shard_dim: dimension to shard @@ -426,7 +452,8 @@ class FusedMoE(nn.Cell): loaded_weight=loaded_weight, param=param, expert_id=expert_id, - tp_rank=tp_rank) + tp_rank=tp_rank, + load_full=load_full_w3) def weight_loader(self, param: torch.nn.Parameter, loaded_weight: Tensor, weight_name: str, -- Gitee From 31a6c465ae8b71022a81f35c00fc6383a58c448b Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Mon, 23 Jun 2025 16:23:38 +0800 Subject: [PATCH 08/77] update moe --- .../layers/fused_moe/fused_moe.py | 62 +++++++++++-------- .../model_executor/layers/fused_moe/layer.py | 11 ++-- 2 files changed, 42 insertions(+), 31 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py index 169716d7..1f6f4b7e 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py @@ -1,9 +1,9 @@ from typing import Optional -from mindspore import Tensor, mint +from mindspore import Tensor, mint, ops from mindspore.ops.auto_generate import FusedAddTopKDiv import mindspore as ms -from vllm.distributed.parallel_state import get_tp_group +from vllm.distributed.parallel_state import get_ep_group def fused_topk( hidden_states: Tensor, @@ -56,37 +56,39 @@ def grouped_topk( return topk_weights.to(ms.float32), topk_ids.to(ms.int32) +def _ep_dispatch(x, topk_ids): + return mint.distributed.all_to_all(x, topk_ids) + +def _ep_combine(x, topk_ids): + return mint.distributed.all_to_all(x, topk_ids) + def fused_experts(hidden_states: Tensor, w1: Tensor, w2: Tensor, topk_weights: Tensor, topk_ids: Tensor, - inplace: bool = False, activation: str = "silu", - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, global_num_experts: int = -1, - expert_map: Optional[Tensor] = None, - w1_scale: Optional[Tensor] = None, - w2_scale: Optional[Tensor] = None, - w1_zp: Optional[Tensor] = None, - w2_zp: Optional[Tensor] = None, - a1_scale: Optional[Tensor] = None, - a2_scale: Optional[Tensor] = None, - block_shape: Optional[list[int]] = None, - allow_deep_gemm: bool = False) -> Tensor: + apply_router_weight_on_input: bool = False, + expert_map: Optional[Tensor] = None) -> Tensor: + use_ep = False if expert_map is not None: use_ep = True if use_ep: - _run_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, activation) + hidden_states = _ep_dispatch(hidden_states, topk_ids) + hidden_states = _run_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, activation) + hidden_states = _ep_combine(hidden_states, topk_ids) + if apply_router_weight_on_input: + hidden_states = mint.mul(hidden_states, topk_weights) + hidden_states = hidden_states.sum(-1) else: - _run_tp_moe(hidden_states, w1, w2, topk_ids, topk_weights, activation) + hidden_states =_run_tp_moe(hidden_states, w1, w2, topk_ids, topk_weights, activation) + if apply_router_weight_on_input: + hidden_states = mint.mul(hidden_states, topk_weights) + + return hidden_states def _run_activation(x, activation): if activation == "silu": @@ -97,16 +99,22 @@ def _run_activation(x, activation): raise ValueError(f"Unsupported activation function: {activation}") +group_matmul_ops = ops.auto_generate.GroupedMatmulV4() + +def _run_group_matmul(hidden_states, weight, group_list): + return group_matmul_ops([hidden_states], [weight], group_list, + None, None, None, None, None, None, + group_list, split_item=3, group_type=0, group_list_type=1) + def _run_ep_moe(hidden_states, w1, w2, group_list, group_logits, activation): - hidden_states = mint.group_matmul(hidden_states, w1, group_list) + hidden_states = _run_group_matmul(hidden_states, w1, group_list) hidden_states = _run_activation(hidden_states, activation) - hidden_states = mint.group_matmul(hidden_states, w2, group_list) - hidden_states = mint.mul(hidden_states, group_logits) + hidden_states = _run_group_matmul(hidden_states, w2, group_list) return hidden_states @@ -116,9 +124,9 @@ def _run_tp_moe(hidden_states, group_list, group_logits, activation): - hidden_states = mint.group_matmul(hidden_states, w1, group_list) + # hidden_states = mint.group_matmul(hidden_states, w1, group_list) + hidden_states = _run_group_matmul([hidden_states], [w1], group_list) hidden_states = _run_activation(hidden_states, activation) - hidden_states = mint.group_matmul(hidden_states, w2, group_list) - hidden_states = mint.all_reduce(hidden_states, get_tp_group()) - hidden_states = mint.mul(hidden_states, group_logits) + hidden_states = _run_group_matmul(hidden_states, w2, group_list) + hidden_states = mint.distributed.all_reduce(hidden_states, get_ep_group()) return hidden_states diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index f5ea5590..9dcc459d 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -145,10 +145,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, nn.Cell): w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=True, activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, expert_map=expert_map, ) @@ -592,9 +591,13 @@ class FusedMoE(nn.Cell): router_logits: Tensor): assert self.quant_method is not None - if self.dp_size > 1: + do_naive_dispatch_combine: bool = ( + self.dp_size > 1 + and not self.ep_size > 1) + if do_naive_dispatch_combine: hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits) + # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -614,7 +617,7 @@ class FusedMoE(nn.Cell): apply_router_weight_on_input=self.apply_router_weight_on_input, ) - if self.dp_size > 1: + if do_naive_dispatch_combine: final_hidden_states = get_ep_group().combine(final_hidden_states) if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): -- Gitee From 95a1b96501e2a22139af9298cf5a296413e1d3ce Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Mon, 23 Jun 2025 16:25:17 +0800 Subject: [PATCH 09/77] register model --- vllm_mindspore/model_executor/models/registry.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_mindspore/model_executor/models/registry.py b/vllm_mindspore/model_executor/models/registry.py index 009d84a0..50dde9a4 100644 --- a/vllm_mindspore/model_executor/models/registry.py +++ b/vllm_mindspore/model_executor/models/registry.py @@ -30,6 +30,7 @@ _NATIVE_MODELS = { "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), + "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"), } _MINDFORMERS_MODELS = { -- Gitee From 395791820f33cc1a161bb629e1adc63fe29f311f Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Mon, 23 Jun 2025 17:21:05 +0800 Subject: [PATCH 10/77] update --- .../layers/fused_moe/__init__.py | 3 +- .../model_executor/layers/fused_moe/layer.py | 70 ++++++++++++++++++- .../model_executor/models/qwen3_moe.py | 2 +- vllm_mindspore/model_executor/models/utils.py | 4 +- 4 files changed, 71 insertions(+), 8 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/__init__.py b/vllm_mindspore/model_executor/layers/fused_moe/__init__.py index a38a67cd..29502460 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/__init__.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/__init__.py @@ -1,2 +1 @@ -class FusedMoE: - ... \ No newline at end of file +from .layer import FusedMoE \ No newline at end of file diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 9dcc459d..6b8db5ad 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -26,9 +26,9 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum from vllm.utils import direct_register_custom_op -from vllm.model_executor.layers.fused_moe.layers import FusedMoEParallelConfig +from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig -from vllm.model_executor.layers.fused_moe.layers import (determine_expert_map, MoEConfig, +from vllm.model_executor.layers.fused_moe.layer import (determine_expert_map, FusedMoeWeightScaleSupported, FusedMoEMethodBase) @@ -40,6 +40,70 @@ from vllm_mindspore.model_executor.layers.fused_moe.fused_moe import (fused_topk from mindspore import nn, Tensor, Parameter, mint +logger = init_logger(__name__) + + +@dataclass +class MoEConfig: + num_experts: int + experts_per_token: int + hidden_dim: int + + num_local_experts: int + moe_parallel_config: FusedMoEParallelConfig + + in_dtype: torch.dtype # The activation type. + quant_dtype: torch.dtype = None + + # TODO: add more quantization params, blocked, per-token, etc. + block_size: int = 128 + + max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE + + def __post_init__(self): + if self.dp_size > 1: + logger.debug("Using MOEConfig::max_num_tokens=%d", + self.max_num_tokens) + + @property + def tp_size(self): + return self.moe_parallel_config.tp_size + + @property + def dp_size(self): + return self.moe_parallel_config.dp_size + + @property + def ep_size(self): + return self.moe_parallel_config.ep_size + + @property + def tp_rank(self): + return self.moe_parallel_config.tp_rank + + @property + def dp_rank(self): + return self.moe_parallel_config.dp_rank + + @property + def ep_rank(self): + return self.moe_parallel_config.ep_rank + + @property + def use_ep(self): + return self.moe_parallel_config.use_ep + + @property + def use_pplx_kernels(self): + return self.moe_parallel_config.use_pplx_kernels + + @property + def use_deepep_ht_kernels(self): + return self.moe_parallel_config.use_deepep_ht_kernels + + @property + def use_deepep_ll_kernels(self): + return self.moe_parallel_config.use_deepep_ll_kernels class UnquantizedFusedMoEMethod(FusedMoEMethodBase, nn.Cell): @@ -137,7 +201,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, nn.Cell): custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, - indices_type=torch.uint32 if self.moe.use_pplx_kernels else None) + indices_type=None) return self.fused_experts( hidden_states=x, diff --git a/vllm_mindspore/model_executor/models/qwen3_moe.py b/vllm_mindspore/model_executor/models/qwen3_moe.py index 27533115..5e25a372 100644 --- a/vllm_mindspore/model_executor/models/qwen3_moe.py +++ b/vllm_mindspore/model_executor/models/qwen3_moe.py @@ -486,7 +486,7 @@ class Qwen3MoeForCausalLM(NativeModel, SupportsPP): fall_back_to_pt_during_load = False def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() + super().__init__(vllm_config=vllm_config, prefix=prefix) config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config diff --git a/vllm_mindspore/model_executor/models/utils.py b/vllm_mindspore/model_executor/models/utils.py index 9ecebe2a..56e7b623 100644 --- a/vllm_mindspore/model_executor/models/utils.py +++ b/vllm_mindspore/model_executor/models/utils.py @@ -20,8 +20,8 @@ from dataclasses import dataclass, field from typing import Iterable, List, Mapping, Optional, Tuple, Union -import mindspore as ms, nn -from mindspore import mint, ops +import mindspore as ms +from mindspore import mint, ops, nn from vllm.sequence import IntermediateTensors from vllm_mindspore.multimodal.inputs import NestedTensors # type: ignore[attr-defined] -- Gitee From bb18fc8abae763f00270c6aadd6f226e621578ba Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Mon, 23 Jun 2025 17:24:04 +0800 Subject: [PATCH 11/77] update --- .../model_executor/layers/fused_moe/layer.py | 142 +++++++++++++++++- 1 file changed, 140 insertions(+), 2 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 6b8db5ad..1d1eac54 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -26,11 +26,13 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum from vllm.utils import direct_register_custom_op -from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig +# from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig from vllm.model_executor.layers.fused_moe.layer import (determine_expert_map, FusedMoeWeightScaleSupported, - FusedMoEMethodBase) + FusedMoEMethodBase, + #MoEConfig, + ) from vllm_mindspore.model_executor.layers.fused_moe.fused_moe import (fused_topk, @@ -43,6 +45,142 @@ from mindspore import nn, Tensor, Parameter, mint logger = init_logger(__name__) +@dataclass +class FusedMoEParallelConfig: + tp_size: int + dp_size: int + ep_size: int + tp_rank: int + dp_rank: int + ep_rank: int + + use_ep: bool # whether to use EP or not + + @property + def use_all2all_kernels(self): + return self.dp_size > 1 and self.use_ep + + @property + def use_pplx_kernels(self): + return (self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "pplx") + + @property + def use_deepep_ht_kernels(self): + return (self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput") + + @property + def use_deepep_ll_kernels(self): + return (self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") + + @staticmethod + def make(tp_size_: int, dp_size_: int, + vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": + """ + Determine MoE parallel configuration. Based on the input tp_size_, + dp_size_, ep_size_ and vllm's parallel config, determine what + level's of parallelism to use in the fused moe layer. + + Args: + tp_size_ (int): tp_size passed into the FusedMoE constructor. + dp_size_ (int): dp_size passed into the FusedMoE constructor. + ep_size_ (int): ep_size passed into the FusedMoE constructor. + vllm_parallel_config (ParallelConfig): vllm's parallel config + object. + + Examples: + When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1, + we simply return the sizes unaltered and the ranks set to 0. + + Expert Parallelism is considered only when either dp_size_ or tp_size_ + is non trivial. + + When TP = 2, DP = 1 and EP = False, the configuration on different + devices, + - device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} // + legend : {size, rank} + - device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0} + - Comment : Tensors are sharded across 2 devices. + + When TP = 1, DP = 2 and EP = False, the configuration on different + devices, + - device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0} + - device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0} + - Comment: There are 2 engine instances and the tensors are sharded + across 2 decvices. + + When TP = 2, DP = 2 and EP = False, the configuration on different + devices, + - device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0} + - device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0} + - device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0} + - device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0} + - Comment: There are 2 engine instances and the tensors are sharded + across 4 devices. + + When, TP = 2, DP = 1 and EP = True, the configuration on different + devices, + - device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0} + - device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1} + - Comment: The experts are split between the 2 devices. + + When, TP = 1, DP = 2 and EP = True, the configuration on different + devices, + - device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0} + - device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1} + - Comment: There are 2 engine instances and the experts are split + between the 2 devices. + + When TP = 2, DP = 2 and EP = True, the configuration on different + devices, + - device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0} + - device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1} + - device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2} + - device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3} + - Comment: There are 2 engine instances and the experts are split + between the 4 devices. + """ + + def flatten_tp_across_dp(dp_rank: int): + tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank() + # There are actually dp_size_ * tp_size_ devices. Update tp_size + # and tp_rank so we shard across all devices. + tp_size = dp_size_ * tp_size_ + tp_rank = dp_rank * tp_size_ + tp_rank + return tp_size, tp_rank + + use_ep = (dp_size_ * tp_size_ > 1 + and vllm_parallel_config.enable_expert_parallel) + + dp_size = dp_size_ + dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 + tp_size, tp_rank = flatten_tp_across_dp(dp_rank) + + if not use_ep: + return FusedMoEParallelConfig(tp_size=tp_size, + tp_rank=tp_rank, + dp_size=dp_size, + dp_rank=dp_rank, + ep_size=1, + ep_rank=0, + use_ep=False) + # DP + EP / TP + EP / DP + TP + EP + assert use_ep + # In EP, each device owns a set of experts fully. There is no tensor + # parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that. + ep_size = tp_size + ep_rank = tp_rank + return FusedMoEParallelConfig(tp_size=1, + tp_rank=0, + dp_size=dp_size, + dp_rank=dp_rank, + ep_size=ep_size, + ep_rank=ep_rank, + use_ep=True) + + @dataclass class MoEConfig: num_experts: int -- Gitee From b1462eea9823cc2dbd1936bf4755a4a81203d82d Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Mon, 23 Jun 2025 19:35:35 +0800 Subject: [PATCH 12/77] update --- .../model_executor/layers/fused_moe/layer.py | 43 ++++++++++++++++--- .../model_executor/models/qwen3_moe.py | 13 +++--- vllm_mindspore/model_executor/models/utils.py | 2 +- 3 files changed, 44 insertions(+), 14 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 1d1eac54..a4a164fe 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -20,8 +20,7 @@ from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum @@ -37,8 +36,8 @@ from vllm.model_executor.layers.fused_moe.layer import (determine_expert_map, from vllm_mindspore.model_executor.layers.fused_moe.fused_moe import (fused_topk, grouped_topk, - MOE_DP_CHUNK_SIZE, fused_experts) +from vllm_mindspore.model_executor.layers.quantization.base_config import QuantizeMethodBase from mindspore import nn, Tensor, Parameter, mint @@ -244,6 +243,35 @@ class MoEConfig: return self.moe_parallel_config.use_deepep_ll_kernels +class FusedMoEMethodBase(QuantizeMethodBase): + + @abstractmethod + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + raise NotImplementedError + + @abstractmethod + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + raise NotImplementedError + class UnquantizedFusedMoEMethod(FusedMoEMethodBase, nn.Cell): """MoE method without quantization.""" @@ -465,7 +493,7 @@ class FusedMoE(nn.Cell): moe_parallel_config=self.moe_parallel_config, # TODO (bnell): this needs to be fixed for quantized types. in_dtype=params_dtype, - max_num_tokens=MOE_DP_CHUNK_SIZE, + max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, ) self.moe_config = moe self.quant_config = quant_config @@ -587,15 +615,15 @@ class FusedMoE(nn.Cell): # Index the loaded weight for tp sharding. # down_proj: "RowParallel" so tp sharding on input_dim # Narrow parameter and load. - shard_size = param.shape[shard_dim] if not load_full: + shard_size = param.shape[shard_dim + 1] loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, shard_size) - param.set_data(loaded_weight) + param[expert_id] = loaded_weight # w2, down_proj: Load into only logical weight of w2. else: - param[expert_id] = loaded_weight + param.set_data(loaded_weight) def _load_single_value(self, param: Parameter, loaded_weight: Tensor, expert_id: int): @@ -713,6 +741,7 @@ class FusedMoE(nn.Cell): shard_dim=shard_dim, loaded_weight=loaded_weight, param=param, + expert_id=expert_id, tp_rank=self.tp_rank) return diff --git a/vllm_mindspore/model_executor/models/qwen3_moe.py b/vllm_mindspore/model_executor/models/qwen3_moe.py index 5e25a372..a31b5c3e 100644 --- a/vllm_mindspore/model_executor/models/qwen3_moe.py +++ b/vllm_mindspore/model_executor/models/qwen3_moe.py @@ -22,9 +22,9 @@ # limitations under the License. """Inference-only Qwen3MoE model compatible with HuggingFace weights.""" from collections.abc import Iterable -from typing import Any, Optional, Union +from typing import Any, Optional, Union, Dict, Tuple -from mindspore import Tensor, nn +from mindspore import Tensor, nn, Parameter from transformers import PretrainedConfig from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -371,8 +371,8 @@ class Qwen3MoeModel(nn.Cell): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[Tuple[str, Tensor]], + params_dict: Dict[str, Parameter]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -390,7 +390,6 @@ class Qwen3MoeModel(nn.Cell): ckpt_up_proj_name="up_proj", num_experts=self.config.num_experts) - params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: @@ -502,6 +501,8 @@ class Qwen3MoeForCausalLM(NativeModel, SupportsPP): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + self.common_preprocess(vllm_config, prefix) + def get_input_embeddings(self, input_ids: Tensor) -> Tensor: return self.model.get_input_embeddings(input_ids) @@ -528,4 +529,4 @@ class Qwen3MoeForCausalLM(NativeModel, SupportsPP): def load_weights(self, weights: Iterable[tuple[str, Tensor]]) -> set[str]: params_dict = self.get_params_dict() - self.model.load_weights(weights, params_dict) + return self.model.load_weights(weights, params_dict) diff --git a/vllm_mindspore/model_executor/models/utils.py b/vllm_mindspore/model_executor/models/utils.py index 56e7b623..26b5c268 100644 --- a/vllm_mindspore/model_executor/models/utils.py +++ b/vllm_mindspore/model_executor/models/utils.py @@ -273,7 +273,7 @@ def get_pp_missing_layer_names(model: nn.Cell) -> list[str]: return _model_to_pp_missing_layer_names[model_id] missing_layer_names = [] - for name, cell in model.name_cells(): + for cell, name in model.cells_and_names(): if isinstance(cell, PPMissingLayer): # NOTE: the trailing dot is used to match the prefix of the layer. # without the dot, we could match a layer that is not missing, -- Gitee From c0995996ddd91081b5431e0d18a0ae374445228f Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Mon, 23 Jun 2025 19:48:16 +0800 Subject: [PATCH 13/77] update --- vllm_mindspore/config.py | 13 ++++++++++++- vllm_mindspore/v1/worker/gpu_model_runner.py | 5 ++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/vllm_mindspore/config.py b/vllm_mindspore/config.py index 0fd6ca23..5cdb6736 100644 --- a/vllm_mindspore/config.py +++ b/vllm_mindspore/config.py @@ -16,7 +16,7 @@ # limitations under the License. # ============================================================================ from collections import Counter -from typing import Union +from typing import Union, TypeVar import sys import socket import threading @@ -409,3 +409,14 @@ def stateless_destroy_socket_process_group(dp_group: "SocketProcessGroup") -> No if dp_group: dp_group.close() logger.info(f"Socket process group for rank {dp_group.rank} destroyed.") + +T = TypeVar("T") + +def get_layers_from_vllm_config(vllm_config: VllmConfig, + layer_type: type[T]) -> dict[str, T]: + return { + layer_name: layer + for layer_name, layer in + vllm_config.compilation_config.static_forward_context.items() + if isinstance(layer, layer_type) + } diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index 7f4e3fe1..1edc077b 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -25,6 +25,8 @@ from mindspore import mutable from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata from vllm_mindspore.utils import get_valid_dtype from vllm_mindspore.model_executor.layers.rotary_embedding import InferMRotaryEmbedding as MRotaryEmbedding # type: ignore[attr-defined] +from vllm_mindspore.config import get_layers_from_vllm_config +from vllm_mindspore.model_executor.models.model_base import AttentionWrapper from vllm.v1.outputs import ModelRunnerOutput from vllm.attention import AttentionType @@ -444,7 +446,8 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla kv_cache_spec: dict[str, KVCacheSpec] = {} - for layer_name, attn_module in forward_ctx.items(): + attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionWrapper) + for layer_name, attn_module in attn_layers.items(): # vllm-mindspore AttentionWrapper is not an Attention isinstance # assert isinstance(attn_module, Attention) if attn_module.attn_type == AttentionType.DECODER: -- Gitee From 1c7f00087e8bbe005d1fcbd68c8bf05f371a40bb Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Mon, 23 Jun 2025 20:13:51 +0800 Subject: [PATCH 14/77] update --- .../model_executor/models/qwen3_moe.py | 58 ++++++++++++++----- 1 file changed, 44 insertions(+), 14 deletions(-) diff --git a/vllm_mindspore/model_executor/models/qwen3_moe.py b/vllm_mindspore/model_executor/models/qwen3_moe.py index a31b5c3e..59c92a63 100644 --- a/vllm_mindspore/model_executor/models/qwen3_moe.py +++ b/vllm_mindspore/model_executor/models/qwen3_moe.py @@ -22,7 +22,7 @@ # limitations under the License. """Inference-only Qwen3MoE model compatible with HuggingFace weights.""" from collections.abc import Iterable -from typing import Any, Optional, Union, Dict, Tuple +from typing import Any, Optional, Union, Dict, Tuple, List from mindspore import Tensor, nn, Parameter from transformers import PretrainedConfig @@ -84,7 +84,7 @@ class Qwen3MoeMLP(nn.Cell): "Only silu is supported for now.") self.act_fn = SiluAndMul() - def forward(self, x): + def construct(self, x): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) @@ -122,7 +122,7 @@ class Qwen3MoeSparseMoeBlock(nn.Cell): quant_config=None, prefix=f"{prefix}.gate") - def forward(self, hidden_states: Tensor) -> Tensor: + def construct(self, hidden_states: Tensor) -> Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape hidden_dim = hidden_states.shape[-1] @@ -212,10 +212,18 @@ class Qwen3MoeAttention(nn.Cell): self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) - def forward( + def construct( self, positions: Tensor, hidden_states: Tensor, + key_cache: Tensor, + value_cache: Tensor, + is_prefill: bool, + slot_mapping: Tensor, + attn_mask: Tensor, + batch_valid_length: Tensor, + q_seq_lens: Tensor, + block_tables: Tensor, ) -> Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -229,8 +237,10 @@ class Qwen3MoeAttention(nn.Cell): self.head_dim) k_by_head = self.k_norm(k_by_head) k = k_by_head.view(k.shape) - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v) + q, k = self.rotary_emb(positions, q, k, batch_valid_length, is_prefill) + attn_output = self.attn(q, k, v, key_cache, value_cache, is_prefill, + slot_mapping, attn_mask, batch_valid_length, + q_seq_lens, block_tables) output, _ = self.o_proj(attn_output) return output @@ -286,10 +296,18 @@ class Qwen3MoeDecoderLayer(nn.Cell): self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def forward( + def construct( self, positions: Tensor, hidden_states: Tensor, + key_cache: Tensor, + value_cache: Tensor, + is_prefill: bool, + slot_mapping: Tensor, + attn_mask: Tensor, + batch_valid_length: Tensor, + q_seq_lens: Tensor, + block_tables: Tensor, residual: Optional[Tensor], ) -> Tensor: # Self Attention @@ -299,11 +317,10 @@ class Qwen3MoeDecoderLayer(nn.Cell): else: hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - ) - + hidden_states = self.self_attn(positions, hidden_states, key_cache, + value_cache, is_prefill, slot_mapping, + attn_mask, batch_valid_length, + q_seq_lens, block_tables) # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) @@ -343,10 +360,18 @@ class Qwen3MoeModel(nn.Cell): def get_input_embeddings(self, input_ids: Tensor) -> Tensor: return self.embed_tokens(input_ids) - def forward( + def construct( self, input_ids: Tensor, positions: Tensor, + key_caches: List[Tensor], + value_caches: List[Tensor], + is_prefill: bool, + slot_mapping: Tensor, + attn_mask: Tensor, + batch_valid_length: Tensor, + q_seq_lens: Tensor, + block_tables: Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[Tensor] = None, ) -> Union[Tensor, IntermediateTensors]: @@ -362,7 +387,12 @@ class Qwen3MoeModel(nn.Cell): residual = intermediate_tensors["residual"] for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, residual) + hidden_states, residual = layer(positions, hidden_states, + key_caches[i - self.start_layer], + value_caches[i - self.start_layer], + is_prefill, slot_mapping, + attn_mask, batch_valid_length, + q_seq_lens, block_tables, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, -- Gitee From a1df535dd87bcc23b8d4fd7eead5eef6cc8dcd55 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Mon, 23 Jun 2025 20:15:01 +0800 Subject: [PATCH 15/77] update --- vllm_mindspore/model_executor/models/qwen3_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_mindspore/model_executor/models/qwen3_moe.py b/vllm_mindspore/model_executor/models/qwen3_moe.py index 59c92a63..f0ffc6f8 100644 --- a/vllm_mindspore/model_executor/models/qwen3_moe.py +++ b/vllm_mindspore/model_executor/models/qwen3_moe.py @@ -542,6 +542,7 @@ class Qwen3MoeForCausalLM(NativeModel, SupportsPP): positions: Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[Tensor] = None, + **kwargs ) -> Union[Tensor, IntermediateTensors]: hidden_states = self.exec_model(input_ids, positions, intermediate_tensors, inputs_embeds) -- Gitee From 17a3aa1afe0382e82e8a12e8e8903dca96a53241 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Mon, 23 Jun 2025 20:45:20 +0800 Subject: [PATCH 16/77] update --- vllm_mindspore/model_executor/layers/linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index 0dee09d6..f2a883a7 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -248,7 +248,7 @@ class ReplicatedLinear(LinearBase): f"to a parameter of size {param.size()}") param.set_data(loaded_weight) - def forward( + def construct( self, x: Tensor ) -> Union[Tensor, tuple[Tensor, Optional[Parameter]]]: bias = self.bias if not self.skip_bias_add else None -- Gitee From 80d4d72a641349b3ccb300635a0f4e1299c6f071 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Mon, 23 Jun 2025 20:49:42 +0800 Subject: [PATCH 17/77] update npucomm --- .../device_communicators/npu_communicator.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/vllm_mindspore/distributed/device_communicators/npu_communicator.py b/vllm_mindspore/distributed/device_communicators/npu_communicator.py index 3885baa9..9cd0f278 100644 --- a/vllm_mindspore/distributed/device_communicators/npu_communicator.py +++ b/vllm_mindspore/distributed/device_communicators/npu_communicator.py @@ -1,4 +1,17 @@ +from mindspore import Tensor + from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator class NPUCommunicator(CudaCommunicator): - ... + def dispatch( + self, hidden_states: Tensor, + router_logits: Tensor) -> tuple[Tensor, Tensor]: + assert self.all2all_manager is not None + hidden_states, router_logits = self.all2all_manager.dispatch( + hidden_states, router_logits) + return hidden_states, router_logits + + def combine(self, hidden_states: Tensor) -> Tensor: + assert self.all2all_manager is not None + hidden_states = self.all2all_manager.combine(hidden_states) + return hidden_states -- Gitee From edc03deeaa9b8981714b4b528bea21d1dc0d969f Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Mon, 23 Jun 2025 21:10:53 +0800 Subject: [PATCH 18/77] update --- .../device_communicators/npu_communicator.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/vllm_mindspore/distributed/device_communicators/npu_communicator.py b/vllm_mindspore/distributed/device_communicators/npu_communicator.py index 9cd0f278..cfb89294 100644 --- a/vllm_mindspore/distributed/device_communicators/npu_communicator.py +++ b/vllm_mindspore/distributed/device_communicators/npu_communicator.py @@ -1,8 +1,35 @@ from mindspore import Tensor +from mindspore.communication import get_rank, get_group_size +import torch.distributed as dist from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator +from vllm.distributed.parallel_state import (get_dp_group, + get_tp_group, + in_the_same_node_as) +from vllm.forward_context import get_forward_context + class NPUCommunicator(CudaCommunicator): + def __init__(self, + cpu_group, + device = None, + device_group = None, + unique_name: str = ""): + super().__init__(cpu_group, device, device_group, unique_name) + + # all2all lives in ep group, which is merged from dp and tp group + self.dp_group = get_dp_group() + self.tp_group = get_tp_group() + # no self.ep_group since self.ep_group is still in construction + # when we create this object + self.dp_rank = self.dp_group.rank_in_group + self.dp_world_size = self.dp_group.world_size + self.rank = dist.get_rank(cpu_group) + self.world_size = dist.get_world_size(cpu_group) + + # all2all communication often has separate implementations for + # intra-node and inter-node communication + self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0)) def dispatch( self, hidden_states: Tensor, router_logits: Tensor) -> tuple[Tensor, Tensor]: @@ -15,3 +42,25 @@ class NPUCommunicator(CudaCommunicator): assert self.all2all_manager is not None hidden_states = self.all2all_manager.combine(hidden_states) return hidden_states + + def dispatch(self, hidden_states: Tensor, + router_logits: Tensor): + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + + hidden_states = self.naive_multicast(hidden_states, + cu_tokens_across_dp_cpu) + router_logits = self.naive_multicast(router_logits, + cu_tokens_across_dp_cpu) + return hidden_states, router_logits + + def combine(self, hidden_states: Tensor) -> Tensor: + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ + self.dp_rank - 1] + end = cu_tokens_across_dp_cpu[self.dp_rank] + + all_hidden_states = self.dp_group.all_reduce(hidden_states) + hidden_states = all_hidden_states[start:end, :] + return hidden_states \ No newline at end of file -- Gitee From 926d26935acfd13a5e87ae43a0eda1cb4944d82e Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Tue, 24 Jun 2025 11:18:35 +0800 Subject: [PATCH 19/77] update --- .../layers/fused_moe/fused_moe.py | 143 ++++++++++++------ .../model_executor/layers/fused_moe/layer.py | 18 ++- vllm_mindspore/platforms/ascend.py | 3 +- 3 files changed, 107 insertions(+), 57 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py index 1f6f4b7e..aa2e32ca 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py @@ -1,9 +1,12 @@ from typing import Optional from mindspore import Tensor, mint, ops -from mindspore.ops.auto_generate import FusedAddTopKDiv +from mindspore.ops.auto_generate import (GroupedMatmulV4, + FusedAddTopKDiv, + MoeInitRoutingV2, + MoeTokenUnpermute) import mindspore as ms -from vllm.distributed.parallel_state import get_ep_group +from vllm.distributed.parallel_state import get_ep_group, get_dp_group def fused_topk( hidden_states: Tensor, @@ -25,8 +28,8 @@ def fused_topk( if indices_type is not None: topk_ids = topk_ids.to(indices_type) - return topk_weights.to(ms.float32), topk_ids.to(ms.int32) - + return topk_weights, topk_ids + def grouped_topk( hidden_states: Tensor, @@ -53,14 +56,8 @@ def grouped_topk( scoring_type, renormalize) - return topk_weights.to(ms.float32), topk_ids.to(ms.int32) - - -def _ep_dispatch(x, topk_ids): - return mint.distributed.all_to_all(x, topk_ids) + return topk_weights, topk_ids -def _ep_combine(x, topk_ids): - return mint.distributed.all_to_all(x, topk_ids) def fused_experts(hidden_states: Tensor, w1: Tensor, @@ -70,38 +67,47 @@ def fused_experts(hidden_states: Tensor, activation: str = "silu", global_num_experts: int = -1, apply_router_weight_on_input: bool = False, - expert_map: Optional[Tensor] = None) -> Tensor: - - use_ep = False - if expert_map is not None: - use_ep = True - - if use_ep: - hidden_states = _ep_dispatch(hidden_states, topk_ids) - hidden_states = _run_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, activation) - hidden_states = _ep_combine(hidden_states, topk_ids) - if apply_router_weight_on_input: - hidden_states = mint.mul(hidden_states, topk_weights) - hidden_states = hidden_states.sum(-1) - else: - hidden_states =_run_tp_moe(hidden_states, w1, w2, topk_ids, topk_weights, activation) - if apply_router_weight_on_input: - hidden_states = mint.mul(hidden_states, topk_weights) + expert_map: Optional[Tensor] = None, + tp_size: int = 1, + ep_size: int = 0) -> Tensor: + + if tp_size >= 1: + # no ep, pure tp + if ep_size == 1: + hidden_states = _run_tp_moe(hidden_states, w1, w2, topk_ids, topk_weights, + activation, global_num_experts, + apply_router_weight_on_input) + # ep_size > 1 : pure ep or tp + ep + else: + # pure ep + if tp_size == 1: + hidden_states = _run_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, + activation, global_num_experts, + apply_router_weight_on_input) + # tp_size > 1 : tp + ep + else: + hidden_states = _run_tp_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, + activation, global_num_experts, + apply_router_weight_on_input) return hidden_states -def _run_activation(x, activation): +def _gate_activation(gate, activation): if activation == "silu": - return mint.silu(x) + return mint.silu(gate) elif activation == "gelu": - return mint.gelu(x) + return mint.gelu(gate) else: raise ValueError(f"Unsupported activation function: {activation}") -group_matmul_ops = ops.auto_generate.GroupedMatmulV4() +group_matmul_ops = GroupedMatmulV4() +moe_init_routing_op = MoeInitRoutingV2() +moe_token_unpermute = MoeTokenUnpermute() +all_gather_dp = ops.AllGather(get_dp_group()) +all_reduce_ep = ops.AllReduce(get_ep_group()) -def _run_group_matmul(hidden_states, weight, group_list): +def _group_matmul(hidden_states, weight, group_list): return group_matmul_ops([hidden_states], [weight], group_list, None, None, None, None, None, None, group_list, split_item=3, group_type=0, group_list_type=1) @@ -109,24 +115,65 @@ def _run_group_matmul(hidden_states, weight, group_list): def _run_ep_moe(hidden_states, w1, w2, - group_list, - group_logits, - activation): - hidden_states = _run_group_matmul(hidden_states, w1, group_list) - hidden_states = _run_activation(hidden_states, activation) - hidden_states = _run_group_matmul(hidden_states, w2, group_list) + topk_ids, + topk_weights, + activation, + global_num_experts, + apply_router_weight_on_input): + hidden_states = _group_matmul(hidden_states, w1, topk_ids) + hidden_states = _gate_activation(hidden_states, activation) + hidden_states = _group_matmul(hidden_states, w2, topk_ids) return hidden_states def _run_tp_moe(hidden_states, w1, w2, - group_list, - group_logits, - activation): - # hidden_states = mint.group_matmul(hidden_states, w1, group_list) - hidden_states = _run_group_matmul([hidden_states], [w1], group_list) - hidden_states = _run_activation(hidden_states, activation) - hidden_states = _run_group_matmul(hidden_states, w2, group_list) - hidden_states = mint.distributed.all_reduce(hidden_states, get_ep_group()) - return hidden_states + topk_ids, + topk_weights, + activation, + global_num_experts, + apply_router_weight_on_input): + hidden_states = all_gather_dp(hidden_states) + topk_ids = all_gather_dp(topk_ids) + topk_weights = all_gather_dp(topk_weights) + sorted_input_tensor, unsort_map, group_list, _ = \ + moe_init_routing_op( + hidden_states, + topk_ids, + active_num=0, + expert_capacity=0, + expert_num=global_num_experts, + drop_pad_mode=0, + expert_tokens_count_or_cumsum_flag=2, + expert_tokens_before_capacity_flag=True) + + group_list = group_list.astype(ms.int64) + + gate_hidden_out = _group_matmul(sorted_input_tensor, w1, group_list) + gate, hidden = mint.split(gate_hidden_out, + (w1.shape[0] // 2, w1.shape[0] // 2), -1) + gate = _gate_activation(gate, activation) + hidden = mint.mul(hidden, gate) + hidden = _group_matmul(hidden, w2, group_list) + expert_output = all_reduce_ep(hidden) + if apply_router_weight_on_input: + moe_output = moe_token_unpermute(permuted_tokens=expert_output, + sorted_indices=unsort_map, + probs=topk_weights, + padded_mode=False, + restore_shape=None) + moe_output = moe_output[:].sum(-1) + return moe_output + + +def _run_tp_ep_moe(hidden_states, + w1, + w2, + group_list, + group_logits, + activation, + global_num_experts, + apply_router_weight_on_input): + raise NotImplementedError( + "TP + EP MoE is not implemented yet. Please use pure TP or pure EP MoE instead.") diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index a4a164fe..a774e17f 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -379,6 +379,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, nn.Cell): global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, expert_map=expert_map, + tp_size=self.moe.tp_size, + ep_size=self.moe.ep_size, ) @@ -822,12 +824,12 @@ class FusedMoE(nn.Cell): router_logits: Tensor): assert self.quant_method is not None - do_naive_dispatch_combine: bool = ( - self.dp_size > 1 - and not self.ep_size > 1) - if do_naive_dispatch_combine: - hidden_states, router_logits = get_ep_group().dispatch( - hidden_states, router_logits) + # do_naive_dispatch_combine: bool = ( + # self.dp_size > 1 + # and not self.ep_size > 1) + # if do_naive_dispatch_combine: + # hidden_states, router_logits = get_ep_group().dispatch( + # hidden_states, router_logits) # Matrix multiply. final_hidden_states = self.quant_method.apply( @@ -848,8 +850,8 @@ class FusedMoE(nn.Cell): apply_router_weight_on_input=self.apply_router_weight_on_input, ) - if do_naive_dispatch_combine: - final_hidden_states = get_ep_group().combine(final_hidden_states) + # if do_naive_dispatch_combine: + # final_hidden_states = get_ep_group().combine(final_hidden_states) if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) diff --git a/vllm_mindspore/platforms/ascend.py b/vllm_mindspore/platforms/ascend.py index 7a31885c..c61e3978 100644 --- a/vllm_mindspore/platforms/ascend.py +++ b/vllm_mindspore/platforms/ascend.py @@ -136,7 +136,8 @@ class AscendPlatform(Platform): def get_device_communicator_cls(cls) -> str: """Get device specific communicator class for distributed communication.""" if envs.VLLM_USE_V1: - return "vllm_mindspore.distributed.device_communicators.npu_communicator.NPUCommunicator" + return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" + # return "vllm_mindspore.distributed.device_communicators.npu_communicator.NPUCommunicator" return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" @classmethod -- Gitee From c9b2258018a194a26e924d0ed0cc235cd18bce7c Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Tue, 24 Jun 2025 16:42:36 +0800 Subject: [PATCH 20/77] update --- .../layers/fused_moe/fused_moe.py | 35 ++++++++----------- .../model_executor/layers/fused_moe/layer.py | 21 ++++++++++- 2 files changed, 35 insertions(+), 21 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py index aa2e32ca..99498877 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py @@ -94,9 +94,9 @@ def fused_experts(hidden_states: Tensor, def _gate_activation(gate, activation): if activation == "silu": - return mint.silu(gate) + return mint.nn.functional.silu(gate) elif activation == "gelu": - return mint.gelu(gate) + return mint.nn.functional.gelu(gate) else: raise ValueError(f"Unsupported activation function: {activation}") @@ -104,13 +104,11 @@ def _gate_activation(gate, activation): group_matmul_ops = GroupedMatmulV4() moe_init_routing_op = MoeInitRoutingV2() moe_token_unpermute = MoeTokenUnpermute() -all_gather_dp = ops.AllGather(get_dp_group()) -all_reduce_ep = ops.AllReduce(get_ep_group()) def _group_matmul(hidden_states, weight, group_list): - return group_matmul_ops([hidden_states], [weight], group_list, + return group_matmul_ops([hidden_states], [weight], None, None, None, None, None, None, - group_list, split_item=3, group_type=0, group_list_type=1) + group_list, split_item=3, group_type=0, group_list_type=1)[0] def _run_ep_moe(hidden_states, w1, @@ -134,9 +132,9 @@ def _run_tp_moe(hidden_states, activation, global_num_experts, apply_router_weight_on_input): - hidden_states = all_gather_dp(hidden_states) - topk_ids = all_gather_dp(topk_ids) - topk_weights = all_gather_dp(topk_weights) + topk_weights = mint.cast(topk_weights, hidden_states.dtype) + topk_ids = mint.cast(topk_ids, ms.int32) + sorted_input_tensor, unsort_map, group_list, _ = \ moe_init_routing_op( hidden_states, @@ -150,20 +148,17 @@ def _run_tp_moe(hidden_states, group_list = group_list.astype(ms.int64) - gate_hidden_out = _group_matmul(sorted_input_tensor, w1, group_list) + gate_hidden_out = _group_matmul(sorted_input_tensor, mint.transpose(w1, -1, -2), group_list) gate, hidden = mint.split(gate_hidden_out, - (w1.shape[0] // 2, w1.shape[0] // 2), -1) + (w1.shape[1] // 2, w1.shape[1] // 2), -1) gate = _gate_activation(gate, activation) hidden = mint.mul(hidden, gate) - hidden = _group_matmul(hidden, w2, group_list) - expert_output = all_reduce_ep(hidden) - if apply_router_weight_on_input: - moe_output = moe_token_unpermute(permuted_tokens=expert_output, - sorted_indices=unsort_map, - probs=topk_weights, - padded_mode=False, - restore_shape=None) - moe_output = moe_output[:].sum(-1) + expert_output = _group_matmul(hidden, mint.transpose(w2, -1, -2), group_list) + moe_output = moe_token_unpermute(permuted_tokens=expert_output, + sorted_indices=unsort_map, + probs=topk_weights, + padded_mode=False, + restore_shape=None) return moe_output diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index a774e17f..5069e94c 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -39,7 +39,8 @@ from vllm_mindspore.model_executor.layers.fused_moe.fused_moe import (fused_topk fused_experts) from vllm_mindspore.model_executor.layers.quantization.base_config import QuantizeMethodBase -from mindspore import nn, Tensor, Parameter, mint +from mindspore import nn, Tensor, Parameter, mint, ops +import mindspore as ms logger = init_logger(__name__) @@ -530,6 +531,11 @@ class FusedMoE(nn.Cell): self.quant_method.create_weights(layer=self, **moe_quant_params) + if self.dp_size > 1 and self.ep_size == 1: + self.pure_tp = True + self.all_gather_from_dp_group = ops.Gather(get_dp_group()) + self.all_reduce_from_world_group = ops.AllReduce(get_ep_group) + @property def tp_size(self): return self.moe_parallel_config.tp_size @@ -831,6 +837,13 @@ class FusedMoE(nn.Cell): # hidden_states, router_logits = get_ep_group().dispatch( # hidden_states, router_logits) + if self.pure_tp: + hidden_states = self.all_gather_from_dp_group(hidden_states) + router_logits = self.all_gather_from_dp_group(router_logits) + tokens_num = Tensor(hidden_states.shape[0], ms.int32) + tokens_num_total = self.all_gather_from_dp_group(tokens_num, 0) + tokens_cumulative = mint.cumsum(tokens_num_total) + # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -850,6 +863,12 @@ class FusedMoE(nn.Cell): apply_router_weight_on_input=self.apply_router_weight_on_input, ) + if self.pure_tp: + final_hidden_states = self.all_reduce(final_hidden_states) + start = 0 if self.dp_rank == 0 else tokens_cumulative[self.dp_rank - 1] + end = tokens_cumulative[self.dp_rank] + final_hidden_states = final_hidden_states[start:end] + # if do_naive_dispatch_combine: # final_hidden_states = get_ep_group().combine(final_hidden_states) -- Gitee From 4813e87b13c71cc21991bda077315148c484534c Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Tue, 24 Jun 2025 19:10:13 +0800 Subject: [PATCH 21/77] update v1 --- .../model_executor/layers/fused_moe/fused_moe.py | 4 ++-- .../model_executor/layers/fused_moe/layer.py | 14 ++++++++------ vllm_mindspore/model_executor/models/qwen3_moe.py | 9 +++++++++ 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py index 99498877..a3d9a173 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py @@ -132,8 +132,8 @@ def _run_tp_moe(hidden_states, activation, global_num_experts, apply_router_weight_on_input): - topk_weights = mint.cast(topk_weights, hidden_states.dtype) - topk_ids = mint.cast(topk_ids, ms.int32) + topk_weights = topk_weights.astype(hidden_states.dtype) + topk_ids = topk_ids.astype(ms.int32) sorted_input_tensor, unsort_map, group_list, _ = \ moe_init_routing_op( diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 5069e94c..9415ecab 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -533,8 +533,8 @@ class FusedMoE(nn.Cell): if self.dp_size > 1 and self.ep_size == 1: self.pure_tp = True - self.all_gather_from_dp_group = ops.Gather(get_dp_group()) - self.all_reduce_from_world_group = ops.AllReduce(get_ep_group) + self.all_gather_from_dp_group = ops.AllGather(get_dp_group().device_group._name) + self.all_reduce_from_world_group = ops.AllReduce(get_ep_group().device_group._name) @property def tp_size(self): @@ -838,11 +838,12 @@ class FusedMoE(nn.Cell): # hidden_states, router_logits) if self.pure_tp: + tokens_num = Tensor([[hidden_states.shape[0]]], dtype=ms.int32) + tokens_num_total = self.all_gather_from_dp_group(tokens_num) + tokens_num_total = tokens_num_total.reshape(-1) + tokens_cumulative = mint.cumsum(tokens_num_total, 0) hidden_states = self.all_gather_from_dp_group(hidden_states) router_logits = self.all_gather_from_dp_group(router_logits) - tokens_num = Tensor(hidden_states.shape[0], ms.int32) - tokens_num_total = self.all_gather_from_dp_group(tokens_num, 0) - tokens_cumulative = mint.cumsum(tokens_num_total) # Matrix multiply. final_hidden_states = self.quant_method.apply( @@ -864,7 +865,8 @@ class FusedMoE(nn.Cell): ) if self.pure_tp: - final_hidden_states = self.all_reduce(final_hidden_states) + # final_hidden_states = self.all_reduce_from_world_group(final_hidden_states) + mint.distributed.all_reduce(final_hidden_states) start = 0 if self.dp_rank == 0 else tokens_cumulative[self.dp_rank - 1] end = tokens_cumulative[self.dp_rank] final_hidden_states = final_hidden_states[start:end] diff --git a/vllm_mindspore/model_executor/models/qwen3_moe.py b/vllm_mindspore/model_executor/models/qwen3_moe.py index f0ffc6f8..adb710b4 100644 --- a/vllm_mindspore/model_executor/models/qwen3_moe.py +++ b/vllm_mindspore/model_executor/models/qwen3_moe.py @@ -52,6 +52,8 @@ from vllm_mindspore.model_executor.models.utils import ( extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm_mindspore.model_executor.models.model_base import NativeModel +from vllm_mindspore.model_executor.layers.sampler import (SamplerOutput, + get_sampler) logger = init_logger(__name__) @@ -531,6 +533,8 @@ class Qwen3MoeForCausalLM(NativeModel, SupportsPP): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + self.sampler = get_sampler() + self.common_preprocess(vllm_config, prefix) def get_input_embeddings(self, input_ids: Tensor) -> Tensor: @@ -548,6 +552,11 @@ class Qwen3MoeForCausalLM(NativeModel, SupportsPP): inputs_embeds) return hidden_states + def sample(self, logits: Tensor, + sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + def compute_logits( self, hidden_states: Tensor, -- Gitee From e5956279373a1bd719c6979b7fb6e53bec20c784 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Wed, 25 Jun 2025 09:10:38 +0800 Subject: [PATCH 22/77] update --- .../model_executor/layers/fused_moe/layer.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 9415ecab..bc2ab61b 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -531,10 +531,14 @@ class FusedMoE(nn.Cell): self.quant_method.create_weights(layer=self, **moe_quant_params) - if self.dp_size > 1 and self.ep_size == 1: + self.dp_group = get_dp_group().device_group._name + self.ep_group = get_ep_group().device_group._name + + if self.dp_size > 1 and self.ep_size == 1 or self.dp_size == 1: self.pure_tp = True - self.all_gather_from_dp_group = ops.AllGather(get_dp_group().device_group._name) - self.all_reduce_from_world_group = ops.AllReduce(get_ep_group().device_group._name) + if self.dp_size > 1: + self.all_gather_from_dp_group = ops.AllGather(self.dp_group) + self.all_reduce_from_world_group = ops.AllReduce(self.ep_group) @property def tp_size(self): @@ -837,7 +841,7 @@ class FusedMoE(nn.Cell): # hidden_states, router_logits = get_ep_group().dispatch( # hidden_states, router_logits) - if self.pure_tp: + if self.dp_size > 1 and self.pure_tp: tokens_num = Tensor([[hidden_states.shape[0]]], dtype=ms.int32) tokens_num_total = self.all_gather_from_dp_group(tokens_num) tokens_num_total = tokens_num_total.reshape(-1) @@ -866,10 +870,11 @@ class FusedMoE(nn.Cell): if self.pure_tp: # final_hidden_states = self.all_reduce_from_world_group(final_hidden_states) - mint.distributed.all_reduce(final_hidden_states) - start = 0 if self.dp_rank == 0 else tokens_cumulative[self.dp_rank - 1] - end = tokens_cumulative[self.dp_rank] - final_hidden_states = final_hidden_states[start:end] + mint.distributed.all_reduce(final_hidden_states, self.ep_group) + if self.dp_size > 1: + start = 0 if self.dp_rank == 0 else tokens_cumulative[self.dp_rank - 1] + end = tokens_cumulative[self.dp_rank] + final_hidden_states = final_hidden_states[start:end] # if do_naive_dispatch_combine: # final_hidden_states = get_ep_group().combine(final_hidden_states) -- Gitee From a7b5ed465a6d93b32a8992a69347eecbde923906 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Wed, 25 Jun 2025 10:31:53 +0800 Subject: [PATCH 23/77] fix --- vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py | 1 + vllm_mindspore/model_executor/layers/fused_moe/layer.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py index a3d9a173..6beb2aa5 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py @@ -154,6 +154,7 @@ def _run_tp_moe(hidden_states, gate = _gate_activation(gate, activation) hidden = mint.mul(hidden, gate) expert_output = _group_matmul(hidden, mint.transpose(w2, -1, -2), group_list) + expert_output = mint.nan_to_num(expert_output, 0, 0, 0) moe_output = moe_token_unpermute(permuted_tokens=expert_output, sorted_indices=unsort_map, probs=topk_weights, diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index bc2ab61b..d83347c4 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -578,7 +578,7 @@ class FusedMoE(nn.Cell): # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim - shard_size = param.shape[shard_dim] // 2 + shard_size = param.shape[shard_dim + 1] // 2 loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, shard_size) # Narrow parameter and load. -- Gitee From aefe3efd68294faabf02f0d126effe3909b87e31 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Wed, 25 Jun 2025 15:19:40 +0800 Subject: [PATCH 24/77] update good presision --- vllm_mindspore/model_executor/layers/fused_moe/layer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index d83347c4..614eb6ae 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -824,7 +824,8 @@ class FusedMoE(nn.Cell): """ The pplx combine kernel reduces across GPU ranks by default. """ - return tensor_model_parallel_all_reduce(final_hidden_states) + # return tensor_model_parallel_all_reduce(final_hidden_states) + return final_hidden_states def construct(self, hidden_states: Tensor, router_logits: Tensor): -- Gitee From 57a13bd3e66ac29186abbdc138735852ccd4f2e8 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Wed, 25 Jun 2025 18:05:07 +0800 Subject: [PATCH 25/77] update --- .../model_executor/layers/fused_moe/layer.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 614eb6ae..cf58f6bf 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -847,8 +847,23 @@ class FusedMoE(nn.Cell): tokens_num_total = self.all_gather_from_dp_group(tokens_num) tokens_num_total = tokens_num_total.reshape(-1) tokens_cumulative = mint.cumsum(tokens_num_total, 0) - hidden_states = self.all_gather_from_dp_group(hidden_states) - router_logits = self.all_gather_from_dp_group(router_logits) + start = 0 if self.dp_rank == 0 else tokens_cumulative[self.dp_rank - 1] + end = tokens_cumulative[self.dp_rank] + + hidden_buffer = mint.zeros(tokens_cumulative[-1].item(), + hidden_states.shape[-1], + dtype=hidden_states.dtype) + hidden_buffer[start:end] = hidden_states + mint.distributed.all_reduce(hidden_buffer, self.ep_group) + + logit_buffer = mint.zeros(tokens_cumulative[-1].item(), + router_logits.shape[-1], + dtype=router_logits.dtype) + logit_buffer[start:end] = router_logits + mint.distributed.all_reduce(logit_buffer, self.ep_group) + + hidden_states = hidden_buffer + router_logits = logit_buffer # Matrix multiply. final_hidden_states = self.quant_method.apply( -- Gitee From ff4960a8f1a99ad3e807ba6d3c6298999a3b6866 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Thu, 26 Jun 2025 09:18:01 +0800 Subject: [PATCH 26/77] update --- vllm_mindspore/model_executor/layers/fused_moe/layer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index cf58f6bf..6f9bf8bc 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -854,13 +854,13 @@ class FusedMoE(nn.Cell): hidden_states.shape[-1], dtype=hidden_states.dtype) hidden_buffer[start:end] = hidden_states - mint.distributed.all_reduce(hidden_buffer, self.ep_group) + mint.distributed.all_reduce(hidden_buffer, group=self.dp_group) logit_buffer = mint.zeros(tokens_cumulative[-1].item(), router_logits.shape[-1], dtype=router_logits.dtype) logit_buffer[start:end] = router_logits - mint.distributed.all_reduce(logit_buffer, self.ep_group) + mint.distributed.all_reduce(logit_buffer, group=self.dp_group) hidden_states = hidden_buffer router_logits = logit_buffer @@ -886,7 +886,7 @@ class FusedMoE(nn.Cell): if self.pure_tp: # final_hidden_states = self.all_reduce_from_world_group(final_hidden_states) - mint.distributed.all_reduce(final_hidden_states, self.ep_group) + mint.distributed.all_reduce(final_hidden_states, group=self.ep_group) if self.dp_size > 1: start = 0 if self.dp_rank == 0 else tokens_cumulative[self.dp_rank - 1] end = tokens_cumulative[self.dp_rank] -- Gitee From 03d4dfd9831292eadf9b4103d7ad11f9ef8fdbdd Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Thu, 26 Jun 2025 09:30:34 +0800 Subject: [PATCH 27/77] update --- vllm_mindspore/model_executor/layers/fused_moe/layer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 6f9bf8bc..80bc19bb 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -850,14 +850,14 @@ class FusedMoE(nn.Cell): start = 0 if self.dp_rank == 0 else tokens_cumulative[self.dp_rank - 1] end = tokens_cumulative[self.dp_rank] - hidden_buffer = mint.zeros(tokens_cumulative[-1].item(), - hidden_states.shape[-1], + hidden_buffer = mint.zeros((tokens_cumulative[-1].item(), + hidden_states.shape[-1]), dtype=hidden_states.dtype) hidden_buffer[start:end] = hidden_states mint.distributed.all_reduce(hidden_buffer, group=self.dp_group) - logit_buffer = mint.zeros(tokens_cumulative[-1].item(), - router_logits.shape[-1], + logit_buffer = mint.zeros((tokens_cumulative[-1].item(), + router_logits.shape[-1]), dtype=router_logits.dtype) logit_buffer[start:end] = router_logits mint.distributed.all_reduce(logit_buffer, group=self.dp_group) -- Gitee From f35239220ba2537caef95a79e05b984962ae19f5 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Thu, 26 Jun 2025 11:45:06 +0800 Subject: [PATCH 28/77] update --- .../model_executor/layers/fused_moe/layer.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 80bc19bb..949fdbef 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -536,9 +536,10 @@ class FusedMoE(nn.Cell): if self.dp_size > 1 and self.ep_size == 1 or self.dp_size == 1: self.pure_tp = True + self.all_reduce_from_ep_group = ops.AllReduce(group=self.ep_group) if self.dp_size > 1: - self.all_gather_from_dp_group = ops.AllGather(self.dp_group) - self.all_reduce_from_world_group = ops.AllReduce(self.ep_group) + self.all_gather_from_dp_group = ops.AllGather(group=self.dp_group) + self.all_reduce_from_dp_group = ops.AllReduce(group=self.dp_group) @property def tp_size(self): @@ -847,20 +848,21 @@ class FusedMoE(nn.Cell): tokens_num_total = self.all_gather_from_dp_group(tokens_num) tokens_num_total = tokens_num_total.reshape(-1) tokens_cumulative = mint.cumsum(tokens_num_total, 0) - start = 0 if self.dp_rank == 0 else tokens_cumulative[self.dp_rank - 1] - end = tokens_cumulative[self.dp_rank] + start = 0 if self.dp_rank == 0 else tokens_cumulative[self.dp_rank - 1].item() + end = tokens_cumulative[self.dp_rank].item() hidden_buffer = mint.zeros((tokens_cumulative[-1].item(), hidden_states.shape[-1]), dtype=hidden_states.dtype) hidden_buffer[start:end] = hidden_states - mint.distributed.all_reduce(hidden_buffer, group=self.dp_group) - + # mint.distributed.all_reduce(hidden_buffer, group=self.dp_group) + hidden_buffer = self.all_reduce_from_dp_group(hidden_buffer) logit_buffer = mint.zeros((tokens_cumulative[-1].item(), router_logits.shape[-1]), dtype=router_logits.dtype) logit_buffer[start:end] = router_logits - mint.distributed.all_reduce(logit_buffer, group=self.dp_group) + # mint.distributed.all_reduce(logit_buffer, group=self.dp_group) + logit_buffer = self.all_reduce_from_dp_group(logit_buffer) hidden_states = hidden_buffer router_logits = logit_buffer @@ -886,7 +888,8 @@ class FusedMoE(nn.Cell): if self.pure_tp: # final_hidden_states = self.all_reduce_from_world_group(final_hidden_states) - mint.distributed.all_reduce(final_hidden_states, group=self.ep_group) + # mint.distributed.all_reduce(final_hidden_states, group=self.ep_group) + final_hidden_states = self.all_reduce_from_ep_group(final_hidden_states) if self.dp_size > 1: start = 0 if self.dp_rank == 0 else tokens_cumulative[self.dp_rank - 1] end = tokens_cumulative[self.dp_rank] -- Gitee From 45424e42c5373652f7cf23639c5aa224c44d60bf Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Thu, 26 Jun 2025 11:58:16 +0800 Subject: [PATCH 29/77] test moe2 --- .../layers/fused_moe/fused_moe2.py | 186 ++++++++++++++++++ 1 file changed, 186 insertions(+) create mode 100644 vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py new file mode 100644 index 00000000..f3c44146 --- /dev/null +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py @@ -0,0 +1,186 @@ +from typing import Optional + +from mindspore import Tensor, mint, ops, nn +from mindspore.ops.auto_generate import (GroupedMatmulV4, + FusedAddTopKDiv, + MoeInitRoutingV2, + MoeTokenUnpermute) +import mindspore as ms +from vllm.distributed.parallel_state import get_ep_group, get_dp_group + +def fused_topk( + hidden_states: Tensor, + gating_output: Tensor, + topk: int, + renormalize: bool, + indices_type = None, +) -> tuple[Tensor, Tensor]: + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + score = mint.softmax(gating_output, dim=-1) + topk_weights, topk_ids = mint.topk( + score, + k=topk, + dim=-1 + ) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + if indices_type is not None: + topk_ids = topk_ids.to(indices_type) + return topk_weights, topk_ids + + +def grouped_topk( + hidden_states: Tensor, + gating_output: Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[Tensor] = None +) -> tuple[Tensor, Tensor]: + fused_add_topk_div = FusedAddTopKDiv() + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + scoring_type = 0 # sigmoid + topk_in_group = 2 + topk_weights, topk_ids = fused_add_topk_div( + gating_output, + e_score_correction_bias, + num_expert_group, + topk_group, + topk, + topk_in_group, + scoring_type, + renormalize) + + return topk_weights, topk_ids + + +class FusedExperts(nn.Cell): + def __init__(self): + super().__init__() + self.group_matmul_ops = GroupedMatmulV4() + self.moe_init_routing_op = MoeInitRoutingV2() + self.moe_token_unpermute = MoeTokenUnpermute() + + def construct(self, + hidden_states: Tensor, + w1: Tensor, + w2: Tensor, + topk_weights: Tensor, + topk_ids: Tensor, + activation: str = "silu", + global_num_experts: int = -1, + apply_router_weight_on_input: bool = False, + expert_map: Optional[Tensor] = None, + tp_size: int = 1, + ep_size: int = 0) -> Tensor: + + if tp_size >= 1: + # no ep, pure tp + if ep_size == 1: + hidden_states = self._run_tp_moe(hidden_states, w1, w2, topk_ids, topk_weights, + activation, global_num_experts, + apply_router_weight_on_input) + # ep_size > 1 : pure ep or tp + ep + else: + # pure ep + if tp_size == 1: + hidden_states = self._run_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, + activation, global_num_experts, + apply_router_weight_on_input) + # tp_size > 1 : tp + ep + else: + hidden_states = self._run_tp_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, + activation, global_num_experts, + apply_router_weight_on_input) + + return hidden_states + + + def _gate_activation(self, gate, activation): + if activation == "silu": + return mint.nn.functional.silu(gate) + elif activation == "gelu": + return mint.nn.functional.gelu(gate) + else: + raise ValueError(f"Unsupported activation function: {activation}") + + + + + def _group_matmul(self, hidden_states, weight, group_list): + return self.group_matmul_ops([hidden_states], [weight], + None, None, None, None, None, None, + group_list, split_item=3, group_type=0, group_list_type=1)[0] + + def _run_ep_moe(self, + hidden_states, + w1, + w2, + topk_ids, + topk_weights, + activation, + global_num_experts, + apply_router_weight_on_input): + hidden_states = self._group_matmul(hidden_states, w1, topk_ids) + hidden_states = self._gate_activation(hidden_states, activation) + hidden_states = self._group_matmul(hidden_states, w2, topk_ids) + return hidden_states + + + def _run_tp_moe(self, + hidden_states, + w1, + w2, + topk_ids, + topk_weights, + activation, + global_num_experts, + apply_router_weight_on_input): + topk_weights = topk_weights.astype(hidden_states.dtype) + topk_ids = topk_ids.astype(ms.int32) + + sorted_input_tensor, unsort_map, group_list, _ = \ + self.moe_init_routing_op( + hidden_states, + topk_ids, + active_num=0, + expert_capacity=0, + expert_num=global_num_experts, + drop_pad_mode=0, + expert_tokens_count_or_cumsum_flag=2, + expert_tokens_before_capacity_flag=True) + + group_list = group_list.astype(ms.int64) + + gate_hidden_out = self._group_matmul(sorted_input_tensor, mint.transpose(w1, -1, -2), group_list) + gate, hidden = mint.split(gate_hidden_out, + (w1.shape[1] // 2, w1.shape[1] // 2), -1) + gate = self._gate_activation(gate, activation) + hidden = mint.mul(hidden, gate) + expert_output = self._group_matmul(hidden, mint.transpose(w2, -1, -2), group_list) + expert_output = mint.nan_to_num(expert_output, 0, 0, 0) + moe_output = self.moe_token_unpermute(permuted_tokens=expert_output, + sorted_indices=unsort_map, + probs=topk_weights, + padded_mode=False, + restore_shape=None) + return moe_output + + + def _run_tp_ep_moe( + self, + hidden_states, + w1, + w2, + group_list, + group_logits, + activation, + global_num_experts, + apply_router_weight_on_input): + raise NotImplementedError( + "TP + EP MoE is not implemented yet. Please use pure TP or pure EP MoE instead.") -- Gitee From d7a9bc9066d52713124ec6432531d136625851d1 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Thu, 26 Jun 2025 15:50:15 +0800 Subject: [PATCH 30/77] support jit --- vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py index 6beb2aa5..a7b3bf7d 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py @@ -15,8 +15,6 @@ def fused_topk( renormalize: bool, indices_type = None, ) -> tuple[Tensor, Tensor]: - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") score = mint.softmax(gating_output, dim=-1) topk_weights, topk_ids = mint.topk( score, -- Gitee From df37c48aadb2b0c5fed5cca72bf51f8e06275d76 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Thu, 26 Jun 2025 20:09:02 +0800 Subject: [PATCH 31/77] suit tp+dp jit --- .../model_executor/layers/fused_moe/layer.py | 33 ++-- .../model_executor/models/qwen3_moe.py | 157 +++++++++++++++++- 2 files changed, 162 insertions(+), 28 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 949fdbef..863330f7 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -829,11 +829,12 @@ class FusedMoE(nn.Cell): return final_hidden_states def construct(self, hidden_states: Tensor, - router_logits: Tensor): - return self.forward_impl(hidden_states, router_logits) + router_logits: Tensor, + dp_pad_input): + return self.forward_impl(hidden_states, router_logits, dp_pad_input) def forward_impl(self, hidden_states: Tensor, - router_logits: Tensor): + router_logits: Tensor, dp_pad_input): assert self.quant_method is not None # do_naive_dispatch_combine: bool = ( @@ -844,24 +845,12 @@ class FusedMoE(nn.Cell): # hidden_states, router_logits) if self.dp_size > 1 and self.pure_tp: - tokens_num = Tensor([[hidden_states.shape[0]]], dtype=ms.int32) - tokens_num_total = self.all_gather_from_dp_group(tokens_num) - tokens_num_total = tokens_num_total.reshape(-1) - tokens_cumulative = mint.cumsum(tokens_num_total, 0) - start = 0 if self.dp_rank == 0 else tokens_cumulative[self.dp_rank - 1].item() - end = tokens_cumulative[self.dp_rank].item() - - hidden_buffer = mint.zeros((tokens_cumulative[-1].item(), - hidden_states.shape[-1]), - dtype=hidden_states.dtype) - hidden_buffer[start:end] = hidden_states - # mint.distributed.all_reduce(hidden_buffer, group=self.dp_group) + tokens_num = hidden_states.shape[0] + + hidden_buffer = mint.nn.functional.pad(hidden_states, dp_pad_input) hidden_buffer = self.all_reduce_from_dp_group(hidden_buffer) - logit_buffer = mint.zeros((tokens_cumulative[-1].item(), - router_logits.shape[-1]), - dtype=router_logits.dtype) - logit_buffer[start:end] = router_logits - # mint.distributed.all_reduce(logit_buffer, group=self.dp_group) + + logit_buffer = mint.nn.functional.pad(router_logits, dp_pad_input) logit_buffer = self.all_reduce_from_dp_group(logit_buffer) hidden_states = hidden_buffer @@ -891,8 +880,8 @@ class FusedMoE(nn.Cell): # mint.distributed.all_reduce(final_hidden_states, group=self.ep_group) final_hidden_states = self.all_reduce_from_ep_group(final_hidden_states) if self.dp_size > 1: - start = 0 if self.dp_rank == 0 else tokens_cumulative[self.dp_rank - 1] - end = tokens_cumulative[self.dp_rank] + start = dp_pad_input[-2] + end = start + tokens_num final_hidden_states = final_hidden_states[start:end] # if do_naive_dispatch_combine: diff --git a/vllm_mindspore/model_executor/models/qwen3_moe.py b/vllm_mindspore/model_executor/models/qwen3_moe.py index adb710b4..da88a23e 100644 --- a/vllm_mindspore/model_executor/models/qwen3_moe.py +++ b/vllm_mindspore/model_executor/models/qwen3_moe.py @@ -24,10 +24,15 @@ from collections.abc import Iterable from typing import Any, Optional, Union, Dict, Tuple, List -from mindspore import Tensor, nn, Parameter +import mindspore as ms +from mindspore import Tensor, nn, Parameter, mint +from mindspore import Tensor, nn, mutable +from mindspore.common import dtype as mstype + from transformers import PretrainedConfig from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_world_size, + get_dp_group) from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.interfaces import SupportsPP @@ -54,6 +59,7 @@ from vllm_mindspore.model_executor.models.utils import ( from vllm_mindspore.model_executor.models.model_base import NativeModel from vllm_mindspore.model_executor.layers.sampler import (SamplerOutput, get_sampler) +from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE logger = init_logger(__name__) @@ -86,7 +92,7 @@ class Qwen3MoeMLP(nn.Cell): "Only silu is supported for now.") self.act_fn = SiluAndMul() - def construct(self, x): + def construct(self, x, dp_pad_input): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) @@ -124,7 +130,7 @@ class Qwen3MoeSparseMoeBlock(nn.Cell): quant_config=None, prefix=f"{prefix}.gate") - def construct(self, hidden_states: Tensor) -> Tensor: + def construct(self, hidden_states: Tensor, dp_pad_input) -> Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape hidden_dim = hidden_states.shape[-1] @@ -311,6 +317,7 @@ class Qwen3MoeDecoderLayer(nn.Cell): q_seq_lens: Tensor, block_tables: Tensor, residual: Optional[Tensor], + dp_pad_input: Optional[bool] = None, ) -> Tensor: # Self Attention if residual is None: @@ -326,7 +333,7 @@ class Qwen3MoeDecoderLayer(nn.Cell): # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) - hidden_states = self.mlp(hidden_states) + hidden_states = self.mlp(hidden_states, dp_pad_input) return hidden_states, residual @@ -376,6 +383,7 @@ class Qwen3MoeModel(nn.Cell): block_tables: Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[Tensor] = None, + dp_pad_input = None, ) -> Union[Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -394,7 +402,8 @@ class Qwen3MoeModel(nn.Cell): value_caches[i - self.start_layer], is_prefill, slot_mapping, attn_mask, batch_valid_length, - q_seq_lens, block_tables, residual) + q_seq_lens, block_tables, residual, + dp_pad_input) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -537,6 +546,12 @@ class Qwen3MoeForCausalLM(NativeModel, SupportsPP): self.common_preprocess(vllm_config, prefix) + if get_dp_group().world_size > 1 and not self.parallel_config.enable_expert_parallel: + self.dp_pad_input = True + self.dp_group = get_dp_group().device_group._name + self.dp_world_size = get_dp_group().world_size + self.dp_rank = get_dp_group().rank_in_group + def get_input_embeddings(self, input_ids: Tensor) -> Tensor: return self.model.get_input_embeddings(input_ids) @@ -570,3 +585,133 @@ class Qwen3MoeForCausalLM(NativeModel, SupportsPP): Tensor]]) -> set[str]: params_dict = self.get_params_dict() return self.model.load_weights(weights, params_dict) + + def exec_model(self, + input_ids: Tensor, + positions: Tensor, + intermediate_tensors: IntermediateTensors = None, + inputs_embeds: Tensor = None, + **kwargs): + model_inputs, is_prefill = self.prepare_inputs(input_ids, positions, + intermediate_tensors, + inputs_embeds) + + if self.prev_prefill != is_prefill and self.is_graph_mode: + self.set_model_inputs(input_ids, positions, intermediate_tensors, + inputs_embeds, is_prefill) + self.prev_prefill = is_prefill + + # for dummy_attention_metadata + if is_prefill and not self.set_flags: + self.set_flags = True + + if self.run_model is None: + self.run_model = ms.jit( + function=self.model, # type: ignore[attr-defined] + jit_level='O0' + ) if self.is_graph_mode else self.model # type: ignore[attr-defined] + + if self.dp_pad_input: + # if dp and not ep, should pad input to gather. + token_num_total = mint.empty((self.dp_world_size, 1), dtype=ms.int32) + send_tensor = ms.Tensor([[input_ids.shape[0]]], dtype=ms.int32) + mint.distributed.all_gather_into_tensor(token_num_total, send_tensor, + group=self.dp_group) + token_num_total = token_num_total.reshape(-1) + tokens_cumulative = mint.cumsum(token_num_total, dim=0) + start = 0 if self.dp_rank == 0 else tokens_cumulative[self.dp_rank - 1].item() + end = tokens_cumulative[self.dp_rank].item() + end2 = tokens_cumulative[-1].item() - end + pad_index_tensor = ms.Tensor([0, 0, start, end2], dtype=ms.int32) + + model_output = self.run_model( # type: ignore[misc] + input_ids=model_inputs["input_ids"], + positions=model_inputs["position_ids"], + key_caches=model_inputs["key_cache"], + value_caches=model_inputs["value_cache"], + is_prefill=is_prefill, + slot_mapping=model_inputs["slot_mapping"], + attn_mask=model_inputs["attention_mask"], + batch_valid_length=model_inputs["batch_valid_length"], + q_seq_lens=model_inputs["q_seq_lens"], + block_tables=model_inputs["block_tables"], + intermediate_tensors=model_inputs["intermediate_tensors"], + inputs_embeds=model_inputs["inputs_embeds"], + pad_index_tensor=pad_index_tensor if self.dp_pad_input else None, + ) + + return model_output + + + def set_model_inputs(self, input_ids, position_ids, intermediate_tensors, + inputs_embeds, is_prefill): + if input_ids is None: + dyn_input_ids = None + else: + dyn_input_ids = ms.Tensor(shape=[None] * input_ids.ndim, + dtype=mstype.int32) + + if position_ids is None: + dyn_position_ids = None + else: + dyn_position_ids = ms.Tensor(shape=[None] * position_ids.ndim, + dtype=mstype.int32) + + if inputs_embeds is None: + dyn_inputs_embeds = None + else: + dyn_inputs_embeds = ms.Tensor(shape=[None] * inputs_embeds.ndim, + dtype=inputs_embeds.dtype) + + if intermediate_tensors is None: + dyn_intermediate_tensors = None + else: + dyn_intermediate_tensors = ms.Tensor( + shape=[None] * intermediate_tensors.ndim, + dtype=intermediate_tensors.dtype) + + block_size = self.cache_config.block_size + num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) + head_size = self.model_config.get_head_size() + kv_cache_shape = (None, block_size, num_kv_heads, head_size) + + kv_cache_dtype = self.model_config.dtype if self.cache_config.cache_dtype == "auto" \ + else self.cache_config.cache_dtype + if kv_cache_dtype in STR_DTYPE_TO_MS_DTYPE: + kv_cache_dtype = STR_DTYPE_TO_MS_DTYPE[kv_cache_dtype] + + num_layers = self.model_config.get_num_layers(self.parallel_config) + + dyn_key_cache = Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype) + dyn_value_cache = Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype) + dyn_key_caches = mutable([dyn_key_cache for _ in range(num_layers)]) + dyn_value_caches = mutable( + [dyn_value_cache for _ in range(num_layers)]) + + dyn_slot_mapping = Tensor(shape=[None], dtype=mstype.int32) + dynamic_attention_mask = Tensor(shape=[None, None], + dtype=self.model_config.dtype) + dyn_batch_valid_length = Tensor(shape=[None], dtype=mstype.int32) + dyn_q_seq_lens = Tensor(shape=[None], dtype=mstype.int32) + dyn_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) + dyn_pad_input = Tensor(shape=[4], dtype=mstype.int32) if self.dp_pad_input else None + + self.model.set_inputs( + dyn_input_ids, + dyn_position_ids, + dyn_key_caches, # type: ignore[attr-defined] + dyn_value_caches, + is_prefill, + dyn_slot_mapping, + dynamic_attention_mask, + dyn_batch_valid_length, + dyn_q_seq_lens, + dyn_block_tables, + dyn_intermediate_tensors, + dyn_inputs_embeds, + dyn_pad_input) + + dynamic_hidden_states = Tensor(shape=[None, None], + dtype=self.model_config.dtype) + self.lm_head.set_inputs( + dynamic_hidden_states) # type: ignore[attr-defined] -- Gitee From 638b3e15d0c88695f5cc67e6c8842b03feb51dd4 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Thu, 26 Jun 2025 20:22:16 +0800 Subject: [PATCH 32/77] update --- .../model_executor/layers/fused_moe/layer.py | 12 ++++++------ .../model_executor/models/qwen3_moe.py | 19 ++++++++++--------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 863330f7..d280acb2 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -830,11 +830,11 @@ class FusedMoE(nn.Cell): def construct(self, hidden_states: Tensor, router_logits: Tensor, - dp_pad_input): - return self.forward_impl(hidden_states, router_logits, dp_pad_input) + dp_pad_index): + return self.forward_impl(hidden_states, router_logits, dp_pad_index) def forward_impl(self, hidden_states: Tensor, - router_logits: Tensor, dp_pad_input): + router_logits: Tensor, dp_pad_index): assert self.quant_method is not None # do_naive_dispatch_combine: bool = ( @@ -847,10 +847,10 @@ class FusedMoE(nn.Cell): if self.dp_size > 1 and self.pure_tp: tokens_num = hidden_states.shape[0] - hidden_buffer = mint.nn.functional.pad(hidden_states, dp_pad_input) + hidden_buffer = mint.nn.functional.pad(hidden_states, dp_pad_index) hidden_buffer = self.all_reduce_from_dp_group(hidden_buffer) - logit_buffer = mint.nn.functional.pad(router_logits, dp_pad_input) + logit_buffer = mint.nn.functional.pad(router_logits, dp_pad_index) logit_buffer = self.all_reduce_from_dp_group(logit_buffer) hidden_states = hidden_buffer @@ -880,7 +880,7 @@ class FusedMoE(nn.Cell): # mint.distributed.all_reduce(final_hidden_states, group=self.ep_group) final_hidden_states = self.all_reduce_from_ep_group(final_hidden_states) if self.dp_size > 1: - start = dp_pad_input[-2] + start = dp_pad_index[-2] end = start + tokens_num final_hidden_states = final_hidden_states[start:end] diff --git a/vllm_mindspore/model_executor/models/qwen3_moe.py b/vllm_mindspore/model_executor/models/qwen3_moe.py index da88a23e..0e492e0d 100644 --- a/vllm_mindspore/model_executor/models/qwen3_moe.py +++ b/vllm_mindspore/model_executor/models/qwen3_moe.py @@ -92,7 +92,7 @@ class Qwen3MoeMLP(nn.Cell): "Only silu is supported for now.") self.act_fn = SiluAndMul() - def construct(self, x, dp_pad_input): + def construct(self, x, dp_pad_index): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) @@ -130,7 +130,7 @@ class Qwen3MoeSparseMoeBlock(nn.Cell): quant_config=None, prefix=f"{prefix}.gate") - def construct(self, hidden_states: Tensor, dp_pad_input) -> Tensor: + def construct(self, hidden_states: Tensor, dp_pad_index) -> Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape hidden_dim = hidden_states.shape[-1] @@ -139,7 +139,8 @@ class Qwen3MoeSparseMoeBlock(nn.Cell): # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) + router_logits=router_logits, + dp_pad_index=dp_pad_index) final_hidden_states = final_hidden_states if self.tp_size > 1: final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 @@ -317,7 +318,7 @@ class Qwen3MoeDecoderLayer(nn.Cell): q_seq_lens: Tensor, block_tables: Tensor, residual: Optional[Tensor], - dp_pad_input: Optional[bool] = None, + dp_pad_index: Optional[bool] = None, ) -> Tensor: # Self Attention if residual is None: @@ -333,7 +334,7 @@ class Qwen3MoeDecoderLayer(nn.Cell): # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) - hidden_states = self.mlp(hidden_states, dp_pad_input) + hidden_states = self.mlp(hidden_states, dp_pad_index) return hidden_states, residual @@ -383,7 +384,7 @@ class Qwen3MoeModel(nn.Cell): block_tables: Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[Tensor] = None, - dp_pad_input = None, + dp_pad_index = None, ) -> Union[Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -403,7 +404,7 @@ class Qwen3MoeModel(nn.Cell): is_prefill, slot_mapping, attn_mask, batch_valid_length, q_seq_lens, block_tables, residual, - dp_pad_input) + dp_pad_index) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -622,7 +623,7 @@ class Qwen3MoeForCausalLM(NativeModel, SupportsPP): start = 0 if self.dp_rank == 0 else tokens_cumulative[self.dp_rank - 1].item() end = tokens_cumulative[self.dp_rank].item() end2 = tokens_cumulative[-1].item() - end - pad_index_tensor = ms.Tensor([0, 0, start, end2], dtype=ms.int32) + dp_pad_index = ms.Tensor([0, 0, start, end2], dtype=ms.int32) model_output = self.run_model( # type: ignore[misc] input_ids=model_inputs["input_ids"], @@ -637,7 +638,7 @@ class Qwen3MoeForCausalLM(NativeModel, SupportsPP): block_tables=model_inputs["block_tables"], intermediate_tensors=model_inputs["intermediate_tensors"], inputs_embeds=model_inputs["inputs_embeds"], - pad_index_tensor=pad_index_tensor if self.dp_pad_input else None, + dp_pad_index=dp_pad_index if self.dp_pad_input else None, ) return model_output -- Gitee From 89e46734c6557aa6f29bac6bd53930fb35f0a336 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Thu, 26 Jun 2025 20:55:09 +0800 Subject: [PATCH 33/77] fix --- vllm_mindspore/model_executor/models/qwen3_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_mindspore/model_executor/models/qwen3_moe.py b/vllm_mindspore/model_executor/models/qwen3_moe.py index 0e492e0d..339508d5 100644 --- a/vllm_mindspore/model_executor/models/qwen3_moe.py +++ b/vllm_mindspore/model_executor/models/qwen3_moe.py @@ -547,6 +547,7 @@ class Qwen3MoeForCausalLM(NativeModel, SupportsPP): self.common_preprocess(vllm_config, prefix) + self.dp_pad_input = False if get_dp_group().world_size > 1 and not self.parallel_config.enable_expert_parallel: self.dp_pad_input = True self.dp_group = get_dp_group().device_group._name -- Gitee From 6faab79abda2b484557264eb58465d04808ad9de Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Fri, 27 Jun 2025 10:27:42 +0800 Subject: [PATCH 34/77] update --- .../model_executor/layers/fused_moe/layer.py | 39 ++++++++++++------- .../model_executor/models/qwen3_moe.py | 36 ++++++++++++----- 2 files changed, 52 insertions(+), 23 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index d280acb2..1ac6b782 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -540,6 +540,7 @@ class FusedMoE(nn.Cell): if self.dp_size > 1: self.all_gather_from_dp_group = ops.AllGather(group=self.dp_group) self.all_reduce_from_dp_group = ops.AllReduce(group=self.dp_group) + self.reduce_scatter_from_dp_group = ops.ReduceScatter(group=self.dp_group) @property def tp_size(self): @@ -829,12 +830,13 @@ class FusedMoE(nn.Cell): return final_hidden_states def construct(self, hidden_states: Tensor, - router_logits: Tensor, - dp_pad_index): - return self.forward_impl(hidden_states, router_logits, dp_pad_index) + router_logits: Tensor, + dp_pad_index, + dp_select_index): + return self.forward_impl(hidden_states, router_logits, dp_pad_index, dp_select_index) def forward_impl(self, hidden_states: Tensor, - router_logits: Tensor, dp_pad_index): + router_logits: Tensor, dp_pad_index, dp_select_index): assert self.quant_method is not None # do_naive_dispatch_combine: bool = ( @@ -847,14 +849,21 @@ class FusedMoE(nn.Cell): if self.dp_size > 1 and self.pure_tp: tokens_num = hidden_states.shape[0] + # hidden_buffer = mint.nn.functional.pad(hidden_states, dp_pad_index) + # hidden_buffer = self.all_reduce_from_dp_group(hidden_buffer) + + # logit_buffer = mint.nn.functional.pad(router_logits, dp_pad_index) + # logit_buffer = self.all_reduce_from_dp_group(logit_buffer) + + # ops.AllGather is not supported for uneven size tensor, so need to pad to same size. hidden_buffer = mint.nn.functional.pad(hidden_states, dp_pad_index) - hidden_buffer = self.all_reduce_from_dp_group(hidden_buffer) + hidden_buffer = self.all_gather_from_dp_group(hidden_buffer) logit_buffer = mint.nn.functional.pad(router_logits, dp_pad_index) - logit_buffer = self.all_reduce_from_dp_group(logit_buffer) + logit_buffer = self.all_gather_from_dp_group(logit_buffer) - hidden_states = hidden_buffer - router_logits = logit_buffer + # hidden_states = mint.index_select(hidden_buffer, 0, dp_select_index) + # router_logits = mint.index_select(logit_buffer, 0, dp_select_index) # Matrix multiply. final_hidden_states = self.quant_method.apply( @@ -878,11 +887,15 @@ class FusedMoE(nn.Cell): if self.pure_tp: # final_hidden_states = self.all_reduce_from_world_group(final_hidden_states) # mint.distributed.all_reduce(final_hidden_states, group=self.ep_group) - final_hidden_states = self.all_reduce_from_ep_group(final_hidden_states) - if self.dp_size > 1: - start = dp_pad_index[-2] - end = start + tokens_num - final_hidden_states = final_hidden_states[start:end] + if self.dp_size == 1: + final_hidden_states = self.all_reduce_from_ep_group(final_hidden_states) + # dp_size > 1 + else: + final_hidden_states = self.reduce_scatter_from_ep_group(final_hidden_states) + final_hidden_states = mint.index_select(final_hidden_states, 0, dp_select_index) + # start = dp_pad_index[-2] + # end = start + tokens_num + # final_hidden_states = final_hidden_states[start:end] # if do_naive_dispatch_combine: # final_hidden_states = get_ep_group().combine(final_hidden_states) diff --git a/vllm_mindspore/model_executor/models/qwen3_moe.py b/vllm_mindspore/model_executor/models/qwen3_moe.py index 339508d5..f3ab0a9a 100644 --- a/vllm_mindspore/model_executor/models/qwen3_moe.py +++ b/vllm_mindspore/model_executor/models/qwen3_moe.py @@ -24,6 +24,7 @@ from collections.abc import Iterable from typing import Any, Optional, Union, Dict, Tuple, List +import numpy as np import mindspore as ms from mindspore import Tensor, nn, Parameter, mint from mindspore import Tensor, nn, mutable @@ -92,7 +93,7 @@ class Qwen3MoeMLP(nn.Cell): "Only silu is supported for now.") self.act_fn = SiluAndMul() - def construct(self, x, dp_pad_index): + def construct(self, x, dp_pad_index, dp_select_index): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) @@ -319,6 +320,7 @@ class Qwen3MoeDecoderLayer(nn.Cell): block_tables: Tensor, residual: Optional[Tensor], dp_pad_index: Optional[bool] = None, + dp_select_index: Optional[Tensor] = None, ) -> Tensor: # Self Attention if residual is None: @@ -334,7 +336,7 @@ class Qwen3MoeDecoderLayer(nn.Cell): # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) - hidden_states = self.mlp(hidden_states, dp_pad_index) + hidden_states = self.mlp(hidden_states, dp_pad_index, dp_select_index) return hidden_states, residual @@ -385,6 +387,7 @@ class Qwen3MoeModel(nn.Cell): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[Tensor] = None, dp_pad_index = None, + dp_select_index: Optional[Tensor] = None, ) -> Union[Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -404,7 +407,7 @@ class Qwen3MoeModel(nn.Cell): is_prefill, slot_mapping, attn_mask, batch_valid_length, q_seq_lens, block_tables, residual, - dp_pad_index) + dp_pad_index, dp_select_index) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -620,11 +623,20 @@ class Qwen3MoeForCausalLM(NativeModel, SupportsPP): mint.distributed.all_gather_into_tensor(token_num_total, send_tensor, group=self.dp_group) token_num_total = token_num_total.reshape(-1) - tokens_cumulative = mint.cumsum(token_num_total, dim=0) - start = 0 if self.dp_rank == 0 else tokens_cumulative[self.dp_rank - 1].item() - end = tokens_cumulative[self.dp_rank].item() - end2 = tokens_cumulative[-1].item() - end - dp_pad_index = ms.Tensor([0, 0, start, end2], dtype=ms.int32) + # tokens_cumulative = mint.cumsum(token_num_total, dim=0) + # start = 0 if self.dp_rank == 0 else tokens_cumulative[self.dp_rank - 1].item() + # end = tokens_cumulative[self.dp_rank].item() + # end2 = tokens_cumulative[-1].item() - end + # dp_pad_index = ms.Tensor([0, 0, start, end2], dtype=ms.int32) + token_num_total = token_num_total.asnumpy() + max_token_num = int(token_num_total.max()) + total_pad_num = (max_token_num - token_num_total) + this_pad_num = total_pad_num[self.dp_rank] + dp_pad_index = ms.Tensor([0, 0, 0, int(this_pad_num)], dtype=mstype.int32) + dp_select_index = [j + self.dp_rank * max_token_num + for j in range(token_num_total[self.dp_rank])] + dp_select_index = ms.Tensor(dp_select_index, dtype=mstype.int32) + model_output = self.run_model( # type: ignore[misc] input_ids=model_inputs["input_ids"], @@ -640,6 +652,7 @@ class Qwen3MoeForCausalLM(NativeModel, SupportsPP): intermediate_tensors=model_inputs["intermediate_tensors"], inputs_embeds=model_inputs["inputs_embeds"], dp_pad_index=dp_pad_index if self.dp_pad_input else None, + dp_select_index=dp_select_index if self.dp_pad_input else None ) return model_output @@ -696,7 +709,9 @@ class Qwen3MoeForCausalLM(NativeModel, SupportsPP): dyn_batch_valid_length = Tensor(shape=[None], dtype=mstype.int32) dyn_q_seq_lens = Tensor(shape=[None], dtype=mstype.int32) dyn_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) - dyn_pad_input = Tensor(shape=[4], dtype=mstype.int32) if self.dp_pad_input else None + dyn_dp_pad_input = Tensor(shape=[4], dtype=mstype.int32) if self.dp_pad_input else None + dyn_dp_select_index = Tensor(shape=[None], dtype=mstype.int32) if self.dp_pad_input else None + self.model.set_inputs( dyn_input_ids, @@ -711,7 +726,8 @@ class Qwen3MoeForCausalLM(NativeModel, SupportsPP): dyn_block_tables, dyn_intermediate_tensors, dyn_inputs_embeds, - dyn_pad_input) + dyn_dp_pad_input, + dyn_dp_select_index) dynamic_hidden_states = Tensor(shape=[None, None], dtype=self.model_config.dtype) -- Gitee From 9887e4af707bd8fd2a4b909e8083bf20e5890145 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Fri, 27 Jun 2025 14:10:12 +0800 Subject: [PATCH 35/77] update --- .../model_executor/layers/fused_moe/layer.py | 42 ++++++------- .../model_executor/models/qwen3_moe.py | 62 +++++++++++++------ 2 files changed, 63 insertions(+), 41 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 1ac6b782..df74ede1 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -538,6 +538,7 @@ class FusedMoE(nn.Cell): self.pure_tp = True self.all_reduce_from_ep_group = ops.AllReduce(group=self.ep_group) if self.dp_size > 1: + self.gather = ops.Gather() self.all_gather_from_dp_group = ops.AllGather(group=self.dp_group) self.all_reduce_from_dp_group = ops.AllReduce(group=self.dp_group) self.reduce_scatter_from_dp_group = ops.ReduceScatter(group=self.dp_group) @@ -832,23 +833,21 @@ class FusedMoE(nn.Cell): def construct(self, hidden_states: Tensor, router_logits: Tensor, dp_pad_index, - dp_select_index): - return self.forward_impl(hidden_states, router_logits, dp_pad_index, dp_select_index) + dp_unpad_index, + dp_pad_index_with_offset, + dp_unpad_index_total_with_offset): + return self.forward_impl(hidden_states, router_logits, dp_pad_index, + dp_unpad_index, dp_pad_index_with_offset, + dp_unpad_index_total_with_offset) def forward_impl(self, hidden_states: Tensor, - router_logits: Tensor, dp_pad_index, dp_select_index): - assert self.quant_method is not None - - # do_naive_dispatch_combine: bool = ( - # self.dp_size > 1 - # and not self.ep_size > 1) - # if do_naive_dispatch_combine: - # hidden_states, router_logits = get_ep_group().dispatch( - # hidden_states, router_logits) + router_logits: Tensor, dp_pad_index, dp_unpad_index, + dp_pad_index_total_with_offset, + dp_unpad_index_total_with_offset): + # dp_pad_index = [0, 1, 2, 3, 0, 0, 0] + # dp_pad_index_with_offset = [5, 6, 7, 8, 0, 0, 0] if self.dp_size > 1 and self.pure_tp: - tokens_num = hidden_states.shape[0] - # hidden_buffer = mint.nn.functional.pad(hidden_states, dp_pad_index) # hidden_buffer = self.all_reduce_from_dp_group(hidden_buffer) @@ -856,14 +855,14 @@ class FusedMoE(nn.Cell): # logit_buffer = self.all_reduce_from_dp_group(logit_buffer) # ops.AllGather is not supported for uneven size tensor, so need to pad to same size. - hidden_buffer = mint.nn.functional.pad(hidden_states, dp_pad_index) + hidden_buffer = self.gather(hidden_states, dp_pad_index) hidden_buffer = self.all_gather_from_dp_group(hidden_buffer) - logit_buffer = mint.nn.functional.pad(router_logits, dp_pad_index) + logit_buffer = self.gather(router_logits, dp_pad_index) logit_buffer = self.all_gather_from_dp_group(logit_buffer) - # hidden_states = mint.index_select(hidden_buffer, 0, dp_select_index) - # router_logits = mint.index_select(logit_buffer, 0, dp_select_index) + hidden_states = mint.index_select(hidden_buffer, 0, dp_unpad_index_total_with_offset) + router_logits = mint.index_select(logit_buffer, 0, dp_unpad_index_total_with_offset) # Matrix multiply. final_hidden_states = self.quant_method.apply( @@ -885,21 +884,18 @@ class FusedMoE(nn.Cell): ) if self.pure_tp: - # final_hidden_states = self.all_reduce_from_world_group(final_hidden_states) - # mint.distributed.all_reduce(final_hidden_states, group=self.ep_group) if self.dp_size == 1: final_hidden_states = self.all_reduce_from_ep_group(final_hidden_states) # dp_size > 1 else: + final_hidden_states = self.gather(final_hidden_states, dp_pad_index_total_with_offset) final_hidden_states = self.reduce_scatter_from_ep_group(final_hidden_states) - final_hidden_states = mint.index_select(final_hidden_states, 0, dp_select_index) + final_hidden_states = self.gather(final_hidden_states, dp_unpad_index) + # final_hidden_states = mint.index_select(final_hidden_states, 0, dp_unpad_index) # start = dp_pad_index[-2] # end = start + tokens_num # final_hidden_states = final_hidden_states[start:end] - # if do_naive_dispatch_combine: - # final_hidden_states = get_ep_group().combine(final_hidden_states) - if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) final_hidden_states = tensor_model_parallel_all_reduce( diff --git a/vllm_mindspore/model_executor/models/qwen3_moe.py b/vllm_mindspore/model_executor/models/qwen3_moe.py index f3ab0a9a..bd29b708 100644 --- a/vllm_mindspore/model_executor/models/qwen3_moe.py +++ b/vllm_mindspore/model_executor/models/qwen3_moe.py @@ -93,7 +93,7 @@ class Qwen3MoeMLP(nn.Cell): "Only silu is supported for now.") self.act_fn = SiluAndMul() - def construct(self, x, dp_pad_index, dp_select_index): + def construct(self, x, dp_pad_index, dp_unpad_index, dp_unpad_index_total_with_offset): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) @@ -131,7 +131,8 @@ class Qwen3MoeSparseMoeBlock(nn.Cell): quant_config=None, prefix=f"{prefix}.gate") - def construct(self, hidden_states: Tensor, dp_pad_index) -> Tensor: + def construct(self, hidden_states: Tensor, dp_pad_index, dp_unpad_index, + dp_pad_index_with_offset, dp_unpad_index_total_with_offset) -> Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape hidden_dim = hidden_states.shape[-1] @@ -141,7 +142,10 @@ class Qwen3MoeSparseMoeBlock(nn.Cell): router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits, - dp_pad_index=dp_pad_index) + dp_pad_index=dp_pad_index, + dp_unpad_index=dp_unpad_index, + dp_pad_index_with_offset=dp_pad_index_with_offset, + dp_unpad_index_total_with_offset=dp_unpad_index_total_with_offset) final_hidden_states = final_hidden_states if self.tp_size > 1: final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 @@ -320,7 +324,9 @@ class Qwen3MoeDecoderLayer(nn.Cell): block_tables: Tensor, residual: Optional[Tensor], dp_pad_index: Optional[bool] = None, - dp_select_index: Optional[Tensor] = None, + dp_unpad_index: Optional[Tensor] = None, + dp_pad_index_with_offset: Optional[Tensor] = None, + dp_unpad_index_total_with_offset: Optional[Tensor] = None, ) -> Tensor: # Self Attention if residual is None: @@ -336,7 +342,8 @@ class Qwen3MoeDecoderLayer(nn.Cell): # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) - hidden_states = self.mlp(hidden_states, dp_pad_index, dp_select_index) + hidden_states = self.mlp(hidden_states, dp_pad_index, dp_unpad_index, + dp_pad_index_with_offset, dp_unpad_index_total_with_offset) return hidden_states, residual @@ -387,7 +394,10 @@ class Qwen3MoeModel(nn.Cell): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[Tensor] = None, dp_pad_index = None, - dp_select_index: Optional[Tensor] = None, + dp_unpad_index: Optional[Tensor] = None, + dp_pad_index_total_with_offset: Optional[Tensor] = None, + dp_unpad_index_total_with_offset: Optional[Tensor] = None, + ) -> Union[Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -407,7 +417,9 @@ class Qwen3MoeModel(nn.Cell): is_prefill, slot_mapping, attn_mask, batch_valid_length, q_seq_lens, block_tables, residual, - dp_pad_index, dp_select_index) + dp_pad_index, dp_unpad_index, + dp_pad_index_total_with_offset, + dp_unpad_index_total_with_offset) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -629,13 +641,21 @@ class Qwen3MoeForCausalLM(NativeModel, SupportsPP): # end2 = tokens_cumulative[-1].item() - end # dp_pad_index = ms.Tensor([0, 0, start, end2], dtype=ms.int32) token_num_total = token_num_total.asnumpy() - max_token_num = int(token_num_total.max()) - total_pad_num = (max_token_num - token_num_total) + max_token_num = token_num_total.max() + total_pad_num = max_token_num - token_num_total this_pad_num = total_pad_num[self.dp_rank] - dp_pad_index = ms.Tensor([0, 0, 0, int(this_pad_num)], dtype=mstype.int32) - dp_select_index = [j + self.dp_rank * max_token_num - for j in range(token_num_total[self.dp_rank])] - dp_select_index = ms.Tensor(dp_select_index, dtype=mstype.int32) + + dp_unpad_index = np.arange(token_num_total[self.dp_rank]) + dp_pad_index = np.pad(dp_unpad_index, (0, this_pad_num)) + + dp_pad_index_total_with_offset = [np.pad(np.arange(token_num_total[rank]), (0, total_pad_num[rank])) + for rank in self.dp_world_size] + dp_pad_index_total_with_offset = np.concatenate(dp_pad_index_total_with_offset, axis=0) + + + dp_unpad_index_total_with_offset = [np.arange(token_num_total[rank]) + rank * max_token_num + for rank in self.dp_world_size] + dp_unpad_index_total_with_offset = ms.Tensor(dp_unpad_index_total_with_offset, dtype=mstype.int32) model_output = self.run_model( # type: ignore[misc] @@ -652,7 +672,9 @@ class Qwen3MoeForCausalLM(NativeModel, SupportsPP): intermediate_tensors=model_inputs["intermediate_tensors"], inputs_embeds=model_inputs["inputs_embeds"], dp_pad_index=dp_pad_index if self.dp_pad_input else None, - dp_select_index=dp_select_index if self.dp_pad_input else None + dp_unpad_index=dp_unpad_index if self.dp_pad_input else None, + dp_pad_index_total_with_offset=dp_pad_index_total_with_offset if self.dp_pad_input else None + dp_unpad_index_total_with_offset=dp_unpad_index_total_with_offset if self.dp_pad_input else None ) return model_output @@ -709,8 +731,10 @@ class Qwen3MoeForCausalLM(NativeModel, SupportsPP): dyn_batch_valid_length = Tensor(shape=[None], dtype=mstype.int32) dyn_q_seq_lens = Tensor(shape=[None], dtype=mstype.int32) dyn_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) - dyn_dp_pad_input = Tensor(shape=[4], dtype=mstype.int32) if self.dp_pad_input else None - dyn_dp_select_index = Tensor(shape=[None], dtype=mstype.int32) if self.dp_pad_input else None + dyn_dp_pad_index = Tensor(shape=[None], dtype=mstype.int32) if self.dp_pad_input else None + dyn_dp_unpad_index = Tensor(shape=[None], dtype=mstype.int32) if self.dp_pad_input else None + dyn_dp_pad_index_with_offset = Tensor(shape=[None], dtype=mstype.int32) if self.dp_pad_input else None + dp_unpad_index_total_with_offset = Tensor(shape=[None], dtype=mstype.int32) if self.dp_pad_input else None self.model.set_inputs( @@ -726,8 +750,10 @@ class Qwen3MoeForCausalLM(NativeModel, SupportsPP): dyn_block_tables, dyn_intermediate_tensors, dyn_inputs_embeds, - dyn_dp_pad_input, - dyn_dp_select_index) + dyn_dp_pad_index, + dyn_dp_unpad_index, + dyn_dp_pad_index_with_offset, + dp_unpad_index_total_with_offset) dynamic_hidden_states = Tensor(shape=[None, None], dtype=self.model_config.dtype) -- Gitee From 137b932fede213bcd3249cab4928d6769519642a Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Fri, 27 Jun 2025 14:38:48 +0800 Subject: [PATCH 36/77] update --- .../model_executor/layers/fused_moe/layer.py | 8 ++++---- vllm_mindspore/model_executor/models/qwen3_moe.py | 12 +++++++----- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index df74ede1..ff6244be 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -855,10 +855,10 @@ class FusedMoE(nn.Cell): # logit_buffer = self.all_reduce_from_dp_group(logit_buffer) # ops.AllGather is not supported for uneven size tensor, so need to pad to same size. - hidden_buffer = self.gather(hidden_states, dp_pad_index) + hidden_buffer = self.gather(hidden_states, dp_pad_index, 0) hidden_buffer = self.all_gather_from_dp_group(hidden_buffer) - logit_buffer = self.gather(router_logits, dp_pad_index) + logit_buffer = self.gather(router_logits, dp_pad_index, 0) logit_buffer = self.all_gather_from_dp_group(logit_buffer) hidden_states = mint.index_select(hidden_buffer, 0, dp_unpad_index_total_with_offset) @@ -888,9 +888,9 @@ class FusedMoE(nn.Cell): final_hidden_states = self.all_reduce_from_ep_group(final_hidden_states) # dp_size > 1 else: - final_hidden_states = self.gather(final_hidden_states, dp_pad_index_total_with_offset) + final_hidden_states = self.gather(final_hidden_states, dp_pad_index_total_with_offset, 0) final_hidden_states = self.reduce_scatter_from_ep_group(final_hidden_states) - final_hidden_states = self.gather(final_hidden_states, dp_unpad_index) + final_hidden_states = self.gather(final_hidden_states, dp_unpad_index, 0) # final_hidden_states = mint.index_select(final_hidden_states, 0, dp_unpad_index) # start = dp_pad_index[-2] # end = start + tokens_num diff --git a/vllm_mindspore/model_executor/models/qwen3_moe.py b/vllm_mindspore/model_executor/models/qwen3_moe.py index bd29b708..d0926e1f 100644 --- a/vllm_mindspore/model_executor/models/qwen3_moe.py +++ b/vllm_mindspore/model_executor/models/qwen3_moe.py @@ -645,16 +645,18 @@ class Qwen3MoeForCausalLM(NativeModel, SupportsPP): total_pad_num = max_token_num - token_num_total this_pad_num = total_pad_num[self.dp_rank] - dp_unpad_index = np.arange(token_num_total[self.dp_rank]) - dp_pad_index = np.pad(dp_unpad_index, (0, this_pad_num)) + dp_unpad_index = ms.Tensor(np.arange(token_num_total[self.dp_rank]), dtype=ms.int32) + dp_pad_index = ms.Tensor(np.pad(dp_unpad_index, (0, this_pad_num)), dtype=ms.int32) dp_pad_index_total_with_offset = [np.pad(np.arange(token_num_total[rank]), (0, total_pad_num[rank])) - for rank in self.dp_world_size] + for rank in range(self.dp_world_size)] dp_pad_index_total_with_offset = np.concatenate(dp_pad_index_total_with_offset, axis=0) + dp_pad_index_total_with_offset = ms.Tensor(dp_pad_index_total_with_offset, dtype=mstype.int32) dp_unpad_index_total_with_offset = [np.arange(token_num_total[rank]) + rank * max_token_num - for rank in self.dp_world_size] + for rank in range(self.dp_world_size)] + dp_unpad_index_total_with_offset = np.concatenate(dp_unpad_index_total_with_offset, axis=0) dp_unpad_index_total_with_offset = ms.Tensor(dp_unpad_index_total_with_offset, dtype=mstype.int32) @@ -673,7 +675,7 @@ class Qwen3MoeForCausalLM(NativeModel, SupportsPP): inputs_embeds=model_inputs["inputs_embeds"], dp_pad_index=dp_pad_index if self.dp_pad_input else None, dp_unpad_index=dp_unpad_index if self.dp_pad_input else None, - dp_pad_index_total_with_offset=dp_pad_index_total_with_offset if self.dp_pad_input else None + dp_pad_index_total_with_offset=dp_pad_index_total_with_offset if self.dp_pad_input else None, dp_unpad_index_total_with_offset=dp_unpad_index_total_with_offset if self.dp_pad_input else None ) -- Gitee From ced4491a6512d8a19ef93aaaa93419c2a7c185b9 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Fri, 27 Jun 2025 15:13:09 +0800 Subject: [PATCH 37/77] update --- .../model_executor/layers/fused_moe/layer.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index ff6244be..e8ee29bd 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -534,6 +534,8 @@ class FusedMoE(nn.Cell): self.dp_group = get_dp_group().device_group._name self.ep_group = get_ep_group().device_group._name + self.tp_world_size = get_tensor_model_parallel_world_size() + if self.dp_size > 1 and self.ep_size == 1 or self.dp_size == 1: self.pure_tp = True self.all_reduce_from_ep_group = ops.AllReduce(group=self.ep_group) @@ -541,7 +543,9 @@ class FusedMoE(nn.Cell): self.gather = ops.Gather() self.all_gather_from_dp_group = ops.AllGather(group=self.dp_group) self.all_reduce_from_dp_group = ops.AllReduce(group=self.dp_group) - self.reduce_scatter_from_dp_group = ops.ReduceScatter(group=self.dp_group) + # self.reduce_scatter_from_ep_group = ops.ReduceScatter(group=self.ep_group) + self.reduce_from_ep_group = ops.Reduce(0, group=self.ep_group) + self.scatter_to_ep_group = ops.CollectiveScatter(0, group=self.ep_group) @property def tp_size(self): @@ -889,7 +893,10 @@ class FusedMoE(nn.Cell): # dp_size > 1 else: final_hidden_states = self.gather(final_hidden_states, dp_pad_index_total_with_offset, 0) - final_hidden_states = self.reduce_scatter_from_ep_group(final_hidden_states) + final_hidden_states = self.reduce_from_ep_group(final_hidden_states) + final_hidden_states = mint.repeat_interleave(final_hidden_states, self.tp_world_size, dim=0) + final_hidden_states = self.scatter_to_ep_group(final_hidden_states) + # final_hidden_states = self.reduce_scatter_from_ep_group(final_hidden_states) final_hidden_states = self.gather(final_hidden_states, dp_unpad_index, 0) # final_hidden_states = mint.index_select(final_hidden_states, 0, dp_unpad_index) # start = dp_pad_index[-2] -- Gitee From 1e7fc08ece80737191e7643b8da27957460bda5e Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Fri, 27 Jun 2025 16:13:23 +0800 Subject: [PATCH 38/77] update --- vllm_mindspore/model_executor/layers/fused_moe/layer.py | 8 ++------ vllm_mindspore/model_executor/models/qwen3_moe.py | 9 +++++++-- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index e8ee29bd..e964d987 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -543,9 +543,7 @@ class FusedMoE(nn.Cell): self.gather = ops.Gather() self.all_gather_from_dp_group = ops.AllGather(group=self.dp_group) self.all_reduce_from_dp_group = ops.AllReduce(group=self.dp_group) - # self.reduce_scatter_from_ep_group = ops.ReduceScatter(group=self.ep_group) - self.reduce_from_ep_group = ops.Reduce(0, group=self.ep_group) - self.scatter_to_ep_group = ops.CollectiveScatter(0, group=self.ep_group) + self.reduce_scatter_from_ep_group = ops.ReduceScatter(group=self.ep_group) @property def tp_size(self): @@ -893,10 +891,8 @@ class FusedMoE(nn.Cell): # dp_size > 1 else: final_hidden_states = self.gather(final_hidden_states, dp_pad_index_total_with_offset, 0) - final_hidden_states = self.reduce_from_ep_group(final_hidden_states) final_hidden_states = mint.repeat_interleave(final_hidden_states, self.tp_world_size, dim=0) - final_hidden_states = self.scatter_to_ep_group(final_hidden_states) - # final_hidden_states = self.reduce_scatter_from_ep_group(final_hidden_states) + final_hidden_state = self.reduce_scatter_from_ep_group(final_hidden_state) final_hidden_states = self.gather(final_hidden_states, dp_unpad_index, 0) # final_hidden_states = mint.index_select(final_hidden_states, 0, dp_unpad_index) # start = dp_pad_index[-2] diff --git a/vllm_mindspore/model_executor/models/qwen3_moe.py b/vllm_mindspore/model_executor/models/qwen3_moe.py index d0926e1f..882f584e 100644 --- a/vllm_mindspore/model_executor/models/qwen3_moe.py +++ b/vllm_mindspore/model_executor/models/qwen3_moe.py @@ -641,6 +641,7 @@ class Qwen3MoeForCausalLM(NativeModel, SupportsPP): # end2 = tokens_cumulative[-1].item() - end # dp_pad_index = ms.Tensor([0, 0, start, end2], dtype=ms.int32) token_num_total = token_num_total.asnumpy() + token_num_total_cumsum = np.cumsum(token_num_total) max_token_num = token_num_total.max() total_pad_num = max_token_num - token_num_total this_pad_num = total_pad_num[self.dp_rank] @@ -648,8 +649,12 @@ class Qwen3MoeForCausalLM(NativeModel, SupportsPP): dp_unpad_index = ms.Tensor(np.arange(token_num_total[self.dp_rank]), dtype=ms.int32) dp_pad_index = ms.Tensor(np.pad(dp_unpad_index, (0, this_pad_num)), dtype=ms.int32) - dp_pad_index_total_with_offset = [np.pad(np.arange(token_num_total[rank]), (0, total_pad_num[rank])) - for rank in range(self.dp_world_size)] + # dp_pad_index_total_with_offset = [np.pad(np.arange(token_num_total[rank]), (0, total_pad_num[rank])) + # for rank in range(self.dp_world_size)] + dp_pad_index_total_with_offset = [np.pad(np.arange(0 if rank == 0 else token_num_total_cumsum[rank - 1], + token_num_total_cumsum[rank]), (0, total_pad_num[rank])) + for rank in range(self.dp_world_size)] + dp_pad_index_total_with_offset = np.concatenate(dp_pad_index_total_with_offset, axis=0) dp_pad_index_total_with_offset = ms.Tensor(dp_pad_index_total_with_offset, dtype=mstype.int32) -- Gitee From 80b9c6a99e88f3554cfbc45c0c3ef83b72d6b766 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Fri, 27 Jun 2025 17:22:59 +0800 Subject: [PATCH 39/77] update --- vllm_mindspore/model_executor/layers/fused_moe/layer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index e964d987..cc5a35ba 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -857,6 +857,10 @@ class FusedMoE(nn.Cell): # logit_buffer = self.all_reduce_from_dp_group(logit_buffer) # ops.AllGather is not supported for uneven size tensor, so need to pad to same size. + num_token = ms.Tensor([[hidden_states.shape[0]]], dtype=ms.int32) + all_num_token = self.all_gather_from_dp_group(num_token) + all_num_token_cumsum = mint.cumsum(all_num_token, dim=0) + hidden_buffer = self.gather(hidden_states, dp_pad_index, 0) hidden_buffer = self.all_gather_from_dp_group(hidden_buffer) @@ -891,8 +895,10 @@ class FusedMoE(nn.Cell): # dp_size > 1 else: final_hidden_states = self.gather(final_hidden_states, dp_pad_index_total_with_offset, 0) + final_hidden_states = final_hidden_states.reshape(self.dp_size, -1, final_hidden_states.shape[-1]) final_hidden_states = mint.repeat_interleave(final_hidden_states, self.tp_world_size, dim=0) - final_hidden_state = self.reduce_scatter_from_ep_group(final_hidden_state) + final_hidden_states = final_hidden_states.reshape(-1, final_hidden_states.shape[-1]) + final_hidden_states = self.reduce_scatter_from_ep_group(final_hidden_states) final_hidden_states = self.gather(final_hidden_states, dp_unpad_index, 0) # final_hidden_states = mint.index_select(final_hidden_states, 0, dp_unpad_index) # start = dp_pad_index[-2] -- Gitee From f20a9df9a028c4568a60f186e964e9cf1c236194 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Fri, 27 Jun 2025 17:36:29 +0800 Subject: [PATCH 40/77] update --- vllm_mindspore/model_executor/layers/fused_moe/layer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index cc5a35ba..0685b9dd 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -857,10 +857,6 @@ class FusedMoE(nn.Cell): # logit_buffer = self.all_reduce_from_dp_group(logit_buffer) # ops.AllGather is not supported for uneven size tensor, so need to pad to same size. - num_token = ms.Tensor([[hidden_states.shape[0]]], dtype=ms.int32) - all_num_token = self.all_gather_from_dp_group(num_token) - all_num_token_cumsum = mint.cumsum(all_num_token, dim=0) - hidden_buffer = self.gather(hidden_states, dp_pad_index, 0) hidden_buffer = self.all_gather_from_dp_group(hidden_buffer) -- Gitee From 43d15258e97201206a188199a932dac595cd3479 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Sat, 28 Jun 2025 10:23:06 +0800 Subject: [PATCH 41/77] test --- .../model_executor/layers/fused_moe/layer.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 0685b9dd..4cb4e8d3 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -822,14 +822,15 @@ class FusedMoE(nn.Cell): Therefore it is required that we reduce the shared_experts output early. """ - return self.use_pplx_kernels + return not (self.pure_tp and self.dp_size == 1) def maybe_all_reduce_tensor_model_parallel( self, final_hidden_states: Tensor): """ The pplx combine kernel reduces across GPU ranks by default. """ - # return tensor_model_parallel_all_reduce(final_hidden_states) + if self.pure_tp and self.dp_size == 1: + return tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states def construct(self, hidden_states: Tensor, @@ -885,11 +886,12 @@ class FusedMoE(nn.Cell): apply_router_weight_on_input=self.apply_router_weight_on_input, ) - if self.pure_tp: - if self.dp_size == 1: - final_hidden_states = self.all_reduce_from_ep_group(final_hidden_states) - # dp_size > 1 - else: + # if self.pure_tp: + # if self.dp_size == 1: + # final_hidden_states = self.all_reduce_from_ep_group(final_hidden_states) + # # dp_size > 1 + # else: + if self.pure_tp and self.dp_size > 1: final_hidden_states = self.gather(final_hidden_states, dp_pad_index_total_with_offset, 0) final_hidden_states = final_hidden_states.reshape(self.dp_size, -1, final_hidden_states.shape[-1]) final_hidden_states = mint.repeat_interleave(final_hidden_states, self.tp_world_size, dim=0) -- Gitee From efc30693af49dcf5ba14f7a877bd42f68c205cd0 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Sat, 28 Jun 2025 11:04:24 +0800 Subject: [PATCH 42/77] update test --- .../model_executor/layers/fused_moe/layer.py | 31 +++++++++---------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 4cb4e8d3..304b4557 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -38,6 +38,7 @@ from vllm_mindspore.model_executor.layers.fused_moe.fused_moe import (fused_topk grouped_topk, fused_experts) from vllm_mindspore.model_executor.layers.quantization.base_config import QuantizeMethodBase +from vllm_mindspore.distributed.communication_op import ReduceFromModelParallelRegion from mindspore import nn, Tensor, Parameter, mint, ops import mindspore as ms @@ -536,9 +537,13 @@ class FusedMoE(nn.Cell): self.tp_world_size = get_tensor_model_parallel_world_size() - if self.dp_size > 1 and self.ep_size == 1 or self.dp_size == 1: + self.reduce_from_tp_group = ReduceFromModelParallelRegion() + + # pure_tp means using tensor parallelism only, no expert parallelism. + self.pure_tp = False + + if self.tp_size >= 1 and self.ep_size == 1: self.pure_tp = True - self.all_reduce_from_ep_group = ops.AllReduce(group=self.ep_group) if self.dp_size > 1: self.gather = ops.Gather() self.all_gather_from_dp_group = ops.AllGather(group=self.dp_group) @@ -810,27 +815,19 @@ class FusedMoE(nn.Cell): return topk_weights, topk_ids def must_reduce_shared_expert_outputs(self) -> bool: - """ - The shared_experts are typically computed using the RowParallelLinear - layer. The result of this function is typically used as - the reduce_results argument to the module. - When just tensor-parallel is used, it is not required to reduce - the shared_experts results immediately. Instead we reduce at the - once at the end of the MoE op. (Refer to DeepSeekV2MoE module) - With EP and the pplx kernels - this is no longer viable as all - GPU ranks in DP, produce the complete set of hidden_states. - Therefore it is required that we reduce the shared_experts output - early. - """ + # If dp_size == 1, means routed expert use the same tensor parallel group as shared expert. + # And meanwhile if ep_size == 1, it means using tensor parallel to compute routed expert. + # So we can delay the shared expert outputs reduce after the routed expert and + # the shared expert are added. return not (self.pure_tp and self.dp_size == 1) def maybe_all_reduce_tensor_model_parallel( self, final_hidden_states: Tensor): """ - The pplx combine kernel reduces across GPU ranks by default. + To all_reduce after routed expert and shared expert are added. """ if self.pure_tp and self.dp_size == 1: - return tensor_model_parallel_all_reduce(final_hidden_states) + return self.reduce_from_tp_group(final_hidden_states) return final_hidden_states def construct(self, hidden_states: Tensor, @@ -850,7 +847,7 @@ class FusedMoE(nn.Cell): # dp_pad_index = [0, 1, 2, 3, 0, 0, 0] # dp_pad_index_with_offset = [5, 6, 7, 8, 0, 0, 0] - if self.dp_size > 1 and self.pure_tp: + if self.pure_tp and self.dp_size > 1: # hidden_buffer = mint.nn.functional.pad(hidden_states, dp_pad_index) # hidden_buffer = self.all_reduce_from_dp_group(hidden_buffer) -- Gitee From f0bbd02c16765ebcfd6d1621936ffa7ce7d43b05 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Sat, 28 Jun 2025 11:29:01 +0800 Subject: [PATCH 43/77] update --- .../model_executor/layers/fused_moe/layer.py | 32 +++++++------------ 1 file changed, 11 insertions(+), 21 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 304b4557..2692ce1f 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -545,7 +545,6 @@ class FusedMoE(nn.Cell): if self.tp_size >= 1 and self.ep_size == 1: self.pure_tp = True if self.dp_size > 1: - self.gather = ops.Gather() self.all_gather_from_dp_group = ops.AllGather(group=self.dp_group) self.all_reduce_from_dp_group = ops.AllReduce(group=self.dp_group) self.reduce_scatter_from_ep_group = ops.ReduceScatter(group=self.ep_group) @@ -603,7 +602,6 @@ class FusedMoE(nn.Cell): # w3, up_proj: Load into second logical weight of w13. else: assert shard_id == "w3" - # expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) if shard_dim == 1: param[expert_id, :, shard_size:shard_size*2] = loaded_weight else: @@ -619,7 +617,6 @@ class FusedMoE(nn.Cell): # w3, up_proj: Load into second logical weight of w13. else: assert shard_id == "w3" - # expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) if shard_dim == 2: param[:, :, shard_size:shard_size*2] = loaded_weight else: @@ -844,21 +841,19 @@ class FusedMoE(nn.Cell): router_logits: Tensor, dp_pad_index, dp_unpad_index, dp_pad_index_total_with_offset, dp_unpad_index_total_with_offset): - # dp_pad_index = [0, 1, 2, 3, 0, 0, 0] - # dp_pad_index_with_offset = [5, 6, 7, 8, 0, 0, 0] - + """ + If dp_world_size == 4, dp_rank == 1, tokens_num across dp is [1, 3, 4, 2], then + dp_pad_index = [0, 1, 2, 0] + dp_unpad_index = [0, 1, 2] + dp_pad_index_total_with_offset = [0, 0, 0, 0, 1, 2, 3, 0, 4, 5, 6, 0, 7, 8, 0, 0] + dp_unpad_index_total_with_offset = [0, 4, 5, 6, 8, 9, 10, 11, 12, 13] + """ if self.pure_tp and self.dp_size > 1: - # hidden_buffer = mint.nn.functional.pad(hidden_states, dp_pad_index) - # hidden_buffer = self.all_reduce_from_dp_group(hidden_buffer) - - # logit_buffer = mint.nn.functional.pad(router_logits, dp_pad_index) - # logit_buffer = self.all_reduce_from_dp_group(logit_buffer) - # ops.AllGather is not supported for uneven size tensor, so need to pad to same size. - hidden_buffer = self.gather(hidden_states, dp_pad_index, 0) + hidden_buffer = mint.index_select(hidden_states, 0, dp_pad_index) hidden_buffer = self.all_gather_from_dp_group(hidden_buffer) - logit_buffer = self.gather(router_logits, dp_pad_index, 0) + logit_buffer = mint.index_select(router_logits, 0, dp_pad_index) logit_buffer = self.all_gather_from_dp_group(logit_buffer) hidden_states = mint.index_select(hidden_buffer, 0, dp_unpad_index_total_with_offset) @@ -883,18 +878,13 @@ class FusedMoE(nn.Cell): apply_router_weight_on_input=self.apply_router_weight_on_input, ) - # if self.pure_tp: - # if self.dp_size == 1: - # final_hidden_states = self.all_reduce_from_ep_group(final_hidden_states) - # # dp_size > 1 - # else: if self.pure_tp and self.dp_size > 1: - final_hidden_states = self.gather(final_hidden_states, dp_pad_index_total_with_offset, 0) + final_hidden_states = mint.index_select(final_hidden_states, 0, dp_pad_index_total_with_offset) final_hidden_states = final_hidden_states.reshape(self.dp_size, -1, final_hidden_states.shape[-1]) final_hidden_states = mint.repeat_interleave(final_hidden_states, self.tp_world_size, dim=0) final_hidden_states = final_hidden_states.reshape(-1, final_hidden_states.shape[-1]) final_hidden_states = self.reduce_scatter_from_ep_group(final_hidden_states) - final_hidden_states = self.gather(final_hidden_states, dp_unpad_index, 0) + final_hidden_states = mint.index_select(final_hidden_states, 0, dp_unpad_index) # final_hidden_states = mint.index_select(final_hidden_states, 0, dp_unpad_index) # start = dp_pad_index[-2] # end = start + tokens_num -- Gitee From 6cfd5b07627809fc0be29c629ed0ef8e13572892 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Sat, 28 Jun 2025 16:48:33 +0800 Subject: [PATCH 44/77] update --- .../layers/fused_moe/fused_moe2.py | 131 ++++++++++++++---- .../model_executor/layers/fused_moe/layer.py | 1 + 2 files changed, 107 insertions(+), 25 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py index f3c44146..e40b022c 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py @@ -6,7 +6,7 @@ from mindspore.ops.auto_generate import (GroupedMatmulV4, MoeInitRoutingV2, MoeTokenUnpermute) import mindspore as ms -from vllm.distributed.parallel_state import get_ep_group, get_dp_group +from vllm.distributed.parallel_state import get_ep_group, get_tp_group, get_tensor_model_parallel_rank def fused_topk( hidden_states: Tensor, @@ -60,12 +60,51 @@ def grouped_topk( class FusedExperts(nn.Cell): - def __init__(self): + def __init__(self, moe_config): super().__init__() self.group_matmul_ops = GroupedMatmulV4() self.moe_init_routing_op = MoeInitRoutingV2() self.moe_token_unpermute = MoeTokenUnpermute() + self.pure_tp = False + self.pure_ep = True + + if moe_config.moe_parallel_config.ep_size > 1 and \ + moe_config.moe_parallel_config.tp_size == 1: + # pure ep + self.pure_ep = True + self.tp_rank = get_tensor_model_parallel_rank() + ep_size = moe_config.moe_parallel_config.ep_size + self.ep_size = ep_size + self.ep_group = get_ep_group().device_group._name + experts_num = moe_config.num_experts + expert_num_map = [(experts_num // ep_size) * i - 1 + for i in range(1, ep_size)] + expert_num_map.append(experts_num - ((experts_num // ep_size) * (ep_size - 1))) + if self.tp_rank == 0: + self.send_experts_num_map = ms.Tensor(expert_num_map, dtype=ms.int64) + else: + self.send_experts_num_map = ms.zeros(ep_size, dtype=ms.int64) + self.recv_experts_num_map = ms.Tensor([moe_config.num_local_experts for _ in range(ep_size)], + dtype=ms.int64) + + self.all_to_all_across_ep = ops.AlltoAll(split_count=ep_size, + split_dim=0, + concat_dim=0, + group=self.ep_group) + + self.expert_hidden_size = moe_config.hidden_dim + self.all_to_all_v_across_ep_with_block_size = ops.AlltoAllV(block_size=self.hidden_size, + group=self.ep_group) + self.all_to_all_v_across_ep = ops.AlltoAllV(group=self.ep_group) + + self.tp_group = get_tp_group().device_group._name + self.broadcast_to_tensor_parallel_region = ops.Broadcast(0, group=self.tp_group) + + if moe_config.moe_parallel_config.ep_size == 1 and \ + moe_config.moe_parallel_config.tp_size >= 1: + self.pure_tp = True + def construct(self, hidden_states: Tensor, w1: Tensor, @@ -79,24 +118,21 @@ class FusedExperts(nn.Cell): tp_size: int = 1, ep_size: int = 0) -> Tensor: - if tp_size >= 1: - # no ep, pure tp - if ep_size == 1: - hidden_states = self._run_tp_moe(hidden_states, w1, w2, topk_ids, topk_weights, - activation, global_num_experts, - apply_router_weight_on_input) + if self.pure_tp: + hidden_states = self._run_tp_moe(hidden_states, w1, w2, topk_ids, topk_weights, + activation, global_num_experts, + apply_router_weight_on_input) # ep_size > 1 : pure ep or tp + ep - else: + elif self.pure_ep: # pure ep - if tp_size == 1: - hidden_states = self._run_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, - activation, global_num_experts, - apply_router_weight_on_input) - # tp_size > 1 : tp + ep - else: - hidden_states = self._run_tp_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, - activation, global_num_experts, - apply_router_weight_on_input) + hidden_states = self._run_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, + activation, global_num_experts, + apply_router_weight_on_input) + # tp_size > 1 : tp + ep + else: + hidden_states = self._run_tp_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, + activation, global_num_experts, + apply_router_weight_on_input) return hidden_states @@ -109,9 +145,6 @@ class FusedExperts(nn.Cell): else: raise ValueError(f"Unsupported activation function: {activation}") - - - def _group_matmul(self, hidden_states, weight, group_list): return self.group_matmul_ops([hidden_states], [weight], None, None, None, None, None, None, @@ -126,10 +159,58 @@ class FusedExperts(nn.Cell): activation, global_num_experts, apply_router_weight_on_input): - hidden_states = self._group_matmul(hidden_states, w1, topk_ids) - hidden_states = self._gate_activation(hidden_states, activation) - hidden_states = self._group_matmul(hidden_states, w2, topk_ids) - return hidden_states + topk_weights = topk_weights.astype(hidden_states.dtype) + topk_ids = topk_ids.astype(ms.int32) + + sorted_input_tensor, unsort_map, group_list, _ = \ + self.moe_init_routing_op( + hidden_states, + topk_ids, + active_num=0, + expert_capacity=0, + expert_num=global_num_experts, + drop_pad_mode=0, + expert_tokens_count_or_cumsum_flag=2, + expert_tokens_before_capacity_flag=True) + + group_list = group_list.astype(ms.int64) + + # group_list = group_list.reshape(1, -1) + + if self.tp_rank == 0: + group_list_cumsum = mint.cumsum(group_list) + # expert index = [3, 7, 11, 15] (self.ep_group_size,) + send_list = group_list_cumsum[self.expert_num_map] # [20, 30, 40, 50] + else: + send_list = mint.zeros(self.ep_size, dtype=ms.int64) # [0, 0, 0, 0] + + recv_list = self.all_to_all_across_ep(send_list) + # recv_list [20, 40, 60, 70] + local_input_tensor = self.all_to_all_v_across_ep_with_block_size(sorted_input_tensor, + send_list, recv_list) + + local_group_list = self.all_to_all_v_across_ep(group_list, self.send_experts_num_map, + self.recv_experts_num_map) + local_group_list = local_group_list.reshape(-1, self.local_expert_num) + local_group_list = local_group_list.sum(dim=0) + + recv_tokens = recv_list.sum() + if recv_tokens > 0: + gate_hidden_out = self._group_matmul(local_input_tensor, mint.transpose(w1, -1, -2), local_group_list) + gate, hidden = mint.split(gate_hidden_out, + (w1.shape[1] // 2, w1.shape[1] // 2), -1) + gate = self._gate_activation(gate, activation) + hidden = mint.mul(hidden, gate) + expert_output = self._group_matmul(hidden, mint.transpose(w2, -1, -2), group_list) + expert_output = mint.nan_to_num(expert_output, 0, 0, 0) + expert_output = self.all_to_all_v_across_ep_with_block_size(expert_output, recv_list, send_list) + moe_output = self.moe_token_unpermute(permuted_tokens=expert_output, + sorted_indices=unsort_map, + probs=topk_weights, + padded_mode=False, + restore_shape=None) + moe_output = self.broadcast_to_tensor_parallel_region(moe_output) + return moe_output def _run_tp_moe(self, diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 2692ce1f..9aa9bb4c 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -280,6 +280,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, nn.Cell): def __init__(self, moe: MoEConfig): super().__init__() self.fused_experts = fused_experts # type: ignore + # self.fused_experts = fused_experts(moe) self.moe = moe def create_weights(self, layer: nn.Cell, num_experts: int, -- Gitee From f28294f207913aae341ec28cab4095bb5d2c2fdc Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Sat, 28 Jun 2025 17:43:27 +0800 Subject: [PATCH 45/77] update --- .../layers/fused_moe/fused_moe2.py | 30 ++++++++++++------- .../model_executor/layers/fused_moe/layer.py | 5 ++-- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py index e40b022c..b11253ee 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py @@ -81,10 +81,11 @@ class FusedExperts(nn.Cell): expert_num_map = [(experts_num // ep_size) * i - 1 for i in range(1, ep_size)] expert_num_map.append(experts_num - ((experts_num // ep_size) * (ep_size - 1))) + self.experts_num_map = ms.Tensor(expert_num_map, dtype=ms.int64) if self.tp_rank == 0: self.send_experts_num_map = ms.Tensor(expert_num_map, dtype=ms.int64) else: - self.send_experts_num_map = ms.zeros(ep_size, dtype=ms.int64) + self.send_experts_num_map = mint.zeros(ep_size, dtype=ms.int64) self.recv_experts_num_map = ms.Tensor([moe_config.num_local_experts for _ in range(ep_size)], dtype=ms.int64) @@ -94,13 +95,16 @@ class FusedExperts(nn.Cell): group=self.ep_group) self.expert_hidden_size = moe_config.hidden_dim - self.all_to_all_v_across_ep_with_block_size = ops.AlltoAllV(block_size=self.hidden_size, + self.all_to_all_v_across_ep_with_block_size = ops.AlltoAllV(block_size=self.expert_hidden_size, group=self.ep_group) self.all_to_all_v_across_ep = ops.AlltoAllV(group=self.ep_group) self.tp_group = get_tp_group().device_group._name self.broadcast_to_tensor_parallel_region = ops.Broadcast(0, group=self.tp_group) + self.dummy_token = mint.zeros((1, self.hidden_size), dtype=moe_config.in_dtype) + + if moe_config.moe_parallel_config.ep_size == 1 and \ moe_config.moe_parallel_config.tp_size >= 1: self.pure_tp = True @@ -178,15 +182,15 @@ class FusedExperts(nn.Cell): # group_list = group_list.reshape(1, -1) if self.tp_rank == 0: - group_list_cumsum = mint.cumsum(group_list) + group_list_cumsum = mint.cumsum(group_list, 0) # expert index = [3, 7, 11, 15] (self.ep_group_size,) - send_list = group_list_cumsum[self.expert_num_map] # [20, 30, 40, 50] + send_list = group_list_cumsum[self.experts_num_map] # [20, 30, 40, 50] else: send_list = mint.zeros(self.ep_size, dtype=ms.int64) # [0, 0, 0, 0] recv_list = self.all_to_all_across_ep(send_list) # recv_list [20, 40, 60, 70] - local_input_tensor = self.all_to_all_v_across_ep_with_block_size(sorted_input_tensor, + local_input_tensor = self.all_to_all_v_across_ep_with_block_size(sorted_input_tensor.reshape(-1), send_list, recv_list) local_group_list = self.all_to_all_v_across_ep(group_list, self.send_experts_num_map, @@ -196,6 +200,7 @@ class FusedExperts(nn.Cell): recv_tokens = recv_list.sum() if recv_tokens > 0: + local_input_tensor = local_input_tensor.reshape(-1, self.hidden_size) gate_hidden_out = self._group_matmul(local_input_tensor, mint.transpose(w1, -1, -2), local_group_list) gate, hidden = mint.split(gate_hidden_out, (w1.shape[1] // 2, w1.shape[1] // 2), -1) @@ -203,12 +208,15 @@ class FusedExperts(nn.Cell): hidden = mint.mul(hidden, gate) expert_output = self._group_matmul(hidden, mint.transpose(w2, -1, -2), group_list) expert_output = mint.nan_to_num(expert_output, 0, 0, 0) - expert_output = self.all_to_all_v_across_ep_with_block_size(expert_output, recv_list, send_list) - moe_output = self.moe_token_unpermute(permuted_tokens=expert_output, - sorted_indices=unsort_map, - probs=topk_weights, - padded_mode=False, - restore_shape=None) + else: + expert_output = self.dummy_token + expert_output = self.all_to_all_v_across_ep_with_block_size(expert_output.reshape(-1), + recv_list, send_list) + moe_output = self.moe_token_unpermute(permuted_tokens=expert_output, + sorted_indices=unsort_map, + probs=topk_weights, + padded_mode=False, + restore_shape=None) moe_output = self.broadcast_to_tensor_parallel_region(moe_output) return moe_output diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 9aa9bb4c..d3514411 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -37,6 +37,7 @@ from vllm.model_executor.layers.fused_moe.layer import (determine_expert_map, from vllm_mindspore.model_executor.layers.fused_moe.fused_moe import (fused_topk, grouped_topk, fused_experts) +from vllm_mindspore.model_executor.layers.fused_moe.fused_moe2 import FusedExperts from vllm_mindspore.model_executor.layers.quantization.base_config import QuantizeMethodBase from vllm_mindspore.distributed.communication_op import ReduceFromModelParallelRegion @@ -279,8 +280,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, nn.Cell): def __init__(self, moe: MoEConfig): super().__init__() - self.fused_experts = fused_experts # type: ignore - # self.fused_experts = fused_experts(moe) + # self.fused_experts = fused_experts # type: ignore + self.fused_experts = FusedExperts(moe) self.moe = moe def create_weights(self, layer: nn.Cell, num_experts: int, -- Gitee From b58c0f8020e6020433d647f014e4b6f3bac24c8d Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Mon, 30 Jun 2025 20:53:23 +0800 Subject: [PATCH 46/77] update --- .../layers/fused_moe/fused_moe2.py | 90 ++++++++++++------- 1 file changed, 58 insertions(+), 32 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py index b11253ee..d22a72a9 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py @@ -1,12 +1,14 @@ from typing import Optional +import numpy as np from mindspore import Tensor, mint, ops, nn from mindspore.ops.auto_generate import (GroupedMatmulV4, FusedAddTopKDiv, MoeInitRoutingV2, MoeTokenUnpermute) import mindspore as ms -from vllm.distributed.parallel_state import get_ep_group, get_tp_group, get_tensor_model_parallel_rank +from vllm.distributed.parallel_state import (get_ep_group, get_tp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) def fused_topk( hidden_states: Tensor, @@ -78,33 +80,47 @@ class FusedExperts(nn.Cell): self.ep_size = ep_size self.ep_group = get_ep_group().device_group._name experts_num = moe_config.num_experts - expert_num_map = [(experts_num // ep_size) * i - 1 - for i in range(1, ep_size)] - expert_num_map.append(experts_num - ((experts_num // ep_size) * (ep_size - 1))) - self.experts_num_map = ms.Tensor(expert_num_map, dtype=ms.int64) + experts_num_map = [(experts_num // ep_size) + for _ in range(ep_size - 1)] + experts_num_map.append(experts_num - ((experts_num // ep_size) * (ep_size - 1))) + # self.experts_num_map = ms.Tensor(expert_num_map, dtype=ms.int64) + experts_num_map_np = np.array(experts_num_map, dtype=np.int64) + experts_num_map_cu_np = np.cumsum(experts_num_map_np, dtype=np.int64) + self.experts_num_map_cu_index = ms.Tensor(experts_num_map_cu_np - 1, dtype=ms.int64) + if self.tp_rank == 0: - self.send_experts_num_map = ms.Tensor(expert_num_map, dtype=ms.int64) + self.send_experts_num_map = ms.Tensor(experts_num_map, dtype=ms.int64) else: self.send_experts_num_map = mint.zeros(ep_size, dtype=ms.int64) - self.recv_experts_num_map = ms.Tensor([moe_config.num_local_experts for _ in range(ep_size)], - dtype=ms.int64) - - self.all_to_all_across_ep = ops.AlltoAll(split_count=ep_size, - split_dim=0, - concat_dim=0, - group=self.ep_group) - self.expert_hidden_size = moe_config.hidden_dim - self.all_to_all_v_across_ep_with_block_size = ops.AlltoAllV(block_size=self.expert_hidden_size, + tp_world_size = get_tensor_model_parallel_world_size() + recv_num_map_list = [] + for i in range(self.ep_size): + if i % tp_world_size == 0: + recv_num_map_list.append(moe_config.num_local_experts) + else: + recv_num_map_list.append(0) + self.recv_experts_num_map = ms.Tensor(recv_num_map_list, dtype=ms.int64) + self.local_expert_num = moe_config.num_local_experts + + self.prepend_tensor = ms.Tensor([0], dtype=ms.int32) + + # self.all_to_all_across_ep = ops.AlltoAll(split_count=ep_size, + # split_dim=0, + # concat_dim=0, + # group=self.ep_group) + + self.hidden_size = moe_config.hidden_dim + self.all_to_all_v_across_ep_with_block_size = ops.AlltoAllV(block_size=self.hidden_size, group=self.ep_group) self.all_to_all_v_across_ep = ops.AlltoAllV(group=self.ep_group) + self.even_list = [1 for _ in range(ep_size)] self.tp_group = get_tp_group().device_group._name self.broadcast_to_tensor_parallel_region = ops.Broadcast(0, group=self.tp_group) self.dummy_token = mint.zeros((1, self.hidden_size), dtype=moe_config.in_dtype) - if moe_config.moe_parallel_config.ep_size == 1 and \ moe_config.moe_parallel_config.tp_size >= 1: self.pure_tp = True @@ -177,23 +193,28 @@ class FusedExperts(nn.Cell): expert_tokens_count_or_cumsum_flag=2, expert_tokens_before_capacity_flag=True) - group_list = group_list.astype(ms.int64) - # group_list = group_list.reshape(1, -1) if self.tp_rank == 0: - group_list_cumsum = mint.cumsum(group_list, 0) + group_list_cumsum = mint.cumsum(group_list, 0, dtype=ms.int32) # expert index = [3, 7, 11, 15] (self.ep_group_size,) - send_list = group_list_cumsum[self.experts_num_map] # [20, 30, 40, 50] + # 看下每个rank, 发送多少tensor 数据给其他的rank + send_list = group_list_cumsum[self.experts_num_map_cu_index] # [20, 30, 40, 50] + send_list = mint.diff(send_list, prepend=self.prepend_tensor) else: - send_list = mint.zeros(self.ep_size, dtype=ms.int64) # [0, 0, 0, 0] + send_list = mint.zeros(self.ep_size, dtype=ms.int32) # [0, 0, 0, 0] + + group_list = group_list.astype(ms.int64) - recv_list = self.all_to_all_across_ep(send_list) + # recv_list = self.all_to_all_across_ep(send_list) + recv_list = self.all_to_all_v_across_ep(send_list, self.even_list, self.even_list) # recv_list [20, 40, 60, 70] local_input_tensor = self.all_to_all_v_across_ep_with_block_size(sorted_input_tensor.reshape(-1), - send_list, recv_list) + send_list.tolist(), + recv_list.tolist()) - local_group_list = self.all_to_all_v_across_ep(group_list, self.send_experts_num_map, + local_group_list = self.all_to_all_v_across_ep(group_list, + self.send_experts_num_map, self.recv_experts_num_map) local_group_list = local_group_list.reshape(-1, self.local_expert_num) local_group_list = local_group_list.sum(dim=0) @@ -206,18 +227,23 @@ class FusedExperts(nn.Cell): (w1.shape[1] // 2, w1.shape[1] // 2), -1) gate = self._gate_activation(gate, activation) hidden = mint.mul(hidden, gate) - expert_output = self._group_matmul(hidden, mint.transpose(w2, -1, -2), group_list) + expert_output = self._group_matmul(hidden, mint.transpose(w2, -1, -2), local_group_list) expert_output = mint.nan_to_num(expert_output, 0, 0, 0) else: expert_output = self.dummy_token expert_output = self.all_to_all_v_across_ep_with_block_size(expert_output.reshape(-1), - recv_list, send_list) - moe_output = self.moe_token_unpermute(permuted_tokens=expert_output, - sorted_indices=unsort_map, - probs=topk_weights, - padded_mode=False, - restore_shape=None) - moe_output = self.broadcast_to_tensor_parallel_region(moe_output) + recv_list.tolist(), + send_list.tolist()) + if self.tp_rank == 0: + expert_output = expert_output.reshape(-1, self.hidden_size) + moe_output = self.moe_token_unpermute(permuted_tokens=expert_output, + sorted_indices=unsort_map, + probs=topk_weights, + padded_mode=False, + restore_shape=None) + moe_output = self.broadcast_to_tensor_parallel_region((moe_output,))[0] + else: + moe_output = self.broadcast_to_tensor_parallel_region((hidden_states,))[0] return moe_output -- Gitee From 4abea095abb0a5cc03bc90f9c5de80305ff6c28b Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Tue, 1 Jul 2025 11:48:24 +0800 Subject: [PATCH 47/77] support EP + TP --- .../layers/fused_moe/fused_moe2.py | 33 ++++++++----------- .../model_executor/layers/fused_moe/layer.py | 3 +- 2 files changed, 15 insertions(+), 21 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py index d22a72a9..8a665418 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py @@ -76,6 +76,7 @@ class FusedExperts(nn.Cell): # pure ep self.pure_ep = True self.tp_rank = get_tensor_model_parallel_rank() + self.tp_world_size = get_tensor_model_parallel_world_size() ep_size = moe_config.moe_parallel_config.ep_size self.ep_size = ep_size self.ep_group = get_ep_group().device_group._name @@ -103,12 +104,7 @@ class FusedExperts(nn.Cell): self.recv_experts_num_map = ms.Tensor(recv_num_map_list, dtype=ms.int64) self.local_expert_num = moe_config.num_local_experts - self.prepend_tensor = ms.Tensor([0], dtype=ms.int32) - - # self.all_to_all_across_ep = ops.AlltoAll(split_count=ep_size, - # split_dim=0, - # concat_dim=0, - # group=self.ep_group) + self.prepend_tensor = ms.Tensor([0], dtype=ms.int64) self.hidden_size = moe_config.hidden_dim self.all_to_all_v_across_ep_with_block_size = ops.AlltoAllV(block_size=self.hidden_size, @@ -134,9 +130,7 @@ class FusedExperts(nn.Cell): activation: str = "silu", global_num_experts: int = -1, apply_router_weight_on_input: bool = False, - expert_map: Optional[Tensor] = None, - tp_size: int = 1, - ep_size: int = 0) -> Tensor: + expert_map: Optional[Tensor] = None) -> Tensor: if self.pure_tp: hidden_states = self._run_tp_moe(hidden_states, w1, w2, topk_ids, topk_weights, @@ -196,13 +190,13 @@ class FusedExperts(nn.Cell): # group_list = group_list.reshape(1, -1) if self.tp_rank == 0: - group_list_cumsum = mint.cumsum(group_list, 0, dtype=ms.int32) + group_list_cumsum = mint.cumsum(group_list, 0, dtype=ms.int64) # expert index = [3, 7, 11, 15] (self.ep_group_size,) # 看下每个rank, 发送多少tensor 数据给其他的rank send_list = group_list_cumsum[self.experts_num_map_cu_index] # [20, 30, 40, 50] send_list = mint.diff(send_list, prepend=self.prepend_tensor) else: - send_list = mint.zeros(self.ep_size, dtype=ms.int32) # [0, 0, 0, 0] + send_list = mint.zeros(self.ep_size, dtype=ms.int64) # [0, 0, 0, 0] group_list = group_list.astype(ms.int64) @@ -210,8 +204,8 @@ class FusedExperts(nn.Cell): recv_list = self.all_to_all_v_across_ep(send_list, self.even_list, self.even_list) # recv_list [20, 40, 60, 70] local_input_tensor = self.all_to_all_v_across_ep_with_block_size(sorted_input_tensor.reshape(-1), - send_list.tolist(), - recv_list.tolist()) + send_list, + recv_list) local_group_list = self.all_to_all_v_across_ep(group_list, self.send_experts_num_map, @@ -232,8 +226,8 @@ class FusedExperts(nn.Cell): else: expert_output = self.dummy_token expert_output = self.all_to_all_v_across_ep_with_block_size(expert_output.reshape(-1), - recv_list.tolist(), - send_list.tolist()) + recv_list, + send_list) if self.tp_rank == 0: expert_output = expert_output.reshape(-1, self.hidden_size) moe_output = self.moe_token_unpermute(permuted_tokens=expert_output, @@ -241,12 +235,13 @@ class FusedExperts(nn.Cell): probs=topk_weights, padded_mode=False, restore_shape=None) - moe_output = self.broadcast_to_tensor_parallel_region((moe_output,))[0] - else: - moe_output = self.broadcast_to_tensor_parallel_region((hidden_states,))[0] + if self.tp_world_size > 0: + if self.tp_rank == 0: + moe_output = self.broadcast_to_tensor_parallel_region((moe_output,))[0] + else: + moe_output = self.broadcast_to_tensor_parallel_region((hidden_states,))[0] return moe_output - def _run_tp_moe(self, hidden_states, w1, diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index d3514411..01684916 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -383,8 +383,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, nn.Cell): global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, expert_map=expert_map, - tp_size=self.moe.tp_size, - ep_size=self.moe.ep_size, ) @@ -825,6 +823,7 @@ class FusedMoE(nn.Cell): """ To all_reduce after routed expert and shared expert are added. """ + # Do delay allreduce If "must_reduce_shared_expert_outputs" return True if self.pure_tp and self.dp_size == 1: return self.reduce_from_tp_group(final_hidden_states) return final_hidden_states -- Gitee From 1a677235982b730084ce594299d3365d98859514 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Tue, 1 Jul 2025 14:12:13 +0800 Subject: [PATCH 48/77] update --- .../model_executor/layers/fused_moe/layer.py | 75 ++++--------------- 1 file changed, 15 insertions(+), 60 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 01684916..27c07ab3 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -6,31 +6,20 @@ from dataclasses import dataclass from enum import Enum from typing import Callable, Optional -import torch -from torch.nn.parameter import UninitializedParameter - import vllm.envs as envs from vllm.config import ParallelConfig, get_current_vllm_config from vllm.distributed import (get_dp_group, get_ep_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) -from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger -from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.utils import set_weight_attrs -from vllm.platforms import current_platform -from vllm.platforms.interface import CpuArchEnum -from vllm.utils import direct_register_custom_op # from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig from vllm.model_executor.layers.fused_moe.layer import (determine_expert_map, - FusedMoeWeightScaleSupported, - FusedMoEMethodBase, - #MoEConfig, + FusedMoeWeightScaleSupported, + FusedMoEMethodBase, ) @@ -43,6 +32,7 @@ from vllm_mindspore.distributed.communication_op import ReduceFromModelParallelR from mindspore import nn, Tensor, Parameter, mint, ops import mindspore as ms +import mindspore._c_expression.typing.Type as ms_dtype logger = init_logger(__name__) @@ -58,25 +48,6 @@ class FusedMoEParallelConfig: use_ep: bool # whether to use EP or not - @property - def use_all2all_kernels(self): - return self.dp_size > 1 and self.use_ep - - @property - def use_pplx_kernels(self): - return (self.use_all2all_kernels - and envs.VLLM_ALL2ALL_BACKEND == "pplx") - - @property - def use_deepep_ht_kernels(self): - return (self.use_all2all_kernels - and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput") - - @property - def use_deepep_ll_kernels(self): - return (self.use_all2all_kernels - and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") - @staticmethod def make(tp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": @@ -192,8 +163,8 @@ class MoEConfig: num_local_experts: int moe_parallel_config: FusedMoEParallelConfig - in_dtype: torch.dtype # The activation type. - quant_dtype: torch.dtype = None + in_dtype: ms_dtype # The activation type. + quant_dtype: ms_dtype = None # TODO: add more quantization params, blocked, per-token, etc. block_size: int = 128 @@ -233,46 +204,34 @@ class MoEConfig: def use_ep(self): return self.moe_parallel_config.use_ep - @property - def use_pplx_kernels(self): - return self.moe_parallel_config.use_pplx_kernels - - @property - def use_deepep_ht_kernels(self): - return self.moe_parallel_config.use_deepep_ht_kernels - - @property - def use_deepep_ll_kernels(self): - return self.moe_parallel_config.use_deepep_ll_kernels - class FusedMoEMethodBase(QuantizeMethodBase): @abstractmethod - def create_weights(self, layer: torch.nn.Module, num_experts: int, + def create_weights(self, layer: nn.Cell, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): + params_dtype, **extra_weight_attrs): raise NotImplementedError @abstractmethod def apply( self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, + layer: nn.Cell, + x: Tensor, + router_logits: Tensor, top_k: int, renormalize: bool, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, + expert_map: Optional[Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: Optional[Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", - ) -> torch.Tensor: + ) -> Tensor: raise NotImplementedError class UnquantizedFusedMoEMethod(FusedMoEMethodBase, nn.Cell): @@ -577,10 +536,6 @@ class FusedMoE(nn.Cell): def use_ep(self): return self.moe_parallel_config.use_ep - @property - def use_pplx_kernels(self): - return self.moe_parallel_config.use_pplx_kernels - def _load_w13(self, param: Parameter, shard_dim: int, shard_id: str, loaded_weight: Tensor, expert_id: int, tp_rank: int, load_full: bool = False): @@ -703,7 +658,7 @@ class FusedMoE(nn.Cell): tp_rank=tp_rank, load_full=load_full_w3) - def weight_loader(self, param: torch.nn.Parameter, + def weight_loader(self, param: nn.Parameter, loaded_weight: Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: @@ -775,7 +730,7 @@ class FusedMoE(nn.Cell): custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[Tensor] = None, - indices_type: Optional[torch.dtype] = None): + indices_type=None): # DeekSeekv2 uses grouped_top_k if use_grouped_topk: -- Gitee From 848160a3be06847a8363416f099a2cbf0b994ccc Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Tue, 1 Jul 2025 14:19:25 +0800 Subject: [PATCH 49/77] update --- vllm_mindspore/model_executor/layers/fused_moe/layer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 27c07ab3..69604b7a 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -32,7 +32,6 @@ from vllm_mindspore.distributed.communication_op import ReduceFromModelParallelR from mindspore import nn, Tensor, Parameter, mint, ops import mindspore as ms -import mindspore._c_expression.typing.Type as ms_dtype logger = init_logger(__name__) @@ -163,8 +162,8 @@ class MoEConfig: num_local_experts: int moe_parallel_config: FusedMoEParallelConfig - in_dtype: ms_dtype # The activation type. - quant_dtype: ms_dtype = None + in_dtype: ms.dtype.Type # The activation type. + quant_dtype: ms.dtype.Type = None # TODO: add more quantization params, blocked, per-token, etc. block_size: int = 128 -- Gitee From 160b9f67f2ce6d2bea149e5cb5d0bd680a54e3b8 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Tue, 1 Jul 2025 14:21:26 +0800 Subject: [PATCH 50/77] update --- vllm_mindspore/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 69604b7a..f7005098 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -657,7 +657,7 @@ class FusedMoE(nn.Cell): tp_rank=tp_rank, load_full=load_full_w3) - def weight_loader(self, param: nn.Parameter, + def weight_loader(self, param: Parameter, loaded_weight: Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: -- Gitee From bdd830110b7ba44593bde55a12360d5dfe56ae59 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Tue, 1 Jul 2025 17:39:12 +0800 Subject: [PATCH 51/77] update --- .../layers/fused_moe/fused_moe2.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py index 8a665418..011ef8dd 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py @@ -200,22 +200,31 @@ class FusedExperts(nn.Cell): group_list = group_list.astype(ms.int64) - # recv_list = self.all_to_all_across_ep(send_list) - recv_list = self.all_to_all_v_across_ep(send_list, self.even_list, self.even_list) + local_group_list = self.all_to_all_v_across_ep(group_list, + self.send_experts_num_map, + self.recv_experts_num_map) + + local_group_list = local_group_list.reshape(-1, self.local_expert_num) + recv_list = local_group_list.sum(dim=1) + + # recv_list = self.all_to_all_v_across_ep(send_list, self.even_list, self.even_list) # recv_list [20, 40, 60, 70] local_input_tensor = self.all_to_all_v_across_ep_with_block_size(sorted_input_tensor.reshape(-1), send_list, recv_list) - local_group_list = self.all_to_all_v_across_ep(group_list, - self.send_experts_num_map, - self.recv_experts_num_map) - local_group_list = local_group_list.reshape(-1, self.local_expert_num) - local_group_list = local_group_list.sum(dim=0) + topk_ids_1d, _ = mint.sort(topk_ids.reshape(-1)) + topk_ids_local = self.all_to_all_v_across_ep(topk_ids_1d, send_list, recv_list) + local_group_list = local_group_list.sum(dim=0) recv_tokens = recv_list.sum() if recv_tokens > 0: + _, resort_index = mint.sort(topk_ids_local) + _, unresort_index = mint.sort(resort_index) + local_input_tensor = local_input_tensor.reshape(-1, self.hidden_size) + local_input_tensor = mint.index_select(local_input_tensor, 0, resort_index) + gate_hidden_out = self._group_matmul(local_input_tensor, mint.transpose(w1, -1, -2), local_group_list) gate, hidden = mint.split(gate_hidden_out, (w1.shape[1] // 2, w1.shape[1] // 2), -1) @@ -223,6 +232,8 @@ class FusedExperts(nn.Cell): hidden = mint.mul(hidden, gate) expert_output = self._group_matmul(hidden, mint.transpose(w2, -1, -2), local_group_list) expert_output = mint.nan_to_num(expert_output, 0, 0, 0) + + expert_output = mint.index_select(expert_output, 0, unresort_index) else: expert_output = self.dummy_token expert_output = self.all_to_all_v_across_ep_with_block_size(expert_output.reshape(-1), -- Gitee From ba2128b555cb882e82a5f673e8c63e512a7bc2b5 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Tue, 1 Jul 2025 19:59:35 +0800 Subject: [PATCH 52/77] update --- .../model_executor/layers/fused_moe/fused_moe2.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py index 011ef8dd..3f2a0acb 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py @@ -96,12 +96,16 @@ class FusedExperts(nn.Cell): tp_world_size = get_tensor_model_parallel_world_size() recv_num_map_list = [] + recv_list_index = [] for i in range(self.ep_size): if i % tp_world_size == 0: recv_num_map_list.append(moe_config.num_local_experts) + recv_list_index.append(i) else: recv_num_map_list.append(0) self.recv_experts_num_map = ms.Tensor(recv_num_map_list, dtype=ms.int64) + self.recv_list_index = ms.Tensor(recv_list_index, dtype=ms.int64) + self.local_expert_num = moe_config.num_local_experts self.prepend_tensor = ms.Tensor([0], dtype=ms.int64) @@ -205,7 +209,9 @@ class FusedExperts(nn.Cell): self.recv_experts_num_map) local_group_list = local_group_list.reshape(-1, self.local_expert_num) - recv_list = local_group_list.sum(dim=1) + recv_list_value = local_group_list.sum(dim=-1) + recv_list = mint.zeros(self.ep_size, dtype=ms.int64) + recv_list[self.recv_list_index] = recv_list_value # recv_list = self.all_to_all_v_across_ep(send_list, self.even_list, self.even_list) # recv_list [20, 40, 60, 70] -- Gitee From 959b1c4f35f1da87344b6dc3421f6d7909d9212c Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Wed, 2 Jul 2025 15:46:40 +0800 Subject: [PATCH 53/77] replace broadcast with allreduce --- .../layers/fused_moe/fused_moe2.py | 65 +++++++++++-------- 1 file changed, 39 insertions(+), 26 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py index 3f2a0acb..b128ba65 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py @@ -9,6 +9,7 @@ from mindspore.ops.auto_generate import (GroupedMatmulV4, import mindspore as ms from vllm.distributed.parallel_state import (get_ep_group, get_tp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm_mindspore.distributed.communication_op import ReduceFromModelParallelRegion def fused_topk( hidden_states: Tensor, @@ -69,17 +70,31 @@ class FusedExperts(nn.Cell): self.moe_token_unpermute = MoeTokenUnpermute() self.pure_tp = False - self.pure_ep = True + self.pure_ep = False + # pure ep mode if moe_config.moe_parallel_config.ep_size > 1 and \ moe_config.moe_parallel_config.tp_size == 1: - # pure ep self.pure_ep = True + + # some configuration for tensor model parallel region self.tp_rank = get_tensor_model_parallel_rank() self.tp_world_size = get_tensor_model_parallel_world_size() + self.tp_group = get_tp_group().device_group._name + self.all_reduce_across_tp = ReduceFromModelParallelRegion() + self.broadcast_to_tensor_parallel_region = ops.Broadcast(0, group=self.tp_group) + + + # some configuration for expert parallel region ep_size = moe_config.moe_parallel_config.ep_size self.ep_size = ep_size self.ep_group = get_ep_group().device_group._name + + # some configuration for experts + self.hidden_size = moe_config.hidden_dim + self.local_expert_num = moe_config.num_local_experts + + # some configuration for alltoall communication experts_num = moe_config.num_experts experts_num_map = [(experts_num // ep_size) for _ in range(ep_size - 1)] @@ -94,11 +109,10 @@ class FusedExperts(nn.Cell): else: self.send_experts_num_map = mint.zeros(ep_size, dtype=ms.int64) - tp_world_size = get_tensor_model_parallel_world_size() recv_num_map_list = [] recv_list_index = [] for i in range(self.ep_size): - if i % tp_world_size == 0: + if i % self.tp_world_size == 0: recv_num_map_list.append(moe_config.num_local_experts) recv_list_index.append(i) else: @@ -106,24 +120,22 @@ class FusedExperts(nn.Cell): self.recv_experts_num_map = ms.Tensor(recv_num_map_list, dtype=ms.int64) self.recv_list_index = ms.Tensor(recv_list_index, dtype=ms.int64) - self.local_expert_num = moe_config.num_local_experts - self.prepend_tensor = ms.Tensor([0], dtype=ms.int64) - self.hidden_size = moe_config.hidden_dim self.all_to_all_v_across_ep_with_block_size = ops.AlltoAllV(block_size=self.hidden_size, group=self.ep_group) self.all_to_all_v_across_ep = ops.AlltoAllV(group=self.ep_group) self.even_list = [1 for _ in range(ep_size)] - self.tp_group = get_tp_group().device_group._name - self.broadcast_to_tensor_parallel_region = ops.Broadcast(0, group=self.tp_group) - self.dummy_token = mint.zeros((1, self.hidden_size), dtype=moe_config.in_dtype) - if moe_config.moe_parallel_config.ep_size == 1 and \ + # pure tp mode + elif moe_config.moe_parallel_config.ep_size == 1 and \ moe_config.moe_parallel_config.tp_size >= 1: self.pure_tp = True + # tp + ep mode + else: + raise NotImplementedError("tp + ep mode not support yet.") def construct(self, hidden_states: Tensor, @@ -192,16 +204,6 @@ class FusedExperts(nn.Cell): expert_tokens_before_capacity_flag=True) # group_list = group_list.reshape(1, -1) - - if self.tp_rank == 0: - group_list_cumsum = mint.cumsum(group_list, 0, dtype=ms.int64) - # expert index = [3, 7, 11, 15] (self.ep_group_size,) - # 看下每个rank, 发送多少tensor 数据给其他的rank - send_list = group_list_cumsum[self.experts_num_map_cu_index] # [20, 30, 40, 50] - send_list = mint.diff(send_list, prepend=self.prepend_tensor) - else: - send_list = mint.zeros(self.ep_size, dtype=ms.int64) # [0, 0, 0, 0] - group_list = group_list.astype(ms.int64) local_group_list = self.all_to_all_v_across_ep(group_list, @@ -213,6 +215,15 @@ class FusedExperts(nn.Cell): recv_list = mint.zeros(self.ep_size, dtype=ms.int64) recv_list[self.recv_list_index] = recv_list_value + if self.tp_rank == 0: + group_list_cumsum = mint.cumsum(group_list, 0, dtype=ms.int64) + # expert index = [3, 7, 11, 15] (self.ep_group_size,) + # 看下每个rank, 发送多少tensor 数据给其他的rank + send_list = group_list_cumsum[self.experts_num_map_cu_index] # [20, 30, 40, 50] + send_list = mint.diff(send_list, prepend=self.prepend_tensor) + else: + send_list = mint.zeros(self.ep_size, dtype=ms.int64) # [0, 0, 0, 0] + # recv_list = self.all_to_all_v_across_ep(send_list, self.even_list, self.even_list) # recv_list [20, 40, 60, 70] local_input_tensor = self.all_to_all_v_across_ep_with_block_size(sorted_input_tensor.reshape(-1), @@ -252,11 +263,13 @@ class FusedExperts(nn.Cell): probs=topk_weights, padded_mode=False, restore_shape=None) - if self.tp_world_size > 0: - if self.tp_rank == 0: - moe_output = self.broadcast_to_tensor_parallel_region((moe_output,))[0] - else: - moe_output = self.broadcast_to_tensor_parallel_region((hidden_states,))[0] + else: + # moe_output = hidden_states + moe_output = mint.zeros_like(hidden_states) + + # if self.tp_world_size > 0: + # moe_output = self.broadcast_to_tensor_parallel_region((moe_output,))[0] + moe_output = self.all_reduce_across_tp(moe_output) return moe_output def _run_tp_moe(self, -- Gitee From 748c51cc020bb891090849be235d63f0395da039 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Thu, 3 Jul 2025 10:20:41 +0800 Subject: [PATCH 54/77] update good for jit ep --- .../model_executor/layers/fused_moe/fused_moe2.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py index b128ba65..2bef513f 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py @@ -129,6 +129,8 @@ class FusedExperts(nn.Cell): self.dummy_token = mint.zeros((1, self.hidden_size), dtype=moe_config.in_dtype) + self.depend = ops.Depend() + # pure tp mode elif moe_config.moe_parallel_config.ep_size == 1 and \ moe_config.moe_parallel_config.tp_size >= 1: @@ -264,12 +266,13 @@ class FusedExperts(nn.Cell): padded_mode=False, restore_shape=None) else: - # moe_output = hidden_states - moe_output = mint.zeros_like(hidden_states) + hidden_states = self.depend(hidden_states, expert_output) + moe_output = hidden_states + # moe_output = mint.zeros_like(hidden_states) - # if self.tp_world_size > 0: - # moe_output = self.broadcast_to_tensor_parallel_region((moe_output,))[0] - moe_output = self.all_reduce_across_tp(moe_output) + if self.tp_world_size > 0: + moe_output = self.broadcast_to_tensor_parallel_region((moe_output,))[0] + # moe_output = self.all_reduce_across_tp(moe_output) return moe_output def _run_tp_moe(self, -- Gitee From 67b8739f0ec46f08a40da6b7601311f8c522996e Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Fri, 4 Jul 2025 09:16:22 +0800 Subject: [PATCH 55/77] update --- .../model_executor/layers/fused_moe/fused_moe2.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py index 2bef513f..8b67d8e0 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py @@ -266,13 +266,14 @@ class FusedExperts(nn.Cell): padded_mode=False, restore_shape=None) else: - hidden_states = self.depend(hidden_states, expert_output) - moe_output = hidden_states - # moe_output = mint.zeros_like(hidden_states) - - if self.tp_world_size > 0: - moe_output = self.broadcast_to_tensor_parallel_region((moe_output,))[0] - # moe_output = self.all_reduce_across_tp(moe_output) + # hidden_states = self.depend(hidden_states, expert_output) + # moe_output = hidden_states + moe_output = mint.zeros_like(hidden_states) + moe_output = self.depend(moe_output, expert_output) + + # if self.tp_world_size > 0: + # moe_output = self.broadcast_to_tensor_parallel_region((moe_output,))[0] + moe_output = self.all_reduce_across_tp(moe_output) return moe_output def _run_tp_moe(self, -- Gitee From b975a8b0d55ef84dc86574142684ed55040a50e2 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Fri, 4 Jul 2025 11:23:15 +0800 Subject: [PATCH 56/77] add dispatch and combine --- .../layers/fused_moe/fused_moe2.py | 102 ++++++++++++++---- vllm_mindspore/model_executor/models/utils.py | 2 +- 2 files changed, 82 insertions(+), 22 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py index 8b67d8e0..fb5873e6 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py @@ -5,7 +5,9 @@ from mindspore import Tensor, mint, ops, nn from mindspore.ops.auto_generate import (GroupedMatmulV4, FusedAddTopKDiv, MoeInitRoutingV2, - MoeTokenUnpermute) + MoeTokenUnpermute, + MoeDistributeDispatch, + MoeDistributeCombine) import mindspore as ms from vllm.distributed.parallel_state import (get_ep_group, get_tp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -89,6 +91,7 @@ class FusedExperts(nn.Cell): ep_size = moe_config.moe_parallel_config.ep_size self.ep_size = ep_size self.ep_group = get_ep_group().device_group._name + self.ep_rank = get_ep_group().rank_in_group # some configuration for experts self.hidden_size = moe_config.hidden_dim @@ -96,6 +99,7 @@ class FusedExperts(nn.Cell): # some configuration for alltoall communication experts_num = moe_config.num_experts + self.expert_num = experts_num experts_num_map = [(experts_num // ep_size) for _ in range(ep_size - 1)] experts_num_map.append(experts_num - ((experts_num // ep_size) * (ep_size - 1))) @@ -131,6 +135,12 @@ class FusedExperts(nn.Cell): self.depend = ops.Depend() + self.dispatch = MoeDistributeDispatch() # only support in 910b and 910_A3 + self.combine = MoeDistributeCombine() # only support in 910b and 910_A3 + self.dispatch_tp_world_size = 0 + self.dispatch_shared_expert_num = 0 + self.max_bs = 256 * self.ep_size + # pure tp mode elif moe_config.moe_parallel_config.ep_size == 1 and \ moe_config.moe_parallel_config.tp_size >= 1: @@ -151,18 +161,18 @@ class FusedExperts(nn.Cell): expert_map: Optional[Tensor] = None) -> Tensor: if self.pure_tp: - hidden_states = self._run_tp_moe(hidden_states, w1, w2, topk_ids, topk_weights, + hidden_states = self.run_tp_moe(hidden_states, w1, w2, topk_ids, topk_weights, activation, global_num_experts, apply_router_weight_on_input) # ep_size > 1 : pure ep or tp + ep elif self.pure_ep: # pure ep - hidden_states = self._run_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, + hidden_states = self.run_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, activation, global_num_experts, apply_router_weight_on_input) # tp_size > 1 : tp + ep else: - hidden_states = self._run_tp_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, + hidden_states = self.run_tp_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, activation, global_num_experts, apply_router_weight_on_input) @@ -181,8 +191,18 @@ class FusedExperts(nn.Cell): return self.group_matmul_ops([hidden_states], [weight], None, None, None, None, None, None, group_list, split_item=3, group_type=0, group_list_type=1)[0] + + def _ffn(self, hidden_state, w1, w2, group_list, activation): + gate_hidden_out = self._group_matmul(hidden_state, mint.transpose(w1, -1, -2), hidden_state) + gate, hidden = mint.split(gate_hidden_out, + (w1.shape[1] // 2, w1.shape[1] // 2), -1) + gate = self._gate_activation(gate, activation) + hidden = mint.mul(hidden, gate) + expert_output = self._group_matmul(hidden, mint.transpose(w2, -1, -2), group_list) + expert_output = mint.nan_to_num(expert_output, 0, 0, 0) + return expert_output - def _run_ep_moe(self, + def run_ep_moe(self, hidden_states, w1, w2, @@ -191,6 +211,18 @@ class FusedExperts(nn.Cell): activation, global_num_experts, apply_router_weight_on_input): + + # return self._moe_with_dispatch_combine( + # hidden_states, + # w1, + # w2, + # topk_ids, + # topk_weights, + # activation, + # global_num_experts, + # apply_router_weight_on_input + # ) + topk_weights = topk_weights.astype(hidden_states.dtype) topk_ids = topk_ids.astype(ms.int32) @@ -244,13 +276,7 @@ class FusedExperts(nn.Cell): local_input_tensor = local_input_tensor.reshape(-1, self.hidden_size) local_input_tensor = mint.index_select(local_input_tensor, 0, resort_index) - gate_hidden_out = self._group_matmul(local_input_tensor, mint.transpose(w1, -1, -2), local_group_list) - gate, hidden = mint.split(gate_hidden_out, - (w1.shape[1] // 2, w1.shape[1] // 2), -1) - gate = self._gate_activation(gate, activation) - hidden = mint.mul(hidden, gate) - expert_output = self._group_matmul(hidden, mint.transpose(w2, -1, -2), local_group_list) - expert_output = mint.nan_to_num(expert_output, 0, 0, 0) + expert_output = self.ffn(local_input_tensor, w1, w2, local_group_list, activation) expert_output = mint.index_select(expert_output, 0, unresort_index) else: @@ -276,7 +302,7 @@ class FusedExperts(nn.Cell): moe_output = self.all_reduce_across_tp(moe_output) return moe_output - def _run_tp_moe(self, + def run_tp_moe(self, hidden_states, w1, w2, @@ -301,13 +327,8 @@ class FusedExperts(nn.Cell): group_list = group_list.astype(ms.int64) - gate_hidden_out = self._group_matmul(sorted_input_tensor, mint.transpose(w1, -1, -2), group_list) - gate, hidden = mint.split(gate_hidden_out, - (w1.shape[1] // 2, w1.shape[1] // 2), -1) - gate = self._gate_activation(gate, activation) - hidden = mint.mul(hidden, gate) - expert_output = self._group_matmul(hidden, mint.transpose(w2, -1, -2), group_list) - expert_output = mint.nan_to_num(expert_output, 0, 0, 0) + expert_output = self._ffn(sorted_input_tensor, w1, w2, group_list, activation) + moe_output = self.moe_token_unpermute(permuted_tokens=expert_output, sorted_indices=unsort_map, probs=topk_weights, @@ -316,7 +337,7 @@ class FusedExperts(nn.Cell): return moe_output - def _run_tp_ep_moe( + def run_tp_ep_moe( self, hidden_states, w1, @@ -328,3 +349,42 @@ class FusedExperts(nn.Cell): apply_router_weight_on_input): raise NotImplementedError( "TP + EP MoE is not implemented yet. Please use pure TP or pure EP MoE instead.") + + + def _moe_with_dispatch_combine(self, hidden_states, w1, w2, topk_ids, topk_weights, + activation, global_num_experts, + apply_router_weight_on_input): + """fused ops, moe feed forward with dispatch and combine.""" + # Dispatch + expand_x, _, expand_idx, expert_token_nums, ep_recv_counts, tp_recv_counts, _ = self.dispatch( + x=hidden_states, + expert_ids=topk_ids, + ep_world_size=self.ep_size, + ep_rank_id=self.ep_rank, + moe_expert_num=global_num_experts, + group_ep=self.ep_group, + tp_world_size=self.tp_world_size, + shared_expert_num=self.dispatch_shared_expert_num, + global_bs=self.max_bs, + expert_token_nums_type=1) + + # GroupMamtul + ffn_res = self._ffn(expand_x, w1, w2, expert_token_nums, activation) + + # Combine + moe_output = self.combine( + expand_x=ffn_res, + expert_ids=topk_ids, + expand_idx=expand_idx, + ep_send_counts=ep_recv_counts, + expert_scales=topk_weights, + ep_world_size=self.ep_size, + ep_rank_id=self.ep_rank, + moe_expert_num=global_num_experts, + tp_send_counts=tp_recv_counts, + group_ep=self.ep_group, + tp_world_size=self.dispatch_tp_world_size, + shared_expert_num=self.dispatch_shared_expert_num, + global_bs=self.max_bs) + + return moe_output diff --git a/vllm_mindspore/model_executor/models/utils.py b/vllm_mindspore/model_executor/models/utils.py index 26b5c268..3518bb6d 100644 --- a/vllm_mindspore/model_executor/models/utils.py +++ b/vllm_mindspore/model_executor/models/utils.py @@ -273,7 +273,7 @@ def get_pp_missing_layer_names(model: nn.Cell) -> list[str]: return _model_to_pp_missing_layer_names[model_id] missing_layer_names = [] - for cell, name in model.cells_and_names(): + for name, cell in model.cells_and_names(): if isinstance(cell, PPMissingLayer): # NOTE: the trailing dot is used to match the prefix of the layer. # without the dot, we could match a layer that is not missing, -- Gitee From dee421a1e8baaa82b17263fe4223490d12c402df Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Fri, 4 Jul 2025 12:04:01 +0800 Subject: [PATCH 57/77] update --- .../layers/fused_moe/fused_moe2.py | 131 ++++++++++-------- 1 file changed, 75 insertions(+), 56 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py index fb5873e6..21af0bbb 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py @@ -135,6 +135,7 @@ class FusedExperts(nn.Cell): self.depend = ops.Depend() + # some configuration for dispatch and combine self.dispatch = MoeDistributeDispatch() # only support in 910b and 910_A3 self.combine = MoeDistributeCombine() # only support in 910b and 910_A3 self.dispatch_tp_world_size = 0 @@ -191,9 +192,9 @@ class FusedExperts(nn.Cell): return self.group_matmul_ops([hidden_states], [weight], None, None, None, None, None, None, group_list, split_item=3, group_type=0, group_list_type=1)[0] - + def _ffn(self, hidden_state, w1, w2, group_list, activation): - gate_hidden_out = self._group_matmul(hidden_state, mint.transpose(w1, -1, -2), hidden_state) + gate_hidden_out = self._group_matmul(hidden_state, mint.transpose(w1, -1, -2), group_list) gate, hidden = mint.split(gate_hidden_out, (w1.shape[1] // 2, w1.shape[1] // 2), -1) gate = self._gate_activation(gate, activation) @@ -202,6 +203,54 @@ class FusedExperts(nn.Cell): expert_output = mint.nan_to_num(expert_output, 0, 0, 0) return expert_output + def run_tp_moe(self, + hidden_states, + w1, + w2, + topk_ids, + topk_weights, + activation, + global_num_experts, + apply_router_weight_on_input): + topk_weights = topk_weights.astype(hidden_states.dtype) + topk_ids = topk_ids.astype(ms.int32) + + sorted_input_tensor, unsort_map, group_list, _ = \ + self.moe_init_routing_op( + hidden_states, + topk_ids, + active_num=0, + expert_capacity=0, + expert_num=global_num_experts, + drop_pad_mode=0, + expert_tokens_count_or_cumsum_flag=2, + expert_tokens_before_capacity_flag=True) + + group_list = group_list.astype(ms.int64) + + expert_output = self._ffn(sorted_input_tensor, w1, w2, group_list, activation) + + moe_output = self.moe_token_unpermute(permuted_tokens=expert_output, + sorted_indices=unsort_map, + probs=topk_weights, + padded_mode=False, + restore_shape=None) + return moe_output + + + def run_tp_ep_moe( + self, + hidden_states, + w1, + w2, + group_list, + group_logits, + activation, + global_num_experts, + apply_router_weight_on_input): + raise NotImplementedError( + "TP + EP MoE is not implemented yet. Please use pure TP or pure EP MoE instead.") + def run_ep_moe(self, hidden_states, w1, @@ -211,8 +260,10 @@ class FusedExperts(nn.Cell): activation, global_num_experts, apply_router_weight_on_input): - - # return self._moe_with_dispatch_combine( + topk_weights = topk_weights.astype(hidden_states.dtype) + topk_ids = topk_ids.astype(ms.int32) + + # return self._ep_with_dispatch_combine( # hidden_states, # w1, # w2, @@ -223,8 +274,25 @@ class FusedExperts(nn.Cell): # apply_router_weight_on_input # ) - topk_weights = topk_weights.astype(hidden_states.dtype) - topk_ids = topk_ids.astype(ms.int32) + return self._ep_with_all_to_all_v( + hidden_states, + w1, + w2, + topk_ids, + topk_weights, + activation, + global_num_experts, + apply_router_weight_on_input) + + def _ep_with_all_to_all_v(self, + hidden_states, + w1, + w2, + topk_ids, + topk_weights, + activation, + global_num_experts, + apply_router_weight_on_input): sorted_input_tensor, unsort_map, group_list, _ = \ self.moe_init_routing_op( @@ -302,56 +370,7 @@ class FusedExperts(nn.Cell): moe_output = self.all_reduce_across_tp(moe_output) return moe_output - def run_tp_moe(self, - hidden_states, - w1, - w2, - topk_ids, - topk_weights, - activation, - global_num_experts, - apply_router_weight_on_input): - topk_weights = topk_weights.astype(hidden_states.dtype) - topk_ids = topk_ids.astype(ms.int32) - - sorted_input_tensor, unsort_map, group_list, _ = \ - self.moe_init_routing_op( - hidden_states, - topk_ids, - active_num=0, - expert_capacity=0, - expert_num=global_num_experts, - drop_pad_mode=0, - expert_tokens_count_or_cumsum_flag=2, - expert_tokens_before_capacity_flag=True) - - group_list = group_list.astype(ms.int64) - - expert_output = self._ffn(sorted_input_tensor, w1, w2, group_list, activation) - - moe_output = self.moe_token_unpermute(permuted_tokens=expert_output, - sorted_indices=unsort_map, - probs=topk_weights, - padded_mode=False, - restore_shape=None) - return moe_output - - - def run_tp_ep_moe( - self, - hidden_states, - w1, - w2, - group_list, - group_logits, - activation, - global_num_experts, - apply_router_weight_on_input): - raise NotImplementedError( - "TP + EP MoE is not implemented yet. Please use pure TP or pure EP MoE instead.") - - - def _moe_with_dispatch_combine(self, hidden_states, w1, w2, topk_ids, topk_weights, + def _ep_with_dispatch_combine(self, hidden_states, w1, w2, topk_ids, topk_weights, activation, global_num_experts, apply_router_weight_on_input): """fused ops, moe feed forward with dispatch and combine.""" -- Gitee From b50dfd973df0224689ad7323fc92ef5e7520c331 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Fri, 4 Jul 2025 15:36:16 +0800 Subject: [PATCH 58/77] add dispatch alltoall and chunked alltoall --- .../layers/fused_moe/fused_moe2.py | 107 +++++++++--------- .../model_executor/layers/fused_moe/layer.py | 74 ++++++++++++ .../model_executor/models/qwen3_moe.py | 1 + 3 files changed, 131 insertions(+), 51 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py index 21af0bbb..1666172b 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py @@ -92,55 +92,59 @@ class FusedExperts(nn.Cell): self.ep_size = ep_size self.ep_group = get_ep_group().device_group._name self.ep_rank = get_ep_group().rank_in_group + self.expert_num = moe_config.num_experts + + self.use_dispatch_kernels = moe_config.use_dispatch_kernels + + if self.use_dispatch_kernels: + # some configuration for dispatch and combine + self.dispatch = MoeDistributeDispatch() # only support in 910b and 910_A3 + self.combine = MoeDistributeCombine() # only support in 910b and 910_A3 + self.dispatch_tp_world_size = 0 + self.dispatch_shared_expert_num = 0 + self.max_bs = 256 * self.ep_size - # some configuration for experts - self.hidden_size = moe_config.hidden_dim - self.local_expert_num = moe_config.num_local_experts - - # some configuration for alltoall communication - experts_num = moe_config.num_experts - self.expert_num = experts_num - experts_num_map = [(experts_num // ep_size) - for _ in range(ep_size - 1)] - experts_num_map.append(experts_num - ((experts_num // ep_size) * (ep_size - 1))) - # self.experts_num_map = ms.Tensor(expert_num_map, dtype=ms.int64) - experts_num_map_np = np.array(experts_num_map, dtype=np.int64) - experts_num_map_cu_np = np.cumsum(experts_num_map_np, dtype=np.int64) - self.experts_num_map_cu_index = ms.Tensor(experts_num_map_cu_np - 1, dtype=ms.int64) - - if self.tp_rank == 0: - self.send_experts_num_map = ms.Tensor(experts_num_map, dtype=ms.int64) else: - self.send_experts_num_map = mint.zeros(ep_size, dtype=ms.int64) - - recv_num_map_list = [] - recv_list_index = [] - for i in range(self.ep_size): - if i % self.tp_world_size == 0: - recv_num_map_list.append(moe_config.num_local_experts) - recv_list_index.append(i) + # some configuration for experts + self.hidden_size = moe_config.hidden_dim + self.local_expert_num = moe_config.num_local_experts + + # some configuration for alltoall communication + experts_num = self.expert_num + experts_num_map = [(experts_num // ep_size) + for _ in range(ep_size - 1)] + experts_num_map.append(experts_num - ((experts_num // ep_size) * (ep_size - 1))) + # self.experts_num_map = ms.Tensor(expert_num_map, dtype=ms.int64) + experts_num_map_np = np.array(experts_num_map, dtype=np.int64) + experts_num_map_cu_np = np.cumsum(experts_num_map_np, dtype=np.int64) + self.experts_num_map_cu_index = ms.Tensor(experts_num_map_cu_np - 1, dtype=ms.int64) + + if self.tp_rank == 0: + self.send_experts_num_map = ms.Tensor(experts_num_map, dtype=ms.int64) else: - recv_num_map_list.append(0) - self.recv_experts_num_map = ms.Tensor(recv_num_map_list, dtype=ms.int64) - self.recv_list_index = ms.Tensor(recv_list_index, dtype=ms.int64) + self.send_experts_num_map = mint.zeros(ep_size, dtype=ms.int64) - self.prepend_tensor = ms.Tensor([0], dtype=ms.int64) + recv_num_map_list = [] + recv_list_index = [] + for i in range(self.ep_size): + if i % self.tp_world_size == 0: + recv_num_map_list.append(moe_config.num_local_experts) + recv_list_index.append(i) + else: + recv_num_map_list.append(0) + self.recv_experts_num_map = ms.Tensor(recv_num_map_list, dtype=ms.int64) + self.recv_list_index = ms.Tensor(recv_list_index, dtype=ms.int64) - self.all_to_all_v_across_ep_with_block_size = ops.AlltoAllV(block_size=self.hidden_size, - group=self.ep_group) - self.all_to_all_v_across_ep = ops.AlltoAllV(group=self.ep_group) - self.even_list = [1 for _ in range(ep_size)] + self.prepend_tensor = ms.Tensor([0], dtype=ms.int64) - self.dummy_token = mint.zeros((1, self.hidden_size), dtype=moe_config.in_dtype) + self.all_to_all_v_across_ep_with_block_size = ops.AlltoAllV(block_size=self.hidden_size, + group=self.ep_group) + self.all_to_all_v_across_ep = ops.AlltoAllV(group=self.ep_group) + self.even_list = [1 for _ in range(ep_size)] - self.depend = ops.Depend() + self.dummy_token = mint.zeros((1, self.hidden_size), dtype=moe_config.in_dtype) - # some configuration for dispatch and combine - self.dispatch = MoeDistributeDispatch() # only support in 910b and 910_A3 - self.combine = MoeDistributeCombine() # only support in 910b and 910_A3 - self.dispatch_tp_world_size = 0 - self.dispatch_shared_expert_num = 0 - self.max_bs = 256 * self.ep_size + self.depend = ops.Depend() # pure tp mode elif moe_config.moe_parallel_config.ep_size == 1 and \ @@ -263,16 +267,17 @@ class FusedExperts(nn.Cell): topk_weights = topk_weights.astype(hidden_states.dtype) topk_ids = topk_ids.astype(ms.int32) - # return self._ep_with_dispatch_combine( - # hidden_states, - # w1, - # w2, - # topk_ids, - # topk_weights, - # activation, - # global_num_experts, - # apply_router_weight_on_input - # ) + if self.use_dispatch_kernels: + return self._ep_with_dispatch_combine( + hidden_states, + w1, + w2, + topk_ids, + topk_weights, + activation, + global_num_experts, + apply_router_weight_on_input + ) return self._ep_with_all_to_all_v( hidden_states, diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index f7005098..8e5f2e7d 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -15,6 +15,7 @@ from vllm.distributed import (get_dp_group, get_ep_group, from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.utils import set_weight_attrs +from vllm.forward_context import get_forward_context # from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig from vllm.model_executor.layers.fused_moe.layer import (determine_expert_map, @@ -47,6 +48,15 @@ class FusedMoEParallelConfig: use_ep: bool # whether to use EP or not + @property + def use_all2all_kernels(self): + return self.dp_size > 1 and self.use_ep + + @property + def use_dispatch_kernels(self): + return (self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == 'dispatch') + @staticmethod def make(tp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": @@ -203,6 +213,9 @@ class MoEConfig: def use_ep(self): return self.moe_parallel_config.use_ep + @property + def use_dispatch_kernels(self): + return self.moe_parallel_config.use_dispatch_kernels class FusedMoEMethodBase(QuantizeMethodBase): @@ -535,6 +548,10 @@ class FusedMoE(nn.Cell): def use_ep(self): return self.moe_parallel_config.use_ep + @property + def use_dispatch_kernels(self): + return self.moe_parallel_config.use_dispatch_kernels + def _load_w13(self, param: Parameter, shard_dim: int, shard_id: str, loaded_weight: Tensor, expert_id: int, tp_rank: int, load_full: bool = False): @@ -788,6 +805,11 @@ class FusedMoE(nn.Cell): dp_unpad_index, dp_pad_index_with_offset, dp_unpad_index_total_with_offset): + if self.use_dispatch_kernels: + return self.forward_impl_chunked(hidden_states, router_logits, dp_pad_index, + dp_unpad_index, dp_pad_index_with_offset, + dp_unpad_index_total_with_offset) + return self.forward_impl(hidden_states, router_logits, dp_pad_index, dp_unpad_index, dp_pad_index_with_offset, dp_unpad_index_total_with_offset) @@ -852,6 +874,58 @@ class FusedMoE(nn.Cell): return final_hidden_states + def forward_impl_chunked(self, full_hidden_states: Tensor, + full_router_logits: Tensor): + + full_final_hidden_states = mint.empty_like(full_hidden_states) + + def process_chunk(chunk_start, chunk_end, skip_result_store=False): + chunk_size = chunk_end - chunk_start + hidden_states = full_hidden_states[chunk_start:chunk_end, :] + router_logits = full_router_logits[chunk_start:chunk_end, :] + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + activation=self.activation, + apply_router_weight_on_input=self.apply_router_weight_on_input, + ) + + if not skip_result_store: + full_final_hidden_states[chunk_start:chunk_end, :] = final_hidden_states + + ctx = get_forward_context() + max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu + moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens + + num_tokens = full_hidden_states.size(0) + for chunk_start_ in range(0, max_tokens_across_dp, + moe_dp_chunk_size_per_rank): + chunk_start = chunk_start_ + chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, + max_tokens_across_dp) + # clamp start and end + chunk_start = min(chunk_start, num_tokens - 1) + chunk_end = min(chunk_end, num_tokens) + + process_chunk(chunk_start, + chunk_end, + skip_result_store=chunk_start_ >= num_tokens) + + return full_final_hidden_states + @classmethod def make_expert_params_mapping( cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, diff --git a/vllm_mindspore/model_executor/models/qwen3_moe.py b/vllm_mindspore/model_executor/models/qwen3_moe.py index 882f584e..107e4cb5 100644 --- a/vllm_mindspore/model_executor/models/qwen3_moe.py +++ b/vllm_mindspore/model_executor/models/qwen3_moe.py @@ -39,6 +39,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.interfaces import SupportsPP from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from vllm.forward_context import get_forward_context from vllm_mindspore.attention import Attention from vllm_mindspore.model_executor.layers.activation import SiluAndMul -- Gitee From 0d91746d5da63f4def9fba83fe227eab79eedb44 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Fri, 4 Jul 2025 16:03:44 +0800 Subject: [PATCH 59/77] remove some not use code --- .../device_communicators/npu_communicator.py | 66 ------- .../layers/fused_moe/fused_moe.py | 173 ------------------ 2 files changed, 239 deletions(-) delete mode 100644 vllm_mindspore/distributed/device_communicators/npu_communicator.py delete mode 100644 vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py diff --git a/vllm_mindspore/distributed/device_communicators/npu_communicator.py b/vllm_mindspore/distributed/device_communicators/npu_communicator.py deleted file mode 100644 index cfb89294..00000000 --- a/vllm_mindspore/distributed/device_communicators/npu_communicator.py +++ /dev/null @@ -1,66 +0,0 @@ -from mindspore import Tensor -from mindspore.communication import get_rank, get_group_size -import torch.distributed as dist - -from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator -from vllm.distributed.parallel_state import (get_dp_group, - get_tp_group, - in_the_same_node_as) -from vllm.forward_context import get_forward_context - - -class NPUCommunicator(CudaCommunicator): - def __init__(self, - cpu_group, - device = None, - device_group = None, - unique_name: str = ""): - super().__init__(cpu_group, device, device_group, unique_name) - - # all2all lives in ep group, which is merged from dp and tp group - self.dp_group = get_dp_group() - self.tp_group = get_tp_group() - # no self.ep_group since self.ep_group is still in construction - # when we create this object - self.dp_rank = self.dp_group.rank_in_group - self.dp_world_size = self.dp_group.world_size - self.rank = dist.get_rank(cpu_group) - self.world_size = dist.get_world_size(cpu_group) - - # all2all communication often has separate implementations for - # intra-node and inter-node communication - self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0)) - def dispatch( - self, hidden_states: Tensor, - router_logits: Tensor) -> tuple[Tensor, Tensor]: - assert self.all2all_manager is not None - hidden_states, router_logits = self.all2all_manager.dispatch( - hidden_states, router_logits) - return hidden_states, router_logits - - def combine(self, hidden_states: Tensor) -> Tensor: - assert self.all2all_manager is not None - hidden_states = self.all2all_manager.combine(hidden_states) - return hidden_states - - def dispatch(self, hidden_states: Tensor, - router_logits: Tensor): - cu_tokens_across_dp_cpu = get_forward_context( - ).dp_metadata.cu_tokens_across_dp_cpu - - hidden_states = self.naive_multicast(hidden_states, - cu_tokens_across_dp_cpu) - router_logits = self.naive_multicast(router_logits, - cu_tokens_across_dp_cpu) - return hidden_states, router_logits - - def combine(self, hidden_states: Tensor) -> Tensor: - cu_tokens_across_dp_cpu = get_forward_context( - ).dp_metadata.cu_tokens_across_dp_cpu - start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ - self.dp_rank - 1] - end = cu_tokens_across_dp_cpu[self.dp_rank] - - all_hidden_states = self.dp_group.all_reduce(hidden_states) - hidden_states = all_hidden_states[start:end, :] - return hidden_states \ No newline at end of file diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py deleted file mode 100644 index a7b3bf7d..00000000 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py +++ /dev/null @@ -1,173 +0,0 @@ -from typing import Optional - -from mindspore import Tensor, mint, ops -from mindspore.ops.auto_generate import (GroupedMatmulV4, - FusedAddTopKDiv, - MoeInitRoutingV2, - MoeTokenUnpermute) -import mindspore as ms -from vllm.distributed.parallel_state import get_ep_group, get_dp_group - -def fused_topk( - hidden_states: Tensor, - gating_output: Tensor, - topk: int, - renormalize: bool, - indices_type = None, -) -> tuple[Tensor, Tensor]: - score = mint.softmax(gating_output, dim=-1) - topk_weights, topk_ids = mint.topk( - score, - k=topk, - dim=-1 - ) - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - - if indices_type is not None: - topk_ids = topk_ids.to(indices_type) - return topk_weights, topk_ids - - -def grouped_topk( - hidden_states: Tensor, - gating_output: Tensor, - topk: int, - renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[Tensor] = None -) -> tuple[Tensor, Tensor]: - fused_add_topk_div = FusedAddTopKDiv() - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") - scoring_type = 0 # sigmoid - topk_in_group = 2 - topk_weights, topk_ids = fused_add_topk_div( - gating_output, - e_score_correction_bias, - num_expert_group, - topk_group, - topk, - topk_in_group, - scoring_type, - renormalize) - - return topk_weights, topk_ids - - -def fused_experts(hidden_states: Tensor, - w1: Tensor, - w2: Tensor, - topk_weights: Tensor, - topk_ids: Tensor, - activation: str = "silu", - global_num_experts: int = -1, - apply_router_weight_on_input: bool = False, - expert_map: Optional[Tensor] = None, - tp_size: int = 1, - ep_size: int = 0) -> Tensor: - - if tp_size >= 1: - # no ep, pure tp - if ep_size == 1: - hidden_states = _run_tp_moe(hidden_states, w1, w2, topk_ids, topk_weights, - activation, global_num_experts, - apply_router_weight_on_input) - # ep_size > 1 : pure ep or tp + ep - else: - # pure ep - if tp_size == 1: - hidden_states = _run_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, - activation, global_num_experts, - apply_router_weight_on_input) - # tp_size > 1 : tp + ep - else: - hidden_states = _run_tp_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, - activation, global_num_experts, - apply_router_weight_on_input) - - return hidden_states - -def _gate_activation(gate, activation): - if activation == "silu": - return mint.nn.functional.silu(gate) - elif activation == "gelu": - return mint.nn.functional.gelu(gate) - else: - raise ValueError(f"Unsupported activation function: {activation}") - - -group_matmul_ops = GroupedMatmulV4() -moe_init_routing_op = MoeInitRoutingV2() -moe_token_unpermute = MoeTokenUnpermute() - -def _group_matmul(hidden_states, weight, group_list): - return group_matmul_ops([hidden_states], [weight], - None, None, None, None, None, None, - group_list, split_item=3, group_type=0, group_list_type=1)[0] - -def _run_ep_moe(hidden_states, - w1, - w2, - topk_ids, - topk_weights, - activation, - global_num_experts, - apply_router_weight_on_input): - hidden_states = _group_matmul(hidden_states, w1, topk_ids) - hidden_states = _gate_activation(hidden_states, activation) - hidden_states = _group_matmul(hidden_states, w2, topk_ids) - return hidden_states - - -def _run_tp_moe(hidden_states, - w1, - w2, - topk_ids, - topk_weights, - activation, - global_num_experts, - apply_router_weight_on_input): - topk_weights = topk_weights.astype(hidden_states.dtype) - topk_ids = topk_ids.astype(ms.int32) - - sorted_input_tensor, unsort_map, group_list, _ = \ - moe_init_routing_op( - hidden_states, - topk_ids, - active_num=0, - expert_capacity=0, - expert_num=global_num_experts, - drop_pad_mode=0, - expert_tokens_count_or_cumsum_flag=2, - expert_tokens_before_capacity_flag=True) - - group_list = group_list.astype(ms.int64) - - gate_hidden_out = _group_matmul(sorted_input_tensor, mint.transpose(w1, -1, -2), group_list) - gate, hidden = mint.split(gate_hidden_out, - (w1.shape[1] // 2, w1.shape[1] // 2), -1) - gate = _gate_activation(gate, activation) - hidden = mint.mul(hidden, gate) - expert_output = _group_matmul(hidden, mint.transpose(w2, -1, -2), group_list) - expert_output = mint.nan_to_num(expert_output, 0, 0, 0) - moe_output = moe_token_unpermute(permuted_tokens=expert_output, - sorted_indices=unsort_map, - probs=topk_weights, - padded_mode=False, - restore_shape=None) - return moe_output - - -def _run_tp_ep_moe(hidden_states, - w1, - w2, - group_list, - group_logits, - activation, - global_num_experts, - apply_router_weight_on_input): - raise NotImplementedError( - "TP + EP MoE is not implemented yet. Please use pure TP or pure EP MoE instead.") -- Gitee From 5efd293facd0dace19eff63bb133a88460f253a6 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Fri, 4 Jul 2025 17:53:58 +0800 Subject: [PATCH 60/77] update --- vllm_mindspore/model_executor/layers/fused_moe/layer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 8e5f2e7d..3f74b4ab 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -24,9 +24,9 @@ from vllm.model_executor.layers.fused_moe.layer import (determine_expert_map, ) -from vllm_mindspore.model_executor.layers.fused_moe.fused_moe import (fused_topk, - grouped_topk, - fused_experts) +# from vllm_mindspore.model_executor.layers.fused_moe.fused_moe import (fused_topk, + # grouped_topk, + # fused_experts) from vllm_mindspore.model_executor.layers.fused_moe.fused_moe2 import FusedExperts from vllm_mindspore.model_executor.layers.quantization.base_config import QuantizeMethodBase from vllm_mindspore.distributed.communication_op import ReduceFromModelParallelRegion -- Gitee From d70e5c0b24f9d2cc6119de13420fcb16685d19fe Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Fri, 4 Jul 2025 18:04:19 +0800 Subject: [PATCH 61/77] Revert "remove some not use code" This reverts commit 0d91746d5da63f4def9fba83fe227eab79eedb44. --- .../device_communicators/npu_communicator.py | 66 +++++++ .../layers/fused_moe/fused_moe.py | 173 ++++++++++++++++++ 2 files changed, 239 insertions(+) create mode 100644 vllm_mindspore/distributed/device_communicators/npu_communicator.py create mode 100644 vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py diff --git a/vllm_mindspore/distributed/device_communicators/npu_communicator.py b/vllm_mindspore/distributed/device_communicators/npu_communicator.py new file mode 100644 index 00000000..cfb89294 --- /dev/null +++ b/vllm_mindspore/distributed/device_communicators/npu_communicator.py @@ -0,0 +1,66 @@ +from mindspore import Tensor +from mindspore.communication import get_rank, get_group_size +import torch.distributed as dist + +from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator +from vllm.distributed.parallel_state import (get_dp_group, + get_tp_group, + in_the_same_node_as) +from vllm.forward_context import get_forward_context + + +class NPUCommunicator(CudaCommunicator): + def __init__(self, + cpu_group, + device = None, + device_group = None, + unique_name: str = ""): + super().__init__(cpu_group, device, device_group, unique_name) + + # all2all lives in ep group, which is merged from dp and tp group + self.dp_group = get_dp_group() + self.tp_group = get_tp_group() + # no self.ep_group since self.ep_group is still in construction + # when we create this object + self.dp_rank = self.dp_group.rank_in_group + self.dp_world_size = self.dp_group.world_size + self.rank = dist.get_rank(cpu_group) + self.world_size = dist.get_world_size(cpu_group) + + # all2all communication often has separate implementations for + # intra-node and inter-node communication + self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0)) + def dispatch( + self, hidden_states: Tensor, + router_logits: Tensor) -> tuple[Tensor, Tensor]: + assert self.all2all_manager is not None + hidden_states, router_logits = self.all2all_manager.dispatch( + hidden_states, router_logits) + return hidden_states, router_logits + + def combine(self, hidden_states: Tensor) -> Tensor: + assert self.all2all_manager is not None + hidden_states = self.all2all_manager.combine(hidden_states) + return hidden_states + + def dispatch(self, hidden_states: Tensor, + router_logits: Tensor): + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + + hidden_states = self.naive_multicast(hidden_states, + cu_tokens_across_dp_cpu) + router_logits = self.naive_multicast(router_logits, + cu_tokens_across_dp_cpu) + return hidden_states, router_logits + + def combine(self, hidden_states: Tensor) -> Tensor: + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ + self.dp_rank - 1] + end = cu_tokens_across_dp_cpu[self.dp_rank] + + all_hidden_states = self.dp_group.all_reduce(hidden_states) + hidden_states = all_hidden_states[start:end, :] + return hidden_states \ No newline at end of file diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py new file mode 100644 index 00000000..a7b3bf7d --- /dev/null +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py @@ -0,0 +1,173 @@ +from typing import Optional + +from mindspore import Tensor, mint, ops +from mindspore.ops.auto_generate import (GroupedMatmulV4, + FusedAddTopKDiv, + MoeInitRoutingV2, + MoeTokenUnpermute) +import mindspore as ms +from vllm.distributed.parallel_state import get_ep_group, get_dp_group + +def fused_topk( + hidden_states: Tensor, + gating_output: Tensor, + topk: int, + renormalize: bool, + indices_type = None, +) -> tuple[Tensor, Tensor]: + score = mint.softmax(gating_output, dim=-1) + topk_weights, topk_ids = mint.topk( + score, + k=topk, + dim=-1 + ) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + if indices_type is not None: + topk_ids = topk_ids.to(indices_type) + return topk_weights, topk_ids + + +def grouped_topk( + hidden_states: Tensor, + gating_output: Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[Tensor] = None +) -> tuple[Tensor, Tensor]: + fused_add_topk_div = FusedAddTopKDiv() + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + scoring_type = 0 # sigmoid + topk_in_group = 2 + topk_weights, topk_ids = fused_add_topk_div( + gating_output, + e_score_correction_bias, + num_expert_group, + topk_group, + topk, + topk_in_group, + scoring_type, + renormalize) + + return topk_weights, topk_ids + + +def fused_experts(hidden_states: Tensor, + w1: Tensor, + w2: Tensor, + topk_weights: Tensor, + topk_ids: Tensor, + activation: str = "silu", + global_num_experts: int = -1, + apply_router_weight_on_input: bool = False, + expert_map: Optional[Tensor] = None, + tp_size: int = 1, + ep_size: int = 0) -> Tensor: + + if tp_size >= 1: + # no ep, pure tp + if ep_size == 1: + hidden_states = _run_tp_moe(hidden_states, w1, w2, topk_ids, topk_weights, + activation, global_num_experts, + apply_router_weight_on_input) + # ep_size > 1 : pure ep or tp + ep + else: + # pure ep + if tp_size == 1: + hidden_states = _run_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, + activation, global_num_experts, + apply_router_weight_on_input) + # tp_size > 1 : tp + ep + else: + hidden_states = _run_tp_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, + activation, global_num_experts, + apply_router_weight_on_input) + + return hidden_states + +def _gate_activation(gate, activation): + if activation == "silu": + return mint.nn.functional.silu(gate) + elif activation == "gelu": + return mint.nn.functional.gelu(gate) + else: + raise ValueError(f"Unsupported activation function: {activation}") + + +group_matmul_ops = GroupedMatmulV4() +moe_init_routing_op = MoeInitRoutingV2() +moe_token_unpermute = MoeTokenUnpermute() + +def _group_matmul(hidden_states, weight, group_list): + return group_matmul_ops([hidden_states], [weight], + None, None, None, None, None, None, + group_list, split_item=3, group_type=0, group_list_type=1)[0] + +def _run_ep_moe(hidden_states, + w1, + w2, + topk_ids, + topk_weights, + activation, + global_num_experts, + apply_router_weight_on_input): + hidden_states = _group_matmul(hidden_states, w1, topk_ids) + hidden_states = _gate_activation(hidden_states, activation) + hidden_states = _group_matmul(hidden_states, w2, topk_ids) + return hidden_states + + +def _run_tp_moe(hidden_states, + w1, + w2, + topk_ids, + topk_weights, + activation, + global_num_experts, + apply_router_weight_on_input): + topk_weights = topk_weights.astype(hidden_states.dtype) + topk_ids = topk_ids.astype(ms.int32) + + sorted_input_tensor, unsort_map, group_list, _ = \ + moe_init_routing_op( + hidden_states, + topk_ids, + active_num=0, + expert_capacity=0, + expert_num=global_num_experts, + drop_pad_mode=0, + expert_tokens_count_or_cumsum_flag=2, + expert_tokens_before_capacity_flag=True) + + group_list = group_list.astype(ms.int64) + + gate_hidden_out = _group_matmul(sorted_input_tensor, mint.transpose(w1, -1, -2), group_list) + gate, hidden = mint.split(gate_hidden_out, + (w1.shape[1] // 2, w1.shape[1] // 2), -1) + gate = _gate_activation(gate, activation) + hidden = mint.mul(hidden, gate) + expert_output = _group_matmul(hidden, mint.transpose(w2, -1, -2), group_list) + expert_output = mint.nan_to_num(expert_output, 0, 0, 0) + moe_output = moe_token_unpermute(permuted_tokens=expert_output, + sorted_indices=unsort_map, + probs=topk_weights, + padded_mode=False, + restore_shape=None) + return moe_output + + +def _run_tp_ep_moe(hidden_states, + w1, + w2, + group_list, + group_logits, + activation, + global_num_experts, + apply_router_weight_on_input): + raise NotImplementedError( + "TP + EP MoE is not implemented yet. Please use pure TP or pure EP MoE instead.") -- Gitee From 2e805c4f92605baceb5c80d435c3ead3e51945bf Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Fri, 4 Jul 2025 18:04:48 +0800 Subject: [PATCH 62/77] Revert "update" This reverts commit 5efd293facd0dace19eff63bb133a88460f253a6. --- vllm_mindspore/model_executor/layers/fused_moe/layer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 3f74b4ab..8e5f2e7d 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -24,9 +24,9 @@ from vllm.model_executor.layers.fused_moe.layer import (determine_expert_map, ) -# from vllm_mindspore.model_executor.layers.fused_moe.fused_moe import (fused_topk, - # grouped_topk, - # fused_experts) +from vllm_mindspore.model_executor.layers.fused_moe.fused_moe import (fused_topk, + grouped_topk, + fused_experts) from vllm_mindspore.model_executor.layers.fused_moe.fused_moe2 import FusedExperts from vllm_mindspore.model_executor.layers.quantization.base_config import QuantizeMethodBase from vllm_mindspore.distributed.communication_op import ReduceFromModelParallelRegion -- Gitee From aa7f6384ec43bf8e20cb5ca33101ed69a2551fb5 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Fri, 4 Jul 2025 18:06:00 +0800 Subject: [PATCH 63/77] remove use code --- .../device_communicators/npu_communicator.py | 66 ------------------- 1 file changed, 66 deletions(-) delete mode 100644 vllm_mindspore/distributed/device_communicators/npu_communicator.py diff --git a/vllm_mindspore/distributed/device_communicators/npu_communicator.py b/vllm_mindspore/distributed/device_communicators/npu_communicator.py deleted file mode 100644 index cfb89294..00000000 --- a/vllm_mindspore/distributed/device_communicators/npu_communicator.py +++ /dev/null @@ -1,66 +0,0 @@ -from mindspore import Tensor -from mindspore.communication import get_rank, get_group_size -import torch.distributed as dist - -from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator -from vllm.distributed.parallel_state import (get_dp_group, - get_tp_group, - in_the_same_node_as) -from vllm.forward_context import get_forward_context - - -class NPUCommunicator(CudaCommunicator): - def __init__(self, - cpu_group, - device = None, - device_group = None, - unique_name: str = ""): - super().__init__(cpu_group, device, device_group, unique_name) - - # all2all lives in ep group, which is merged from dp and tp group - self.dp_group = get_dp_group() - self.tp_group = get_tp_group() - # no self.ep_group since self.ep_group is still in construction - # when we create this object - self.dp_rank = self.dp_group.rank_in_group - self.dp_world_size = self.dp_group.world_size - self.rank = dist.get_rank(cpu_group) - self.world_size = dist.get_world_size(cpu_group) - - # all2all communication often has separate implementations for - # intra-node and inter-node communication - self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0)) - def dispatch( - self, hidden_states: Tensor, - router_logits: Tensor) -> tuple[Tensor, Tensor]: - assert self.all2all_manager is not None - hidden_states, router_logits = self.all2all_manager.dispatch( - hidden_states, router_logits) - return hidden_states, router_logits - - def combine(self, hidden_states: Tensor) -> Tensor: - assert self.all2all_manager is not None - hidden_states = self.all2all_manager.combine(hidden_states) - return hidden_states - - def dispatch(self, hidden_states: Tensor, - router_logits: Tensor): - cu_tokens_across_dp_cpu = get_forward_context( - ).dp_metadata.cu_tokens_across_dp_cpu - - hidden_states = self.naive_multicast(hidden_states, - cu_tokens_across_dp_cpu) - router_logits = self.naive_multicast(router_logits, - cu_tokens_across_dp_cpu) - return hidden_states, router_logits - - def combine(self, hidden_states: Tensor) -> Tensor: - cu_tokens_across_dp_cpu = get_forward_context( - ).dp_metadata.cu_tokens_across_dp_cpu - start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ - self.dp_rank - 1] - end = cu_tokens_across_dp_cpu[self.dp_rank] - - all_hidden_states = self.dp_group.all_reduce(hidden_states) - hidden_states = all_hidden_states[start:end, :] - return hidden_states \ No newline at end of file -- Gitee From 95e382f47eca849fde1c9d78cdc8d5bf025fcd57 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Fri, 4 Jul 2025 18:13:22 +0800 Subject: [PATCH 64/77] add ep.patch --- vllm_dp/ep.patch | 80 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 vllm_dp/ep.patch diff --git a/vllm_dp/ep.patch b/vllm_dp/ep.patch new file mode 100644 index 00000000..9b7be259 --- /dev/null +++ b/vllm_dp/ep.patch @@ -0,0 +1,80 @@ +diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py +index fa493fefb..1426354ec 100644 +--- a/vllm/distributed/parallel_state.py ++++ b/vllm/distributed/parallel_state.py +@@ -761,6 +761,14 @@ def get_dp_group() -> GroupCoordinator: + return _DP + + ++_EP: Optional[GroupCoordinator] = None ++ ++ ++def get_ep_group() -> GroupCoordinator: ++ assert _EP is not None, ("expert parallel group is not initialized") ++ return _EP ++ ++ + def get_pp_group() -> GroupCoordinator: + assert _PP is not None, ( + "pipeline model parallel group is not initialized") +@@ -954,10 +962,21 @@ def initialize_model_parallel( + backend, + group_name="dp") + ++ global _EP ++ assert _EP is None, ("expert parallel group is already initialized") ++ group_ranks = all_ranks.transpose(1, 2).reshape( ++ -1, data_parallel_size * tensor_model_parallel_size).unbind(0) ++ group_ranks = [x.tolist() for x in group_ranks] ++ _EP = init_model_parallel_group(group_ranks, ++ get_world_group().local_rank, ++ backend, ++ group_name="ep") ++ + logger.info( + "rank %s in world size %s is assigned as " +- "DP rank %s, PP rank %s, TP rank %s", rank, world_size, +- _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group) ++ "DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, world_size, ++ _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group, ++ _EP.rank_in_group) + + + def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: +@@ -1068,6 +1087,10 @@ def destroy_model_parallel(): + _DP.destroy() + _DP = None + ++ global _EP ++ if _EP: ++ _EP.destroy() ++ _EP = None + + def destroy_distributed_environment(): + global _WORLD +diff --git a/vllm/envs.py b/vllm/envs.py +index 6067f5bdd..1becf9e1b 100644 +--- a/vllm/envs.py ++++ b/vllm/envs.py +@@ -106,6 +106,8 @@ if TYPE_CHECKING: + VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False + VLLM_TPU_BUCKET_PADDING_GAP: int = 0 + VLLM_USE_DEEP_GEMM: bool = False ++ VLLM_MOE_DP_CHUNK_SIZE: int = 256 ++ VLLM_ALL2ALL_BACKEND: str = "naive" + + + def get_default_cache_root(): +@@ -692,6 +694,12 @@ environment_variables: dict[str, Callable[[], Any]] = { + # Allow use of DeepGemm kernels for fused moe ops. + "VLLM_USE_DEEP_GEMM": + lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), ++ ++ "VLLM_ALL2ALL_BACKEND": ++ lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"), ++ ++ "VLLM_MOE_DP_CHUNK_SIZE": ++ lambda: int(os.getenv("VLLM_MOE_DP_CHUNK_SIZE", "256")), + } + + # end-env-vars-definition -- Gitee From 5487ff75e4eddf993b1e3b7a6f465f73c5bb60cc Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Mon, 7 Jul 2025 09:38:32 +0800 Subject: [PATCH 65/77] update --- .../device_communicators/__init__.py | 0 .../layers/fused_moe/fused_moe-bak.py | 173 ++++++++ .../layers/fused_moe/fused_moe.py | 416 ++++++++++++++---- .../layers/fused_moe/fused_moe2.py | 414 ----------------- .../model_executor/layers/fused_moe/layer.py | 8 +- 5 files changed, 505 insertions(+), 506 deletions(-) delete mode 100644 vllm_mindspore/distributed/device_communicators/__init__.py create mode 100644 vllm_mindspore/model_executor/layers/fused_moe/fused_moe-bak.py delete mode 100644 vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py diff --git a/vllm_mindspore/distributed/device_communicators/__init__.py b/vllm_mindspore/distributed/device_communicators/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe-bak.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe-bak.py new file mode 100644 index 00000000..a7b3bf7d --- /dev/null +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe-bak.py @@ -0,0 +1,173 @@ +from typing import Optional + +from mindspore import Tensor, mint, ops +from mindspore.ops.auto_generate import (GroupedMatmulV4, + FusedAddTopKDiv, + MoeInitRoutingV2, + MoeTokenUnpermute) +import mindspore as ms +from vllm.distributed.parallel_state import get_ep_group, get_dp_group + +def fused_topk( + hidden_states: Tensor, + gating_output: Tensor, + topk: int, + renormalize: bool, + indices_type = None, +) -> tuple[Tensor, Tensor]: + score = mint.softmax(gating_output, dim=-1) + topk_weights, topk_ids = mint.topk( + score, + k=topk, + dim=-1 + ) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + if indices_type is not None: + topk_ids = topk_ids.to(indices_type) + return topk_weights, topk_ids + + +def grouped_topk( + hidden_states: Tensor, + gating_output: Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[Tensor] = None +) -> tuple[Tensor, Tensor]: + fused_add_topk_div = FusedAddTopKDiv() + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + scoring_type = 0 # sigmoid + topk_in_group = 2 + topk_weights, topk_ids = fused_add_topk_div( + gating_output, + e_score_correction_bias, + num_expert_group, + topk_group, + topk, + topk_in_group, + scoring_type, + renormalize) + + return topk_weights, topk_ids + + +def fused_experts(hidden_states: Tensor, + w1: Tensor, + w2: Tensor, + topk_weights: Tensor, + topk_ids: Tensor, + activation: str = "silu", + global_num_experts: int = -1, + apply_router_weight_on_input: bool = False, + expert_map: Optional[Tensor] = None, + tp_size: int = 1, + ep_size: int = 0) -> Tensor: + + if tp_size >= 1: + # no ep, pure tp + if ep_size == 1: + hidden_states = _run_tp_moe(hidden_states, w1, w2, topk_ids, topk_weights, + activation, global_num_experts, + apply_router_weight_on_input) + # ep_size > 1 : pure ep or tp + ep + else: + # pure ep + if tp_size == 1: + hidden_states = _run_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, + activation, global_num_experts, + apply_router_weight_on_input) + # tp_size > 1 : tp + ep + else: + hidden_states = _run_tp_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, + activation, global_num_experts, + apply_router_weight_on_input) + + return hidden_states + +def _gate_activation(gate, activation): + if activation == "silu": + return mint.nn.functional.silu(gate) + elif activation == "gelu": + return mint.nn.functional.gelu(gate) + else: + raise ValueError(f"Unsupported activation function: {activation}") + + +group_matmul_ops = GroupedMatmulV4() +moe_init_routing_op = MoeInitRoutingV2() +moe_token_unpermute = MoeTokenUnpermute() + +def _group_matmul(hidden_states, weight, group_list): + return group_matmul_ops([hidden_states], [weight], + None, None, None, None, None, None, + group_list, split_item=3, group_type=0, group_list_type=1)[0] + +def _run_ep_moe(hidden_states, + w1, + w2, + topk_ids, + topk_weights, + activation, + global_num_experts, + apply_router_weight_on_input): + hidden_states = _group_matmul(hidden_states, w1, topk_ids) + hidden_states = _gate_activation(hidden_states, activation) + hidden_states = _group_matmul(hidden_states, w2, topk_ids) + return hidden_states + + +def _run_tp_moe(hidden_states, + w1, + w2, + topk_ids, + topk_weights, + activation, + global_num_experts, + apply_router_weight_on_input): + topk_weights = topk_weights.astype(hidden_states.dtype) + topk_ids = topk_ids.astype(ms.int32) + + sorted_input_tensor, unsort_map, group_list, _ = \ + moe_init_routing_op( + hidden_states, + topk_ids, + active_num=0, + expert_capacity=0, + expert_num=global_num_experts, + drop_pad_mode=0, + expert_tokens_count_or_cumsum_flag=2, + expert_tokens_before_capacity_flag=True) + + group_list = group_list.astype(ms.int64) + + gate_hidden_out = _group_matmul(sorted_input_tensor, mint.transpose(w1, -1, -2), group_list) + gate, hidden = mint.split(gate_hidden_out, + (w1.shape[1] // 2, w1.shape[1] // 2), -1) + gate = _gate_activation(gate, activation) + hidden = mint.mul(hidden, gate) + expert_output = _group_matmul(hidden, mint.transpose(w2, -1, -2), group_list) + expert_output = mint.nan_to_num(expert_output, 0, 0, 0) + moe_output = moe_token_unpermute(permuted_tokens=expert_output, + sorted_indices=unsort_map, + probs=topk_weights, + padded_mode=False, + restore_shape=None) + return moe_output + + +def _run_tp_ep_moe(hidden_states, + w1, + w2, + group_list, + group_logits, + activation, + global_num_experts, + apply_router_weight_on_input): + raise NotImplementedError( + "TP + EP MoE is not implemented yet. Please use pure TP or pure EP MoE instead.") diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py index a7b3bf7d..2e6285f6 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py @@ -1,12 +1,17 @@ from typing import Optional -from mindspore import Tensor, mint, ops +import numpy as np +from mindspore import Tensor, mint, ops, nn from mindspore.ops.auto_generate import (GroupedMatmulV4, FusedAddTopKDiv, MoeInitRoutingV2, - MoeTokenUnpermute) + MoeTokenUnpermute, + MoeDistributeDispatch, + MoeDistributeCombine) import mindspore as ms -from vllm.distributed.parallel_state import get_ep_group, get_dp_group +from vllm.distributed.parallel_state import (get_ep_group, get_tp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm_mindspore.distributed.communication_op import ReduceFromModelParallelRegion def fused_topk( hidden_states: Tensor, @@ -57,7 +62,98 @@ def grouped_topk( return topk_weights, topk_ids -def fused_experts(hidden_states: Tensor, +class FusedExperts(nn.Cell): + def __init__(self, moe_config): + super().__init__() + self.group_matmul_ops = GroupedMatmulV4() + self.moe_init_routing_op = MoeInitRoutingV2() + self.moe_token_unpermute = MoeTokenUnpermute() + + self.pure_tp = False + self.pure_ep = False + + # pure ep mode + if moe_config.moe_parallel_config.ep_size > 1 and \ + moe_config.moe_parallel_config.tp_size == 1: + self.pure_ep = True + + # some configuration for tensor model parallel region + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_world_size = get_tensor_model_parallel_world_size() + self.tp_group = get_tp_group().device_group._name + self.all_reduce_across_tp = ReduceFromModelParallelRegion() + self.broadcast_to_tensor_parallel_region = ops.Broadcast(0, group=self.tp_group) + + + # some configuration for expert parallel region + ep_size = moe_config.moe_parallel_config.ep_size + self.ep_size = ep_size + self.ep_group = get_ep_group().device_group._name + self.ep_rank = get_ep_group().rank_in_group + self.expert_num = moe_config.num_experts + + self.use_dispatch_kernels = moe_config.use_dispatch_kernels + + if self.use_dispatch_kernels: + # some configuration for dispatch and combine + self.dispatch = MoeDistributeDispatch() # only support in 910b and 910_A3 + self.combine = MoeDistributeCombine() # only support in 910b and 910_A3 + self.dispatch_tp_world_size = 0 + self.dispatch_shared_expert_num = 0 + self.max_bs = 256 * self.ep_size + + else: + # some configuration for experts + self.hidden_size = moe_config.hidden_dim + self.local_expert_num = moe_config.num_local_experts + + # some configuration for alltoall communication + experts_num = self.expert_num + experts_num_map = [(experts_num // ep_size) + for _ in range(ep_size - 1)] + experts_num_map.append(experts_num - ((experts_num // ep_size) * (ep_size - 1))) + # self.experts_num_map = ms.Tensor(expert_num_map, dtype=ms.int64) + experts_num_map_np = np.array(experts_num_map, dtype=np.int64) + experts_num_map_cu_np = np.cumsum(experts_num_map_np, dtype=np.int64) + self.experts_num_map_cu_index = ms.Tensor(experts_num_map_cu_np - 1, dtype=ms.int64) + + if self.tp_rank == 0: + self.send_experts_num_map = ms.Tensor(experts_num_map, dtype=ms.int64) + else: + self.send_experts_num_map = mint.zeros(ep_size, dtype=ms.int64) + + recv_num_map_list = [] + recv_list_index = [] + for i in range(self.ep_size): + if i % self.tp_world_size == 0: + recv_num_map_list.append(moe_config.num_local_experts) + recv_list_index.append(i) + else: + recv_num_map_list.append(0) + self.recv_experts_num_map = ms.Tensor(recv_num_map_list, dtype=ms.int64) + self.recv_list_index = ms.Tensor(recv_list_index, dtype=ms.int64) + + self.prepend_tensor = ms.Tensor([0], dtype=ms.int64) + + self.all_to_all_v_across_ep_with_block_size = ops.AlltoAllV(block_size=self.hidden_size, + group=self.ep_group) + self.all_to_all_v_across_ep = ops.AlltoAllV(group=self.ep_group) + self.even_list = [1 for _ in range(ep_size)] + + self.dummy_token = mint.zeros((1, self.hidden_size), dtype=moe_config.in_dtype) + + self.depend = ops.Depend() + + # pure tp mode + elif moe_config.moe_parallel_config.ep_size == 1 and \ + moe_config.moe_parallel_config.tp_size >= 1: + self.pure_tp = True + # tp + ep mode + else: + raise NotImplementedError("tp + ep mode not support yet.") + + def construct(self, + hidden_states: Tensor, w1: Tensor, w2: Tensor, topk_weights: Tensor, @@ -65,109 +161,251 @@ def fused_experts(hidden_states: Tensor, activation: str = "silu", global_num_experts: int = -1, apply_router_weight_on_input: bool = False, - expert_map: Optional[Tensor] = None, - tp_size: int = 1, - ep_size: int = 0) -> Tensor: - - if tp_size >= 1: - # no ep, pure tp - if ep_size == 1: - hidden_states = _run_tp_moe(hidden_states, w1, w2, topk_ids, topk_weights, + expert_map: Optional[Tensor] = None) -> Tensor: + + if self.pure_tp: + hidden_states = self.run_tp_moe(hidden_states, w1, w2, topk_ids, topk_weights, + activation, global_num_experts, + apply_router_weight_on_input) + # ep_size > 1 : pure ep or tp + ep + elif self.pure_ep: + # pure ep + hidden_states = self.run_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, activation, global_num_experts, apply_router_weight_on_input) - # ep_size > 1 : pure ep or tp + ep - else: - # pure ep - if tp_size == 1: - hidden_states = _run_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, - activation, global_num_experts, - apply_router_weight_on_input) # tp_size > 1 : tp + ep - else: - hidden_states = _run_tp_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, - activation, global_num_experts, - apply_router_weight_on_input) + else: + hidden_states = self.run_tp_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, + activation, global_num_experts, + apply_router_weight_on_input) - return hidden_states + return hidden_states -def _gate_activation(gate, activation): - if activation == "silu": - return mint.nn.functional.silu(gate) - elif activation == "gelu": - return mint.nn.functional.gelu(gate) - else: - raise ValueError(f"Unsupported activation function: {activation}") + def _gate_activation(self, gate, activation): + if activation == "silu": + return mint.nn.functional.silu(gate) + elif activation == "gelu": + return mint.nn.functional.gelu(gate) + else: + raise ValueError(f"Unsupported activation function: {activation}") -group_matmul_ops = GroupedMatmulV4() -moe_init_routing_op = MoeInitRoutingV2() -moe_token_unpermute = MoeTokenUnpermute() + def _group_matmul(self, hidden_states, weight, group_list): + return self.group_matmul_ops([hidden_states], [weight], + None, None, None, None, None, None, + group_list, split_item=3, group_type=0, group_list_type=1)[0] -def _group_matmul(hidden_states, weight, group_list): - return group_matmul_ops([hidden_states], [weight], - None, None, None, None, None, None, - group_list, split_item=3, group_type=0, group_list_type=1)[0] + def _ffn(self, hidden_state, w1, w2, group_list, activation): + gate_hidden_out = self._group_matmul(hidden_state, mint.transpose(w1, -1, -2), group_list) + gate, hidden = mint.split(gate_hidden_out, + (w1.shape[1] // 2, w1.shape[1] // 2), -1) + gate = self._gate_activation(gate, activation) + hidden = mint.mul(hidden, gate) + expert_output = self._group_matmul(hidden, mint.transpose(w2, -1, -2), group_list) + expert_output = mint.nan_to_num(expert_output, 0, 0, 0) + return expert_output -def _run_ep_moe(hidden_states, - w1, - w2, + def run_tp_moe(self, + hidden_states, + w1, + w2, + topk_ids, + topk_weights, + activation, + global_num_experts, + apply_router_weight_on_input): + topk_weights = topk_weights.astype(hidden_states.dtype) + topk_ids = topk_ids.astype(ms.int32) + + sorted_input_tensor, unsort_map, group_list, _ = \ + self.moe_init_routing_op( + hidden_states, topk_ids, - topk_weights, - activation, - global_num_experts, - apply_router_weight_on_input): - hidden_states = _group_matmul(hidden_states, w1, topk_ids) - hidden_states = _gate_activation(hidden_states, activation) - hidden_states = _group_matmul(hidden_states, w2, topk_ids) - return hidden_states + active_num=0, + expert_capacity=0, + expert_num=global_num_experts, + drop_pad_mode=0, + expert_tokens_count_or_cumsum_flag=2, + expert_tokens_before_capacity_flag=True) + + group_list = group_list.astype(ms.int64) + + expert_output = self._ffn(sorted_input_tensor, w1, w2, group_list, activation) + moe_output = self.moe_token_unpermute(permuted_tokens=expert_output, + sorted_indices=unsort_map, + probs=topk_weights, + padded_mode=False, + restore_shape=None) + return moe_output -def _run_tp_moe(hidden_states, + + def run_tp_ep_moe(self, + hidden_states, + w1, + w2, + group_list, + group_logits, + activation, + global_num_experts, + apply_router_weight_on_input): + raise NotImplementedError( + "TP + EP MoE is not implemented yet. Please use pure TP or pure EP MoE instead.") + + def run_ep_moe(self, + hidden_states, + w1, + w2, + topk_ids, + topk_weights, + activation, + global_num_experts, + apply_router_weight_on_input): + topk_weights = topk_weights.astype(hidden_states.dtype) + topk_ids = topk_ids.astype(ms.int32) + + if self.use_dispatch_kernels: + return self._ep_with_dispatch_combine( + hidden_states, w1, w2, topk_ids, topk_weights, activation, global_num_experts, - apply_router_weight_on_input): - topk_weights = topk_weights.astype(hidden_states.dtype) - topk_ids = topk_ids.astype(ms.int32) + apply_router_weight_on_input + ) - sorted_input_tensor, unsort_map, group_list, _ = \ - moe_init_routing_op( + return self._ep_with_all_to_all_v( hidden_states, + w1, + w2, topk_ids, - active_num=0, - expert_capacity=0, - expert_num=global_num_experts, - drop_pad_mode=0, - expert_tokens_count_or_cumsum_flag=2, - expert_tokens_before_capacity_flag=True) - - group_list = group_list.astype(ms.int64) - - gate_hidden_out = _group_matmul(sorted_input_tensor, mint.transpose(w1, -1, -2), group_list) - gate, hidden = mint.split(gate_hidden_out, - (w1.shape[1] // 2, w1.shape[1] // 2), -1) - gate = _gate_activation(gate, activation) - hidden = mint.mul(hidden, gate) - expert_output = _group_matmul(hidden, mint.transpose(w2, -1, -2), group_list) - expert_output = mint.nan_to_num(expert_output, 0, 0, 0) - moe_output = moe_token_unpermute(permuted_tokens=expert_output, - sorted_indices=unsort_map, - probs=topk_weights, - padded_mode=False, - restore_shape=None) - return moe_output - - -def _run_tp_ep_moe(hidden_states, - w1, - w2, - group_list, - group_logits, - activation, - global_num_experts, - apply_router_weight_on_input): - raise NotImplementedError( - "TP + EP MoE is not implemented yet. Please use pure TP or pure EP MoE instead.") + topk_weights, + activation, + global_num_experts, + apply_router_weight_on_input) + + def _ep_with_all_to_all_v(self, + hidden_states, + w1, + w2, + topk_ids, + topk_weights, + activation, + global_num_experts, + apply_router_weight_on_input): + + sorted_input_tensor, unsort_map, group_list, _ = \ + self.moe_init_routing_op( + hidden_states, + topk_ids, + active_num=0, + expert_capacity=0, + expert_num=global_num_experts, + drop_pad_mode=0, + expert_tokens_count_or_cumsum_flag=2, + expert_tokens_before_capacity_flag=True) + + # group_list = group_list.reshape(1, -1) + group_list = group_list.astype(ms.int64) + + local_group_list = self.all_to_all_v_across_ep(group_list, + self.send_experts_num_map, + self.recv_experts_num_map) + + local_group_list = local_group_list.reshape(-1, self.local_expert_num) + recv_list_value = local_group_list.sum(dim=-1) + recv_list = mint.zeros(self.ep_size, dtype=ms.int64) + recv_list[self.recv_list_index] = recv_list_value + + if self.tp_rank == 0: + group_list_cumsum = mint.cumsum(group_list, 0, dtype=ms.int64) + # expert index = [3, 7, 11, 15] (self.ep_group_size,) + # 看下每个rank, 发送多少tensor 数据给其他的rank + send_list = group_list_cumsum[self.experts_num_map_cu_index] # [20, 30, 40, 50] + send_list = mint.diff(send_list, prepend=self.prepend_tensor) + else: + send_list = mint.zeros(self.ep_size, dtype=ms.int64) # [0, 0, 0, 0] + + # recv_list = self.all_to_all_v_across_ep(send_list, self.even_list, self.even_list) + # recv_list [20, 40, 60, 70] + local_input_tensor = self.all_to_all_v_across_ep_with_block_size(sorted_input_tensor.reshape(-1), + send_list, + recv_list) + + topk_ids_1d, _ = mint.sort(topk_ids.reshape(-1)) + topk_ids_local = self.all_to_all_v_across_ep(topk_ids_1d, send_list, recv_list) + + local_group_list = local_group_list.sum(dim=0) + recv_tokens = recv_list.sum() + if recv_tokens > 0: + _, resort_index = mint.sort(topk_ids_local) + _, unresort_index = mint.sort(resort_index) + + local_input_tensor = local_input_tensor.reshape(-1, self.hidden_size) + local_input_tensor = mint.index_select(local_input_tensor, 0, resort_index) + + expert_output = self.ffn(local_input_tensor, w1, w2, local_group_list, activation) + + expert_output = mint.index_select(expert_output, 0, unresort_index) + else: + expert_output = self.dummy_token + expert_output = self.all_to_all_v_across_ep_with_block_size(expert_output.reshape(-1), + recv_list, + send_list) + if self.tp_rank == 0: + expert_output = expert_output.reshape(-1, self.hidden_size) + moe_output = self.moe_token_unpermute(permuted_tokens=expert_output, + sorted_indices=unsort_map, + probs=topk_weights, + padded_mode=False, + restore_shape=None) + else: + # hidden_states = self.depend(hidden_states, expert_output) + # moe_output = hidden_states + moe_output = mint.zeros_like(hidden_states) + moe_output = self.depend(moe_output, expert_output) + + # if self.tp_world_size > 0: + # moe_output = self.broadcast_to_tensor_parallel_region((moe_output,))[0] + moe_output = self.all_reduce_across_tp(moe_output) + return moe_output + + def _ep_with_dispatch_combine(self, hidden_states, w1, w2, topk_ids, topk_weights, + activation, global_num_experts, + apply_router_weight_on_input): + """fused ops, moe feed forward with dispatch and combine.""" + # Dispatch + expand_x, _, expand_idx, expert_token_nums, ep_recv_counts, tp_recv_counts, _ = self.dispatch( + x=hidden_states, + expert_ids=topk_ids, + ep_world_size=self.ep_size, + ep_rank_id=self.ep_rank, + moe_expert_num=global_num_experts, + group_ep=self.ep_group, + tp_world_size=self.tp_world_size, + shared_expert_num=self.dispatch_shared_expert_num, + global_bs=self.max_bs, + expert_token_nums_type=1) + + # GroupMamtul + ffn_res = self._ffn(expand_x, w1, w2, expert_token_nums, activation) + + # Combine + moe_output = self.combine( + expand_x=ffn_res, + expert_ids=topk_ids, + expand_idx=expand_idx, + ep_send_counts=ep_recv_counts, + expert_scales=topk_weights, + ep_world_size=self.ep_size, + ep_rank_id=self.ep_rank, + moe_expert_num=global_num_experts, + tp_send_counts=tp_recv_counts, + group_ep=self.ep_group, + tp_world_size=self.dispatch_tp_world_size, + shared_expert_num=self.dispatch_shared_expert_num, + global_bs=self.max_bs) + + return moe_output diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py deleted file mode 100644 index 1666172b..00000000 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe2.py +++ /dev/null @@ -1,414 +0,0 @@ -from typing import Optional - -import numpy as np -from mindspore import Tensor, mint, ops, nn -from mindspore.ops.auto_generate import (GroupedMatmulV4, - FusedAddTopKDiv, - MoeInitRoutingV2, - MoeTokenUnpermute, - MoeDistributeDispatch, - MoeDistributeCombine) -import mindspore as ms -from vllm.distributed.parallel_state import (get_ep_group, get_tp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) -from vllm_mindspore.distributed.communication_op import ReduceFromModelParallelRegion - -def fused_topk( - hidden_states: Tensor, - gating_output: Tensor, - topk: int, - renormalize: bool, - indices_type = None, -) -> tuple[Tensor, Tensor]: - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") - score = mint.softmax(gating_output, dim=-1) - topk_weights, topk_ids = mint.topk( - score, - k=topk, - dim=-1 - ) - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - - if indices_type is not None: - topk_ids = topk_ids.to(indices_type) - return topk_weights, topk_ids - - -def grouped_topk( - hidden_states: Tensor, - gating_output: Tensor, - topk: int, - renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[Tensor] = None -) -> tuple[Tensor, Tensor]: - fused_add_topk_div = FusedAddTopKDiv() - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") - scoring_type = 0 # sigmoid - topk_in_group = 2 - topk_weights, topk_ids = fused_add_topk_div( - gating_output, - e_score_correction_bias, - num_expert_group, - topk_group, - topk, - topk_in_group, - scoring_type, - renormalize) - - return topk_weights, topk_ids - - -class FusedExperts(nn.Cell): - def __init__(self, moe_config): - super().__init__() - self.group_matmul_ops = GroupedMatmulV4() - self.moe_init_routing_op = MoeInitRoutingV2() - self.moe_token_unpermute = MoeTokenUnpermute() - - self.pure_tp = False - self.pure_ep = False - - # pure ep mode - if moe_config.moe_parallel_config.ep_size > 1 and \ - moe_config.moe_parallel_config.tp_size == 1: - self.pure_ep = True - - # some configuration for tensor model parallel region - self.tp_rank = get_tensor_model_parallel_rank() - self.tp_world_size = get_tensor_model_parallel_world_size() - self.tp_group = get_tp_group().device_group._name - self.all_reduce_across_tp = ReduceFromModelParallelRegion() - self.broadcast_to_tensor_parallel_region = ops.Broadcast(0, group=self.tp_group) - - - # some configuration for expert parallel region - ep_size = moe_config.moe_parallel_config.ep_size - self.ep_size = ep_size - self.ep_group = get_ep_group().device_group._name - self.ep_rank = get_ep_group().rank_in_group - self.expert_num = moe_config.num_experts - - self.use_dispatch_kernels = moe_config.use_dispatch_kernels - - if self.use_dispatch_kernels: - # some configuration for dispatch and combine - self.dispatch = MoeDistributeDispatch() # only support in 910b and 910_A3 - self.combine = MoeDistributeCombine() # only support in 910b and 910_A3 - self.dispatch_tp_world_size = 0 - self.dispatch_shared_expert_num = 0 - self.max_bs = 256 * self.ep_size - - else: - # some configuration for experts - self.hidden_size = moe_config.hidden_dim - self.local_expert_num = moe_config.num_local_experts - - # some configuration for alltoall communication - experts_num = self.expert_num - experts_num_map = [(experts_num // ep_size) - for _ in range(ep_size - 1)] - experts_num_map.append(experts_num - ((experts_num // ep_size) * (ep_size - 1))) - # self.experts_num_map = ms.Tensor(expert_num_map, dtype=ms.int64) - experts_num_map_np = np.array(experts_num_map, dtype=np.int64) - experts_num_map_cu_np = np.cumsum(experts_num_map_np, dtype=np.int64) - self.experts_num_map_cu_index = ms.Tensor(experts_num_map_cu_np - 1, dtype=ms.int64) - - if self.tp_rank == 0: - self.send_experts_num_map = ms.Tensor(experts_num_map, dtype=ms.int64) - else: - self.send_experts_num_map = mint.zeros(ep_size, dtype=ms.int64) - - recv_num_map_list = [] - recv_list_index = [] - for i in range(self.ep_size): - if i % self.tp_world_size == 0: - recv_num_map_list.append(moe_config.num_local_experts) - recv_list_index.append(i) - else: - recv_num_map_list.append(0) - self.recv_experts_num_map = ms.Tensor(recv_num_map_list, dtype=ms.int64) - self.recv_list_index = ms.Tensor(recv_list_index, dtype=ms.int64) - - self.prepend_tensor = ms.Tensor([0], dtype=ms.int64) - - self.all_to_all_v_across_ep_with_block_size = ops.AlltoAllV(block_size=self.hidden_size, - group=self.ep_group) - self.all_to_all_v_across_ep = ops.AlltoAllV(group=self.ep_group) - self.even_list = [1 for _ in range(ep_size)] - - self.dummy_token = mint.zeros((1, self.hidden_size), dtype=moe_config.in_dtype) - - self.depend = ops.Depend() - - # pure tp mode - elif moe_config.moe_parallel_config.ep_size == 1 and \ - moe_config.moe_parallel_config.tp_size >= 1: - self.pure_tp = True - # tp + ep mode - else: - raise NotImplementedError("tp + ep mode not support yet.") - - def construct(self, - hidden_states: Tensor, - w1: Tensor, - w2: Tensor, - topk_weights: Tensor, - topk_ids: Tensor, - activation: str = "silu", - global_num_experts: int = -1, - apply_router_weight_on_input: bool = False, - expert_map: Optional[Tensor] = None) -> Tensor: - - if self.pure_tp: - hidden_states = self.run_tp_moe(hidden_states, w1, w2, topk_ids, topk_weights, - activation, global_num_experts, - apply_router_weight_on_input) - # ep_size > 1 : pure ep or tp + ep - elif self.pure_ep: - # pure ep - hidden_states = self.run_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, - activation, global_num_experts, - apply_router_weight_on_input) - # tp_size > 1 : tp + ep - else: - hidden_states = self.run_tp_ep_moe(hidden_states, w1, w2, topk_ids, topk_weights, - activation, global_num_experts, - apply_router_weight_on_input) - - return hidden_states - - - def _gate_activation(self, gate, activation): - if activation == "silu": - return mint.nn.functional.silu(gate) - elif activation == "gelu": - return mint.nn.functional.gelu(gate) - else: - raise ValueError(f"Unsupported activation function: {activation}") - - def _group_matmul(self, hidden_states, weight, group_list): - return self.group_matmul_ops([hidden_states], [weight], - None, None, None, None, None, None, - group_list, split_item=3, group_type=0, group_list_type=1)[0] - - def _ffn(self, hidden_state, w1, w2, group_list, activation): - gate_hidden_out = self._group_matmul(hidden_state, mint.transpose(w1, -1, -2), group_list) - gate, hidden = mint.split(gate_hidden_out, - (w1.shape[1] // 2, w1.shape[1] // 2), -1) - gate = self._gate_activation(gate, activation) - hidden = mint.mul(hidden, gate) - expert_output = self._group_matmul(hidden, mint.transpose(w2, -1, -2), group_list) - expert_output = mint.nan_to_num(expert_output, 0, 0, 0) - return expert_output - - def run_tp_moe(self, - hidden_states, - w1, - w2, - topk_ids, - topk_weights, - activation, - global_num_experts, - apply_router_weight_on_input): - topk_weights = topk_weights.astype(hidden_states.dtype) - topk_ids = topk_ids.astype(ms.int32) - - sorted_input_tensor, unsort_map, group_list, _ = \ - self.moe_init_routing_op( - hidden_states, - topk_ids, - active_num=0, - expert_capacity=0, - expert_num=global_num_experts, - drop_pad_mode=0, - expert_tokens_count_or_cumsum_flag=2, - expert_tokens_before_capacity_flag=True) - - group_list = group_list.astype(ms.int64) - - expert_output = self._ffn(sorted_input_tensor, w1, w2, group_list, activation) - - moe_output = self.moe_token_unpermute(permuted_tokens=expert_output, - sorted_indices=unsort_map, - probs=topk_weights, - padded_mode=False, - restore_shape=None) - return moe_output - - - def run_tp_ep_moe( - self, - hidden_states, - w1, - w2, - group_list, - group_logits, - activation, - global_num_experts, - apply_router_weight_on_input): - raise NotImplementedError( - "TP + EP MoE is not implemented yet. Please use pure TP or pure EP MoE instead.") - - def run_ep_moe(self, - hidden_states, - w1, - w2, - topk_ids, - topk_weights, - activation, - global_num_experts, - apply_router_weight_on_input): - topk_weights = topk_weights.astype(hidden_states.dtype) - topk_ids = topk_ids.astype(ms.int32) - - if self.use_dispatch_kernels: - return self._ep_with_dispatch_combine( - hidden_states, - w1, - w2, - topk_ids, - topk_weights, - activation, - global_num_experts, - apply_router_weight_on_input - ) - - return self._ep_with_all_to_all_v( - hidden_states, - w1, - w2, - topk_ids, - topk_weights, - activation, - global_num_experts, - apply_router_weight_on_input) - - def _ep_with_all_to_all_v(self, - hidden_states, - w1, - w2, - topk_ids, - topk_weights, - activation, - global_num_experts, - apply_router_weight_on_input): - - sorted_input_tensor, unsort_map, group_list, _ = \ - self.moe_init_routing_op( - hidden_states, - topk_ids, - active_num=0, - expert_capacity=0, - expert_num=global_num_experts, - drop_pad_mode=0, - expert_tokens_count_or_cumsum_flag=2, - expert_tokens_before_capacity_flag=True) - - # group_list = group_list.reshape(1, -1) - group_list = group_list.astype(ms.int64) - - local_group_list = self.all_to_all_v_across_ep(group_list, - self.send_experts_num_map, - self.recv_experts_num_map) - - local_group_list = local_group_list.reshape(-1, self.local_expert_num) - recv_list_value = local_group_list.sum(dim=-1) - recv_list = mint.zeros(self.ep_size, dtype=ms.int64) - recv_list[self.recv_list_index] = recv_list_value - - if self.tp_rank == 0: - group_list_cumsum = mint.cumsum(group_list, 0, dtype=ms.int64) - # expert index = [3, 7, 11, 15] (self.ep_group_size,) - # 看下每个rank, 发送多少tensor 数据给其他的rank - send_list = group_list_cumsum[self.experts_num_map_cu_index] # [20, 30, 40, 50] - send_list = mint.diff(send_list, prepend=self.prepend_tensor) - else: - send_list = mint.zeros(self.ep_size, dtype=ms.int64) # [0, 0, 0, 0] - - # recv_list = self.all_to_all_v_across_ep(send_list, self.even_list, self.even_list) - # recv_list [20, 40, 60, 70] - local_input_tensor = self.all_to_all_v_across_ep_with_block_size(sorted_input_tensor.reshape(-1), - send_list, - recv_list) - - topk_ids_1d, _ = mint.sort(topk_ids.reshape(-1)) - topk_ids_local = self.all_to_all_v_across_ep(topk_ids_1d, send_list, recv_list) - - local_group_list = local_group_list.sum(dim=0) - recv_tokens = recv_list.sum() - if recv_tokens > 0: - _, resort_index = mint.sort(topk_ids_local) - _, unresort_index = mint.sort(resort_index) - - local_input_tensor = local_input_tensor.reshape(-1, self.hidden_size) - local_input_tensor = mint.index_select(local_input_tensor, 0, resort_index) - - expert_output = self.ffn(local_input_tensor, w1, w2, local_group_list, activation) - - expert_output = mint.index_select(expert_output, 0, unresort_index) - else: - expert_output = self.dummy_token - expert_output = self.all_to_all_v_across_ep_with_block_size(expert_output.reshape(-1), - recv_list, - send_list) - if self.tp_rank == 0: - expert_output = expert_output.reshape(-1, self.hidden_size) - moe_output = self.moe_token_unpermute(permuted_tokens=expert_output, - sorted_indices=unsort_map, - probs=topk_weights, - padded_mode=False, - restore_shape=None) - else: - # hidden_states = self.depend(hidden_states, expert_output) - # moe_output = hidden_states - moe_output = mint.zeros_like(hidden_states) - moe_output = self.depend(moe_output, expert_output) - - # if self.tp_world_size > 0: - # moe_output = self.broadcast_to_tensor_parallel_region((moe_output,))[0] - moe_output = self.all_reduce_across_tp(moe_output) - return moe_output - - def _ep_with_dispatch_combine(self, hidden_states, w1, w2, topk_ids, topk_weights, - activation, global_num_experts, - apply_router_weight_on_input): - """fused ops, moe feed forward with dispatch and combine.""" - # Dispatch - expand_x, _, expand_idx, expert_token_nums, ep_recv_counts, tp_recv_counts, _ = self.dispatch( - x=hidden_states, - expert_ids=topk_ids, - ep_world_size=self.ep_size, - ep_rank_id=self.ep_rank, - moe_expert_num=global_num_experts, - group_ep=self.ep_group, - tp_world_size=self.tp_world_size, - shared_expert_num=self.dispatch_shared_expert_num, - global_bs=self.max_bs, - expert_token_nums_type=1) - - # GroupMamtul - ffn_res = self._ffn(expand_x, w1, w2, expert_token_nums, activation) - - # Combine - moe_output = self.combine( - expand_x=ffn_res, - expert_ids=topk_ids, - expand_idx=expand_idx, - ep_send_counts=ep_recv_counts, - expert_scales=topk_weights, - ep_world_size=self.ep_size, - ep_rank_id=self.ep_rank, - moe_expert_num=global_num_experts, - tp_send_counts=tp_recv_counts, - group_ep=self.ep_group, - tp_world_size=self.dispatch_tp_world_size, - shared_expert_num=self.dispatch_shared_expert_num, - global_bs=self.max_bs) - - return moe_output diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 8e5f2e7d..1e228219 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -24,10 +24,12 @@ from vllm.model_executor.layers.fused_moe.layer import (determine_expert_map, ) +# from vllm_mindspore.model_executor.layers.fused_moe.fused_moe import (fused_topk, +# grouped_topk, +# fused_experts) from vllm_mindspore.model_executor.layers.fused_moe.fused_moe import (fused_topk, - grouped_topk, - fused_experts) -from vllm_mindspore.model_executor.layers.fused_moe.fused_moe2 import FusedExperts + grouped_topk, + FusedExperts) from vllm_mindspore.model_executor.layers.quantization.base_config import QuantizeMethodBase from vllm_mindspore.distributed.communication_op import ReduceFromModelParallelRegion -- Gitee From 9a230bc9a80e1d90d62740fa3b7a9fd0c2277acb Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Mon, 7 Jul 2025 09:40:53 +0800 Subject: [PATCH 66/77] update --- vllm_mindspore/platforms/ascend.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_mindspore/platforms/ascend.py b/vllm_mindspore/platforms/ascend.py index c61e3978..43d5d177 100644 --- a/vllm_mindspore/platforms/ascend.py +++ b/vllm_mindspore/platforms/ascend.py @@ -137,7 +137,6 @@ class AscendPlatform(Platform): """Get device specific communicator class for distributed communication.""" if envs.VLLM_USE_V1: return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" - # return "vllm_mindspore.distributed.device_communicators.npu_communicator.NPUCommunicator" return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" @classmethod -- Gitee From a39b79c14e988e44d21dbb01741e4810d3d5a255 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Mon, 7 Jul 2025 09:51:49 +0800 Subject: [PATCH 67/77] update --- vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py index 2e6285f6..c5894b29 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py @@ -346,7 +346,7 @@ class FusedExperts(nn.Cell): local_input_tensor = local_input_tensor.reshape(-1, self.hidden_size) local_input_tensor = mint.index_select(local_input_tensor, 0, resort_index) - expert_output = self.ffn(local_input_tensor, w1, w2, local_group_list, activation) + expert_output = self._ffn(local_input_tensor, w1, w2, local_group_list, activation) expert_output = mint.index_select(expert_output, 0, unresort_index) else: -- Gitee From ac65e4cf9c3965b97b80848ffe51f9973ec4954b Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Mon, 7 Jul 2025 09:54:53 +0800 Subject: [PATCH 68/77] transpose w1 and w2 in init --- .../layers/fused_moe/fused_moe.py | 6 ++++-- .../model_executor/layers/fused_moe/layer.py | 18 ++++++++++++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py index c5894b29..7ed4224c 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py @@ -196,12 +196,14 @@ class FusedExperts(nn.Cell): group_list, split_item=3, group_type=0, group_list_type=1)[0] def _ffn(self, hidden_state, w1, w2, group_list, activation): - gate_hidden_out = self._group_matmul(hidden_state, mint.transpose(w1, -1, -2), group_list) + # gate_hidden_out = self._group_matmul(hidden_state, mint.transpose(w1, -1, -2), group_list) + gate_hidden_out = self._group_matmul(hidden_state, w1, group_list) gate, hidden = mint.split(gate_hidden_out, (w1.shape[1] // 2, w1.shape[1] // 2), -1) gate = self._gate_activation(gate, activation) hidden = mint.mul(hidden, gate) - expert_output = self._group_matmul(hidden, mint.transpose(w2, -1, -2), group_list) + # expert_output = self._group_matmul(hidden, mint.transpose(w2, -1, -2), group_list) + expert_output = self._group_matmul(hidden, w2, group_list) expert_output = mint.nan_to_num(expert_output, 0, 0, 0) return expert_output diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 1e228219..07c8e2d5 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -261,24 +261,38 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, nn.Cell): hidden_size: int, intermediate_size_per_partition: int, params_dtype, **extra_weight_attrs): # Fused gate_up_proj (column parallel) + # w13_weight = Parameter(mint.empty( + # num_experts, + # 2 * intermediate_size_per_partition, + # hidden_size, + # dtype=params_dtype), + # requires_grad=False) w13_weight = Parameter(mint.empty( num_experts, - 2 * intermediate_size_per_partition, hidden_size, + 2 * intermediate_size_per_partition, dtype=params_dtype), requires_grad=False) layer.insert_param_to_cell("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) + set_weight_attrs(w13_weight, {"is_transposed": True}) # down_proj (row parallel) + # w2_weight = Parameter(mint.empty( + # num_experts, + # hidden_size, + # intermediate_size_per_partition, + # dtype=params_dtype), + # requires_grad=False) w2_weight = Parameter(mint.empty( num_experts, - hidden_size, intermediate_size_per_partition, + hidden_size, dtype=params_dtype), requires_grad=False) layer.insert_param_to_cell("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) + set_weight_attrs(w2_weight, {"is_transposed": True}) def apply( self, -- Gitee From 6dfcc95d901315e4947a8deecd6f5dc82e52e145 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Mon, 7 Jul 2025 10:04:51 +0800 Subject: [PATCH 69/77] Revert "transpose w1 and w2 in init" This reverts commit ac65e4cf9c3965b97b80848ffe51f9973ec4954b. --- .../layers/fused_moe/fused_moe.py | 6 ++---- .../model_executor/layers/fused_moe/layer.py | 18 ++---------------- 2 files changed, 4 insertions(+), 20 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py index 7ed4224c..c5894b29 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py @@ -196,14 +196,12 @@ class FusedExperts(nn.Cell): group_list, split_item=3, group_type=0, group_list_type=1)[0] def _ffn(self, hidden_state, w1, w2, group_list, activation): - # gate_hidden_out = self._group_matmul(hidden_state, mint.transpose(w1, -1, -2), group_list) - gate_hidden_out = self._group_matmul(hidden_state, w1, group_list) + gate_hidden_out = self._group_matmul(hidden_state, mint.transpose(w1, -1, -2), group_list) gate, hidden = mint.split(gate_hidden_out, (w1.shape[1] // 2, w1.shape[1] // 2), -1) gate = self._gate_activation(gate, activation) hidden = mint.mul(hidden, gate) - # expert_output = self._group_matmul(hidden, mint.transpose(w2, -1, -2), group_list) - expert_output = self._group_matmul(hidden, w2, group_list) + expert_output = self._group_matmul(hidden, mint.transpose(w2, -1, -2), group_list) expert_output = mint.nan_to_num(expert_output, 0, 0, 0) return expert_output diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 07c8e2d5..1e228219 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -261,38 +261,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, nn.Cell): hidden_size: int, intermediate_size_per_partition: int, params_dtype, **extra_weight_attrs): # Fused gate_up_proj (column parallel) - # w13_weight = Parameter(mint.empty( - # num_experts, - # 2 * intermediate_size_per_partition, - # hidden_size, - # dtype=params_dtype), - # requires_grad=False) w13_weight = Parameter(mint.empty( num_experts, - hidden_size, 2 * intermediate_size_per_partition, + hidden_size, dtype=params_dtype), requires_grad=False) layer.insert_param_to_cell("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - set_weight_attrs(w13_weight, {"is_transposed": True}) # down_proj (row parallel) - # w2_weight = Parameter(mint.empty( - # num_experts, - # hidden_size, - # intermediate_size_per_partition, - # dtype=params_dtype), - # requires_grad=False) w2_weight = Parameter(mint.empty( num_experts, - intermediate_size_per_partition, hidden_size, + intermediate_size_per_partition, dtype=params_dtype), requires_grad=False) layer.insert_param_to_cell("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) - set_weight_attrs(w2_weight, {"is_transposed": True}) def apply( self, -- Gitee From 23df86c9863c762f70af58fa6fcf4aca53aaf759 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Mon, 7 Jul 2025 21:25:30 +0800 Subject: [PATCH 70/77] update depend --- vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py index c5894b29..30e5a0e8 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py @@ -330,11 +330,13 @@ class FusedExperts(nn.Cell): # recv_list = self.all_to_all_v_across_ep(send_list, self.even_list, self.even_list) # recv_list [20, 40, 60, 70] + recv_list = self.depend(recv_list, local_group_list) local_input_tensor = self.all_to_all_v_across_ep_with_block_size(sorted_input_tensor.reshape(-1), send_list, recv_list) topk_ids_1d, _ = mint.sort(topk_ids.reshape(-1)) + topk_ids_1d = self.depend(topk_ids_1d, local_input_tensor) topk_ids_local = self.all_to_all_v_across_ep(topk_ids_1d, send_list, recv_list) local_group_list = local_group_list.sum(dim=0) @@ -351,6 +353,7 @@ class FusedExperts(nn.Cell): expert_output = mint.index_select(expert_output, 0, unresort_index) else: expert_output = self.dummy_token + expert_output = self.depend(expert_output, topk_ids_local) expert_output = self.all_to_all_v_across_ep_with_block_size(expert_output.reshape(-1), recv_list, send_list) -- Gitee From 4489ae488ff34390f6d4adf2842a9ff7a9045db2 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Tue, 8 Jul 2025 16:13:01 +0800 Subject: [PATCH 71/77] fix big ep --- .../model_executor/layers/fused_moe/layer.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 1e228219..3d1f370b 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -35,6 +35,7 @@ from vllm_mindspore.distributed.communication_op import ReduceFromModelParallelR from mindspore import nn, Tensor, Parameter, mint, ops import mindspore as ms +from mindspore.ops import ReduceOp logger = init_logger(__name__) @@ -522,6 +523,9 @@ class FusedMoE(nn.Cell): self.all_reduce_from_dp_group = ops.AllReduce(group=self.dp_group) self.reduce_scatter_from_ep_group = ops.ReduceScatter(group=self.ep_group) + if self.dp_size > 1 and self.ep_size > 1: + self.all_reduce_max_across_dp = ops.AllReduce(op=ReduceOp.MAX, group=self.dp_group) + @property def tp_size(self): return self.moe_parallel_config.tp_size @@ -808,9 +812,7 @@ class FusedMoE(nn.Cell): dp_pad_index_with_offset, dp_unpad_index_total_with_offset): if self.use_dispatch_kernels: - return self.forward_impl_chunked(hidden_states, router_logits, dp_pad_index, - dp_unpad_index, dp_pad_index_with_offset, - dp_unpad_index_total_with_offset) + return self.forward_impl_chunked(hidden_states, router_logits) return self.forward_impl(hidden_states, router_logits, dp_pad_index, dp_unpad_index, dp_pad_index_with_offset, @@ -908,9 +910,13 @@ class FusedMoE(nn.Cell): if not skip_result_store: full_final_hidden_states[chunk_start:chunk_end, :] = final_hidden_states - ctx = get_forward_context() - max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu + # ctx = get_forward_context() + # max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu + # moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens + tokens_num = full_hidden_states.shape[0] + max_tokens_across_dp = self.all_reduce_max_across_dp(ms.Tensor(tokens_num)) moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens + max_tokens_across_dp = max_tokens_across_dp.item() num_tokens = full_hidden_states.size(0) for chunk_start_ in range(0, max_tokens_across_dp, -- Gitee From 5a917b76aa940b4ab93dc8ab01f429d6c94f11b0 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Tue, 8 Jul 2025 21:37:21 +0800 Subject: [PATCH 72/77] update big ep --- vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py | 2 +- vllm_mindspore/model_executor/layers/fused_moe/layer.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py index 30e5a0e8..82991513 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py @@ -387,7 +387,7 @@ class FusedExperts(nn.Cell): ep_rank_id=self.ep_rank, moe_expert_num=global_num_experts, group_ep=self.ep_group, - tp_world_size=self.tp_world_size, + tp_world_size=self.dispatch_tp_world_size, shared_expert_num=self.dispatch_shared_expert_num, global_bs=self.max_bs, expert_token_nums_type=1) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 3d1f370b..94e4f173 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -913,12 +913,10 @@ class FusedMoE(nn.Cell): # ctx = get_forward_context() # max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu # moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens - tokens_num = full_hidden_states.shape[0] - max_tokens_across_dp = self.all_reduce_max_across_dp(ms.Tensor(tokens_num)) + num_tokens = ops.shape(full_hidden_states.shape)[0] + max_tokens_across_dp = self.all_reduce_max_across_dp(ops.scalar_to_tensor(num_tokens, dtype=ms.int32)) moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens - max_tokens_across_dp = max_tokens_across_dp.item() - num_tokens = full_hidden_states.size(0) for chunk_start_ in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): chunk_start = chunk_start_ -- Gitee From b0f79fb04ce8179f7fa3e8a91999afc62a80aac8 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Wed, 9 Jul 2025 09:20:51 +0800 Subject: [PATCH 73/77] update --- vllm_mindspore/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 94e4f173..d2c34aec 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -913,7 +913,7 @@ class FusedMoE(nn.Cell): # ctx = get_forward_context() # max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu # moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens - num_tokens = ops.shape(full_hidden_states.shape)[0] + num_tokens = ops.shape(full_hidden_states)[0] max_tokens_across_dp = self.all_reduce_max_across_dp(ops.scalar_to_tensor(num_tokens, dtype=ms.int32)) moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens -- Gitee From 1c0edc083ec7ad0f3c23f2274e3090722a12a6e1 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Wed, 9 Jul 2025 09:23:55 +0800 Subject: [PATCH 74/77] update --- vllm_mindspore/model_executor/layers/fused_moe/layer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index d2c34aec..0bd2ca34 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -915,6 +915,7 @@ class FusedMoE(nn.Cell): # moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens num_tokens = ops.shape(full_hidden_states)[0] max_tokens_across_dp = self.all_reduce_max_across_dp(ops.scalar_to_tensor(num_tokens, dtype=ms.int32)) + max_tokens_across_dp = max_tokens_across_dp.item() moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens for chunk_start_ in range(0, max_tokens_across_dp, -- Gitee From d9ff71d38ba29e0f8829258b3b5c32934217f903 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Wed, 9 Jul 2025 11:56:45 +0800 Subject: [PATCH 75/77] update --- vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py index 82991513..5264da5d 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/fused_moe.py @@ -401,7 +401,7 @@ class FusedExperts(nn.Cell): expert_ids=topk_ids, expand_idx=expand_idx, ep_send_counts=ep_recv_counts, - expert_scales=topk_weights, + expert_scales=topk_weights.astype(ms.float32), ep_world_size=self.ep_size, ep_rank_id=self.ep_rank, moe_expert_num=global_num_experts, -- Gitee From 02033f941966e77991f259fbb836360e3126c079 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Wed, 9 Jul 2025 15:48:26 +0800 Subject: [PATCH 76/77] update tp mode --- .../model_executor/layers/fused_moe/layer.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 0bd2ca34..71d5c0ef 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -522,6 +522,7 @@ class FusedMoE(nn.Cell): self.all_gather_from_dp_group = ops.AllGather(group=self.dp_group) self.all_reduce_from_dp_group = ops.AllReduce(group=self.dp_group) self.reduce_scatter_from_ep_group = ops.ReduceScatter(group=self.ep_group) + self.reduce_scatter_from_dp_group = ops.ReduceScatter(group=self.dp_group) if self.dp_size > 1 and self.ep_size > 1: self.all_reduce_max_across_dp = ops.AllReduce(op=ReduceOp.MAX, group=self.dp_group) @@ -793,7 +794,7 @@ class FusedMoE(nn.Cell): # And meanwhile if ep_size == 1, it means using tensor parallel to compute routed expert. # So we can delay the shared expert outputs reduce after the routed expert and # the shared expert are added. - return not (self.pure_tp and self.dp_size == 1) + return not self.pure_tp def maybe_all_reduce_tensor_model_parallel( self, final_hidden_states: Tensor): @@ -801,7 +802,7 @@ class FusedMoE(nn.Cell): To all_reduce after routed expert and shared expert are added. """ # Do delay allreduce If "must_reduce_shared_expert_outputs" return True - if self.pure_tp and self.dp_size == 1: + if self.pure_tp: return self.reduce_from_tp_group(final_hidden_states) return final_hidden_states @@ -860,16 +861,17 @@ class FusedMoE(nn.Cell): ) if self.pure_tp and self.dp_size > 1: - final_hidden_states = mint.index_select(final_hidden_states, 0, dp_pad_index_total_with_offset) - final_hidden_states = final_hidden_states.reshape(self.dp_size, -1, final_hidden_states.shape[-1]) + final_hidden_states = mint.index_select(final_hidden_states, 0, dp_pad_index_total_with_offset) + final_hidden_states = final_hidden_states.reshape(self.dp_size, -1, final_hidden_states.shape[-1]) + if self.reduce_results: final_hidden_states = mint.repeat_interleave(final_hidden_states, self.tp_world_size, dim=0) final_hidden_states = final_hidden_states.reshape(-1, final_hidden_states.shape[-1]) final_hidden_states = self.reduce_scatter_from_ep_group(final_hidden_states) final_hidden_states = mint.index_select(final_hidden_states, 0, dp_unpad_index) - # final_hidden_states = mint.index_select(final_hidden_states, 0, dp_unpad_index) - # start = dp_pad_index[-2] - # end = start + tokens_num - # final_hidden_states = final_hidden_states[start:end] + else: + final_hidden_states = final_hidden_states.reshape(-1, final_hidden_states.shape[-1]) + final_hidden_states = self.reduce_scatter_from_dp_group(final_hidden_states) + final_hidden_states = mint.index_select(final_hidden_states, 0, dp_unpad_index) if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) -- Gitee From 189e17239b8923816a2cd72590012dd703e83c31 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Wed, 9 Jul 2025 16:00:09 +0800 Subject: [PATCH 77/77] update --- vllm_mindspore/model_executor/layers/fused_moe/layer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_mindspore/model_executor/layers/fused_moe/layer.py b/vllm_mindspore/model_executor/layers/fused_moe/layer.py index 71d5c0ef..29802bca 100644 --- a/vllm_mindspore/model_executor/layers/fused_moe/layer.py +++ b/vllm_mindspore/model_executor/layers/fused_moe/layer.py @@ -872,6 +872,7 @@ class FusedMoE(nn.Cell): final_hidden_states = final_hidden_states.reshape(-1, final_hidden_states.shape[-1]) final_hidden_states = self.reduce_scatter_from_dp_group(final_hidden_states) final_hidden_states = mint.index_select(final_hidden_states, 0, dp_unpad_index) + return final_hidden_states if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) -- Gitee