From 0be490e9ded532670340616593eaaf6cf67317a4 Mon Sep 17 00:00:00 2001 From: zhanzhan1 Date: Tue, 3 Jun 2025 20:17:40 +0800 Subject: [PATCH] support pp with qwen2 --- .../model_executor/models/model_base.py | 3 +- vllm_mindspore/model_executor/models/qwen2.py | 51 +++++++++++-------- vllm_mindspore/model_executor/models/utils.py | 11 ++-- vllm_mindspore/worker/model_runner.py | 3 +- vllm_mindspore/worker/worker.py | 14 +++-- 5 files changed, 51 insertions(+), 31 deletions(-) diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 961f54a2..40f5fada 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -190,7 +190,8 @@ class MsModelBase(): key_cache = [] value_cache = [] forward_context = get_forward_context() - for i in range(self.config.num_hidden_layers): + num_layers = self.model_config.get_num_layers(self.parallel_config) + for i in range(num_layers): k_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][0] v_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][1] key_cache.append(k_cache) diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index 5eb70a82..a829b5e2 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -59,6 +59,7 @@ from vllm.sequence import IntermediateTensors from vllm.attention.backends.abstract import AttentionType from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.attention.backends.abstract import AttentionMetadata +from vllm.model_executor.models.interfaces import SupportsPP class Qwen2MLP(nn.Cell): @@ -348,7 +349,8 @@ class Qwen2Model(nn.Cell): batch_valid_length: Tensor, q_seq_lens: Tensor, block_tables: Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, + hidden_states: Optional[Tensor] = None, + residual: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None, ) -> Union[Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: @@ -357,9 +359,6 @@ class Qwen2Model(nn.Cell): else: hidden_states = self.get_input_embeddings(input_ids) residual = None - else: - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] for i in range(self.start_layer, self.end_layer): # PP 并行对层进行切分 layer = self.layers[i] @@ -377,12 +376,9 @@ class Qwen2Model(nn.Cell): residual ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states + return hidden_states, residual + hidden_states, residual = self.norm(hidden_states, residual) + return hidden_states, residual def load_weights(self, weights: Iterable[Tuple[str, Tensor]], params_dict: Dict[str, Parameter]): loaded_params: Set[str] = set() @@ -435,7 +431,7 @@ class Qwen2Model(nn.Cell): return loaded_params -class Qwen2ForCausalLM(MsModelBase): +class Qwen2ForCausalLM(MsModelBase, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -491,15 +487,17 @@ class Qwen2ForCausalLM(MsModelBase): self.prefill = True self.mstype = STR_DTYPE_TO_MS_DTYPE.get(self.model_config.dtype, self.model_config.dtype) - self.casual_mask = LowerTriangularMask(dtype=self.mstype, + self.casual_mask = LowerTriangularMask(dtype=self.mstype, max_model_len=self.model_config.max_model_len) self.set_model_inputs(self.prefill) - self.kv_caches = [Fake_Attention() for i in range(config.num_hidden_layers)] + + self.num_layers = self.model_config.get_num_layers(self.parallel_config) + self.kv_caches = [Fake_Attention() for i in range(self.num_layers)] compilation_config = vllm_config.compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") - for i in range(config.num_hidden_layers): + for i in range(self.num_layers): compilation_config.static_forward_context[str(i)] = self.kv_caches[i] def set_model_inputs(self, is_prefill): @@ -527,7 +525,10 @@ class Qwen2ForCausalLM(MsModelBase): dyn_batch_valid_length = Tensor(shape=[None,], dtype=mstype.int32) dyn_q_seq_lens = Tensor(shape=[None, ], dtype=mstype.int32) dyn_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) - dyn_intermediate_tensors = None + dyn_hidden_states = Tensor(shape=[None, None, None], + dtype=self.mstype) if not get_pp_group().is_first_rank else None + dyn_residual = Tensor(shape=[None, None, None], + dtype=self.mstype) if not get_pp_group().is_first_rank else None dyn_inputs_embeds = None self.model.set_inputs( dyn_input_ids, @@ -540,7 +541,8 @@ class Qwen2ForCausalLM(MsModelBase): dyn_batch_valid_length, dyn_q_seq_lens, dyn_block_tables, - dyn_intermediate_tensors, + dyn_hidden_states, + dyn_residual, dyn_inputs_embeds ) @@ -586,6 +588,10 @@ class Qwen2ForCausalLM(MsModelBase): batch_valid_length = Tensor.from_numpy(seq_lens_np) q_seq_lens = Tensor.from_numpy(np.array(attn_metadata.query_lens, dtype=np.int32)) block_tables = attn_metadata.block_tables + + hidden_states = intermediate_tensors["hidden_states"] if intermediate_tensors else None + residual = intermediate_tensors["residual"] if intermediate_tensors else None + model_output = self.model(input_ids, positions, key_cache, @@ -596,13 +602,16 @@ class Qwen2ForCausalLM(MsModelBase): batch_valid_length, q_seq_lens, block_tables, - intermediate_tensors, + hidden_states, + residual, inputs_embeds) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": model_output[0], "residual": model_output[1], }) + if is_prefill: - model_output = ops.squeeze(model_output, 0) - else: - model_output = ops.squeeze(model_output, 1) - return model_output + return ops.squeeze(model_output[0], 0) + return ops.squeeze(model_output[0], 1) def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: params_dict = self.get_params_dict() diff --git a/vllm_mindspore/model_executor/models/utils.py b/vllm_mindspore/model_executor/models/utils.py index 4bb7831c..49ae816d 100644 --- a/vllm_mindspore/model_executor/models/utils.py +++ b/vllm_mindspore/model_executor/models/utils.py @@ -155,10 +155,11 @@ def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int): batch_size: int, dtype, device, + seq_lens: Optional[int] = 1, ) -> IntermediateTensors: dtype = get_valid_dtype(dtype) return IntermediateTensors( - {key: mint.zeros((batch_size, hidden_size), dtype=dtype) for key in keys} + {key: mint.zeros((batch_size, seq_lens, hidden_size), dtype=dtype) for key in keys} ) return make_empty_intermediate_tensors @@ -229,7 +230,7 @@ def merge_multimodal_embeddings( Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the positions in ``inputs_embeds`` corresponding to placeholder tokens in ``input_ids``. - + ``placeholder_token_id`` can be a list of token ids (e.g, token ids of img_start, img_break, and img_end tokens) when needed: This means the order of these tokens in the ``input_ids`` MUST MATCH the order of @@ -242,7 +243,7 @@ def merge_multimodal_embeddings( - I is image embedding token - B is image break token - E is image end token. - + Then the image embeddings (that correspond to I's) from vision encoder must be padded with embeddings of S, B, and E in the same order of input_ids for a correct embedding merge. @@ -252,7 +253,7 @@ def merge_multimodal_embeddings( """ if isinstance(placeholder_token_id, list): placeholder_token_id = ms.Tensor(placeholder_token_id, - device=input_ids.device) + device=input_ids.device) return _merge_multimodal_embeddings( inputs_embeds, ms.numpy.isin(input_ids, placeholder_token_id), @@ -263,4 +264,4 @@ def merge_multimodal_embeddings( inputs_embeds, (input_ids == placeholder_token_id), multimodal_embeddings, - ) \ No newline at end of file + ) diff --git a/vllm_mindspore/worker/model_runner.py b/vllm_mindspore/worker/model_runner.py index 561fd202..50f4baff 100644 --- a/vllm_mindspore/worker/model_runner.py +++ b/vllm_mindspore/worker/model_runner.py @@ -151,7 +151,8 @@ def _dummy_run(self, if not get_pp_group().is_first_rank: intermediate_tensors = \ self.model.make_empty_intermediate_tensors( - batch_size=batch_size, + batch_size=1, + seq_lens=batch_size, dtype=self.model_config.dtype, device=self.device) diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index 8ce1bc91..e58f8d17 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -31,6 +31,7 @@ from vllm.distributed import ( init_distributed_environment, set_custom_all_reduce, ) +from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger @@ -74,23 +75,30 @@ def _prepare_input_for_warmup(model_config, model_runner, cache_engine, is_prefi def _warm_up_model(self) -> None: # cache_engine is a list with length equal to the size of pipeline-parallel, and only pp=1 is supported. + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = self.model_runner.model.make_empty_intermediate_tensors( + batch_size=1, + dtype=self.model_config.dtype, + device=self.device, + ) kv_cache = self.cache_engine[0].gpu_cache is_mtp_model = self.speculative_config is not None and self.model_config.hf_config.model_type == "deepseek_mtp" if is_mtp_model: # prefill mtp model model_input, previous_hidden_states = _prepare_input_for_warmup(self.model_config, self.model_runner, self.cache_engine[0], True, is_mtp_model) - self.model_runner.execute_model(model_input, kv_cache, None, previous_hidden_states=previous_hidden_states) + self.model_runner.execute_model(model_input, kv_cache, intermediate_tensors, previous_hidden_states=previous_hidden_states) # warmup for decode if self.vllm_config.scheduler_config.is_multi_step: model_input, _ = _prepare_input_for_warmup(self.model_config, self.model_runner._base_model_runner, self.cache_engine[0], False) - self.model_runner._base_model_runner.execute_model(model_input, kv_cache, None) + self.model_runner._base_model_runner.execute_model(model_input, kv_cache, intermediate_tensors) else: model_input, previous_hidden_states = _prepare_input_for_warmup(self.model_config, self.model_runner, self.cache_engine[0], False, is_mtp_model) - self.model_runner.execute_model(model_input, kv_cache, None, previous_hidden_states=previous_hidden_states) + self.model_runner.execute_model(model_input, kv_cache, intermediate_tensors, previous_hidden_states=previous_hidden_states) torch.cuda.synchronize() -- Gitee