diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index b31f00ce8825f42198425a7d489b2ce1fa149812..2f4e3ce2464f13afdf26186b9cacc99dec3218e5 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -86,12 +86,30 @@ from vllm_mindspore.v1.engine.core import ( vllm.v1.engine.core.DPEngineCoreProc._init_data_parallel = _init_data_parallel vllm.v1.engine.core.DPEngineCoreProc.shutdown = shutdown +from vllm_mindspore.v1.core.kv_cache_utils import get_kv_cache_config + +vllm.v1.core.kv_cache_utils.get_kv_cache_config = get_kv_cache_config +vllm.v1.engine.core.get_kv_cache_config = get_kv_cache_config +from vllm_mindspore.v1.core.single_type_kv_cache_manager import ( + find_longest_cache_hit, spec_manager_map) + +vllm.v1.core.single_type_kv_cache_manager.FullAttentionManager.\ + find_longest_cache_hit = find_longest_cache_hit +vllm.v1.core.single_type_kv_cache_manager.spec_manager_map = spec_manager_map from vllm_mindspore.utils import ( make_tensor_with_pad, async_tensor_h2d, ascend_is_initialized, ms_memory_profiling, ) +from vllm_mindspore.config import CacheDType, _CacheConfig + +vllm.config.CacheConfig = _CacheConfig +vllm.config.CacheDType = CacheDType +import vllm.engine.arg_utils + +vllm.engine.arg_utils.CacheDType = CacheDType +vllm.engine.arg_utils.CacheConfig = _CacheConfig vllm.utils.make_tensor_with_pad = make_tensor_with_pad vllm.utils.async_tensor_h2d = async_tensor_h2d @@ -172,12 +190,17 @@ from vllm_mindspore.worker.cache_engine import ( ms_swap_in, ms_swap_out, ) - +from vllm_mindspore.utils import get_dtype_size import vllm.worker.cache_engine vllm.worker.cache_engine.CacheEngine._allocate_kv_cache = ms_allocate_kv_cache vllm.worker.cache_engine.CacheEngine.swap_in = ms_swap_in vllm.worker.cache_engine.CacheEngine.swap_out = ms_swap_out +vllm.worker.cache_engine.get_dtype_size = get_dtype_size + +import vllm.v1.kv_cache_interface + +vllm.v1.kv_cache_interface.get_dtype_size = get_dtype_size from vllm_mindspore.model_executor.model_loader.weight_utils import ( safetensors_weights_iterator, ) diff --git a/vllm_mindspore/config.py b/vllm_mindspore/config.py index faab407becc0e26d84bd8a73932a28e423ae0d45..28f66acba68cc01bf3ada421cbbdac6aa5652d47 100644 --- a/vllm_mindspore/config.py +++ b/vllm_mindspore/config.py @@ -18,22 +18,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +import hashlib import socket import threading import time from collections import Counter -from typing import Optional, Union +from dataclasses import field +from typing import Any, Literal, Optional, Union, get_args import msgspec import torch import vllm.envs as envs +from pydantic import SkipValidation +from pydantic.dataclasses import dataclass from transformers import PretrainedConfig from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass -from vllm.config import (_STR_DTYPE_TO_TORCH_DTYPE, CompilationConfig, - CompilationLevel, VllmConfig, _find_dtype, - _resolve_auto_dtype) +from vllm.config import (_STR_DTYPE_TO_TORCH_DTYPE, BlockSize, CacheConfig, + CompilationConfig, CompilationLevel, ParallelConfig, + PrefixCachingHashAlgo, VllmConfig, _find_dtype, + _resolve_auto_dtype, config) from vllm.logger import init_logger -from vllm.utils import random_uuid +from vllm.utils import GiB_bytes, get_cpu_memory, random_uuid from vllm_mindspore.utils import is_310p @@ -395,3 +400,157 @@ def stateless_destroy_socket_process_group( dp_group.close() logger.info("Socket process group for rank %d destroyed.", dp_group.rank) + + +CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "int8"] + + +@config +@dataclass +class _CacheConfig(CacheConfig): + """Configuration for the KV cache.""" + + block_size: SkipValidation[BlockSize] = None # type: ignore + """Size of a contiguous cache block in number of tokens. This is ignored on + neuron devices and set to `--max-model-len`. On CUDA devices, only block + sizes up to 32 are supported. On HPU devices, block size defaults to 128. + + This config has no static default. If left unspecified by the user, it will + be set in `Platform.check_and_update_configs()` based on the current + platform.""" + gpu_memory_utilization: float = 0.9 + """The fraction of GPU memory to be used for the model executor, which can + range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory + utilization. If unspecified, will use the default value of 0.9. This is a + per-instance limit, and only applies to the current vLLM instance. It does + not matter if you have another vLLM instance running on the same GPU. For + example, if you have two vLLM instances running on the same GPU, you can + set the GPU memory utilization to 0.5 for each instance.""" + swap_space: float = 4 + """Size of the CPU swap space per GPU (in GiB).""" + cache_dtype: CacheDType = "auto" + """Data type for kv cache storage. If "auto", will use model data type. + CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports + fp8 (=fp8_e4m3).""" + is_attention_free: bool = False + """Whether the model is attention-free. This is primarily set in + `ModelConfig` and that value should be manually duplicated here.""" + num_gpu_blocks_override: Optional[int] = None + """Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks` + if specified. Does nothing if `None`. Used for testing preemption.""" + sliding_window: Optional[int] = None + """Sliding window size for the KV cache. This is primarily set in + `ModelConfig` and that value should be manually duplicated here.""" + enable_prefix_caching: Optional[bool] = None + """Whether to enable prefix caching. Disabled by default for V0. Enabled by + default for V1.""" + prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin" + """Set the hash algorithm for prefix caching:\n + - "builtin" is Python's built-in hash.\n + - "sha256" is collision resistant but with certain overheads.""" + cpu_offload_gb: float = 0 + """The space in GiB to offload to CPU, per GPU. Default is 0, which means + no offloading. Intuitively, this argument can be seen as a virtual way to + increase the GPU memory size. For example, if you have one 24 GB GPU and + set this to 10, virtually you can think of it as a 34 GB GPU. Then you can + load a 13B model with BF16 weight, which requires at least 26GB GPU memory. + Note that this requires fast CPU-GPU interconnect, as part of the model is + loaded from CPU memory to GPU memory on the fly in each model forward pass. + """ + calculate_kv_scales: bool = False + """This enables dynamic calculation of `k_scale` and `v_scale` when + kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model + checkpoint if available. Otherwise, the scales will default to 1.0.""" + + # Will be set after profiling. + num_gpu_blocks: Optional[int] = field(default=None, init=False) + """The number of blocks to allocate for GPU memory.""" + num_cpu_blocks: Optional[int] = field(default=None, init=False) + """The number of blocks to allocate for CPU memory.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + factors.append(self.cache_dtype) + # `cpu_offload_gb` does not use `torch.compile` yet. + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self) -> None: + self.swap_space_bytes = self.swap_space * GiB_bytes + + self._verify_args() + self._verify_cache_dtype() + self._verify_prefix_caching() + + def metrics_info(self): + # convert cache_config to dict(key: str, value: str) for prometheus + # metrics info + return {key: str(value) for key, value in self.__dict__.items()} + + def _verify_args(self) -> None: + if self.cpu_offload_gb < 0: + raise ValueError("CPU offload space must be non-negative" + f", but got {self.cpu_offload_gb}") + + if self.gpu_memory_utilization > 1.0: + raise ValueError( + "GPU memory utilization must be less than 1.0. Got " + f"{self.gpu_memory_utilization}.") + + def _verify_cache_dtype(self) -> None: + if self.cache_dtype == "auto": + pass + elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2", "int8"): + logger.info( + "Using fp8 data type to store kv cache. It reduces the GPU " + "memory footprint and boosts the performance. " + "Meanwhile, it may cause accuracy drop without a proper " + "scaling factor") + else: + raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") + + def _verify_prefix_caching(self) -> None: + if not self.enable_prefix_caching: + return + + if self.sliding_window is not None and not envs.VLLM_USE_V1: + raise NotImplementedError( + "Prefix caching is not supported with sliding window. " + "Run with --disable-sliding-window to use prefix caching.") + + if (self.enable_prefix_caching and self.prefix_caching_hash_algo + not in get_args(PrefixCachingHashAlgo)): + raise ValueError( + "Unknown prefix caching hash algorithm: " + f"{self.prefix_caching_hash_algo}. Must be one of " + f"{get_args(PrefixCachingHashAlgo)}.") + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + total_cpu_memory = get_cpu_memory() + # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel + # group are in the same node. However, the GPUs may span multiple nodes. + num_gpus_per_node = parallel_config.tensor_parallel_size + cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node + + msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the " + f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory " + "is allocated for the swap space.") + if cpu_memory_usage > 0.7 * total_cpu_memory: + raise ValueError("Too large swap space. " + msg) + elif cpu_memory_usage > 0.4 * total_cpu_memory: + logger.warning("Possibly too large swap space. %s", msg) diff --git a/vllm_mindspore/engine/arg_utils.py b/vllm_mindspore/engine/arg_utils.py index a7bfc2204126628a8d66e620b3e6ce195db50b37..0f0584918f57ccff26d678ede4dac7a282949b81 100644 --- a/vllm_mindspore/engine/arg_utils.py +++ b/vllm_mindspore/engine/arg_utils.py @@ -92,21 +92,21 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: return False # No Fp8 KV cache so far. - if self.kv_cache_dtype != "auto": - fp8_attention = self.kv_cache_dtype.startswith("fp8") - will_use_fa = (current_platform.is_cuda() - and not envs.is_set("VLLM_ATTENTION_BACKEND") - ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" - supported = False - if current_platform.is_rocm(): - supported = True - elif fp8_attention and will_use_fa: - from vllm.attention.utils.fa_utils import flash_attn_supports_fp8 - supported = flash_attn_supports_fp8() - if not supported: - _raise_or_fallback(feature_name="--kv-cache-dtype", - recommend_to_remove=False) - return False + # if self.kv_cache_dtype != "auto": + # fp8_attention = self.kv_cache_dtype.startswith("fp8") + # will_use_fa = ( + # current_platform.is_cuda() + # and not envs.is_set("VLLM_ATTENTION_BACKEND") + # ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" + # supported = False + # if fp8_attention and will_use_fa: + # from vllm.vllm_flash_attn.fa_utils import ( + # flash_attn_supports_fp8) + # supported = flash_attn_supports_fp8() + # if not supported: + # _raise_or_fallback(feature_name="--kv-cache-dtype", + # recommend_to_remove=False) + # return False # No Prompt Adapter so far. if self.enable_prompt_adapter: diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py index d13e4351c8252134be298c3172b01fe331e13ec9..4b422fc44e1c0c86733e0f838ca6be935a0d1257 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py @@ -39,6 +39,7 @@ from vllm.sequence import IntermediateTensors from vllm_mindspore.model_executor.models.attention_mask import ( LowerTriangularMask) from vllm_mindspore.model_executor.models.model_base import MsModelBase +from vllm_mindspore.utils import get_fa3_no_quant_layers try: # Need to apply dllm pd patch on vllm to use pd disagg related functions @@ -77,6 +78,23 @@ class MfModelBase(MsModelBase): self.use_mla_op = \ bool(vllm_config.additional_config and vllm_config.additional_config.get('use_mla_op') == 1) + self.fa3_quant = \ + bool(vllm_config.additional_config + and vllm_config.additional_config.get('fa3_quant') == 1) + if self.fa3_quant and vllm_config.cache_config.cache_dtype == "auto": + raise RuntimeError( + "To use fa3_quant, you need to set " + "\"--kv-cache-dtype 'int8'\" in the startup command.") + if self.fa3_quant and not self.use_mla_op: + raise RuntimeError( + "To use fa3_quant, you need to set " + "\"\'use_mla_op\': 1\" in the additional-config of " + "startup command.") + if self.fa3_quant: + self.fa3_no_quant_layers = get_fa3_no_quant_layers() + else: + self.fa3_no_quant_layers = [] + self._generate_model_config() if not hasattr(self, 'mf_model_config'): raise RuntimeError('mf_model_config not initialized') diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 4d35d968449d602115bdab34f01f9fcde2fcc091..3a8c20bb10cbadec156b980ab6f9d90d8ef3459a 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -22,7 +22,7 @@ from typing import Any, Optional, Union, cast import mindspore as ms import numpy as np import vllm.envs as envs -from mindspore import Tensor, mutable, nn +from mindspore import Tensor, mutable, nn, ops from mindspore.common import dtype as mstype from vllm.attention.backends.abstract import AttentionType from vllm.config import VllmConfig, get_current_vllm_config @@ -67,12 +67,15 @@ class AttentionWrapper: class MLAAttentionWrapper(AttentionWrapper): - def __init__(self): + def __init__(self, kv_cache_dtype=None): super().__init__() vllm_config = get_current_vllm_config() self.use_mla_op = bool( vllm_config.additional_config and vllm_config.additional_config.get('use_mla_op') == 1) + if kv_cache_dtype is None: + kv_cache_dtype = vllm_config.model_config.dtype + self.dtype = kv_cache_dtype if not self.use_mla_op: self.kv_cache = [ ( @@ -82,21 +85,39 @@ class MLAAttentionWrapper(AttentionWrapper): range(vllm_config.parallel_config.pipeline_parallel_size) ] else: + self.fa3_quant = bool( + vllm_config.additional_config + and vllm_config.additional_config.get('fa3_quant') == 1) kv_lora_rank = getattr(vllm_config.model_config.hf_text_config, 'kv_lora_rank', 0) qk_rope_head_dim = getattr(vllm_config.model_config.hf_text_config, 'qk_rope_head_dim', 0) # k_shape, r_shape used for mla_op - k_shape = [*(self.kv_shape[0:-1]), kv_lora_rank - ] if self.use_mla_op else None - r_shape = [*(self.kv_shape[0:-1]), qk_rope_head_dim - ] if self.use_mla_op else None - self.kv_cache = [ - (ms.mint.zeros(k_shape, dtype=vllm_config.model_config.dtype), - ms.mint.zeros(r_shape, dtype=vllm_config.model_config.dtype)) - for _ in range( - vllm_config.parallel_config.pipeline_parallel_size) - ] + if self.fa3_quant: + k_shape = [*(self.kv_shape[0:-2]), kv_lora_rank] + r_shape = [*(self.kv_shape[0:-2]), qk_rope_head_dim] + self.kv_cache = [ + (ops.auto_generate.format_cast( + ms.mint.zeros(k_shape, dtype=kv_cache_dtype), 29), + ops.auto_generate.format_cast( + ms.mint.zeros(r_shape, + dtype=vllm_config.model_config.dtype), + 29)) + for _ in range( + vllm_config.parallel_config.pipeline_parallel_size) + ] + + else: + k_shape = [*(self.kv_shape[0:-1]), kv_lora_rank] + r_shape = [*(self.kv_shape[0:-1]), qk_rope_head_dim] + self.kv_cache = [ + (ms.mint.zeros(k_shape, + dtype=vllm_config.model_config.dtype), + ms.mint.zeros(r_shape, + dtype=vllm_config.model_config.dtype)) + for _ in range( + vllm_config.parallel_config.pipeline_parallel_size) + ] class MsModelBase: diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index 445d49985ccc8e0044b0cb93a6ca3fcbe9c694c7..40beeb163de458a003e60029a461c1704b5c193a 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -54,6 +54,7 @@ STR_DTYPE_TO_MS_DTYPE = { "fp8": ms.uint8, "fp8_e4m3": ms.uint8, "fp8_e5m2": ms.uint8, + "int8": ms.int8, } FORMAT_TYPE = { @@ -77,6 +78,11 @@ def create_kv_cache(kv_shape, dtype): return ms.mint.zeros(kv_shape, dtype=dtype) +def get_fa3_no_quant_layers(): + return ["0", "1", "2", "46", "47", "50", "54", "55", "56", \ + "57", "58", "59", "60"] + + def get_valid_dtype(dtype): if isinstance(dtype, str): dtype = STR_DTYPE_TO_MS_DTYPE[dtype] @@ -158,6 +164,7 @@ STR_DTYPE_TO_TENSOR_DTYPE = { "fp8": torch.uint8, "fp8_e4m3": torch.uint8, "fp8_e5m2": torch.uint8, + "int8": torch.int8, } STR_DTYPE_TO_MS_DTYPE = { @@ -168,9 +175,20 @@ STR_DTYPE_TO_MS_DTYPE = { "fp8": mstype.uint8, "fp8_e4m3": mstype.uint8, "fp8_e5m2": mstype.uint8, + "int8": mstype.int8, } +def get_kv_cache_dtype(vllm_config): + """get_kv_cache_dtype""" + kv_cache_dtype = vllm_config.model_config.dtype if \ + vllm_config.cache_config.cache_dtype == "auto" \ + else vllm_config.cache_config.cache_dtype + if kv_cache_dtype in STR_DTYPE_TO_MS_DTYPE: + kv_cache_dtype = STR_DTYPE_TO_MS_DTYPE[kv_cache_dtype] + return kv_cache_dtype + + class vllmModelBackendEnum(str, Enum): """Define the variable Enum of VLLM_MS_MODEL_BACKEND""" MF = 'MindFormers' diff --git a/vllm_mindspore/v1/core/kv_cache_utils.py b/vllm_mindspore/v1/core/kv_cache_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..777be9b44060ecf5943b0a62b36d3775ff80c81d --- /dev/null +++ b/vllm_mindspore/v1/core/kv_cache_utils.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/v1/core/kv_cache_utils.py +# +# Copyright 2025 Huawei Technologies Co., Ltd. +# Copyright 2024-2025 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from vllm.config import VllmConfig +from vllm.utils import cdiv +from vllm.v1.core.kv_cache_utils import ( + _get_kv_cache_config_uniform_page_size, _get_kv_cache_config_uniform_type, + check_enough_kv_cache_memory, create_kv_cache_group_specs, + is_kv_cache_page_size_uniform, is_kv_cache_type_uniform, logger, + unify_hybrid_kv_cache_specs) +from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec, + KVCacheTensor) + + +def get_max_concurrency_for_kv_cache_config_diff_page_size( + vllm_config: VllmConfig, kv_cache_config: KVCacheConfig) -> float: + """ + Get the maximum concurrency for the given KV cache configuration. + """ + block_size = kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + num_block_per_request = cdiv(vllm_config.model_config.max_model_len, + block_size) + max_concurrency = kv_cache_config.num_blocks / num_block_per_request + return max_concurrency + + +def _get_kv_cache_config_not_uniform(vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int) -> KVCacheConfig: + """ + Generates the KV cache configuration for a model with one type of KV cache. + Divide the available memory equally among all layers. + + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The kv cache spec of each attention layer in the model + available_memory: Memory available for KV cache in bytes. + + Returns: + The generated KVCacheConfig + """ + + # different layers may have different page_size_bytes + page_sizes = sum( + [layer.page_size_bytes for layer in kv_cache_spec.values()]) + num_blocks = int(available_memory // page_sizes) + # create different groups based on page_size_bytes + page_size_layer_map: dict[int, list[str]] = {} + kv_cache_tensors = [] + + for layer_name, layer in kv_cache_spec.items(): + kv_cache_tensors.append( + KVCacheTensor(size=num_blocks * layer.page_size_bytes, + shared_by=[layer_name])) + page_size_layer_map.setdefault(layer.page_size_bytes, + []).append(layer_name) + + grouped_layer_names = list(page_size_layer_map.values()) + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=kv_cache_tensors, + kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec, + grouped_layer_names), + ) + + num_tokens = num_blocks * vllm_config.cache_config.block_size + num_tokens_str = f"{num_tokens:,}" + logger.info("GPU KV cache size: %s tokens", num_tokens_str) + max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" + max_concurrency = get_max_concurrency_for_kv_cache_config_diff_page_size( + vllm_config, kv_cache_config) + logger.info("Maximum concurrency for %s tokens per request: %.2fx", + max_model_len_str, max_concurrency) + return kv_cache_config + + +def get_kv_cache_config( + vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int, +) -> KVCacheConfig: + """ + Generates the KV cache configuration for a model. + + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The kv cache spec of each attention layer in the model + available_memory: Memory available for KV cache in bytes. + + Returns: + The generated KVCacheConfigs + """ + check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) + + if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager: + unify_hybrid_kv_cache_specs(kv_cache_spec) + + if is_kv_cache_type_uniform(kv_cache_spec): + # KV cache of all layers are the same, which is true for + # most models. Allocate the same amount of memory for + # each layer. + return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, + available_memory) + elif is_kv_cache_page_size_uniform(kv_cache_spec): + # Model contains multiple attention types, but KV cache of all layers + # have the same physical memory per block per layer. Split the layers + # into groups with the same number of layers, and thus same total page + # size. + return _get_kv_cache_config_uniform_page_size(vllm_config, + kv_cache_spec, + available_memory) + + return _get_kv_cache_config_not_uniform(vllm_config, kv_cache_spec, + available_memory) diff --git a/vllm_mindspore/v1/core/single_type_kv_cache_manager.py b/vllm_mindspore/v1/core/single_type_kv_cache_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..5f9090b577efa9f28802c45ff20862d387eaaf6b --- /dev/null +++ b/vllm_mindspore/v1/core/single_type_kv_cache_manager.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/v1/core/single_type_kv_cache_manager.py +# +# Copyright 2025 Huawei Technologies Co., Ltd. +# Copyright 2024-2025 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock +from vllm.v1.core.single_type_kv_cache_manager import ( + FullAttentionManager, SingleTypeKVCacheManager, SlidingWindowManager) +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, + SlidingWindowSpec) + +from vllm_mindspore.v1.kv_cache_interface import MLAQuantFullAttentionSpec + + +@classmethod +def find_longest_cache_hit( + cls, + block_hashes: list[BlockHash], + max_length: int, + kv_cache_group_ids: list[int], + block_pool: BlockPool, + kv_cache_spec: KVCacheSpec, + use_eagle: bool, +) -> tuple[list[KVCacheBlock], ...]: + assert isinstance( + kv_cache_spec, (FullAttentionSpec, MLAQuantFullAttentionSpec)), ( + "FullAttentionManager can only be used for full attention " + "or mla quant full attention groups") + computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( + [] for _ in range(len(kv_cache_group_ids))) + max_num_blocks = max_length // kv_cache_spec.block_size + for i, block_hash in zip(range(max_num_blocks), block_hashes): + # block_hashes is a chain of block hashes. If a block hash is not + # in the cached_block_hash_to_id, the following block hashes are + # not computed yet for sure. + if cached_block := block_pool.get_cached_block(block_hash, + kv_cache_group_ids): + for computed, cached in zip(computed_blocks, cached_block): + computed.append(cached) + else: + break + if use_eagle and computed_blocks[0]: + for computed in computed_blocks: + computed.pop() + return computed_blocks + + +spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { + FullAttentionSpec: FullAttentionManager, + MLAQuantFullAttentionSpec: FullAttentionManager, + SlidingWindowSpec: SlidingWindowManager, +} diff --git a/vllm_mindspore/v1/kv_cache_interface.py b/vllm_mindspore/v1/kv_cache_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..37e253f0c1712e04934b62e780db9c2237a89375 --- /dev/null +++ b/vllm_mindspore/v1/kv_cache_interface.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/v1/kv_cache_interface.py +# +# Copyright 2025 Huawei Technologies Co., Ltd. +# Copyright 2024-2025 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from vllm.config import VllmConfig +from vllm.utils import cdiv +from vllm.v1.kv_cache_interface import AttentionSpec + + +class MLAQuantFullAttentionSpec(AttentionSpec): + + @property + def type_id(self) -> str: + return f"mla_quant_full_attention_{self.block_size}" \ + f"_{self.page_size_bytes}" + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + max_model_len = vllm_config.model_config.max_model_len + return cdiv(max_model_len, self.block_size) * self.page_size_bytes + + @property + def page_size_bytes(self) -> int: + # For MLA we only store a single latent vector + coef = 1 if self.use_mla else 2 + assert self.head_size == 576 + ctkv_nope_dim = 512 + qk_rope_dim = 64 + return coef * self.block_size * self.num_kv_heads * (ctkv_nope_dim + + qk_rope_dim * 2) diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index b37b39bc1f718473c16cc86a3cbcb28fc73fc70a..11de216f59baa9216553eb07bd997a1d84f9a47a 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -25,7 +25,7 @@ import mindspore as ms import numpy as np import torch from mindspore import Generator as msGenerator -from mindspore import Tensor, mint, mutable +from mindspore import Tensor, mint, mutable, ops from vllm.attention import AttentionType from vllm.logger import init_logger from vllm.sampling_params import SamplingType @@ -40,6 +40,7 @@ from vllm.v1.worker.utils import initialize_kv_cache_for_kv_sharing from vllm_mindspore.model_executor.layers.rotary_embedding import ( InferMRotaryEmbedding as MRotaryEmbedding) from vllm_mindspore.utils import (create_kv_cache, get_dtype_size, + get_fa3_no_quant_layers, get_valid_dtype, is_310p) logger = init_logger(__name__) @@ -293,16 +294,42 @@ def _allocate_kv_cache_tensors(self, kv_cache_config): use_mla_op = bool( self.vllm_config.additional_config and self.vllm_config.additional_config.get('use_mla_op') == 1) + fa3_quant = bool( + self.vllm_config.additional_config + and self.vllm_config.additional_config.get('fa3_quant') == 1) kv_lora_rank = getattr(self.vllm_config.model_config.hf_text_config, 'kv_lora_rank', 0) qk_rope_head_dim = getattr(self.vllm_config.model_config.hf_text_config, 'qk_rope_head_dim', 0) + fa3_no_quant_layers = get_fa3_no_quant_layers() + + def get_dtype_from_groups(layer_name, kv_cache_groups): + for group in kv_cache_groups: + if layer_name in group.layer_names: + kv_cache_spec = group.kv_cache_spec + use_mla = kv_cache_spec.use_mla + dtype = kv_cache_spec.dtype + coef = 1 if use_mla else 2 + return dtype, coef + return None, None kv_cache_raw_tensors: dict[str, Tensor] = {} target_dtype = get_valid_dtype(dtype) dtype_size = get_dtype_size(target_dtype) for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + assert len(kv_cache_tensor.shared_by) == 1 raw_tensors = [] + layer_name = kv_cache_tensor.shared_by[0] + if fa3_quant: + target_dtype, coef = get_dtype_from_groups( + layer_name, kv_cache_config.kv_cache_groups) + dtype_size = get_dtype_size(target_dtype) + is_fa3_quant_layer = use_mla_op and fa3_quant and \ + layer_name not in fa3_no_quant_layers + if is_fa3_quant_layer: + raw_tensor_shape = kv_cache_tensor.size + else: + raw_tensor_shape = kv_cache_tensor.size // dtype_size // coef raw_tensor_shape = kv_cache_tensor.size // dtype_size // coef for i in range(coef): """ @@ -315,17 +342,31 @@ def _allocate_kv_cache_tensors(self, kv_cache_config): get_dtype_size(self.dtype)) 4. kv cache shape: num_blocks, block_size, num_kv_heads, head_size """ - raw_tensors.extend( - [mint.zeros(raw_tensor_shape, dtype=target_dtype)] - if not use_mla_op else [ + if not use_mla_op: + raw_tensors.extend( + [mint.zeros(raw_tensor_shape, dtype=target_dtype)]) + elif is_fa3_quant_layer: + raw_tensors.extend( + # rope_cache dtype is bfloat16, and the target_dtype is + # int8, so need to add twice qk_rope_head_dim + [ + mint.zeros(int(raw_tensor_shape * kv_lora_rank / + (kv_lora_rank + qk_rope_head_dim * 2)), + dtype=target_dtype), + mint.zeros(int(raw_tensor_shape * qk_rope_head_dim / + (kv_lora_rank + qk_rope_head_dim * 2)), + dtype=self.vllm_config.model_config.dtype) + ]) + else: + raw_tensors.extend([ mint.zeros(int(raw_tensor_shape * kv_lora_rank / (kv_lora_rank + qk_rope_head_dim)), dtype=target_dtype), - # deepseek mla op need key cache and rope cache mint.zeros(int(raw_tensor_shape * qk_rope_head_dim / (kv_lora_rank + qk_rope_head_dim)), dtype=target_dtype) ]) + for layer_name in kv_cache_tensor.shared_by: kv_cache_raw_tensors[layer_name] = tuple(raw_tensors) @@ -357,10 +398,14 @@ def _reshape_kv_cache_tensors( use_mla_op = bool( self.vllm_config.additional_config and self.vllm_config.additional_config.get('use_mla_op') == 1) + fa3_quant = bool( + self.vllm_config.additional_config + and self.vllm_config.additional_config.get('fa3_quant') == 1) kv_lora_rank = getattr(self.vllm_config.model_config.hf_text_config, 'kv_lora_rank', 0) qk_rope_head_dim = getattr(self.vllm_config.model_config.hf_text_config, 'qk_rope_head_dim', 0) + fa3_no_quant_layers = get_fa3_no_quant_layers() kv_caches: dict[str, tuple] = {} for i, kv_cache_group_spec in enumerate(kv_cache_config.kv_cache_groups): kv_cache_spec = kv_cache_group_spec.kv_cache_spec @@ -369,13 +414,18 @@ def _reshape_kv_cache_tensors( raw_tensor = kv_cache_raw_tensors[layer_name] target_dtype = get_valid_dtype(kv_cache_spec.dtype) dtype_size = get_dtype_size(target_dtype) - num_blocks = \ - (raw_tensor[0].numel() - if not use_mla_op else - # deepseek mla op need key cache and rope cache - (raw_tensor[0].numel() + raw_tensor[1].numel())) * \ - coef * dtype_size // kv_cache_spec.page_size_bytes - if isinstance(kv_cache_spec, FullAttentionSpec): + is_fa3_quant_layer = use_mla_op and fa3_quant and \ + layer_name not in fa3_no_quant_layers + if is_fa3_quant_layer: + num_blocks = (raw_tensor[0].numel() + raw_tensor[1].numel() * + 2) * coef // kv_cache_spec.page_size_bytes + else: + num_blocks = \ + (raw_tensor[0].numel() if not use_mla_op else \ + (raw_tensor[0].numel() + raw_tensor[1].numel())) * \ + coef * dtype_size // kv_cache_spec.page_size_bytes + if isinstance(kv_cache_spec, + (FullAttentionSpec, MLAQuantFullAttentionSpec)): kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) @@ -411,7 +461,20 @@ def _reshape_kv_cache_tensors( else: cache_block = kv_cache_raw_tensor.view( kv_cache_shape[1:]).permute(*inv_order[1:]) - kv_cache_layer.append(cache_block) + if fa3_quant: + num_blocks, block_size, _, _ = cache_block.shape + cache_block = ops.reshape(cache_block, + (num_blocks, block_size, -1)) + from mindspore.common.api import _pynative_executor + cache_block_nz = ops.auto_generate.format_cast( + cache_block, 29) + _pynative_executor.sync() + import gc + del cache_block + gc.collect() + kv_cache_layer.append(cache_block_nz) + else: + kv_cache_layer.append(cache_block) kv_caches[layer_name] = mutable(tuple(kv_cache_layer)) else: raise NotImplementedError @@ -664,6 +727,13 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: forward_ctx = self.vllm_config.compilation_config.static_forward_context block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla + use_mla_op = bool( + self.vllm_config.additional_config + and self.vllm_config.additional_config.get('use_mla_op') == 1) + fa3_quant = bool( + self.vllm_config.additional_config + and self.vllm_config.additional_config.get('fa3_quant') == 1) + fa3_no_quant_layers = get_fa3_no_quant_layers() kv_cache_spec: dict[str, KVCacheSpec] = {} for layer_name, attn_module in forward_ctx.items(): """ @@ -680,12 +750,24 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: sliding_window=attn_module.sliding_window, use_mla=use_mla) else: - kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - use_mla=use_mla) + kv_cache_dtype = self.kv_cache_dtype + fa3_quant_layer = layer_name not in fa3_no_quant_layers + if fa3_quant and not fa3_quant_layer: + kv_cache_dtype = self.vllm_config.model_config.dtype + if use_mla_op and fa3_quant and fa3_quant_layer: + kv_cache_spec[layer_name] = MLAQuantFullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=kv_cache_dtype, + use_mla=use_mla) + else: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=kv_cache_dtype, + use_mla=use_mla) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache.