diff --git a/vllm_mindspore/model_executor/layers/layernorm.py b/vllm_mindspore/model_executor/layers/layernorm.py index bc6dfe11015f36947ef55d019da15ad2747dd541..761bbd381ba3b8f520171c03d0d787f0c9451830 100644 --- a/vllm_mindspore/model_executor/layers/layernorm.py +++ b/vllm_mindspore/model_executor/layers/layernorm.py @@ -23,6 +23,7 @@ from typing import Optional, Union from mindspore import Parameter, Tensor, mint, nn, ops from mindspore._c_expression import typing +from mindspore.ops.auto_generate.gen_ops_prim import AddRmsNorm from vllm.config import get_current_vllm_config @@ -45,6 +46,8 @@ class RMSNorm(nn.Cell): params_dtype = get_current_vllm_config().model_config.dtype self.weight = Parameter(mint.ones(hidden_size, dtype=params_dtype)) self.rms_norm = ops.RmsNorm(eps) + self.eps = eps + self.add_rms_norm = AddRmsNorm() def construct( self, @@ -52,9 +55,8 @@ class RMSNorm(nn.Cell): residual: Optional[Tensor] = None ) -> Union[Tensor, tuple[Tensor, Tensor]]: if residual is not None: - x = x + residual - residual = x + output, _, residual = self.add_rms_norm(x, residual, self.weight, + self.eps) + return output, residual output = self.rms_norm(x, self.weight)[0] - if residual is None: - return output - return output, residual + return output diff --git a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py index ee725f9af8353c6430c4a3515a38e7006feb6251..556afe4676d9ea11002325ac50344457fb2a54f9 100644 --- a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py @@ -346,13 +346,14 @@ class VocabParallelEmbedding(nn.Cell): # If parameter does not have output dim, then it should # be copied onto all gpus (e.g. g_idx for act_order gptq). if output_dim is None: + loaded_weight = loaded_weight[:] assert param.data.shape == loaded_weight.shape if param.data.shape != loaded_weight.shape: raise ValueError( f"'param.data.shape' should be equal to " f"'loaded_weight.shape', but got {param.data.shape} " f"and {loaded_weight.shape}") - param.set_data(loaded_weight) + param.set_data(ms.from_numpy(loaded_weight)) return # Shard indexes for loading the weight diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index e7384143e78cee262603879d2cb7bc9652842b10..d1fdcaafce4770ca00932cf364b02e596af649d0 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -346,6 +346,7 @@ class NativeModel(MsModelBase): self.is_graph_mode = not vllm_config.model_config.enforce_eager self.prev_prefill = False self.run_model = None + self.dynamic_tensors_has_init = False @property def ready_model(self) -> nn.Cell: @@ -376,30 +377,31 @@ class NativeModel(MsModelBase): compilation_config.static_forward_context[str( i)] = self.kv_caches[i] - def set_model_inputs(self, input_ids, position_ids, intermediate_tensors, - inputs_embeds, is_prefill): + def _init_base_dynamic_tensors(self, input_ids, position_ids, + intermediate_tensors, inputs_embeds): if input_ids is None: - dyn_input_ids = None + self.dyn_input_ids = None else: - dyn_input_ids = ms.Tensor(shape=[None] * input_ids.ndim, - dtype=mstype.int32) + self.dyn_input_ids = ms.Tensor(shape=[None] * input_ids.ndim, + dtype=mstype.int32) if position_ids is None: - dyn_position_ids = None + self.dyn_position_ids = None else: - dyn_position_ids = ms.Tensor(shape=[None] * position_ids.ndim, - dtype=mstype.int32) + self.dyn_position_ids = ms.Tensor(shape=[None] * position_ids.ndim, + dtype=mstype.int32) if inputs_embeds is None: - dyn_inputs_embeds = None + self.dyn_inputs_embeds = None else: - dyn_inputs_embeds = ms.Tensor(shape=[None] * inputs_embeds.ndim, - dtype=inputs_embeds.dtype) + self.dyn_inputs_embeds = ms.Tensor(shape=[None] * + inputs_embeds.ndim, + dtype=inputs_embeds.dtype) if intermediate_tensors is None: - dyn_intermediate_tensors = None + self.dyn_intermediate_tensors = None else: - dyn_intermediate_tensors = ms.Tensor( + self.dyn_intermediate_tensors = ms.Tensor( shape=[None] * intermediate_tensors.ndim, dtype=intermediate_tensors.dtype) @@ -418,25 +420,45 @@ class NativeModel(MsModelBase): dyn_key_cache = Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype) dyn_value_cache = Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype) - dyn_key_caches = mutable([dyn_key_cache for _ in range(num_layers)]) - dyn_value_caches = mutable( + self.dyn_key_caches = mutable( + [dyn_key_cache for _ in range(num_layers)]) + self.dyn_value_caches = mutable( [dyn_value_cache for _ in range(num_layers)]) - dyn_slot_mapping = Tensor(shape=[None], dtype=mstype.int32) - dynamic_attention_mask = Tensor(shape=[None, None], - dtype=self.model_config.dtype) - dyn_batch_valid_length = Tensor(shape=[None], dtype=mstype.int32) - dyn_q_seq_lens = Tensor(shape=[None], dtype=mstype.int32) - dyn_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) - self.ready_model.set_inputs( - dyn_input_ids, dyn_position_ids, dyn_key_caches, dyn_value_caches, - is_prefill, dyn_slot_mapping, dynamic_attention_mask, - dyn_batch_valid_length, dyn_q_seq_lens, dyn_block_tables, - dyn_intermediate_tensors, dyn_inputs_embeds) + self.dyn_slot_mapping = Tensor(shape=[ + None, + ], dtype=mstype.int32) + self.dynamic_attention_mask = Tensor(shape=[None, None], + dtype=self.model_config.dtype) + self.dyn_batch_valid_length = Tensor(shape=[ + None, + ], + dtype=mstype.int32) + self.dyn_q_seq_lens = Tensor(shape=[ + None, + ], dtype=mstype.int32) + self.dyn_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) + self.dynamic_hidden_states = Tensor(shape=[None, None], + dtype=self.model_config.dtype) - dynamic_hidden_states = Tensor(shape=[None, None], - dtype=self.model_config.dtype) - self.ready_lm_head.set_inputs(dynamic_hidden_states) + def set_model_inputs(self, input_ids, position_ids, intermediate_tensors, + inputs_embeds, is_prefill): + if not self.dynamic_tensors_has_init: + self._init_base_dynamic_tensors(input_ids, position_ids, + intermediate_tensors, + inputs_embeds) + self.ready_lm_head.set_inputs(self.dynamic_hidden_states) + self.dynamic_tensors_has_init = True + + self.ready_model.set_inputs( + self.dyn_input_ids, self.dyn_position_ids, self.dyn_key_caches, + self.dyn_value_caches, is_prefill, self.dyn_slot_mapping, + self.dynamic_attention_mask, self.dyn_batch_valid_length, + self.dyn_q_seq_lens, self.dyn_block_tables, + self.dyn_intermediate_tensors, self.dyn_inputs_embeds) + # By setting the phase, some processes in jit compilation + # are skipped to achieve acceleration. + self.ready_model.phase = "prefill" if is_prefill else "increment" def prepare_inputs(self, input_ids, positions, intermediate_tensors, inputs_embeds): @@ -459,21 +481,22 @@ class NativeModel(MsModelBase): intermediate_tensors, inputs_embeds) - if self.prev_prefill != is_prefill and self.is_graph_mode: + if self.is_graph_mode and self.prev_prefill != is_prefill: self.set_model_inputs(input_ids, positions, intermediate_tensors, inputs_embeds, is_prefill) - self.prev_prefill = is_prefill + self.prev_prefill = is_prefill # for dummy_attention_metadata - if is_prefill and not self.set_flags: + if not self.set_flags and is_prefill: self.set_flags = True if self.run_model is None: + if self.model is None: + raise RuntimeError("model is not initialized") self.run_model = ms.jit( function=self.model, jit_level='O0') if self.is_graph_mode else self.model - if self.model is None: - raise RuntimeError("model is not initialized") + model_output = self.run_model( input_ids=model_inputs["input_ids"], positions=model_inputs["position_ids"],