From be18006c68f3d952bb8e817cbb9b603cfc0d08c4 Mon Sep 17 00:00:00 2001 From: dengyunyang Date: Mon, 11 Aug 2025 16:28:41 +0800 Subject: [PATCH] adapte to eplb, only support to deepseekv3 in large ep now --- .../models/mf_models/deepseek_v3.py | 62 ++++++++++++++++++- .../mf_models/deepseekv3_weight_processor.py | 3 + .../models/mf_models/weight_processor.py | 22 ++++++- vllm_mindspore/v1/worker/gpu_worker.py | 9 ++- 4 files changed, 89 insertions(+), 7 deletions(-) diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py index d17cf0bd..1e76ba04 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py @@ -45,6 +45,10 @@ from vllm.distributed.parallel_state import ( from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.sampler import get_sampler +try: + from vllm.model_executor.models.interfaces import MixtureOfExperts +except: + MixtureOfExperts = None from vllm_mindspore.model_executor.models.attention_mask import ( MLALowerTriangularMask) @@ -55,6 +59,12 @@ from vllm_mindspore.model_executor.models.mf_models \ # isort: on from vllm_mindspore.model_executor.models.mf_models.mf_model_base import ( MfModelBase) +from research.deepseek3.deepseek3_model_infer import ( + InferenceDeepseekV3ForCausalLM, + DeepseekV3Model, + DeepseekV3MoE +) +from research.deepseek3.moe import ExpertParallelMoE from vllm_mindspore.model_executor.models.model_base import MLAAttentionWrapper with contextlib.suppress(ImportError): @@ -123,9 +133,10 @@ def _get_padding_index(q_seq_len): return ms.from_numpy(attn_padding_idx), ms.from_numpy(attn_unpadding_idx), \ ms.from_numpy(ffn_padding_idx), ms.from_numpy(ffn_unpadding_idx) - -class DeepseekV3ForCausalLM(MfModelBase): - +if not MixtureOfExperts: + logger.warning("Current vllm version not support EPLB.") +bases = (MfModelBase, MixtureOfExperts) if MixtureOfExperts else (MfModelBase,) +class DeepseekV3ForCausalLM(*bases): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__(vllm_config=vllm_config, prefix=prefix) self.is_quant = bool( @@ -161,6 +172,35 @@ class DeepseekV3ForCausalLM(MfModelBase): dtype=self.mf_model_config.compute_dtype, max_model_len=self.model_config.max_model_len) + self.enable_eplb = hasattr(self.mf_model_config.moe_config, "enable_eplb") and \ + self.mf_model_config.moe_config.enable_eplb and MixtureOfExperts is not None + if self.enable_eplb: + self.expert_weights = [] + self.moe_layers = self._get_moe_layers() + example_moe = self.moe_layers[-1] + self.num_logical_experts = example_moe.expert_num + self.num_physical_experts = example_moe.physical_expert_num + self.num_local_physical_experts = example_moe.local_ep_num + self.num_routed_experts = example_moe.expert_num + self.num_shared_experts = example_moe.dispatch_shared_expert_num + self.num_redundant_experts = example_moe.redundant_expert_num + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = self.mf_model_config.moe_config.n_group + else: + logger.info("EPLB in vllm_mindspore is disable, need to check the version of vllm and mindformers.") + + def _get_moe_layers(self): + if not isinstance(self.network, InferenceDeepseekV3ForCausalLM) or \ + not isinstance(self.network.model, DeepseekV3Model): + return [] + moe_layers: list[ExpertParallelMoE] = [] + for layer in self.network.model.layers: + if not isinstance(layer.feed_forward, DeepseekV3MoE) or \ + not isinstance(layer.feed_forward.routed_experts, ExpertParallelMoE): + continue + moe_layers.append(layer.feed_forward.routed_experts) + return moe_layers + def _generate_model_config(self): self.mf_config.load_checkpoint = self.get_model_path() @@ -192,6 +232,22 @@ class DeepseekV3ForCausalLM(MfModelBase): ptq.convert(network) return network, network.lm_head + def set_eplb_state( + self, + expert_load_view: Tensor, + logical_to_physical_map: Tensor, + logical_replica_count: Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + def get_kvcache(self): key_cache = [] forward_context = get_forward_context() diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py index 73bc3027..cd417b26 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py @@ -237,6 +237,9 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): w3_scale_list = [] for index in range(self.ep_start, self.ep_stop): + # get the physical experts to load + if self.enable_eplb: + index = self.init_physical_to_logical_map[index] base_path = f"model.layers.{layer_id}.mlp.experts.{index}" w1_hf_name = f"{base_path}.gate_proj.weight" w2_hf_name = f"{base_path}.down_proj.weight" diff --git a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py index d60506fb..820a193f 100644 --- a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py @@ -25,7 +25,10 @@ from mindformers.parallel_core.inference.parallel_state import ( from mindformers.parallel_core.inference.utils import get_tp_world_size from mindspore.communication.management import get_group_size, get_rank from safetensors import safe_open - +try: + from vllm.distributed.eplb.eplb_state import EplbState +except: + EplbState = None class EPMethod(Enum): """ @@ -66,7 +69,16 @@ class BaseWeightProcessor: self.ep_method = EPMethod.ALLGATHER self.tp_rank_id = self.global_rank_id % self.tp_group_size - self.ep_group_nums = self.num_router_experts // self.moe_ep_size + # For eplb; we use hasattr to avoid vllm_mindspore depend on a specific update on mindformers + self.enable_eplb = hasattr(self.config.moe_config, "enable_eplb") and \ + self.config.moe_config.enable_eplb and EplbState is not None + if self.enable_eplb: + self.num_redundant_expert = self.config.moe_config.redundant_expert_num \ + if self.config.moe_config.redundant_expert_num else 0 + self.ep_group_nums = (self.num_router_experts + self.num_redundant_expert) // self.moe_ep_size + else: + self.ep_group_nums = self.num_router_experts // self.moe_ep_size + self.moe_ep_rank_id = self.global_rank_id // self.moe_tp_size self.moe_tp_rank_id = self.global_rank_id % self.moe_tp_size self.ep_start = self.moe_ep_rank_id * self.ep_group_nums @@ -75,6 +87,12 @@ class BaseWeightProcessor: self.parameter_dict = {} self.file_handles = {} + # For eplb; init_physical_to_logical_map used for weight loader + if self.enable_eplb: + self.init_physical_to_logical_map = \ + EplbState.build_initial_global_physical_to_logical_map( + self.num_router_experts, self.num_redundant_expert) + def get_file_handles(self, filename): if filename not in self.file_handles: fp = safe_open(filename, framework="np") diff --git a/vllm_mindspore/v1/worker/gpu_worker.py b/vllm_mindspore/v1/worker/gpu_worker.py index 2a49d511..067099ce 100644 --- a/vllm_mindspore/v1/worker/gpu_worker.py +++ b/vllm_mindspore/v1/worker/gpu_worker.py @@ -73,5 +73,10 @@ def compile_or_warm_up_model(self) -> None: # Since prefill is done previously, we do decode here. default_max_num_reqs = 1 # For MindSpore, we only do one more decode here. if get_pp_group().is_last_rank: - self.model_runner._dummy_sampler_run( - self.model_runner._dummy_run(num_tokens=default_max_num_reqs)) + # we use hasattr to avoid vllm_mindspore depend on a specific vllm version + if not hasattr(self.vllm_config.parallel_config, "enable_eplb"): + self.model_runner._dummy_sampler_run( + self.model_runner._dummy_run(num_tokens=default_max_num_reqs)) + else: + self.model_runner._dummy_sampler_run( + self.model_runner._dummy_run(num_tokens=default_max_num_reqs, skip_eplb=True)) -- Gitee