From 43db8b7791eb11bc2f75b2d0c868c3f0c18ac4c9 Mon Sep 17 00:00:00 2001 From: zhanzhan1 Date: Wed, 14 May 2025 09:58:14 +0800 Subject: [PATCH 1/4] Support for enable_prefix_caching --- vllm_mindspore/attention/layer.py | 11 ++- .../model_executor/layers/logits_processor.py | 18 +++-- .../models/{mf_models => }/attention_mask.py | 23 ++---- .../models/mf_models/mf_model_base.py | 5 +- .../model_executor/models/model_base.py | 44 +---------- vllm_mindspore/model_executor/models/qwen2.py | 73 ++++++++++++++++--- vllm_mindspore/utils.py | 9 +-- 7 files changed, 93 insertions(+), 90 deletions(-) rename vllm_mindspore/model_executor/models/{mf_models => }/attention_mask.py (59%) diff --git a/vllm_mindspore/attention/layer.py b/vllm_mindspore/attention/layer.py index 4634727b9..f5f160851 100644 --- a/vllm_mindspore/attention/layer.py +++ b/vllm_mindspore/attention/layer.py @@ -157,11 +157,10 @@ class Attention(nn.Cell): value_cache: Tensor, is_prefill: bool, slot_mapping: Tensor, - batch_valid_length: Tuple[int], + attn_mask: Tensor, + batch_valid_length: Tensor, q_seq_lens: Tensor, block_tables: Tensor, - attn_mask: Tensor, - decode_mask: Tensor, ) -> Tensor: """Attention foward, support MHA and GQA. @@ -181,7 +180,7 @@ class Attention(nn.Cell): output = self._run_prefill_forward(query, key, value, attn_mask, batch_valid_length, batch_valid_length) else: output = self._run_decode_forward(query, key_cache, value_cache, block_tables, batch_valid_length, - decode_mask, q_seq_lens) + attn_mask, q_seq_lens) return output def _run_prefill_forward( @@ -228,7 +227,7 @@ class Attention(nn.Cell): value_cache: Tensor, block_tables: Tensor, batch_valid_length: Tensor, - decode_mask: Tensor, + attn_mask: Tensor, q_seq_lens: Tensor, ) -> Tensor: """Decode with PagedAttention. @@ -248,7 +247,7 @@ class Attention(nn.Cell): batch_valid_length, None, None, - decode_mask, + attn_mask * -10000, q_seq_lens ) return output diff --git a/vllm_mindspore/model_executor/layers/logits_processor.py b/vllm_mindspore/model_executor/layers/logits_processor.py index 647b4ac83..75f35d6da 100644 --- a/vllm_mindspore/model_executor/layers/logits_processor.py +++ b/vllm_mindspore/model_executor/layers/logits_processor.py @@ -41,6 +41,7 @@ if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None: _logits_processor_threadpool = ThreadPoolExecutor( envs.VLLM_LOGITS_PROCESSOR_THREADS) + class LogitsProcessor(nn.Cell): """Process logits and apply logits processors from sampling metadata. @@ -88,6 +89,8 @@ class LogitsProcessor(nn.Cell): logits = hidden_states else: if sampling_metadata is not None: + if sampling_metadata.selected_token_indices.numel() <= 0: + return mint.zeros((0, self.vocab_size), dtype=hidden_states.dtype) hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) # Get the logits for the next tokens. @@ -102,7 +105,7 @@ class LogitsProcessor(nn.Cell): logits *= self.scale # Apply logits processors (if any). - if sampling_metadata is not None: + if sampling_metadata.seq_groups is not None: logits = _apply_logits_processors(logits, sampling_metadata) return logits @@ -146,10 +149,10 @@ def _prune_hidden_states( # NOTE(kzawora): The if guard is needed for Gaudi - in some scenarios # (warmup, profile_run) we might not have selected_token_indices, # so we skip pruning. - if sampling_metadata.selected_token_indices is not None: - return ops.gather(hidden_states, sampling_metadata.selected_token_indices, 0) - else: - return hidden_states + indices = sampling_metadata.selected_token_indices + if indices is not None and indices.numel() > 0: + return mint.index_select(hidden_states, 0, sampling_metadata.selected_token_indices) + return hidden_states def _apply_logits_processors( @@ -187,7 +190,7 @@ def _apply_logits_processors( logits_processed += len(seq_group.sample_indices) + len( seq_group.prompt_logprob_indices ) - + for logits_row_idx, future in logits_row_ids_and_logits_row_futures: logits[logits_row_idx] = future.result() @@ -196,6 +199,7 @@ def _apply_logits_processors( assert logits_processed == logits.shape[0] return logits + def _apply_logits_processors_single_seq(logits_row, logits_processors, past_tokens_ids, prompt_tokens_ids) -> Tensor: @@ -206,4 +210,4 @@ def _apply_logits_processors_single_seq(logits_row, logits_processors, logits_row) else: logits_row = logits_processor(past_tokens_ids, logits_row) - return logits_row \ No newline at end of file + return logits_row diff --git a/vllm_mindspore/model_executor/models/mf_models/attention_mask.py b/vllm_mindspore/model_executor/models/attention_mask.py similarity index 59% rename from vllm_mindspore/model_executor/models/mf_models/attention_mask.py rename to vllm_mindspore/model_executor/models/attention_mask.py index 10fcd25ec..3a9fe6983 100644 --- a/vllm_mindspore/model_executor/models/mf_models/attention_mask.py +++ b/vllm_mindspore/model_executor/models/attention_mask.py @@ -18,36 +18,29 @@ infer attention mask. """ import numpy as np -import mindspore as ms -from mindspore import Tensor, JitConfig, Model +from mindspore import Tensor, mint class LowerTriangularMask: r""" Provide Infer model attention mask. Args: - mf_model_config (MF Config): The config of Infer model. + dtype (mstype): The dtype of the mask. + max_model_len (int): The maximum length of the model. """ - def __init__(self, mf_model_config): - compute_dtype = mf_model_config.compute_dtype - seq_length = mf_model_config.seq_length - self.prefill_mask = Tensor(np.triu(np.ones(shape=(128, 128), dtype=np.float16), k=1), dtype=compute_dtype) - - self.decode_mask = Tensor(np.triu(np.ones(shape=(seq_length, seq_length), dtype=np.int8), k=1), - dtype=compute_dtype) - - self.hard_mask = Tensor([0], dtype=compute_dtype).reshape(1, 1) - - self.gather = ms.ops.Gather() + def __init__(self, dtype, max_model_len): + self.prefill_mask = Tensor.from_numpy(np.triu(np.ones(shape=(128, 128), dtype=np.float16), k=1)).to(dtype) + self.decode_mask = Tensor.from_numpy(np.triu(np.ones(shape=(max_model_len, max_model_len), dtype=np.int8), k=1)).to(dtype) + self.hard_mask = mint.zeros((1, 1), dtype=dtype) def gen_attention_mask(self, is_prefill, position_ids, query_lens): if is_prefill: attention_mask = self.prefill_mask else: if max(query_lens) > 1: - attention_mask = self.gather(self.decode_mask, position_ids, 0) + attention_mask = mint.index_select(self.decode_mask, 0, position_ids) else: attention_mask = self.hard_mask return attention_mask diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py index d8769c37d..51f0d9858 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py @@ -40,7 +40,7 @@ from mindformers.core.context import build_mf_context from mindformers.core.parallel_config import build_parallel_config from vllm_mindspore.model_executor.models.model_base import MsModelBase -from vllm_mindspore.model_executor.models.mf_models.attention_mask import LowerTriangularMask +from vllm_mindspore.model_executor.models.attention_mask import LowerTriangularMask logger = init_logger(__name__) @@ -73,7 +73,8 @@ class MfModelBase(MsModelBase): ) self.mf_config.model.model_config.parallel_config.pipeline_stage = 1 self._generate_model_config() - self.casual_mask = LowerTriangularMask(mf_model_config=self.mf_model_config) + self.casual_mask = LowerTriangularMask(dtype=self.mf_model_config.compute_dtype, + max_model_len=self.mf_model_config.seq_length) self.network, self.lm_head = self._create_network() affinity_config = self.mf_config.get('context', {}).get('affinity_cpu_list', {}) if isinstance(affinity_config, dict): diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index b97d71526..961f54a2d 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -31,9 +31,7 @@ from vllm.forward_context import get_forward_context import torch from mindspore import Tensor, nn, mutable -from mindspore import dtype as mstype -from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE class Fake_Attention: def __init__(self): @@ -64,6 +62,7 @@ class Fake_MLA(Fake_Attention): for _ in range(vllm_config.parallel_config.pipeline_parallel_size) ] + class MsModelBase(): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super(MsModelBase, self).__init__() @@ -187,47 +186,6 @@ class MsModelBase(): ) -> Union[Tensor, IntermediateTensors]: raise NotImplementedError - def set_model_inputs(self, is_prefill): - dyn_input_ids = Tensor(shape=[None, None], dtype=mstype.int64) - dyn_position_ids = Tensor(shape=[None], dtype=mstype.int64) - - block_size = self.cache_config.block_size - num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) - head_size = self.model_config.get_head_size() - kv_cache_shape = (None, block_size, num_kv_heads, head_size) - - kv_cache_dtype = self.model_config.dtype if self.cache_config.cache_dtype == "auto" \ - else self.cache_config.cache_dtype - kv_cache_dtype = STR_DTYPE_TO_MS_DTYPE[kv_cache_dtype] - - num_layers = self.model_config.get_num_layers(self.parallel_config) - - dyn_key_cache = mutable(Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype)) - dyn_value_cache = mutable(Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype)) - dyn_key_caches = mutable([dyn_key_cache for _ in range(num_layers)]) - dyn_value_caches = mutable([dyn_value_cache for _ in range(num_layers)]) - - dyn_batch_valid_length = Tensor(shape=[None, ], dtype=mstype.int32) - dyn_q_seq_lens = Tensor(shape=[None, ], dtype=mstype.int32) - dyn_slot_mapping = Tensor(shape=[None, ], dtype=mstype.int32) - dyn_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) - dyn_intermediate_tensors = None - dyn_inputs_embeds = None - - self.model.set_inputs( - dyn_input_ids, - dyn_position_ids, - dyn_key_caches, - dyn_value_caches, - is_prefill, - dyn_slot_mapping, - dyn_batch_valid_length, - dyn_q_seq_lens, - dyn_block_tables, - dyn_intermediate_tensors, - dyn_inputs_embeds - ) - def get_kvcache(self): key_cache = [] value_cache = [] diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index 32d9da8d9..5eb70a827 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -15,7 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -from vllm.config import get_current_vllm_config from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union, Iterable if TYPE_CHECKING: @@ -25,7 +24,7 @@ else: import numpy as np -from mindspore import Parameter, Tensor, mint, nn, jit, ops +from mindspore import Parameter, Tensor, mint, nn, jit, ops, mutable from mindspore.common import dtype as mstype @@ -49,6 +48,8 @@ from vllm_mindspore.model_executor.models.utils import ( maybe_prefix) from vllm_mindspore.model_executor.sampling_metadata import SamplingMetadata from vllm_mindspore.model_executor.models.model_base import MsModelBase, Fake_Attention +from vllm_mindspore.model_executor.models.attention_mask import LowerTriangularMask +from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE from vllm.config import CacheConfig, VllmConfig @@ -173,8 +174,6 @@ class Qwen2Attention(nn.Cell): prefix=f"{prefix}.attn", attn_type=attn_type ) - self.attn_mask = mint.triu(mint.ones(size=(128, 128), dtype=mstype.bfloat16), 1) - self.hard_mask = Tensor([0], dtype=mstype.bfloat16).reshape(1, 1) @jit def construct( @@ -185,15 +184,16 @@ class Qwen2Attention(nn.Cell): value_cache: Tensor, is_prefill: bool, slot_mapping: Tensor, - batch_valid_length: Tuple[int], + attn_mask: Tensor, + batch_valid_length: Tensor, q_seq_lens: Tensor, block_tables: Tensor, ) -> Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = mint.split(qkv, (self.q_size, self.kv_size, self.kv_size), -1) - q, k = self.rotary_emb(positions, q, k, q_seq_lens, is_prefill) - attn_output = self.attn(q, k, v, key_cache, value_cache, is_prefill, slot_mapping, batch_valid_length, - q_seq_lens, block_tables, self.attn_mask, self.hard_mask) + q, k = self.rotary_emb(positions, q, k, batch_valid_length, is_prefill) + attn_output = self.attn(q, k, v, key_cache, value_cache, is_prefill, slot_mapping, attn_mask, + batch_valid_length, q_seq_lens, block_tables) output, _ = self.o_proj(attn_output) return output @@ -257,7 +257,8 @@ class Qwen2DecoderLayer(nn.Cell): value_cache: Tensor, is_prefill: bool, slot_mapping: Tensor, - batch_valid_length: Tuple[int], + attn_mask: Tensor, + batch_valid_length: Tensor, q_seq_lens: Tensor, block_tables: Tensor, residual: Optional[Tensor], @@ -275,6 +276,7 @@ class Qwen2DecoderLayer(nn.Cell): value_cache, is_prefill, slot_mapping, + attn_mask, batch_valid_length, q_seq_lens, block_tables @@ -342,6 +344,7 @@ class Qwen2Model(nn.Cell): value_caches: List[Tensor], is_prefill: bool, slot_mapping: Tensor, + attn_mask: Tensor, batch_valid_length: Tensor, q_seq_lens: Tensor, block_tables: Tensor, @@ -367,6 +370,7 @@ class Qwen2Model(nn.Cell): value_caches[i - self.start_layer], is_prefill, slot_mapping, + attn_mask, batch_valid_length, q_seq_lens, block_tables, @@ -486,6 +490,9 @@ class Qwen2ForCausalLM(MsModelBase): self.set_modules({"model": self.model, "lm_head": self.lm_head}) self.prefill = True + self.mstype = STR_DTYPE_TO_MS_DTYPE.get(self.model_config.dtype, self.model_config.dtype) + self.casual_mask = LowerTriangularMask(dtype=self.mstype, + max_model_len=self.model_config.max_model_len) self.set_model_inputs(self.prefill) self.kv_caches = [Fake_Attention() for i in range(config.num_hidden_layers)] compilation_config = vllm_config.compilation_config @@ -495,8 +502,47 @@ class Qwen2ForCausalLM(MsModelBase): for i in range(config.num_hidden_layers): compilation_config.static_forward_context[str(i)] = self.kv_caches[i] - def get_input_embeddings(self, input_ids: Tensor) -> Tensor: - return self.model.get_input_embeddings(input_ids) + def set_model_inputs(self, is_prefill): + dyn_input_ids = Tensor(shape=[None, None], dtype=mstype.int64) + dyn_position_ids = Tensor(shape=[None], dtype=mstype.int64) + + block_size = self.cache_config.block_size + num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) + head_size = self.model_config.get_head_size() + kv_cache_shape = (None, block_size, num_kv_heads, head_size) + + kv_cache_dtype = self.model_config.dtype if self.cache_config.cache_dtype == "auto" \ + else self.cache_config.cache_dtype + kv_cache_dtype = STR_DTYPE_TO_MS_DTYPE[kv_cache_dtype] + + num_layers = self.model_config.get_num_layers(self.parallel_config) + + dyn_key_cache = Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype) + dyn_value_cache = Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype) + dyn_key_caches = mutable([dyn_key_cache for _ in range(num_layers)]) + dyn_value_caches = mutable([dyn_value_cache for _ in range(num_layers)]) + + dyn_slot_mapping = Tensor(shape=[None, ], dtype=mstype.int32) + dynamic_attention_mask = Tensor(shape=[None, None], dtype=self.mstype) + dyn_batch_valid_length = Tensor(shape=[None,], dtype=mstype.int32) + dyn_q_seq_lens = Tensor(shape=[None, ], dtype=mstype.int32) + dyn_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) + dyn_intermediate_tensors = None + dyn_inputs_embeds = None + self.model.set_inputs( + dyn_input_ids, + dyn_position_ids, + dyn_key_caches, + dyn_value_caches, + is_prefill, + dyn_slot_mapping, + dynamic_attention_mask, + dyn_batch_valid_length, + dyn_q_seq_lens, + dyn_block_tables, + dyn_intermediate_tensors, + dyn_inputs_embeds + ) def forward( self, @@ -535,7 +581,9 @@ class Qwen2ForCausalLM(MsModelBase): self.set_model_inputs(self.prefill) slot_mapping = attn_metadata.slot_mapping - batch_valid_length = Tensor.from_numpy(np.array(attn_metadata.seq_lens, dtype=np.int32)) + attn_mask = self.casual_mask.gen_attention_mask(is_prefill, positions, query_lens) + seq_lens_np = np.array(attn_metadata.seq_lens, dtype=np.int32) + batch_valid_length = Tensor.from_numpy(seq_lens_np) q_seq_lens = Tensor.from_numpy(np.array(attn_metadata.query_lens, dtype=np.int32)) block_tables = attn_metadata.block_tables model_output = self.model(input_ids, @@ -544,6 +592,7 @@ class Qwen2ForCausalLM(MsModelBase): value_cache, is_prefill, slot_mapping, + attn_mask, batch_valid_length, q_seq_lens, block_tables, diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index f1939751b..b67c3cc87 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -229,6 +229,10 @@ def check_ready(): # Common environment variables of predict. set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + default_env = { + "MS_INTERNAL_DISABLE_CUSTOM_KERNEL_LIST": "FlashAttentionScore,PagedAttention", + } + env_setup(default_env) if os.getenv("MS_MEMPOOL_BLOCK_SIZE"): set_context(mempool_block_size=f"{os.environ['MS_MEMPOOL_BLOCK_SIZE']}GB") @@ -243,11 +247,6 @@ def check_ready(): 'For "MindFormers" model backend, environments %s should be set!' % str(lost_envs) ) - - mindformers_default_env = { - "MS_INTERNAL_DISABLE_CUSTOM_KERNEL_LIST": "FlashAttentionScore,PagedAttention", - } - env_setup(mindformers_default_env) elif is_mindway_model_backend(): logger.info("Run with MindWAY backend!") else: -- Gitee From 73212f83f5c8e98a4d112385d1933248c287d8df Mon Sep 17 00:00:00 2001 From: alien_0119 Date: Thu, 8 May 2025 17:38:08 +0800 Subject: [PATCH 2/4] add qwen2.5-vl --- .../model_executor/layers/rotary_embedding.py | 48 +- .../model_executor/models/model_base.py | 4 +- .../model_executor/models/qwen2_5_vl.py | 1077 +++++++++++++++++ .../model_executor/models/registry.py | 1 + vllm_mindspore/worker/worker.py | 10 +- 5 files changed, 1120 insertions(+), 20 deletions(-) create mode 100644 vllm_mindspore/model_executor/models/qwen2_5_vl.py diff --git a/vllm_mindspore/model_executor/layers/rotary_embedding.py b/vllm_mindspore/model_executor/layers/rotary_embedding.py index c9dfe254d..1fa2ae0d9 100644 --- a/vllm_mindspore/model_executor/layers/rotary_embedding.py +++ b/vllm_mindspore/model_executor/layers/rotary_embedding.py @@ -22,6 +22,7 @@ import numpy as np import mindspore from mindspore import Tensor, mint, ops from mindspore.common import dtype as mstype +from mindspore.ops.auto_generate.gen_ops_prim import SliceExt from transformers import PretrainedConfig @@ -460,47 +461,64 @@ class InferMRotaryEmbedding(InferRotaryEmbedding): query: [num_tokens, num_heads * head_size] key: [num_tokens, num_kv_heads * head_size] """ + half_rotary_dim = self.rotary_dim // 2 # prefill if is_prefill: num_tokens = positions.shape[-1] cos, sin = self.freqs_cos[positions], self.freqs_sin[positions] - cos, sin = cos[..., :self.rotary_dim//2], sin[..., :self.rotary_dim//2] + #cos, sin = cos[..., :self.rotary_dim//2], sin[..., :self.rotary_dim//2] + cos = SliceExt()(cos, -1, 0, half_rotary_dim, 1) + sin = SliceExt()(sin, -1, 0, half_rotary_dim, 1) if positions.ndim == 2: - cos_l = ops.split(cos, self.mrope_section, axis=-1) - sin_l = ops.split(sin, self.mrope_section, axis=-1) + cos_l = mint.split(cos, self.mrope_section, dim=-1) + sin_l = mint.split(sin, self.mrope_section, dim=-1) cos, sin = (), () for i in range(len(self.mrope_section)): - cos += (cos_l[i][i],) - sin += (sin_l[i][i],) + #cos += (cos_l[i][i],) + #sin += (sin_l[i][i],) + cos_l_select = mint.index_select(cos_l[i], 0, Tensor([i])).squeeze(0) + cos += (cos_l_select,) + sin_l_select = mint.index_select(sin_l[i], 0, Tensor([i])).squeeze(0) + sin += (sin_l_select,) cos = ops.cat(cos, axis=-1) sin = ops.cat(sin, axis=-1) query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) - query_rot = query[..., :self.rotary_dim] - query_pass = query[..., self.rotary_dim:] + #query_rot = query[..., :self.rotary_dim] + #query_pass = query[..., self.rotary_dim:] + query_rot = SliceExt()(query, -1, 0, self.rotary_dim, 1) + query_pass = SliceExt()(query, -1, self.rotary_dim, query_shape[-1], 1) query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) query = ops.cat((query_rot, query_pass), axis=-1).view(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] + #key_rot = key[..., :self.rotary_dim] + #key_pass = key[..., self.rotary_dim:] + key_rot = SliceExt()(key, -1, 0, self.rotary_dim, 1) + key_pass = SliceExt()(key, -1, self.rotary_dim, key_shape[-1], 1) key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) key = ops.cat((key_rot, key_pass), axis=-1).view(key_shape) return query, key # decode - if positions.ndim == 2 and positions.shape[0] == len(self.mrope_section): + if positions.ndim == 2: num_tokens = positions.shape[-1] cos, sin = self.freqs_cos[positions], self.freqs_sin[positions] - cos, sin = cos[..., :self.rotary_dim//2], sin[..., :self.rotary_dim//2] - cos_l = ops.split(cos, self.mrope_section, axis=-1) - sin_l = ops.split(sin, self.mrope_section, axis=-1) + #cos, sin = cos[..., :self.rotary_dim//2], sin[..., :self.rotary_dim//2] + cos = SliceExt()(cos, -1, 0, half_rotary_dim, 1) + sin = SliceExt()(sin, -1, 0, half_rotary_dim, 1) + cos_l = mint.split(cos, self.mrope_section, dim=-1) + sin_l = mint.split(sin, self.mrope_section, dim=-1) cos, sin = (), () for i in range(len(self.mrope_section)): - cos += (cos_l[i][i],) - sin += (sin_l[i][i],) + #cos += (cos_l[i][i],) + #sin += (sin_l[i][i],) + cos_l_select = mint.index_select(cos_l[i], 0, Tensor([i])).squeeze(0) + cos += (cos_l_select,) + sin_l_select = mint.index_select(sin_l[i], 0, Tensor([i])).squeeze(0) + sin += (sin_l_select,) cos = ops.cat(cos, axis=-1) sin = ops.cat(sin, axis=-1) freqs_cos = ops.cat([cos, cos], axis=-1).squeeze(1) diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 961f54a2d..8d6d6dea5 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -162,6 +162,7 @@ class MsModelBase(): inputs_embeds: Optional[Tensor] = None, previous_hidden_states: Optional[Tensor] = None, spec_step_idx: int = 0, + **kwargs, ) -> Union[Tensor, IntermediateTensors]: return self.forward( input_ids, @@ -171,7 +172,8 @@ class MsModelBase(): intermediate_tensors, inputs_embeds, previous_hidden_states=previous_hidden_states, - spec_step_idx=spec_step_idx + spec_step_idx=spec_step_idx, + **kwargs, ) def forward( diff --git a/vllm_mindspore/model_executor/models/qwen2_5_vl.py b/vllm_mindspore/model_executor/models/qwen2_5_vl.py new file mode 100644 index 000000000..908e9a445 --- /dev/null +++ b/vllm_mindspore/model_executor/models/qwen2_5_vl.py @@ -0,0 +1,1077 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# credits to zhtmike and SamitHuang from https://github.com/SamitHuang/vllm-mindspore +"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" +from functools import partial +from typing import Callable, Iterable, List, Mapping, Optional, Set, Tuple, Union, Dict, Any + +import numpy as np +import math +import mindspore as ms +import mindspore.nn as nn +import mindspore.mint as mint +import mindspore.ops as ops +import mindspore.mint.nn.functional as F + +from transformers import BatchFeature +from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig + +from vllm_mindspore.model_executor.sampling_metadata import SamplingMetadata +from vllm_mindspore.model_executor.layers.layernorm import RMSNorm +from vllm_mindspore.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear +from vllm_mindspore.model_executor.layers.logits_processor import LogitsProcessor +from vllm_mindspore.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm_mindspore.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm_mindspore.model_executor.model_loader.weight_utils import default_weight_loader +from vllm_mindspore.model_executor.models.model_base import MsModelBase, Fake_Attention +from vllm_mindspore.model_executor.models.interfaces import SupportsMultiModal +from vllm_mindspore.model_executor.models.qwen2 import Qwen2Model +from vllm_mindspore.model_executor.models.utils import PPMissingLayer, WeightsMapper, maybe_prefix, merge_multimodal_embeddings +from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE + +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder +from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalProcessor +from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VLImageInputs, Qwen2_5_VLVideoInputs, Qwen2_5_VLImagePixelInputs, Qwen2_5_VLImageEmbeddingInputs, Qwen2_5_VLVideoPixelInputs, Qwen2_5_VLVideoEmbeddingInputs, Qwen2_5_VLProcessingInfo +from vllm.attention import AttentionMetadata +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.processing import PromptReplacement +from vllm.multimodal.parse import MultiModalDataItems +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank, tensor_model_parallel_all_gather +from vllm.distributed import utils as dist_utils +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.config import uses_mrope + + +logger = init_logger(__name__) + + +_ACTIVATION_REGISTRY = {"silu": F.silu} + + +# === Vision Inputs === # + +class _Qwen2VLMultiModalProcessor(Qwen2VLMultiModalProcessor): + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor( + **hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + placeholder = { + "image": vocab[hf_processor.image_token], + "video": vocab[hf_processor.video_token], + } + + merge_length = image_processor.merge_size**2 + + def get_replacement_qwen2vl(item_idx: int, modality: str): + grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] + assert isinstance(grid_thw, np.ndarray) + + num_tokens = int(grid_thw.prod()) // merge_length + return [placeholder[modality]] * num_tokens + + return [ + PromptReplacement( + modality=modality, + target=[placeholder[modality]], + replacement=partial(get_replacement_qwen2vl, + modality=modality), + ) for modality in ("image", "video") + ] + +# === Vision Encoder === # + + +class Qwen2_5_VisionMLP(nn.Cell): + + def __init__(self, + in_features: int, + hidden_features: int, + bias: bool = False, + act_fn: Callable[[ms.Tensor], ms.Tensor] = F.silu, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.gate_proj = ColumnParallelLinear(in_features, + hidden_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_proj", + params_dtype=ms.bfloat16) + self.up_proj = ColumnParallelLinear(in_features, + hidden_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.up_proj", + params_dtype=ms.bfloat16) + self.down_proj = RowParallelLinear(hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + params_dtype=ms.bfloat16) + self.act_fn = act_fn + + def construct(self, x: ms.Tensor): + x_gate, _ = self.gate_proj(x) + x_gate = self.act_fn(x_gate) + x_up, _ = self.up_proj(x) + x_down, _ = self.down_proj(x_gate * x_up) + return x_down + + +def apply_rotary_pos_emb_flashatt(q: ms.Tensor, k: ms.Tensor, cos: ms.Tensor, sin: ms.Tensor) -> Tuple[ms.Tensor, ms.Tensor]: + q_embed = ops.rotary_position_embedding(q.float(), cos, sin).type_as(q) + k_embed = ops.rotary_position_embedding(k.float(), cos, sin).type_as(k) + return q_embed, k_embed + + +class Qwen2_5_VisionAttention(nn.Cell): + + def __init__( + self, + embed_dim: int, + num_heads: int, + projection_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + # Per attention head and per partition values. + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.hidden_size_per_attention_head = dist_utils.divide(projection_size, num_heads) + self.num_attention_heads_per_partition = dist_utils.divide(num_heads, self.tp_size) + self.num_heads = num_heads + + self.qkv = ColumnParallelLinear(input_size=embed_dim, + output_size=3 * projection_size, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + params_dtype=ms.bfloat16) + self.proj = RowParallelLinear(input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + params_dtype=ms.bfloat16) + + def split_qkv(self, qkv: ms.Tensor) -> tuple[ms.Tensor, ...]: + # [s, 3 * head * head_dim] + seq_len, _ = qkv.shape + if self.tp_size > 1: + qkv = tensor_model_parallel_all_gather(qkv) + + # [s, 3 * head * head_dim] -> 3 * [s, head * head_dim] + q, k, v = mint.chunk(qkv, 3, dim=-1) + + # 3 * [s, head * head_dim] + if self.tp_size > 1: + splitter = partial(dist_utils.split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + v = splitter(v)[self.tp_rank] + + # 3 * [s, head * head_dim] -> 3 * [s, head, head_dim] + new_shape = (seq_len, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + q, k, v = (x.view(*new_shape) for x in (q, k, v)) + return q, k, v + + def construct( + self, + x: ms.Tensor, + cu_seqlens: ms.Tensor, + position_embeddings: Tuple[ms.Tensor, ms.Tensor], + ) -> ms.Tensor: + seq_length = x.shape[0] + x, _ = self.qkv(x) + q, k, v = self.split_qkv(x) + + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_flashatt(mint.unsqueeze(q, 0), mint.unsqueeze(k, 0), cos, sin) + + q = mint.squeeze(q, 0) + k = mint.squeeze(k, 0) + + context_layer = ops.flash_attention_score( + q, + k, + v, + self.num_heads // self.tp_size, + actual_seq_qlen=cu_seqlens, + actual_seq_kvlen=cu_seqlens, + scalar_value=1 / math.sqrt(q.shape[-1]), + input_layout="TND", + ).reshape(seq_length, -1) + output, _ = self.proj(context_layer) + return output + + +class Qwen2_5_VisionBlock(nn.Cell): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_hidden_dim: int, + act_fn: Callable[[ms.Tensor], ms.Tensor] = F.silu, + norm_layer: Optional[Callable[[int], nn.Cell]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = partial(mint.nn.LayerNorm, eps=1e-6, dtype=ms.bfloat16) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + self.attn = Qwen2_5_VisionAttention(embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn") + self.mlp = Qwen2_5_VisionMLP(dim, + mlp_hidden_dim, + act_fn=act_fn, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + def construct(self, x: ms.Tensor, cu_seqlens: ms.Tensor, + position_embeddings: Tuple[ms.Tensor, ms.Tensor]) -> ms.Tensor: + x = x + self.attn(self.norm1(x), + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings) + x = x + self.mlp(self.norm2(x)) + return x + + +class Qwen2_5_VisionPatchEmbed(nn.Cell): + + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + hidden_size: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.hidden_size = hidden_size + + kernel_size = (temporal_patch_size, patch_size, patch_size) + self.proj = mint.nn.Conv3d(in_channels, + hidden_size, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + dtype=ms.bfloat16) + + def construct(self, x: ms.Tensor) -> ms.Tensor: + L, C = x.shape + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, + self.patch_size) + x = self.proj(x).view(L, self.hidden_size) + return x + + +class Qwen2_5_VisionPatchMerger(nn.Cell): + + def __init__( + self, + d_model: int, + context_dim: int, + norm_layer: Optional[Callable[[int], nn.Cell]] = None, + spatial_merge_size: int = 2, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + if norm_layer is None: + norm_layer = partial(mint.nn.LayerNorm, eps=1e-6, dtype=ms.bfloat16) + self.ln_q = norm_layer(context_dim) + self.mlp = nn.CellList([ + ColumnParallelLinear(self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.0", + params_dtype=ms.bfloat16), + mint.nn.GELU(), + RowParallelLinear(self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.2", + params_dtype=ms.bfloat16), + ]) + + def construct(self, x: ms.Tensor) -> ms.Tensor: + x = self.ln_q(x) + x = x.view(-1, self.hidden_size) + + mlp_fc1, mlp_act, mlp_fc2 = self.mlp + x_parallel, _ = mlp_fc1(x) + x_parallel = mlp_act(x_parallel) + out, _ = mlp_fc2(x_parallel) + return out + + +class Qwen2_5_VisionRotaryEmbedding(nn.Cell): + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.dim = dim + self.theta = theta + self.inv_freq = 1.0 / (theta + ** (mint.arange(0, dim, 2, dtype=ms.float32) / dim)) + self._seq_len_cached = 0 + self._freqs_cached = None + + def update_freqs_cache(self, seqlen: int) -> None: + if seqlen > self._seq_len_cached: + seqlen *= 2 + self._seq_len_cached = seqlen + self.inv_freq = 1.0 / (self.theta**(mint.arange + (0, self.dim, 2, dtype=ms.float32) / self.dim)) + seq = mint.arange(seqlen, dtype=self.inv_freq.dtype) + freqs = mint.outer(seq, self.inv_freq) + self._freqs_cached = freqs + + def construct(self, seqlen: int) -> ms.Tensor: + self.update_freqs_cache(seqlen) + return self._freqs_cached[:seqlen] + + +class Qwen2_5_VisionTransformer(nn.Cell): + + def __init__( + self, + vision_config: Qwen2_5_VLVisionConfig, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + patch_size = vision_config.patch_size + temporal_patch_size = vision_config.temporal_patch_size + in_channels = vision_config.in_channels + depth = vision_config.depth + self.hidden_size = vision_config.hidden_size + self.num_heads = vision_config.num_heads + + # args for get_window_index + self.window_size = vision_config.window_size + self.patch_size = vision_config.patch_size + self.spatial_merge_size = vision_config.spatial_merge_size + self.fullatt_block_indexes = vision_config.fullatt_block_indexes + self.spatial_merge_unit = self.spatial_merge_size**2 + + self.patch_embed = Qwen2_5_VisionPatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=in_channels, + hidden_size=self.hidden_size, + ) + + norm_layer = partial(RMSNorm, eps=norm_eps, params_dtype=ms.bfloat16) + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.CellList([ + Qwen2_5_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(depth) + ]) + self.merger = Qwen2_5_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + norm_layer=norm_layer, + spatial_merge_size=self.spatial_merge_size, + quant_config=quant_config, + prefix=f"{prefix}.merger", + ) + + @property + def dtype(self) -> ms.Type: + return self.patch_embed.proj.weight.dtype + + def rot_pos_emb(self, grid_thw: ms.Tensor) -> ms.Tensor: + pos_ids = [] + for t, h, w in grid_thw: + t, h, w = t.item(), h.item(), w.item() + hpos_ids = mint.arange(h).unsqueeze(1).expand((-1, w)) + wpos_ids = mint.arange(w).unsqueeze(0).expand((h, -1)) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + pos_ids.append(mint.tile(mint.stack([hpos_ids, wpos_ids], dim=-1), (t, 1))) + pos_ids = mint.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max().item() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = (self.window_size // + self.spatial_merge_size // self.patch_size) + + for grid_t, grid_h, grid_w in grid_thw: + grid_t, grid_h, grid_w = grid_t.item(), grid_h.item(), grid_w.item() + llm_grid_h = grid_h // self.spatial_merge_size + llm_grid_w = grid_w // self.spatial_merge_size + index = mint.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100) + index_padded = index_padded.reshape(grid_t, num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, num_windows_h * num_windows_w, vit_merger_window_size, + vit_merger_window_size) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = mint.cumsum(seqlens, 0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += grid_t * llm_grid_h * llm_grid_w + window_index = mint.cat(window_index, dim=0) + return window_index, cu_window_seqlens + + def construct( + self, + x: ms.Tensor, + grid_thw: ms.Tensor, + ) -> ms.Tensor: + # patchify + hidden_states = x.to(dtype=self.dtype) + hidden_states = self.patch_embed(hidden_states) + + # compute position embedding + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + # windows attention + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = ms.tensor(cu_window_seqlens, dtype=ms.int32) + cu_window_seqlens = mint.unique_consecutive(cu_window_seqlens) + seq_len, _ = hidden_states.shape + hidden_states = hidden_states.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(1, seq_len, 1, -1) + emb = mint.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (mint.cos(emb), mint.sin(emb)) + + # compute cu_seqlens + cu_seqlens = mint.cumsum(mint.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]), dim=0, dtype=ms.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) + + # transformers + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + hidden_states = blk(hidden_states, + cu_seqlens=cu_seqlens_now, + position_embeddings=position_embeddings) + + # adapter + hidden_states = self.merger(hidden_states) + reverse_indices = mint.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, + ms.Tensor]], params_dict: Dict[str, ms.Parameter]) -> Set[str]: + loaded_params: Set[str] = set() + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Qwen2_5_VLMultiModalProcessor(_Qwen2VLMultiModalProcessor): + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + **super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs), + second_per_grid_ts=MultiModalFieldConfig.batched("video"), + ) + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen2_5_VLMultiModalProcessor, + info=Qwen2_5_VLProcessingInfo, + dummy_inputs=Qwen2_5_VLDummyInputsBuilder) +class Qwen2_5_VLForConditionalGeneration(MsModelBase, SupportsMultiModal): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + # LoRA specific attributes + supported_lora_modules = [ + # language model + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", # Same name with vision encoder + # vision tower + "qkv", + "gate_proj", + "up_proj", + "attn.proj", # Distinguish patch_embed.proj + "fc1", + "fc2", + # projector + "mlp.0", + "mlp.2" + ] + + embedding_modules = {} + embedding_padding_modules = [] + + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + self.visual = Qwen2_5_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=maybe_prefix(prefix, "visual"), + ) + + self.model = Qwen2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + params_dtype=ms.bfloat16, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head")) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = self.model.make_empty_intermediate_tensors + self.set_modules({"visual": self.visual, "model": self.model, "lm_head": self.lm_head}) + + self.prefill = True + + self.kv_caches = [Fake_Attention() for i in range(config.num_hidden_layers)] + compilation_config = vllm_config.compilation_config + + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + for i in range(config.num_hidden_layers): + compilation_config.static_forward_context[str(i)] = self.kv_caches[i] + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + # GPTQ configs do not have a list of ignored modules, however AutoGPTQ + # seems to avoid vision encoder sections for some models. + # if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): + # return None + return quant_config + + def _validate_and_reshape_mm_tensor(self, mm_input: object, + name: str) -> ms.Tensor: + if not isinstance(mm_input, (ms.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, ms.Tensor): + if mm_input.ndim == 2: + return mm_input + if mm_input.ndim != 3: + raise ValueError(f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})") + return mint.concat(list(mm_input)) + else: + return mint.concat(mm_input) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Qwen2_5_VLImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, "image pixel values") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(pixel_values, (ms.Tensor, list)): + raise ValueError("Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}") + + return Qwen2_5_VLImagePixelInputs(type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw) + + if image_embeds is not None: + image_embeds = self._validate_and_reshape_mm_tensor( + image_embeds, "image embeds") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(image_embeds, ms.Tensor): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + return Qwen2_5_VLImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw) + + def _parse_and_validate_video_input( + self, **kwargs: object) -> Optional[Qwen2_5_VLVideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_embeds = kwargs.pop("video_embeds", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + second_per_grid_ts = kwargs.pop("second_per_grid_ts", None) + + if pixel_values_videos is None and video_embeds is None: + return None + + if pixel_values_videos is not None: + pixel_values_videos = self._validate_and_reshape_mm_tensor( + pixel_values_videos, "video pixel values") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + return Qwen2_5_VLVideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + ) + + if video_embeds is not None: + video_embeds = self._validate_and_reshape_mm_tensor( + video_embeds, "video embeds") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + if not isinstance(video_embeds, ms.Tensor): + raise ValueError("Incorrect type of video embeddings. " + f"Got type: {type(video_embeds)}") + return Qwen2_5_VLVideoEmbeddingInputs( + type="video_embeds", + video_embeds=video_embeds, + video_grid_thw=video_grid_thw) + + def _process_image_input( + self, + image_input: Qwen2_5_VLImageInputs) -> tuple[ms.Tensor, ...]: + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + else: + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + + # Split concatenated embeddings for each image item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return image_embeds.split(sizes.tolist()) + + def _process_video_input( + self, + video_input: Qwen2_5_VLVideoInputs) -> tuple[ms.Tensor, ...]: + + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + + if video_input["type"] == "video_embeds": + video_embeds = video_input["video_embeds"].type(self.visual.dtype) + else: + pixel_values_videos = video_input["pixel_values_videos"].type( + self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + + # Split concatenated embeddings for each video item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return video_embeds.split(sizes.tolist()) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", + "image_embeds") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if input_key in ("pixel_values_videos", + "video_embeds") and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input( + **kwargs) + return modalities + + def get_multimodal_embeddings( + self, **kwargs) -> Optional[tuple[ms.Tensor, ...]]: + + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return None + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[ms.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_video_input(video_input) + multimodal_embeddings += video_embeddings + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: ms.Tensor, + multimodal_embeddings: Optional[tuple[ms.Tensor, ...]] = None, + ) -> ms.Tensor: + inputs_embeds = self.model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + [self.config.image_token_id, self.config.video_token_id]) + return inputs_embeds + + def get_input_embeddings_v0( + self, + input_ids: ms.Tensor, + image_input: Optional[tuple[ms.Tensor, ...]] = None, + video_input: Optional[tuple[ms.Tensor, ...]] = None, + ) -> ms.Tensor: + + inputs_embeds = self.get_input_embeddings(input_ids) + if image_input is not None: + image_embeds = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=self.config.image_token_id, + ) + + if video_input is not None: + video_embeds = self._process_video_input(video_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + video_embeds, + placeholder_token_id=self.config.video_token_id, + ) + return inputs_embeds + + def set_model_inputs(self, input_ids: Optional[ms.Tensor] = None, position_ids: Optional[ms.Tensor] = None, + inputs_embeds: Optional[ms.Tensor] = None, is_prefill: bool = False,): + if input_ids is None: + dyn_input_ids = None + else: + dyn_input_ids = ms.Tensor(shape=[None] * input_ids.ndim, dtype=input_ids.dtype) + + if position_ids is None: + dyn_position_ids = None + else: + dyn_position_ids = ms.Tensor(shape=[None] * position_ids.ndim, dtype=position_ids.dtype) + + if inputs_embeds is None: + dyn_inputs_embeds = None + else: + dyn_inputs_embeds = ms.Tensor(shape=[None] * inputs_embeds.ndim, dtype=inputs_embeds.dtype) + + block_size = self.cache_config.block_size + num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) + head_size = self.model_config.get_head_size() + kv_cache_shape = (None, block_size, num_kv_heads, head_size) + + kv_cache_dtype = self.model_config.dtype if self.cache_config.cache_dtype == "auto" \ + else self.cache_config.cache_dtype + kv_cache_dtype = STR_DTYPE_TO_MS_DTYPE[kv_cache_dtype] + + num_layers = self.model_config.get_num_layers(self.parallel_config) + + dyn_key_cache = ms.mutable(ms.Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype)) + dyn_value_cache = ms.mutable(ms.Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype)) + dyn_key_caches = ms.mutable([dyn_key_cache for _ in range(num_layers)]) + dyn_value_caches = ms.mutable([dyn_value_cache for _ in range(num_layers)]) + + dyn_is_prefill = ms.mutable(is_prefill) + + dyn_batch_valid_length = ms.Tensor(shape=[None, ], dtype=ms.int32) + dyn_q_seq_lens = ms.Tensor(shape=[None, ], dtype=ms.int32) + dyn_slot_mapping = ms.Tensor(shape=[None, ], dtype=ms.int32) + dyn_block_tables = ms.Tensor(shape=[None, None], dtype=ms.int32) + dyn_intermediate_tensors = None + + self.model.set_inputs( + dyn_input_ids, + dyn_position_ids, + dyn_key_caches, + dyn_value_caches, + dyn_is_prefill, + dyn_slot_mapping, + dyn_batch_valid_length, + dyn_q_seq_lens, + dyn_block_tables, + dyn_intermediate_tensors, + dyn_inputs_embeds + ) + + self.lm_head.set_inputs( + ms.Tensor(shape=[None], dtype=ms.int64) + ) + + def forward( + self, + input_ids: ms.Tensor, + positions: ms.Tensor, + kv_caches: List[Tuple[ms.Tensor, ms.Tensor]], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[ms.Tensor] = None, + **kwargs: object, + ) -> Union[ms.Tensor, IntermediateTensors]: + key_cache, value_cache = self.get_kvcache() + + seq_lens = attn_metadata.seq_lens + max_query_len = attn_metadata.max_query_len + # When Mutli-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes and max_query_len will be 1. + if self.is_multi_step_chunked_prefill and max_query_len == 1: + query_lens = [1] * len(seq_lens) + else: + query_lens = attn_metadata.query_lens + + seq_lens_np = np.array(seq_lens, dtype=np.int32) + query_lens_np = np.array(query_lens, dtype=np.int32) + kv_cache_lens = seq_lens_np - query_lens_np + is_prefill = bool(attn_metadata.num_decode_tokens == 0 and kv_cache_lens.max() == 0) + + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner from + # `get_multimodal_embeddings` and `get_input_embeddings`, this + # condition is only for v0 compatibility. + elif inputs_embeds is None: + image_input = self._parse_and_validate_image_input(**kwargs) + video_input = self._parse_and_validate_video_input(**kwargs) + + if image_input is None and video_input is None: + inputs_embeds = None + else: + if uses_mrope(self.config): + assert positions.ndim == 2 and positions.shape[0] == 3, ( + "multimodal section rotary embedding requires " + f"(3, seq_len) positions, but got {positions.shape}") + inputs_embeds = self.get_input_embeddings_v0( + input_ids, + image_input=image_input, + video_input=video_input) + input_ids = None + + # a patch to avoid compiling different positions format between warm-up and real-run + if positions.ndim == 1: + positions = mint.tile(positions.expand_dims(0), (3, 1)) + + if is_prefill: + if input_ids is not None: + input_ids = input_ids.expand_dims(0) + else: + inputs_embeds = inputs_embeds.expand_dims(0) + if not self.prefill: + self.prefill = True + self.set_model_inputs(input_ids, positions, inputs_embeds, self.prefill) + else: + if input_ids is not None: + input_ids = input_ids.expand_dims(1) + else: + inputs_embeds = inputs_embeds.expand_dims(1) + if self.prefill: + self.prefill = False + self.set_model_inputs(input_ids, positions, inputs_embeds, self.prefill) + + slot_mapping = attn_metadata.slot_mapping + batch_valid_length = ms.Tensor.from_numpy(np.array(attn_metadata.seq_lens, dtype=np.int32)) + q_seq_lens = ms.Tensor.from_numpy(np.array(attn_metadata.query_lens, dtype=np.int32)) + block_tables = attn_metadata.block_tables + + model_output = self.model( + input_ids, + positions, + key_cache, + value_cache, + ms.mutable(is_prefill), + slot_mapping, + batch_valid_length, + q_seq_lens, + block_tables, + intermediate_tensors, + inputs_embeds, + ) + + if is_prefill: + model_output = model_output.squeeze(0) + else: + model_output = model_output.squeeze(1) + + return model_output + + def compute_logits( + self, + hidden_states: ms.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[ms.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) + return logits + + def sample( + self, logits: ms.Tensor, sampling_metadata: SamplingMetadata + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + ms.Tensor]]) -> Set[str]: + params_dict = self.get_params_dict() + for name, weight in weights: + if "visual." in name: + self.visual.load_weights([(name, weight)], params_dict) + else: + self.model.load_weights([(name, weight)], params_dict) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="visual.", + tower_model="visual.merger.") diff --git a/vllm_mindspore/model_executor/models/registry.py b/vllm_mindspore/model_executor/models/registry.py index c9790915c..7b72ca846 100644 --- a/vllm_mindspore/model_executor/models/registry.py +++ b/vllm_mindspore/model_executor/models/registry.py @@ -32,6 +32,7 @@ from vllm_mindspore.utils import is_mindformers_model_backend, is_mindway_model_ _MINDSPORE_MODELS = { "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), + "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), } _MINDFORMERS_MODELS = { diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index 8ce1bc91d..bd5137b42 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -56,13 +56,14 @@ def _prepare_input_for_warmup(model_config, model_runner, cache_engine, is_prefi sampling_params=SamplingParams(), block_tables={idx: block_tables}, lora_request=None, - multi_modal_data=None, - multi_modal_placeholders=None, + multi_modal_data=dummy_data.multi_modal_data, + multi_modal_placeholders=dummy_data.multi_modal_placeholders, ) for idx in range(bs) ] - model_input = model_runner.prepare_model_input(seqs) + finished_requests_ids = [seq.request_id for seq in seqs] + model_input = model_runner.prepare_model_input(seqs, finished_requests_ids=finished_requests_ids) block_tables = model_input.attn_metadata.block_tables if block_tables is not None and block_tables.numel() <= 0: model_input.attn_metadata.block_tables = torch.zeros((1, 1), dtype=torch.int32) @@ -76,7 +77,8 @@ def _warm_up_model(self) -> None: # cache_engine is a list with length equal to the size of pipeline-parallel, and only pp=1 is supported. kv_cache = self.cache_engine[0].gpu_cache is_mtp_model = self.speculative_config is not None and self.model_config.hf_config.model_type == "deepseek_mtp" - if is_mtp_model: + max_mm_tokens = self.model_runner.mm_registry.get_max_multimodal_tokens(self.model_config) + if is_mtp_model or max_mm_tokens > 0: # prefill mtp model model_input, previous_hidden_states = _prepare_input_for_warmup(self.model_config, self.model_runner, self.cache_engine[0], True, is_mtp_model) -- Gitee From 5d88541e7f5c4aff6f91c27c4fde3584b8c81f16 Mon Sep 17 00:00:00 2001 From: twc Date: Thu, 15 May 2025 11:20:47 +0800 Subject: [PATCH 3/4] opt qwen2.5 vl performance --- .../model_executor/models/qwen2_5_vl.py | 77 ++++++++++++------- 1 file changed, 49 insertions(+), 28 deletions(-) diff --git a/vllm_mindspore/model_executor/models/qwen2_5_vl.py b/vllm_mindspore/model_executor/models/qwen2_5_vl.py index 908e9a445..6033a5932 100644 --- a/vllm_mindspore/model_executor/models/qwen2_5_vl.py +++ b/vllm_mindspore/model_executor/models/qwen2_5_vl.py @@ -15,8 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -# credits to zhtmike and SamitHuang from https://github.com/SamitHuang/vllm-mindspore """Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" +import time from functools import partial from typing import Callable, Iterable, List, Mapping, Optional, Set, Tuple, Union, Dict, Any @@ -43,6 +43,7 @@ from vllm_mindspore.model_executor.models.model_base import MsModelBase, Fake_At from vllm_mindspore.model_executor.models.interfaces import SupportsMultiModal from vllm_mindspore.model_executor.models.qwen2 import Qwen2Model from vllm_mindspore.model_executor.models.utils import PPMissingLayer, WeightsMapper, maybe_prefix, merge_multimodal_embeddings +from vllm_mindspore.model_executor.models.attention_mask import LowerTriangularMask from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE from vllm.model_executor.models.module_mapping import MultiModelKeys @@ -181,7 +182,7 @@ class Qwen2_5_VisionAttention(nn.Cell): quant_config=quant_config, prefix=f"{prefix}.proj", params_dtype=ms.bfloat16) - + def split_qkv(self, qkv: ms.Tensor) -> tuple[ms.Tensor, ...]: # [s, 3 * head * head_dim] seq_len, _ = qkv.shape @@ -205,7 +206,7 @@ class Qwen2_5_VisionAttention(nn.Cell): q, k, v = (x.view(*new_shape) for x in (q, k, v)) return q, k, v - def construct( + def construct( self, x: ms.Tensor, cu_seqlens: ms.Tensor, @@ -427,7 +428,8 @@ class Qwen2_5_VisionTransformer(nn.Cell): quant_config=quant_config, prefix=f"{prefix}.merger", ) - + from mindspore.communication.management import get_rank + self.rank_id = get_rank() @property def dtype(self) -> ms.Type: return self.patch_embed.proj.weight.dtype @@ -645,7 +647,7 @@ class Qwen2_5_VLForConditionalGeneration(MsModelBase, SupportsMultiModal): self.model = Qwen2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - + if get_pp_group().is_last_rank: if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens @@ -665,6 +667,8 @@ class Qwen2_5_VLForConditionalGeneration(MsModelBase, SupportsMultiModal): self.prefill = True + self.mstype = STR_DTYPE_TO_MS_DTYPE.get(self.model_config.dtype, self.model_config.dtype) + self.kv_caches = [Fake_Attention() for i in range(config.num_hidden_layers)] compilation_config = vllm_config.compilation_config @@ -672,7 +676,11 @@ class Qwen2_5_VLForConditionalGeneration(MsModelBase, SupportsMultiModal): raise ValueError(f"Duplicate layer name: {prefix}") for i in range(config.num_hidden_layers): compilation_config.static_forward_context[str(i)] = self.kv_caches[i] + self.casual_mask = LowerTriangularMask(dtype=self.mstype, + max_model_len=self.model_config.max_model_len) + from mindspore.communication import get_group_size, get_rank + self.rank_id = get_rank() def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): # GPTQ configs do not have a list of ignored modules, however AutoGPTQ # seems to avoid vision encoder sections for some models. @@ -887,14 +895,14 @@ class Qwen2_5_VLForConditionalGeneration(MsModelBase, SupportsMultiModal): placeholder_token_id=self.config.video_token_id, ) return inputs_embeds - - def set_model_inputs(self, input_ids: Optional[ms.Tensor] = None, position_ids: Optional[ms.Tensor] = None, + + def set_model_inputs(self, input_ids: Optional[ms.Tensor] = None, position_ids: Optional[ms.Tensor] = None, inputs_embeds: Optional[ms.Tensor] = None, is_prefill: bool = False,): if input_ids is None: dyn_input_ids = None else: dyn_input_ids = ms.Tensor(shape=[None] * input_ids.ndim, dtype=input_ids.dtype) - + if position_ids is None: dyn_position_ids = None else: @@ -921,8 +929,7 @@ class Qwen2_5_VLForConditionalGeneration(MsModelBase, SupportsMultiModal): dyn_key_caches = ms.mutable([dyn_key_cache for _ in range(num_layers)]) dyn_value_caches = ms.mutable([dyn_value_cache for _ in range(num_layers)]) - dyn_is_prefill = ms.mutable(is_prefill) - + dynamic_attention_mask = ms.Tensor(shape=[None, None], dtype=self.mstype) dyn_batch_valid_length = ms.Tensor(shape=[None, ], dtype=ms.int32) dyn_q_seq_lens = ms.Tensor(shape=[None, ], dtype=ms.int32) dyn_slot_mapping = ms.Tensor(shape=[None, ], dtype=ms.int32) @@ -934,8 +941,9 @@ class Qwen2_5_VLForConditionalGeneration(MsModelBase, SupportsMultiModal): dyn_position_ids, dyn_key_caches, dyn_value_caches, - dyn_is_prefill, + is_prefill, dyn_slot_mapping, + dynamic_attention_mask, dyn_batch_valid_length, dyn_q_seq_lens, dyn_block_tables, @@ -958,7 +966,7 @@ class Qwen2_5_VLForConditionalGeneration(MsModelBase, SupportsMultiModal): **kwargs: object, ) -> Union[ms.Tensor, IntermediateTensors]: key_cache, value_cache = self.get_kvcache() - + seq_lens = attn_metadata.seq_lens max_query_len = attn_metadata.max_query_len # When Mutli-Step is enabled with Chunked-Prefill, prefills and @@ -991,11 +999,17 @@ class Qwen2_5_VLForConditionalGeneration(MsModelBase, SupportsMultiModal): assert positions.ndim == 2 and positions.shape[0] == 3, ( "multimodal section rotary embedding requires " f"(3, seq_len) positions, but got {positions.shape}") + # ms.runtime.synchronize() + # start = time.time() inputs_embeds = self.get_input_embeddings_v0( input_ids, image_input=image_input, video_input=video_input) input_ids = None + # ms.runtime.synchronize() + # end = time.time() + # if self.rank_id == 0: + # print("get_input_embeddings_v0--use--",end-start,flush=True) # a patch to avoid compiling different positions format between warm-up and real-run if positions.ndim == 1: @@ -1016,27 +1030,33 @@ class Qwen2_5_VLForConditionalGeneration(MsModelBase, SupportsMultiModal): inputs_embeds = inputs_embeds.expand_dims(1) if self.prefill: self.prefill = False - self.set_model_inputs(input_ids, positions, inputs_embeds, self.prefill) + self.set_model_inputs(input_ids, positions, inputs_embeds, self.prefill) slot_mapping = attn_metadata.slot_mapping batch_valid_length = ms.Tensor.from_numpy(np.array(attn_metadata.seq_lens, dtype=np.int32)) q_seq_lens = ms.Tensor.from_numpy(np.array(attn_metadata.query_lens, dtype=np.int32)) block_tables = attn_metadata.block_tables - - model_output = self.model( - input_ids, - positions, - key_cache, - value_cache, - ms.mutable(is_prefill), - slot_mapping, - batch_valid_length, - q_seq_lens, - block_tables, - intermediate_tensors, - inputs_embeds, - ) - + attn_mask = self.casual_mask.gen_attention_mask(is_prefill, positions, query_lens) + + # ms.runtime.synchronize() + # start = time.time() + model_output = self.model(input_ids, + positions, + key_cache, + value_cache, + is_prefill, + slot_mapping, + attn_mask, + batch_valid_length, + q_seq_lens, + block_tables, + intermediate_tensors, + inputs_embeds) + # ms.runtime.synchronize() + # end = time.time() + + # if self.rank_id == 0: + # print("is_prefill:{} model---use:{}----".format(is_prefill, end-start), flush=True) if is_prefill: model_output = model_output.squeeze(0) else: @@ -1066,6 +1086,7 @@ class Qwen2_5_VLForConditionalGeneration(MsModelBase, SupportsMultiModal): self.visual.load_weights([(name, weight)], params_dict) else: self.model.load_weights([(name, weight)], params_dict) + # pass def get_mm_mapping(self) -> MultiModelKeys: """ -- Gitee From 628b3ba3fcf97eac35ca9e87a4f35131dc2e7aba Mon Sep 17 00:00:00 2001 From: twc Date: Tue, 3 Jun 2025 14:33:05 +0800 Subject: [PATCH 4/4] add print --- .../model_executor/models/qwen2_5_vl.py | 16 ++++++++++++++-- vllm_mindspore/worker/worker.py | 3 +++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/vllm_mindspore/model_executor/models/qwen2_5_vl.py b/vllm_mindspore/model_executor/models/qwen2_5_vl.py index 6033a5932..60d3a9ce1 100644 --- a/vllm_mindspore/model_executor/models/qwen2_5_vl.py +++ b/vllm_mindspore/model_executor/models/qwen2_5_vl.py @@ -68,7 +68,6 @@ logger = init_logger(__name__) _ACTIVATION_REGISTRY = {"silu": F.silu} - # === Vision Inputs === # class _Qwen2VLMultiModalProcessor(Qwen2VLMultiModalProcessor): @@ -221,6 +220,13 @@ class Qwen2_5_VisionAttention(nn.Cell): q = mint.squeeze(q, 0) k = mint.squeeze(k, 0) + # print("seq_length------",seq_length,flush=True) + # print("q.shape------",q.shape,flush=True) + # print("k.shape------",k.shape,flush=True) + # print("v.shape------",v.shape,flush=True) + # print("self.num_heads // self.tp_size------",self.num_heads // self.tp_size,flush=True) + # print("cu_seqlens------",cu_seqlens,flush=True) + # print("1 / math.sqrt(q.shape[-1])------",1 / math.sqrt(q.shape[-1]),flush=True) context_layer = ops.flash_attention_score( q, @@ -901,6 +907,7 @@ class Qwen2_5_VLForConditionalGeneration(MsModelBase, SupportsMultiModal): if input_ids is None: dyn_input_ids = None else: + # dyn_input_ids = ms.Tensor(shape=[None] * input_ids.ndim, dtype=input_ids.dtype) dyn_input_ids = ms.Tensor(shape=[None] * input_ids.ndim, dtype=input_ids.dtype) if position_ids is None: @@ -1037,7 +1044,12 @@ class Qwen2_5_VLForConditionalGeneration(MsModelBase, SupportsMultiModal): q_seq_lens = ms.Tensor.from_numpy(np.array(attn_metadata.query_lens, dtype=np.int32)) block_tables = attn_metadata.block_tables attn_mask = self.casual_mask.gen_attention_mask(is_prefill, positions, query_lens) - + # print("positions---2222222--", positions.shape, flush=True) + # print("is_prefill---222222--", is_prefill, flush=True) + # print("slot_mapping---222222--", slot_mapping, flush=True) + # print("attn_mask--222222---", attn_mask, flush=True) + # print("batch_valid_length--222222---", batch_valid_length, flush=True) + # print("q_seq_lens---222222--", q_seq_lens, flush=True) # ms.runtime.synchronize() # start = time.time() model_output = self.model(input_ids, diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index bd5137b42..1195de6ed 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -48,6 +48,8 @@ def _prepare_input_for_warmup(model_config, model_runner, cache_engine, is_prefi seq_len = model_runner.scheduler_config.max_num_batched_tokens if is_prefill else 1 dummy_data = model_runner.input_registry.dummy_data_for_profiling(model_config, seq_len, model_runner.mm_registry) block_tables = [i for i in range(math.ceil(seq_len / cache_engine.block_size))] + # print("seq_len----",seq_len,flush=True) + # print("block_tables----",block_tables,flush=True) seqs = [ SequenceGroupMetadata( request_id=str(idx), @@ -61,6 +63,7 @@ def _prepare_input_for_warmup(model_config, model_runner, cache_engine, is_prefi ) for idx in range(bs) ] + # print("seqs----",seqs,flush=True) finished_requests_ids = [seq.request_id for seq in seqs] model_input = model_runner.prepare_model_input(seqs, finished_requests_ids=finished_requests_ids) -- Gitee