From 589a907ce5e4fadcc852bb7a370c972a29af23d4 Mon Sep 17 00:00:00 2001 From: zhanzhan1 Date: Thu, 15 May 2025 17:34:06 +0800 Subject: [PATCH] vllm support pp with qwen2.5 --- .../models/mf_models/mf_model_base.py | 15 ++++++++---- .../model_executor/models/mf_models/qwen2.py | 18 ++++++++++---- .../mf_models/qwen2_weight_processor.py | 24 +++++++++---------- .../models/mf_models/weight_processor.py | 9 ++++--- .../model_executor/models/model_base.py | 3 ++- vllm_mindspore/model_executor/models/utils.py | 2 +- vllm_mindspore/worker/worker.py | 14 ++++++++--- 7 files changed, 57 insertions(+), 28 deletions(-) diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py index 51f0d985..b44ee04d 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py @@ -27,7 +27,7 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size, get_pp_group from vllm.logger import init_logger import torch @@ -71,7 +71,7 @@ class MfModelBase(MsModelBase): self.mf_config.model.model_config.parallel_config.model_parallel = ( get_tensor_model_parallel_world_size() ) - self.mf_config.model.model_config.parallel_config.pipeline_stage = 1 + self.mf_config.model.model_config.parallel_config.pipeline_stage = get_pp_group().world_size self._generate_model_config() self.casual_mask = LowerTriangularMask(dtype=self.mf_model_config.compute_dtype, max_model_len=self.mf_model_config.seq_length) @@ -92,8 +92,9 @@ class MfModelBase(MsModelBase): def _set_dynamic_inputs(self): self.network.set_dynamic_inputs() - dynamic_hidden_states = Tensor(shape=[None, None], dtype=self.mf_model_config.compute_dtype) - self.lm_head.set_inputs(dynamic_hidden_states) + if get_pp_group().is_last_rank: + dynamic_hidden_states = Tensor(shape=[None, None], dtype=self.mf_model_config.compute_dtype) + self.lm_head.set_inputs(dynamic_hidden_states) def prepare_inputs(self, input_ids, positions, attn_metadata): key_cache, value_cache = self.get_kvcache() @@ -147,6 +148,7 @@ class MfModelBase(MsModelBase): ) -> Union[Tensor, IntermediateTensors]: model_inputs, is_prefill = self.prepare_inputs(input_ids, positions, attn_metadata) model_inputs = self.update_model_inputs(model_inputs, **kwargs) + model_inputs["hidden_states"] = intermediate_tensors["hidden_states"] if intermediate_tensors else None if is_prefill: self.network.phase = "prefill" @@ -159,6 +161,10 @@ class MfModelBase(MsModelBase): self.set_flags = True else: hidden_states = self.network(**model_inputs) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + }) return hidden_states @@ -172,6 +178,7 @@ class MfModelBase(MsModelBase): logits = ms.mint.zeros((0, self.mf_model_config.vocab_size), dtype=self.mf_model_config.compute_dtype) else: + hidden_states = hidden_states.view(-1, self.mf_model_config.hidden_size) hidden_states = hidden_states.index_select(0, selected_token_indices) logits = self.lm_head(hidden_states) logits = logits.reshape(-1, logits.shape[-1]) diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen2.py b/vllm_mindspore/model_executor/models/mf_models/qwen2.py index e7475bb0..9175094f 100644 --- a/vllm_mindspore/model_executor/models/mf_models/qwen2.py +++ b/vllm_mindspore/model_executor/models/mf_models/qwen2.py @@ -21,6 +21,8 @@ from typing import Iterable, Set, Tuple from vllm.config import VllmConfig from vllm.config import get_current_vllm_config from vllm.logger import init_logger +from vllm.distributed.parallel_state import get_pp_group +from vllm.model_executor.models.interfaces import SupportsPP from mindspore import Tensor, JitConfig from mindspore.nn.utils import no_init_parameters @@ -34,29 +36,35 @@ from vllm_mindspore.model_executor.layers.sampler import get_sampler from vllm_mindspore.model_executor.models.model_base import Fake_Attention from vllm_mindspore.model_executor.models.mf_models.mf_model_base import MfModelBase from vllm_mindspore.model_executor.models.mf_models.qwen2_weight_processor import Qwen2WeightProcessor +from vllm_mindspore.model_executor.models.utils import make_empty_intermediate_tensors_factory logger = init_logger(__name__) -class Qwen2ForCausalLM(MfModelBase): +class Qwen2ForCausalLM(MfModelBase, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super(Qwen2ForCausalLM, self).__init__(vllm_config=vllm_config, prefix=prefix) self.mf_kvcaches_init = False self.sampler = get_sampler() self.set_modules({"model": self.network}) + self.num_layers = self.model_config.get_num_layers(self.parallel_config) - self.kv_caches = [Fake_Attention() for i in range(self.mf_model_config.num_layers)] + self.kv_caches = [Fake_Attention() for i in range(self.num_layers)] compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") - for i in range(self.mf_model_config.num_layers): + for i in range(self.num_layers): compilation_config.static_forward_context[str(i)] = self.kv_caches[i] self.set_flags = False + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], self.mf_model_config.hidden_size + ) + def _generate_model_config(self): self.mf_config.load_checkpoint = self.get_model_path() self.mf_model_config = LlamaConfig_MF(**self.mf_config.model.model_config) @@ -73,7 +81,9 @@ class Qwen2ForCausalLM(MfModelBase): # Initial network with no_init_parameters(): # Delay initialization network = ParallelQwenForCausalLM_MF(self.mf_model_config) - return network, network.lm_head + if get_pp_group().is_last_rank: + return network, network.lm_head + return network, None def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: weight_processor = Qwen2WeightProcessor(self.mf_config, self.network, False) diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen2_weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/qwen2_weight_processor.py index 59423eca..6f08f385 100644 --- a/vllm_mindspore/model_executor/models/mf_models/qwen2_weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/qwen2_weight_processor.py @@ -48,7 +48,7 @@ class Qwen2WeightProcessor(BaseWeightProcessor): np_data, _ = self.get_safetensor_from_file(embed_tokens_hf_name, src_hf_dir, hf_weight_map) else: np_data, _ = self.get_safetensor_from_file(embed_tokens_hf_name, src_hf_dir, hf_weight_map, - is_split_param=True, split_axis=0) + is_split_param=self.is_split_param, split_axis=0) self.parameter_dict[embed_tokens_ms_name] = ms.Parameter(ms.from_numpy(np_data).astype(ms.bfloat16), name=embed_tokens_ms_name, requires_grad=False) @@ -65,7 +65,7 @@ class Qwen2WeightProcessor(BaseWeightProcessor): if not self.config.model.model_config.tie_word_embeddings: if not self.config.parallel_config.vocab_emb_dp: np_data, _ = self.get_safetensor_from_file(lm_head_hf_name, src_hf_dir, hf_weight_map, - is_split_param=True, split_axis=0) + is_split_param=self.is_split_param, split_axis=0) else: np_data, _ = self.get_safetensor_from_file(lm_head_hf_name, src_hf_dir, hf_weight_map) self.parameter_dict[lm_head_ms_name] = ms.Parameter(ms.from_numpy(np_data).astype(ms.bfloat16), @@ -94,17 +94,17 @@ class Qwen2WeightProcessor(BaseWeightProcessor): ffn_concat = self.config.model.model_config.qkv_concat w1_hf_name = f"model.layers.{layer_id}.mlp.gate_proj.weight" w1_ms_name = self.convert_weight_name(w1_hf_name) - w1_ms_param, _ = self.get_safetensor_from_file(w1_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, + w1_ms_param, _ = self.get_safetensor_from_file(w1_hf_name, src_hf_dir, hf_weight_map, is_split_param=self.is_split_param, split_axis=0) w2_hf_name = f"model.layers.{layer_id}.mlp.down_proj.weight" w2_ms_name = self.convert_weight_name(w2_hf_name) - w2_ms_param, _ = self.get_safetensor_from_file(w2_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, + w2_ms_param, _ = self.get_safetensor_from_file(w2_hf_name, src_hf_dir, hf_weight_map, is_split_param=self.is_split_param, split_axis=1) w3_hf_name = f"model.layers.{layer_id}.mlp.up_proj.weight" w3_ms_name = self.convert_weight_name(w3_hf_name) - w3_ms_param, _ = self.get_safetensor_from_file(w3_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, + w3_ms_param, _ = self.get_safetensor_from_file(w3_hf_name, src_hf_dir, hf_weight_map, is_split_param=self.is_split_param, split_axis=0) if ffn_concat: @@ -130,37 +130,37 @@ class Qwen2WeightProcessor(BaseWeightProcessor): # wq wq_hf_name = f"model.layers.{layer_id}.self_attn.q_proj.weight" wq_ms_name = self.convert_weight_name(wq_hf_name) - wq_ms_param, _ = self.get_safetensor_from_file(wq_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, + wq_ms_param, _ = self.get_safetensor_from_file(wq_hf_name, src_hf_dir, hf_weight_map, is_split_param=self.is_split_param, split_axis=0) # wq bias wq_bias_hf_name = f"model.layers.{layer_id}.self_attn.q_proj.bias" wq_bias_ms_name = self.convert_weight_name(wq_bias_hf_name) wq_bias_ms_param, _ = self.get_safetensor_from_file(wq_bias_hf_name, src_hf_dir, hf_weight_map, - is_split_param=True, + is_split_param=self.is_split_param, split_axis=0) # wk wk_hf_name = f"model.layers.{layer_id}.self_attn.k_proj.weight" wk_ms_name = self.convert_weight_name(wk_hf_name) - wk_ms_param, _ = self.get_safetensor_from_file(wk_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, + wk_ms_param, _ = self.get_safetensor_from_file(wk_hf_name, src_hf_dir, hf_weight_map, is_split_param=self.is_split_param, split_axis=0) # wk bias wk_bias_hf_name = f"model.layers.{layer_id}.self_attn.k_proj.bias" wk_bias_ms_name = self.convert_weight_name(wk_bias_hf_name) wk_bias_ms_param, _ = self.get_safetensor_from_file(wk_bias_hf_name, src_hf_dir, hf_weight_map, - is_split_param=True, + is_split_param=self.is_split_param, split_axis=0) # wv wv_hf_name = f"model.layers.{layer_id}.self_attn.v_proj.weight" wv_ms_name = self.convert_weight_name(wv_hf_name) - wv_ms_param, _ = self.get_safetensor_from_file(wv_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, + wv_ms_param, _ = self.get_safetensor_from_file(wv_hf_name, src_hf_dir, hf_weight_map, is_split_param=self.is_split_param, split_axis=0) # wv bias wv_bias_hf_name = f"model.layers.{layer_id}.self_attn.v_proj.bias" wv_bias_ms_name = self.convert_weight_name(wv_bias_hf_name) wv_bias_ms_param, _ = self.get_safetensor_from_file(wv_bias_hf_name, src_hf_dir, hf_weight_map, - is_split_param=True, + is_split_param=self.is_split_param, split_axis=0) if qkv_concat: @@ -201,7 +201,7 @@ class Qwen2WeightProcessor(BaseWeightProcessor): # wo wo_hf_name = f"model.layers.{layer_id}.self_attn.o_proj.weight" wo_ms_name = self.convert_weight_name(wo_hf_name) - wo_ms_param, _ = self.get_safetensor_from_file(wo_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, + wo_ms_param, _ = self.get_safetensor_from_file(wo_hf_name, src_hf_dir, hf_weight_map, is_split_param=self.is_split_param, split_axis=1) self.parameter_dict[wo_ms_name] = ms.Parameter(ms.from_numpy(wo_ms_param).astype(ms.bfloat16), name=wo_ms_name, diff --git a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py index 9b0aab3a..228dd791 100644 --- a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py @@ -19,7 +19,8 @@ transform huggingface safetensor. import os from safetensors import safe_open -from mindspore.communication.management import get_rank, get_group_size + +from vllm.distributed import get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank class BaseWeightProcessor: @@ -35,11 +36,13 @@ class BaseWeightProcessor: self.config = config self.network = network self.is_quant = is_quant - self.tp_group_size = get_group_size() - self.rank_id = get_rank() + self.tp_group_size = get_tensor_model_parallel_world_size() + self.rank_id = get_tensor_model_parallel_rank() self.parameter_dict = {} self.file_handles = {} + self.is_split_param = self.tp_group_size > 1 + def get_file_handles(self, filename): if filename not in self.file_handles: fp = safe_open(filename, framework="np") diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 961f54a2..40f5fada 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -190,7 +190,8 @@ class MsModelBase(): key_cache = [] value_cache = [] forward_context = get_forward_context() - for i in range(self.config.num_hidden_layers): + num_layers = self.model_config.get_num_layers(self.parallel_config) + for i in range(num_layers): k_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][0] v_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][1] key_cache.append(k_cache) diff --git a/vllm_mindspore/model_executor/models/utils.py b/vllm_mindspore/model_executor/models/utils.py index 4bb7831c..adcc65ba 100644 --- a/vllm_mindspore/model_executor/models/utils.py +++ b/vllm_mindspore/model_executor/models/utils.py @@ -158,7 +158,7 @@ def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int): ) -> IntermediateTensors: dtype = get_valid_dtype(dtype) return IntermediateTensors( - {key: mint.zeros((batch_size, hidden_size), dtype=dtype) for key in keys} + {key: mint.zeros((batch_size, 1, hidden_size), dtype=dtype) for key in keys} ) return make_empty_intermediate_tensors diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index 8ce1bc91..e58f8d17 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -31,6 +31,7 @@ from vllm.distributed import ( init_distributed_environment, set_custom_all_reduce, ) +from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger @@ -74,23 +75,30 @@ def _prepare_input_for_warmup(model_config, model_runner, cache_engine, is_prefi def _warm_up_model(self) -> None: # cache_engine is a list with length equal to the size of pipeline-parallel, and only pp=1 is supported. + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = self.model_runner.model.make_empty_intermediate_tensors( + batch_size=1, + dtype=self.model_config.dtype, + device=self.device, + ) kv_cache = self.cache_engine[0].gpu_cache is_mtp_model = self.speculative_config is not None and self.model_config.hf_config.model_type == "deepseek_mtp" if is_mtp_model: # prefill mtp model model_input, previous_hidden_states = _prepare_input_for_warmup(self.model_config, self.model_runner, self.cache_engine[0], True, is_mtp_model) - self.model_runner.execute_model(model_input, kv_cache, None, previous_hidden_states=previous_hidden_states) + self.model_runner.execute_model(model_input, kv_cache, intermediate_tensors, previous_hidden_states=previous_hidden_states) # warmup for decode if self.vllm_config.scheduler_config.is_multi_step: model_input, _ = _prepare_input_for_warmup(self.model_config, self.model_runner._base_model_runner, self.cache_engine[0], False) - self.model_runner._base_model_runner.execute_model(model_input, kv_cache, None) + self.model_runner._base_model_runner.execute_model(model_input, kv_cache, intermediate_tensors) else: model_input, previous_hidden_states = _prepare_input_for_warmup(self.model_config, self.model_runner, self.cache_engine[0], False, is_mtp_model) - self.model_runner.execute_model(model_input, kv_cache, None, previous_hidden_states=previous_hidden_states) + self.model_runner.execute_model(model_input, kv_cache, intermediate_tensors, previous_hidden_states=previous_hidden_states) torch.cuda.synchronize() -- Gitee