diff --git a/examples/generate_multimodal_vllm.py b/examples/generate_multimodal_vllm.py index 052fa182e3d751c92abe9097e48da813df733004..a17731b68690633aee5479c8659d2ff6100501f7 100644 --- a/examples/generate_multimodal_vllm.py +++ b/examples/generate_multimodal_vllm.py @@ -22,11 +22,13 @@ from vllm import LLM, SamplingParams # noqa: E402 # Qwen2.5-VL -def get_llm(model_path: str, question: str, modality: str): +def get_llm(question, modality, args): llm = LLM( - model=model_path, + model=args.model_path, max_model_len=4096, max_num_seqs=5, + max_num_batched_tokens=args.max_num_batched_tokens, + gpu_memory_utilization=args.gpu_memory_utilization, mm_processor_kwargs={ "min_pixels": 28 * 28, "max_pixels": 1280 * 28 * 28, @@ -55,10 +57,7 @@ def main(args): # Prepare args and inputs. img_question = "What is the content of this image?" img = Image.open("./imgs/1.jpg").convert("RGB") - llm, prompt, stop_token_ids = get_llm( - args.model_path, img_question, "image" - ) - + llm, prompt, stop_token_ids = get_llm(img_question, "image", args) inputs = [ { "prompt": prompt, @@ -92,6 +91,8 @@ if __name__ == "__main__": parser.add_argument( "--model_path", type=str, default="Qwen/Qwen2.5-VL-3B-Instruct" ) + parser.add_argument("--max_num_batched_tokens", type=int, default=2048) + parser.add_argument("--gpu_memory_utilization", type=float, default=0.8) args, _ = parser.parse_known_args() main(args) diff --git a/vllm_mindspore/model_executor/models/mindone_models/base.py b/vllm_mindspore/model_executor/models/mindone_models/base.py index d7399281b89b601af8b5de49211ae351fa511840..6b6e373e1d5591bb548818e1eac818344010fcb0 100644 --- a/vllm_mindspore/model_executor/models/mindone_models/base.py +++ b/vllm_mindspore/model_executor/models/mindone_models/base.py @@ -13,9 +13,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# ============================================================================ +from typing import Optional + +from mindspore import Tensor, ops from vllm.config import VllmConfig +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm_mindspore.model_executor.models.model_base import MsModelBase +from vllm_mindspore.model_executor.models.attention_mask import ( + LowerTriangularMask) +from vllm_mindspore.model_executor.models.model_base import (AttentionWrapper, + MsModelBase) class MindONEModelBase(MsModelBase): @@ -89,3 +98,46 @@ class MindONEModelBase(MsModelBase): def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) + + def common_preprocess(self, vllm_config, prefix=""): + self.set_modules({"model": self.model, "lm_head": self.lm_head}) + + self.casual_mask = LowerTriangularMask( + dtype=self.model_config.dtype, + max_model_len=self.model_config.max_model_len) + self.kv_caches = [ + AttentionWrapper() for i in range(self.config.num_hidden_layers) + ] + + compilation_config = vllm_config.compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + for i in range(self.config.num_hidden_layers): + compilation_config.static_forward_context[str( + i)] = self.kv_caches[i] + self.is_eager_mode = vllm_config.model_config.enforce_eager + self.prefill_graph = None + self.decode_graph = None + + def sample(self, logits: Tensor, + sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, *args, **kwargs): + if self.config.tie_word_embeddings: + self.lm_head.weight.set_data( + self.model.embed_tokens.embedding_table.data) + + def compute_logits( + self, + hidden_states: Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[Tensor]: + if sampling_metadata is not None and sampling_metadata.selected_token_indices is not None: # noqa: E501 + hidden_states = ops.gather( + hidden_states, sampling_metadata.selected_token_indices, 0) + + logits = self.lm_head(hidden_states).float() + + return logits diff --git a/vllm_mindspore/model_executor/models/mindone_models/qwen2.py b/vllm_mindspore/model_executor/models/mindone_models/qwen2.py index b4150ef6dc2a37c24920d5314b41844dd09c4f47..7f6f7d5970689506eb78c826ed3bd6861cab583c 100644 --- a/vllm_mindspore/model_executor/models/mindone_models/qwen2.py +++ b/vllm_mindspore/model_executor/models/mindone_models/qwen2.py @@ -31,27 +31,23 @@ if TYPE_CHECKING: else: Qwen2Config = None -import numpy as np +import mindspore as ms from mindone.transformers.models.qwen2.modeling_qwen2 import ( Qwen2Attention, Qwen2MLP, Qwen2PreTrainedModel, Qwen2RMSNorm) -from mindspore import Tensor, jit, mutable, nn, ops +from mindspore import Tensor, nn from mindspore.common import dtype as mstype -from vllm.attention.backends.abstract import AttentionMetadata, AttentionType +from vllm.attention.backends.abstract import AttentionType from vllm.config import VllmConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.sampler import get_sampler from vllm.sequence import IntermediateTensors from vllm_mindspore.attention import Attention from vllm_mindspore.model_executor.layers.rotary_embedding import get_rope -from vllm_mindspore.model_executor.models.attention_mask import ( - LowerTriangularMask) from vllm_mindspore.model_executor.models.mindone_models.base import ( MindONEModelBase) from vllm_mindspore.model_executor.models.mindone_models.utils import ( enable_dynamic_shape) -from vllm_mindspore.model_executor.models.model_base import Fake_Attention -from vllm_mindspore.model_executor.sampling_metadata import SamplingMetadata -from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE +from vllm_mindspore.model_executor.utils import set_model_context class vLLMQwen2Attention(Qwen2Attention): @@ -77,14 +73,12 @@ class vLLMQwen2Attention(Qwen2Attention): attn_type=AttentionType.DECODER, ) - @jit def construct( self, positions: Tensor, hidden_states: Tensor, key_cache: Tensor, value_cache: Tensor, - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -95,14 +89,13 @@ class vLLMQwen2Attention(Qwen2Attention): k = self.k_proj(hidden_states) v = self.v_proj(hidden_states) - q, k = self.rotary_emb(positions, q, k, batch_valid_length, is_prefill) + q, k = self.rotary_emb(positions, q, k, batch_valid_length) attn_output = self.attn( q, k, v, key_cache, value_cache, - is_prefill, slot_mapping, attn_mask, batch_valid_length, @@ -125,14 +118,12 @@ class vLLMQwen2DecoderLayer(nn.Cell): self.self_attn = vLLMQwen2Attention(config, layer_idx) - @jit def construct( self, positions: Tensor, hidden_states: Tensor, key_cache: Tensor, value_cache: Tensor, - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -148,7 +139,6 @@ class vLLMQwen2DecoderLayer(nn.Cell): hidden_states, key_cache, value_cache, - is_prefill, slot_mapping, attn_mask, batch_valid_length, @@ -190,14 +180,12 @@ class vLLMQwen2Model(Qwen2PreTrainedModel): def get_input_embeddings(self): return self.embed_tokens - @jit def construct( self, input_ids: Optional[Tensor], positions: Tensor, key_caches: list[Tensor], value_caches: list[Tensor], - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -216,7 +204,6 @@ class vLLMQwen2Model(Qwen2PreTrainedModel): hidden_states, key_caches[i], value_caches[i], - is_prefill, slot_mapping, attn_mask, batch_valid_length, @@ -262,125 +249,62 @@ class Qwen2ForCausalLM(MindONEModelBase): self.lora_config = lora_config self.quant_config = quant_config self.sampler = get_sampler() - - self.set_modules({"model": self.model, "lm_head": self.lm_head}) - - self.prefill = True - self.mstype = STR_DTYPE_TO_MS_DTYPE.get(self.model_config.dtype, - self.model_config.dtype) - self.casual_mask = LowerTriangularMask( - dtype=self.mstype, max_model_len=self.model_config.max_model_len) - self.kv_caches = [ - Fake_Attention() for i in range(config.num_hidden_layers) - ] - compilation_config = vllm_config.compilation_config - - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - for i in range(config.num_hidden_layers): - compilation_config.static_forward_context[str( - i)] = self.kv_caches[i] + self.common_preprocess(vllm_config, prefix) def get_input_embeddings(self, input_ids: Tensor) -> Tensor: return self.model.get_input_embeddings(input_ids) - def forward( - self, - input_ids: Tensor, - positions: Tensor, - kv_caches: list[tuple[Tensor, Tensor]], - attn_metadata: AttentionMetadata, - intermediate_tensors: IntermediateTensors = None, - inputs_embeds: Tensor = None, - **kwargs, - ) -> Union[Tensor, IntermediateTensors]: - key_cache, value_cache = self.get_kvcache() - - seq_lens = attn_metadata.seq_lens - max_query_len = attn_metadata.max_query_len - # When Mutli-Step is enabled with Chunked-Prefill, prefills and - # decodes are scheduled together. In the first step, all the - # prefills turn into decodes and max_query_len will be 1. - if self.is_multi_step_chunked_prefill and max_query_len == 1: - query_lens = [1] * len(seq_lens) - else: - query_lens = attn_metadata.query_lens - - seq_lens_np = np.array(seq_lens, dtype=np.int32) - query_lens_np = np.array(query_lens, dtype=np.int32) - kv_cache_lens = seq_lens_np - query_lens_np - is_prefill = bool(attn_metadata.num_decode_tokens == 0 - and kv_cache_lens.max() == 0) - if is_prefill: - input_ids = ops.expand_dims(input_ids, 0) - else: - input_ids = ops.expand_dims(input_ids, 1) - - slot_mapping = attn_metadata.slot_mapping - attn_mask = self.casual_mask.gen_attention_mask( - is_prefill, positions, query_lens) - seq_lens_np = np.array(attn_metadata.seq_lens, dtype=np.int32) - batch_valid_length = Tensor.from_numpy(seq_lens_np) - q_seq_lens = Tensor.from_numpy( - np.array(attn_metadata.query_lens, dtype=np.int32)) - block_tables = attn_metadata.block_tables - - model_inputs = ( - input_ids, - positions, - key_cache, - value_cache, - mutable(is_prefill), - slot_mapping, - attn_mask, - batch_valid_length, - q_seq_lens, - block_tables, + def forward(self, + input_ids: Tensor, + positions: Tensor, + intermediate_tensors: IntermediateTensors = None, + inputs_embeds: Tensor = None, + **kwargs) -> Union[Tensor, IntermediateTensors]: + ori_model_inputs, is_prefill = self.prepare_base_inputs( + input_ids, positions) + is_prefill = bool(is_prefill) + + model_inputs = (\ + ori_model_inputs["input_ids"], + ori_model_inputs["position_ids"], + ori_model_inputs["key_cache"], + ori_model_inputs["value_cache"], + ori_model_inputs["slot_mapping"], + ori_model_inputs["attention_mask"], + ori_model_inputs["batch_valid_length"], + ori_model_inputs["q_seq_lens"], + ori_model_inputs["block_tables"], intermediate_tensors, inputs_embeds, ) - if is_prefill: - if not self.prefill: - self.prefill = True - enable_dynamic_shape( - self.model, *model_inputs - ) # enable dynamic shape once on first prefill step - else: - if self.prefill: - self.prefill = False - enable_dynamic_shape( - self.model, *model_inputs - ) # enable dynamic shape once on first decode step + # for dummy_attention_metadata + if is_prefill and not self.set_flags: # type: ignore + self.set_flags = True - model_output = self.model(*model_inputs) + # eager mode + if self.is_eager_mode: + set_model_context("is_prefill", is_prefill) + model_output = self.model(*model_inputs) + return model_output + # graph mode if is_prefill: - model_output = ops.squeeze(model_output, 0) + self.model.phase = "prefill" + if self.prefill_graph is None: + set_model_context("is_prefill", True) + self.model._set_jit_graph_name("prefill") + enable_dynamic_shape(self.model, *model_inputs) + self.prefill_graph = ms.jit(function=self.model, + jit_level="O0") + model_output = self.prefill_graph(*model_inputs) else: - model_output = ops.squeeze(model_output, 1) + self.model.phase = "increment" + if self.decode_graph is None: + set_model_context("is_prefill", False) + self.model._set_jit_graph_name("decode") + enable_dynamic_shape(self.model, *model_inputs) + self.decode_graph = ms.jit(function=self.model, jit_level="O0") + model_output = self.decode_graph(*model_inputs) return model_output - - def sample(self, logits: Tensor, - sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def load_weights(self, *args, **kwargs): - if self.config.tie_word_embeddings: - self.lm_head.weight.set_data( - self.model.embed_tokens.embedding_table.data) - - def compute_logits( - self, - hidden_states: Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[Tensor]: - if sampling_metadata.selected_token_indices is not None: - hidden_states = ops.gather( - hidden_states, sampling_metadata.selected_token_indices, 0) - - logits = self.lm_head(hidden_states).float() - - return logits diff --git a/vllm_mindspore/model_executor/models/mindone_models/qwen2_5_vl.py b/vllm_mindspore/model_executor/models/mindone_models/qwen2_5_vl.py index 51cc8705a592f504ae8899c24f506ab284a0f4b6..e3354ba051f56434ce9df19fb5ab35480804de2d 100644 --- a/vllm_mindspore/model_executor/models/mindone_models/qwen2_5_vl.py +++ b/vllm_mindspore/model_executor/models/mindone_models/qwen2_5_vl.py @@ -42,23 +42,11 @@ from mindone.transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VLForConditionalGeneration as MindONE_Qwen2_5_VLForConditionalGeneration) # noqa: E501 from mindone.transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VLPreTrainedModel) -from mindspore import Tensor, mint, mutable, nn, ops +from mindspore import Tensor, mint, nn, ops from mindspore.common.api import _pynative_executor -from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.models.qwen2_5_vl import ( - Qwen2_5_VLConfig, Qwen2_5_VLDummyInputsBuilder) -from vllm.model_executor.models.qwen2_5_vl import ( - Qwen2_5_VLForConditionalGeneration as vLLM_Qwen2_5_VLForConditionalGeneration) # noqa: E501 -from vllm.model_executor.models.qwen2_5_vl import ( - Qwen2_5_VLImageEmbeddingInputs, Qwen2_5_VLImageInputs, - Qwen2_5_VLImagePixelInputs) -from vllm.model_executor.models.qwen2_5_vl import ( - Qwen2_5_VLMultiModalProcessor as vLLM_Qwen2_5_VLMultiModalProcessor) -from vllm.model_executor.models.qwen2_5_vl import ( - Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs, - Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs from vllm.multimodal.parse import MultiModalDataItems @@ -67,8 +55,9 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope from vllm_mindspore.model_executor.models.attention_mask import ( - LowerTriangularMask) -from vllm_mindspore.model_executor.models.interfaces import SupportsMultiModal + MultiModalLowerTriangularMask) +from vllm_mindspore.model_executor.models.interfaces import ( + MultiModalEmbeddings, SupportsMultiModal) from vllm_mindspore.model_executor.models.mindone_models.qwen2 import ( MindONEModelBase) from vllm_mindspore.model_executor.models.mindone_models.qwen2 import ( @@ -77,12 +66,21 @@ from vllm_mindspore.model_executor.models.mindone_models.qwen2 import ( vLLMQwen2Model) from vllm_mindspore.model_executor.models.mindone_models.utils import ( enable_dynamic_shape) -from vllm_mindspore.model_executor.models.model_base import Fake_Attention +from vllm_mindspore.model_executor.models.qwen2_5_vl import ( + Qwen2_5_VLConfig, Qwen2_5_VLDummyInputsBuilder) +from vllm_mindspore.model_executor.models.qwen2_5_vl import ( + Qwen2_5_VLForConditionalGeneration as vLLM_Qwen2_5_VLForConditionalGeneration) # noqa: E501 +from vllm_mindspore.model_executor.models.qwen2_5_vl import ( + Qwen2_5_VLImageEmbeddingInputs, Qwen2_5_VLImageInputs, + Qwen2_5_VLImagePixelInputs) +from vllm_mindspore.model_executor.models.qwen2_5_vl import ( + Qwen2_5_VLMultiModalProcessor as vLLM_Qwen2_5_VLMultiModalProcessor) +from vllm_mindspore.model_executor.models.qwen2_5_vl import ( + Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs, + Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs) from vllm_mindspore.model_executor.models.utils import ( maybe_prefix, merge_multimodal_embeddings) -from vllm_mindspore.model_executor.sampling_metadata import SamplingMetadata -from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE - +from vllm_mindspore.model_executor.utils import set_model_context # yapf:enable @@ -105,17 +103,10 @@ class Qwen2ForCausalLM(vLLM_Qwen2ForCausalLM): self.quant_config = quant_config self.sampler = get_sampler() - self.set_modules({"model": self.model, "lm_head": self.lm_head}) - self.kv_caches = [ - Fake_Attention() for i in range(config.num_hidden_layers) - ] - compilation_config = vllm_config.compilation_config - - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - for i in range(config.num_hidden_layers): - compilation_config.static_forward_context[str( - i)] = self.kv_caches[i] + self.common_preprocess(vllm_config, prefix) + self.casual_mask = MultiModalLowerTriangularMask( + dtype=vllm_config.model_config.dtype, + max_model_len=vllm_config.model_config.max_model_len) class Qwen2_5_VLMultiModalProcessor(vLLM_Qwen2_5_VLMultiModalProcessor): @@ -207,11 +198,12 @@ class Qwen2_5_VLForConditionalGeneration(MindONEModelBase, SupportsMultiModal): "visual": self.visual, "language_model": self.language_model }) - self.prefill = True - self.mstype = STR_DTYPE_TO_MS_DTYPE.get(self.model_config.dtype, - self.model_config.dtype) - self.casual_mask = LowerTriangularMask( - dtype=self.mstype, max_model_len=self.model_config.max_model_len) + + self.casual_mask = self.language_model.casual_mask + self.kv_caches = self.language_model.kv_caches + self.is_eager_mode = vllm_config.model_config.enforce_eager + self.prefill_graph = None + self.decode_graph = None @cached_property def sampler(self): @@ -318,17 +310,57 @@ class Qwen2_5_VLForConditionalGeneration(MindONEModelBase, SupportsMultiModal): return None - _process_image_input = vLLM_Qwen2_5_VLForConditionalGeneration._process_image_input # noqa: E501 - _process_video_input = vLLM_Qwen2_5_VLForConditionalGeneration._process_video_input # noqa: E501 + def _process_image_input( + self, image_input: Qwen2_5_VLImageInputs + ) -> tuple[mindspore.Tensor, ...]: + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + # grid_thw_list = grid_thw.tolist() + + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + else: + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + + # Split concatenated embeddings for each image item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return image_embeds.split(sizes.tolist()) + + def _process_video_input( + self, video_input: Qwen2_5_VLVideoInputs + ) -> tuple[mindspore.Tensor, ...]: + + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + # grid_thw_list = grid_thw.tolist() + + if video_input["type"] == "video_embeds": + video_embeds = video_input["video_embeds"].type(self.visual.dtype) + else: + pixel_values_videos = video_input["pixel_values_videos"].type( + self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + + # Split concatenated embeddings for each video item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return video_embeds.split(sizes.tolist()) + _parse_and_validate_multimodal_inputs = ( vLLM_Qwen2_5_VLForConditionalGeneration. _parse_and_validate_multimodal_inputs) - get_multimodal_embeddings = None + get_multimodal_embeddings = ( + vLLM_Qwen2_5_VLForConditionalGeneration.get_multimodal_embeddings) def get_input_embeddings( self, input_ids: mindspore.Tensor, - multimodal_embeddings: Optional[tuple[mindspore.Tensor, ...]] = None, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, ) -> mindspore.Tensor: inputs_embeds = self.language_model.model.embed_tokens(input_ids) if multimodal_embeddings is not None: @@ -343,8 +375,8 @@ class Qwen2_5_VLForConditionalGeneration(MindONEModelBase, SupportsMultiModal): def get_input_embeddings_v0( self, input_ids: mindspore.Tensor, - image_input: Optional[tuple[mindspore.Tensor, ...]] = None, - video_input: Optional[tuple[mindspore.Tensor, ...]] = None, + image_input: Optional[Qwen2_5_VLImageInputs] = None, + video_input: Optional[Qwen2_5_VLVideoInputs] = None, ) -> mindspore.Tensor: inputs_embeds = self.get_input_embeddings(input_ids) if image_input is not None: @@ -370,91 +402,63 @@ class Qwen2_5_VLForConditionalGeneration(MindONEModelBase, SupportsMultiModal): self, input_ids: Tensor, positions: Tensor, - kv_caches: list[tuple[Tensor, Tensor]], - attn_metadata: AttentionMetadata, intermediate_tensors: IntermediateTensors = None, inputs_embeds: Tensor = None, ): - key_caches, value_caches = self.language_model.get_kvcache() - - seq_lens = attn_metadata.seq_lens - max_query_len = attn_metadata.max_query_len - # When Mutli-Step is enabled with Chunked-Prefill, prefills and - # decodes are scheduled together. In the first step, all the - # prefills turn into decodes and max_query_len will be 1. - if self.is_multi_step_chunked_prefill and max_query_len == 1: - query_lens = [1] * len(seq_lens) - else: - query_lens = attn_metadata.query_lens - - seq_lens_np = np.array(seq_lens, dtype=np.int32) - query_lens_np = np.array(query_lens, dtype=np.int32) - kv_cache_lens = seq_lens_np - query_lens_np - is_prefill = bool(attn_metadata.num_decode_tokens == 0 - and kv_cache_lens.max() == 0) - - if is_prefill > 0: - if input_ids is not None: - input_ids = input_ids.expand_dims(0) - else: - inputs_embeds = inputs_embeds.expand_dims(0) - else: - if input_ids is not None: - input_ids = input_ids.expand_dims(1) - else: - inputs_embeds = inputs_embeds.expand_dims(1) - + ori_model_inputs, is_prefill = self.prepare_base_inputs( + input_ids, positions) + is_prefill = bool(is_prefill) if inputs_embeds is None: inputs_embeds = self.language_model.model.embed_tokens(input_ids) input_ids = None - slot_mapping = attn_metadata.slot_mapping - attn_mask = self.casual_mask.gen_attention_mask( - is_prefill, positions, query_lens) - seq_lens_np = np.array(attn_metadata.seq_lens, dtype=np.int32) - batch_valid_length = Tensor.from_numpy(seq_lens_np) - q_seq_lens = Tensor.from_numpy( - np.array(attn_metadata.query_lens, dtype=np.int32)) - block_tables = attn_metadata.block_tables - # keep position.ndim to 2, for work on mindspore dynamic shape if positions.ndim == 1: positions = positions[None] - model_inputs = ( - input_ids, - positions, - key_caches, - value_caches, - mutable(is_prefill), - slot_mapping, - attn_mask, - batch_valid_length, - q_seq_lens, - block_tables, + model_inputs = (\ + ori_model_inputs["input_ids"], + ori_model_inputs["position_ids"], + ori_model_inputs["key_cache"], + ori_model_inputs["value_cache"], + ori_model_inputs["slot_mapping"], + ori_model_inputs["attention_mask"], + ori_model_inputs["batch_valid_length"], + ori_model_inputs["q_seq_lens"], + ori_model_inputs["block_tables"], intermediate_tensors, inputs_embeds, ) - if is_prefill: - if not self.prefill: - self.prefill = True - enable_dynamic_shape( - self.language_model.model, *model_inputs - ) # enable dynamic shape once on first prefill step - else: - if self.prefill: - self.prefill = False - enable_dynamic_shape( - self.language_model.model, *model_inputs - ) # enable dynamic shape once on first decode step + # for dummy_attention_metadata + if is_prefill and not self.set_flags: # type: ignore + self.set_flags = True - hidden_states = self.language_model.model(*model_inputs) + # eager mode + if self.is_eager_mode: + set_model_context("is_prefill", is_prefill) + hidden_states = self.language_model.model(*model_inputs) + return hidden_states + # graph mode if is_prefill: - hidden_states = ops.squeeze(hidden_states, 0) + self.language_model.model.phase = "prefill" + if self.prefill_graph is None: + set_model_context("is_prefill", True) + self.language_model.model._set_jit_graph_name("prefill") + enable_dynamic_shape(self.language_model.model, *model_inputs) + self.prefill_graph = mindspore.jit( + function=self.language_model.model, jit_level="O0") + hidden_states = self.prefill_graph(*model_inputs) else: - hidden_states = ops.squeeze(hidden_states, 1) + self.language_model.model.phase = "increment" + if self.decode_graph is None: + set_model_context("is_prefill", False) + self.language_model.model._set_jit_graph_name("decode") + enable_dynamic_shape(self.language_model.model, *model_inputs) + self.decode_graph = mindspore.jit( + function=self.language_model.model, jit_level="O0") + hidden_states = self.decode_graph(*model_inputs) return hidden_states @@ -462,8 +466,6 @@ class Qwen2_5_VLForConditionalGeneration(MindONEModelBase, SupportsMultiModal): self, input_ids: mindspore.Tensor, positions: mindspore.Tensor, - kv_caches: list[mindspore.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[mindspore.Tensor] = None, **kwargs: object, @@ -514,14 +516,9 @@ class Qwen2_5_VLForConditionalGeneration(MindONEModelBase, SupportsMultiModal): video_input=video_input) input_ids = None - hidden_states = self.run_language_model( - input_ids, - positions, - kv_caches, - attn_metadata, - intermediate_tensors, - inputs_embeds, - ) + hidden_states = self.run_language_model(input_ids, positions, + intermediate_tensors, + inputs_embeds) return hidden_states diff --git a/vllm_mindspore/model_executor/models/mindone_models/qwen3.py b/vllm_mindspore/model_executor/models/mindone_models/qwen3.py index 0546cfefdb315cb1f8fd04d0fd4b6aacb70cf987..dd185dda014861ca70d714f0f904c1946ef39ca7 100644 --- a/vllm_mindspore/model_executor/models/mindone_models/qwen3.py +++ b/vllm_mindspore/model_executor/models/mindone_models/qwen3.py @@ -25,27 +25,23 @@ from typing import TYPE_CHECKING, Optional, Union if TYPE_CHECKING: from transformers import Qwen3Config -import numpy as np +import mindspore as ms from mindone.transformers.models.qwen3.modeling_qwen3 import ( Qwen3Attention, Qwen3MLP, Qwen3PreTrainedModel, Qwen3RMSNorm) -from mindspore import Tensor, jit, mint, mutable, nn, ops +from mindspore import Tensor, mint, nn from mindspore.common import dtype as mstype -from vllm.attention.backends.abstract import AttentionMetadata, AttentionType +from vllm.attention.backends.abstract import AttentionType from vllm.config import VllmConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.sampler import get_sampler from vllm.sequence import IntermediateTensors from vllm_mindspore.attention import Attention from vllm_mindspore.model_executor.layers.rotary_embedding import get_rope -from vllm_mindspore.model_executor.models.attention_mask import ( - LowerTriangularMask) from vllm_mindspore.model_executor.models.mindone_models.base import ( MindONEModelBase) from vllm_mindspore.model_executor.models.mindone_models.utils import ( enable_dynamic_shape) -from vllm_mindspore.model_executor.models.model_base import Fake_Attention -from vllm_mindspore.model_executor.sampling_metadata import SamplingMetadata -from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE +from vllm_mindspore.model_executor.utils import set_model_context class vLLMQwen3Attention(Qwen3Attention): @@ -75,14 +71,12 @@ class vLLMQwen3Attention(Qwen3Attention): self.q_norm = Qwen3RMSNorm(self.head_dim, eps=self.config.rms_norm_eps) self.k_norm = Qwen3RMSNorm(self.head_dim, eps=self.config.rms_norm_eps) - @jit def construct( self, positions: Tensor, hidden_states: Tensor, key_cache: Tensor, value_cache: Tensor, - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -102,14 +96,13 @@ class vLLMQwen3Attention(Qwen3Attention): k_by_head = self.k_norm(k_by_head) k = k_by_head.view(k.shape) - q, k = self.rotary_emb(positions, q, k, batch_valid_length, is_prefill) + q, k = self.rotary_emb(positions, q, k, batch_valid_length) attn_output = self.attn( q, k, v, key_cache, value_cache, - is_prefill, slot_mapping, attn_mask, batch_valid_length, @@ -132,14 +125,12 @@ class vLLMQwen3DecoderLayer(nn.Cell): self.self_attn = vLLMQwen3Attention(config, layer_idx) - @jit def construct( self, positions: Tensor, hidden_states: Tensor, key_cache: Tensor, value_cache: Tensor, - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -155,7 +146,6 @@ class vLLMQwen3DecoderLayer(nn.Cell): hidden_states, key_cache, value_cache, - is_prefill, slot_mapping, attn_mask, batch_valid_length, @@ -197,14 +187,12 @@ class vLLMQwen3Model(Qwen3PreTrainedModel): def get_input_embeddings(self): return self.embed_tokens - @jit def construct( self, input_ids: Optional[Tensor], positions: Tensor, key_caches: list[Tensor], value_caches: list[Tensor], - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -223,7 +211,6 @@ class vLLMQwen3Model(Qwen3PreTrainedModel): hidden_states, key_caches[i], value_caches[i], - is_prefill, slot_mapping, attn_mask, batch_valid_length, @@ -270,124 +257,62 @@ class Qwen3ForCausalLM(MindONEModelBase): self.quant_config = quant_config self.sampler = get_sampler() - self.set_modules({"model": self.model, "lm_head": self.lm_head}) - - self.prefill = True - self.mstype = STR_DTYPE_TO_MS_DTYPE.get(self.model_config.dtype, - self.model_config.dtype) - self.casual_mask = LowerTriangularMask( - dtype=self.mstype, max_model_len=self.model_config.max_model_len) - self.kv_caches = [ - Fake_Attention() for i in range(config.num_hidden_layers) - ] - compilation_config = vllm_config.compilation_config - - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - for i in range(config.num_hidden_layers): - compilation_config.static_forward_context[str( - i)] = self.kv_caches[i] + self.common_preprocess(vllm_config, prefix) def get_input_embeddings(self, input_ids: Tensor) -> Tensor: return self.model.get_input_embeddings(input_ids) - def forward( - self, - input_ids: Tensor, - positions: Tensor, - kv_caches: list[tuple[Tensor, Tensor]], - attn_metadata: AttentionMetadata, - intermediate_tensors: IntermediateTensors = None, - inputs_embeds: Tensor = None, - **kwargs, - ) -> Union[Tensor, IntermediateTensors]: - key_cache, value_cache = self.get_kvcache() - - seq_lens = attn_metadata.seq_lens - max_query_len = attn_metadata.max_query_len - # When Mutli-Step is enabled with Chunked-Prefill, prefills and - # decodes are scheduled together. In the first step, all the - # prefills turn into decodes and max_query_len will be 1. - if self.is_multi_step_chunked_prefill and max_query_len == 1: - query_lens = [1] * len(seq_lens) - else: - query_lens = attn_metadata.query_lens - - seq_lens_np = np.array(seq_lens, dtype=np.int32) - query_lens_np = np.array(query_lens, dtype=np.int32) - kv_cache_lens = seq_lens_np - query_lens_np - is_prefill = bool(attn_metadata.num_decode_tokens == 0 - and kv_cache_lens.max() == 0) - if is_prefill: - input_ids = ops.expand_dims(input_ids, 0) - else: - input_ids = ops.expand_dims(input_ids, 1) - - slot_mapping = attn_metadata.slot_mapping - attn_mask = self.casual_mask.gen_attention_mask( - is_prefill, positions, query_lens) - seq_lens_np = np.array(attn_metadata.seq_lens, dtype=np.int32) - batch_valid_length = Tensor.from_numpy(seq_lens_np) - q_seq_lens = Tensor.from_numpy( - np.array(attn_metadata.query_lens, dtype=np.int32)) - block_tables = attn_metadata.block_tables - - model_inputs = ( - input_ids, - positions, - key_cache, - value_cache, - mutable(is_prefill), - slot_mapping, - attn_mask, - batch_valid_length, - q_seq_lens, - block_tables, + def forward(self, + input_ids: Tensor, + positions: Tensor, + intermediate_tensors: IntermediateTensors = None, + inputs_embeds: Tensor = None, + **kwargs) -> Union[Tensor, IntermediateTensors]: + ori_model_inputs, is_prefill = self.prepare_base_inputs( + input_ids, positions) + is_prefill = bool(is_prefill) + + model_inputs = (\ + ori_model_inputs["input_ids"], + ori_model_inputs["position_ids"], + ori_model_inputs["key_cache"], + ori_model_inputs["value_cache"], + ori_model_inputs["slot_mapping"], + ori_model_inputs["attention_mask"], + ori_model_inputs["batch_valid_length"], + ori_model_inputs["q_seq_lens"], + ori_model_inputs["block_tables"], intermediate_tensors, inputs_embeds, ) - if is_prefill: - if not self.prefill: - self.prefill = True - enable_dynamic_shape( - self.model, *model_inputs - ) # enable dynamic shape once on first prefill step - else: - if self.prefill: - self.prefill = False - enable_dynamic_shape( - self.model, *model_inputs - ) # enable dynamic shape once on first decode step + # for dummy_attention_metadata + if is_prefill and not self.set_flags: # type: ignore + self.set_flags = True - model_output = self.model(*model_inputs) + # eager mode + if self.is_eager_mode: + set_model_context("is_prefill", is_prefill) + model_output = self.model(*model_inputs) + return model_output + # graph mode if is_prefill: - model_output = ops.squeeze(model_output, 0) + self.model.phase = "prefill" + if self.prefill_graph is None: + set_model_context("is_prefill", True) + self.model._set_jit_graph_name("prefill") + enable_dynamic_shape(self.model, *model_inputs) + self.prefill_graph = ms.jit(function=self.model, + jit_level="O0") + model_output = self.prefill_graph(*model_inputs) else: - model_output = ops.squeeze(model_output, 1) + self.model.phase = "increment" + if self.decode_graph is None: + set_model_context("is_prefill", False) + self.model._set_jit_graph_name("decode") + enable_dynamic_shape(self.model, *model_inputs) + self.decode_graph = ms.jit(function=self.model, jit_level="O0") + model_output = self.decode_graph(*model_inputs) return model_output - - def load_weights(self, *args, **kwargs): - if self.config.tie_word_embeddings: - self.lm_head.weight.set_data( - self.model.embed_tokens.embedding_table.data) - - def sample(self, logits: Tensor, - sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def compute_logits( - self, - hidden_states: Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[Tensor]: - if sampling_metadata.selected_token_indices is not None: - hidden_states = ops.gather( - hidden_states, sampling_metadata.selected_token_indices, 0) - - logits = self.lm_head(hidden_states).float() - - return logits