diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py index 51f0d985837485994528ef2afac6a4abd9237087..b44ee04d4b2b58dc092495c8b99568895778b0a4 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py @@ -27,7 +27,7 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size, get_pp_group from vllm.logger import init_logger import torch @@ -71,7 +71,7 @@ class MfModelBase(MsModelBase): self.mf_config.model.model_config.parallel_config.model_parallel = ( get_tensor_model_parallel_world_size() ) - self.mf_config.model.model_config.parallel_config.pipeline_stage = 1 + self.mf_config.model.model_config.parallel_config.pipeline_stage = get_pp_group().world_size self._generate_model_config() self.casual_mask = LowerTriangularMask(dtype=self.mf_model_config.compute_dtype, max_model_len=self.mf_model_config.seq_length) @@ -92,8 +92,9 @@ class MfModelBase(MsModelBase): def _set_dynamic_inputs(self): self.network.set_dynamic_inputs() - dynamic_hidden_states = Tensor(shape=[None, None], dtype=self.mf_model_config.compute_dtype) - self.lm_head.set_inputs(dynamic_hidden_states) + if get_pp_group().is_last_rank: + dynamic_hidden_states = Tensor(shape=[None, None], dtype=self.mf_model_config.compute_dtype) + self.lm_head.set_inputs(dynamic_hidden_states) def prepare_inputs(self, input_ids, positions, attn_metadata): key_cache, value_cache = self.get_kvcache() @@ -147,6 +148,7 @@ class MfModelBase(MsModelBase): ) -> Union[Tensor, IntermediateTensors]: model_inputs, is_prefill = self.prepare_inputs(input_ids, positions, attn_metadata) model_inputs = self.update_model_inputs(model_inputs, **kwargs) + model_inputs["hidden_states"] = intermediate_tensors["hidden_states"] if intermediate_tensors else None if is_prefill: self.network.phase = "prefill" @@ -159,6 +161,10 @@ class MfModelBase(MsModelBase): self.set_flags = True else: hidden_states = self.network(**model_inputs) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + }) return hidden_states @@ -172,6 +178,7 @@ class MfModelBase(MsModelBase): logits = ms.mint.zeros((0, self.mf_model_config.vocab_size), dtype=self.mf_model_config.compute_dtype) else: + hidden_states = hidden_states.view(-1, self.mf_model_config.hidden_size) hidden_states = hidden_states.index_select(0, selected_token_indices) logits = self.lm_head(hidden_states) logits = logits.reshape(-1, logits.shape[-1]) diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen2.py b/vllm_mindspore/model_executor/models/mf_models/qwen2.py index e7475bb0d5c9a204f74961fd86cd6b7339b52393..9175094f2717bf1916cd046dccd46c628fb03a44 100644 --- a/vllm_mindspore/model_executor/models/mf_models/qwen2.py +++ b/vllm_mindspore/model_executor/models/mf_models/qwen2.py @@ -21,6 +21,8 @@ from typing import Iterable, Set, Tuple from vllm.config import VllmConfig from vllm.config import get_current_vllm_config from vllm.logger import init_logger +from vllm.distributed.parallel_state import get_pp_group +from vllm.model_executor.models.interfaces import SupportsPP from mindspore import Tensor, JitConfig from mindspore.nn.utils import no_init_parameters @@ -34,29 +36,35 @@ from vllm_mindspore.model_executor.layers.sampler import get_sampler from vllm_mindspore.model_executor.models.model_base import Fake_Attention from vllm_mindspore.model_executor.models.mf_models.mf_model_base import MfModelBase from vllm_mindspore.model_executor.models.mf_models.qwen2_weight_processor import Qwen2WeightProcessor +from vllm_mindspore.model_executor.models.utils import make_empty_intermediate_tensors_factory logger = init_logger(__name__) -class Qwen2ForCausalLM(MfModelBase): +class Qwen2ForCausalLM(MfModelBase, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super(Qwen2ForCausalLM, self).__init__(vllm_config=vllm_config, prefix=prefix) self.mf_kvcaches_init = False self.sampler = get_sampler() self.set_modules({"model": self.network}) + self.num_layers = self.model_config.get_num_layers(self.parallel_config) - self.kv_caches = [Fake_Attention() for i in range(self.mf_model_config.num_layers)] + self.kv_caches = [Fake_Attention() for i in range(self.num_layers)] compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") - for i in range(self.mf_model_config.num_layers): + for i in range(self.num_layers): compilation_config.static_forward_context[str(i)] = self.kv_caches[i] self.set_flags = False + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], self.mf_model_config.hidden_size + ) + def _generate_model_config(self): self.mf_config.load_checkpoint = self.get_model_path() self.mf_model_config = LlamaConfig_MF(**self.mf_config.model.model_config) @@ -73,7 +81,9 @@ class Qwen2ForCausalLM(MfModelBase): # Initial network with no_init_parameters(): # Delay initialization network = ParallelQwenForCausalLM_MF(self.mf_model_config) - return network, network.lm_head + if get_pp_group().is_last_rank: + return network, network.lm_head + return network, None def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: weight_processor = Qwen2WeightProcessor(self.mf_config, self.network, False) diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen2_weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/qwen2_weight_processor.py index 59423eca036f17d4c1739735dc2ffe74c7d96b21..6f08f38581b0c69c4a798c2b94b1b4dd9da5e0ba 100644 --- a/vllm_mindspore/model_executor/models/mf_models/qwen2_weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/qwen2_weight_processor.py @@ -48,7 +48,7 @@ class Qwen2WeightProcessor(BaseWeightProcessor): np_data, _ = self.get_safetensor_from_file(embed_tokens_hf_name, src_hf_dir, hf_weight_map) else: np_data, _ = self.get_safetensor_from_file(embed_tokens_hf_name, src_hf_dir, hf_weight_map, - is_split_param=True, split_axis=0) + is_split_param=self.is_split_param, split_axis=0) self.parameter_dict[embed_tokens_ms_name] = ms.Parameter(ms.from_numpy(np_data).astype(ms.bfloat16), name=embed_tokens_ms_name, requires_grad=False) @@ -65,7 +65,7 @@ class Qwen2WeightProcessor(BaseWeightProcessor): if not self.config.model.model_config.tie_word_embeddings: if not self.config.parallel_config.vocab_emb_dp: np_data, _ = self.get_safetensor_from_file(lm_head_hf_name, src_hf_dir, hf_weight_map, - is_split_param=True, split_axis=0) + is_split_param=self.is_split_param, split_axis=0) else: np_data, _ = self.get_safetensor_from_file(lm_head_hf_name, src_hf_dir, hf_weight_map) self.parameter_dict[lm_head_ms_name] = ms.Parameter(ms.from_numpy(np_data).astype(ms.bfloat16), @@ -94,17 +94,17 @@ class Qwen2WeightProcessor(BaseWeightProcessor): ffn_concat = self.config.model.model_config.qkv_concat w1_hf_name = f"model.layers.{layer_id}.mlp.gate_proj.weight" w1_ms_name = self.convert_weight_name(w1_hf_name) - w1_ms_param, _ = self.get_safetensor_from_file(w1_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, + w1_ms_param, _ = self.get_safetensor_from_file(w1_hf_name, src_hf_dir, hf_weight_map, is_split_param=self.is_split_param, split_axis=0) w2_hf_name = f"model.layers.{layer_id}.mlp.down_proj.weight" w2_ms_name = self.convert_weight_name(w2_hf_name) - w2_ms_param, _ = self.get_safetensor_from_file(w2_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, + w2_ms_param, _ = self.get_safetensor_from_file(w2_hf_name, src_hf_dir, hf_weight_map, is_split_param=self.is_split_param, split_axis=1) w3_hf_name = f"model.layers.{layer_id}.mlp.up_proj.weight" w3_ms_name = self.convert_weight_name(w3_hf_name) - w3_ms_param, _ = self.get_safetensor_from_file(w3_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, + w3_ms_param, _ = self.get_safetensor_from_file(w3_hf_name, src_hf_dir, hf_weight_map, is_split_param=self.is_split_param, split_axis=0) if ffn_concat: @@ -130,37 +130,37 @@ class Qwen2WeightProcessor(BaseWeightProcessor): # wq wq_hf_name = f"model.layers.{layer_id}.self_attn.q_proj.weight" wq_ms_name = self.convert_weight_name(wq_hf_name) - wq_ms_param, _ = self.get_safetensor_from_file(wq_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, + wq_ms_param, _ = self.get_safetensor_from_file(wq_hf_name, src_hf_dir, hf_weight_map, is_split_param=self.is_split_param, split_axis=0) # wq bias wq_bias_hf_name = f"model.layers.{layer_id}.self_attn.q_proj.bias" wq_bias_ms_name = self.convert_weight_name(wq_bias_hf_name) wq_bias_ms_param, _ = self.get_safetensor_from_file(wq_bias_hf_name, src_hf_dir, hf_weight_map, - is_split_param=True, + is_split_param=self.is_split_param, split_axis=0) # wk wk_hf_name = f"model.layers.{layer_id}.self_attn.k_proj.weight" wk_ms_name = self.convert_weight_name(wk_hf_name) - wk_ms_param, _ = self.get_safetensor_from_file(wk_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, + wk_ms_param, _ = self.get_safetensor_from_file(wk_hf_name, src_hf_dir, hf_weight_map, is_split_param=self.is_split_param, split_axis=0) # wk bias wk_bias_hf_name = f"model.layers.{layer_id}.self_attn.k_proj.bias" wk_bias_ms_name = self.convert_weight_name(wk_bias_hf_name) wk_bias_ms_param, _ = self.get_safetensor_from_file(wk_bias_hf_name, src_hf_dir, hf_weight_map, - is_split_param=True, + is_split_param=self.is_split_param, split_axis=0) # wv wv_hf_name = f"model.layers.{layer_id}.self_attn.v_proj.weight" wv_ms_name = self.convert_weight_name(wv_hf_name) - wv_ms_param, _ = self.get_safetensor_from_file(wv_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, + wv_ms_param, _ = self.get_safetensor_from_file(wv_hf_name, src_hf_dir, hf_weight_map, is_split_param=self.is_split_param, split_axis=0) # wv bias wv_bias_hf_name = f"model.layers.{layer_id}.self_attn.v_proj.bias" wv_bias_ms_name = self.convert_weight_name(wv_bias_hf_name) wv_bias_ms_param, _ = self.get_safetensor_from_file(wv_bias_hf_name, src_hf_dir, hf_weight_map, - is_split_param=True, + is_split_param=self.is_split_param, split_axis=0) if qkv_concat: @@ -201,7 +201,7 @@ class Qwen2WeightProcessor(BaseWeightProcessor): # wo wo_hf_name = f"model.layers.{layer_id}.self_attn.o_proj.weight" wo_ms_name = self.convert_weight_name(wo_hf_name) - wo_ms_param, _ = self.get_safetensor_from_file(wo_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, + wo_ms_param, _ = self.get_safetensor_from_file(wo_hf_name, src_hf_dir, hf_weight_map, is_split_param=self.is_split_param, split_axis=1) self.parameter_dict[wo_ms_name] = ms.Parameter(ms.from_numpy(wo_ms_param).astype(ms.bfloat16), name=wo_ms_name, diff --git a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py index 9b0aab3a177323bb584f268061a06f9213070494..228dd7912a80cde70e0f46524025a341fe2c7cea 100644 --- a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py @@ -19,7 +19,8 @@ transform huggingface safetensor. import os from safetensors import safe_open -from mindspore.communication.management import get_rank, get_group_size + +from vllm.distributed import get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank class BaseWeightProcessor: @@ -35,11 +36,13 @@ class BaseWeightProcessor: self.config = config self.network = network self.is_quant = is_quant - self.tp_group_size = get_group_size() - self.rank_id = get_rank() + self.tp_group_size = get_tensor_model_parallel_world_size() + self.rank_id = get_tensor_model_parallel_rank() self.parameter_dict = {} self.file_handles = {} + self.is_split_param = self.tp_group_size > 1 + def get_file_handles(self, filename): if filename not in self.file_handles: fp = safe_open(filename, framework="np") diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 961f54a2df96ec829ec75980af8994c6055c85e3..40f5fada19b0e419ebb759b98797c514c2cef5e0 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -190,7 +190,8 @@ class MsModelBase(): key_cache = [] value_cache = [] forward_context = get_forward_context() - for i in range(self.config.num_hidden_layers): + num_layers = self.model_config.get_num_layers(self.parallel_config) + for i in range(num_layers): k_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][0] v_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][1] key_cache.append(k_cache) diff --git a/vllm_mindspore/model_executor/models/utils.py b/vllm_mindspore/model_executor/models/utils.py index 4bb7831c584c03b61bda9e0d751d32be934db19b..adcc65bad257764ead33fe623a746cac22576458 100644 --- a/vllm_mindspore/model_executor/models/utils.py +++ b/vllm_mindspore/model_executor/models/utils.py @@ -158,7 +158,7 @@ def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int): ) -> IntermediateTensors: dtype = get_valid_dtype(dtype) return IntermediateTensors( - {key: mint.zeros((batch_size, hidden_size), dtype=dtype) for key in keys} + {key: mint.zeros((batch_size, 1, hidden_size), dtype=dtype) for key in keys} ) return make_empty_intermediate_tensors diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index 8ce1bc91d511a43a83fd3c8b0e70d228b98b951b..e58f8d17947b1d33607b0b1fc13289b5bec53a78 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -31,6 +31,7 @@ from vllm.distributed import ( init_distributed_environment, set_custom_all_reduce, ) +from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger @@ -74,23 +75,30 @@ def _prepare_input_for_warmup(model_config, model_runner, cache_engine, is_prefi def _warm_up_model(self) -> None: # cache_engine is a list with length equal to the size of pipeline-parallel, and only pp=1 is supported. + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = self.model_runner.model.make_empty_intermediate_tensors( + batch_size=1, + dtype=self.model_config.dtype, + device=self.device, + ) kv_cache = self.cache_engine[0].gpu_cache is_mtp_model = self.speculative_config is not None and self.model_config.hf_config.model_type == "deepseek_mtp" if is_mtp_model: # prefill mtp model model_input, previous_hidden_states = _prepare_input_for_warmup(self.model_config, self.model_runner, self.cache_engine[0], True, is_mtp_model) - self.model_runner.execute_model(model_input, kv_cache, None, previous_hidden_states=previous_hidden_states) + self.model_runner.execute_model(model_input, kv_cache, intermediate_tensors, previous_hidden_states=previous_hidden_states) # warmup for decode if self.vllm_config.scheduler_config.is_multi_step: model_input, _ = _prepare_input_for_warmup(self.model_config, self.model_runner._base_model_runner, self.cache_engine[0], False) - self.model_runner._base_model_runner.execute_model(model_input, kv_cache, None) + self.model_runner._base_model_runner.execute_model(model_input, kv_cache, intermediate_tensors) else: model_input, previous_hidden_states = _prepare_input_for_warmup(self.model_config, self.model_runner, self.cache_engine[0], False, is_mtp_model) - self.model_runner.execute_model(model_input, kv_cache, None, previous_hidden_states=previous_hidden_states) + self.model_runner.execute_model(model_input, kv_cache, intermediate_tensors, previous_hidden_states=previous_hidden_states) torch.cuda.synchronize()