diff --git a/vllm_mindspore/attention/layer.py b/vllm_mindspore/attention/layer.py index 40f4328d91895659b0971de5d765b52c5fc37257..1490d3cf1189d1ea1a7d2e85012a1de8a9495c5f 100644 --- a/vllm_mindspore/attention/layer.py +++ b/vllm_mindspore/attention/layer.py @@ -28,6 +28,8 @@ from vllm.config import CacheConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm_mindspore.model_executor.utils import get_model_context + class Attention(nn.Cell): """Attention layer. @@ -84,10 +86,9 @@ class Attention(nn.Cell): kv_head_num=num_kv_heads) def construct(self, query: Tensor, key: Tensor, value: Tensor, - key_cache: Tensor, value_cache: Tensor, is_prefill: bool, - slot_mapping: Tensor, attn_mask: Tensor, - batch_valid_length: Tensor, q_seq_lens: Tensor, - block_tables: Tensor) -> Tensor: + key_cache: Tensor, value_cache: Tensor, slot_mapping: Tensor, + attn_mask: Tensor, batch_valid_length: Tensor, + q_seq_lens: Tensor, block_tables: Tensor) -> Tensor: """Attention forward, support MHA and GQA. Args: @@ -106,7 +107,7 @@ class Attention(nn.Cell): cache_out = self.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) query = ops.depend(query, cache_out) - if is_prefill: + if get_model_context("is_prefill"): output = self._run_prefill_forward(query, key, value, attn_mask, batch_valid_length, batch_valid_length) diff --git a/vllm_mindspore/model_executor/layers/quantization/attention.py b/vllm_mindspore/model_executor/layers/quantization/attention.py index 4aeebf4fdfb0bb96c98fad2dd92cc05173d5ac20..9b6659143c1f36d118aac8984985fd7a0d3c8893 100644 --- a/vllm_mindspore/model_executor/layers/quantization/attention.py +++ b/vllm_mindspore/model_executor/layers/quantization/attention.py @@ -34,7 +34,8 @@ from vllm.config import CacheConfig from vllm_mindspore.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm_mindspore.model_executor.utils import set_weight_attrs +from vllm_mindspore.model_executor.utils import (get_model_context, + set_weight_attrs) from vllm_mindspore.utils import is_310p @@ -134,7 +135,6 @@ class BaseKVCacheMethod(QuantizeMethodBase): value: Tensor, key_cache: Tensor, value_cache: Tensor, - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -147,6 +147,7 @@ class BaseKVCacheMethod(QuantizeMethodBase): cache_out = self.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) query = ops.depend(query, cache_out) + is_prefill = get_model_context("is_prefill") if self.use_fused_attn: if not is_prefill: key = self.reshape( @@ -351,7 +352,6 @@ class KVCacheInt8Method(BaseKVCacheMethod): value: Tensor, key_cache: Tensor, value_cache: Tensor, - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -370,7 +370,7 @@ class KVCacheInt8Method(BaseKVCacheMethod): cache_out = self.reshape_and_cache(quant_key, quant_value, key_cache, value_cache, slot_mapping) query = ops.depend(query, cache_out) - if is_prefill: + if get_model_context("is_prefill"): output = self._run_prefill_forward(query, key, value, attn_mask, batch_valid_length, batch_valid_length) diff --git a/vllm_mindspore/model_executor/layers/rotary_embedding.py b/vllm_mindspore/model_executor/layers/rotary_embedding.py index e13642a53bb6f520ebb1a06c99fe291e8b6698be..afe4a0d2381ee49b780d9781bc0bcdc5b574849a 100644 --- a/vllm_mindspore/model_executor/layers/rotary_embedding.py +++ b/vllm_mindspore/model_executor/layers/rotary_embedding.py @@ -35,6 +35,8 @@ from mindspore.ops.auto_generate.gen_ops_prim import SliceExt from transformers import PretrainedConfig from vllm.config import get_current_vllm_config +from vllm_mindspore.model_executor.utils import get_model_context + def _apply_rotary_emb( x: Tensor, @@ -195,12 +197,11 @@ class InferRotaryEmbedding(nn.Cell): query: Tensor, key: Tensor, batch_valid_length: Tensor, - is_prefill: bool, offsets: Optional[Tensor] = None, ) -> tuple[Tensor, Tensor]: query = query.contiguous() key = key.contiguous() - if is_prefill: + if get_model_context("is_prefill"): return self.rotary_embedding_op(query, key, self.freqs_cos, self.freqs_sin, batch_valid_length) @@ -286,7 +287,6 @@ class MRotaryEmbedding(RotaryEmbedding): query: mindspore.Tensor, key: mindspore.Tensor, batch_valid_length: Tensor = None, - is_prefill: bool = False, ) -> tuple[mindspore.Tensor, mindspore.Tensor]: """ Args: @@ -531,7 +531,6 @@ class InferMRotaryEmbedding(InferRotaryEmbedding): query: mindspore.Tensor, key: mindspore.Tensor, batch_valid_length: Tensor = None, - is_prefill: bool = False, ) -> tuple[mindspore.Tensor, mindspore.Tensor]: """ Args: @@ -543,7 +542,7 @@ class InferMRotaryEmbedding(InferRotaryEmbedding): """ half_rotary_dim = self.rotary_dim // 2 # prefill - if is_prefill: + if get_model_context("is_prefill"): num_tokens = positions.shape[-1] cos, sin = self.freqs_cos[positions], self.freqs_sin[positions] cos = SliceExt()(cos, -1, 0, half_rotary_dim, 1) diff --git a/vllm_mindspore/model_executor/models/llama.py b/vllm_mindspore/model_executor/models/llama.py index 086e5a43cb4a2386b9f6c7ff726fedfaecbc40b3..aed237ddb2c432910b052837e4290bd2644cc4af 100644 --- a/vllm_mindspore/model_executor/models/llama.py +++ b/vllm_mindspore/model_executor/models/llama.py @@ -217,7 +217,6 @@ class LlamaAttention(nn.Cell): hidden_states: Tensor, key_cache: Tensor, value_cache: Tensor, - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -227,10 +226,10 @@ class LlamaAttention(nn.Cell): qkv, _ = self.qkv_proj(hidden_states) q, k, v = mint.split(qkv, (self.q_size, self.kv_size, self.kv_size), -1) - q, k = self.rotary_emb(positions, q, k, batch_valid_length, is_prefill) - attn_output = self.attn(q, k, v, key_cache, value_cache, is_prefill, - slot_mapping, attn_mask, batch_valid_length, - q_seq_lens, block_tables) + q, k = self.rotary_emb(positions, q, k, batch_valid_length) + attn_output = self.attn(q, k, v, key_cache, value_cache, slot_mapping, + attn_mask, batch_valid_length, q_seq_lens, + block_tables) output, _ = self.o_proj(attn_output) return output @@ -297,7 +296,6 @@ class LlamaDecoderLayer(nn.Cell): hidden_states: Tensor, key_cache: Tensor, value_cache: Tensor, - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -314,9 +312,9 @@ class LlamaDecoderLayer(nn.Cell): hidden_states, residual) hidden_states = self.self_attn(positions, hidden_states, key_cache, - value_cache, is_prefill, slot_mapping, - attn_mask, batch_valid_length, - q_seq_lens, block_tables) + value_cache, slot_mapping, attn_mask, + batch_valid_length, q_seq_lens, + block_tables) # Fully Connected hidden_states, residual = self.post_attention_layernorm( @@ -387,7 +385,6 @@ class LlamaModel(nn.Cell): positions: Tensor, key_caches: list[Tensor], value_caches: list[Tensor], - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -412,9 +409,9 @@ class LlamaModel(nn.Cell): hidden_states, residual = layer(positions, hidden_states, key_caches[i - self.start_layer], value_caches[i - self.start_layer], - is_prefill, slot_mapping, - attn_mask, batch_valid_length, - q_seq_lens, block_tables, residual) + slot_mapping, attn_mask, + batch_valid_length, q_seq_lens, + block_tables, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index e7384143e78cee262603879d2cb7bc9652842b10..af6a7a6c0f2d2f725309bbc4ae59544f6f1da382 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -33,6 +33,7 @@ from vllm.sequence import IntermediateTensors from vllm_mindspore.model_executor.models.attention_mask import ( LowerTriangularMask) +from vllm_mindspore.model_executor.utils import set_model_context from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata @@ -343,9 +344,9 @@ class NativeModel(MsModelBase): if vllm_config.lora_config is not None: # native model lora only support pynative mode now vllm_config.model_config.enforce_eager = True - self.is_graph_mode = not vllm_config.model_config.enforce_eager - self.prev_prefill = False - self.run_model = None + self.is_eager_mode = vllm_config.model_config.enforce_eager + self.prefill_graph = None + self.decode_graph = None @property def ready_model(self) -> nn.Cell: @@ -376,8 +377,11 @@ class NativeModel(MsModelBase): compilation_config.static_forward_context[str( i)] = self.kv_caches[i] - def set_model_inputs(self, input_ids, position_ids, intermediate_tensors, - inputs_embeds, is_prefill): + def set_model_inputs(self, + input_ids=None, + position_ids=None, + intermediate_tensors=None, + inputs_embeds=None): if input_ids is None: dyn_input_ids = None else: @@ -428,11 +432,12 @@ class NativeModel(MsModelBase): dyn_batch_valid_length = Tensor(shape=[None], dtype=mstype.int32) dyn_q_seq_lens = Tensor(shape=[None], dtype=mstype.int32) dyn_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) - self.ready_model.set_inputs( - dyn_input_ids, dyn_position_ids, dyn_key_caches, dyn_value_caches, - is_prefill, dyn_slot_mapping, dynamic_attention_mask, - dyn_batch_valid_length, dyn_q_seq_lens, dyn_block_tables, - dyn_intermediate_tensors, dyn_inputs_embeds) + self.ready_model.set_inputs(dyn_input_ids, dyn_position_ids, + dyn_key_caches, dyn_value_caches, + dyn_slot_mapping, dynamic_attention_mask, + dyn_batch_valid_length, dyn_q_seq_lens, + dyn_block_tables, dyn_intermediate_tensors, + dyn_inputs_embeds) dynamic_hidden_states = Tensor(shape=[None, None], dtype=self.model_config.dtype) @@ -443,11 +448,22 @@ class NativeModel(MsModelBase): model_inputs, is_prefill = self.prepare_base_inputs( input_ids, positions) + new_model_inputs = {} + new_model_inputs["input_ids"] = model_inputs["input_ids"] + new_model_inputs["batch_valid_length"] = model_inputs[ + "batch_valid_length"] + new_model_inputs["block_tables"] = model_inputs["block_tables"] + new_model_inputs["slot_mapping"] = model_inputs["slot_mapping"] + new_model_inputs["positions"] = model_inputs["position_ids"] + new_model_inputs["q_seq_lens"] = model_inputs["q_seq_lens"] + new_model_inputs["attn_mask"] = model_inputs["attention_mask"] + new_model_inputs["key_caches"] = model_inputs["key_cache"] + new_model_inputs["value_caches"] = model_inputs["value_cache"] # for multimodal model - model_inputs["intermediate_tensors"] = intermediate_tensors - model_inputs["inputs_embeds"] = inputs_embeds + new_model_inputs["intermediate_tensors"] = intermediate_tensors + new_model_inputs["inputs_embeds"] = inputs_embeds - return model_inputs, is_prefill + return new_model_inputs, is_prefill def exec_model(self, input_ids: Tensor, @@ -459,34 +475,35 @@ class NativeModel(MsModelBase): intermediate_tensors, inputs_embeds) - if self.prev_prefill != is_prefill and self.is_graph_mode: - self.set_model_inputs(input_ids, positions, intermediate_tensors, - inputs_embeds, is_prefill) - self.prev_prefill = is_prefill - # for dummy_attention_metadata if is_prefill and not self.set_flags: self.set_flags = True - if self.run_model is None: - self.run_model = ms.jit( - function=self.model, - jit_level='O0') if self.is_graph_mode else self.model - if self.model is None: - raise RuntimeError("model is not initialized") - model_output = self.run_model( - input_ids=model_inputs["input_ids"], - positions=model_inputs["position_ids"], - key_caches=model_inputs["key_cache"], - value_caches=model_inputs["value_cache"], - is_prefill=is_prefill, - slot_mapping=model_inputs["slot_mapping"], - attn_mask=model_inputs["attention_mask"], - batch_valid_length=model_inputs["batch_valid_length"], - q_seq_lens=model_inputs["q_seq_lens"], - block_tables=model_inputs["block_tables"], - intermediate_tensors=model_inputs["intermediate_tensors"], - inputs_embeds=model_inputs["inputs_embeds"], - ) # type: ignore[misc] + # eager mode + if self.is_eager_mode: + set_model_context("is_prefill", is_prefill) + model_output = self.model(**model_inputs) + return model_output + + # graph mode + if is_prefill: + self.model.phase = "prefill" + if self.prefill_graph is None: + set_model_context("is_prefill", True) + self.model._set_jit_graph_name("prefill") + self.set_model_inputs(input_ids, positions, + intermediate_tensors, inputs_embeds) + self.prefill_graph = ms.jit(function=self.model, + jit_level="O0") + model_output = self.prefill_graph(**model_inputs) + else: + self.model.phase = "increment" + if self.decode_graph is None: + set_model_context("is_prefill", False) + self.model._set_jit_graph_name("decode") + self.set_model_inputs(input_ids, positions, + intermediate_tensors, inputs_embeds) + self.decode_graph = ms.jit(function=self.model, jit_level="O0") + model_output = self.decode_graph(**model_inputs) return model_output diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index 82937be771822009b9ac874d98279f4b148a9cb3..414e639655b8356ab168bd0895599f8a1719e4fe 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -174,7 +174,6 @@ class Qwen2Attention(nn.Cell): hidden_states: Tensor, key_cache: Tensor, value_cache: Tensor, - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -184,10 +183,10 @@ class Qwen2Attention(nn.Cell): qkv, _ = self.qkv_proj(hidden_states) q, k, v = mint.split(qkv, (self.q_size, self.kv_size, self.kv_size), -1) - q, k = self.rotary_emb(positions, q, k, batch_valid_length, is_prefill) - attn_output = self.attn(q, k, v, key_cache, value_cache, is_prefill, - slot_mapping, attn_mask, batch_valid_length, - q_seq_lens, block_tables) + q, k = self.rotary_emb(positions, q, k, batch_valid_length) + attn_output = self.attn(q, k, v, key_cache, value_cache, slot_mapping, + attn_mask, batch_valid_length, q_seq_lens, + block_tables) output, _ = self.o_proj(attn_output) return output @@ -250,7 +249,6 @@ class Qwen2DecoderLayer(nn.Cell): hidden_states: Tensor, key_cache: Tensor, value_cache: Tensor, - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -266,9 +264,9 @@ class Qwen2DecoderLayer(nn.Cell): hidden_states, residual = self.input_layernorm( hidden_states, residual) hidden_states = self.self_attn(positions, hidden_states, key_cache, - value_cache, is_prefill, slot_mapping, - attn_mask, batch_valid_length, - q_seq_lens, block_tables) + value_cache, slot_mapping, attn_mask, + batch_valid_length, q_seq_lens, + block_tables) # Fully Connected hidden_states, residual = self.post_attention_layernorm( @@ -331,7 +329,6 @@ class Qwen2Model(nn.Cell): positions: Tensor, key_caches: list[Tensor], value_caches: list[Tensor], - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -351,14 +348,14 @@ class Qwen2Model(nn.Cell): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): # PP 并行对层进行切分 + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, key_caches[i - self.start_layer], value_caches[i - self.start_layer], - is_prefill, slot_mapping, - attn_mask, batch_valid_length, - q_seq_lens, block_tables, residual) + slot_mapping, attn_mask, + batch_valid_length, q_seq_lens, + block_tables, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, diff --git a/vllm_mindspore/model_executor/models/qwen2_5_vl.py b/vllm_mindspore/model_executor/models/qwen2_5_vl.py index c67d5b7fad4ed1687404452f3eb84c4cd0e5bd18..47e405202e13cc08c19b6fcd50f85084232f3aa3 100644 --- a/vllm_mindspore/model_executor/models/qwen2_5_vl.py +++ b/vllm_mindspore/model_executor/models/qwen2_5_vl.py @@ -1204,7 +1204,7 @@ class Qwen2_5_VLForConditionalGeneration(NativeModel, SupportsMultiModal): quant_config=self._maybe_ignore_quant_config(quant_config), prefix=maybe_prefix(prefix, "visual"), ) - if self.is_graph_mode: + if not self.is_eager_mode: self.visual.construct = ms.jit(function=self.visual, jit_level='O0') self.visual.set_model_inputs() diff --git a/vllm_mindspore/model_executor/models/qwen3.py b/vllm_mindspore/model_executor/models/qwen3.py index 84fccb8ecd0aea6f9fd3373c1820c0a8b6f6d103..857f2766281c93685779e70536980636831a919b 100644 --- a/vllm_mindspore/model_executor/models/qwen3.py +++ b/vllm_mindspore/model_executor/models/qwen3.py @@ -157,7 +157,6 @@ class Qwen3Attention(nn.Cell): hidden_states: Tensor, key_cache: Tensor, value_cache: Tensor, - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -175,10 +174,10 @@ class Qwen3Attention(nn.Cell): self.head_dim) k_by_head = self.k_norm(k_by_head) k = k_by_head.view(k.shape) - q, k = self.rotary_emb(positions, q, k, batch_valid_length, is_prefill) - attn_output = self.attn(q, k, v, key_cache, value_cache, is_prefill, - slot_mapping, attn_mask, batch_valid_length, - q_seq_lens, block_tables) + q, k = self.rotary_emb(positions, q, k, batch_valid_length) + attn_output = self.attn(q, k, v, key_cache, value_cache, slot_mapping, + attn_mask, batch_valid_length, q_seq_lens, + block_tables) output, _ = self.o_proj(attn_output) return output @@ -240,7 +239,6 @@ class Qwen3DecoderLayer(nn.Cell): hidden_states: Tensor, key_cache: Tensor, value_cache: Tensor, - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -256,9 +254,9 @@ class Qwen3DecoderLayer(nn.Cell): hidden_states, residual = self.input_layernorm( hidden_states, residual) hidden_states = self.self_attn(positions, hidden_states, key_cache, - value_cache, is_prefill, slot_mapping, - attn_mask, batch_valid_length, - q_seq_lens, block_tables) + value_cache, slot_mapping, attn_mask, + batch_valid_length, q_seq_lens, + block_tables) # Fully Connected hidden_states, residual = self.post_attention_layernorm( diff --git a/vllm_mindspore/model_executor/utils.py b/vllm_mindspore/model_executor/utils.py index 38b4f7a2726c87808b09ef09f67b4b91a749a8d1..b7187afeda9d54d5352ee3454dbf5e00d12cf164 100644 --- a/vllm_mindspore/model_executor/utils.py +++ b/vllm_mindspore/model_executor/utils.py @@ -30,4 +30,16 @@ def set_weight_attrs( if weight_attrs is None: return for key, value in weight_attrs.items(): - setattr(weight, key, value) \ No newline at end of file + setattr(weight, key, value) + + +_native_model_context = {"is_prefill": True} + + +def set_model_context(key, value): + global _native_model_context + _native_model_context[key] = value + + +def get_model_context(key): + return _native_model_context[key]