From a1ed8e38f7933a4bc49da91642a9cb9f90268dad Mon Sep 17 00:00:00 2001 From: zhanzhan1 Date: Tue, 22 Jul 2025 10:04:04 +0800 Subject: [PATCH] support pp develop --- .../python/cases_parallel/vllm_qwen_7b_v1.py | 34 ++++++++++ tests/st/python/test_cases_parallel.py | 4 +- vllm_mindspore/__init__.py | 5 ++ .../models/mf_models/deepseek_v3.py | 26 +++++--- .../mf_models/deepseekv3_weight_processor.py | 62 ++++++++++------- .../models/mf_models/mf_model_base.py | 49 ++++++++++++-- .../models/mf_models/weight_processor.py | 66 +++++++++++++++---- .../model_executor/models/model_base.py | 41 ++++++++---- vllm_mindspore/model_executor/models/qwen2.py | 32 ++++----- vllm_mindspore/utils.py | 31 +++++++++ vllm_mindspore/v1/worker/gpu_worker.py | 4 ++ vllm_mindspore/worker/worker.py | 18 ++++- 12 files changed, 291 insertions(+), 81 deletions(-) diff --git a/tests/st/python/cases_parallel/vllm_qwen_7b_v1.py b/tests/st/python/cases_parallel/vllm_qwen_7b_v1.py index dd7b0a1c..0ec205ff 100644 --- a/tests/st/python/cases_parallel/vllm_qwen_7b_v1.py +++ b/tests/st/python/cases_parallel/vllm_qwen_7b_v1.py @@ -74,3 +74,37 @@ def test_vllm_qwen(): # unset env env_manager.unset_all() + +def test_vllm_qwen_pp2(): + """ + test case qwen2.5 7B with pipeline parallel + """ + + # Sample prompts. + prompts = [ + "You are a helpful assistant.<|User|>将文本分类为中性、负面或正面。" + " \n文本:我认为这次假期还可以。 \n情感:<|Assistant|>\n", + ] + + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_k=1) + + # Create an LLM. + llm = LLM( + model="/home/workspace/mindspore_dataset/weight/Qwen2.5-7B-Instruct", + gpu_memory_utilization=0.9, + pipeline_parallel_size=2, + distributed_executor_backend='ray') + # Generate texts from the prompts. The output is a list of RequestOutput + # objects that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + except_list = ['中性<|Assistant|> 这句话'] + # Print the outputs. + for i, output in enumerate(outputs): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + assert generated_text == except_list[i] + + # unset env + env_manager.unset_all() diff --git a/tests/st/python/test_cases_parallel.py b/tests/st/python/test_cases_parallel.py index 4bca8b68..f8c73c06 100644 --- a/tests/st/python/test_cases_parallel.py +++ b/tests/st/python/test_cases_parallel.py @@ -67,7 +67,9 @@ def test_cases_parallel_part0(): "vllm_mf_qwen_7b_chunk_prefill_test_mf_qwen_7b_chunk_prefill.log"), (2, "cases_parallel/vllm_mf_qwen_7b_chunk_prefill_v1.py" "::test_mf_qwen_7b_chunk_prefill", - "vllm_mf_qwen_7b_chunk_prefill_v1_test_mf_qwen_7b_chunk_prefill.log") + "vllm_mf_qwen_7b_chunk_prefill_v1_test_mf_qwen_7b_chunk_prefill.log"), + (2, "cases_parallel/vllm_qwen_7b_v1.py::test_vllm_qwen_pp2", + "vllm_qwen_7b_v1_test_vllm_qwen_pp2.log"), ] run_tasks(cases) diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index a4e04937..b29e4012 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -476,3 +476,8 @@ from vllm_mindspore.entrypoints.__main__ import ( patch_server_run_api_server_worker_proc() check_ready() + +from vllm_mindspore.utils import view +from mindspore import Tensor + +Tensor.view = view diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py index dfc8be5d..5fc1067a 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py @@ -40,9 +40,10 @@ from research.deepseek3.deepseek3_config import (DeepseekV3Config as from research.deepseek3.deepseek3_model_infer import DeepseekV3DecodeLayer from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed.parallel_state import ( - get_dp_group, get_tensor_model_parallel_world_size) + get_dp_group, get_pp_group, get_tensor_model_parallel_world_size) from vllm.forward_context import get_forward_context from vllm.logger import init_logger +from vllm.model_executor.models.interfaces import SupportsPP from vllm_mindspore.model_executor.layers.sampler import get_sampler from vllm_mindspore.model_executor.models.attention_mask import ( @@ -55,6 +56,8 @@ from vllm_mindspore.model_executor.models.mf_models \ from vllm_mindspore.model_executor.models.mf_models.mf_model_base import ( MfModelBase) from vllm_mindspore.model_executor.models.model_base import MLAAttentionWrapper +from vllm_mindspore.model_executor.models.utils import ( + make_empty_intermediate_tensors_factory) with contextlib.suppress(ImportError): # Need to apply dllm pd patch on vllm to use pd disagg related functions @@ -122,7 +125,7 @@ def _get_padding_index(q_seq_len): ms.from_numpy(ffn_padding_idx), ms.from_numpy(ffn_unpadding_idx) -class DeepseekV3ForCausalLM(MfModelBase): +class DeepseekV3ForCausalLM(MfModelBase, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__(vllm_config=vllm_config, prefix=prefix) @@ -134,15 +137,16 @@ class DeepseekV3ForCausalLM(MfModelBase): self.sampler = get_sampler() self.set_modules({"model": self.network}) + self.num_layers = self.model_config.get_num_layers( + self.parallel_config) self.kv_caches = [ - MLAAttentionWrapper() - for i in range(self.mf_model_config.num_layers) + MLAAttentionWrapper() for _ in range(self.num_layers) ] compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") - for i in range(self.mf_model_config.num_layers): + for i in range(self.num_layers): compilation_config.static_forward_context[str( i)] = self.kv_caches[i] @@ -151,6 +155,10 @@ class DeepseekV3ForCausalLM(MfModelBase): self.casual_mask = MLALowerTriangularMask( dtype=self.mf_model_config.compute_dtype, max_model_len=self.model_config.max_model_len) + self.make_empty_intermediate_tensors = \ + make_empty_intermediate_tensors_factory( + keys=["hidden_states"], + hidden_size=self.mf_model_config.hidden_size) def _generate_model_config(self): self.mf_config.load_checkpoint = self.get_model_path() @@ -181,12 +189,14 @@ class DeepseekV3ForCausalLM(MfModelBase): if ptq is not None: ptq.apply(network) ptq.convert(network) - return network, network.lm_head + if get_pp_group().is_last_rank: + return network, network.lm_head + return network, None def get_kvcache(self): key_cache = [] forward_context = get_forward_context() - for i in range(self.mf_model_config.num_layers): + for i in range(self.num_layers): k_cache = self.kv_caches[i].kv_cache[ forward_context.virtual_engine][0] key_cache.append(k_cache) @@ -217,7 +227,7 @@ class DeepseekV3ForCausalLM(MfModelBase): do_predict=True) else: weight_processor = DeepseekV3WeightProcessor( - self.mf_config, self.network, self.is_quant) + self.mf_config, self.network, self.is_quant, self.vllm_config) weight_processor.load_safetensors_shard( self.mf_config.load_checkpoint) return None # type: ignore[return-value] diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py index b344eefe..776842ca 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py @@ -28,6 +28,7 @@ from mindformers.parallel_core.inference.parallel_state import ( from mindspore import dtype from mindspore.communication.management import get_rank from tqdm import tqdm +from vllm.distributed import get_pp_group, get_pp_indices from vllm.logger import init_logger from vllm_mindspore.model_executor.models.mf_models.weight_processor import ( @@ -62,9 +63,14 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): """ - def __init__(self, config, network, is_quant): - super().__init__(config, network, is_quant) - self.num_layers = self.config.model.model_config.num_layers + def __init__(self, config, network, is_quant, vllm_config): + super().__init__(config, network, is_quant, vllm_config) + self.num_layers = self.vllm_config.model_config.get_num_layers( + self.vllm_config.parallel_config) + self.start_layer, self.end_layer = get_pp_indices( + self.config.model.model_config.num_layers, + get_pp_group().rank_in_group, + get_pp_group().world_size) self.expert_num = self.config.moe_config.expert_num self.moe_split_tp = self.moe_tp_size > 1 self.moe_split_ep = self.moe_ep_size > 1 @@ -422,20 +428,20 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): w2_scale_hf_name, w3_scale_hf_name, src_hf_dir, hf_weight_map): if self.ep_method in [EPMethod.DEFAULT, EPMethod.ALLGATHER]: - w1_ms_param, _ = self.get_safetensor_from_file_split_global_group( + w1_ms_param, _ = self.get_safetensor_from_file_split_tp_dp_group( w1_hf_name, src_hf_dir, hf_weight_map, split_axis=0) - w2_ms_param, _ = self.get_safetensor_from_file_split_global_group( + w2_ms_param, _ = self.get_safetensor_from_file_split_tp_dp_group( w2_hf_name, src_hf_dir, hf_weight_map, split_axis=1) - w3_ms_param, _ = self.get_safetensor_from_file_split_global_group( + w3_ms_param, _ = self.get_safetensor_from_file_split_tp_dp_group( w3_hf_name, src_hf_dir, hf_weight_map, split_axis=0) w1_scale_ms_param, _ = ( - self.get_safetensor_from_file_split_global_group( + self.get_safetensor_from_file_split_tp_dp_group( w1_scale_hf_name, src_hf_dir, hf_weight_map, split_axis=0)) w2_scale_ms_param, _ = self.get_safetensor_from_file( w2_scale_hf_name, src_hf_dir, hf_weight_map) w3_scale_ms_param, _ = ( - self.get_safetensor_from_file_split_global_group( + self.get_safetensor_from_file_split_tp_dp_group( w3_scale_hf_name, src_hf_dir, hf_weight_map, split_axis=0)) elif self.ep_method == EPMethod.ALLTOALL: w1_ms_param, _ = self.get_safetensor_from_file( @@ -1135,7 +1141,7 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): def convert_mtp_weight_name(self, weight_name: str): layer = 0 if 'layers.' not in weight_name else int( weight_name[weight_name.find('layers.'):].split('.')[1]) - if layer < self.num_layers: + if self.start_layer <= layer < self.end_layer: return weight_name mtp_prefix = 'mtp_model' is_mtp_layer = ('tok_embeddings' not in weight_name @@ -1186,14 +1192,17 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): base_ms_name = f"model.layers.{layer_id}.feed_forward.routed_experts" w1_ms_name = f"{base_ms_name}.ffn.w1.weight" - w1_ms_name = (w1_ms_name if layer_id < self.num_layers else - self.convert_mtp_weight_name(w1_ms_name)) + w1_ms_name = ( + w1_ms_name if self.start_layer <= layer_id < self.end_layer + else self.convert_mtp_weight_name(w1_ms_name)) w2_ms_name = f"{base_ms_name}.ffn.w2.weight" - w2_ms_name = (w2_ms_name if layer_id < self.num_layers else - self.convert_mtp_weight_name(w2_ms_name)) + w2_ms_name = ( + w2_ms_name if self.start_layer <= layer_id < self.end_layer + else self.convert_mtp_weight_name(w2_ms_name)) w3_ms_name = f"{base_ms_name}.ffn.w3.weight" - w3_ms_name = (w3_ms_name if layer_id < self.num_layers else - self.convert_mtp_weight_name(w3_ms_name)) + w3_ms_name = ( + w3_ms_name if self.start_layer <= layer_id < self.end_layer + else self.convert_mtp_weight_name(w3_ms_name)) for index in range(0, self.num_router_experts): base_hf_name = f"model.layers.{layer_id}.mlp.experts.{index}" @@ -1221,7 +1230,8 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): base_path = f"model.layers.{layer_id}.feed_forward.routed_experts" w_gate_hidden_name = f"{base_path}.ffn.w_gate_hidden.weight" w_gate_hidden_name = ( - w_gate_hidden_name if layer_id < self.num_layers else + w_gate_hidden_name + if self.start_layer <= layer_id < self.end_layer else self.convert_mtp_weight_name(w_gate_hidden_name)) w_gate_hidden_np = np.concatenate( [w1_ms_stack_param, w3_ms_stack_param], axis=1) @@ -1253,11 +1263,11 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): def get_moe_shared_expert_weight(self, w1_hf_name, w2_hf_name, w3_hf_name, src_hf_dir, hf_weight_map): if self.ep_method in [EPMethod.DEFAULT, EPMethod.ALLGATHER]: - w1_ms_param, _ = self.get_safetensor_from_file_split_global_group( + w1_ms_param, _ = self.get_safetensor_from_file_split_tp_dp_group( w1_hf_name, src_hf_dir, hf_weight_map, split_axis=0) - w2_ms_param, _ = self.get_safetensor_from_file_split_global_group( + w2_ms_param, _ = self.get_safetensor_from_file_split_tp_dp_group( w2_hf_name, src_hf_dir, hf_weight_map, split_axis=1) - w3_ms_param, _ = self.get_safetensor_from_file_split_global_group( + w3_ms_param, _ = self.get_safetensor_from_file_split_tp_dp_group( w3_hf_name, src_hf_dir, hf_weight_map, split_axis=0) elif self.ep_method == EPMethod.ALLTOALL: w1_ms_param, _ = self.get_safetensor_from_file( @@ -1293,7 +1303,8 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): base_path = f"model.layers.{layer_id}.feed_forward.shared_experts" w_gate_hidden_name = f"{base_path}.w_gate_hidden.weight" w_gate_hidden_name = ( - w_gate_hidden_name if layer_id < self.num_layers else + w_gate_hidden_name + if self.start_layer <= layer_id < self.end_layer else self.convert_mtp_weight_name(w_gate_hidden_name)) w_gate_hidden_np = np.concatenate([w1_ms_param, w3_ms_param], axis=0) @@ -1554,7 +1565,7 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): self.infer_process_norm_weight(src_hf_dir, layer_id, hf_weight_map) # convert mtp shared weights. - if layer_id >= self.num_layers: + if layer_id >= self.end_layer: self.infer_process_mtp_layer_weight(src_hf_dir, layer_id, hf_weight_map) @@ -2223,10 +2234,11 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): enable_tqdm = rank_id == 0 mtp_layers = self.config.model.model_config.num_nextn_predict_layers - start_layer = 0 if not is_mtp_model else self.num_layers - end_layer = (self.num_layers if not is_mtp_model else self.num_layers + - mtp_layers) - for layer_id in tqdm(range(start_layer, end_layer), + self.start_layer = ( + self.start_layer if not is_mtp_model else self.end_layer) + self.end_layer = (self.end_layer + if not is_mtp_model else self.end_layer + mtp_layers) + for layer_id in tqdm(range(self.start_layer, self.end_layer), desc="Weight loading", disable=not enable_tqdm): if self.is_quant: diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py index 38e97956..fa59893e 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py @@ -20,6 +20,7 @@ from collections.abc import Iterable from typing import Optional, Union import mindspore as ms +import numpy as np from mindformers.core.context import build_mf_context from mindformers.core.parallel_config import build_parallel_config from mindformers.tools.register.config import MindFormerConfig @@ -27,9 +28,10 @@ from mindformers.tools.utils import is_pynative from mindspore import Tensor, nn from mindspore.common.api import _pynative_executor from mindspore.communication import get_rank +from vllm import envs from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.distributed.parallel_state import get_dp_group +from vllm.distributed.parallel_state import get_dp_group, get_pp_group from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput @@ -73,7 +75,10 @@ class MfModelBase(MsModelBase): self.mf_config.parallel_config) self.mf_config.model.model_config.parallel_config.model_parallel = ( get_tensor_model_parallel_world_size()) - self.mf_config.model.model_config.parallel_config.pipeline_stage = 1 + self.mf_config.model.model_config.parallel_config.pipeline_stage = ( + get_pp_group().world_size) + self.mf_config.model.model_config.offset = self.get_mf_offset( + self.mf_config.model.model_config) self._generate_model_config() if not hasattr(self, 'mf_model_config'): raise RuntimeError('mf_model_config not initialized') @@ -123,7 +128,8 @@ class MfModelBase(MsModelBase): raise RuntimeError('mf_model_config not initialized') dynamic_hidden_states = Tensor( shape=[None, None], dtype=self.mf_model_config.compute_dtype) - self.ready_lm_head.set_inputs(dynamic_hidden_states) + if get_pp_group().is_last_rank: + self.ready_lm_head.set_inputs(dynamic_hidden_states) def prepare_inputs(self, input_ids, positions): return self.prepare_base_inputs(input_ids, positions) @@ -150,6 +156,34 @@ class MfModelBase(MsModelBase): for i in range(self.mf_model_config.num_layers): wait_for_kv_layer_from_connector("key." + str(i)) + def get_mf_offset(self, model_config): + """ get pp offset from vllm style""" + partition_list_str = envs.VLLM_PP_LAYER_PARTITION + num_layers = model_config.num_layers + pp_size = model_config.parallel_config.pipeline_stage + if partition_list_str is not None: + try: + partitions = [ + int(layer) for layer in partition_list_str.split(",") + ] + except ValueError as err: + raise ValueError("Invalid partition string: {}".format( + partition_list_str)) from err + if len(partitions) != pp_size: + raise ValueError(f"{len(partitions)=} does not match {pp_size=}.") + if sum(partitions) != num_layers: + raise ValueError( + f"{sum(partitions)=} does not match {num_layers=}.") + partitions = np.array(partitions, dtype=np.int32) + avg_layers = num_layers // pp_size + avg_layers_list = np.ones((pp_size, ), dtype=np.int32) * avg_layers + if (partitions == avg_layers_list).all(): + return 0 + else: + return (partitions - avg_layers_list).tolist() + else: + return 0 + def forward(self, input_ids: Tensor, positions: Tensor, @@ -158,6 +192,10 @@ class MfModelBase(MsModelBase): **kwargs) -> Union[Tensor, IntermediateTensors]: model_inputs, is_prefill = self.prepare_inputs(input_ids, positions) model_inputs = self.update_model_inputs(model_inputs, **kwargs) + model_inputs["hidden_states"] = None + if intermediate_tensors is not None: + model_inputs["hidden_states"] = intermediate_tensors["hidden_states"] + if is_prefill: self.network.phase = "prefill" @@ -179,7 +217,10 @@ class MfModelBase(MsModelBase): self.connector_wait_for_kv_layer() logger.debug("connector_wait_for_kv_layer success") hidden_states = self.network(**model_inputs) - + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + }) return hidden_states def compute_logits( diff --git a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py index d60506fb..1b687743 100644 --- a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py @@ -21,8 +21,11 @@ import os from enum import Enum from mindformers.parallel_core.inference.parallel_state import ( - get_data_parallel_world_size) -from mindformers.parallel_core.inference.utils import get_tp_world_size + get_data_parallel_world_size, get_moe_expert_parallel_rank, + get_moe_expert_parallel_world_size, get_moe_tensor_parallel_rank, + get_moe_tensor_parallel_world_size, get_tensor_and_data_parallel_rank, + get_tensor_and_data_parallel_world_size, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from mindspore.communication.management import get_group_size, get_rank from safetensors import safe_open @@ -45,30 +48,32 @@ class BaseWeightProcessor: """ - def __init__(self, config, network, is_quant): + def __init__(self, config, network, is_quant, vllm_config): + self.vllm_config = vllm_config self.config = config self.network = network self.is_quant = is_quant self.global_rank_id = get_rank() self.global_group_size = get_group_size() - self.tp_group_size = get_tp_world_size() + self.tp_group_size = get_tensor_model_parallel_world_size() self.dp_group_size = get_data_parallel_world_size() + self.tp_dp_group_size = get_tensor_and_data_parallel_world_size() + self.tp_dp_group_id = get_tensor_and_data_parallel_rank() self.num_router_experts = self.config.moe_config.expert_num if \ self.config.moe_config.expert_num else 1 - self.moe_ep_size = self.config.parallel_config.expert_parallel \ - if self.config.parallel_config.expert_parallel else 1 - self.moe_tp_size = self.global_group_size // self.moe_ep_size + self.moe_ep_size = get_moe_expert_parallel_world_size() + self.moe_tp_size = get_moe_tensor_parallel_world_size() self.ep_method = EPMethod.DEFAULT if self.dp_group_size > 1\ - and self.moe_ep_size == self.global_group_size: + and self.moe_ep_size == self.tp_dp_group_size: self.ep_method = EPMethod.ALLTOALL elif self.dp_group_size > 1: self.ep_method = EPMethod.ALLGATHER - self.tp_rank_id = self.global_rank_id % self.tp_group_size + self.tp_rank_id = get_tensor_model_parallel_rank() self.ep_group_nums = self.num_router_experts // self.moe_ep_size - self.moe_ep_rank_id = self.global_rank_id // self.moe_tp_size - self.moe_tp_rank_id = self.global_rank_id % self.moe_tp_size + self.moe_ep_rank_id = get_moe_expert_parallel_rank() + self.moe_tp_rank_id = get_moe_tensor_parallel_rank() self.ep_start = self.moe_ep_rank_id * self.ep_group_nums self.ep_stop = (self.moe_ep_rank_id + 1) * self.ep_group_nums @@ -90,8 +95,8 @@ class BaseWeightProcessor: filename = os.path.join(src_hf_dir, safetensor_file) sf_file = self.get_file_handles(filename) qint4 = False - if sf_file.metadata( - ) is not None and hf_param_name in sf_file.metadata(): + if sf_file.metadata() is not None and hf_param_name in sf_file.metadata( + ).keys(): qint4 = True np_data = sf_file.get_tensor(hf_param_name) @@ -197,6 +202,41 @@ class BaseWeightProcessor: "split_axis:{} is not supported.".format(split_axis)) return split_data, qint4 + def get_safetensor_from_file_split_tp_dp_group(self, + hf_param_name, + src_hf_dir, + hf_weight_map, + split_axis=0): + safetensor_file = hf_weight_map[hf_param_name] + filename = os.path.join(src_hf_dir, safetensor_file) + sf_file = self.get_file_handles(filename) + qint4 = False + if sf_file.metadata( + ) is not None and hf_param_name in sf_file.metadata().keys(): + qint4 = True + + np_data = sf_file.get_slice(hf_param_name) + shape = np_data.get_shape() + if split_axis == 0: + split_size = shape[0] // self.tp_dp_group_size + start = self.tp_dp_group_id * split_size + stop = (self.tp_dp_group_id + 1) * split_size + split_data = np_data[start:stop] + elif split_axis == 1: + split_size = shape[1] // self.tp_dp_group_size + start = self.tp_dp_group_id * split_size + stop = (self.tp_dp_group_id + 1) * split_size + split_data = np_data[:, start:stop] + elif split_axis == 2: + split_size = shape[2] // self.tp_dp_group_size + start = self.tp_dp_group_id * split_size + stop = (self.tp_dp_group_id + 1) * split_size + split_data = np_data[:, :, start:stop] + else: + raise ValueError( + "split_axis:{} is not supported.".format(split_axis)) + return split_data, qint4 + def get_routed_safetensor_3_dim(self, hf_param_name, src_hf_dir, diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 38199ae5..1bcf5b32 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -80,6 +80,7 @@ class MsModelBase: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() + self.vllm_config = vllm_config config = vllm_config.model_config.hf_config lora_config = vllm_config.lora_config @@ -216,7 +217,9 @@ class MsModelBase: key_cache = [] value_cache = [] forward_context = get_forward_context() - for i in range(self.config.num_hidden_layers): + num_hidden_layers = self.model_config.get_num_layers( + self.parallel_config) + for i in range(num_hidden_layers): k_cache = self.kv_caches[i].kv_cache[ forward_context.virtual_engine][0] v_cache = self.kv_caches[i].kv_cache[ @@ -350,18 +353,18 @@ class NativeModel(MsModelBase): self.casual_mask = LowerTriangularMask( dtype=self.model_config.dtype, max_model_len=self.model_config.max_model_len) - self.kv_caches = [ - AttentionWrapper() for i in range(self.config.num_hidden_layers) - ] + num_hidden_layers = self.model_config.get_num_layers( + self.parallel_config) + self.kv_caches = [AttentionWrapper() for _ in range(num_hidden_layers)] compilation_config = vllm_config.compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") - for i in range(self.config.num_hidden_layers): + for i in range(num_hidden_layers): compilation_config.static_forward_context[str( i)] = self.kv_caches[i] - def set_model_inputs(self, is_prefill): + def set_model_inputs(self, is_prefill, intermediate_tensors): dyn_input_ids = Tensor(shape=[None], dtype=mstype.int32) dyn_position_ids = Tensor(shape=[None], dtype=mstype.int32) @@ -396,21 +399,36 @@ class NativeModel(MsModelBase): None, ], dtype=mstype.int32) dyn_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) - dyn_intermediate_tensors = None + if intermediate_tensors is None: + dyn_hidden_states = None + dyn_residual = None + else: + dyn_hidden_states = Tensor(shape=[None, None], + dtype=self.model_config.dtype) + dyn_residual = Tensor(shape=[None, None], + dtype=self.model_config.dtype) dyn_inputs_embeds = None self.ready_model.set_inputs( dyn_input_ids, dyn_position_ids, dyn_key_caches, dyn_value_caches, is_prefill, dyn_slot_mapping, dynamic_attention_mask, dyn_batch_valid_length, dyn_q_seq_lens, dyn_block_tables, - dyn_intermediate_tensors, dyn_inputs_embeds) + dyn_hidden_states, dyn_residual, dyn_inputs_embeds) def prepare_inputs(self, input_ids, positions, intermediate_tensors, inputs_embeds): model_inputs, is_prefill = self.prepare_base_inputs( input_ids, positions) + # for pp + if intermediate_tensors is None: + model_inputs["hidden_states"] = None + model_inputs["residual"] = None + else: + model_inputs["hidden_states"] = intermediate_tensors[ + "hidden_states"] + model_inputs["residual"] = intermediate_tensors["residual"] + # for multimodal model - model_inputs["intermediate_tensors"] = intermediate_tensors model_inputs["inputs_embeds"] = inputs_embeds return model_inputs, is_prefill @@ -426,7 +444,7 @@ class NativeModel(MsModelBase): inputs_embeds) if self.prev_prefill != is_prefill and self.is_graph_mode: - self.set_model_inputs(is_prefill) + self.set_model_inputs(is_prefill, intermediate_tensors) self.prev_prefill = is_prefill # for dummy_attention_metadata @@ -450,7 +468,8 @@ class NativeModel(MsModelBase): batch_valid_length=model_inputs["batch_valid_length"], q_seq_lens=model_inputs["q_seq_lens"], block_tables=model_inputs["block_tables"], - intermediate_tensors=model_inputs["intermediate_tensors"], + hidden_states=model_inputs["hidden_states"], + residual=model_inputs["residual"], inputs_embeds=model_inputs["inputs_embeds"], ) # type: ignore[misc] diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index e1163ba3..02426df1 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -41,7 +41,7 @@ from vllm.attention.backends.abstract import AttentionType from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.models.interfaces import SupportsLoRA +from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP from vllm.sequence import IntermediateTensors from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -333,9 +333,10 @@ class Qwen2Model(nn.Cell): batch_valid_length: Tensor, q_seq_lens: Tensor, block_tables: Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, + hidden_states: Optional[Tensor] = None, + residual: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None, - ) -> Union[Tensor, IntermediateTensors]: + ) -> tuple[Tensor, Tensor]: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -343,9 +344,6 @@ class Qwen2Model(nn.Cell): else: hidden_states = self.get_input_embeddings(input_ids) residual = None - else: - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] for i in range(self.start_layer, self.end_layer): # PP 并行对层进行切分 layer = self.layers[i] @@ -355,13 +353,9 @@ class Qwen2Model(nn.Cell): is_prefill, slot_mapping, attn_mask, batch_valid_length, q_seq_lens, block_tables, residual) - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states + if get_pp_group().is_last_rank: + hidden_states, residual = self.norm(hidden_states, residual) + return hidden_states, residual def load_weights(self, weights: Iterable[tuple[str, Tensor]], params_dict: dict[str, Parameter]): @@ -415,7 +409,7 @@ class Qwen2Model(nn.Cell): return loaded_params -class Qwen2ForCausalLM(NativeModel, SupportsLoRA): +class Qwen2ForCausalLM(NativeModel, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -477,8 +471,14 @@ class Qwen2ForCausalLM(NativeModel, SupportsLoRA): intermediate_tensors: IntermediateTensors = None, inputs_embeds: Tensor = None, **kwargs) -> Union[Tensor, IntermediateTensors]: - hidden_states = self.exec_model(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states, residual = self.exec_model(input_ids, positions, + intermediate_tensors, + inputs_embeds) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual, + }) return hidden_states def load_weights(self, weights: Iterable[tuple[str, Tensor]]) -> set[str]: diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index b2615b46..0f8d0f17 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -35,6 +35,7 @@ else: import mindspore as ms from mindspore import dtype as mstype +from mindspore._c_expression import typing from mindspore.common.initializer import Zero from vllm.logger import init_logger from vllm.utils import (TORCH_DTYPE_TO_NUMPY_DTYPE, MemoryProfilingResult, @@ -321,3 +322,33 @@ def ms_memory_profiling( result.non_torch_increase = diff_from_create.non_torch_memory result.profile_time = diff_profile.timestamp result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa + + +def view(self, *shape_or_dtype): + if len(shape_or_dtype) == 1 and isinstance(shape_or_dtype[0], typing.Type): + target_dtype = shape_or_dtype[0] + ori_shape = self.shape + target_shape = (-1, ) + if len(ori_shape) > 1: + target_shape = ori_shape[:-1] + target_shape + out = np.frombuffer( + self.numpy(), + torch.ops.creation._TypeDict.get(target_dtype, np.float32)) + if not out.flags.aligned: + out = np.require(out, requirements=["ALIGNED"]) + if target_dtype == ms.bfloat16: + return ms.Tensor.from_numpy(out.astype( + np.float32)).astype(target_dtype).reshape(target_shape) + return ms.Tensor.from_numpy(out).reshape(target_shape) + result = [] + if type(shape_or_dtype) is tuple: + for items in shape_or_dtype: + if not isinstance(items, int): + for item in items: + if not isinstance(item, int): + result.append(item.item()) + else: + result.append(item) + else: + result.append(items) + return ms.ops.reshape(self, result) diff --git a/vllm_mindspore/v1/worker/gpu_worker.py b/vllm_mindspore/v1/worker/gpu_worker.py index ab3bd566..db92c182 100644 --- a/vllm_mindspore/v1/worker/gpu_worker.py +++ b/vllm_mindspore/v1/worker/gpu_worker.py @@ -71,6 +71,10 @@ def compile_or_warm_up_model(self) -> None: # MindSpore does not support cuda graph. No need to warm up the model. # Since prefill is done previously, we do decode here. default_max_num_reqs = 1 # For MindSpore, we only do one more decode here. + # Only pp_last_rank requires _dummy_sampler_run, + # and only pp_last_rank can _dummy_sampler_run. if get_pp_group().is_last_rank: self.model_runner._dummy_sampler_run( self.model_runner._dummy_run(num_tokens=default_max_num_reqs)) + else: + self.model_runner._dummy_run(num_tokens=default_max_num_reqs) diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index ed2b636a..f9e38ff6 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -18,6 +18,7 @@ import math import torch +from vllm.distributed import get_pp_group from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.sampling_params import SamplingParams @@ -72,6 +73,17 @@ def _warm_up_model(self) -> None: kv_cache = self.cache_engine[0].gpu_cache is_mtp_model = self.speculative_config is not None and \ self.model_config.hf_config.model_type == "deepseek_mtp" + + def get_model(cls): + if cls.vllm_config.scheduler_config.is_multi_step: + return cls.model_runner._base_model_runner.model + return cls.model_runner.model + + intermediate_tensors = None + model = get_model(self) + if not get_pp_group().is_first_rank: + intermediate_tensors = model.make_empty_intermediate_tensors( + batch_size=1, dtype=self.model_config.dtype, device=self.devices) if is_mtp_model: # prefill mtp model model_input, previous_hidden_states = _prepare_input_for_warmup( @@ -80,7 +92,7 @@ def _warm_up_model(self) -> None: self.model_runner.execute_model( model_input, kv_cache, - None, + intermediate_tensors, previous_hidden_states=previous_hidden_states) # warmup for decode @@ -89,7 +101,7 @@ def _warm_up_model(self) -> None: self.model_config, self.model_runner._base_model_runner, self.cache_engine[0], False) self.model_runner._base_model_runner.execute_model( - model_input, kv_cache, None) + model_input, kv_cache, intermediate_tensors) else: model_input, previous_hidden_states = _prepare_input_for_warmup( self.model_config, self.model_runner, self.cache_engine[0], False, @@ -97,7 +109,7 @@ def _warm_up_model(self) -> None: self.model_runner.execute_model( model_input, kv_cache, - None, + intermediate_tensors, previous_hidden_states=previous_hidden_states) torch.cuda.synchronize() -- Gitee