diff --git a/vllm_mindspore/attention/layer.py b/vllm_mindspore/attention/layer.py index 84335349b49f20eabedb8bb0ee90ef1025726e97..38ce35c64f96b51fb4f3a8a105e49ed2b67be4d9 100644 --- a/vllm_mindspore/attention/layer.py +++ b/vllm_mindspore/attention/layer.py @@ -175,7 +175,8 @@ class Attention(nn.Cell): block_tables: shape = [block_size, num_block] """ output = query - key_cache, value_cache = kv_cache + key_cache = kv_cache[0] + value_cache = kv_cache[1] cache_out = self.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) query = ops.depend(query, cache_out) if num_prefill_tokens > 0: diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index f1bb23615371f91672f403da251fea01515e6aae..1e81fcb484cd44b28c49c3c94f24ee10fa87ba91 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -165,10 +165,8 @@ class MsModelBase(): num_layers = self.model_config.get_num_layers(self.parallel_config) - dyn_key_cache = mutable(Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype)) - dyn_value_cache = mutable(Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype)) - dyn_kv_cache = mutable((dyn_key_cache, dyn_value_cache)) - dyn_kv_caches = mutable([dyn_kv_cache for _ in range(num_layers)]) + kv_cache_shape = (2, None, block_size, num_kv_heads, head_size) + dyn_kv_caches = mutable([Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype) for _ in range(num_layers)]) dyn_num_prefill_tokens = mutable(1) dyn_num_decode_tokens = mutable(0) diff --git a/vllm_mindspore/worker/model_runner.py b/vllm_mindspore/worker/model_runner.py index 9a6df84634af2f36e62a3dd1eb3b4d80ecb12b50..a19bde81316d91f6b906e53f10ed70b4c782ed90 100644 --- a/vllm_mindspore/worker/model_runner.py +++ b/vllm_mindspore/worker/model_runner.py @@ -133,14 +133,8 @@ def profile_run(self) -> None: block_size = self.cache_config.block_size num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) head_size = self.model_config.get_head_size() - kv_shape = [0, block_size, num_kv_heads, head_size] - kv_caches = mutable([ - mutable(( - mutable(torch.tensor([], dtype=kv_cache_dtype, device=self.device).reshape(kv_shape)), - mutable(torch.tensor([], dtype=kv_cache_dtype, device=self.device).reshape(kv_shape)), - )) - for _ in range(num_layers) - ]) + kv_shape = [2, 0, block_size, num_kv_heads, head_size] + kv_caches = mutable([torch.tensor([], dtype=kv_cache_dtype, device=self.device).reshape(kv_shape) for _ in range(num_layers)]) finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( seqs, finished_requests_ids=finished_requests_ids) @@ -166,4 +160,4 @@ def _get_supported_attention_backends(chunked_prefill_enabled: bool) \ if chunked_prefill_enabled: return MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS else: - return MULTI_STEP_ATTENTION_BACKENDS \ No newline at end of file + return MULTI_STEP_ATTENTION_BACKENDS