From 45ef5cdff9204ae43b823c999cf7ce9afea83eb4 Mon Sep 17 00:00:00 2001 From: lijiakun Date: Fri, 11 Apr 2025 16:32:23 +0800 Subject: [PATCH] test --- tests/st/python/test_vllm_qwen_7b.py | 74 +++++++++++++++++++ vllm_mindspore/attention/backends/ms_attn.py | 15 +++- vllm_mindspore/attention/layer.py | 50 ++++++++----- .../models/mf_models/mf_model_base.py | 13 ---- .../model_executor/models/model_base.py | 25 +++++-- vllm_mindspore/model_executor/models/qwen2.py | 67 ++++++++++------- 6 files changed, 177 insertions(+), 67 deletions(-) create mode 100644 tests/st/python/test_vllm_qwen_7b.py diff --git a/tests/st/python/test_vllm_qwen_7b.py b/tests/st/python/test_vllm_qwen_7b.py new file mode 100644 index 00000000..bce75d3e --- /dev/null +++ b/tests/st/python/test_vllm_qwen_7b.py @@ -0,0 +1,74 @@ +# Copyright 2024 The vLLM team. +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://wwww.apache.org/licenses/LICENSE-2.0 +# +# Unless required by application law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test vllm qwen.""" +import pytest +import os +from . import set_env +env_manager = set_env.EnvVarManager() +# def env +env_vars = { + "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"), + "MS_ENABLE_LCCL": "off", + "HCCL_OP_EXPANSION_MODE": "AIV", + "ASCEND_RT_VISIBLE_DEVICES": "0,1", + "MS_ALLOC_CONF": "enable_vmm:True", + "LCCL_DETERMINISTIC": "1", + "HCCL_DETERMINISTIC": "true", + "ATB_MATMUL_SHUFFLE_K_ENABLE": "0", + "ATB_LLM_LCOC_ENABLE": "0" +} +# set env +env_manager.setup_ai_environment(env_vars) +import vllm_mindspore +from vllm import LLM, SamplingParams + + +class TestQwen: + """ + Test Qwen. + """ + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_single + def test_vllm_qwen(self): + """ + test case qwen2.5 7B + """ + + # Sample prompts. + prompts = [ + "You are a helpful assistant.<|User|>将文本分类为中性、负面或正面。 \n文本:我认为这次假期还可以。 \n情感:<|Assistant|>\n", + ] + + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_k=1) + + # Create an LLM. + llm = LLM(model="/home/workspace/mindspore_dataset/weight/Qwen2.5-7B-Instruct", + gpu_memory_utilization=0.9, tensor_parallel_size=2) + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + except_list=['中性<|Assistant|> 这句话'] + # Print the outputs. + for i, output in enumerate(outputs): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + assert generated_text == except_list[i] + + # unset env + env_manager.unset_all() diff --git a/vllm_mindspore/attention/backends/ms_attn.py b/vllm_mindspore/attention/backends/ms_attn.py index 49932122..13e73dd0 100644 --- a/vllm_mindspore/attention/backends/ms_attn.py +++ b/vllm_mindspore/attention/backends/ms_attn.py @@ -23,6 +23,8 @@ from itertools import accumulate from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import os +import numpy as np + import torch from vllm.attention.backends.abstract import ( @@ -55,6 +57,7 @@ import mindspore as ms from mindspore import mutable from mindspore._c_expression import swap_cache + def advance_step_op(sampled_token_ids, model_input, seq_lens_tensor, @@ -391,13 +394,17 @@ class MSAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): raise AttributeError(f"Invalid attention type {str(attn_type)}") def keys(self): - return ["num_prefill_tokens", "num_decode_tokens", "slot_mapping", "batch_valid_length", "context_lens", "block_tables"] + return ["num_prefill_tokens", "num_decode_tokens", "slot_mapping", "batch_valid_length", "q_seq_lens", "block_tables"] def __getitem__(self, key): - if key == "context_lens": - key = "seq_lens_tensor" + if key == "q_seq_lens": + query_lens = getattr(self, "query_lens") + return ms.Tensor.from_numpy(np.array(query_lens, dtype=np.int32)) if key == "batch_valid_length": - return mutable(getattr(self, "seq_lens"), dynamic_len=True) + a = getattr(self, "seq_lens_tensor") + a.asnumpy() + return a + # return mutable(getattr(self, "seq_lens_tensor")) if key == "block_tables": if getattr(self, key).ndim == 1: return mutable(getattr(self, key).expand_dims(0)) diff --git a/vllm_mindspore/attention/layer.py b/vllm_mindspore/attention/layer.py index 84335349..bb24b9cb 100644 --- a/vllm_mindspore/attention/layer.py +++ b/vllm_mindspore/attention/layer.py @@ -153,15 +153,16 @@ class Attention(nn.Cell): query: Tensor, key: Tensor, value: Tensor, - kv_cache: Tuple[Tensor, Tensor], + key_cache, value_cache: Tuple[Tensor, Tensor], # attn_metadata: MSMetadata, - num_prefill_tokens: int, + num_prefill_tokens: bool, num_decode_tokens: int, slot_mapping: Tensor, batch_valid_length: Tuple[int], - context_lens: Tensor, + q_seq_lens: Tensor, block_tables: Tensor, attn_mask: Tensor, + decode_mask:Tensor, ) -> Tensor: """Attention foward, support MHA and GQA. @@ -175,13 +176,14 @@ class Attention(nn.Cell): block_tables: shape = [block_size, num_block] """ output = query - key_cache, value_cache = kv_cache + # key_cache, value_cache = kv_cache cache_out = self.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) query = ops.depend(query, cache_out) if num_prefill_tokens > 0: output = self._run_prefill_forward(query, key, value, attn_mask, batch_valid_length, batch_valid_length) if num_decode_tokens > 0: - output = self._run_decode_forward(query, key_cache, value_cache, block_tables, context_lens) + output = self._run_decode_forward(query, key_cache, value_cache, block_tables,batch_valid_length, + decode_mask, q_seq_lens) return output def _run_prefill_forward( @@ -206,16 +208,18 @@ class Attention(nn.Cell): query = query.view(-1, self.hidden_size_per_partition) key = key.view(-1, self.kv_hidden_size_per_partition) value = value.view(-1, self.kv_hidden_size_per_partition) - _, _, _, output = self.flash_attention(query, - key, - value, - None, - None, - None, - attn_mask, - None, - actual_seq_qlen, - actual_seq_kvlen) + _, _, _, output = self.flash_attention( + query, + key, + value, + None, + None, + None, + attn_mask, + None, + actual_seq_qlen, + actual_seq_kvlen + ) output = output.view(1, -1, self.hidden_size_per_partition) return output @@ -225,7 +229,9 @@ class Attention(nn.Cell): key_cache: Tensor, value_cache: Tensor, block_tables: Tensor, - context_lens: Tensor, + batch_valid_length: Tensor, + decode_mask:Tensor, + q_seq_lens: Tensor, ) -> Tensor: """Decode with PagedAttention. @@ -236,5 +242,15 @@ class Attention(nn.Cell): block_tables: shape = [block_size, num_block] context_lens: shape = [batch_size, ] """ - output = self.paged_attention(query, key_cache, value_cache, block_tables, context_lens) + output = self.paged_attention( + query, + key_cache, + value_cache, + block_tables, + batch_valid_length, + None, + None, + decode_mask, + q_seq_lens + ) return output diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py index 7ac62f49..62a03cfa 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py @@ -27,7 +27,6 @@ from vllm.config import VllmConfig from vllm.config import get_current_vllm_config from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.forward_context import ForwardContext, get_forward_context from vllm.sequence import IntermediateTensors from vllm.distributed import get_tensor_model_parallel_world_size from vllm.attention.backends.abstract import AttentionType @@ -123,18 +122,6 @@ class MfModelBase(MsModelBase): raise NotImplementedError("Function _create_network should be Implemented!") - def get_kvcache(self): - key_cache = [] - value_cache = [] - forward_context = get_forward_context() - for i in range(self.mf_model_config.num_layers): - k_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][0] - v_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][1] - key_cache.append(k_cache) - value_cache.append(v_cache) - return mutable(key_cache), mutable(value_cache) - - def prepare_inputs(self, input_ids, positions, attn_metadata): key_cache, value_cache = self.get_kvcache() seq_lens = attn_metadata.seq_lens diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index d6355e42..6bd16792 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -25,6 +25,7 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from vllm.forward_context import get_forward_context from mindspore import Tensor, nn, mutable from mindspore import dtype as mstype @@ -172,13 +173,13 @@ class MsModelBase(): dyn_key_cache = mutable(Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype)) dyn_value_cache = mutable(Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype)) - dyn_kv_cache = mutable((dyn_key_cache, dyn_value_cache)) - dyn_kv_caches = mutable([dyn_kv_cache for _ in range(num_layers)]) + dyn_key_caches = mutable([dyn_key_cache for _ in range(num_layers)]) + dyn_value_caches = mutable([dyn_value_cache for _ in range(num_layers)]) dyn_num_prefill_tokens = mutable(1) dyn_num_decode_tokens = mutable(0) - dyn_context_lens = Tensor(shape=[None, ], dtype=mstype.int32) - dyn_batch_valid_length = mutable([0, 0, 0], dynamic_len=True) + dyn_batch_valid_length = Tensor(shape=[None, ], dtype=mstype.int32) + dyn_q_seq_lens = Tensor(shape=[None, ], dtype=mstype.int32) dyn_slot_mapping = Tensor(shape=[None, ], dtype=mstype.int32) dyn_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) dyn_intermediate_tensors = None @@ -187,17 +188,29 @@ class MsModelBase(): self.model.set_inputs( dyn_input_ids, dyn_position_ids, - dyn_kv_caches, + dyn_key_caches, + dyn_value_caches, dyn_num_prefill_tokens, dyn_num_decode_tokens, - dyn_context_lens, dyn_batch_valid_length, + dyn_q_seq_lens, dyn_slot_mapping, dyn_block_tables, dyn_intermediate_tensors, dyn_inputs_embeds ) + def get_kvcache(self): + key_cache = [] + value_cache = [] + forward_context = get_forward_context() + for i in range(self.config.num_hidden_layers): + k_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][0] + v_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][1] + key_cache.append(k_cache) + value_cache.append(v_cache) + return mutable(key_cache), mutable(value_cache) + @abstractmethod def compute_logits( self, diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index 2c3c81d4..1464a8b9 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +from vllm.config import get_current_vllm_config from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union, Iterable if TYPE_CHECKING: @@ -33,8 +34,6 @@ from vllm_mindspore.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm_mindspore.model_executor.layers.logits_processor import \ LogitsProcessor -from vllm.model_executor.layers.quantization import \ - QuantizationConfig from vllm_mindspore.model_executor.layers.rotary_embedding import get_rope from vllm_mindspore.model_executor.layers.sampler import (SamplerOutput, get_sampler) @@ -47,9 +46,12 @@ from vllm_mindspore.model_executor.models.utils import ( maybe_prefix) from vllm_mindspore.model_executor.sampling_metadata import SamplingMetadata from vllm_mindspore.model_executor.models.model_base import MsModelBase +from vllm_mindspore.model_executor.models.mf_models.mf_model_base import Fake_Attention from vllm.config import CacheConfig, VllmConfig +from vllm.model_executor.layers.quantization import \ + QuantizationConfig from vllm.sequence import IntermediateTensors from vllm.attention.backends.abstract import AttentionType from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -170,26 +172,27 @@ class Qwen2Attention(nn.Cell): attn_type=attn_type ) self.attn_mask = mint.triu(mint.ones(size=(128, 128), dtype=mstype.bfloat16), 1) + self.hard_mask = Tensor([0], dtype=mstype.bfloat16).reshape(1, 1) @jit def construct( self, positions: Tensor, hidden_states: Tensor, - kv_cache: Tuple[Tensor, Tensor], + key_cache, value_cache: Tuple[Tensor, Tensor], # attn_metadata: AttentionMetadata, num_prefill_tokens: int, num_decode_tokens: int, slot_mapping: Tensor, batch_valid_length: Tuple[int], - context_lens: Tensor, + q_seq_lens: Tensor, block_tables: Tensor, ) -> Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = mint.split(qkv, (self.q_size, self.kv_size, self.kv_size), -1) - q, k = self.rotary_emb(positions, q, k, context_lens, num_prefill_tokens) - attn_output = self.attn(q, k, v, kv_cache, num_prefill_tokens, num_decode_tokens, - slot_mapping, batch_valid_length, context_lens, block_tables, self.attn_mask) + q, k = self.rotary_emb(positions, q, k, q_seq_lens, num_prefill_tokens) + attn_output = self.attn(q, k, v, key_cache, value_cache, num_prefill_tokens, num_decode_tokens, + slot_mapping, batch_valid_length, q_seq_lens, block_tables, self.attn_mask, self.hard_mask) output, _ = self.o_proj(attn_output) return output @@ -249,13 +252,13 @@ class Qwen2DecoderLayer(nn.Cell): self, positions: Tensor, hidden_states: Tensor, - kv_cache: Tuple[Tensor, Tensor], + key_cache, value_cache: Tuple[Tensor, Tensor], # attn_metadata: AttentionMetadata, num_prefill_tokens: int, num_decode_tokens: int, slot_mapping: Tensor, batch_valid_length: Tuple[int], - context_lens: Tensor, + q_seq_lens: Tensor, block_tables: Tensor, residual: Optional[Tensor], ) -> Tuple[Tensor, Tensor]: @@ -268,12 +271,12 @@ class Qwen2DecoderLayer(nn.Cell): hidden_states = self.self_attn( positions, hidden_states, - kv_cache, + key_cache, value_cache, num_prefill_tokens, num_decode_tokens, slot_mapping, batch_valid_length, - context_lens, + q_seq_lens, block_tables ) @@ -335,13 +338,13 @@ class Qwen2Model(nn.Cell): self, input_ids: Optional[Tensor], positions: Tensor, - kv_caches: List[Tuple[Tensor, Tensor]], + key_cache, value_cache, # attn_metadata: AttentionMetadata, num_prefill_tokens: int, num_decode_tokens: int, slot_mapping: Tensor, - batch_valid_length: Tuple[int], - context_lens: Tensor, + batch_valid_length: Tensor, + q_seq_lens: Tensor, block_tables: Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[Tensor] = None, @@ -361,12 +364,13 @@ class Qwen2Model(nn.Cell): hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], + key_cache[i - self.start_layer], + value_cache[i - self.start_layer], num_prefill_tokens, num_decode_tokens, slot_mapping, batch_valid_length, - context_lens, + q_seq_lens, block_tables, residual ) @@ -398,16 +402,16 @@ class Qwen2Model(nn.Cell): # the checkpoint. Skip them. continue if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -484,6 +488,13 @@ class Qwen2ForCausalLM(MsModelBase): self.set_modules({"model": self.model, "lm_head": self.lm_head}) self.set_model_inputs() + self.kv_caches = [Fake_Attention() for i in range(config.num_hidden_layers)] + compilation_config = vllm_config.compilation_config + + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + for i in range(config.num_hidden_layers): + compilation_config.static_forward_context[str(i)] = self.kv_caches[i] def get_input_embeddings(self, input_ids: Tensor) -> Tensor: return self.model.get_input_embeddings(input_ids) @@ -498,13 +509,15 @@ class Qwen2ForCausalLM(MsModelBase): inputs_embeds: Tensor = None, **kwargs ) -> Union[Tensor, IntermediateTensors]: + key_cache, value_cache = self.get_kvcache() if attn_metadata.num_prefill_tokens > 0: input_ids = input_ids.expand_dims(0) if attn_metadata.num_decode_tokens > 0: input_ids = input_ids.expand_dims(1) model_output = self.model(input_ids, positions, - kv_caches, + key_cache, + value_cache, **dict(attn_metadata), intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds) -- Gitee