From 38dadba8236afa448d9daa2d9ea74a27e00a08d5 Mon Sep 17 00:00:00 2001 From: tronzhang Date: Sat, 29 Mar 2025 20:21:39 +0800 Subject: [PATCH 1/5] adapte vLLM-MindSpore with vLLM-Ascend --- vllm_mindspore/__init__.py | 139 +++++++++++++------ vllm_mindspore/attention/backends/ms_attn.py | 25 ++-- vllm_mindspore/platforms/ascend.py | 19 ++- vllm_mindspore/worker/model_runner.py | 101 +++++++++++++- 4 files changed, 219 insertions(+), 65 deletions(-) diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 963316a7..91aa9483 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -17,13 +17,7 @@ # ============================================================================ import sys -import warnings -if "vllm" in sys.modules: - # Check models variable in sub process, cannot raise here. - warnings.warn( - "vllm import before vllm_mindspore, vllm_mindspore cannot worker right!" - ) # 1. set env before import mindspore. from vllm_mindspore.scripts import env_setup @@ -32,32 +26,33 @@ env_setup() # 2. update the log configuration ahead of other modifications. import vllm_mindspore.logger -from vllm_mindspore.platforms.ascend import AscendPlatform +# ================ For vllm ================ +# from vllm_mindspore.platforms.ascend import AscendPlatform -ascend_platform = AscendPlatform() +# ascend_platform = AscendPlatform() -import vllm.config +# import vllm.config -vllm.config.current_platform = ascend_platform +# vllm.config.current_platform = ascend_platform -import vllm.platforms +# import vllm.platforms -vllm.platforms.current_platform = ascend_platform +# vllm.platforms.current_platform = ascend_platform import vllm.utils -vllm.utils.current_platform = ascend_platform +# vllm.utils.current_platform = ascend_platform -import vllm.attention.selector -vllm.attention.selector.current_platform = ascend_platform +# import vllm.attention.selector +# vllm.attention.selector.current_platform = ascend_platform -import vllm.engine.arg_utils -from vllm_mindspore.engine.arg_utils import _is_v1_supported_oracle -vllm.engine.arg_utils.EngineArgs._is_v1_supported_oracle = _is_v1_supported_oracle +# import vllm.engine.arg_utils +# from vllm_mindspore.engine.arg_utils import _is_v1_supported_oracle +# vllm.engine.arg_utils.EngineArgs._is_v1_supported_oracle = _is_v1_supported_oracle -import vllm.v1.engine.core -from vllm_mindspore.v1.engine.core import shutdown -vllm.v1.engine.core.DPEngineCoreProc.shutdown = shutdown +# import vllm.v1.engine.core +# from vllm_mindspore.v1.engine.core import shutdown +# vllm.v1.engine.core.DPEngineCoreProc.shutdown = shutdown from vllm_mindspore.utils import ( make_tensor_with_pad, @@ -102,17 +97,17 @@ from vllm_mindspore.model_executor.sampling_metadata import SamplingTensors vllm.model_executor.sampling_metadata.async_tensor_h2d = async_tensor_h2d vllm.model_executor.sampling_metadata.SamplingTensors.from_lists = SamplingTensors.from_lists -from vllm_mindspore.worker.cache_engine import ( - ms_allocate_kv_cache, - ms_swap_in, - ms_swap_out, -) +# from vllm_mindspore.worker.cache_engine import ( +# ms_allocate_kv_cache, +# ms_swap_in, +# ms_swap_out, +# ) -import vllm.worker.cache_engine +# import vllm.worker.cache_engine -vllm.worker.cache_engine.CacheEngine._allocate_kv_cache = ms_allocate_kv_cache -vllm.worker.cache_engine.CacheEngine.swap_in = ms_swap_in -vllm.worker.cache_engine.CacheEngine.swap_out = ms_swap_out +# vllm.worker.cache_engine.CacheEngine._allocate_kv_cache = ms_allocate_kv_cache +# vllm.worker.cache_engine.CacheEngine.swap_in = ms_swap_in +# vllm.worker.cache_engine.CacheEngine.swap_out = ms_swap_out from vllm_mindspore.model_executor.model_loader.weight_utils import ( safetensors_weights_iterator, @@ -122,16 +117,16 @@ vllm.model_executor.model_loader.loader.safetensors_weights_iterator = ( safetensors_weights_iterator ) -from vllm_mindspore.worker.worker import _warm_up_model -from vllm_mindspore.worker.profile import ( - wrapper_worker_init, - wrapper_worker_init_device, -) -from vllm.worker.worker import Worker +# from vllm_mindspore.worker.worker import _warm_up_model +# from vllm_mindspore.worker.profile import ( +# wrapper_worker_init, +# wrapper_worker_init_device, +# ) +# from vllm.worker.worker import Worker -Worker._warm_up_model = _warm_up_model -Worker.__init__ = wrapper_worker_init(Worker.__init__) -Worker.init_device = wrapper_worker_init_device(Worker.init_device) +# Worker._warm_up_model = _warm_up_model +# Worker.__init__ = wrapper_worker_init(Worker.__init__) +# Worker.init_device = wrapper_worker_init_device(Worker.init_device) from vllm_mindspore.worker.model_runner import ( _get_cuda_graph_pad_size, @@ -139,6 +134,7 @@ from vllm_mindspore.worker.model_runner import ( _get_supported_attention_backends, ) +import vllm.worker.model_runner vllm.worker.model_runner.ModelInputForGPUBuilder._get_cuda_graph_pad_size = ( _get_cuda_graph_pad_size ) @@ -273,11 +269,11 @@ vllm.v1.worker.gpu_model_runner.InputBatch._make_sampling_metadata = _make_sampl vllm.v1.worker.gpu_input_batch.InputBatch._make_prompt_token_ids_tensor = _make_prompt_token_ids_tensor vllm.v1.worker.gpu_model_runner.InputBatch._make_prompt_token_ids_tensor = _make_prompt_token_ids_tensor -from vllm.v1.worker.gpu_worker import Worker -from vllm_mindspore.v1.worker.gpu_worker import init_device +# from vllm.v1.worker.gpu_worker import Worker +# from vllm_mindspore.v1.worker.gpu_worker import init_device -Worker.__init__ = wrapper_worker_init(Worker.__init__) -Worker.init_device = wrapper_worker_init_device(init_device) +# Worker.__init__ = wrapper_worker_init(Worker.__init__) +# Worker.init_device = wrapper_worker_init_device(init_device) import vllm.v1.utils @@ -311,9 +307,62 @@ from vllm_mindspore.distributed.shm_broadcast import initialize_ShmRingBuffer from vllm.distributed.device_communicators.shm_broadcast import ShmRingBuffer ShmRingBuffer.__init__ = initialize_ShmRingBuffer +# from vllm_mindspore.v1.worker.gpu_worker import compile_or_warm_up_model +# from vllm.v1.worker.gpu_worker import Worker +# Worker.compile_or_warm_up_model = compile_or_warm_up_model + +from vllm_mindspore.v1.core.sched.scheduler import schedule +from vllm.v1.core.sched.scheduler import Scheduler +Scheduler.schedule = schedule + +# ================ For vllm-ascend ================ + +from vllm_mindspore.platforms.ascend import get_attn_backend_cls + +import vllm.platforms + +vllm.platforms.current_platform.get_attn_backend_cls = get_attn_backend_cls + +from vllm_mindspore.worker.cache_engine import ( + ms_allocate_kv_cache, + ms_swap_in, + ms_swap_out, +) + +from vllm_ascend.worker.worker import CacheEngine + +CacheEngine._allocate_kv_cache = ms_allocate_kv_cache +CacheEngine.swap_in = ms_swap_in +CacheEngine.swap_out = ms_swap_out + +from vllm_mindspore.worker.worker import _warm_up_model + +from vllm_mindspore.worker.profile import ( + wrapper_worker_init, + wrapper_worker_init_device, +) + +from vllm_ascend.worker.worker import NPUWorker + +NPUWorker._warm_up_model = _warm_up_model +NPUWorker.__init__ = wrapper_worker_init(NPUWorker.__init__) +NPUWorker.init_device = wrapper_worker_init_device(NPUWorker.init_device) + +from vllm_mindspore.worker.model_runner import profile_run +from vllm_ascend.worker.model_runner import NPUModelRunner + +NPUModelRunner.profile_run = profile_run + from vllm_mindspore.v1.worker.gpu_worker import compile_or_warm_up_model -from vllm.v1.worker.gpu_worker import Worker -Worker.compile_or_warm_up_model = compile_or_warm_up_model +from vllm_ascend.worker.worker_v1 import NPUWorker +NPUWorker.compile_or_warm_up_model = compile_or_warm_up_model + +from vllm_mindspore.v1.worker.gpu_worker import init_device + +NPUWorker.__init__ = wrapper_worker_init(NPUWorker.__init__) +NPUWorker.init_device = wrapper_worker_init_device(init_device) + +# ================ End ================ from .utils import check_ready diff --git a/vllm_mindspore/attention/backends/ms_attn.py b/vllm_mindspore/attention/backends/ms_attn.py index d6123b0a..64c94efc 100644 --- a/vllm_mindspore/attention/backends/ms_attn.py +++ b/vllm_mindspore/attention/backends/ms_attn.py @@ -501,17 +501,12 @@ class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]): self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, - batch_size: int, ): """Build attention metadata with on-device tensors. Args: seq_lens: The maybe padded sequence lengths of the input sequences. query_lens: The query lengths of the input sequences. - cuda_graph_pad_size: The padding size for cuda graph. - -1 if cuda graph is not used. - batch_size: The maybe padded batch size. """ prefix_cache_hit = any( [ @@ -525,7 +520,6 @@ class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]): ) device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 max_query_len = max(query_lens) decode_query_lens = query_lens[self.num_prefills:] @@ -539,15 +533,12 @@ class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]): query_start_loc = list(accumulate(query_lens, initial=0)) seq_start_loc = list(accumulate(seq_lens, initial=0)) - if use_captured_graph: - raise RuntimeError("Doesnot support captured graph now!") - else: - block_tables = make_tensor_with_pad( - self.block_tables, - pad=-1, - dtype=torch.int, - device=device, - ) + block_tables = make_tensor_with_pad( + self.block_tables, + pad=-1, + dtype=torch.int, + device=device, + ) assert max_query_len > 0, "query_lens: {}".format(query_lens) context_lens_tensor = ms.Tensor(self.context_lens, dtype=ms.int32) @@ -596,6 +587,10 @@ class MsAttentionBackend(AttentionBackend): def get_builder_cls() -> Type["MsAttentionMetadataBuilder"]: return MsAttentionMetadataBuilder + @classmethod + def make_metadata_builder(cls, *args, **kwargs) -> "MsAttentionMetadataBuilder": + return cls.get_builder_cls()(*args, **kwargs) + @staticmethod def get_state_cls() -> Type["AttentionState"]: return MsAttentionState diff --git a/vllm_mindspore/platforms/ascend.py b/vllm_mindspore/platforms/ascend.py index 356a33a0..77686171 100644 --- a/vllm_mindspore/platforms/ascend.py +++ b/vllm_mindspore/platforms/ascend.py @@ -144,4 +144,21 @@ class AscendPlatform(Platform): @classmethod def supports_v1(cls, model_config: ModelConfig) -> bool: - return True \ No newline at end of file + return True + +def get_attn_backend_cls(selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla): + """Get the attention backend class of a device.""" + if use_v1: + raise RuntimeError("vLLM-MindSpore do not support v1 egine now!") + if use_mla: + logger.info("Using MindSpore MLA backend.") + return "vllm_mindspore.attention.backends.ms_attn.MLABackend" + + if selected_backend == _Backend.FLASH_ATTN or selected_backend is None: + logger.info("Using MindSpore Attention backend.") + return "vllm_mindspore.attention.backends.ms_attn.MsAttentionBackend" + + raise ValueError( + "Invaild attention backend %s for vLLM-MindSpore with head_size: %s, dtype: %s, kv_cache_dtype: %s, block_size: %s." + % (str(selected_backend), str(head_size), str(dtype), str(kv_cache_dtype), str(block_size)) + ) diff --git a/vllm_mindspore/worker/model_runner.py b/vllm_mindspore/worker/model_runner.py index 55bb26ec..56e425aa 100644 --- a/vllm_mindspore/worker/model_runner.py +++ b/vllm_mindspore/worker/model_runner.py @@ -19,9 +19,9 @@ from typing import List import torch + from vllm.distributed import get_pp_group from vllm.logger import init_logger -from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.sequence import SequenceGroupMetadata from vllm_mindspore.utils import STR_DTYPE_TO_TENSOR_DTYPE @@ -132,8 +132,7 @@ def _dummy_run(self, # tensor aliasing. kv_cache_dtype = self.model_config.dtype if self.cache_config.cache_dtype == "auto" \ else self.cache_config.cache_dtype - if kv_cache_dtype in STR_DTYPE_TO_TENSOR_DTYPE: - kv_cache_dtype = STR_DTYPE_TO_TENSOR_DTYPE[kv_cache_dtype] + kv_cache_dtype = STR_DTYPE_TO_TENSOR_DTYPE[kv_cache_dtype] block_size = self.cache_config.block_size num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) head_size = self.model_config.get_head_size() @@ -179,4 +178,98 @@ def _get_supported_attention_backends(chunked_prefill_enabled: bool) \ if chunked_prefill_enabled: return MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS else: - return MULTI_STEP_ATTENTION_BACKENDS \ No newline at end of file + return MULTI_STEP_ATTENTION_BACKENDS + + +def profile_run(self) -> None: + # Enable top-k sampling to reflect the accurate memory usage. + sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + + # Profile memory usage with max_num_sequences sequences and the total + # number of tokens equal to max_num_batched_tokens. + seqs: List[SequenceGroupMetadata] = [] + # Additional GPU memory may be needed for multi-modal encoding, which + # needs to be accounted for when calculating the GPU blocks for + # vLLM blocker manager. + # To exercise the worst scenario for GPU memory consumption, + # the number of seqs (batch_size) is chosen to maximize the number + # of images processed. + + max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( + self.model_config) + if max_mm_tokens > 0: + max_num_seqs_orig = max_num_seqs + max_num_seqs = min(max_num_seqs, + max_num_batched_tokens // max_mm_tokens) + if max_num_seqs < 1: + expr = (f"min({max_num_seqs_orig}, " + f"{max_num_batched_tokens} // {max_mm_tokens})") + logger.warning( + "Computed max_num_seqs (%s) to be less than 1. " + "Setting it to the minimum value of 1.", expr) + max_num_seqs = 1 + + batch_size = 0 + for group_id in range(max_num_seqs): + seq_len = (max_num_batched_tokens // max_num_seqs + + (group_id < max_num_batched_tokens % max_num_seqs)) + batch_size += seq_len + + dummy_data = self.input_registry \ + .dummy_data_for_profiling(self.model_config, + seq_len, + self.mm_registry) + + seq = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: dummy_data.seq_data}, + sampling_params=sampling_params, + block_tables=None, + lora_request=None, + multi_modal_data=dummy_data.multi_modal_data, + multi_modal_placeholders=dummy_data.multi_modal_placeholders, + ) + seqs.append(seq) + + # Run the model with the dummy inputs. + num_layers = self.model_config.get_num_layers(self.parallel_config) + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value ``None``. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + # it is important to create tensors inside the loop, rather than + # multiplying the list, to avoid Dynamo from treating them as + # tensor aliasing. + + kv_cache_dtype = self.model_config.dtype if self.cache_config.cache_dtype == "auto" \ + else self.cache_config.cache_dtype + kv_cache_dtype = STR_DTYPE_TO_TENSOR_DTYPE[kv_cache_dtype] + block_size = self.cache_config.block_size + num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) + head_size = self.model_config.get_head_size() + kv_shape = [0, block_size, num_kv_heads, head_size] + kv_caches = mutable([ + mutable(( + mutable(torch.tensor([], dtype=kv_cache_dtype, device=self.device).reshape(kv_shape)), + mutable(torch.tensor([], dtype=kv_cache_dtype, device=self.device).reshape(kv_shape)), + )) + for _ in range(num_layers) + ]) + finished_requests_ids = [seq.request_id for seq in seqs] + model_input = self.prepare_model_input( + seqs, finished_requests_ids=finished_requests_ids) + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=batch_size, + dtype=self.model_config.dtype, + device=self.device) + + self.execute_model(model_input, kv_caches, intermediate_tensors) + + from vllm.platforms import current_platform + current_platform.synchronize() + return -- Gitee From 1fefe47d24cc8de74f9934efd56992a23b2de814 Mon Sep 17 00:00:00 2001 From: zlq2020 Date: Tue, 8 Apr 2025 13:15:39 +0800 Subject: [PATCH 2/5] adapter vllm v1+plugin --- vllm_mindspore/__init__.py | 25 + .../attention/backends/ms_attn_v1.py | 210 ++++++ vllm_mindspore/platforms/ascend.py | 7 +- vllm_mindspore/utils.py | 1 - vllm_mindspore/worker/model_runner_v1.py | 692 ++++++++++++++++++ vllm_mindspore/worker/worker_v1.py | 72 ++ 6 files changed, 1004 insertions(+), 3 deletions(-) create mode 100644 vllm_mindspore/attention/backends/ms_attn_v1.py create mode 100644 vllm_mindspore/worker/model_runner_v1.py create mode 100644 vllm_mindspore/worker/worker_v1.py diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 91aa9483..011af523 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -316,6 +316,8 @@ from vllm.v1.core.sched.scheduler import Scheduler Scheduler.schedule = schedule # ================ For vllm-ascend ================ +import vllm_ascend.utils +vllm_ascend.utils.vllm_version_is = lambda version: True from vllm_mindspore.platforms.ascend import get_attn_backend_cls @@ -364,6 +366,29 @@ NPUWorker.init_device = wrapper_worker_init_device(init_device) # ================ End ================ +# ============ For v1 start =========== +from vllm_mindspore.config import _get_and_verify_dtype +vllm.config._get_and_verify_dtype = _get_and_verify_dtype + +from vllm_mindspore.worker.model_runner_v1 import _dummy_run, _process_reqs +from vllm_ascend.worker.model_runner_v1 import NPUModelRunner +NPUModelRunner._dummy_run = _dummy_run +NPUModelRunner._process_reqs = _process_reqs + +#from vllm.v1.worker.gpu_model_runner import GPUModelRunner +#NPUModelRunner.execute_model = GPUModelRunner.execute_model + +from vllm_mindspore.worker.model_runner_v1 import _prepare_inputs, _update_states, initialize_kv_cache, get_kv_cache_spec +NPUModelRunner._prepare_inputs = _prepare_inputs +NPUModelRunner._update_states = _update_states +NPUModelRunner.initialize_kv_cache = initialize_kv_cache +NPUModelRunner.get_kv_cache_spec = get_kv_cache_spec + +from vllm_mindspore.worker.worker_v1 import determine_available_memory +NPUWorker.determine_available_memory = determine_available_memory + +# ============ For v1 end =========== + from .utils import check_ready from vllm_mindspore.engine.multiprocessing.engine import cleanup diff --git a/vllm_mindspore/attention/backends/ms_attn_v1.py b/vllm_mindspore/attention/backends/ms_attn_v1.py new file mode 100644 index 00000000..731a28d6 --- /dev/null +++ b/vllm_mindspore/attention/backends/ms_attn_v1.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type +import numpy as np + +import torch +from mindspore import mutable +from mindspore._c_expression import swap_cache + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType) +from vllm.logger import init_logger + +from vllm_mindspore.utils import MsKVCache + +logger = init_logger(__name__) + + +class MsAttentionBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "MS_ATTN" + + @staticmethod + def get_impl_cls() -> Type["AttentionImpl"]: + return MsAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return MSAttentionMetadata + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def use_cascade_attention(*args, **kwargs) -> bool: + return False + + +class MLABackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "MS_MLA" + + @staticmethod + def get_impl_cls() -> Type["AttentionImpl"]: + return MsAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return MSAttentionMetadata + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (1, num_blocks, block_size, 1, head_size) + + @staticmethod + def use_cascade_attention(*args, **kwargs) -> bool: + return False + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [576] + + +@dataclass +class MSAttentionMetadata: + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # max_seq_len: int + # seq_lens: torch.Tensor + # seq_lens_np: np.ndarray + # block_tables: torch.Tensor + # slot_mapping: torch.Tensor + # q_seq_lens: torch.Tensor + # context_lens: torch.Tensor + # max_context_lens: int + + # def __getitem__(self, key): + # if key == "batch_valid_length": + # key = "seq_lens" + # if key == "block_tables": + # if getattr(self, key).ndim == 1: + # return mutable(getattr(self, key).expand_dims(0)) + # return mutable(getattr(self, key)) + # return getattr(self, key) + + # AscendMetadata + block_tables: Optional[torch.Tensor] + seq_lens: Optional[List[int]] = None + context_lens: Optional[List[int]] = None + max_query_len: Optional[int] = None + slot_mapping: torch.Tensor = None + is_only_prefill: bool = False + attn_mask: Optional[torch.Tensor] = None + + # add for mindspore + num_decode_tokens: int = 0 + query_lens: Optional[List[int]] = None + + +class MsAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: AttentionType = AttentionType.DECODER, + ) -> None: + pass + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: MSAttentionMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + output: shape = [num_tokens, num_heads, head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + NOTE: It in-place updates the output tensor. + """ + pass diff --git a/vllm_mindspore/platforms/ascend.py b/vllm_mindspore/platforms/ascend.py index 77686171..89a41646 100644 --- a/vllm_mindspore/platforms/ascend.py +++ b/vllm_mindspore/platforms/ascend.py @@ -109,7 +109,6 @@ class AscendPlatform(Platform): if use_mla: return "vllm_mindspore.v1.attention.backends.flash_attn.MLABackend" return "vllm_mindspore.v1.attention.backends.flash_attn.FlashAttentionBackend" - raise RuntimeError("vLLM-MindSpore do not support v1 egine now!") if use_mla: logger.info("Using MindSpore MLA backend.") return "vllm_mindspore.attention.backends.ms_attn.MLABackend" @@ -149,7 +148,11 @@ class AscendPlatform(Platform): def get_attn_backend_cls(selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla): """Get the attention backend class of a device.""" if use_v1: - raise RuntimeError("vLLM-MindSpore do not support v1 egine now!") + if use_mla: + logger.info("Using MindSpore MLA backend for V1.") + return "vllm_mindspore.attention.backends.ms_attn_v1.MLABackend" + logger.info("Using MindSpore Attention backend for V1.") + return "vllm_mindspore.attention.backends.ms_attn_v1.MsAttentionBackend" if use_mla: logger.info("Using MindSpore MLA backend.") return "vllm_mindspore.attention.backends.ms_attn.MLABackend" diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index 153589ed..af65c4de 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -164,7 +164,6 @@ def check_ready(): import vllm.envs as envs from mindspore import set_context - # Common environment variables of predict. set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) default_env = { diff --git a/vllm_mindspore/worker/model_runner_v1.py b/vllm_mindspore/worker/model_runner_v1.py new file mode 100644 index 00000000..0e857c5d --- /dev/null +++ b/vllm_mindspore/worker/model_runner_v1.py @@ -0,0 +1,692 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from typing import List, Optional, Tuple +import numpy as np + +import torch +import mindspore as ms +from mindspore import Tensor + +from vllm.logger import init_logger +from vllm.distributed.parallel_state import get_pp_group +from vllm.forward_context import set_forward_context +from vllm.sequence import IntermediateTensors + +from vllm_mindspore.attention.backends.ms_attn_v1 import MSAttentionMetadata, MsAttentionBackend + +#################33 +from mindspore import mutable +from vllm_mindspore.utils import get_valid_dtype +# from vllm_mindspore.utils import is_use_mla + +from vllm.attention import AttentionType +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheSpec +from vllm.v1.utils import bind_kv_cache +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.logger import logger +from vllm.distributed.parallel_state import get_pp_group +from vllm.utils import cdiv +from vllm.logger import init_logger +from vllm.v1.worker.gpu_input_batch import CachedRequestState +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.sampling_params import SamplingType +################### + + +logger = init_logger(__name__) + + +@torch.inference_mode() +def _dummy_run( + self, + num_tokens: int, + dummy_kv_caches: List[torch.Tensor], +) -> torch.Tensor: + model = self.model + if self.is_multimodal_model: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + + if self.uses_mrope: + positions = self.mrope_positions[:, :num_tokens] + else: + positions = self.input_positions_cpu[:num_tokens] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = ( + self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device)) + intermediate_tensors = IntermediateTensors({ + k: v[:num_tokens] + for k, v in self.intermediate_tensors.items() + }) + + if dummy_kv_caches is None: # for compile_or_warm_up_model + attn_metadata = _dummy_attention_metadata(input_ids, positions, False) + else: + attn_metadata = _dummy_attention_metadata(input_ids, positions, True) + + with set_forward_context(attn_metadata, self.vllm_config): + hidden_states = model(input_ids=input_ids, + #positions=positions.to(self.device), + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + kv_caches=dummy_kv_caches, + attn_metadata=attn_metadata) + return hidden_states + + +def _dummy_attention_metadata(input_ids: Tensor, positions: Tensor, is_prefill=True) -> MSAttentionMetadata: + input_len = input_ids.shape[0] + # max_seq_len = ms.Tensor(input_len, dtype=ms.int32) + # seq_lengths = ms.Tensor([input_len], dtype=ms.int32) + # q_seq_lens = ms.Tensor([input_len], dtype=ms.int32) + seq_lens_np = np.array([input_len], dtype=np.int32) + + block_tables = ms.Tensor([[0]], dtype=ms.int32) + slot_mapping = [-1 for _ in range(input_len)] + slot_mapping = ms.Tensor(slot_mapping, dtype=ms.int32) + return MSAttentionMetadata( + block_tables=block_tables, + seq_lens=seq_lens_np, + context_lens=0, + max_query_len=1, + slot_mapping=slot_mapping, + num_decode_tokens=0 if is_prefill else 1, + query_lens=seq_lens_np, + ) + + +def _process_reqs( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, +) -> torch.Tensor: + # Check input valid + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + + # Copy the blocks from CPU to NPU. + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + self.input_batch.block_table.commit(num_reqs) + + # Get the number of scheduled tokens for each request. + # TODO: The Python loop can be slow. Optimize. + num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32) + max_num_scheduled_tokens = 0 + for i, req_id in enumerate(self.input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_scheduled_tokens[i] = num_tokens + # max_num_scheduled_tokens = max(max_num_scheduled_tokens, + # num_tokens) + + # Prepare positions + req_indices = np.repeat(self.arange_np[:num_reqs], + num_scheduled_tokens) + cu_num_tokens = np.cumsum(num_scheduled_tokens) + cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens, + num_scheduled_tokens) + arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets + + positions_np = self.positions_np[:total_num_scheduled_tokens] + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np) + + self.positions[:total_num_scheduled_tokens].copy_( + self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) + positions = self.positions[:total_num_scheduled_tokens] + + self.seq_lens_np[:num_reqs] = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens) + seq_lens = self.seq_lens_cpu[:num_reqs] + + query_lens = torch.from_numpy(num_scheduled_tokens) + + block_table_indices = (req_indices * self.max_num_blocks_per_req + + positions_np // self.block_size) + block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() + block_offsets = positions_np % self.block_size + np.add(block_numbers * self.block_size, + block_offsets, + out=self.slot_mapping_np[:total_num_scheduled_tokens]) + slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to( + self.device, non_blocking=True) + + # attn_mask = self.make_attention_mask(seq_lens=seq_lens, + # query_lens=num_scheduled_tokens, + # position=positions) + + num_decode_tokens = self.input_batch.num_computed_tokens_cpu[:num_reqs].max() + + attn_metadata = MSAttentionMetadata( + seq_lens=query_lens, + context_lens=seq_lens, + slot_mapping=slot_mapping, + block_tables=( + self.input_batch.block_table.get_device_tensor()[:num_reqs]), + # attn_mask=attn_mask + num_decode_tokens=num_decode_tokens, + query_lens=query_lens + ) + + # Prepare input_ids + token_indices = (positions_np + + req_indices * self.input_batch.token_ids_cpu.shape[1]) + torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=self.input_ids_cpu[:total_num_scheduled_tokens]) + # Copy the tensors to the NPU. + self.input_ids[:total_num_scheduled_tokens].copy_( + self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) + input_ids = self.input_ids[:total_num_scheduled_tokens] + + # Run forward pass + with set_forward_context(attn_metadata, self.vllm_config): + assert self.model is not None + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=None, + kv_caches=self.kv_caches, + attn_metadata=attn_metadata, + ) + + return hidden_states[cu_num_tokens - 1] + + +def _prepare_inputs( + self, + scheduler_output: "SchedulerOutput", +) -> Tuple[MSAttentionMetadata, torch.Tensor]: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + self.input_batch.block_table.commit(num_reqs) + # context_lens = ms.Tensor(self.input_batch.num_computed_tokens_cpu[:num_reqs], dtype=torch.int32) + context_lens = ms.from_numpy(self.input_batch.num_computed_tokens_cpu[:num_reqs]) + context_lens.move_to("Ascend", blocking=False) + # Get the number of scheduled tokens for each request. + # TODO: The Python loop can be slow. Optimize. + num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32) + max_num_scheduled_tokens = 0 + for i, req_id in enumerate(self.input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_scheduled_tokens[i] = num_tokens + max_num_scheduled_tokens = max(max_num_scheduled_tokens, + num_tokens) + + # non_blocking send q_seq_lens to device + # q_seq_lens = ms.Tensor(num_scheduled_tokens, dtype=ms.int32) + q_seq_lens = ms.from_numpy(num_scheduled_tokens) + q_seq_lens.move_to("Ascend", blocking=False) + + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + req_indices = np.repeat(self.arange_np[:num_reqs], + num_scheduled_tokens) + + # Get batched arange. + # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # Equivalent to but faster than: + # np.concatenate([np.arange(n) for n in num_scheduled_tokens]) + # Step 1. [2, 5, 3] -> [2, 7, 10] + cu_num_tokens = np.cumsum(num_scheduled_tokens) + # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] + cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens, + num_scheduled_tokens) + # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets + + # Get positions. + positions_np = self.positions_np[:total_num_scheduled_tokens] + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np) + # self.positions_cpu[:total_num_scheduled_tokens] = torch.from_numpy(positions_np) + self.positions[:total_num_scheduled_tokens] = torch.from_numpy(positions_np) + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = (positions_np + + req_indices * self.input_batch.token_ids_cpu.shape[1]) + + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + # self.input_ids_cpu[:total_num_scheduled_tokens] = \ + # torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), + # 0, + # torch.from_numpy(token_indices)) + # self.input_ids[:total_num_scheduled_tokens] = \ + # torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), + # 0, + # torch.from_numpy(token_indices)) + self.input_ids[:total_num_scheduled_tokens] = torch.from_numpy( + np.take(self.input_batch.token_ids_cpu.flatten(), + token_indices, + 0) + ) + + # Calculate the slot mapping. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` here + # because M (max_model_len) is not necessarily divisible by block_size. + block_table_indices = (req_indices * self.max_num_blocks_per_req + + positions_np // self.block_size) + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + # block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + # block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() + block_numbers = self.input_batch.block_table.block_table_np.flatten()[block_table_indices] + block_offsets = positions_np % self.block_size + np.add(block_numbers * self.block_size, + block_offsets, + out=self.slot_mapping_np[:total_num_scheduled_tokens]) + # TODO: + # self.slot_mapping_cpu[:total_num_scheduled_tokens] = \ + # torch.from_numpy(self.slot_mapping_np[:total_num_scheduled_tokens]) + + # non_blocking send q_seq_lens to device + # TODO: + # slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(torch.int32) + # slot_mapping = ms.Tensor(self.slot_mapping_np[:total_num_scheduled_tokens], dtype=ms.int32) + slot_mapping = ms.from_numpy(self.slot_mapping_np[:total_num_scheduled_tokens]) + slot_mapping.move_to("Ascend", blocking=False) + + # # Prepare the attention metadata. + self.query_start_loc_np[0] = 0 + self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens + # query_start_loc = ms.Tensor(self.query_start_loc_np[:num_reqs + 1], dtype=ms.int32) + query_start_loc = ms.from_numpy(self.query_start_loc_np[:num_reqs + 1]) + query_start_loc.move_to("Ascend", blocking=False) + # TODO: + # self.query_start_loc_cpu[1:num_reqs + 1] = \ + # torch.from_numpy(self.query_start_loc_np[1:num_reqs + 1]) + + self.seq_lens_np[:num_reqs] = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens) + max_seq_len = self.seq_lens_np[:num_reqs].max() + # TODO: + seq_lens_np = self.seq_lens_np[:num_reqs] + # seq_lens = ms.Tensor(seq_lens_np, dtype=ms.int32) + seq_lens = ms.from_numpy(seq_lens_np) + seq_lens.move_to("Ascend", blocking=False) + # self.seq_lens_cpu[:num_reqs] = torch.from_numpy(self.seq_lens_np[:num_reqs]) + + # # Copy the tensors to the GPU. + # self.input_ids[:total_num_scheduled_tokens] = \ + # self.input_ids_cpu[:total_num_scheduled_tokens] + + # # Common case (1D positions) + # self.positions[:total_num_scheduled_tokens] = \ + # self.positions_cpu[:total_num_scheduled_tokens] + + # TODO: + # query_start_loc = self.query_start_loc_cpu[:num_reqs + 1] + + # seq_lens = self.seq_lens_cpu[:num_reqs] + + max_context_lens = self.input_batch.num_computed_tokens_cpu[:num_reqs].max() + + attn_metadata = MSAttentionMetadata( + block_tables=(self.input_batch.block_table.get_device_tensor()[:num_reqs]), + seq_lens=seq_lens, + context_lens=context_lens, + max_query_len=1, + slot_mapping=slot_mapping, + num_decode_tokens=num_decode_tokens, + query_lens=query_len + ) + + use_spec_decode = len( + scheduler_output.scheduled_spec_decode_tokens) > 0 + if use_spec_decode: + logits_indices = self._calc_spec_decode_metadata( + scheduler_output, cu_num_tokens) + else: + # NOTE(woosuk): Due to chunked prefills, the batch may contain + # partial requests. While we should not sample any token + # from these partial requests, we do so for simplicity. + # We will ignore the sampled tokens from the partial requests. + # TODO: Support prompt logprobs. + logits_indices = query_start_loc[1:] - 1 + + # Hot-Swap lora model + if self.lora_config: + self.set_active_loras(self.input_batch, num_scheduled_tokens) + + return attn_metadata, logits_indices + + +def create_block(shape, dtype, name=None, device=None): + from mindspore import mint + blocks = mint.empty(shape, dtype=dtype, device=device) + return blocks + +def initialize_kv_cache(self, kv_cache_config) -> None: + """ + Initialize KV cache based on `kv_cache_config`. + Args: + kv_cache_config: Configuration for the KV cache, including the KV + cache size of each layer + """ + if len(kv_cache_config.groups) > 1: + raise NotImplementedError( + "Hybrid models with more than one KV cache type are not " + "supported yet.") + + kv_caches: Dict[str, torch.Tensor] = {} + + # backend = MLABackend if is_use_mla(self.model_config) else FlashAttentionBackend + backend = MsAttentionBackend + + for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items(): + tensor_config = kv_cache_config.tensors[layer_name] + assert tensor_config.size % layer_spec.page_size_bytes == 0 + num_blocks = tensor_config.size // layer_spec.page_size_bytes + if isinstance(layer_spec, FullAttentionSpec): + kv_cache_shape = backend.get_kv_cache_shape( + num_blocks, layer_spec.block_size, layer_spec.num_kv_heads, + layer_spec.head_size) + dtype = layer_spec.dtype + dtype = get_valid_dtype(dtype) + current_cache = [] + device_type = "CPU" if self.device.type == "cpu" else "Ascend" + for i in range(kv_cache_shape[0]): + cache_blocks = create_block( + kv_cache_shape[1:], dtype, device=device_type + ) + current_cache.append(mutable(cache_blocks)) + kv_caches[layer_name] = mutable(tuple(current_cache)) + else: + raise NotImplementedError + + bind_kv_cache( + kv_caches, + self.vllm_config.compilation_config.static_forward_context, + self.kv_caches) + + +def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + """Update the cached states and the persistent batch with the scheduler + output. + + The updated states are used by the `_prepare_inputs` function to create + the input GPU tensors for the model. + + The SamplingMetadata is updated and copied to the GPU if there is a + new/resumed/paused/finished request in the batch. + """ + # Remove finished requests from the cached states. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + self.encoder_cache.pop(req_id, None) + # Remove the finished requests from the persistent batch. + # NOTE(woosuk): There could be an edge case where finished_req_ids and + # scheduled_req_ids overlap. This happens when a request is aborted and + # then resubmitted with the same ID. In this case, we treat them as two + # distinct requests - clearing the cached states for the first request + # and handling the second as a new request. + removed_req_indices: List[int] = [] + for req_id in scheduler_output.finished_req_ids: + req_index = self.input_batch.remove_request(req_id) + if req_index is not None: + removed_req_indices.append(req_index) + + # Free the cached encoder outputs. + for req_id, input_id in scheduler_output.free_encoder_input_ids: + encoder_outputs = self.encoder_cache.get(req_id) + if encoder_outputs is not None: + encoder_outputs.pop(input_id, None) + if not encoder_outputs: + self.encoder_cache.pop(req_id, None) + + # Remove the unscheduled requests from the persistent batch. + # NOTE(woosuk): The unscheduled requests are either preempted requests + # or running requests that are not scheduled in this step. We remove + # them from the persistent batch but keep their cached states since + # they will be scheduled again sometime in the future. + scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() + cached_req_ids = self.input_batch.req_id_to_index.keys() + unscheduled_req_ids = cached_req_ids - scheduled_req_ids + # NOTE(woosuk): The persistent batch optimization assumes that + # consecutive batches contain mostly the same requests. If batches + # have low request overlap (e.g., alternating between two distinct + # sets of requests), this optimization becomes very inefficient. + for req_id in unscheduled_req_ids: + req_index = self.input_batch.remove_request(req_id) + assert req_index is not None + removed_req_indices.append(req_index) + + req_ids_to_add: List[str] = [] + # Add new requests to the cached states. + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + sampling_params = new_req_data.sampling_params + if sampling_params.sampling_type == SamplingType.RANDOM_SEED: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + self.requests[req_id] = CachedRequestState( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + prompt=new_req_data.prompt, + mm_inputs=new_req_data.mm_inputs, + mm_positions=new_req_data.mm_positions, + sampling_params=sampling_params, + generator=generator, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + output_token_ids=[], + lora_request=new_req_data.lora_request, + ) + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + for mm_input in self.requests[req_id].mm_inputs: + if mm_input.get("image_grid_thw") is not None: + image_grid_thw.extend( + mm_input["image_grid_thw"].tolist()) + if mm_input.get("video_grid_thw") is not None: + video_grid_thw.extend( + mm_input["video_grid_thw"].tolist()) + if mm_input.get("second_per_grid_ts") is not None: + second_per_grid_ts.extend( + mm_input["second_per_grid_ts"]) + + hf_config = self.model_config.hf_config + + self.requests[req_id].mrope_positions, \ + self.requests[req_id].mrope_position_delta = \ + MRotaryEmbedding.get_input_positions_tensor( + self.requests[req_id].prompt_token_ids, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + ) + + req_ids_to_add.append(req_id) + + # Update the states of the running/resumed requests. + for req_data in scheduler_output.scheduled_cached_reqs: + req_id = req_data.req_id + req_state = self.requests[req_id] + + # Update the cached states. + num_computed_tokens = req_data.num_computed_tokens + req_state.num_computed_tokens = num_computed_tokens + # Add the sampled token(s) from the previous step (if any). + # This doesn't include "unverified" tokens like spec decode tokens. + num_new_tokens = (num_computed_tokens + + len(req_data.new_token_ids) - + req_state.num_tokens) + if num_new_tokens == 1: + # Avoid slicing list in most common case. + req_state.output_token_ids.append(req_data.new_token_ids[-1]) + elif num_new_tokens > 0: + req_state.output_token_ids.extend( + req_data.new_token_ids[-num_new_tokens:]) + # Update the block IDs. + if not req_data.resumed_from_preemption: + # Append the new blocks to the existing block IDs. + req_state.block_ids.extend(req_data.new_block_ids) + else: + # The request is resumed from preemption. + # Replace the existing block IDs with the new ones. + req_state.block_ids = req_data.new_block_ids + + req_index = self.input_batch.req_id_to_index.get(req_id) + if req_index is None: + # The request is not in the persistent batch. + # The request was either preempted and resumed later, or was not + # scheduled in the previous step and needs to be added again. + req_ids_to_add.append(req_id) + continue + + # Update the persistent batch. + self.input_batch.num_computed_tokens_cpu[req_index] = ( + num_computed_tokens) + start_index = (len(req_state.block_ids) - + len(req_data.new_block_ids)) + self.input_batch.block_table.append_row(req_index, start_index, + req_data.new_block_ids) + # Add new_token_ids to token_ids_cpu. + start_token_index = num_computed_tokens + end_token_index = num_computed_tokens + len(req_data.new_token_ids) + self.input_batch.token_ids_cpu[ + req_index, + start_token_index:end_token_index] = req_data.new_token_ids + # ####################################################### + # self.input_batch.token_ids_cpu_tensor[ + # req_index, + # start_token_index:end_token_index] = torch.from_numpy( + # self.input_batch.token_ids_cpu[ + # req_index, + # start_token_index:end_token_index]) + # ####################################### + self.input_batch.num_tokens_no_spec[req_index] = end_token_index + # Add spec_token_ids to token_ids_cpu. + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( + req_id, ()) + if spec_token_ids: + start_index = end_token_index + end_token_index += len(spec_token_ids) + self.input_batch.token_ids_cpu[ + req_index, start_index:end_token_index] = spec_token_ids + # NOTE(woosuk): `num_tokens` here may include spec decode tokens. + self.input_batch.num_tokens[req_index] = end_token_index + + + # self.input_batch.token_ids_cpu_tensor.copy_(torch.from_numpy(self.input_batch.token_ids_cpu)) + # Check if the batch has changed. If not, we can skip copying the + # sampling metadata from CPU to GPU. + batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0 + + # Add the new or resumed requests to the persistent batch. + # The smaller empty indices are filled first. + removed_req_indices = sorted(removed_req_indices, reverse=True) + for req_id in req_ids_to_add: + req_state = self.requests[req_id] + if removed_req_indices: + # Fill the empty index. + req_index = removed_req_indices.pop() + else: + # Append to the end. + req_index = None + self.input_batch.add_request(req_state, req_index) + + # self.input_batch.commit() + + # Condense the batched states if there are empty indices. + if removed_req_indices: + self.input_batch.condense(removed_req_indices) + # self.input_batch.commit() + + if batch_changed: + self.input_batch.refresh_sampling_metadata() + + +def get_kv_cache_spec(self) -> KVCacheSpec: + """ + Generates the KVCacheSpec by parsing the kv cache format from each + Attention module in the static forward context. + Returns: + KVCacheSpec: A dictionary mapping layer names to their KV cache + format. Layers that do not need KV cache are not included. + """ + + forward_ctx = self.vllm_config.compilation_config.static_forward_context + block_size = self.vllm_config.cache_config.block_size + kv_cache_spec: KVCacheSpec = {} + for layer_name, attn_module in forward_ctx.items(): + # if isinstance(attn_module, FusedMoE): + # continue + + # TODO: Support other attention modules, e.g., sliding window, + # cross-attention, MLA + #assert isinstance(attn_module, Attention) + if attn_module.attn_type == AttentionType.DECODER: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=attn_module.dtype) + elif attn_module.attn_type in (AttentionType.ENCODER, + AttentionType.ENCODER_ONLY): + # encoder-only attention does not need KV cache. + continue + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + raise NotImplementedError + else: + raise ValueError( + f"Unknown attention type: {attn_module.attn_type}") + + return kv_cache_spec \ No newline at end of file diff --git a/vllm_mindspore/worker/worker_v1.py b/vllm_mindspore/worker/worker_v1.py new file mode 100644 index 00000000..685b1967 --- /dev/null +++ b/vllm_mindspore/worker/worker_v1.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +@torch.inference_mode() +def determine_available_memory(self) -> int: + """Profiles the peak memory usage of the model to determine how much + memory can be used for KV cache without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the free memory that can be used for KV cache in + bytes. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + _, total_gpu_memory = torch.cuda.mem_get_info() + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + self.model_runner.profile_run() + + free_gpu_memory, _ = torch.cuda.mem_get_info() + # NOTE(woosuk): Here we assume that the other processes using the same + # GPU did not change their memory usage during the profiling. + assert self.init_gpu_memory > free_gpu_memory, ( + "Error in memory profiling. " + f"Initial free memory {self.init_gpu_memory}, current free memory" + f" {free_gpu_memory}. This happens when the GPU memory was " + "not properly cleaned up before initializing the vLLM instance.") + + # Get the peak memory allocation recorded by torch + peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"] + + # Check for any memory left around that may have been allocated on the + # gpu outside of `torch`. NCCL operations, for example, can use a few + # GB during a forward pass + torch.cuda.empty_cache() + torch_allocated_bytes = torch.cuda.memory_stats( + )["allocated_bytes.all.current"] + total_allocated_bytes = torch.cuda.mem_get_info( + )[1] - torch.cuda.mem_get_info()[0] + non_torch_allocations = total_allocated_bytes - torch_allocated_bytes + if non_torch_allocations > 0: + peak_memory += non_torch_allocations + available_kv_cache_memory = ( + total_gpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) + + return int(available_kv_cache_memory) -- Gitee From 482bdf10e058ad55620b1857ef9b741b38623ede Mon Sep 17 00:00:00 2001 From: can-gaa-hou Date: Tue, 13 May 2025 09:02:44 +0000 Subject: [PATCH 3/5] fix bug and support v0 --- vllm_mindspore/__init__.py | 547 +++++++--------- vllm_mindspore/attention/backends/ms_attn.py | 25 +- .../attention/backends/ms_attn_v1.py | 210 ------- vllm_mindspore/patch/__init__.py | 0 vllm_mindspore/patch/patch_vllm_ascend.py | 274 ++++++++ vllm_mindspore/platforms/ascend.py | 6 +- vllm_mindspore/utils.py | 9 + vllm_mindspore/v1/worker/gpu_worker.py | 6 +- vllm_mindspore/worker/model_runner.py | 99 +-- vllm_mindspore/worker/model_runner_v1.py | 595 +++--------------- vllm_mindspore/worker/worker_v1.py | 82 +-- 11 files changed, 662 insertions(+), 1191 deletions(-) delete mode 100644 vllm_mindspore/attention/backends/ms_attn_v1.py create mode 100644 vllm_mindspore/patch/__init__.py create mode 100644 vllm_mindspore/patch/patch_vllm_ascend.py diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 011af523..37e1f5e7 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -17,7 +17,13 @@ # ============================================================================ import sys +import warnings +if "vllm" in sys.modules: + # Check models variable in sub process, cannot raise here. + warnings.warn( + "vllm import before vllm_mindspore, vllm_mindspore cannot worker right!" + ) # 1. set env before import mindspore. from vllm_mindspore.scripts import env_setup @@ -26,368 +32,285 @@ env_setup() # 2. update the log configuration ahead of other modifications. import vllm_mindspore.logger -# ================ For vllm ================ -# from vllm_mindspore.platforms.ascend import AscendPlatform +import importlib.util +if importlib.util.find_spec("vllm_ascend") is not None: + import vllm_mindspore.patch.patch_vllm_ascend +else: + warnings.warn( + f"vllm-ascend is not imported because: {e}" + ) -# ascend_platform = AscendPlatform() + from vllm_mindspore.platforms.ascend import AscendPlatform -# import vllm.config + ascend_platform = AscendPlatform() -# vllm.config.current_platform = ascend_platform + import vllm.config -# import vllm.platforms + vllm.config.current_platform = ascend_platform -# vllm.platforms.current_platform = ascend_platform + import vllm.platforms -import vllm.utils + vllm.platforms.current_platform = ascend_platform -# vllm.utils.current_platform = ascend_platform + import vllm.utils -# import vllm.attention.selector -# vllm.attention.selector.current_platform = ascend_platform + vllm.utils.current_platform = ascend_platform -# import vllm.engine.arg_utils -# from vllm_mindspore.engine.arg_utils import _is_v1_supported_oracle -# vllm.engine.arg_utils.EngineArgs._is_v1_supported_oracle = _is_v1_supported_oracle + import vllm.attention.selector + vllm.attention.selector.current_platform = ascend_platform -# import vllm.v1.engine.core -# from vllm_mindspore.v1.engine.core import shutdown -# vllm.v1.engine.core.DPEngineCoreProc.shutdown = shutdown + import vllm.engine.arg_utils + from vllm_mindspore.engine.arg_utils import _is_v1_supported_oracle + vllm.engine.arg_utils.EngineArgs._is_v1_supported_oracle = _is_v1_supported_oracle -from vllm_mindspore.utils import ( - make_tensor_with_pad, - async_tensor_h2d, - ascend_is_initialized, - ms_memory_profiling, -) + from vllm_mindspore.utils import ( + make_tensor_with_pad, + async_tensor_h2d, + ascend_is_initialized, + ms_memory_profiling, + ) -vllm.utils.make_tensor_with_pad = make_tensor_with_pad -vllm.utils.async_tensor_h2d = async_tensor_h2d -vllm.utils.cuda_is_initialized = ascend_is_initialized -vllm.utils.memory_profiling = ms_memory_profiling + vllm.utils.make_tensor_with_pad = make_tensor_with_pad + vllm.utils.async_tensor_h2d = async_tensor_h2d + vllm.utils.cuda_is_initialized = ascend_is_initialized + vllm.utils.memory_profiling = ms_memory_profiling -import vllm.executor + from vllm_mindspore.model_executor.models.registry import ( + MindSporeModelRegistry, + _SUBPROCESS_COMMAND, + ) -from vllm_mindspore.model_executor.models.registry import ( - MindSporeModelRegistry, - _SUBPROCESS_COMMAND, -) + vllm.config.ModelRegistry = MindSporeModelRegistry -vllm.config.ModelRegistry = MindSporeModelRegistry + import vllm.model_executor -import vllm.model_executor + vllm.model_executor.models.ModelRegistry = MindSporeModelRegistry + vllm.model_executor.models.registry._SUBPROCESS_COMMAND = _SUBPROCESS_COMMAND -vllm.model_executor.models.ModelRegistry = MindSporeModelRegistry -vllm.model_executor.models.registry._SUBPROCESS_COMMAND = _SUBPROCESS_COMMAND + from vllm_mindspore.model_executor.model_loader.utils import get_ms_model_architecture -from vllm_mindspore.model_executor.model_loader.utils import get_ms_model_architecture + # To patching the get_model_architecture, should import it first. + from vllm.model_executor.model_loader import get_model_architecture -# To patching the get_model_architecture, should import it first. -from vllm.model_executor.model_loader import get_model_architecture + from vllm_mindspore.model_executor.sampling_metadata import SamplingTensors -vllm.model_executor.model_loader.get_model_architecture = get_ms_model_architecture -vllm.model_executor.model_loader.utils.get_model_architecture = ( - get_ms_model_architecture -) -vllm.model_executor.model_loader.loader.get_model_architecture = ( - get_ms_model_architecture -) + vllm.model_executor.sampling_metadata.async_tensor_h2d = async_tensor_h2d + vllm.model_executor.sampling_metadata.SamplingTensors.from_lists = SamplingTensors.from_lists + # from vllm_mindspore.worker.cache_engine import ( + # ms_allocate_kv_cache, + # ms_swap_in, + # ms_swap_out, + # ) -from vllm_mindspore.model_executor.sampling_metadata import SamplingTensors + from vllm_mindspore.worker.cache_engine import ( + ms_allocate_kv_cache, + ms_swap_in, + ms_swap_out, + ) -vllm.model_executor.sampling_metadata.async_tensor_h2d = async_tensor_h2d -vllm.model_executor.sampling_metadata.SamplingTensors.from_lists = SamplingTensors.from_lists -# from vllm_mindspore.worker.cache_engine import ( -# ms_allocate_kv_cache, -# ms_swap_in, -# ms_swap_out, -# ) + import vllm.worker.cache_engine -# import vllm.worker.cache_engine + vllm.worker.cache_engine.CacheEngine._allocate_kv_cache = ms_allocate_kv_cache + vllm.worker.cache_engine.CacheEngine.swap_in = ms_swap_in + vllm.worker.cache_engine.CacheEngine.swap_out = ms_swap_out -# vllm.worker.cache_engine.CacheEngine._allocate_kv_cache = ms_allocate_kv_cache -# vllm.worker.cache_engine.CacheEngine.swap_in = ms_swap_in -# vllm.worker.cache_engine.CacheEngine.swap_out = ms_swap_out + from vllm_mindspore.model_executor.model_loader.weight_utils import ( + safetensors_weights_iterator, + ) -from vllm_mindspore.model_executor.model_loader.weight_utils import ( - safetensors_weights_iterator, -) + vllm.model_executor.model_loader.loader.safetensors_weights_iterator = ( + safetensors_weights_iterator + ) -vllm.model_executor.model_loader.loader.safetensors_weights_iterator = ( - safetensors_weights_iterator -) + from vllm_mindspore.worker.worker import _warm_up_model + from vllm_mindspore.worker.profile import ( + wrapper_worker_init, + wrapper_worker_init_device, + ) + from vllm.worker.worker import Worker -# from vllm_mindspore.worker.worker import _warm_up_model -# from vllm_mindspore.worker.profile import ( -# wrapper_worker_init, -# wrapper_worker_init_device, -# ) -# from vllm.worker.worker import Worker + Worker._warm_up_model = _warm_up_model + Worker.__init__ = wrapper_worker_init(Worker.__init__) + Worker.init_device = wrapper_worker_init_device(Worker.init_device) -# Worker._warm_up_model = _warm_up_model -# Worker.__init__ = wrapper_worker_init(Worker.__init__) -# Worker.init_device = wrapper_worker_init_device(Worker.init_device) + from vllm_mindspore.worker.model_runner import ( + _get_cuda_graph_pad_size, + _dummy_run, + _get_supported_attention_backends, + ) -from vllm_mindspore.worker.model_runner import ( - _get_cuda_graph_pad_size, - _dummy_run, - _get_supported_attention_backends, -) + vllm.worker.model_runner.ModelInputForGPUBuilder._get_cuda_graph_pad_size = ( + _get_cuda_graph_pad_size + ) + vllm.worker.model_runner.GPUModelRunnerBase._dummy_run = _dummy_run -import vllm.worker.model_runner -vllm.worker.model_runner.ModelInputForGPUBuilder._get_cuda_graph_pad_size = ( - _get_cuda_graph_pad_size -) -vllm.worker.model_runner.GPUModelRunnerBase._dummy_run = _dummy_run + import vllm.worker.multi_step_model_runner -import vllm.worker.multi_step_model_runner + vllm.worker.multi_step_model_runner._get_supported_attention_backends = ( + _get_supported_attention_backends + ) -vllm.worker.multi_step_model_runner._get_supported_attention_backends = ( - _get_supported_attention_backends -) + from vllm_mindspore.executor.multiproc_worker_utils import ( + get_mp_context as ms_get_mp_context, + ) -from vllm_mindspore.executor.multiproc_worker_utils import ( - get_mp_context as ms_get_mp_context, - terminate_worker as ms_terminate_worker, -) + # To patching the get_mp_context, should import it first. + from vllm.executor.multiproc_worker_utils import get_mp_context -# To patching the get_mp_context, should import it first. -from vllm.executor.multiproc_worker_utils import get_mp_context + vllm.executor.multiproc_worker_utils.get_mp_context = ms_get_mp_context -vllm.executor.multiproc_worker_utils.get_mp_context = ms_get_mp_context + import vllm.v1.executor.multiproc_executor + vllm.v1.executor.multiproc_executor.get_mp_context = ms_get_mp_context + import vllm.v1.utils + vllm.v1.utils.get_mp_context = ms_get_mp_context -import vllm.executor.multiproc_worker_utils + from vllm_mindspore.executor.ray_gpu_executor import ( + ms_init_workers_ray, + initialize_ray_cluster, + ) -vllm.executor.multiproc_worker_utils.ProcessWorkerWrapper.terminate_worker = ms_terminate_worker + from vllm.executor.ray_distributed_executor import RayDistributedExecutor -import vllm.v1.executor.multiproc_executor -vllm.v1.executor.multiproc_executor.get_mp_context = ms_get_mp_context -import vllm.v1.utils -vllm.v1.utils.get_mp_context = ms_get_mp_context + RayDistributedExecutor._init_workers_ray = ms_init_workers_ray -from vllm_mindspore.executor.ray_gpu_executor import ( - ms_init_workers_ray, - initialize_ray_cluster, -) + vllm.executor.ray_distributed_executor.initialize_ray_cluster = initialize_ray_cluster + vllm.executor.ray_utils.initialize_ray_cluster = initialize_ray_cluster -from vllm.executor.ray_distributed_executor import RayDistributedExecutor + import vllm.engine.llm_engine + import vllm.engine.async_llm_engine -RayDistributedExecutor._init_workers_ray = ms_init_workers_ray + vllm.engine.llm_engine.initialize_ray_cluster = initialize_ray_cluster + vllm.engine.async_llm_engine.initialize_ray_cluster = initialize_ray_cluster -vllm.executor.ray_distributed_executor.initialize_ray_cluster = initialize_ray_cluster -vllm.executor.ray_utils.initialize_ray_cluster = initialize_ray_cluster -import vllm.engine.llm_engine -import vllm.engine.async_llm_engine + from .config import _verify_quantization, _verify_args, vllm_config_post_init, model_post_init, \ + _get_and_verify_dtype, stateless_init_dp_group, has_unfinished_dp -vllm.engine.llm_engine.initialize_ray_cluster = initialize_ray_cluster -vllm.engine.async_llm_engine.initialize_ray_cluster = initialize_ray_cluster + vllm.config.ModelConfig._verify_quantization = _verify_quantization + vllm.config.VllmConfig.__post_init__ = vllm_config_post_init + vllm.config.SchedulerConfig._verify_args = _verify_args + vllm.config.CompilationConfig.model_post_init = model_post_init + vllm.config._get_and_verify_dtype = _get_and_verify_dtype + vllm.config.ParallelConfig.stateless_init_dp_group = stateless_init_dp_group + vllm.config.ParallelConfig.has_unfinished_dp = has_unfinished_dp + from .utils import update_modules + from vllm_mindspore.attention.backends import ms_attn + update_modules("vllm.attention.backends.flash_attn", ms_attn) + + from vllm_mindspore.worker.spec_decode_worker import ( + spec_decode_worker_init, + _run_no_spec, + _verify_tokens, + _create_output, + _merge_outputs, + ) + from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker + SpecDecodeWorker.__init__ = spec_decode_worker_init + SpecDecodeWorker._verify_tokens = _verify_tokens + SpecDecodeWorker._run_no_spec = _run_no_spec + + from vllm.model_executor.layers.spec_decode_base_sampler import SpecDecodeBaseSampler + SpecDecodeBaseSampler._create_output = _create_output + + from vllm.spec_decode.top1_proposer import Top1Proposer + Top1Proposer._merge_outputs = _merge_outputs + + from vllm_mindspore.model_executor.layers.rejection_sampler import _smallest_positive_value, _multinomial + from vllm.model_executor.layers.rejection_sampler import RejectionSampler + RejectionSampler._smallest_positive_value = _smallest_positive_value + RejectionSampler._smallest_positive_value.__set_name__(RejectionSampler, '_smallest_positive_value') + vllm.model_executor.layers.rejection_sampler._multinomial = _multinomial + + from vllm_mindspore.v1.sample import rejection_sampler + update_modules("vllm.v1.sample.rejection_sampler", rejection_sampler) + + from vllm_mindspore.v1.spec_decode import eagle + update_modules("vllm.v1.spec_decode.eagle", eagle) + + from vllm_mindspore.v1.attention.backends import flash_attn + import vllm.v1.attention.backends + sys.modules['vllm.v1.attention.backends.flash_attn'] = flash_attn + import vllm.v1.attention.backends.flash_attn -from .config import _verify_quantization, _verify_args, vllm_config_post_init, model_post_init, \ - _get_and_verify_dtype, stateless_init_dp_group, has_unfinished_dp - -vllm.config.ModelConfig._verify_quantization = _verify_quantization -vllm.config.VllmConfig.__post_init__ = vllm_config_post_init -vllm.config.SchedulerConfig._verify_args = _verify_args -vllm.config.CompilationConfig.model_post_init = model_post_init -vllm.config._get_and_verify_dtype = _get_and_verify_dtype -vllm.config.ParallelConfig.stateless_init_dp_group = stateless_init_dp_group -vllm.config.ParallelConfig.has_unfinished_dp = has_unfinished_dp + import vllm.v1.worker.gpu_model_runner + + from vllm_mindspore.v1.worker.gpu_model_runner import _prepare_inputs + vllm.v1.worker.gpu_model_runner.GPUModelRunner._prepare_inputs = _prepare_inputs -from .utils import update_modules -from vllm_mindspore.attention.backends import ms_attn -update_modules("vllm.attention.backends.flash_attn", ms_attn) + from vllm_mindspore.v1.worker.gpu_model_runner import _update_states + vllm.v1.worker.gpu_model_runner.GPUModelRunner._update_states = _update_states -from vllm_mindspore.worker.spec_decode_worker import ( - spec_decode_worker_init, - _run_no_spec, - _verify_tokens, - _create_output, - _merge_outputs, -) -from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker -SpecDecodeWorker.__init__ = spec_decode_worker_init -SpecDecodeWorker._verify_tokens = _verify_tokens -SpecDecodeWorker._run_no_spec = _run_no_spec + from vllm_mindspore.v1.worker.gpu_model_runner import initialize_kv_cache + vllm.v1.worker.gpu_model_runner.GPUModelRunner.initialize_kv_cache = initialize_kv_cache -from vllm.model_executor.layers.spec_decode_base_sampler import SpecDecodeBaseSampler -SpecDecodeBaseSampler._create_output = _create_output - -from vllm.spec_decode.top1_proposer import Top1Proposer -Top1Proposer._merge_outputs = _merge_outputs + import vllm.v1.worker.block_table + from vllm_mindspore.v1.worker.block_table import BlockTable + vllm.v1.worker.block_table.BlockTable = BlockTable + vllm.v1.worker.gpu_input_batch.BlockTable = BlockTable -from vllm_mindspore.model_executor.layers.rejection_sampler import _smallest_positive_value, _multinomial -from vllm.model_executor.layers.rejection_sampler import RejectionSampler -RejectionSampler._smallest_positive_value = _smallest_positive_value -RejectionSampler._smallest_positive_value.__set_name__(RejectionSampler, '_smallest_positive_value') -vllm.model_executor.layers.rejection_sampler._multinomial = _multinomial - -######### for multi-model -from vllm_mindspore.inputs.registry import call_hf_processor -from vllm.inputs.registry import InputProcessingContext -InputProcessingContext.call_hf_processor = call_hf_processor - -from vllm_mindspore.multimodal.inputs import as_kwargs -from vllm.multimodal.inputs import MultiModalKwargs -MultiModalKwargs.as_kwargs = as_kwargs - -from vllm_mindspore.model_executor.layers.rotary_embedding import InferMRotaryEmbedding -vllm.model_executor.layers.rotary_embedding.MRotaryEmbedding = InferMRotaryEmbedding - -from vllm_mindspore.v1.sample import rejection_sampler -update_modules("vllm.v1.sample.rejection_sampler", rejection_sampler) - -from vllm_mindspore.v1.spec_decode import eagle -update_modules("vllm.v1.spec_decode.eagle", eagle) - -from vllm_mindspore.v1.attention.backends import flash_attn -import vllm.v1.attention.backends -sys.modules['vllm.v1.attention.backends.flash_attn'] = flash_attn -import vllm.v1.attention.backends.flash_attn - -import vllm.v1.worker.gpu_model_runner - -from vllm_mindspore.v1.worker.gpu_model_runner import _prepare_inputs -vllm.v1.worker.gpu_model_runner.GPUModelRunner._prepare_inputs = _prepare_inputs - -from vllm_mindspore.v1.worker.gpu_model_runner import _update_states -vllm.v1.worker.gpu_model_runner.GPUModelRunner._update_states = _update_states - -from vllm_mindspore.v1.worker.gpu_model_runner import initialize_kv_cache -vllm.v1.worker.gpu_model_runner.GPUModelRunner.initialize_kv_cache = initialize_kv_cache - -import vllm.v1.worker.block_table -from vllm_mindspore.v1.worker.block_table import BlockTable -vllm.v1.worker.block_table.BlockTable = BlockTable -vllm.v1.worker.gpu_input_batch.BlockTable = BlockTable - -import vllm.v1.worker.gpu_input_batch -from vllm_mindspore.v1.worker.gpu_input_batch import _make_sampling_metadata, _make_prompt_token_ids_tensor -vllm.v1.worker.gpu_input_batch.InputBatch._make_sampling_metadata = _make_sampling_metadata -vllm.v1.worker.gpu_model_runner.InputBatch._make_sampling_metadata = _make_sampling_metadata -vllm.v1.worker.gpu_input_batch.InputBatch._make_prompt_token_ids_tensor = _make_prompt_token_ids_tensor -vllm.v1.worker.gpu_model_runner.InputBatch._make_prompt_token_ids_tensor = _make_prompt_token_ids_tensor - -# from vllm.v1.worker.gpu_worker import Worker -# from vllm_mindspore.v1.worker.gpu_worker import init_device - -# Worker.__init__ = wrapper_worker_init(Worker.__init__) -# Worker.init_device = wrapper_worker_init_device(init_device) - - -import vllm.v1.utils -from vllm_mindspore.v1.utils import copy_slice -vllm.v1.utils.copy_slice = copy_slice -vllm.v1.worker.gpu_input_batch.copy_slice = copy_slice - -from vllm_mindspore.v1.sample.ops.penalties import _convert_to_tensors -import vllm.v1.sample.ops.penalties -vllm.v1.sample.ops.penalties._convert_to_tensors = _convert_to_tensors -import vllm.model_executor.layers.utils -from vllm_mindspore.model_executor.layers.utils import apply_penalties -vllm.model_executor.layers.utils.apply_penalties = apply_penalties -vllm.v1.sample.ops.penalties.apply_penalties = apply_penalties - - -from vllm_mindspore.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p, random_sample, \ - apply_top_k_only, topk_topp_sampler_forward_native - -import vllm.v1.sample.ops.topk_topp_sampler -from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler -TopKTopPSampler.forward_native = topk_topp_sampler_forward_native -vllm.v1.sample.ops.topk_topp_sampler.apply_top_k_top_p = apply_top_k_top_p -vllm.v1.sample.ops.topk_topp_sampler.random_sample = random_sample -vllm.v1.sample.ops.topk_topp_sampler.apply_top_k_only = apply_top_k_only -from vllm_mindspore.v1.sample.sampler import apply_temperature -import vllm.v1.sample.sampler -vllm.v1.sample.sampler.Sampler.apply_temperature = apply_temperature - -from vllm_mindspore.distributed.shm_broadcast import initialize_ShmRingBuffer -from vllm.distributed.device_communicators.shm_broadcast import ShmRingBuffer -ShmRingBuffer.__init__ = initialize_ShmRingBuffer - -# from vllm_mindspore.v1.worker.gpu_worker import compile_or_warm_up_model -# from vllm.v1.worker.gpu_worker import Worker -# Worker.compile_or_warm_up_model = compile_or_warm_up_model - -from vllm_mindspore.v1.core.sched.scheduler import schedule -from vllm.v1.core.sched.scheduler import Scheduler -Scheduler.schedule = schedule - -# ================ For vllm-ascend ================ -import vllm_ascend.utils -vllm_ascend.utils.vllm_version_is = lambda version: True - -from vllm_mindspore.platforms.ascend import get_attn_backend_cls - -import vllm.platforms - -vllm.platforms.current_platform.get_attn_backend_cls = get_attn_backend_cls - -from vllm_mindspore.worker.cache_engine import ( - ms_allocate_kv_cache, - ms_swap_in, - ms_swap_out, -) - -from vllm_ascend.worker.worker import CacheEngine - -CacheEngine._allocate_kv_cache = ms_allocate_kv_cache -CacheEngine.swap_in = ms_swap_in -CacheEngine.swap_out = ms_swap_out - -from vllm_mindspore.worker.worker import _warm_up_model - -from vllm_mindspore.worker.profile import ( - wrapper_worker_init, - wrapper_worker_init_device, -) - -from vllm_ascend.worker.worker import NPUWorker - -NPUWorker._warm_up_model = _warm_up_model -NPUWorker.__init__ = wrapper_worker_init(NPUWorker.__init__) -NPUWorker.init_device = wrapper_worker_init_device(NPUWorker.init_device) - -from vllm_mindspore.worker.model_runner import profile_run -from vllm_ascend.worker.model_runner import NPUModelRunner - -NPUModelRunner.profile_run = profile_run - -from vllm_mindspore.v1.worker.gpu_worker import compile_or_warm_up_model -from vllm_ascend.worker.worker_v1 import NPUWorker -NPUWorker.compile_or_warm_up_model = compile_or_warm_up_model - -from vllm_mindspore.v1.worker.gpu_worker import init_device - -NPUWorker.__init__ = wrapper_worker_init(NPUWorker.__init__) -NPUWorker.init_device = wrapper_worker_init_device(init_device) - -# ================ End ================ - -# ============ For v1 start =========== -from vllm_mindspore.config import _get_and_verify_dtype -vllm.config._get_and_verify_dtype = _get_and_verify_dtype - -from vllm_mindspore.worker.model_runner_v1 import _dummy_run, _process_reqs -from vllm_ascend.worker.model_runner_v1 import NPUModelRunner -NPUModelRunner._dummy_run = _dummy_run -NPUModelRunner._process_reqs = _process_reqs - -#from vllm.v1.worker.gpu_model_runner import GPUModelRunner -#NPUModelRunner.execute_model = GPUModelRunner.execute_model - -from vllm_mindspore.worker.model_runner_v1 import _prepare_inputs, _update_states, initialize_kv_cache, get_kv_cache_spec -NPUModelRunner._prepare_inputs = _prepare_inputs -NPUModelRunner._update_states = _update_states -NPUModelRunner.initialize_kv_cache = initialize_kv_cache -NPUModelRunner.get_kv_cache_spec = get_kv_cache_spec - -from vllm_mindspore.worker.worker_v1 import determine_available_memory -NPUWorker.determine_available_memory = determine_available_memory - -# ============ For v1 end =========== + import vllm.v1.worker.gpu_input_batch + from vllm_mindspore.v1.worker.gpu_input_batch import _make_sampling_metadata, _make_prompt_token_ids_tensor + vllm.v1.worker.gpu_input_batch.InputBatch._make_sampling_metadata = _make_sampling_metadata + vllm.v1.worker.gpu_model_runner.InputBatch._make_sampling_metadata = _make_sampling_metadata + vllm.v1.worker.gpu_input_batch.InputBatch._make_prompt_token_ids_tensor = _make_prompt_token_ids_tensor + vllm.v1.worker.gpu_model_runner.InputBatch._make_prompt_token_ids_tensor = _make_prompt_token_ids_tensor + + from vllm.v1.worker.gpu_worker import Worker + from vllm_mindspore.v1.worker.gpu_worker import init_device + + Worker.__init__ = wrapper_worker_init(Worker.__init__) + Worker.init_device = wrapper_worker_init_device(init_device) + + ######### for multi-model + from vllm_mindspore.inputs.registry import call_hf_processor + from vllm.inputs.registry import InputProcessingContext + InputProcessingContext.call_hf_processor = call_hf_processor + + from vllm_mindspore.multimodal.inputs import as_kwargs + from vllm.multimodal.inputs import MultiModalKwargs + MultiModalKwargs.as_kwargs = as_kwargs + + from vllm_mindspore.model_executor.layers.rotary_embedding import InferMRotaryEmbedding + vllm.model_executor.layers.rotary_embedding.MRotaryEmbedding = InferMRotaryEmbedding + + from vllm_mindspore.v1.sample import rejection_sampler + update_modules("vllm.v1.sample.rejection_sampler", rejection_sampler) + + import vllm.v1.utils + from vllm_mindspore.v1.utils import copy_slice + vllm.v1.utils.copy_slice = copy_slice + vllm.v1.worker.gpu_input_batch.copy_slice = copy_slice + + from vllm_mindspore.v1.sample.ops.penalties import _convert_to_tensors + import vllm.v1.sample.ops.penalties + vllm.v1.sample.ops.penalties._convert_to_tensors = _convert_to_tensors + import vllm.model_executor.layers.utils + from vllm_mindspore.model_executor.layers.utils import apply_penalties + vllm.model_executor.layers.utils.apply_penalties = apply_penalties + vllm.v1.sample.ops.penalties.apply_penalties = apply_penalties + + + from vllm_mindspore.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p, random_sample, \ + apply_top_k_only, topk_topp_sampler_forward_native + + import vllm.v1.sample.ops.topk_topp_sampler + from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler + TopKTopPSampler.forward_native = topk_topp_sampler_forward_native + vllm.v1.sample.ops.topk_topp_sampler.apply_top_k_top_p = apply_top_k_top_p + vllm.v1.sample.ops.topk_topp_sampler.random_sample = random_sample + vllm.v1.sample.ops.topk_topp_sampler.apply_top_k_only = apply_top_k_only + from vllm_mindspore.v1.sample.sampler import apply_temperature + import vllm.v1.sample.sampler + vllm.v1.sample.sampler.Sampler.apply_temperature = apply_temperature + + from vllm_mindspore.distributed.shm_broadcast import initialize_ShmRingBuffer + from vllm.distributed.device_communicators.shm_broadcast import ShmRingBuffer + ShmRingBuffer.__init__ = initialize_ShmRingBuffer + + from vllm_mindspore.v1.worker.gpu_worker import compile_or_warm_up_model + from vllm.v1.worker.gpu_worker import Worker + Worker.compile_or_warm_up_model = compile_or_warm_up_model from .utils import check_ready diff --git a/vllm_mindspore/attention/backends/ms_attn.py b/vllm_mindspore/attention/backends/ms_attn.py index 64c94efc..d6123b0a 100644 --- a/vllm_mindspore/attention/backends/ms_attn.py +++ b/vllm_mindspore/attention/backends/ms_attn.py @@ -501,12 +501,17 @@ class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]): self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, + batch_size: int, ): """Build attention metadata with on-device tensors. Args: seq_lens: The maybe padded sequence lengths of the input sequences. query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. """ prefix_cache_hit = any( [ @@ -520,6 +525,7 @@ class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]): ) device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 max_query_len = max(query_lens) decode_query_lens = query_lens[self.num_prefills:] @@ -533,12 +539,15 @@ class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]): query_start_loc = list(accumulate(query_lens, initial=0)) seq_start_loc = list(accumulate(seq_lens, initial=0)) - block_tables = make_tensor_with_pad( - self.block_tables, - pad=-1, - dtype=torch.int, - device=device, - ) + if use_captured_graph: + raise RuntimeError("Doesnot support captured graph now!") + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=-1, + dtype=torch.int, + device=device, + ) assert max_query_len > 0, "query_lens: {}".format(query_lens) context_lens_tensor = ms.Tensor(self.context_lens, dtype=ms.int32) @@ -587,10 +596,6 @@ class MsAttentionBackend(AttentionBackend): def get_builder_cls() -> Type["MsAttentionMetadataBuilder"]: return MsAttentionMetadataBuilder - @classmethod - def make_metadata_builder(cls, *args, **kwargs) -> "MsAttentionMetadataBuilder": - return cls.get_builder_cls()(*args, **kwargs) - @staticmethod def get_state_cls() -> Type["AttentionState"]: return MsAttentionState diff --git a/vllm_mindspore/attention/backends/ms_attn_v1.py b/vllm_mindspore/attention/backends/ms_attn_v1.py deleted file mode 100644 index 731a28d6..00000000 --- a/vllm_mindspore/attention/backends/ms_attn_v1.py +++ /dev/null @@ -1,210 +0,0 @@ -#!/usr/bin/env python3 -# encoding: utf-8 -# Copyright 2025 Huawei Technologies Co., Ltd -# Copyright 2024 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type -import numpy as np - -import torch -from mindspore import mutable -from mindspore._c_expression import swap_cache - -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) -from vllm.logger import init_logger - -from vllm_mindspore.utils import MsKVCache - -logger = init_logger(__name__) - - -class MsAttentionBackend(AttentionBackend): - - accept_output_buffer: bool = True - - @staticmethod - def get_supported_head_sizes() -> List[int]: - return [32, 64, 96, 128, 160, 192, 224, 256] - - @staticmethod - def get_name() -> str: - return "MS_ATTN" - - @staticmethod - def get_impl_cls() -> Type["AttentionImpl"]: - return MsAttentionImpl - - @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: - return MSAttentionMetadata - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - if block_size % 16 != 0: - raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) - - @staticmethod - def use_cascade_attention(*args, **kwargs) -> bool: - return False - - -class MLABackend(AttentionBackend): - @staticmethod - def get_name() -> str: - return "MS_MLA" - - @staticmethod - def get_impl_cls() -> Type["AttentionImpl"]: - return MsAttentionImpl - - @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: - return MSAttentionMetadata - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - if block_size % 16 != 0: - raise ValueError("Block size must be a multiple of 16.") - return (1, num_blocks, block_size, 1, head_size) - - @staticmethod - def use_cascade_attention(*args, **kwargs) -> bool: - return False - - @staticmethod - def get_supported_head_sizes() -> List[int]: - return [576] - - -@dataclass -class MSAttentionMetadata: - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - - # max_seq_len: int - # seq_lens: torch.Tensor - # seq_lens_np: np.ndarray - # block_tables: torch.Tensor - # slot_mapping: torch.Tensor - # q_seq_lens: torch.Tensor - # context_lens: torch.Tensor - # max_context_lens: int - - # def __getitem__(self, key): - # if key == "batch_valid_length": - # key = "seq_lens" - # if key == "block_tables": - # if getattr(self, key).ndim == 1: - # return mutable(getattr(self, key).expand_dims(0)) - # return mutable(getattr(self, key)) - # return getattr(self, key) - - # AscendMetadata - block_tables: Optional[torch.Tensor] - seq_lens: Optional[List[int]] = None - context_lens: Optional[List[int]] = None - max_query_len: Optional[int] = None - slot_mapping: torch.Tensor = None - is_only_prefill: bool = False - attn_mask: Optional[torch.Tensor] = None - - # add for mindspore - num_decode_tokens: int = 0 - query_lens: Optional[List[int]] = None - - -class MsAttentionImpl(AttentionImpl): - """ - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prefill_tokens ----------------->| - |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| - - Otherwise, the layout is as follows: - |<----------------- num_decode_tokens ------------------>| - |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| - - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. - - The prompts might have different lengths, while the generation tokens - always have length 1. - - If chunked prefill is enabled, prefill tokens and decode tokens can be - batched together in a flattened 1D query. - - |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| - |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| - - Currently, cuda graph is disabled for chunked prefill, meaning there's no - padding between prefill and decode tokens. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]] = None, - logits_soft_cap: Optional[float] = None, - attn_type: AttentionType = AttentionType.DECODER, - ) -> None: - pass - - def forward( - self, - layer: torch.nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: MSAttentionMetadata, - output: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with FlashAttention. - - Args: - query: shape = [num_tokens, num_heads, head_size] - key: shape = [num_tokens, num_kv_heads, head_size] - value: shape = [num_tokens, num_kv_heads, head_size] - output: shape = [num_tokens, num_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] - NOTE: kv_cache will be an empty tensor with shape [0] - for profiling run. - attn_metadata: Metadata for attention. - NOTE: It in-place updates the output tensor. - """ - pass diff --git a/vllm_mindspore/patch/__init__.py b/vllm_mindspore/patch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_mindspore/patch/patch_vllm_ascend.py b/vllm_mindspore/patch/patch_vllm_ascend.py new file mode 100644 index 00000000..78ec8872 --- /dev/null +++ b/vllm_mindspore/patch/patch_vllm_ascend.py @@ -0,0 +1,274 @@ +import sys +# ================ For vllm ================ + +import vllm.utils + +from vllm_mindspore.utils import ( + direct_register_custom_op, + make_tensor_with_pad, + async_tensor_h2d, + get_dtype_size, + ascend_device_count_stateless, + ascend_is_initialized, +) + +vllm.utils.direct_register_custom_op = direct_register_custom_op +vllm.utils.make_tensor_with_pad = make_tensor_with_pad +vllm.utils.async_tensor_h2d = async_tensor_h2d +vllm.utils.get_dtype_size = get_dtype_size +vllm.utils.cuda_device_count_stateless = ascend_device_count_stateless +vllm.utils.cuda_is_initialized = ascend_is_initialized +vllm.config.cuda_device_count_stateless = ascend_device_count_stateless + +import vllm.executor + +vllm.executor.cuda_device_count_stateless = ascend_device_count_stateless + +from vllm_mindspore.model_executor.models.registry import ( + MindSporeModelRegistry, + _SUBPROCESS_COMMAND, +) + + +vllm.config.ModelRegistry = MindSporeModelRegistry + +import vllm.model_executor + +vllm.model_executor.models.ModelRegistry = MindSporeModelRegistry +vllm.model_executor.models.registry._SUBPROCESS_COMMAND = _SUBPROCESS_COMMAND + +from vllm_mindspore.model_executor.model_loader.utils import get_ms_model_architecture + +# To patching the get_model_architecture, should import it first. +from vllm.model_executor.model_loader import get_model_architecture + +vllm.model_executor.model_loader.get_model_architecture = get_ms_model_architecture +vllm.model_executor.model_loader.utils.get_model_architecture = ( + get_ms_model_architecture +) +vllm.model_executor.model_loader.loader.get_model_architecture = ( + get_ms_model_architecture +) + +from vllm_mindspore.model_executor.sampling_metadata import ( + SequenceGroupToSample, + SamplingMetadataCache, + SamplingMetadata, +) + +vllm.model_executor.SamplingMetadataCache = SamplingMetadataCache +vllm.model_executor.SamplingMetadata = SamplingMetadata +vllm.model_executor.sampling_metadata.SequenceGroupToSample = SequenceGroupToSample +vllm.model_executor.sampling_metadata.SamplingMetadataCache = SamplingMetadataCache +vllm.model_executor.sampling_metadata.SamplingMetadata = SamplingMetadata + +from vllm_mindspore.model_executor.model_loader.weight_utils import ( + safetensors_weights_iterator, +) + +vllm.model_executor.model_loader.loader.safetensors_weights_iterator = ( + safetensors_weights_iterator +) + +from vllm_mindspore.worker.model_runner import ( + _get_cuda_graph_pad_size, + _dummy_run, + _get_supported_attention_backends, +) + +import vllm.worker.model_runner +vllm.worker.model_runner.ModelInputForGPUBuilder._get_cuda_graph_pad_size = ( + _get_cuda_graph_pad_size +) +vllm.worker.model_runner.GPUModelRunnerBase._dummy_run = _dummy_run + +import vllm.worker.multi_step_model_runner + +vllm.worker.multi_step_model_runner._get_supported_attention_backends = ( + _get_supported_attention_backends +) + +from vllm_mindspore.executor.multiproc_worker_utils import ( + get_mp_context as ms_get_mp_context, +) + +from vllm_mindspore.executor.ray_gpu_executor import ( + ms_init_workers_ray, + initialize_ray_cluster, +) + +from vllm.executor.ray_distributed_executor import RayDistributedExecutor + +RayDistributedExecutor._init_workers_ray = ms_init_workers_ray + +vllm.executor.ray_distributed_executor.initialize_ray_cluster = initialize_ray_cluster +vllm.executor.ray_utils.initialize_ray_cluster = initialize_ray_cluster + +import vllm.engine.llm_engine +import vllm.engine.async_llm_engine + +vllm.engine.llm_engine.initialize_ray_cluster = initialize_ray_cluster +vllm.engine.async_llm_engine.initialize_ray_cluster = initialize_ray_cluster + + +from vllm_mindspore.config import _verify_quantization, _verify_args, vllm_config_post_init, model_post_init, \ + _get_and_verify_dtype, stateless_init_dp_group, has_unfinished_dp + +vllm.config.ModelConfig._verify_quantization = _verify_quantization +vllm.config.VllmConfig.__post_init__ = vllm_config_post_init +vllm.config.SchedulerConfig._verify_args = _verify_args +vllm.config.CompilationConfig.model_post_init = model_post_init +vllm.config._get_and_verify_dtype = _get_and_verify_dtype +vllm.config.ParallelConfig.stateless_init_dp_group = stateless_init_dp_group +vllm.config.ParallelConfig.has_unfinished_dp = has_unfinished_dp + +from vllm_mindspore.utils import update_modules +from vllm_mindspore.attention.backends import ms_attn +update_modules("vllm.attention.backends.flash_attn", ms_attn) + +from vllm_mindspore.worker.spec_decode_worker import ( + spec_decode_worker_init, + _run_no_spec, + _verify_tokens, + _create_output, + _merge_outputs, +) +from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker +SpecDecodeWorker.__init__ = spec_decode_worker_init +SpecDecodeWorker._verify_tokens = _verify_tokens +SpecDecodeWorker._run_no_spec = _run_no_spec + +from vllm.model_executor.layers.spec_decode_base_sampler import SpecDecodeBaseSampler +SpecDecodeBaseSampler._create_output = _create_output + +from vllm.spec_decode.top1_proposer import Top1Proposer +Top1Proposer._merge_outputs = _merge_outputs + +from vllm_mindspore.model_executor.layers.rejection_sampler import _smallest_positive_value, _multinomial +from vllm.model_executor.layers.rejection_sampler import RejectionSampler +RejectionSampler._smallest_positive_value = _smallest_positive_value +RejectionSampler._smallest_positive_value.__set_name__(RejectionSampler, '_smallest_positive_value') +vllm.model_executor.layers.rejection_sampler._multinomial = _multinomial + +from vllm_mindspore.v1.sample import rejection_sampler +update_modules("vllm.v1.sample.rejection_sampler", rejection_sampler) + +from vllm_mindspore.v1.spec_decode import eagle +update_modules("vllm.v1.spec_decode.eagle", eagle) + +from vllm_mindspore.v1.attention.backends import flash_attn +import vllm.v1.attention.backends +sys.modules['vllm.v1.attention.backends.flash_attn'] = flash_attn +import vllm.v1.attention.backends.flash_attn + +import vllm.v1.worker.gpu_model_runner + +import vllm.v1.worker.block_table +from vllm_mindspore.v1.worker.block_table import BlockTable +vllm.v1.worker.block_table.BlockTable = BlockTable +vllm.v1.worker.gpu_input_batch.BlockTable = BlockTable + +import vllm.v1.worker.gpu_input_batch +from vllm_mindspore.v1.worker.gpu_input_batch import _make_sampling_metadata, _make_prompt_token_ids_tensor +vllm.v1.worker.gpu_input_batch.InputBatch._make_sampling_metadata = _make_sampling_metadata +vllm.v1.worker.gpu_model_runner.InputBatch._make_sampling_metadata = _make_sampling_metadata +vllm.v1.worker.gpu_input_batch.InputBatch._make_prompt_token_ids_tensor = _make_prompt_token_ids_tensor +vllm.v1.worker.gpu_model_runner.InputBatch._make_prompt_token_ids_tensor = _make_prompt_token_ids_tensor + +import vllm.v1.utils +from vllm_mindspore.v1.utils import copy_slice +vllm.v1.utils.copy_slice = copy_slice +vllm.v1.worker.gpu_input_batch.copy_slice = copy_slice + +from vllm_mindspore.v1.sample.ops.penalties import _convert_to_tensors +import vllm.v1.sample.ops.penalties +vllm.v1.sample.ops.penalties._convert_to_tensors = _convert_to_tensors +import vllm.model_executor.layers.utils +from vllm_mindspore.model_executor.layers.utils import apply_penalties +vllm.model_executor.layers.utils.apply_penalties = apply_penalties +vllm.v1.sample.ops.penalties.apply_penalties = apply_penalties + + +from vllm_mindspore.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p, random_sample, \ + apply_top_k_only, topk_topp_sampler_forward_native + +import vllm.v1.sample.ops.topk_topp_sampler +from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler +TopKTopPSampler.forward_native = topk_topp_sampler_forward_native +vllm.v1.sample.ops.topk_topp_sampler.apply_top_k_top_p = apply_top_k_top_p +vllm.v1.sample.ops.topk_topp_sampler.random_sample = random_sample +vllm.v1.sample.ops.topk_topp_sampler.apply_top_k_only = apply_top_k_only +from vllm_mindspore.v1.sample.sampler import apply_temperature +import vllm.v1.sample.sampler +vllm.v1.sample.sampler.Sampler.apply_temperature = apply_temperature + +from vllm_mindspore.distributed.shm_broadcast import initialize_ShmRingBuffer +from vllm.distributed.device_communicators.shm_broadcast import ShmRingBuffer +ShmRingBuffer.__init__ = initialize_ShmRingBuffer + +from vllm_mindspore.v1.worker.gpu_worker import compile_or_warm_up_model +from vllm.v1.worker.gpu_worker import Worker +Worker.compile_or_warm_up_model = compile_or_warm_up_model + +# ================ For vllm-ascend ================ +import types +fake_mod = types.ModuleType("vllm_ascend.vllm_ascend_C") +fake_mod.init_module = fake_mod.python_create_and_map = fake_mod.python_unmap_and_release = lambda *a, **kw: None +sys.modules.update({"vllm_ascend.vllm_ascend_C": fake_mod}) + +import vllm_ascend.utils +from vllm_mindspore.utils import vllm_version_is +vllm_ascend.utils.vllm_version_is = vllm_version_is + +from vllm_mindspore.platforms.ascend import get_attn_backend_cls +from vllm_ascend.platform import NPUPlatform +NPUPlatform.get_attn_backend_cls = get_attn_backend_cls + +from vllm_mindspore.worker.cache_engine import ( + ms_allocate_kv_cache, + ms_swap_in, + ms_swap_out, +) + +from vllm_ascend.worker.worker import CacheEngine + +CacheEngine._allocate_kv_cache = ms_allocate_kv_cache +CacheEngine.swap_in = ms_swap_in +CacheEngine.swap_out = ms_swap_out + +from vllm_mindspore.worker.worker import _warm_up_model + +from vllm_mindspore.worker.profile import ( + wrapper_worker_init, + wrapper_worker_init_device, +) + +from vllm_ascend.worker.worker import NPUWorker + +NPUWorker._warm_up_model = _warm_up_model +NPUWorker.__init__ = wrapper_worker_init(NPUWorker.__init__) +NPUWorker.init_device = wrapper_worker_init_device(NPUWorker.init_device) + +# ================ End ================ + +# ============ For v1 start =========== +from vllm_mindspore.config import _get_and_verify_dtype +vllm.config._get_and_verify_dtype = _get_and_verify_dtype + +from vllm_mindspore.worker.model_runner_v1 import _dummy_run, _process_reqs, wrapper_runner_init +from vllm_ascend.worker.model_runner_v1 import NPUModelRunner +NPUModelRunner._dummy_run = _dummy_run +NPUModelRunner._process_reqs = _process_reqs +NPUModelRunner.__init__ = wrapper_runner_init(NPUModelRunner.__init__) + +from vllm_mindspore.worker.worker_v1 import determine_available_memory +from vllm_ascend.worker.worker_v1 import NPUWorker +NPUWorker.determine_available_memory = determine_available_memory + +from vllm_mindspore.v1.worker.gpu_model_runner import initialize_kv_cache +NPUModelRunner.initialize_kv_cache = initialize_kv_cache + +from vllm_mindspore.v1.worker.gpu_worker import compile_or_warm_up_model +from vllm_ascend.worker.worker_v1 import NPUWorker +NPUWorker.compile_or_warm_up_model = compile_or_warm_up_model +# ============ For v1 end =========== \ No newline at end of file diff --git a/vllm_mindspore/platforms/ascend.py b/vllm_mindspore/platforms/ascend.py index 89a41646..d4b71afc 100644 --- a/vllm_mindspore/platforms/ascend.py +++ b/vllm_mindspore/platforms/ascend.py @@ -145,14 +145,14 @@ class AscendPlatform(Platform): def supports_v1(cls, model_config: ModelConfig) -> bool: return True -def get_attn_backend_cls(selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla): +def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla): """Get the attention backend class of a device.""" if use_v1: if use_mla: logger.info("Using MindSpore MLA backend for V1.") - return "vllm_mindspore.attention.backends.ms_attn_v1.MLABackend" + return "vllm_mindspore.v1.attention.backends.flash_attn.MLABackend" logger.info("Using MindSpore Attention backend for V1.") - return "vllm_mindspore.attention.backends.ms_attn_v1.MsAttentionBackend" + return "vllm_mindspore.v1.attention.backends.flash_attn.FlashAttentionBackend" if use_mla: logger.info("Using MindSpore MLA backend.") return "vllm_mindspore.attention.backends.ms_attn.MLABackend" diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index af65c4de..9d488d7c 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -30,6 +30,8 @@ if TYPE_CHECKING: else: Library = None +from packaging.version import Version + from vllm.logger import init_logger import mindspore as ms @@ -294,3 +296,10 @@ def ms_memory_profiling( result.non_torch_increase = diff_from_create.non_torch_memory result.profile_time = diff_profile.timestamp result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa + + +def vllm_version_is(version: str): + import vllm + if vllm.__version__ == '0.8.3': # since vllm-ascend support from 0.8.4, 0.8.3 should be supported too. + return True + return Version(vllm.__version__) == Version(version) diff --git a/vllm_mindspore/v1/worker/gpu_worker.py b/vllm_mindspore/v1/worker/gpu_worker.py index 0395c339..e6543b1c 100644 --- a/vllm_mindspore/v1/worker/gpu_worker.py +++ b/vllm_mindspore/v1/worker/gpu_worker.py @@ -5,6 +5,7 @@ import gc import torch from vllm.logger import init_logger from vllm.distributed.parallel_state import get_pp_group +from vllm.model_executor import set_random_seed logger = init_logger(__name__) @@ -48,5 +49,6 @@ def compile_or_warm_up_model(self) -> None: # Since prefill is done previously, we do decode here. default_max_num_reqs = 1 # For MindSpore, we only do one more decode here. if get_pp_group().is_last_rank: - self.model_runner._dummy_sampler_run(self.model_runner._dummy_run( - num_tokens=default_max_num_reqs)) + self.model_runner._dummy_run( + num_tokens=default_max_num_reqs) + set_random_seed(self.model_config.seed) diff --git a/vllm_mindspore/worker/model_runner.py b/vllm_mindspore/worker/model_runner.py index 56e425aa..68a9360c 100644 --- a/vllm_mindspore/worker/model_runner.py +++ b/vllm_mindspore/worker/model_runner.py @@ -19,9 +19,9 @@ from typing import List import torch - from vllm.distributed import get_pp_group from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.sequence import SequenceGroupMetadata from vllm_mindspore.utils import STR_DTYPE_TO_TENSOR_DTYPE @@ -132,7 +132,8 @@ def _dummy_run(self, # tensor aliasing. kv_cache_dtype = self.model_config.dtype if self.cache_config.cache_dtype == "auto" \ else self.cache_config.cache_dtype - kv_cache_dtype = STR_DTYPE_TO_TENSOR_DTYPE[kv_cache_dtype] + if kv_cache_dtype in STR_DTYPE_TO_TENSOR_DTYPE: + kv_cache_dtype = STR_DTYPE_TO_TENSOR_DTYPE[kv_cache_dtype] block_size = self.cache_config.block_size num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) head_size = self.model_config.get_head_size() @@ -179,97 +180,3 @@ def _get_supported_attention_backends(chunked_prefill_enabled: bool) \ return MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS else: return MULTI_STEP_ATTENTION_BACKENDS - - -def profile_run(self) -> None: - # Enable top-k sampling to reflect the accurate memory usage. - sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) - max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens - max_num_seqs = self.scheduler_config.max_num_seqs - - # Profile memory usage with max_num_sequences sequences and the total - # number of tokens equal to max_num_batched_tokens. - seqs: List[SequenceGroupMetadata] = [] - # Additional GPU memory may be needed for multi-modal encoding, which - # needs to be accounted for when calculating the GPU blocks for - # vLLM blocker manager. - # To exercise the worst scenario for GPU memory consumption, - # the number of seqs (batch_size) is chosen to maximize the number - # of images processed. - - max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( - self.model_config) - if max_mm_tokens > 0: - max_num_seqs_orig = max_num_seqs - max_num_seqs = min(max_num_seqs, - max_num_batched_tokens // max_mm_tokens) - if max_num_seqs < 1: - expr = (f"min({max_num_seqs_orig}, " - f"{max_num_batched_tokens} // {max_mm_tokens})") - logger.warning( - "Computed max_num_seqs (%s) to be less than 1. " - "Setting it to the minimum value of 1.", expr) - max_num_seqs = 1 - - batch_size = 0 - for group_id in range(max_num_seqs): - seq_len = (max_num_batched_tokens // max_num_seqs + - (group_id < max_num_batched_tokens % max_num_seqs)) - batch_size += seq_len - - dummy_data = self.input_registry \ - .dummy_data_for_profiling(self.model_config, - seq_len, - self.mm_registry) - - seq = SequenceGroupMetadata( - request_id=str(group_id), - is_prompt=True, - seq_data={group_id: dummy_data.seq_data}, - sampling_params=sampling_params, - block_tables=None, - lora_request=None, - multi_modal_data=dummy_data.multi_modal_data, - multi_modal_placeholders=dummy_data.multi_modal_placeholders, - ) - seqs.append(seq) - - # Run the model with the dummy inputs. - num_layers = self.model_config.get_num_layers(self.parallel_config) - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value ``None``. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). - # it is important to create tensors inside the loop, rather than - # multiplying the list, to avoid Dynamo from treating them as - # tensor aliasing. - - kv_cache_dtype = self.model_config.dtype if self.cache_config.cache_dtype == "auto" \ - else self.cache_config.cache_dtype - kv_cache_dtype = STR_DTYPE_TO_TENSOR_DTYPE[kv_cache_dtype] - block_size = self.cache_config.block_size - num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) - head_size = self.model_config.get_head_size() - kv_shape = [0, block_size, num_kv_heads, head_size] - kv_caches = mutable([ - mutable(( - mutable(torch.tensor([], dtype=kv_cache_dtype, device=self.device).reshape(kv_shape)), - mutable(torch.tensor([], dtype=kv_cache_dtype, device=self.device).reshape(kv_shape)), - )) - for _ in range(num_layers) - ]) - finished_requests_ids = [seq.request_id for seq in seqs] - model_input = self.prepare_model_input( - seqs, finished_requests_ids=finished_requests_ids) - intermediate_tensors = None - if not get_pp_group().is_first_rank: - intermediate_tensors = self.model.make_empty_intermediate_tensors( - batch_size=batch_size, - dtype=self.model_config.dtype, - device=self.device) - - self.execute_model(model_input, kv_caches, intermediate_tensors) - - from vllm.platforms import current_platform - current_platform.synchronize() - return diff --git a/vllm_mindspore/worker/model_runner_v1.py b/vllm_mindspore/worker/model_runner_v1.py index 0e857c5d..10d7e2d2 100644 --- a/vllm_mindspore/worker/model_runner_v1.py +++ b/vllm_mindspore/worker/model_runner_v1.py @@ -17,17 +17,21 @@ # ============================================================================ from typing import List, Optional, Tuple import numpy as np +import weakref import torch import mindspore as ms from mindspore import Tensor +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import set_forward_context from vllm.sequence import IntermediateTensors +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.attention.layer import Attention -from vllm_mindspore.attention.backends.ms_attn_v1 import MSAttentionMetadata, MsAttentionBackend +from vllm_mindspore.v1.attention.backends.flash_attn import FlashAttentionMetadata #################33 from mindspore import mutable @@ -51,12 +55,30 @@ from vllm.sampling_params import SamplingType logger = init_logger(__name__) +def wrapper_runner_init(func): + def wrapper(*args, **kwargs): + func(*args, **kwargs) + self = args[0] + self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, + dtype=torch.int32, + device="cpu", + pin_memory=True) + self.query_start_loc_np = self.query_start_loc_cpu.numpy() + + import weakref + self.attn_metadata_builder = self.attn_backend.get_builder_cls()( + weakref.proxy(self)) + + return wrapper + + @torch.inference_mode() def _dummy_run( self, - num_tokens: int, - dummy_kv_caches: List[torch.Tensor], + num_tokens: int = None, ) -> torch.Tensor: + if num_tokens is None: + num_tokens = self.max_num_tokens model = self.model if self.is_multimodal_model: input_ids = None @@ -68,7 +90,8 @@ def _dummy_run( if self.uses_mrope: positions = self.mrope_positions[:, :num_tokens] else: - positions = self.input_positions_cpu[:num_tokens] + positions_np = self.positions_np[:num_tokens] + positions = torch.from_numpy(positions_np) if get_pp_group().is_first_rank: intermediate_tensors = None @@ -84,163 +107,33 @@ def _dummy_run( for k, v in self.intermediate_tensors.items() }) - if dummy_kv_caches is None: # for compile_or_warm_up_model - attn_metadata = _dummy_attention_metadata(input_ids, positions, False) - else: - attn_metadata = _dummy_attention_metadata(input_ids, positions, True) - - with set_forward_context(attn_metadata, self.vllm_config): + with set_forward_context(None, self.vllm_config): hidden_states = model(input_ids=input_ids, - #positions=positions.to(self.device), positions=positions, intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - kv_caches=dummy_kv_caches, - attn_metadata=attn_metadata) + inputs_embeds=inputs_embeds) return hidden_states -def _dummy_attention_metadata(input_ids: Tensor, positions: Tensor, is_prefill=True) -> MSAttentionMetadata: - input_len = input_ids.shape[0] - # max_seq_len = ms.Tensor(input_len, dtype=ms.int32) - # seq_lengths = ms.Tensor([input_len], dtype=ms.int32) - # q_seq_lens = ms.Tensor([input_len], dtype=ms.int32) - seq_lens_np = np.array([input_len], dtype=np.int32) - - block_tables = ms.Tensor([[0]], dtype=ms.int32) - slot_mapping = [-1 for _ in range(input_len)] - slot_mapping = ms.Tensor(slot_mapping, dtype=ms.int32) - return MSAttentionMetadata( - block_tables=block_tables, - seq_lens=seq_lens_np, - context_lens=0, - max_query_len=1, - slot_mapping=slot_mapping, - num_decode_tokens=0 if is_prefill else 1, - query_lens=seq_lens_np, - ) - - def _process_reqs( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: - # Check input valid total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 - # Copy the blocks from CPU to NPU. - # OPTIMIZATION: Start copying the block table first. - # This way, we can overlap the copy with the following CPU operations. - self.input_batch.block_table.commit(num_reqs) - - # Get the number of scheduled tokens for each request. - # TODO: The Python loop can be slow. Optimize. - num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32) - max_num_scheduled_tokens = 0 - for i, req_id in enumerate(self.input_batch.req_ids): - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - num_scheduled_tokens[i] = num_tokens - # max_num_scheduled_tokens = max(max_num_scheduled_tokens, - # num_tokens) - - # Prepare positions - req_indices = np.repeat(self.arange_np[:num_reqs], - num_scheduled_tokens) - cu_num_tokens = np.cumsum(num_scheduled_tokens) - cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens, - num_scheduled_tokens) - arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets - - positions_np = self.positions_np[:total_num_scheduled_tokens] - np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np) - - self.positions[:total_num_scheduled_tokens].copy_( - self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) - positions = self.positions[:total_num_scheduled_tokens] - - self.seq_lens_np[:num_reqs] = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] + - num_scheduled_tokens) - seq_lens = self.seq_lens_cpu[:num_reqs] - - query_lens = torch.from_numpy(num_scheduled_tokens) - - block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions_np // self.block_size) - block_table_cpu = self.input_batch.block_table.get_cpu_tensor() - block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() - block_offsets = positions_np % self.block_size - np.add(block_numbers * self.block_size, - block_offsets, - out=self.slot_mapping_np[:total_num_scheduled_tokens]) - slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to( - self.device, non_blocking=True) - - # attn_mask = self.make_attention_mask(seq_lens=seq_lens, - # query_lens=num_scheduled_tokens, - # position=positions) - - num_decode_tokens = self.input_batch.num_computed_tokens_cpu[:num_reqs].max() - - attn_metadata = MSAttentionMetadata( - seq_lens=query_lens, - context_lens=seq_lens, - slot_mapping=slot_mapping, - block_tables=( - self.input_batch.block_table.get_device_tensor()[:num_reqs]), - # attn_mask=attn_mask - num_decode_tokens=num_decode_tokens, - query_lens=query_lens - ) - - # Prepare input_ids - token_indices = (positions_np + - req_indices * self.input_batch.token_ids_cpu.shape[1]) - torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), - 0, - torch.from_numpy(token_indices), - out=self.input_ids_cpu[:total_num_scheduled_tokens]) - # Copy the tensors to the NPU. - self.input_ids[:total_num_scheduled_tokens].copy_( - self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) - input_ids = self.input_ids[:total_num_scheduled_tokens] - - # Run forward pass - with set_forward_context(attn_metadata, self.vllm_config): - assert self.model is not None - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=None, - kv_caches=self.kv_caches, - attn_metadata=attn_metadata, - ) - - return hidden_states[cu_num_tokens - 1] - - -def _prepare_inputs( - self, - scheduler_output: "SchedulerOutput", -) -> Tuple[MSAttentionMetadata, torch.Tensor]: - total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - assert total_num_scheduled_tokens > 0 - num_reqs = self.input_batch.num_reqs - assert num_reqs > 0 + modified_batch = self.attn_metadata_builder.reorder_batch( + self.input_batch, scheduler_output) + if modified_batch: + self.input_batch.refresh_sampling_metadata() # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. self.input_batch.block_table.commit(num_reqs) - # context_lens = ms.Tensor(self.input_batch.num_computed_tokens_cpu[:num_reqs], dtype=torch.int32) - context_lens = ms.from_numpy(self.input_batch.num_computed_tokens_cpu[:num_reqs]) - context_lens.move_to("Ascend", blocking=False) + # Get the number of scheduled tokens for each request. # TODO: The Python loop can be slow. Optimize. num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32) @@ -251,11 +144,6 @@ def _prepare_inputs( max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens) - # non_blocking send q_seq_lens to device - # q_seq_lens = ms.Tensor(num_scheduled_tokens, dtype=ms.int32) - q_seq_lens = ms.from_numpy(num_scheduled_tokens) - q_seq_lens.move_to("Ascend", blocking=False) - # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] req_indices = np.repeat(self.arange_np[:num_reqs], @@ -278,29 +166,30 @@ def _prepare_inputs( np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np) - # self.positions_cpu[:total_num_scheduled_tokens] = torch.from_numpy(positions_np) - self.positions[:total_num_scheduled_tokens] = torch.from_numpy(positions_np) + + if self.uses_mrope: + self._calc_mrope_positions(scheduler_output) + + if self.uses_mrope: + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + self.mrope_positions[:, :total_num_scheduled_tokens].copy_( + self.mrope_positions_cpu[:, :total_num_scheduled_tokens], + non_blocking=True) + positions = self.mrope_positions[:, :total_num_scheduled_tokens] + else: + self.positions[:total_num_scheduled_tokens] = torch.from_numpy(positions_np) + positions = self.positions[:total_num_scheduled_tokens] + # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] # where M is the max_model_len. token_indices = (positions_np + - req_indices * self.input_batch.token_ids_cpu.shape[1]) - - # NOTE(woosuk): We use torch.index_select instead of np.take here - # because torch.index_select is much faster than np.take for large - # tensors. - # self.input_ids_cpu[:total_num_scheduled_tokens] = \ - # torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), - # 0, - # torch.from_numpy(token_indices)) - # self.input_ids[:total_num_scheduled_tokens] = \ - # torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), - # 0, - # torch.from_numpy(token_indices)) + req_indices * self.input_batch.token_ids_cpu.shape[1]) + self.input_ids[:total_num_scheduled_tokens] = torch.from_numpy( - np.take(self.input_batch.token_ids_cpu.flatten(), + np.take(self.input_batch.token_ids_cpu.ravel(), token_indices, 0) ) @@ -313,380 +202,40 @@ def _prepare_inputs( # because M (max_model_len) is not necessarily divisible by block_size. block_table_indices = (req_indices * self.max_num_blocks_per_req + positions_np // self.block_size) - # NOTE(woosuk): We use torch.index_select instead of np.take here - # because torch.index_select is much faster than np.take for large - # tensors. - # block_table_cpu = self.input_batch.block_table.get_cpu_tensor() - # block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() - block_numbers = self.input_batch.block_table.block_table_np.flatten()[block_table_indices] + + + block_numbers = self.input_batch.block_table.block_table_np.ravel()[block_table_indices] block_offsets = positions_np % self.block_size np.add(block_numbers * self.block_size, block_offsets, out=self.slot_mapping_np[:total_num_scheduled_tokens]) - # TODO: - # self.slot_mapping_cpu[:total_num_scheduled_tokens] = \ - # torch.from_numpy(self.slot_mapping_np[:total_num_scheduled_tokens]) - - # non_blocking send q_seq_lens to device - # TODO: - # slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(torch.int32) - # slot_mapping = ms.Tensor(self.slot_mapping_np[:total_num_scheduled_tokens], dtype=ms.int32) - slot_mapping = ms.from_numpy(self.slot_mapping_np[:total_num_scheduled_tokens]) - slot_mapping.move_to("Ascend", blocking=False) # # Prepare the attention metadata. self.query_start_loc_np[0] = 0 self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens - # query_start_loc = ms.Tensor(self.query_start_loc_np[:num_reqs + 1], dtype=ms.int32) - query_start_loc = ms.from_numpy(self.query_start_loc_np[:num_reqs + 1]) - query_start_loc.move_to("Ascend", blocking=False) - # TODO: - # self.query_start_loc_cpu[1:num_reqs + 1] = \ - # torch.from_numpy(self.query_start_loc_np[1:num_reqs + 1]) self.seq_lens_np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens) - max_seq_len = self.seq_lens_np[:num_reqs].max() - # TODO: - seq_lens_np = self.seq_lens_np[:num_reqs] - # seq_lens = ms.Tensor(seq_lens_np, dtype=ms.int32) - seq_lens = ms.from_numpy(seq_lens_np) - seq_lens.move_to("Ascend", blocking=False) - # self.seq_lens_cpu[:num_reqs] = torch.from_numpy(self.seq_lens_np[:num_reqs]) - - # # Copy the tensors to the GPU. - # self.input_ids[:total_num_scheduled_tokens] = \ - # self.input_ids_cpu[:total_num_scheduled_tokens] - - # # Common case (1D positions) - # self.positions[:total_num_scheduled_tokens] = \ - # self.positions_cpu[:total_num_scheduled_tokens] - - # TODO: - # query_start_loc = self.query_start_loc_cpu[:num_reqs + 1] - - # seq_lens = self.seq_lens_cpu[:num_reqs] - - max_context_lens = self.input_batch.num_computed_tokens_cpu[:num_reqs].max() - - attn_metadata = MSAttentionMetadata( - block_tables=(self.input_batch.block_table.get_device_tensor()[:num_reqs]), - seq_lens=seq_lens, - context_lens=context_lens, - max_query_len=1, - slot_mapping=slot_mapping, - num_decode_tokens=num_decode_tokens, - query_lens=query_len - ) - use_spec_decode = len( - scheduler_output.scheduled_spec_decode_tokens) > 0 - if use_spec_decode: - logits_indices = self._calc_spec_decode_metadata( - scheduler_output, cu_num_tokens) - else: - # NOTE(woosuk): Due to chunked prefills, the batch may contain - # partial requests. While we should not sample any token - # from these partial requests, we do so for simplicity. - # We will ignore the sampled tokens from the partial requests. - # TODO: Support prompt logprobs. - logits_indices = query_start_loc[1:] - 1 - - # Hot-Swap lora model - if self.lora_config: - self.set_active_loras(self.input_batch, num_scheduled_tokens) - - return attn_metadata, logits_indices - - -def create_block(shape, dtype, name=None, device=None): - from mindspore import mint - blocks = mint.empty(shape, dtype=dtype, device=device) - return blocks - -def initialize_kv_cache(self, kv_cache_config) -> None: - """ - Initialize KV cache based on `kv_cache_config`. - Args: - kv_cache_config: Configuration for the KV cache, including the KV - cache size of each layer - """ - if len(kv_cache_config.groups) > 1: - raise NotImplementedError( - "Hybrid models with more than one KV cache type are not " - "supported yet.") - - kv_caches: Dict[str, torch.Tensor] = {} - - # backend = MLABackend if is_use_mla(self.model_config) else FlashAttentionBackend - backend = MsAttentionBackend - - for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items(): - tensor_config = kv_cache_config.tensors[layer_name] - assert tensor_config.size % layer_spec.page_size_bytes == 0 - num_blocks = tensor_config.size // layer_spec.page_size_bytes - if isinstance(layer_spec, FullAttentionSpec): - kv_cache_shape = backend.get_kv_cache_shape( - num_blocks, layer_spec.block_size, layer_spec.num_kv_heads, - layer_spec.head_size) - dtype = layer_spec.dtype - dtype = get_valid_dtype(dtype) - current_cache = [] - device_type = "CPU" if self.device.type == "cpu" else "Ascend" - for i in range(kv_cache_shape[0]): - cache_blocks = create_block( - kv_cache_shape[1:], dtype, device=device_type - ) - current_cache.append(mutable(cache_blocks)) - kv_caches[layer_name] = mutable(tuple(current_cache)) - else: - raise NotImplementedError - - bind_kv_cache( - kv_caches, - self.vllm_config.compilation_config.static_forward_context, - self.kv_caches) - - -def _update_states(self, scheduler_output: "SchedulerOutput") -> None: - """Update the cached states and the persistent batch with the scheduler - output. - - The updated states are used by the `_prepare_inputs` function to create - the input GPU tensors for the model. - - The SamplingMetadata is updated and copied to the GPU if there is a - new/resumed/paused/finished request in the batch. - """ - # Remove finished requests from the cached states. - for req_id in scheduler_output.finished_req_ids: - self.requests.pop(req_id, None) - self.encoder_cache.pop(req_id, None) - # Remove the finished requests from the persistent batch. - # NOTE(woosuk): There could be an edge case where finished_req_ids and - # scheduled_req_ids overlap. This happens when a request is aborted and - # then resubmitted with the same ID. In this case, we treat them as two - # distinct requests - clearing the cached states for the first request - # and handling the second as a new request. - removed_req_indices: List[int] = [] - for req_id in scheduler_output.finished_req_ids: - req_index = self.input_batch.remove_request(req_id) - if req_index is not None: - removed_req_indices.append(req_index) - - # Free the cached encoder outputs. - for req_id, input_id in scheduler_output.free_encoder_input_ids: - encoder_outputs = self.encoder_cache.get(req_id) - if encoder_outputs is not None: - encoder_outputs.pop(input_id, None) - if not encoder_outputs: - self.encoder_cache.pop(req_id, None) - - # Remove the unscheduled requests from the persistent batch. - # NOTE(woosuk): The unscheduled requests are either preempted requests - # or running requests that are not scheduled in this step. We remove - # them from the persistent batch but keep their cached states since - # they will be scheduled again sometime in the future. - scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() - cached_req_ids = self.input_batch.req_id_to_index.keys() - unscheduled_req_ids = cached_req_ids - scheduled_req_ids - # NOTE(woosuk): The persistent batch optimization assumes that - # consecutive batches contain mostly the same requests. If batches - # have low request overlap (e.g., alternating between two distinct - # sets of requests), this optimization becomes very inefficient. - for req_id in unscheduled_req_ids: - req_index = self.input_batch.remove_request(req_id) - assert req_index is not None - removed_req_indices.append(req_index) - - req_ids_to_add: List[str] = [] - # Add new requests to the cached states. - for new_req_data in scheduler_output.scheduled_new_reqs: - req_id = new_req_data.req_id - sampling_params = new_req_data.sampling_params - if sampling_params.sampling_type == SamplingType.RANDOM_SEED: - generator = torch.Generator(device=self.device) - generator.manual_seed(sampling_params.seed) - else: - generator = None - - self.requests[req_id] = CachedRequestState( - req_id=req_id, - prompt_token_ids=new_req_data.prompt_token_ids, - prompt=new_req_data.prompt, - mm_inputs=new_req_data.mm_inputs, - mm_positions=new_req_data.mm_positions, - sampling_params=sampling_params, - generator=generator, - block_ids=new_req_data.block_ids, - num_computed_tokens=new_req_data.num_computed_tokens, - output_token_ids=[], - lora_request=new_req_data.lora_request, - ) + attn_metadata = self.attn_metadata_builder.build( + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + common_prefix_len=0, + ) - # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - if self.uses_mrope: - image_grid_thw = [] - video_grid_thw = [] - second_per_grid_ts = [] - for mm_input in self.requests[req_id].mm_inputs: - if mm_input.get("image_grid_thw") is not None: - image_grid_thw.extend( - mm_input["image_grid_thw"].tolist()) - if mm_input.get("video_grid_thw") is not None: - video_grid_thw.extend( - mm_input["video_grid_thw"].tolist()) - if mm_input.get("second_per_grid_ts") is not None: - second_per_grid_ts.extend( - mm_input["second_per_grid_ts"]) - - hf_config = self.model_config.hf_config - - self.requests[req_id].mrope_positions, \ - self.requests[req_id].mrope_position_delta = \ - MRotaryEmbedding.get_input_positions_tensor( - self.requests[req_id].prompt_token_ids, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - ) - - req_ids_to_add.append(req_id) - - # Update the states of the running/resumed requests. - for req_data in scheduler_output.scheduled_cached_reqs: - req_id = req_data.req_id - req_state = self.requests[req_id] - - # Update the cached states. - num_computed_tokens = req_data.num_computed_tokens - req_state.num_computed_tokens = num_computed_tokens - # Add the sampled token(s) from the previous step (if any). - # This doesn't include "unverified" tokens like spec decode tokens. - num_new_tokens = (num_computed_tokens + - len(req_data.new_token_ids) - - req_state.num_tokens) - if num_new_tokens == 1: - # Avoid slicing list in most common case. - req_state.output_token_ids.append(req_data.new_token_ids[-1]) - elif num_new_tokens > 0: - req_state.output_token_ids.extend( - req_data.new_token_ids[-num_new_tokens:]) - # Update the block IDs. - if not req_data.resumed_from_preemption: - # Append the new blocks to the existing block IDs. - req_state.block_ids.extend(req_data.new_block_ids) - else: - # The request is resumed from preemption. - # Replace the existing block IDs with the new ones. - req_state.block_ids = req_data.new_block_ids - - req_index = self.input_batch.req_id_to_index.get(req_id) - if req_index is None: - # The request is not in the persistent batch. - # The request was either preempted and resumed later, or was not - # scheduled in the previous step and needs to be added again. - req_ids_to_add.append(req_id) - continue - - # Update the persistent batch. - self.input_batch.num_computed_tokens_cpu[req_index] = ( - num_computed_tokens) - start_index = (len(req_state.block_ids) - - len(req_data.new_block_ids)) - self.input_batch.block_table.append_row(req_index, start_index, - req_data.new_block_ids) - # Add new_token_ids to token_ids_cpu. - start_token_index = num_computed_tokens - end_token_index = num_computed_tokens + len(req_data.new_token_ids) - self.input_batch.token_ids_cpu[ - req_index, - start_token_index:end_token_index] = req_data.new_token_ids - # ####################################################### - # self.input_batch.token_ids_cpu_tensor[ - # req_index, - # start_token_index:end_token_index] = torch.from_numpy( - # self.input_batch.token_ids_cpu[ - # req_index, - # start_token_index:end_token_index]) - # ####################################### - self.input_batch.num_tokens_no_spec[req_index] = end_token_index - # Add spec_token_ids to token_ids_cpu. - spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( - req_id, ()) - if spec_token_ids: - start_index = end_token_index - end_token_index += len(spec_token_ids) - self.input_batch.token_ids_cpu[ - req_index, start_index:end_token_index] = spec_token_ids - # NOTE(woosuk): `num_tokens` here may include spec decode tokens. - self.input_batch.num_tokens[req_index] = end_token_index - - - # self.input_batch.token_ids_cpu_tensor.copy_(torch.from_numpy(self.input_batch.token_ids_cpu)) - # Check if the batch has changed. If not, we can skip copying the - # sampling metadata from CPU to GPU. - batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0 - - # Add the new or resumed requests to the persistent batch. - # The smaller empty indices are filled first. - removed_req_indices = sorted(removed_req_indices, reverse=True) - for req_id in req_ids_to_add: - req_state = self.requests[req_id] - if removed_req_indices: - # Fill the empty index. - req_index = removed_req_indices.pop() - else: - # Append to the end. - req_index = None - self.input_batch.add_request(req_state, req_index) - - # self.input_batch.commit() - - # Condense the batched states if there are empty indices. - if removed_req_indices: - self.input_batch.condense(removed_req_indices) - # self.input_batch.commit() - - if batch_changed: - self.input_batch.refresh_sampling_metadata() + input_ids = self.input_ids[:total_num_scheduled_tokens] + attn_metadata.num_input_tokens = total_num_scheduled_tokens + # Run forward pass + with set_forward_context(attn_metadata, self.vllm_config): + assert self.model is not None + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=None, + ) -def get_kv_cache_spec(self) -> KVCacheSpec: - """ - Generates the KVCacheSpec by parsing the kv cache format from each - Attention module in the static forward context. - Returns: - KVCacheSpec: A dictionary mapping layer names to their KV cache - format. Layers that do not need KV cache are not included. - """ - - forward_ctx = self.vllm_config.compilation_config.static_forward_context - block_size = self.vllm_config.cache_config.block_size - kv_cache_spec: KVCacheSpec = {} - for layer_name, attn_module in forward_ctx.items(): - # if isinstance(attn_module, FusedMoE): - # continue - - # TODO: Support other attention modules, e.g., sliding window, - # cross-attention, MLA - #assert isinstance(attn_module, Attention) - if attn_module.attn_type == AttentionType.DECODER: - kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=attn_module.dtype) - elif attn_module.attn_type in (AttentionType.ENCODER, - AttentionType.ENCODER_ONLY): - # encoder-only attention does not need KV cache. - continue - elif attn_module.attn_type == AttentionType.ENCODER_DECODER: - raise NotImplementedError - else: - raise ValueError( - f"Unknown attention type: {attn_module.attn_type}") - - return kv_cache_spec \ No newline at end of file + return hidden_states[cu_num_tokens - 1] diff --git a/vllm_mindspore/worker/worker_v1.py b/vllm_mindspore/worker/worker_v1.py index 685b1967..db106838 100644 --- a/vllm_mindspore/worker/worker_v1.py +++ b/vllm_mindspore/worker/worker_v1.py @@ -15,58 +15,70 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +import gc +from typing import Dict, List + import torch from vllm.logger import init_logger +from vllm.v1.utils import bind_kv_cache +from vllm.v1.kv_cache_interface import FullAttentionSpec + +from vllm_ascend.platform import NPUPlatform logger = init_logger(__name__) @torch.inference_mode() def determine_available_memory(self) -> int: - """Profiles the peak memory usage of the model to determine how much - memory can be used for KV cache without OOMs. + # kv_caches: Dict[str, torch.Tensor] = {} + # kv_cache_spec = self.model_runner.get_kv_cache_spec() + # for layer_name, layer_spec in kv_cache_spec.items(): + # if isinstance(layer_spec, FullAttentionSpec): + # # Use an empty tensor instead of `None`` to force Dynamo to pass + # # it by reference, rather by specializing on the value ``None``. + # npu_k_cache = torch.zeros([0, 0, 0, 0], + # dtype=layer_spec.dtype, + # device=self.device) + # npu_v_cache = torch.zeros([0, 0, 0, 0], + # dtype=layer_spec.dtype, + # device=self.device) + # kv_caches[layer_name] = (npu_k_cache, npu_v_cache) + # else: + # raise NotImplementedError - The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the free memory that can be used for KV cache in - bytes. + # runner_kv_caches: List[torch.Tensor] = [] + # bind_kv_cache( + # kv_caches, + # self.vllm_config.compilation_config.static_forward_context, + # runner_kv_caches) - .. tip:: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. - """ - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + NPUPlatform.empty_cache() - _, total_gpu_memory = torch.cuda.mem_get_info() # Execute a forward pass with dummy inputs to profile the memory usage # of the model. self.model_runner.profile_run() - free_gpu_memory, _ = torch.cuda.mem_get_info() + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + free_npu_memory, total_npu_memory = NPUPlatform.mem_get_info() # NOTE(woosuk): Here we assume that the other processes using the same # GPU did not change their memory usage during the profiling. - assert self.init_gpu_memory > free_gpu_memory, ( + peak_memory = self.init_npu_memory - free_npu_memory + assert peak_memory > 0, ( "Error in memory profiling. " - f"Initial free memory {self.init_gpu_memory}, current free memory" - f" {free_gpu_memory}. This happens when the GPU memory was " + f"Initial free memory {self.init_npu_memory}, current free memory" + f" {free_npu_memory}. This happens when the NPU memory was " "not properly cleaned up before initializing the vLLM instance.") - # Get the peak memory allocation recorded by torch - peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"] - - # Check for any memory left around that may have been allocated on the - # gpu outside of `torch`. NCCL operations, for example, can use a few - # GB during a forward pass - torch.cuda.empty_cache() - torch_allocated_bytes = torch.cuda.memory_stats( - )["allocated_bytes.all.current"] - total_allocated_bytes = torch.cuda.mem_get_info( - )[1] - torch.cuda.mem_get_info()[0] - non_torch_allocations = total_allocated_bytes - torch_allocated_bytes - if non_torch_allocations > 0: - peak_memory += non_torch_allocations - available_kv_cache_memory = ( - total_gpu_memory * self.cache_config.gpu_memory_utilization - - peak_memory) - - return int(available_kv_cache_memory) + gc.collect() + # TODO: don`t need impl this func after empty_cache in + # Worker.determine_num_available_blocks() unified` + NPUPlatform.empty_cache() + usable_memory_size = total_npu_memory * self.cache_config.gpu_memory_utilization - peak_memory + npu_kv_cache_bytes = max(usable_memory_size, 0) + logger.info( + f"Available memory: {usable_memory_size}, total memory: {total_npu_memory}" + ) + return int(npu_kv_cache_bytes) -- Gitee From 966c31e95cbdebb66b6d981a3b6c08987d4acb23 Mon Sep 17 00:00:00 2001 From: can-gaa-hou Date: Wed, 4 Jun 2025 03:13:56 +0000 Subject: [PATCH 4/5] adapt v0.8.5 --- vllm_mindspore/patch/patch_vllm_ascend.py | 27 ++++++++++++++++++++++- vllm_mindspore/worker/worker_v1.py | 22 ------------------ 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/vllm_mindspore/patch/patch_vllm_ascend.py b/vllm_mindspore/patch/patch_vllm_ascend.py index 78ec8872..230887bf 100644 --- a/vllm_mindspore/patch/patch_vllm_ascend.py +++ b/vllm_mindspore/patch/patch_vllm_ascend.py @@ -211,7 +211,32 @@ from vllm.v1.worker.gpu_worker import Worker Worker.compile_or_warm_up_model = compile_or_warm_up_model # ================ For vllm-ascend ================ +# ============ For 0.8.5 start =========== +import importlib import types +memory_mod = importlib.import_module("torch.cuda.memory") +if not hasattr(memory_mod, "NPUPluggableAllocator"): + memory_mod.NPUPluggableAllocator = memory_mod.CUDAPluggableAllocator +sys.modules["torch_npu.op_plugin"] = types.ModuleType("torch_npu.op_plugin") +sys.modules["torch_npu.op_plugin.atb"] = types.ModuleType("torch_npu.op_plugin.atb") +fake_mod = types.ModuleType("torch_npu.op_plugin.atb._atb_ops") +fake_mod._register_atb_extensions = lambda *a, **kw: None +sys.modules["torch_npu.op_plugin.atb._atb_ops"] = fake_mod +fake_mod = types.ModuleType("torchair._contrib") +sys.modules["torchair._contrib"] = fake_mod +fake_mod = types.ModuleType("torchair._contrib.custom_torch_ops") +sys.modules["torchair._contrib.custom_torch_ops"] = fake_mod +import torch +if not hasattr(torch, "Tag"): + class _FakeTag: + needs_fixed_stride_order = "needs_fixed_stride_order" + torch.Tag = _FakeTag +fake_fused_moe = types.ModuleType("vllm.model_executor.layers.fused_moe.fused_moe") +fake_fused_moe.direct_register_custom_op = lambda *a, **kw: None +sys.modules["vllm.model_executor.layers.fused_moe.fused_moe"] = fake_fused_moe +import vllm_ascend.ops +vllm_ascend.ops.register_dummy_fusion_op = lambda *a, **kw: None +# ============ For 0.8.5 end =========== fake_mod = types.ModuleType("vllm_ascend.vllm_ascend_C") fake_mod.init_module = fake_mod.python_create_and_map = fake_mod.python_unmap_and_release = lambda *a, **kw: None sys.modules.update({"vllm_ascend.vllm_ascend_C": fake_mod}) @@ -271,4 +296,4 @@ NPUModelRunner.initialize_kv_cache = initialize_kv_cache from vllm_mindspore.v1.worker.gpu_worker import compile_or_warm_up_model from vllm_ascend.worker.worker_v1 import NPUWorker NPUWorker.compile_or_warm_up_model = compile_or_warm_up_model -# ============ For v1 end =========== \ No newline at end of file +# ============ For v1 end =========== diff --git a/vllm_mindspore/worker/worker_v1.py b/vllm_mindspore/worker/worker_v1.py index db106838..8ad5f0c9 100644 --- a/vllm_mindspore/worker/worker_v1.py +++ b/vllm_mindspore/worker/worker_v1.py @@ -30,28 +30,6 @@ logger = init_logger(__name__) @torch.inference_mode() def determine_available_memory(self) -> int: - # kv_caches: Dict[str, torch.Tensor] = {} - # kv_cache_spec = self.model_runner.get_kv_cache_spec() - # for layer_name, layer_spec in kv_cache_spec.items(): - # if isinstance(layer_spec, FullAttentionSpec): - # # Use an empty tensor instead of `None`` to force Dynamo to pass - # # it by reference, rather by specializing on the value ``None``. - # npu_k_cache = torch.zeros([0, 0, 0, 0], - # dtype=layer_spec.dtype, - # device=self.device) - # npu_v_cache = torch.zeros([0, 0, 0, 0], - # dtype=layer_spec.dtype, - # device=self.device) - # kv_caches[layer_name] = (npu_k_cache, npu_v_cache) - # else: - # raise NotImplementedError - - # runner_kv_caches: List[torch.Tensor] = [] - # bind_kv_cache( - # kv_caches, - # self.vllm_config.compilation_config.static_forward_context, - # runner_kv_caches) - # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. NPUPlatform.empty_cache() -- Gitee From 459b8fd1be5f060ec91381b0c197a5487b1aa33a Mon Sep 17 00:00:00 2001 From: can-gaa-hou Date: Tue, 10 Jun 2025 09:11:45 +0000 Subject: [PATCH 5/5] fix v0 ms_attn bugs --- vllm_mindspore/__init__.py | 7 +--- vllm_mindspore/attention/backends/ms_attn.py | 26 ++++++------- vllm_mindspore/patch/patch_vllm_ascend.py | 39 +++++--------------- vllm_mindspore/worker/worker.py | 1 - 4 files changed, 21 insertions(+), 52 deletions(-) diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 37e1f5e7..1a76027f 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -37,7 +37,7 @@ if importlib.util.find_spec("vllm_ascend") is not None: import vllm_mindspore.patch.patch_vllm_ascend else: warnings.warn( - f"vllm-ascend is not imported because: {e}" + f"vllm-ascend is not imported because vllm_ascend is not installed" ) from vllm_mindspore.platforms.ascend import AscendPlatform @@ -96,11 +96,6 @@ else: vllm.model_executor.sampling_metadata.async_tensor_h2d = async_tensor_h2d vllm.model_executor.sampling_metadata.SamplingTensors.from_lists = SamplingTensors.from_lists - # from vllm_mindspore.worker.cache_engine import ( - # ms_allocate_kv_cache, - # ms_swap_in, - # ms_swap_out, - # ) from vllm_mindspore.worker.cache_engine import ( ms_allocate_kv_cache, diff --git a/vllm_mindspore/attention/backends/ms_attn.py b/vllm_mindspore/attention/backends/ms_attn.py index d6123b0a..e76557df 100644 --- a/vllm_mindspore/attention/backends/ms_attn.py +++ b/vllm_mindspore/attention/backends/ms_attn.py @@ -501,17 +501,13 @@ class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]): self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, - batch_size: int, + graph_size: int = -1, ): """Build attention metadata with on-device tensors. Args: seq_lens: The maybe padded sequence lengths of the input sequences. query_lens: The query lengths of the input sequences. - cuda_graph_pad_size: The padding size for cuda graph. - -1 if cuda graph is not used. - batch_size: The maybe padded batch size. """ prefix_cache_hit = any( [ @@ -525,7 +521,6 @@ class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]): ) device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 max_query_len = max(query_lens) decode_query_lens = query_lens[self.num_prefills:] @@ -539,15 +534,12 @@ class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]): query_start_loc = list(accumulate(query_lens, initial=0)) seq_start_loc = list(accumulate(seq_lens, initial=0)) - if use_captured_graph: - raise RuntimeError("Doesnot support captured graph now!") - else: - block_tables = make_tensor_with_pad( - self.block_tables, - pad=-1, - dtype=torch.int, - device=device, - ) + block_tables = make_tensor_with_pad( + self.block_tables, + pad=-1, + dtype=torch.int, + device=device, + ) assert max_query_len > 0, "query_lens: {}".format(query_lens) context_lens_tensor = ms.Tensor(self.context_lens, dtype=ms.int32) @@ -595,6 +587,10 @@ class MsAttentionBackend(AttentionBackend): @staticmethod def get_builder_cls() -> Type["MsAttentionMetadataBuilder"]: return MsAttentionMetadataBuilder + + @classmethod + def make_metadata_builder(cls, *args, **kwargs) -> "MsAttentionMetadataBuilder": + return cls.get_builder_cls()(*args, **kwargs) @staticmethod def get_state_cls() -> Type["AttentionState"]: diff --git a/vllm_mindspore/patch/patch_vllm_ascend.py b/vllm_mindspore/patch/patch_vllm_ascend.py index 230887bf..c9a4fcb9 100644 --- a/vllm_mindspore/patch/patch_vllm_ascend.py +++ b/vllm_mindspore/patch/patch_vllm_ascend.py @@ -3,33 +3,27 @@ import sys import vllm.utils +import vllm.engine.arg_utils +from vllm_mindspore.engine.arg_utils import _is_v1_supported_oracle +vllm.engine.arg_utils.EngineArgs._is_v1_supported_oracle = _is_v1_supported_oracle + from vllm_mindspore.utils import ( - direct_register_custom_op, make_tensor_with_pad, async_tensor_h2d, - get_dtype_size, - ascend_device_count_stateless, ascend_is_initialized, + ms_memory_profiling, ) -vllm.utils.direct_register_custom_op = direct_register_custom_op vllm.utils.make_tensor_with_pad = make_tensor_with_pad vllm.utils.async_tensor_h2d = async_tensor_h2d -vllm.utils.get_dtype_size = get_dtype_size -vllm.utils.cuda_device_count_stateless = ascend_device_count_stateless vllm.utils.cuda_is_initialized = ascend_is_initialized -vllm.config.cuda_device_count_stateless = ascend_device_count_stateless - -import vllm.executor - -vllm.executor.cuda_device_count_stateless = ascend_device_count_stateless +vllm.utils.memory_profiling = ms_memory_profiling from vllm_mindspore.model_executor.models.registry import ( MindSporeModelRegistry, _SUBPROCESS_COMMAND, ) - vllm.config.ModelRegistry = MindSporeModelRegistry import vllm.model_executor @@ -42,25 +36,10 @@ from vllm_mindspore.model_executor.model_loader.utils import get_ms_model_archit # To patching the get_model_architecture, should import it first. from vllm.model_executor.model_loader import get_model_architecture -vllm.model_executor.model_loader.get_model_architecture = get_ms_model_architecture -vllm.model_executor.model_loader.utils.get_model_architecture = ( - get_ms_model_architecture -) -vllm.model_executor.model_loader.loader.get_model_architecture = ( - get_ms_model_architecture -) - -from vllm_mindspore.model_executor.sampling_metadata import ( - SequenceGroupToSample, - SamplingMetadataCache, - SamplingMetadata, -) +from vllm_mindspore.model_executor.sampling_metadata import SamplingTensors -vllm.model_executor.SamplingMetadataCache = SamplingMetadataCache -vllm.model_executor.SamplingMetadata = SamplingMetadata -vllm.model_executor.sampling_metadata.SequenceGroupToSample = SequenceGroupToSample -vllm.model_executor.sampling_metadata.SamplingMetadataCache = SamplingMetadataCache -vllm.model_executor.sampling_metadata.SamplingMetadata = SamplingMetadata +vllm.model_executor.sampling_metadata.async_tensor_h2d = async_tensor_h2d +vllm.model_executor.sampling_metadata.SamplingTensors.from_lists = SamplingTensors.from_lists from vllm_mindspore.model_executor.model_loader.weight_utils import ( safetensors_weights_iterator, diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index 8ce1bc91..2dc69fcd 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -26,7 +26,6 @@ import torch from vllm.config import VllmConfig from vllm.distributed import ( - ensure_kv_transfer_initialized, ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce, -- Gitee