From 444419c6d0b6a3666347f2a430b08d9cfa786197 Mon Sep 17 00:00:00 2001 From: fengtingyan Date: Fri, 28 Mar 2025 01:43:23 +0000 Subject: [PATCH] =?UTF-8?q?kvcache=20=E4=B8=A4=E4=B8=AAtensor=E5=90=88?= =?UTF-8?q?=E4=B8=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vllm_mindspore/attention/layer.py | 3 ++- vllm_mindspore/model_executor/models/model_base.py | 6 ++---- vllm_mindspore/worker/model_runner.py | 12 +++--------- 3 files changed, 7 insertions(+), 14 deletions(-) diff --git a/vllm_mindspore/attention/layer.py b/vllm_mindspore/attention/layer.py index 84335349..38ce35c6 100644 --- a/vllm_mindspore/attention/layer.py +++ b/vllm_mindspore/attention/layer.py @@ -175,7 +175,8 @@ class Attention(nn.Cell): block_tables: shape = [block_size, num_block] """ output = query - key_cache, value_cache = kv_cache + key_cache = kv_cache[0] + value_cache = kv_cache[1] cache_out = self.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) query = ops.depend(query, cache_out) if num_prefill_tokens > 0: diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index f1bb2361..1e81fcb4 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -165,10 +165,8 @@ class MsModelBase(): num_layers = self.model_config.get_num_layers(self.parallel_config) - dyn_key_cache = mutable(Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype)) - dyn_value_cache = mutable(Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype)) - dyn_kv_cache = mutable((dyn_key_cache, dyn_value_cache)) - dyn_kv_caches = mutable([dyn_kv_cache for _ in range(num_layers)]) + kv_cache_shape = (2, None, block_size, num_kv_heads, head_size) + dyn_kv_caches = mutable([Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype) for _ in range(num_layers)]) dyn_num_prefill_tokens = mutable(1) dyn_num_decode_tokens = mutable(0) diff --git a/vllm_mindspore/worker/model_runner.py b/vllm_mindspore/worker/model_runner.py index 9a6df846..a19bde81 100644 --- a/vllm_mindspore/worker/model_runner.py +++ b/vllm_mindspore/worker/model_runner.py @@ -133,14 +133,8 @@ def profile_run(self) -> None: block_size = self.cache_config.block_size num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) head_size = self.model_config.get_head_size() - kv_shape = [0, block_size, num_kv_heads, head_size] - kv_caches = mutable([ - mutable(( - mutable(torch.tensor([], dtype=kv_cache_dtype, device=self.device).reshape(kv_shape)), - mutable(torch.tensor([], dtype=kv_cache_dtype, device=self.device).reshape(kv_shape)), - )) - for _ in range(num_layers) - ]) + kv_shape = [2, 0, block_size, num_kv_heads, head_size] + kv_caches = mutable([torch.tensor([], dtype=kv_cache_dtype, device=self.device).reshape(kv_shape) for _ in range(num_layers)]) finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( seqs, finished_requests_ids=finished_requests_ids) @@ -166,4 +160,4 @@ def _get_supported_attention_backends(chunked_prefill_enabled: bool) \ if chunked_prefill_enabled: return MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS else: - return MULTI_STEP_ATTENTION_BACKENDS \ No newline at end of file + return MULTI_STEP_ATTENTION_BACKENDS -- Gitee