From f6ee9cc4498d2ac6ad15489c4cdebf33c5883e8b Mon Sep 17 00:00:00 2001 From: zlq2020 Date: Thu, 19 Jun 2025 22:08:08 +0800 Subject: [PATCH] refactor native model jit --- .jenkins/test/config/dependent_packages.yaml | 2 +- vllm_mindspore/attention/layer.py | 5 +- .../model_executor/layers/rotary_embedding.py | 8 +- vllm_mindspore/model_executor/models/llama.py | 11 +-- .../model_executor/models/model_base.py | 87 ++++++++++--------- vllm_mindspore/model_executor/models/qwen2.py | 13 ++- vllm_mindspore/model_executor/utils.py | 12 +++ 7 files changed, 76 insertions(+), 62 deletions(-) diff --git a/.jenkins/test/config/dependent_packages.yaml b/.jenkins/test/config/dependent_packages.yaml index 16ca50f..9c68e4e 100644 --- a/.jenkins/test/config/dependent_packages.yaml +++ b/.jenkins/test/config/dependent_packages.yaml @@ -1,5 +1,5 @@ mindspore: - 'https://repo.mindspore.cn/mindspore/mindspore/version/202506/20250613/br_infer_iter_20250613031508_11bcfd2ff4dc201a1c07e5d525cbeff7ec7f9558_newest/' + 'https://repo.mindspore.cn/mindspore/mindspore/version/202506/20250620/br_infer_iter_20250620031508_8daa8f27c9ff571aa1028635d636265e7178a93d_newest/' mindspore_gs: 'https://repo.mindspore.cn/mindspore/golden-stick/version/202506/20250604/master_20250604160014_35fcbec4406d3b18faf02ef99fcbe2741e80348e_newest/' diff --git a/vllm_mindspore/attention/layer.py b/vllm_mindspore/attention/layer.py index f4af1af..77f4ea6 100644 --- a/vllm_mindspore/attention/layer.py +++ b/vllm_mindspore/attention/layer.py @@ -26,6 +26,8 @@ from vllm.config import CacheConfig from vllm.model_executor.layers.quantization.base_config import \ QuantizationConfig +from vllm_mindspore.model_executor.utils import get_model_context + def _pad_to_max_tensor(input_: Tensor, max_len: int, @@ -142,7 +144,6 @@ class Attention(nn.Cell): value: Tensor, key_cache: Tensor, value_cache: Tensor, - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -167,7 +168,7 @@ class Attention(nn.Cell): cache_out = self.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) query = ops.depend(query, cache_out) - if is_prefill: + if get_model_context("is_prefill"): output = self._run_prefill_forward(query, key, value, attn_mask, batch_valid_length, batch_valid_length) diff --git a/vllm_mindspore/model_executor/layers/rotary_embedding.py b/vllm_mindspore/model_executor/layers/rotary_embedding.py index 7470233..82fe1dd 100644 --- a/vllm_mindspore/model_executor/layers/rotary_embedding.py +++ b/vllm_mindspore/model_executor/layers/rotary_embedding.py @@ -30,6 +30,7 @@ from mindspore.ops.auto_generate.gen_ops_prim import SliceExt from transformers import PretrainedConfig from vllm.config import get_current_vllm_config +from vllm_mindspore.model_executor.utils import get_model_context def _apply_rotary_emb( x: Tensor, @@ -191,12 +192,11 @@ class InferRotaryEmbedding(nn.Cell): query: Tensor, key: Tensor, batch_valid_length: Tensor, - is_prefill: bool, offsets: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: query = query.contiguous() key = key.contiguous() - if is_prefill: + if get_model_context("is_prefill"): return self.rotary_embedding_op(query, key, self.freqs_cos, self.freqs_sin, batch_valid_length) @@ -282,7 +282,6 @@ class MRotaryEmbedding(RotaryEmbedding): query: mindspore.Tensor, key: mindspore.Tensor, batch_valid_length: Tensor = None, - is_prefill: bool = False, ) -> Tuple[mindspore.Tensor, mindspore.Tensor]: """ Args: @@ -526,7 +525,6 @@ class InferMRotaryEmbedding(InferRotaryEmbedding): query: mindspore.Tensor, key: mindspore.Tensor, batch_valid_length: Tensor = None, - is_prefill: bool = False, ) -> Tuple[mindspore.Tensor, mindspore.Tensor]: """ Args: @@ -538,7 +536,7 @@ class InferMRotaryEmbedding(InferRotaryEmbedding): """ half_rotary_dim = self.rotary_dim // 2 # prefill - if is_prefill: + if get_model_context("is_prefill"): num_tokens = positions.shape[-1] cos, sin = self.freqs_cos[positions], self.freqs_sin[positions] cos = SliceExt()(cos, -1, 0, half_rotary_dim, 1) diff --git a/vllm_mindspore/model_executor/models/llama.py b/vllm_mindspore/model_executor/models/llama.py index 954579f..b4fab2b 100644 --- a/vllm_mindspore/model_executor/models/llama.py +++ b/vllm_mindspore/model_executor/models/llama.py @@ -196,7 +196,6 @@ class LlamaAttention(nn.Cell): hidden_states: Tensor, key_cache: Tensor, value_cache: Tensor, - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -206,8 +205,8 @@ class LlamaAttention(nn.Cell): qkv, _ = self.qkv_proj(hidden_states) q, k, v = mint.split(qkv, (self.q_size, self.kv_size, self.kv_size), -1) - q, k = self.rotary_emb(positions, q, k, batch_valid_length, is_prefill) - attn_output = self.attn(q, k, v, key_cache, value_cache, is_prefill, + q, k = self.rotary_emb(positions, q, k, batch_valid_length) + attn_output = self.attn(q, k, v, key_cache, value_cache, slot_mapping, attn_mask, batch_valid_length, q_seq_lens, block_tables) output, _ = self.o_proj(attn_output) @@ -276,7 +275,6 @@ class LlamaDecoderLayer(nn.Cell): hidden_states: Tensor, key_cache: Tensor, value_cache: Tensor, - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -293,7 +291,7 @@ class LlamaDecoderLayer(nn.Cell): hidden_states, residual) hidden_states = self.self_attn(positions, hidden_states, key_cache, - value_cache, is_prefill, slot_mapping, + value_cache, slot_mapping, attn_mask, batch_valid_length, q_seq_lens, block_tables) @@ -365,7 +363,6 @@ class LlamaModel(nn.Cell): positions: Tensor, key_caches: List[Tensor], value_caches: List[Tensor], - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -390,7 +387,7 @@ class LlamaModel(nn.Cell): hidden_states, residual = layer(positions, hidden_states, key_caches[i - self.start_layer], value_caches[i - self.start_layer], - is_prefill, slot_mapping, + slot_mapping, attn_mask, batch_valid_length, q_seq_lens, block_tables, residual) diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index d2db979..b8de76a 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -22,6 +22,10 @@ from abc import abstractmethod from typing import Iterable, Optional, Set, Tuple, Union, Dict import numpy as np +import mindspore as ms +from mindspore import Tensor, mutable, nn +from mindspore.common import dtype as mstype + from vllm.attention.backends.abstract import AttentionType from vllm.config import VllmConfig, get_current_vllm_config from vllm.forward_context import get_forward_context @@ -30,14 +34,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors import vllm.envs as envs -import mindspore as ms -from mindspore import Tensor, nn, mutable -from mindspore.common import dtype as mstype - from vllm_mindspore.model_executor.models.attention_mask import LowerTriangularMask from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata - +from vllm_mindspore.model_executor.utils import set_model_context class AttentionWrapper: @@ -336,9 +336,10 @@ class NativeModel(MsModelBase): if vllm_config.lora_config is not None: # native model lora only support pynative mode now vllm_config.model_config.enforce_eager = True - self.is_graph_mode = bool(not vllm_config.model_config.enforce_eager) - self.prev_prefill = False - self.run_model = None + self.is_eager_mode = vllm_config.model_config.enforce_eager + self.fa_network = None + self.pa_network = None + def common_preprocess(self, vllm_config, prefix=""): self.set_modules({ @@ -416,7 +417,6 @@ class NativeModel(MsModelBase): dyn_position_ids, dyn_key_caches, # type: ignore[attr-defined] dyn_value_caches, - is_prefill, dyn_slot_mapping, dynamic_attention_mask, dyn_batch_valid_length, @@ -435,11 +435,21 @@ class NativeModel(MsModelBase): model_inputs, is_prefill = self.prepare_base_inputs( input_ids, positions) + new_model_inputs = {} + new_model_inputs["input_ids"] = model_inputs["input_ids"] + new_model_inputs["batch_valid_length"] =model_inputs["batch_valid_length"] + new_model_inputs["block_tables"] = model_inputs["block_tables"] + new_model_inputs["slot_mapping"] = model_inputs["slot_mapping"] + new_model_inputs["positions"] = model_inputs["position_ids"] + new_model_inputs["q_seq_lens"] = model_inputs["q_seq_lens"] + new_model_inputs["attn_mask"] = model_inputs["attention_mask"] + new_model_inputs["key_caches"] = model_inputs["key_cache"] + new_model_inputs["value_caches"] = model_inputs["value_cache"] # for multimodal model - model_inputs["intermediate_tensors"] = intermediate_tensors - model_inputs["inputs_embeds"] = inputs_embeds + new_model_inputs["intermediate_tensors"] = intermediate_tensors + new_model_inputs["inputs_embeds"] = inputs_embeds - return model_inputs, is_prefill + return new_model_inputs, is_prefill def exec_model(self, input_ids: Tensor, @@ -451,33 +461,32 @@ class NativeModel(MsModelBase): intermediate_tensors, inputs_embeds) - if self.prev_prefill != is_prefill and self.is_graph_mode: - self.set_model_inputs(input_ids, positions, intermediate_tensors, - inputs_embeds, is_prefill) - self.prev_prefill = is_prefill - - # for dummy_attention_metadata + # for dummy_attention_metadata if is_prefill and not self.set_flags: - self.set_flags = True - - if self.run_model is None: - self.run_model = ms.jit( - function=self.model, # type: ignore[attr-defined] - jit_level='O0' - ) if self.is_graph_mode else self.model # type: ignore[attr-defined] - model_output = self.run_model( # type: ignore[misc] - input_ids=model_inputs["input_ids"], - positions=model_inputs["position_ids"], - key_caches=model_inputs["key_cache"], - value_caches=model_inputs["value_cache"], - is_prefill=is_prefill, - slot_mapping=model_inputs["slot_mapping"], - attn_mask=model_inputs["attention_mask"], - batch_valid_length=model_inputs["batch_valid_length"], - q_seq_lens=model_inputs["q_seq_lens"], - block_tables=model_inputs["block_tables"], - intermediate_tensors=model_inputs["intermediate_tensors"], - inputs_embeds=model_inputs["inputs_embeds"], - ) + self.set_flags = True + + # eager mode + if self.is_eager_mode: + set_model_context("is_prefill", is_prefill) + model_output = self.model(**model_inputs) + return model_output + + # graph mode + if is_prefill: + self.model.phase = "prefill" # For better performance, it will be improved. + if self.fa_network is None: + set_model_context("is_prefill", True) + self.model._set_jit_graph_name("prefill") + self.set_model_inputs() + self.fa_network = ms.jit(function=self.model, jit_level="O0") + model_output = self.fa_network(**model_inputs) + else: + self.model.phase = "increment" # For better performance, it will be improved. + if self.pa_network is None: + set_model_context("is_prefill", False) + self.model._set_jit_graph_name("decode") + self.set_model_inputs() + self.pa_network = ms.jit(function=self.model, jit_level="O0") + model_output = self.pa_network(**model_inputs) return model_output diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index 87c54c2..c370376 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -163,7 +163,6 @@ class Qwen2Attention(nn.Cell): hidden_states: Tensor, key_cache: Tensor, value_cache: Tensor, - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -173,8 +172,8 @@ class Qwen2Attention(nn.Cell): qkv, _ = self.qkv_proj(hidden_states) q, k, v = mint.split(qkv, (self.q_size, self.kv_size, self.kv_size), -1) - q, k = self.rotary_emb(positions, q, k, batch_valid_length, is_prefill) - attn_output = self.attn(q, k, v, key_cache, value_cache, is_prefill, + q, k = self.rotary_emb(positions, q, k, batch_valid_length) + attn_output = self.attn(q, k, v, key_cache, value_cache, slot_mapping, attn_mask, batch_valid_length, q_seq_lens, block_tables) output, _ = self.o_proj(attn_output) @@ -239,7 +238,6 @@ class Qwen2DecoderLayer(nn.Cell): hidden_states: Tensor, key_cache: Tensor, value_cache: Tensor, - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -255,7 +253,7 @@ class Qwen2DecoderLayer(nn.Cell): hidden_states, residual = self.input_layernorm( hidden_states, residual) hidden_states = self.self_attn(positions, hidden_states, key_cache, - value_cache, is_prefill, slot_mapping, + value_cache, slot_mapping, attn_mask, batch_valid_length, q_seq_lens, block_tables) @@ -316,7 +314,6 @@ class Qwen2Model(nn.Cell): positions: Tensor, key_caches: List[Tensor], value_caches: List[Tensor], - is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, batch_valid_length: Tensor, @@ -335,12 +332,12 @@ class Qwen2Model(nn.Cell): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): # PP 并行对层进行切分 + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, key_caches[i - self.start_layer], value_caches[i - self.start_layer], - is_prefill, slot_mapping, + slot_mapping, attn_mask, batch_valid_length, q_seq_lens, block_tables, residual) if not get_pp_group().is_last_rank: diff --git a/vllm_mindspore/model_executor/utils.py b/vllm_mindspore/model_executor/utils.py index eb421de..affadbd 100644 --- a/vllm_mindspore/model_executor/utils.py +++ b/vllm_mindspore/model_executor/utils.py @@ -28,3 +28,15 @@ def set_weight_attrs( return for key, value in weight_attrs.items(): setattr(weight, key, value) + + +_native_model_context = { + "is_prefill": True +} + +def set_model_context(key, value): + global _native_model_context + _native_model_context[key] = value + +def get_model_context(key): + return _native_model_context[key] -- Gitee