From b6efe38a789b1f3b4a02c8235a9f1bc6285392e7 Mon Sep 17 00:00:00 2001 From: alien_0119 Date: Wed, 11 Jun 2025 09:43:50 +0800 Subject: [PATCH] adapt v0.8.3 --- .../models/mindone_models/base.py | 46 +++++- .../models/mindone_models/qwen2.py | 145 ++++-------------- .../models/mindone_models/qwen2_5_vl.py | 130 +++++++++------- .../models/mindone_models/qwen3.py | 142 ++++------------- vllm_mindspore/worker/worker.py | 9 +- 5 files changed, 188 insertions(+), 284 deletions(-) diff --git a/vllm_mindspore/model_executor/models/mindone_models/base.py b/vllm_mindspore/model_executor/models/mindone_models/base.py index c0df6280..c715f20f 100644 --- a/vllm_mindspore/model_executor/models/mindone_models/base.py +++ b/vllm_mindspore/model_executor/models/mindone_models/base.py @@ -14,9 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -from vllm.config import VllmConfig +from typing import Optional +from mindspore import Tensor, ops -from vllm_mindspore.model_executor.models.model_base import MsModelBase +from vllm.config import VllmConfig +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm_mindspore.model_executor.layers.sampler import SamplerOutput +from vllm_mindspore.model_executor.models.model_base import AttentionWrapper, MsModelBase +from vllm_mindspore.model_executor.models.attention_mask import LowerTriangularMask class MindONEModelBase(MsModelBase): @@ -91,3 +96,40 @@ class MindONEModelBase(MsModelBase): def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) + + def common_preprocess(self, vllm_config, prefix = ""): + self.set_modules({"model": self.model, "lm_head": self.lm_head}) + + self.casual_mask = LowerTriangularMask(dtype=self.model_config.dtype, + max_model_len=self.model_config.max_model_len) + self.kv_caches = [AttentionWrapper() for i in range(self.config.num_hidden_layers)] + + compilation_config = vllm_config.compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + for i in range(self.config.num_hidden_layers): + compilation_config.static_forward_context[str(i)] = self.kv_caches[i] + + def sample(self, logits: Tensor, + sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, *args, **kwargs): + if self.config.tie_word_embeddings: + self.lm_head.weight.set_data( + self.model.embed_tokens.embedding_table.data) + + def compute_logits( + self, + hidden_states: Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[Tensor]: + if sampling_metadata is not None: + if sampling_metadata.selected_token_indices is not None: + hidden_states = ops.gather( + hidden_states, sampling_metadata.selected_token_indices, 0) + + logits = self.lm_head(hidden_states).float() + + return logits diff --git a/vllm_mindspore/model_executor/models/mindone_models/qwen2.py b/vllm_mindspore/model_executor/models/mindone_models/qwen2.py index 7afb0ae3..a46f2b60 100644 --- a/vllm_mindspore/model_executor/models/mindone_models/qwen2.py +++ b/vllm_mindspore/model_executor/models/mindone_models/qwen2.py @@ -21,28 +21,21 @@ if TYPE_CHECKING: else: Qwen2Config = None -import numpy as np from mindone.transformers.models.qwen2.modeling_qwen2 import ( Qwen2Attention, Qwen2MLP, Qwen2PreTrainedModel, Qwen2RMSNorm) -from mindspore import Tensor, jit, mutable, nn, ops +import mindspore as ms +from mindspore import Tensor, mutable, nn from mindspore.common import dtype as mstype -from vllm.attention.backends.abstract import AttentionMetadata, AttentionType +from vllm.attention.backends.abstract import AttentionType from vllm.config import VllmConfig from vllm.sequence import IntermediateTensors from vllm_mindspore.attention import Attention from vllm_mindspore.model_executor.layers.rotary_embedding import get_rope -from vllm_mindspore.model_executor.layers.sampler import (SamplerOutput, - get_sampler) -from vllm_mindspore.model_executor.models.attention_mask import ( - LowerTriangularMask) +from vllm_mindspore.model_executor.layers.sampler import get_sampler from vllm_mindspore.model_executor.models.mindone_models.base import ( MindONEModelBase) -from vllm_mindspore.model_executor.models.mindone_models.utils import ( - enable_dynamic_shape) -from vllm_mindspore.model_executor.models.model_base import Fake_Attention -from vllm_mindspore.model_executor.sampling_metadata import SamplingMetadata -from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE +from vllm_mindspore.model_executor.models.mindone_models.utils import enable_dynamic_shape class vLLMQwen2Attention(Qwen2Attention): @@ -67,7 +60,6 @@ class vLLMQwen2Attention(Qwen2Attention): prefix=f"model.layers.{self.layer_idx}.self_attn.attn", attn_type=AttentionType.DECODER) - @jit def construct( self, positions: Tensor, @@ -105,7 +97,6 @@ class vLLMQwen2DecoderLayer(nn.Cell): self.self_attn = vLLMQwen2Attention(config, layer_idx) - @jit def construct( self, positions: Tensor, @@ -163,7 +154,6 @@ class vLLMQwen2Model(Qwen2PreTrainedModel): def get_input_embeddings(self): return self.embed_tokens - @jit def construct( self, input_ids: Optional[Tensor], @@ -230,24 +220,10 @@ class Qwen2ForCausalLM(MindONEModelBase): self.lora_config = lora_config self.quant_config = quant_config self.sampler = get_sampler() - - self.set_modules({"model": self.model, "lm_head": self.lm_head}) - - self.prefill = True - self.mstype = STR_DTYPE_TO_MS_DTYPE.get(self.model_config.dtype, - self.model_config.dtype) - self.casual_mask = LowerTriangularMask( - dtype=self.mstype, max_model_len=self.model_config.max_model_len) - self.kv_caches = [ - Fake_Attention() for i in range(config.num_hidden_layers) - ] - compilation_config = vllm_config.compilation_config - - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - for i in range(config.num_hidden_layers): - compilation_config.static_forward_context[str( - i)] = self.kv_caches[i] + self.common_preprocess(vllm_config, prefix) + self.prev_prefill = False + self.is_graph_mode = False if vllm_config.model_config.enforce_eager else True + self.run_model = None def get_input_embeddings(self, input_ids: Tensor) -> Tensor: return self.model.get_input_embeddings(input_ids) @@ -255,98 +231,37 @@ class Qwen2ForCausalLM(MindONEModelBase): def forward(self, input_ids: Tensor, positions: Tensor, - kv_caches: List[Tuple[Tensor, Tensor]], - attn_metadata: AttentionMetadata, intermediate_tensors: IntermediateTensors = None, inputs_embeds: Tensor = None, **kwargs) -> Union[Tensor, IntermediateTensors]: - key_cache, value_cache = self.get_kvcache() - - seq_lens = attn_metadata.seq_lens - max_query_len = attn_metadata.max_query_len - # When Mutli-Step is enabled with Chunked-Prefill, prefills and - # decodes are scheduled together. In the first step, all the - # prefills turn into decodes and max_query_len will be 1. - if self.is_multi_step_chunked_prefill and max_query_len == 1: - query_lens = [1] * len(seq_lens) - else: - query_lens = attn_metadata.query_lens - - seq_lens_np = np.array(seq_lens, dtype=np.int32) - query_lens_np = np.array(query_lens, dtype=np.int32) - kv_cache_lens = seq_lens_np - query_lens_np - is_prefill = bool(attn_metadata.num_decode_tokens == 0 - and kv_cache_lens.max() == 0) - if is_prefill: - input_ids = ops.expand_dims(input_ids, 0) - else: - input_ids = ops.expand_dims(input_ids, 1) - - slot_mapping = attn_metadata.slot_mapping - attn_mask = self.casual_mask.gen_attention_mask( - is_prefill, positions, query_lens) - seq_lens_np = np.array(attn_metadata.seq_lens, dtype=np.int32) - batch_valid_length = Tensor.from_numpy(seq_lens_np) - q_seq_lens = Tensor.from_numpy( - np.array(attn_metadata.query_lens, dtype=np.int32)) - block_tables = attn_metadata.block_tables + ori_model_inputs, is_prefill = self.prepare_base_inputs(input_ids, positions) + is_prefill = bool(is_prefill) model_inputs = (\ - input_ids, - positions, - key_cache, - value_cache, + ori_model_inputs["input_ids"], + ori_model_inputs["position_ids"], + ori_model_inputs["key_cache"], + ori_model_inputs["value_cache"], mutable(is_prefill), - slot_mapping, - attn_mask, - batch_valid_length, - q_seq_lens, - block_tables, + ori_model_inputs["slot_mapping"], + ori_model_inputs["attention_mask"], + ori_model_inputs["batch_valid_length"], + ori_model_inputs["q_seq_lens"], + ori_model_inputs["block_tables"], intermediate_tensors, - inputs_embeds + inputs_embeds, ) - if is_prefill: - if not self.prefill: - self.prefill = True - enable_dynamic_shape( - self.model, *model_inputs - ) # enable dynamic shape once on first prefill step - else: - if self.prefill: - self.prefill = False - enable_dynamic_shape( - self.model, *model_inputs - ) # enable dynamic shape once on first decode step + if self.prev_prefill != is_prefill and self.is_graph_mode: + enable_dynamic_shape(self.model, *model_inputs) + self.prev_prefill = is_prefill - model_output = self.model(*model_inputs) + # for dummy_attention_metadata + if is_prefill and not self.set_flags: + self.set_flags = True - if is_prefill: - model_output = ops.squeeze(model_output, 0) - else: - model_output = ops.squeeze(model_output, 1) + if self.run_model is None: + self.run_model = ms.jit(function=self.model, jit_level='O0') if self.is_graph_mode else self.model + model_output = self.model(*model_inputs) return model_output - - def sample(self, logits: Tensor, - sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def load_weights(self, *args, **kwargs): - if self.config.tie_word_embeddings: - self.lm_head.weight.set_data( - self.model.embed_tokens.embedding_table.data) - - def compute_logits( - self, - hidden_states: Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[Tensor]: - if sampling_metadata.selected_token_indices is not None: - hidden_states = ops.gather( - hidden_states, sampling_metadata.selected_token_indices, 0) - - logits = self.lm_head(hidden_states).float() - - return logits diff --git a/vllm_mindspore/model_executor/models/mindone_models/qwen2_5_vl.py b/vllm_mindspore/model_executor/models/mindone_models/qwen2_5_vl.py index 730b63d1..f8e10a2f 100644 --- a/vllm_mindspore/model_executor/models/mindone_models/qwen2_5_vl.py +++ b/vllm_mindspore/model_executor/models/mindone_models/qwen2_5_vl.py @@ -29,8 +29,9 @@ from mindone.transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VLPreTrainedModel) from mindspore import Tensor, mint, mutable, nn, ops from mindspore.common.api import _pynative_executor -from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig +import vllm.envs as envs +from vllm.forward_context import get_forward_context from vllm.model_executor.models.qwen2_5_vl import ( Qwen2_5_VLConfig, Qwen2_5_VLDummyInputsBuilder) from vllm.model_executor.models.qwen2_5_vl import ( @@ -43,6 +44,7 @@ from vllm.model_executor.models.qwen2_5_vl import ( from vllm.model_executor.models.qwen2_5_vl import ( Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs, Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs from vllm.multimodal.parse import MultiModalDataItems @@ -54,7 +56,7 @@ from vllm_mindspore.model_executor.layers.sampler import (SamplerOutput, get_sampler) from vllm_mindspore.model_executor.models.attention_mask import ( LowerTriangularMask) -from vllm_mindspore.model_executor.models.interfaces import SupportsMultiModal +from vllm_mindspore.model_executor.models.interfaces import SupportsMultiModal, MultiModalEmbeddings from vllm_mindspore.model_executor.models.mindone_models.qwen2 import ( MindONEModelBase) from vllm_mindspore.model_executor.models.mindone_models.qwen2 import ( @@ -63,11 +65,11 @@ from vllm_mindspore.model_executor.models.mindone_models.qwen2 import ( vLLMQwen2Model) from vllm_mindspore.model_executor.models.mindone_models.utils import ( enable_dynamic_shape) -from vllm_mindspore.model_executor.models.model_base import Fake_Attention +from vllm_mindspore.model_executor.models.model_base import AttentionWrapper from vllm_mindspore.model_executor.models.utils import ( maybe_prefix, merge_multimodal_embeddings) -from vllm_mindspore.model_executor.sampling_metadata import SamplingMetadata from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE +from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata # yapf:enable @@ -92,9 +94,7 @@ class Qwen2ForCausalLM(vLLM_Qwen2ForCausalLM): self.sampler = get_sampler() self.set_modules({"model": self.model, "lm_head": self.lm_head}) - self.kv_caches = [ - Fake_Attention() for i in range(config.num_hidden_layers) - ] + self.kv_caches = [AttentionWrapper() for i in range(config.num_hidden_layers)] compilation_config = vllm_config.compilation_config if prefix in compilation_config.static_forward_context: @@ -301,12 +301,12 @@ class Qwen2_5_VLForConditionalGeneration(MindONEModelBase, SupportsMultiModal): _process_image_input = vLLM_Qwen2_5_VLForConditionalGeneration._process_image_input _process_video_input = vLLM_Qwen2_5_VLForConditionalGeneration._process_video_input _parse_and_validate_multimodal_inputs = vLLM_Qwen2_5_VLForConditionalGeneration._parse_and_validate_multimodal_inputs - get_multimodal_embeddings = None + get_multimodal_embeddings = vLLM_Qwen2_5_VLForConditionalGeneration.get_multimodal_embeddings def get_input_embeddings( self, input_ids: mindspore.Tensor, - multimodal_embeddings: Optional[tuple[mindspore.Tensor, ...]] = None, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, ) -> mindspore.Tensor: inputs_embeds = self.language_model.model.embed_tokens(input_ids) if multimodal_embeddings is not None: @@ -318,8 +318,8 @@ class Qwen2_5_VLForConditionalGeneration(MindONEModelBase, SupportsMultiModal): def get_input_embeddings_v0( self, input_ids: mindspore.Tensor, - image_input: Optional[tuple[mindspore.Tensor, ...]] = None, - video_input: Optional[tuple[mindspore.Tensor, ...]] = None, + image_input: Optional[Qwen2_5_VLImageInputs] = None, + video_input: Optional[Qwen2_5_VLVideoInputs] = None, ) -> mindspore.Tensor: inputs_embeds = self.get_input_embeddings(input_ids) if image_input is not None: @@ -345,53 +345,53 @@ class Qwen2_5_VLForConditionalGeneration(MindONEModelBase, SupportsMultiModal): self, input_ids: Tensor, positions: Tensor, - kv_caches: List[Tuple[Tensor, Tensor]], - attn_metadata: AttentionMetadata, intermediate_tensors: IntermediateTensors = None, inputs_embeds: Tensor = None, ): key_caches, value_caches = self.language_model.get_kvcache() - seq_lens = attn_metadata.seq_lens - max_query_len = attn_metadata.max_query_len - # When Mutli-Step is enabled with Chunked-Prefill, prefills and - # decodes are scheduled together. In the first step, all the - # prefills turn into decodes and max_query_len will be 1. - if self.is_multi_step_chunked_prefill and max_query_len == 1: - query_lens = [1] * len(seq_lens) - else: - query_lens = attn_metadata.query_lens - - seq_lens_np = np.array(seq_lens, dtype=np.int32) - query_lens_np = np.array(query_lens, dtype=np.int32) - kv_cache_lens = seq_lens_np - query_lens_np - is_prefill = bool(attn_metadata.num_decode_tokens == 0 - and kv_cache_lens.max() == 0) - - if is_prefill > 0: - if input_ids is not None: - input_ids = input_ids.expand_dims(0) + attn_metadata = get_forward_context().attn_metadata + # input_ids = input_ids.to(mindspore.int64) + if attn_metadata is None: + attn_metadata = self._dummy_attention_metadata(input_ids, positions) + if not envs.VLLM_USE_V1: + seq_lens = attn_metadata.seq_lens + max_query_len = attn_metadata.max_query_len + # When Mutli-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes and max_query_len will be 1. + if self.is_multi_step_chunked_prefill and max_query_len == 1: + query_lens = [1] * len(seq_lens) else: - inputs_embeds = inputs_embeds.expand_dims(0) + query_lens = attn_metadata.query_lens + + seq_lens_np = np.array(seq_lens, dtype=np.int32) + query_lens_np = np.array(query_lens, dtype=np.int32) + kv_cache_lens = seq_lens_np - query_lens_np + is_prefill = bool(attn_metadata.num_decode_tokens == 0 and kv_cache_lens.max() == 0) + slot_mapping = attn_metadata.slot_mapping + batch_valid_length = Tensor.from_numpy(np.array(attn_metadata.seq_lens, dtype=np.int32)) + q_seq_lens = mindspore.Tensor(query_lens_np, dtype=mindspore.int32) + block_tables = attn_metadata.block_tables + position_ids = mindspore.Tensor(positions, dtype=mindspore.int32) + attn_mask = self.casual_mask.gen_attention_mask(is_prefill, position_ids, query_lens) else: - if input_ids is not None: - input_ids = input_ids.expand_dims(1) + if attn_metadata.max_context_lens == 0: + is_prefill = True else: - inputs_embeds = inputs_embeds.expand_dims(1) + is_prefill = False + slot_mapping = attn_metadata.slot_mapping + batch_valid_length = Tensor.from_numpy(attn_metadata.seq_lens_np) + block_tables = attn_metadata.block_tables + query_lens_np = attn_metadata.q_seq_lens_np + attn_mask = self.casual_mask.gen_attention_mask(is_prefill, positions, query_lens_np) + q_seq_lens = mindspore.Tensor(query_lens_np, dtype=mindspore.int32) + positions = positions.to(mindspore.int64) if inputs_embeds is None: inputs_embeds = self.language_model.model.embed_tokens(input_ids) input_ids = None - slot_mapping = attn_metadata.slot_mapping - attn_mask = self.casual_mask.gen_attention_mask( - is_prefill, positions, query_lens) - seq_lens_np = np.array(attn_metadata.seq_lens, dtype=np.int32) - batch_valid_length = Tensor.from_numpy(seq_lens_np) - q_seq_lens = Tensor.from_numpy( - np.array(attn_metadata.query_lens, dtype=np.int32)) - block_tables = attn_metadata.block_tables - # keep position.ndim to 2, for work on mindspore dynamic shape if positions.ndim == 1: positions = positions[None] @@ -424,12 +424,11 @@ class Qwen2_5_VLForConditionalGeneration(MindONEModelBase, SupportsMultiModal): self.language_model.model, *model_inputs ) # enable dynamic shape once on first decode step - hidden_states = self.language_model.model(*model_inputs) + # for dummy_attention_metadata + if is_prefill and not self.set_flags: + self.set_flags = True - if is_prefill: - hidden_states = ops.squeeze(hidden_states, 0) - else: - hidden_states = ops.squeeze(hidden_states, 1) + hidden_states = self.language_model.model(*model_inputs) return hidden_states @@ -437,8 +436,6 @@ class Qwen2_5_VLForConditionalGeneration(MindONEModelBase, SupportsMultiModal): self, input_ids: mindspore.Tensor, positions: mindspore.Tensor, - kv_caches: List[mindspore.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[mindspore.Tensor] = None, **kwargs: object, @@ -490,12 +487,39 @@ class Qwen2_5_VLForConditionalGeneration(MindONEModelBase, SupportsMultiModal): input_ids = None hidden_states = self.run_language_model(input_ids, positions, - kv_caches, attn_metadata, intermediate_tensors, inputs_embeds) return hidden_states + def _dummy_attention_metadata(self, input_ids: Tensor, positions: Tensor): + if input_ids is None: + input_len = input_ids.shape[0] + else: + input_len = positions.shape[1] + max_seq_len = mindspore.Tensor(input_len, dtype=mindspore.int32) + seq_lengths = mindspore.Tensor([input_len], dtype=mindspore.int32) + q_seq_lens_np = np.array([input_len], dtype=np.int32) + seq_lens_np = np.array([input_len], dtype=np.int32) + context_lens_tensor = mindspore.Tensor([0], dtype=mindspore.int32) + + block_tables = mindspore.Tensor([[0]], dtype=mindspore.int32) + slot_mapping = [-1 for _ in range(input_len)] + slot_mapping = mindspore.Tensor(slot_mapping, dtype=mindspore.int32) + return MsAttentionMetadata( + max_seq_len=max_seq_len, + seq_lens=seq_lengths, + seq_lens_np=seq_lens_np, + block_tables=block_tables, + slot_mapping=slot_mapping, + q_seq_lens_np=q_seq_lens_np, + context_lens=context_lens_tensor, + # To enforce prefill and decode are both complied in warmup process. + # So set max_context_lens to 0 for prefill and 1 for decode. + max_context_lens=0 if not self.set_flags else 1, + query_start_loc = None + ) + def compute_logits( self, hidden_states: mindspore.Tensor, diff --git a/vllm_mindspore/model_executor/models/mindone_models/qwen3.py b/vllm_mindspore/model_executor/models/mindone_models/qwen3.py index bbc1cc97..cd66bc71 100644 --- a/vllm_mindspore/model_executor/models/mindone_models/qwen3.py +++ b/vllm_mindspore/model_executor/models/mindone_models/qwen3.py @@ -19,28 +19,23 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union if TYPE_CHECKING: from transformers import Qwen3Config -import numpy as np from mindone.transformers.models.qwen3.modeling_qwen3 import ( Qwen3Attention, Qwen3MLP, Qwen3PreTrainedModel, Qwen3RMSNorm) -from mindspore import Tensor, jit, mint, mutable, nn, ops +import mindspore as ms +from mindspore import Tensor, mint, mutable, nn from mindspore.common import dtype as mstype -from vllm.attention.backends.abstract import AttentionMetadata, AttentionType +from vllm.attention.backends.abstract import AttentionType from vllm.config import VllmConfig from vllm.sequence import IntermediateTensors from vllm_mindspore.attention import Attention from vllm_mindspore.model_executor.layers.rotary_embedding import get_rope -from vllm_mindspore.model_executor.layers.sampler import (SamplerOutput, - get_sampler) -from vllm_mindspore.model_executor.models.attention_mask import ( - LowerTriangularMask) +from vllm_mindspore.model_executor.layers.sampler import ( + get_sampler) from vllm_mindspore.model_executor.models.mindone_models.base import ( MindONEModelBase) from vllm_mindspore.model_executor.models.mindone_models.utils import ( enable_dynamic_shape) -from vllm_mindspore.model_executor.models.model_base import Fake_Attention -from vllm_mindspore.model_executor.sampling_metadata import SamplingMetadata -from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE class vLLMQwen3Attention(Qwen3Attention): @@ -69,7 +64,6 @@ class vLLMQwen3Attention(Qwen3Attention): self.q_norm = Qwen3RMSNorm(self.head_dim, eps=self.config.rms_norm_eps) self.k_norm = Qwen3RMSNorm(self.head_dim, eps=self.config.rms_norm_eps) - @jit def construct( self, positions: Tensor, @@ -116,7 +110,6 @@ class vLLMQwen3DecoderLayer(nn.Cell): self.self_attn = vLLMQwen3Attention(config, layer_idx) - @jit def construct( self, positions: Tensor, @@ -174,7 +167,6 @@ class vLLMQwen3Model(Qwen3PreTrainedModel): def get_input_embeddings(self): return self.embed_tokens - @jit def construct( self, input_ids: Optional[Tensor], @@ -242,23 +234,10 @@ class Qwen3ForCausalLM(MindONEModelBase): self.quant_config = quant_config self.sampler = get_sampler() - self.set_modules({"model": self.model, "lm_head": self.lm_head}) - - self.prefill = True - self.mstype = STR_DTYPE_TO_MS_DTYPE.get(self.model_config.dtype, - self.model_config.dtype) - self.casual_mask = LowerTriangularMask( - dtype=self.mstype, max_model_len=self.model_config.max_model_len) - self.kv_caches = [ - Fake_Attention() for i in range(config.num_hidden_layers) - ] - compilation_config = vllm_config.compilation_config - - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - for i in range(config.num_hidden_layers): - compilation_config.static_forward_context[str( - i)] = self.kv_caches[i] + self.common_preprocess(vllm_config, prefix) + self.prev_prefill = False + self.is_graph_mode = False if vllm_config.model_config.enforce_eager else True + self.run_model = None def get_input_embeddings(self, input_ids: Tensor) -> Tensor: return self.model.get_input_embeddings(input_ids) @@ -266,98 +245,37 @@ class Qwen3ForCausalLM(MindONEModelBase): def forward(self, input_ids: Tensor, positions: Tensor, - kv_caches: List[Tuple[Tensor, Tensor]], - attn_metadata: AttentionMetadata, intermediate_tensors: IntermediateTensors = None, inputs_embeds: Tensor = None, **kwargs) -> Union[Tensor, IntermediateTensors]: - key_cache, value_cache = self.get_kvcache() - - seq_lens = attn_metadata.seq_lens - max_query_len = attn_metadata.max_query_len - # When Mutli-Step is enabled with Chunked-Prefill, prefills and - # decodes are scheduled together. In the first step, all the - # prefills turn into decodes and max_query_len will be 1. - if self.is_multi_step_chunked_prefill and max_query_len == 1: - query_lens = [1] * len(seq_lens) - else: - query_lens = attn_metadata.query_lens - - seq_lens_np = np.array(seq_lens, dtype=np.int32) - query_lens_np = np.array(query_lens, dtype=np.int32) - kv_cache_lens = seq_lens_np - query_lens_np - is_prefill = bool(attn_metadata.num_decode_tokens == 0 - and kv_cache_lens.max() == 0) - if is_prefill: - input_ids = ops.expand_dims(input_ids, 0) - else: - input_ids = ops.expand_dims(input_ids, 1) - - slot_mapping = attn_metadata.slot_mapping - attn_mask = self.casual_mask.gen_attention_mask( - is_prefill, positions, query_lens) - seq_lens_np = np.array(attn_metadata.seq_lens, dtype=np.int32) - batch_valid_length = Tensor.from_numpy(seq_lens_np) - q_seq_lens = Tensor.from_numpy( - np.array(attn_metadata.query_lens, dtype=np.int32)) - block_tables = attn_metadata.block_tables + ori_model_inputs, is_prefill = self.prepare_base_inputs(input_ids, positions) + is_prefill = bool(is_prefill) model_inputs = (\ - input_ids, - positions, - key_cache, - value_cache, + ori_model_inputs["input_ids"], + ori_model_inputs["position_ids"], + ori_model_inputs["key_cache"], + ori_model_inputs["value_cache"], mutable(is_prefill), - slot_mapping, - attn_mask, - batch_valid_length, - q_seq_lens, - block_tables, + ori_model_inputs["slot_mapping"], + ori_model_inputs["attention_mask"], + ori_model_inputs["batch_valid_length"], + ori_model_inputs["q_seq_lens"], + ori_model_inputs["block_tables"], intermediate_tensors, - inputs_embeds + inputs_embeds, ) - if is_prefill: - if not self.prefill: - self.prefill = True - enable_dynamic_shape( - self.model, *model_inputs - ) # enable dynamic shape once on first prefill step - else: - if self.prefill: - self.prefill = False - enable_dynamic_shape( - self.model, *model_inputs - ) # enable dynamic shape once on first decode step + if self.prev_prefill != is_prefill and self.is_graph_mode: + enable_dynamic_shape(self.model, *model_inputs) + self.prev_prefill = is_prefill - model_output = self.model(*model_inputs) + # for dummy_attention_metadata + if is_prefill and not self.set_flags: + self.set_flags = True - if is_prefill: - model_output = ops.squeeze(model_output, 0) - else: - model_output = ops.squeeze(model_output, 1) + if self.run_model is None: + self.run_model = ms.jit(function=self.model, jit_level='O0') if self.is_graph_mode else self.model + model_output = self.model(*model_inputs) return model_output - - def load_weights(self, *args, **kwargs): - if self.config.tie_word_embeddings: - self.lm_head.weight.set_data( - self.model.embed_tokens.embedding_table.data) - - def sample(self, logits: Tensor, - sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def compute_logits( - self, - hidden_states: Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[Tensor]: - if sampling_metadata.selected_token_indices is not None: - hidden_states = ops.gather( - hidden_states, sampling_metadata.selected_token_indices, 0) - - logits = self.lm_head(hidden_states).float() - - return logits diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index 8ce1bc91..1aac63fb 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -36,7 +36,7 @@ from vllm.logger import init_logger from vllm_mindspore.utils import get_valid_dtype from vllm.model_executor import set_random_seed -from vllm.sequence import SequenceGroupMetadata +from vllm.sequence import SequenceData, SequenceGroupMetadata from vllm.sampling_params import SamplingParams @@ -48,11 +48,16 @@ def _prepare_input_for_warmup(model_config, model_runner, cache_engine, is_prefi seq_len = model_runner.scheduler_config.max_num_batched_tokens if is_prefill else 1 dummy_data = model_runner.input_registry.dummy_data_for_profiling(model_config, seq_len, model_runner.mm_registry) block_tables = [i for i in range(math.ceil(seq_len / cache_engine.block_size))] + + seq_data = dummy_data.seq_data + if seq_len == 1: + seq_data = dummy_data.seq_data.from_prompt_token_counts((0, seq_len)) + seqs = [ SequenceGroupMetadata( request_id=str(idx), is_prompt=is_prefill, - seq_data={idx: dummy_data.seq_data}, + seq_data={idx: seq_data}, sampling_params=SamplingParams(), block_tables={idx: block_tables}, lora_request=None, -- Gitee