From 3cee446316b9382f28656430a6d441c5f507911e Mon Sep 17 00:00:00 2001 From: zhanzhan1 Date: Mon, 14 Jul 2025 15:47:43 +0800 Subject: [PATCH 1/2] support pp 0.9.1 --- vllm_mindspore/__init__.py | 4 ++ vllm_mindspore/model_executor/models/llama.py | 27 ++++----- .../models/mf_models/deepseek_v3.py | 23 +++++--- .../mf_models/deepseekv3_weight_processor.py | 48 +++++++++------- .../models/mf_models/mf_model_base.py | 17 +++++- .../models/mf_models/weight_processor.py | 57 ++++++++++++++++--- .../model_executor/models/model_base.py | 35 +++++++++--- vllm_mindspore/model_executor/models/qwen2.py | 30 +++++----- vllm_mindspore/utils.py | 28 +++++++++ vllm_mindspore/v1/worker/gpu_worker.py | 3 + vllm_mindspore/worker/worker.py | 18 +++++- 11 files changed, 207 insertions(+), 83 deletions(-) diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 0f39d7d0..b0c98678 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -425,3 +425,7 @@ try: except: pass check_ready() + +from vllm_mindspore.utils import view +from mindspore import Tensor +Tensor.view = view diff --git a/vllm_mindspore/model_executor/models/llama.py b/vllm_mindspore/model_executor/models/llama.py index 954579f1..e49f0097 100644 --- a/vllm_mindspore/model_executor/models/llama.py +++ b/vllm_mindspore/model_executor/models/llama.py @@ -371,19 +371,16 @@ class LlamaModel(nn.Cell): batch_valid_length: Tensor, q_seq_lens: Tensor, block_tables: Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, + hidden_states: Optional[Tensor] = None, + residual: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None, - ) -> Union[Tensor, IntermediateTensors]: + ) -> Tuple[Tensor, Tensor]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] for i in range(self.start_layer, self.end_layer): layer = self.layers[i] @@ -394,14 +391,9 @@ class LlamaModel(nn.Cell): attn_mask, batch_valid_length, q_seq_lens, block_tables, residual) - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states + if get_pp_group().is_last_rank: + hidden_states, residual = self.norm(hidden_states, residual) + return hidden_states, residual def load_weights(self, weights: Iterable[Tuple[str, Tensor]], params_dict): loaded_params: Set[str] = set() @@ -493,8 +485,13 @@ class LlamaForCausalLM(NativeModel, SupportsPP): intermediate_tensors=None, inputs_embeds=None, **kwargs): - hidden_states = self.exec_model(input_ids, positions, + hidden_states, residual = self.exec_model(input_ids, positions, intermediate_tensors, inputs_embeds) + if not get_pp_group().is_first_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual, + }) return hidden_states def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py index 1e2df73a..fe3d9e60 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py @@ -26,10 +26,11 @@ import mindspore as ms from vllm.config import VllmConfig from vllm.config import get_current_vllm_config -from vllm.distributed.parallel_state import get_dp_group, get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_dp_group, get_tensor_model_parallel_world_size, get_pp_group from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.attention.layer import Attention +from vllm.model_executor.models.interfaces import SupportsPP import mindspore as ms from mindspore import Tensor, JitConfig, Model, mutable @@ -56,6 +57,7 @@ from vllm_mindspore.model_executor.models.model_base import MLAAttentionWrapper from vllm_mindspore.model_executor.models.mf_models.mf_model_base import MfModelBase from vllm_mindspore.model_executor.models.mf_models.deepseekv3_weight_processor import DeepseekV3WeightProcessor from vllm_mindspore.model_executor.models.attention_mask import MLALowerTriangularMask +from vllm_mindspore.model_executor.models.utils import make_empty_intermediate_tensors_factory try: # Need to apply dllm pd patch on vllm to use pd disagg related functions @@ -121,7 +123,7 @@ def _get_padding_index(q_seq_len): -class DeepseekV3ForCausalLM(MfModelBase): +class DeepseekV3ForCausalLM(MfModelBase, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super(DeepseekV3ForCausalLM, self).__init__( vllm_config=vllm_config, prefix=prefix @@ -133,18 +135,21 @@ class DeepseekV3ForCausalLM(MfModelBase): self.sampler = get_sampler() self.set_modules({"model": self.network}) - self.kv_caches = [MLAAttentionWrapper() for i in range(self.mf_model_config.num_layers)] + self.num_layers = self.model_config.get_num_layers(self.parallel_config) + self.kv_caches = [MLAAttentionWrapper() for _ in range(self.num_layers)] compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") - for i in range(self.mf_model_config.num_layers): + for i in range(self.num_layers): compilation_config.static_forward_context[str(i)] = self.kv_caches[i] self.set_flags = False set_runtime_kernel_launch_group() self.casual_mask = MLALowerTriangularMask(dtype=self.mf_model_config.compute_dtype, max_model_len=self.model_config.max_model_len) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(keys=["hidden_states"], + hidden_size=self.model_config.hf_config.hidden_size) def _generate_model_config(self): self.mf_config.load_checkpoint = self.get_model_path() @@ -171,12 +176,14 @@ class DeepseekV3ForCausalLM(MfModelBase): if ptq is not None: ptq.apply(network) ptq.convert(network) - return network, network.lm_head + if get_pp_group().is_last_rank: + return network, network.lm_head + return network, None def get_kvcache(self): key_cache = [] forward_context = get_forward_context() - for i in range(self.mf_model_config.num_layers): + for i in range(self.num_layers): k_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][0] key_cache.append(k_cache) return mutable(key_cache), None @@ -185,7 +192,7 @@ class DeepseekV3ForCausalLM(MfModelBase): logger.debug(f"reached deepseek_v3 connector_send_kvcache") _pynative_executor.sync() forward_context = get_forward_context() - for i in range(self.mf_model_config.num_layers): + for i in range(self.num_layers): kv_cache_module = self.kv_caches[i] kv_cache = kv_cache_module.kv_cache[forward_context.virtual_engine][0] maybe_save_kv_layer_to_connector(str(i), kv_cache) @@ -201,7 +208,7 @@ class DeepseekV3ForCausalLM(MfModelBase): self.mf_config, model, self.network, infer_data, do_predict=True ) else: - weight_processor = DeepseekV3WeightProcessor(self.mf_config, self.network, self.is_quant) + weight_processor = DeepseekV3WeightProcessor(self.mf_config, self.network, self.is_quant, self.vllm_config) weight_processor.load_safetensors_shard(self.mf_config.load_checkpoint) return None diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py index c63abe69..74a0e478 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py @@ -27,6 +27,7 @@ from mindspore import dtype from mindspore.communication.management import get_rank from tqdm import tqdm from vllm.logger import init_logger +from vllm.distributed import get_pp_group, get_pp_indices from vllm_mindspore.model_executor.models.mf_models.weight_processor import ( BaseWeightProcessor, EPMethod) @@ -60,9 +61,14 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): """ - def __init__(self, config, network, is_quant): - super().__init__(config, network, is_quant) - self.num_layers = self.config.model.model_config.num_layers + def __init__(self, config, network, is_quant, vllm_config): + super().__init__(config, network, is_quant, vllm_config) + self.num_layers = self.vllm_config.model_config.get_num_layers(self.vllm_config.parallel_config) + self.start_layer, self.end_layer = get_pp_indices( + self.config.model.model_config.num_layers, + get_pp_group().rank_in_group, + get_pp_group().world_size, + ) self.expert_num = self.config.moe_config.expert_num self.moe_split_tp = self.moe_tp_size > 1 self.moe_split_ep = self.moe_ep_size > 1 @@ -415,18 +421,18 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): w2_scale_hf_name, w3_scale_hf_name, src_hf_dir, hf_weight_map): if self.ep_method in [EPMethod.DEFAULT, EPMethod.ALLGATHER]: - w1_ms_param, _ = self.get_safetensor_from_file_split_global_group( + w1_ms_param, _ = self.get_safetensor_from_file_split_tp_dp_group( w1_hf_name, src_hf_dir, hf_weight_map, split_axis=0) - w2_ms_param, _ = self.get_safetensor_from_file_split_global_group( + w2_ms_param, _ = self.get_safetensor_from_file_split_tp_dp_group( w2_hf_name, src_hf_dir, hf_weight_map, split_axis=1) - w3_ms_param, _ = self.get_safetensor_from_file_split_global_group( + w3_ms_param, _ = self.get_safetensor_from_file_split_tp_dp_group( w3_hf_name, src_hf_dir, hf_weight_map, split_axis=0) - w1_scale_ms_param, _ = self.get_safetensor_from_file_split_global_group( + w1_scale_ms_param, _ = self.get_safetensor_from_file_split_tp_dp_group( w1_scale_hf_name, src_hf_dir, hf_weight_map, split_axis=0) w2_scale_ms_param, _ = self.get_safetensor_from_file( w2_scale_hf_name, src_hf_dir, hf_weight_map) - w3_scale_ms_param, _ = self.get_safetensor_from_file_split_global_group( + w3_scale_ms_param, _ = self.get_safetensor_from_file_split_tp_dp_group( w3_scale_hf_name, src_hf_dir, hf_weight_map, split_axis=0) elif self.ep_method == EPMethod.ALLTOALL: w1_ms_param, _ = self.get_safetensor_from_file( @@ -1115,7 +1121,7 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): def convert_mtp_weight_name(self, weight_name: str): layer = 0 if 'layers.' not in weight_name else int( weight_name[weight_name.find('layers.'):].split('.')[1]) - if layer < self.num_layers: + if self.start_layer <= layer < self.end_layer: return weight_name mtp_prefix = 'mtp_model' is_mtp_layer = 'tok_embeddings' not in weight_name and 'shared_head.' not in weight_name @@ -1161,13 +1167,13 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): w3_list = [] w1_ms_name = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w1.weight" - w1_ms_name = w1_ms_name if layer_id < self.num_layers else self.convert_mtp_weight_name( + w1_ms_name = w1_ms_name if self.start_layer <=layer_id < self.end_layer else self.convert_mtp_weight_name( w1_ms_name) w2_ms_name = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w2.weight" - w2_ms_name = w2_ms_name if layer_id < self.num_layers else self.convert_mtp_weight_name( + w2_ms_name = w2_ms_name if self.start_layer <=layer_id < self.end_layer else self.convert_mtp_weight_name( w2_ms_name) w3_ms_name = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w3.weight" - w3_ms_name = w3_ms_name if layer_id < self.num_layers else self.convert_mtp_weight_name( + w3_ms_name = w3_ms_name if self.start_layer <=layer_id < self.end_layer else self.convert_mtp_weight_name( w3_ms_name) for index in range(0, self.num_router_experts): @@ -1193,7 +1199,7 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): if ffn_concat: w_gate_hidden_name = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w_gate_hidden.weight" - w_gate_hidden_name = w_gate_hidden_name if layer_id < self.num_layers else \ + w_gate_hidden_name = w_gate_hidden_name if self.start_layer <=layer_id < self.end_layer else \ self.convert_mtp_weight_name(w_gate_hidden_name) w_gate_hidden_np = np.concatenate( [w1_ms_stack_param, w3_ms_stack_param], axis=1) @@ -1225,11 +1231,11 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): def get_moe_shared_expert_weight(self, w1_hf_name, w2_hf_name, w3_hf_name, src_hf_dir, hf_weight_map): if self.ep_method in [EPMethod.DEFAULT, EPMethod.ALLGATHER]: - w1_ms_param, _ = self.get_safetensor_from_file_split_global_group( + w1_ms_param, _ = self.get_safetensor_from_file_split_tp_dp_group( w1_hf_name, src_hf_dir, hf_weight_map, split_axis=0) - w2_ms_param, _ = self.get_safetensor_from_file_split_global_group( + w2_ms_param, _ = self.get_safetensor_from_file_split_tp_dp_group( w2_hf_name, src_hf_dir, hf_weight_map, split_axis=1) - w3_ms_param, _ = self.get_safetensor_from_file_split_global_group( + w3_ms_param, _ = self.get_safetensor_from_file_split_tp_dp_group( w3_hf_name, src_hf_dir, hf_weight_map, split_axis=0) elif self.ep_method == EPMethod.ALLTOALL: w1_ms_param, _ = self.get_safetensor_from_file( @@ -1261,7 +1267,7 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): if ffn_concat: w_gate_hidden_name = f"model.layers.{layer_id}.feed_forward.shared_experts.w_gate_hidden.weight" - w_gate_hidden_name = w_gate_hidden_name if layer_id < self.num_layers else \ + w_gate_hidden_name = w_gate_hidden_name if self.start_layer <=layer_id < self.end_layer else \ self.convert_mtp_weight_name(w_gate_hidden_name) w_gate_hidden_np = np.concatenate([w1_ms_param, w3_ms_param], axis=0) @@ -1516,7 +1522,7 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): self.infer_process_norm_weight(src_hf_dir, layer_id, hf_weight_map) # convert mtp shared weights. - if layer_id >= self.num_layers: + if layer_id >= self.end_layer: self.infer_process_mtp_layer_weight(src_hf_dir, layer_id, hf_weight_map) @@ -2161,9 +2167,9 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): enable_tqdm = rank_id == 0 mtp_layers = self.config.model.model_config.num_nextn_predict_layers - start_layer = 0 if not is_mtp_model else self.num_layers - end_layer = self.num_layers if not is_mtp_model else self.num_layers + mtp_layers - for layer_id in tqdm(range(start_layer, end_layer), + self.start_layer = self.start_layer if not is_mtp_model else self.end_layer + self.end_layer = self.end_layer if not is_mtp_model else self.end_layer + mtp_layers + for layer_id in tqdm(range(self.start_layer, self.end_layer), desc="Weight loading", disable=not enable_tqdm): if self.is_quant: diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py index c4df0f43..ed16ec45 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py @@ -29,7 +29,7 @@ from mindspore.common.api import _pynative_executor from mindspore.communication import get_rank from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.distributed.parallel_state import get_dp_group +from vllm.distributed.parallel_state import get_dp_group, get_pp_group from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput @@ -77,7 +77,8 @@ class MfModelBase(MsModelBase): self.mf_config.parallel_config) self.mf_config.model.model_config.parallel_config.model_parallel = ( get_tensor_model_parallel_world_size()) - self.mf_config.model.model_config.parallel_config.pipeline_stage = 1 + self.mf_config.model.model_config.parallel_config.pipeline_stage = ( + get_pp_group().world_size) self._generate_model_config() self.casual_mask = LowerTriangularMask( dtype=self.mf_model_config.compute_dtype, @@ -117,7 +118,8 @@ class MfModelBase(MsModelBase): self.network.set_dynamic_inputs() dynamic_hidden_states = Tensor( shape=[None, None], dtype=self.mf_model_config.compute_dtype) - self.lm_head.set_inputs(dynamic_hidden_states) + if get_pp_group().is_last_rank: + self.lm_head.set_inputs(dynamic_hidden_states) def prepare_inputs(self, input_ids, positions): return self.prepare_base_inputs(input_ids, positions) @@ -149,6 +151,10 @@ class MfModelBase(MsModelBase): **kwargs) -> Union[Tensor, IntermediateTensors]: model_inputs, is_prefill = self.prepare_inputs(input_ids, positions) model_inputs = self.update_model_inputs(model_inputs, **kwargs) + model_inputs["hidden_states"] = None + if intermediate_tensors is not None: + model_inputs["hidden_states"] = intermediate_tensors["hidden_states"] + # enable_mb_split is True in lager EP enable micro-batch and per-dp-bs > 1 enable_mb_split = self.is_enable_micro_batch_split( @@ -183,6 +189,11 @@ class MfModelBase(MsModelBase): self.connector_wait_for_kv_layer() logger.debug(f"connector_wait_for_kv_layer success") hidden_states = self.network(**model_inputs) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + }) return hidden_states diff --git a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py index 89d786eb..c627dac7 100644 --- a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py @@ -19,10 +19,15 @@ transform huggingface safetensor. import os from enum import Enum + +from mindformers.parallel_core.inference.parallel_state import ( + get_data_parallel_world_size, get_moe_expert_parallel_rank, + get_moe_tensor_parallel_rank, get_pipeline_model_parallel_world_size, + get_tensor_and_data_model_parallel_rank, + get_tensor_and_data_model_parallel_world_size, + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from mindspore.communication.management import get_group_size, get_rank from safetensors import safe_open -from mindspore.communication.management import get_rank, get_group_size -from mindformers.parallel_core.inference.utils import get_tp_world_size -from mindformers.parallel_core.inference.parallel_state import get_data_parallel_world_size class EPMethod(Enum): @@ -43,28 +48,32 @@ class BaseWeightProcessor: """ - def __init__(self, config, network, is_quant): + def __init__(self, config, network, is_quant, vllm_config): + self.vllm_config = vllm_config self.config = config self.network = network self.is_quant = is_quant self.global_rank_id = get_rank() self.global_group_size = get_group_size() - self.tp_group_size = get_tp_world_size() + self.tp_group_size = get_tensor_model_parallel_world_size() self.dp_group_size = get_data_parallel_world_size() + self.tp_dp_group_size = get_tensor_and_data_model_parallel_world_size() + self.tp_dp_gourp_id = get_tensor_and_data_model_parallel_rank() + self.pp_group_size = get_pipeline_model_parallel_world_size() self.num_router_experts = self.config.moe_config.expert_num if self.config.moe_config.expert_num else 1 self.moe_ep_size = self.config.parallel_config.expert_parallel \ if self.config.parallel_config.expert_parallel else 1 - self.moe_tp_size = self.global_group_size // self.moe_ep_size + self.moe_tp_size = self.global_group_size // self.moe_ep_size // self.pp_group_size self.ep_method = EPMethod.DEFAULT if self.dp_group_size > 1 and self.moe_ep_size == self.global_group_size: self.ep_method = EPMethod.ALLTOALL elif self.dp_group_size > 1: self.ep_method = EPMethod.ALLGATHER - self.tp_rank_id = self.global_rank_id % self.tp_group_size + self.tp_rank_id = get_tensor_model_parallel_rank() self.ep_group_nums = self.num_router_experts // self.moe_ep_size - self.moe_ep_rank_id = self.global_rank_id // self.moe_tp_size - self.moe_tp_rank_id = self.global_rank_id % self.moe_tp_size + self.moe_ep_rank_id = get_moe_expert_parallel_rank() + self.moe_tp_rank_id = get_moe_tensor_parallel_rank() self.ep_start = self.moe_ep_rank_id * self.ep_group_nums self.ep_stop = (self.moe_ep_rank_id + 1) * self.ep_group_nums @@ -120,6 +129,36 @@ class BaseWeightProcessor: raise ValueError("split_axis:{} is not supported.".format(split_axis)) return split_data, qint4 + def get_safetensor_from_file_split_tp_dp_group(self, hf_param_name, src_hf_dir, hf_weight_map, split_axis=0): + safetensor_file = hf_weight_map[hf_param_name] + filename = os.path.join(src_hf_dir, safetensor_file) + sf_file = self.get_file_handles(filename) + qint4 = False + if sf_file.metadata() is not None and hf_param_name in sf_file.metadata().keys(): + qint4 = True + + np_data = sf_file.get_slice(hf_param_name) + shape = np_data.get_shape() + if split_axis == 0: + split_size = shape[0] // self.tp_dp_group_size + start = self.tp_dp_gourp_id * split_size + stop = (self.tp_dp_gourp_id + 1) * split_size + split_data = np_data[start:stop] + elif split_axis == 1: + split_size = shape[1] // self.tp_dp_group_size + start = self.tp_dp_gourp_id * split_size + stop = (self.tp_dp_gourp_id + 1) * split_size + split_data = np_data[:, start:stop] + elif split_axis == 2: + split_size = shape[2] // self.tp_dp_group_size + start = self.tp_dp_gourp_id * split_size + stop = (self.tp_dp_gourp_id + 1) * split_size + split_data = np_data[:, :, start:stop] + else: + raise ValueError("split_axis:{} is not supported.".format(split_axis)) + return split_data, qint4 + + def get_safetensor_from_file_split_global_group(self, hf_param_name, src_hf_dir, hf_weight_map, split_axis=0): safetensor_file = hf_weight_map[hf_param_name] filename = os.path.join(src_hf_dir, safetensor_file) diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 123256da..ec7021be 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -80,6 +80,7 @@ class MsModelBase: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() + self.vllm_config = vllm_config config = vllm_config.model_config.hf_config lora_config = vllm_config.lora_config @@ -216,7 +217,8 @@ class MsModelBase: key_cache = [] value_cache = [] forward_context = get_forward_context() - for i in range(self.config.num_hidden_layers): + num_layers = self.model_config.get_num_layers(self.parallel_config) + for i in range(num_layers): k_cache = self.kv_caches[i].kv_cache[ # type: ignore[attr-defined] forward_context.virtual_engine][0] v_cache = self.kv_caches[i].kv_cache[ # type: ignore[attr-defined] @@ -352,14 +354,15 @@ class NativeModel(MsModelBase): self.casual_mask = LowerTriangularMask( dtype=self.model_config.dtype, max_model_len=self.model_config.max_model_len) + num_layers = self.model_config.get_num_layers(self.parallel_config) self.kv_caches = [ - AttentionWrapper() for i in range(self.config.num_hidden_layers) + AttentionWrapper() for _ in range(num_layers) ] compilation_config = vllm_config.compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") - for i in range(self.config.num_hidden_layers): + for i in range(num_layers): compilation_config.static_forward_context[str( i)] = self.kv_caches[i] @@ -384,11 +387,15 @@ class NativeModel(MsModelBase): dtype=inputs_embeds.dtype) if intermediate_tensors is None: - dyn_intermediate_tensors = None + dyn_hidden_states = None + dyn_residual = None else: - dyn_intermediate_tensors = ms.Tensor( - shape=[None] * intermediate_tensors.ndim, - dtype=intermediate_tensors.dtype) + dyn_hidden_states = ms.Tensor( + shape=[None] * intermediate_tensors["hidden_states"].ndim, + dtype=intermediate_tensors["hidden_states"].dtype) + dyn_residual = ms.Tensor( + shape=[None] * intermediate_tensors["residual"].ndim, + dtype=intermediate_tensors["residual"].dtype) block_size = self.cache_config.block_size num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) @@ -425,7 +432,8 @@ class NativeModel(MsModelBase): dyn_batch_valid_length, dyn_q_seq_lens, dyn_block_tables, - dyn_intermediate_tensors, + dyn_hidden_states, + dyn_residual, dyn_inputs_embeds) dynamic_hidden_states = Tensor(shape=[None, None], @@ -437,6 +445,14 @@ class NativeModel(MsModelBase): inputs_embeds): model_inputs, is_prefill = self.prepare_base_inputs( input_ids, positions) + + #for pp + if intermediate_tensors is not None: + model_inputs["hidden_states"] = intermediate_tensors["hidden_states"] + model_inputs["residual"] = intermediate_tensors["residual"] + else: + model_inputs["hidden_states"] = None + model_inputs["residual"] = None # for multimodal model model_inputs["intermediate_tensors"] = intermediate_tensors @@ -479,7 +495,8 @@ class NativeModel(MsModelBase): batch_valid_length=model_inputs["batch_valid_length"], q_seq_lens=model_inputs["q_seq_lens"], block_tables=model_inputs["block_tables"], - intermediate_tensors=model_inputs["intermediate_tensors"], + hidden_states=model_inputs["hidden_states"], + residual=model_inputs["residual"], inputs_embeds=model_inputs["inputs_embeds"], ) diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index 87c54c21..68f23d25 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -30,7 +30,7 @@ from vllm.attention.backends.abstract import AttentionType from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.models.interfaces import SupportsLoRA +from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP from vllm.sequence import IntermediateTensors from vllm_mindspore.attention import Attention @@ -322,18 +322,16 @@ class Qwen2Model(nn.Cell): batch_valid_length: Tensor, q_seq_lens: Tensor, block_tables: Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, + hidden_states: Optional[Tensor] = None, + residual: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None, - ) -> Union[Tensor, IntermediateTensors]: + ) -> Tuple[Tensor, Tensor]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) residual = None - else: - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] for i in range(self.start_layer, self.end_layer): # PP 并行对层进行切分 layer = self.layers[i] @@ -343,13 +341,10 @@ class Qwen2Model(nn.Cell): is_prefill, slot_mapping, attn_mask, batch_valid_length, q_seq_lens, block_tables, residual) - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states + + if get_pp_group().is_last_rank: + hidden_states, residual = self.norm(hidden_states, residual) + return hidden_states, residual def load_weights(self, weights: Iterable[Tuple[str, Tensor]], params_dict: Dict[str, Parameter]): @@ -403,7 +398,7 @@ class Qwen2Model(nn.Cell): return loaded_params -class Qwen2ForCausalLM(NativeModel, SupportsLoRA): +class Qwen2ForCausalLM(NativeModel, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -465,8 +460,13 @@ class Qwen2ForCausalLM(NativeModel, SupportsLoRA): intermediate_tensors: IntermediateTensors = None, inputs_embeds: Tensor = None, **kwargs) -> Union[Tensor, IntermediateTensors]: - hidden_states = self.exec_model(input_ids, positions, + hidden_states, residual = self.exec_model(input_ids, positions, intermediate_tensors, inputs_embeds) + if not get_pp_group().is_first_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual, + }) return hidden_states def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index e4ab9fca..28c325a5 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -33,6 +33,7 @@ else: import mindspore as ms from mindspore import dtype as mstype from mindspore.common.initializer import Zero +from mindspore._c_expression import typing from vllm.logger import init_logger from vllm.utils import (TORCH_DTYPE_TO_NUMPY_DTYPE, MemoryProfilingResult, MemorySnapshot, T, make_ndarray_with_pad) @@ -303,3 +304,30 @@ def ms_memory_profiling( result.non_torch_increase = diff_from_create.non_torch_memory result.profile_time = diff_profile.timestamp result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa + + +def view(self, *shape_or_dtype): + if len(shape_or_dtype) == 1 and isinstance(shape_or_dtype[0], typing.Type): + target_dtype = shape_or_dtype[0] + ori_shape = self.shape + target_shape = (-1,) + if len(ori_shape) > 1: + target_shape = ori_shape[:-1] + target_shape + out = np.frombuffer(self.numpy(), torch.ops.creation._TypeDict.get(target_dtype, np.float32)) + if not out.flags.aligned: + out = np.require(out, requirements=["ALIGNED"]) + if target_dtype == ms.bfloat16: + return ms.Tensor.from_numpy(out.astype(np.float32)).astype(target_dtype).reshape(target_shape) + return ms.Tensor.from_numpy(out).reshape(target_shape) + result = [] + if type(shape_or_dtype) is tuple: + for items in shape_or_dtype: + if not isinstance(items, int): + for item in items: + if not isinstance(item, int): + result.append(item.item()) + else: + result.append(item) + else: + result.append(items) + return ms.ops.reshape(self, result) diff --git a/vllm_mindspore/v1/worker/gpu_worker.py b/vllm_mindspore/v1/worker/gpu_worker.py index bb77182e..df417c8b 100644 --- a/vllm_mindspore/v1/worker/gpu_worker.py +++ b/vllm_mindspore/v1/worker/gpu_worker.py @@ -51,6 +51,9 @@ def compile_or_warm_up_model(self) -> None: # MindSpore does not support cuda graph. No need to warm up the model. # Since prefill is done previously, we do decode here. default_max_num_reqs = 1 # For MindSpore, we only do one more decode here. + # Only pp_last_rank requires _dummy_sampler_run, and only pp_last_rank can _dummy_sampler_run. if get_pp_group().is_last_rank: self.model_runner._dummy_sampler_run(self.model_runner._dummy_run( num_tokens=default_max_num_reqs)) + else: + self.model_runner._dummy_run(num_tokens=default_max_num_reqs) diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index 5eed6136..740aa059 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -27,6 +27,7 @@ from vllm_mindspore.utils import get_valid_dtype from vllm.model_executor import set_random_seed from vllm.sequence import SequenceGroupMetadata from vllm.sampling_params import SamplingParams +from vllm.distributed import get_pp_group logger = init_logger(__name__) @@ -72,6 +73,17 @@ def _warm_up_model(self) -> None: # cache_engine is a list with length equal to the size of pipeline-parallel, and only pp=1 is supported. kv_cache = self.cache_engine[0].gpu_cache is_mtp_model = self.speculative_config is not None and self.model_config.hf_config.model_type == "deepseek_mtp" + intermediate_tensors = None + if self.vllm_config.scheduler_config.is_multi_step: + make_empty_intermediate_tensors = self.model_runner._base_model_runner.model.make_empty_intermediate_tensors + else: + make_empty_intermediate_tensors = self.model_runner.model.make_empty_intermediate_tensors + if not get_pp_group().is_first_rank: + intermediate_tensors = make_empty_intermediate_tensors( + batch_size=1, + dtype=self.model_config.dtype, + device=self.devices, + ) if is_mtp_model: # prefill mtp model model_input, previous_hidden_states = _prepare_input_for_warmup( @@ -80,7 +92,7 @@ def _warm_up_model(self) -> None: self.model_runner.execute_model( model_input, kv_cache, - None, + intermediate_tensors, previous_hidden_states=previous_hidden_states) # warmup for decode @@ -89,7 +101,7 @@ def _warm_up_model(self) -> None: self.model_config, self.model_runner._base_model_runner, self.cache_engine[0], False) self.model_runner._base_model_runner.execute_model( - model_input, kv_cache, None) + model_input, kv_cache, intermediate_tensors) else: model_input, previous_hidden_states = _prepare_input_for_warmup( self.model_config, self.model_runner, self.cache_engine[0], False, @@ -97,7 +109,7 @@ def _warm_up_model(self) -> None: self.model_runner.execute_model( model_input, kv_cache, - None, + intermediate_tensors, previous_hidden_states=previous_hidden_states) torch.cuda.synchronize() -- Gitee From 4780b18d446efba444b16dfdf1de969ff2ba9529 Mon Sep 17 00:00:00 2001 From: zhanzhan1 Date: Thu, 17 Jul 2025 16:01:23 +0800 Subject: [PATCH 2/2] support partitions to offset --- .../models/mf_models/mf_model_base.py | 3 ++ .../models/mf_models/weight_processor.py | 8 ++--- vllm_mindspore/model_executor/models/utils.py | 31 +++++++++++++++++++ 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py index ed16ec45..9fa04610 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py @@ -40,6 +40,7 @@ from vllm_mindspore.model_executor.models.attention_mask import ( LowerTriangularMask) from mindspore.common.api import _pynative_executor from vllm_mindspore.model_executor.models.model_base import MsModelBase +from vllm_mindspore.model_executor.models.utils import get_mf_offset try: # Need to apply dllm pd patch on vllm to use pd disagg related functions @@ -79,6 +80,8 @@ class MfModelBase(MsModelBase): get_tensor_model_parallel_world_size()) self.mf_config.model.model_config.parallel_config.pipeline_stage = ( get_pp_group().world_size) + self.mf_config.model.model_config.offset = get_mf_offset( + self.mf_config.model.model_config) self._generate_model_config() self.casual_mask = LowerTriangularMask( dtype=self.mf_model_config.compute_dtype, diff --git a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py index c627dac7..94d2eb2d 100644 --- a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py @@ -23,8 +23,8 @@ from enum import Enum from mindformers.parallel_core.inference.parallel_state import ( get_data_parallel_world_size, get_moe_expert_parallel_rank, get_moe_tensor_parallel_rank, get_pipeline_model_parallel_world_size, - get_tensor_and_data_model_parallel_rank, - get_tensor_and_data_model_parallel_world_size, + get_tensor_and_data_parallel_rank, + get_tensor_and_data_parallel_world_size, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from mindspore.communication.management import get_group_size, get_rank from safetensors import safe_open @@ -57,8 +57,8 @@ class BaseWeightProcessor: self.global_group_size = get_group_size() self.tp_group_size = get_tensor_model_parallel_world_size() self.dp_group_size = get_data_parallel_world_size() - self.tp_dp_group_size = get_tensor_and_data_model_parallel_world_size() - self.tp_dp_gourp_id = get_tensor_and_data_model_parallel_rank() + self.tp_dp_group_size = get_tensor_and_data_parallel_world_size() + self.tp_dp_gourp_id = get_tensor_and_data_parallel_rank() self.pp_group_size = get_pipeline_model_parallel_world_size() self.num_router_experts = self.config.moe_config.expert_num if self.config.moe_config.expert_num else 1 self.moe_ep_size = self.config.parallel_config.expert_parallel \ diff --git a/vllm_mindspore/model_executor/models/utils.py b/vllm_mindspore/model_executor/models/utils.py index 493664cd..f796ba92 100644 --- a/vllm_mindspore/model_executor/models/utils.py +++ b/vllm_mindspore/model_executor/models/utils.py @@ -20,9 +20,12 @@ from dataclasses import dataclass, field from typing import Iterable, List, Mapping, Optional, Tuple, Union +import numpy as np + import mindspore as ms from mindspore import mint, ops from vllm.sequence import IntermediateTensors +import vllm.envs as envs from vllm_mindspore.multimodal.inputs import NestedTensors # type: ignore[attr-defined] from vllm_mindspore.utils import get_valid_dtype @@ -261,3 +264,31 @@ def merge_multimodal_embeddings( (input_ids == placeholder_token_id), multimodal_embeddings, ) + +def get_mf_offset(model_config): + """ get mindformers offset from vllm style""" + partition_list_str = envs.VLLM_PP_LAYER_PARTITION + num_layers = model_config.num_layers + pp_size = model_config.parallel_config.pipeline_stage + if partition_list_str is not None: + try: + partitions = [ + int(layer) for layer in partition_list_str.split(",") + ] + except ValueError as err: + raise ValueError("Invalid partition string: {}".format( + partition_list_str)) from err + if len(partitions) != pp_size: + raise ValueError(f"{len(partitions)=} does not match {pp_size=}.") + if sum(partitions) != num_layers: + raise ValueError( + f"{sum(partitions)=} does not match {num_layers=}.") + partitions = np.array(partitions, dtype=np.int32) + avg_layers = num_layers // pp_size + avg_layers_list = np.ones((pp_size, ), dtype=np.int32) * avg_layers + if (partitions == avg_layers_list).all(): + return 0 + else: + return (partitions - avg_layers_list).tolist() + else: + return 0 -- Gitee