diff --git a/.jenkins/test/config/dependent_packages.yaml b/.jenkins/test/config/dependent_packages.yaml index 2efd644ec233c9acca1e996d0d45c297810dc393..31c5fbd7040eb43959ada2d631075dfce7ffc547 100644 --- a/.jenkins/test/config/dependent_packages.yaml +++ b/.jenkins/test/config/dependent_packages.yaml @@ -1,5 +1,6 @@ mindspore: - 'https://repo.mindspore.cn/mindspore/mindspore/version/202506/20250613/br_infer_iter_20250613031508_11bcfd2ff4dc201a1c07e5d525cbeff7ec7f9558_newest/' + 'https://repo.mindspore.cn/mindspore/mindspore/version/202506/20250620/br_infer_iter_20250620031508_8daa8f27c9ff571aa1028635d636265e7178a93d_newest/' + mindspore_gs: 'https://repo.mindspore.cn/mindspore/golden-stick/version/202506/20250604/master_20250604160014_35fcbec4406d3b18faf02ef99fcbe2741e80348e_newest/' diff --git a/vllm_mindspore/attention/layer.py b/vllm_mindspore/attention/layer.py index f4af1afba6e5e40691dce46d9a4a40eafe18ea0d..77f4ea6791fa9350631ada48b2a66d4610ccbbca 100644 --- a/vllm_mindspore/attention/layer.py +++ b/vllm_mindspore/attention/layer.py @@ -26,6 +26,8 @@ from vllm.config import CacheConfig from vllm.model_executor.layers.quantization.base_config import \ QuantizationConfig +from vllm_mindspore.model_executor.utils import get_model_context + def _pad_to_max_tensor(input_: Tensor, max_len: int, @@ -142,7 +144,6 @@ class Attention(nn.Cell): value: Tensor, key_cache: Tensor, value_cache: Tensor, - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -167,7 +168,7 @@ class Attention(nn.Cell): cache_out = self.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) query = ops.depend(query, cache_out) - if is_prefill: + if get_model_context("is_prefill"): output = self._run_prefill_forward(query, key, value, attn_mask, batch_valid_length, batch_valid_length) diff --git a/vllm_mindspore/model_executor/layers/rotary_embedding.py b/vllm_mindspore/model_executor/layers/rotary_embedding.py index 7470233475254ccccdf677d5627a6f3bd59f6408..82fe1dde5aa9e918ca9f8edc82c032b77e8b433c 100644 --- a/vllm_mindspore/model_executor/layers/rotary_embedding.py +++ b/vllm_mindspore/model_executor/layers/rotary_embedding.py @@ -30,6 +30,7 @@ from mindspore.ops.auto_generate.gen_ops_prim import SliceExt from transformers import PretrainedConfig from vllm.config import get_current_vllm_config +from vllm_mindspore.model_executor.utils import get_model_context def _apply_rotary_emb( x: Tensor, @@ -191,12 +192,11 @@ class InferRotaryEmbedding(nn.Cell): query: Tensor, key: Tensor, batch_valid_length: Tensor, - is_prefill: bool, offsets: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: query = query.contiguous() key = key.contiguous() - if is_prefill: + if get_model_context("is_prefill"): return self.rotary_embedding_op(query, key, self.freqs_cos, self.freqs_sin, batch_valid_length) @@ -282,7 +282,6 @@ class MRotaryEmbedding(RotaryEmbedding): query: mindspore.Tensor, key: mindspore.Tensor, batch_valid_length: Tensor = None, - is_prefill: bool = False, ) -> Tuple[mindspore.Tensor, mindspore.Tensor]: """ Args: @@ -526,7 +525,6 @@ class InferMRotaryEmbedding(InferRotaryEmbedding): query: mindspore.Tensor, key: mindspore.Tensor, batch_valid_length: Tensor = None, - is_prefill: bool = False, ) -> Tuple[mindspore.Tensor, mindspore.Tensor]: """ Args: @@ -538,7 +536,7 @@ class InferMRotaryEmbedding(InferRotaryEmbedding): """ half_rotary_dim = self.rotary_dim // 2 # prefill - if is_prefill: + if get_model_context("is_prefill"): num_tokens = positions.shape[-1] cos, sin = self.freqs_cos[positions], self.freqs_sin[positions] cos = SliceExt()(cos, -1, 0, half_rotary_dim, 1) diff --git a/vllm_mindspore/model_executor/models/llama.py b/vllm_mindspore/model_executor/models/llama.py index 8781e3397f1444c7c5d11389bb26b0a94c9d3a3e..e5d0d14797a20e634602e184df2c168ea629b5b8 100644 --- a/vllm_mindspore/model_executor/models/llama.py +++ b/vllm_mindspore/model_executor/models/llama.py @@ -195,7 +195,6 @@ class LlamaAttention(nn.Cell): hidden_states: Tensor, key_cache: Tensor, value_cache: Tensor, - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -205,8 +204,8 @@ class LlamaAttention(nn.Cell): qkv, _ = self.qkv_proj(hidden_states) q, k, v = mint.split(qkv, (self.q_size, self.kv_size, self.kv_size), -1) - q, k = self.rotary_emb(positions, q, k, batch_valid_length, is_prefill) - attn_output = self.attn(q, k, v, key_cache, value_cache, is_prefill, + q, k = self.rotary_emb(positions, q, k, batch_valid_length) + attn_output = self.attn(q, k, v, key_cache, value_cache, slot_mapping, attn_mask, batch_valid_length, q_seq_lens, block_tables) output, _ = self.o_proj(attn_output) @@ -275,7 +274,6 @@ class LlamaDecoderLayer(nn.Cell): hidden_states: Tensor, key_cache: Tensor, value_cache: Tensor, - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -292,7 +290,7 @@ class LlamaDecoderLayer(nn.Cell): hidden_states, residual) hidden_states = self.self_attn(positions, hidden_states, key_cache, - value_cache, is_prefill, slot_mapping, + value_cache, slot_mapping, attn_mask, batch_valid_length, q_seq_lens, block_tables) @@ -364,7 +362,6 @@ class LlamaModel(nn.Cell): positions: Tensor, key_caches: List[Tensor], value_caches: List[Tensor], - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -389,7 +386,7 @@ class LlamaModel(nn.Cell): hidden_states, residual = layer(positions, hidden_states, key_caches[i - self.start_layer], value_caches[i - self.start_layer], - is_prefill, slot_mapping, + slot_mapping, attn_mask, batch_valid_length, q_seq_lens, block_tables, residual) diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index f4b81766010b7c714735fc427fe746c45317baf8..e060c5f6195991e7c11f46251c1641dd5f345c81 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -22,6 +22,10 @@ from abc import abstractmethod from typing import Iterable, Optional, Set, Tuple, Union, Dict import numpy as np +import mindspore as ms +from mindspore import Tensor, mutable, nn +from mindspore.common import dtype as mstype + from vllm.attention.backends.abstract import AttentionType from vllm.config import VllmConfig, get_current_vllm_config from vllm.forward_context import get_forward_context @@ -30,14 +34,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors import vllm.envs as envs -import mindspore as ms -from mindspore import Tensor, nn, mutable -from mindspore.common import dtype as mstype - from vllm_mindspore.model_executor.models.attention_mask import LowerTriangularMask from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata from vllm_mindspore.worker.profile_controller import vllm_mindspore_profile_controller +from vllm_mindspore.model_executor.utils import set_model_context class AttentionWrapper: @@ -340,9 +341,10 @@ class NativeModel(MsModelBase): if vllm_config.lora_config is not None: # native model lora only support pynative mode now vllm_config.model_config.enforce_eager = True - self.is_graph_mode = bool(not vllm_config.model_config.enforce_eager) - self.prev_prefill = False - self.run_model = None + self.is_eager_mode = vllm_config.model_config.enforce_eager + self.fa_network = None + self.pa_network = None + def common_preprocess(self, vllm_config, prefix=""): self.set_modules({ @@ -364,8 +366,8 @@ class NativeModel(MsModelBase): compilation_config.static_forward_context[str( i)] = self.kv_caches[i] - def set_model_inputs(self, input_ids, position_ids, intermediate_tensors, - inputs_embeds, is_prefill): + def set_model_inputs(self, input_ids=None, position_ids=None, intermediate_tensors=None, + inputs_embeds=None): if input_ids is None: dyn_input_ids = None else: @@ -420,7 +422,6 @@ class NativeModel(MsModelBase): dyn_position_ids, dyn_key_caches, # type: ignore[attr-defined] dyn_value_caches, - is_prefill, dyn_slot_mapping, dynamic_attention_mask, dyn_batch_valid_length, @@ -439,11 +440,21 @@ class NativeModel(MsModelBase): model_inputs, is_prefill = self.prepare_base_inputs( input_ids, positions) + new_model_inputs = {} + new_model_inputs["input_ids"] = model_inputs["input_ids"] + new_model_inputs["batch_valid_length"] =model_inputs["batch_valid_length"] + new_model_inputs["block_tables"] = model_inputs["block_tables"] + new_model_inputs["slot_mapping"] = model_inputs["slot_mapping"] + new_model_inputs["positions"] = model_inputs["position_ids"] + new_model_inputs["q_seq_lens"] = model_inputs["q_seq_lens"] + new_model_inputs["attn_mask"] = model_inputs["attention_mask"] + new_model_inputs["key_caches"] = model_inputs["key_cache"] + new_model_inputs["value_caches"] = model_inputs["value_cache"] # for multimodal model - model_inputs["intermediate_tensors"] = intermediate_tensors - model_inputs["inputs_embeds"] = inputs_embeds + new_model_inputs["intermediate_tensors"] = intermediate_tensors + new_model_inputs["inputs_embeds"] = inputs_embeds - return model_inputs, is_prefill + return new_model_inputs, is_prefill def exec_model(self, input_ids: Tensor, @@ -455,33 +466,32 @@ class NativeModel(MsModelBase): intermediate_tensors, inputs_embeds) - if self.prev_prefill != is_prefill and self.is_graph_mode: - self.set_model_inputs(input_ids, positions, intermediate_tensors, - inputs_embeds, is_prefill) - self.prev_prefill = is_prefill - # for dummy_attention_metadata if is_prefill and not self.set_flags: self.set_flags = True - if self.run_model is None: - self.run_model = ms.jit( - function=self.model, # type: ignore[attr-defined] - jit_level='O0' - ) if self.is_graph_mode else self.model # type: ignore[attr-defined] - model_output = self.run_model( # type: ignore[misc] - input_ids=model_inputs["input_ids"], - positions=model_inputs["position_ids"], - key_caches=model_inputs["key_cache"], - value_caches=model_inputs["value_cache"], - is_prefill=is_prefill, - slot_mapping=model_inputs["slot_mapping"], - attn_mask=model_inputs["attention_mask"], - batch_valid_length=model_inputs["batch_valid_length"], - q_seq_lens=model_inputs["q_seq_lens"], - block_tables=model_inputs["block_tables"], - intermediate_tensors=model_inputs["intermediate_tensors"], - inputs_embeds=model_inputs["inputs_embeds"], - ) + # eager mode + if self.is_eager_mode: + set_model_context("is_prefill", is_prefill) + model_output = self.model(**model_inputs) + return model_output + + # graph mode + if is_prefill: + self.model.phase = "prefill" # For better performance, it will be improved. + if self.fa_network is None: + set_model_context("is_prefill", True) + self.model._set_jit_graph_name("prefill") + self.set_model_inputs() + self.fa_network = ms.jit(function=self.model, jit_level="O0") + model_output = self.fa_network(**model_inputs) + else: + self.model.phase = "increment" # For better performance, it will be improved. + if self.pa_network is None: + set_model_context("is_prefill", False) + self.model._set_jit_graph_name("decode") + self.set_model_inputs() + self.pa_network = ms.jit(function=self.model, jit_level="O0") + model_output = self.pa_network(**model_inputs) return model_output diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index 0f6e7bf7d159c3c82a1284a7294407efb81fe6a1..790b1a67982a23ad30982e9f86d9ccc2a4fcf561 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -162,7 +162,6 @@ class Qwen2Attention(nn.Cell): hidden_states: Tensor, key_cache: Tensor, value_cache: Tensor, - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -172,8 +171,8 @@ class Qwen2Attention(nn.Cell): qkv, _ = self.qkv_proj(hidden_states) q, k, v = mint.split(qkv, (self.q_size, self.kv_size, self.kv_size), -1) - q, k = self.rotary_emb(positions, q, k, batch_valid_length, is_prefill) - attn_output = self.attn(q, k, v, key_cache, value_cache, is_prefill, + q, k = self.rotary_emb(positions, q, k, batch_valid_length) + attn_output = self.attn(q, k, v, key_cache, value_cache, slot_mapping, attn_mask, batch_valid_length, q_seq_lens, block_tables) output, _ = self.o_proj(attn_output) @@ -238,7 +237,6 @@ class Qwen2DecoderLayer(nn.Cell): hidden_states: Tensor, key_cache: Tensor, value_cache: Tensor, - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -254,7 +252,7 @@ class Qwen2DecoderLayer(nn.Cell): hidden_states, residual = self.input_layernorm( hidden_states, residual) hidden_states = self.self_attn(positions, hidden_states, key_cache, - value_cache, is_prefill, slot_mapping, + value_cache, slot_mapping, attn_mask, batch_valid_length, q_seq_lens, block_tables) @@ -315,7 +313,6 @@ class Qwen2Model(nn.Cell): positions: Tensor, key_caches: List[Tensor], value_caches: List[Tensor], - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -334,12 +331,12 @@ class Qwen2Model(nn.Cell): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): # PP 并行对层进行切分 + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, key_caches[i - self.start_layer], value_caches[i - self.start_layer], - is_prefill, slot_mapping, + slot_mapping, attn_mask, batch_valid_length, q_seq_lens, block_tables, residual) if not get_pp_group().is_last_rank: diff --git a/vllm_mindspore/model_executor/utils.py b/vllm_mindspore/model_executor/utils.py index eb421de0ba02911be0da18afb4b885efc566fc74..affadbd8cf48433c7d5904efdb83c4c41dd2aded 100644 --- a/vllm_mindspore/model_executor/utils.py +++ b/vllm_mindspore/model_executor/utils.py @@ -28,3 +28,15 @@ def set_weight_attrs( return for key, value in weight_attrs.items(): setattr(weight, key, value) + + +_native_model_context = { + "is_prefill": True +} + +def set_model_context(key, value): + global _native_model_context + _native_model_context[key] = value + +def get_model_context(key): + return _native_model_context[key]