diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 77fa9197e3c9f2033e85fd66b5d53579d756f6fb..5b6099445e73c2bd23fc4df74ed660442dfd4f87 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -116,14 +116,12 @@ vllm.model_executor.sampling_metadata.SamplingMetadataCache = SamplingMetadataCa vllm.model_executor.sampling_metadata.SamplingMetadata = SamplingMetadata from vllm_mindspore.worker.cache_engine import ( - ms_allocate_kv_cache, ms_swap_in, ms_swap_out, ) import vllm.worker.cache_engine -vllm.worker.cache_engine.CacheEngine._allocate_kv_cache = ms_allocate_kv_cache vllm.worker.cache_engine.CacheEngine.swap_in = ms_swap_in vllm.worker.cache_engine.CacheEngine.swap_out = ms_swap_out @@ -148,14 +146,12 @@ Worker.init_device = wrapper_worker_init_device(Worker.init_device) from vllm_mindspore.worker.model_runner import ( _get_cuda_graph_pad_size, - _dummy_run, _get_supported_attention_backends, ) vllm.worker.model_runner.ModelInputForGPUBuilder._get_cuda_graph_pad_size = ( _get_cuda_graph_pad_size ) -vllm.worker.model_runner.GPUModelRunnerBase._dummy_run = _dummy_run import vllm.worker.multi_step_model_runner diff --git a/vllm_mindspore/attention/layer.py b/vllm_mindspore/attention/layer.py index 84335349b49f20eabedb8bb0ee90ef1025726e97..38ce35c64f96b51fb4f3a8a105e49ed2b67be4d9 100644 --- a/vllm_mindspore/attention/layer.py +++ b/vllm_mindspore/attention/layer.py @@ -175,7 +175,8 @@ class Attention(nn.Cell): block_tables: shape = [block_size, num_block] """ output = query - key_cache, value_cache = kv_cache + key_cache = kv_cache[0] + value_cache = kv_cache[1] cache_out = self.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) query = ops.depend(query, cache_out) if num_prefill_tokens > 0: diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py index f640583ce5c18758348c01da7f7c04e8ede78c7c..f2ce3705668a1ee19b3c1bd69a56b1e3e4015f10 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py @@ -69,14 +69,15 @@ class Fake_Attention: ) head_size = vllm_config.model_config.get_head_size() num_block = 0 - self.kv_shape = [num_block, block_size, num_kv_heads, head_size] + self.kv_shape = [2, num_block, block_size, num_kv_heads, head_size] self.kv_cache = [ - ( - torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend"), - torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend"), - ) + torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend") for _ in range(vllm_config.parallel_config.pipeline_parallel_size) ] + self.num_block = num_block + self.block_size = block_size + self.num_kv_heads = num_kv_heads + self.head_size = head_size self.attn_type = AttentionType.DECODER @@ -84,8 +85,9 @@ class Fake_MLA(Fake_Attention): def __init__(self): super().__init__() vllm_config = get_current_vllm_config() + self.kv_shape = [1, self.num_block, self.block_size, self.num_kv_heads, self.head_size] self.kv_cache = [ - (torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend"),) + torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend") for _ in range(vllm_config.parallel_config.pipeline_parallel_size) ] @@ -99,15 +101,13 @@ class Fake_Attention_V1(Attention): ) head_size = vllm_config.model_config.get_head_size() num_block = 0 - self.kv_shape = [num_block, block_size, num_kv_heads, head_size] + self.kv_shape = [2, num_block, block_size, num_kv_heads, head_size] self.kv_cache = [ - ( - torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend"), - torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend"), - ) + torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend") for _ in range(vllm_config.parallel_config.pipeline_parallel_size) ] self.attn_type = AttentionType.DECODER + self.num_block = num_block self.num_kv_heads = num_kv_heads self.head_size = head_size self.dtype = vllm_config.model_config.dtype @@ -119,8 +119,9 @@ class Fake_MLA_V1(Fake_Attention_V1): def __init__(self): super().__init__() vllm_config = get_current_vllm_config() + self.kv_shape = [1, self.num_block, self.block_size, self.num_kv_heads, self.head_size] self.kv_cache = [ - (torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend"),) + torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend") for _ in range(vllm_config.parallel_config.pipeline_parallel_size) ] diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 7d3ed381aea1aad3914d2c9b5bbe13bb06dd0d67..056676196e04a31d3ecd0d53aef5b2a79f759f6a 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -167,10 +167,8 @@ class MsModelBase(): num_layers = self.model_config.get_num_layers(self.parallel_config) - dyn_key_cache = mutable(Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype)) - dyn_value_cache = mutable(Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype)) - dyn_kv_cache = mutable((dyn_key_cache, dyn_value_cache)) - dyn_kv_caches = mutable([dyn_kv_cache for _ in range(num_layers)]) + kv_cache_shape = (2, None, block_size, num_kv_heads, head_size) + dyn_kv_caches = mutable([Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype) for _ in range(num_layers)]) dyn_num_prefill_tokens = mutable(1) dyn_num_decode_tokens = mutable(0) diff --git a/vllm_mindspore/worker/cache_engine.py b/vllm_mindspore/worker/cache_engine.py index dfd0ef10e64630933b81764626aab2c986fbf2d5..d56c91c75122f106d2cae82ce5d8df503a2e9b5c 100644 --- a/vllm_mindspore/worker/cache_engine.py +++ b/vllm_mindspore/worker/cache_engine.py @@ -35,29 +35,6 @@ def create_block(shape, dtype, name=None, device=None): return blocks -def ms_allocate_kv_cache( - self, - num_blocks: int, - device: str, -) -> List[MsKVCache]: - """Allocates KV cache on the specified device.""" - kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, self.block_size, self.num_kv_heads, self.head_size - ) - kv_cache: List[MsKVCache] = [] - - self.dtype = get_valid_dtype(self.dtype) - - for _ in range(self.num_attention_layers): - device_type = "CPU" if device == "cpu" else "Ascend" - current_cache = [] - for i in range(kv_cache_shape[0]): - cache_blocks = create_block( - kv_cache_shape[1:], self.dtype, device=device_type - ) - current_cache.append(mutable(cache_blocks)) - kv_cache.append(mutable(tuple(current_cache))) - return mutable(kv_cache) def ms_swap_in(self, src_to_dst: ms.Tensor) -> None: diff --git a/vllm_mindspore/worker/model_runner.py b/vllm_mindspore/worker/model_runner.py index 55bb26ec4ee65181cfc30425640149532c5b36bd..a4f118b93d16cc83f7438f7f9ff44d12974f278d 100644 --- a/vllm_mindspore/worker/model_runner.py +++ b/vllm_mindspore/worker/model_runner.py @@ -40,133 +40,6 @@ def _get_cuda_graph_pad_size( return -1 -def _dummy_run(self, - max_num_batched_tokens: int, - max_num_seqs: int = 1) -> None: - with self.set_in_profile_run(): - # Enable top-k sampling to reflect the accurate memory usage. - sampling_params = \ - SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) - - # This represents the maximum number of different requests - # that will have unique loras, an therefore the max amount of memory - # consumption create dummy lora request copies from the lora request - # passed in, which contains a lora from the lora warmup path. - dummy_lora_requests: List[LoRARequest] = [] - dummy_lora_requests_per_seq: List[LoRARequest] = [] - if self.lora_config: - assert self.lora_manager is not None - with self.lora_manager.dummy_lora_cache(): - for idx in range(self.lora_config.max_loras): - lora_id = idx + 1 - dummy_lora_request = LoRARequest( - lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, - lora_path="/not/a/real/path", - ) - self.lora_manager.add_dummy_lora(dummy_lora_request, - rank=LORA_WARMUP_RANK) - dummy_lora_requests.append(dummy_lora_request) - dummy_lora_requests_per_seq = [ - dummy_lora_requests[idx % len(dummy_lora_requests)] - for idx in range(max_num_seqs) - ] - - # Profile memory usage with max_num_sequences sequences and the - # total number of tokens equal to max_num_batched_tokens. - seqs: List[SequenceGroupMetadata] = [] - # Additional GPU memory may be needed for multi-modal encoding, - # which needs to be accounted for when calculating the GPU blocks - # for vLLM blocker manager. - # To exercise the worst scenario for GPU memory consumption, - # the number of seqs (batch_size) is chosen to maximize the number - # of images processed. - - max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( - self.model_config) - if max_mm_tokens > 0: - max_num_seqs_orig = max_num_seqs - max_num_seqs = min(max_num_seqs, - max_num_batched_tokens // max_mm_tokens) - if max_num_seqs < 1: - expr = (f"min({max_num_seqs_orig}, " - f"{max_num_batched_tokens} // {max_mm_tokens})") - logger.warning( - "Computed max_num_seqs (%s) to be less than 1. " - "Setting it to the minimum value of 1.", expr) - max_num_seqs = 1 - - batch_size = 0 - for group_id in range(max_num_seqs): - seq_len = (max_num_batched_tokens // max_num_seqs + - (group_id < max_num_batched_tokens % max_num_seqs)) - batch_size += seq_len - - dummy_data = self.input_registry \ - .dummy_data_for_profiling(self.model_config, - seq_len, - self.mm_registry) - - seq = SequenceGroupMetadata( - request_id=str(group_id), - is_prompt=True, - seq_data={group_id: dummy_data.seq_data}, - sampling_params=sampling_params, - block_tables=None, - lora_request=dummy_lora_requests_per_seq[group_id] - if dummy_lora_requests_per_seq else None, - multi_modal_data=dummy_data.multi_modal_data, - multi_modal_placeholders=dummy_data. - multi_modal_placeholders, - ) - seqs.append(seq) - - # Run the model with the dummy inputs. - num_layers = self.model_config.get_num_layers(self.parallel_config) - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value ``None``. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). - # it is important to create tensors inside the loop, rather than - # multiplying the list, to avoid Dynamo from treating them as - # tensor aliasing. - kv_cache_dtype = self.model_config.dtype if self.cache_config.cache_dtype == "auto" \ - else self.cache_config.cache_dtype - if kv_cache_dtype in STR_DTYPE_TO_TENSOR_DTYPE: - kv_cache_dtype = STR_DTYPE_TO_TENSOR_DTYPE[kv_cache_dtype] - block_size = self.cache_config.block_size - num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) - head_size = self.model_config.get_head_size() - kv_shape = [0, block_size, num_kv_heads, head_size] - kv_caches = mutable([ - mutable(( - mutable(torch.tensor([], dtype=kv_cache_dtype, device=self.device).reshape(kv_shape)), - mutable(torch.tensor([], dtype=kv_cache_dtype, device=self.device).reshape(kv_shape)), - )) - for _ in range(num_layers) - ]) - finished_requests_ids = [seq.request_id for seq in seqs] - model_input = self.prepare_model_input( - seqs, finished_requests_ids=finished_requests_ids) - intermediate_tensors = None - if not get_pp_group().is_first_rank: - intermediate_tensors = \ - self.model.make_empty_intermediate_tensors( - batch_size=batch_size, - dtype=self.model_config.dtype, - device=self.device) - - # Disable KV Scale Calculation for dummy data during profile run - if model_input.attn_metadata is not None: - model_input.attn_metadata.enable_kv_scales_calculation = False - - self.execute_model(model_input, kv_caches, intermediate_tensors) - torch.cuda.synchronize() - if self.lora_config: - # Remove dummy loras. - assert self.lora_manager is not None - self.remove_all_loras() - return MULTI_STEP_ATTENTION_BACKENDS = [