From 5924bbe43712d1e633e9f24fd32d8bb9260c78aa Mon Sep 17 00:00:00 2001 From: ccsszz Date: Thu, 28 Aug 2025 20:30:20 +0800 Subject: [PATCH] deepseek support fa3 quant --- vllm_mindspore/__init__.py | 26 +++ vllm_mindspore/config.py | 169 +++++++++++++++++- vllm_mindspore/engine/arg_utils.py | 18 +- .../models/mf_models/mindformers.py | 22 ++- .../model_executor/models/model_base.py | 44 +++-- vllm_mindspore/utils.py | 4 + vllm_mindspore/v1/core/kv_cache_utils.py | 140 +++++++++++++++ .../v1/core/single_type_kv_cache_manager.py | 68 +++++++ vllm_mindspore/v1/kv_cache_interface.py | 45 +++++ vllm_mindspore/v1/worker/gpu_model_runner.py | 133 ++++++++++++-- 10 files changed, 615 insertions(+), 54 deletions(-) create mode 100644 vllm_mindspore/v1/core/kv_cache_utils.py create mode 100644 vllm_mindspore/v1/core/single_type_kv_cache_manager.py create mode 100644 vllm_mindspore/v1/kv_cache_interface.py diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 84c933ab8..3c9fb8e6f 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -84,6 +84,17 @@ from vllm_mindspore.v1.engine.core import shutdown vllm.v1.engine.core.DPEngineCoreProc.shutdown = shutdown +from vllm_mindspore.v1.core.kv_cache_utils import get_kv_cache_config + +vllm.v1.core.kv_cache_utils.get_kv_cache_config = get_kv_cache_config +vllm.v1.engine.core.get_kv_cache_config = get_kv_cache_config +from vllm_mindspore.v1.core.single_type_kv_cache_manager import ( + find_longest_cache_hit, spec_manager_map) + +vllm.v1.core.single_type_kv_cache_manager.FullAttentionManager.\ + find_longest_cache_hit = find_longest_cache_hit +vllm.v1.core.single_type_kv_cache_manager.spec_manager_map = spec_manager_map + from vllm_mindspore.utils import ( make_tensor_with_pad, async_tensor_h2d, @@ -91,6 +102,15 @@ from vllm_mindspore.utils import ( ms_memory_profiling, ) +from vllm_mindspore.config import CacheDType, _CacheConfig + +vllm.config.CacheConfig = _CacheConfig +vllm.config.CacheDType = CacheDType +import vllm.engine.arg_utils + +vllm.engine.arg_utils.CacheDType = CacheDType +vllm.engine.arg_utils.CacheConfig = _CacheConfig + vllm.utils.make_tensor_with_pad = make_tensor_with_pad vllm.utils.async_tensor_h2d = async_tensor_h2d vllm.utils.cuda_is_initialized = ascend_is_initialized @@ -171,11 +191,17 @@ from vllm_mindspore.worker.cache_engine import ( ms_swap_out, ) +from vllm_mindspore.utils import get_dtype_size import vllm.worker.cache_engine vllm.worker.cache_engine.CacheEngine._allocate_kv_cache = ms_allocate_kv_cache vllm.worker.cache_engine.CacheEngine.swap_in = ms_swap_in vllm.worker.cache_engine.CacheEngine.swap_out = ms_swap_out +vllm.worker.cache_engine.get_dtype_size = get_dtype_size + +import vllm.v1.kv_cache_interface + +vllm.v1.kv_cache_interface.get_dtype_size = get_dtype_size from vllm_mindspore.model_executor.model_loader.weight_utils import ( safetensors_weights_iterator, ) diff --git a/vllm_mindspore/config.py b/vllm_mindspore/config.py index faab407be..28f66acba 100644 --- a/vllm_mindspore/config.py +++ b/vllm_mindspore/config.py @@ -18,22 +18,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +import hashlib import socket import threading import time from collections import Counter -from typing import Optional, Union +from dataclasses import field +from typing import Any, Literal, Optional, Union, get_args import msgspec import torch import vllm.envs as envs +from pydantic import SkipValidation +from pydantic.dataclasses import dataclass from transformers import PretrainedConfig from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass -from vllm.config import (_STR_DTYPE_TO_TORCH_DTYPE, CompilationConfig, - CompilationLevel, VllmConfig, _find_dtype, - _resolve_auto_dtype) +from vllm.config import (_STR_DTYPE_TO_TORCH_DTYPE, BlockSize, CacheConfig, + CompilationConfig, CompilationLevel, ParallelConfig, + PrefixCachingHashAlgo, VllmConfig, _find_dtype, + _resolve_auto_dtype, config) from vllm.logger import init_logger -from vllm.utils import random_uuid +from vllm.utils import GiB_bytes, get_cpu_memory, random_uuid from vllm_mindspore.utils import is_310p @@ -395,3 +400,157 @@ def stateless_destroy_socket_process_group( dp_group.close() logger.info("Socket process group for rank %d destroyed.", dp_group.rank) + + +CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "int8"] + + +@config +@dataclass +class _CacheConfig(CacheConfig): + """Configuration for the KV cache.""" + + block_size: SkipValidation[BlockSize] = None # type: ignore + """Size of a contiguous cache block in number of tokens. This is ignored on + neuron devices and set to `--max-model-len`. On CUDA devices, only block + sizes up to 32 are supported. On HPU devices, block size defaults to 128. + + This config has no static default. If left unspecified by the user, it will + be set in `Platform.check_and_update_configs()` based on the current + platform.""" + gpu_memory_utilization: float = 0.9 + """The fraction of GPU memory to be used for the model executor, which can + range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory + utilization. If unspecified, will use the default value of 0.9. This is a + per-instance limit, and only applies to the current vLLM instance. It does + not matter if you have another vLLM instance running on the same GPU. For + example, if you have two vLLM instances running on the same GPU, you can + set the GPU memory utilization to 0.5 for each instance.""" + swap_space: float = 4 + """Size of the CPU swap space per GPU (in GiB).""" + cache_dtype: CacheDType = "auto" + """Data type for kv cache storage. If "auto", will use model data type. + CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports + fp8 (=fp8_e4m3).""" + is_attention_free: bool = False + """Whether the model is attention-free. This is primarily set in + `ModelConfig` and that value should be manually duplicated here.""" + num_gpu_blocks_override: Optional[int] = None + """Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks` + if specified. Does nothing if `None`. Used for testing preemption.""" + sliding_window: Optional[int] = None + """Sliding window size for the KV cache. This is primarily set in + `ModelConfig` and that value should be manually duplicated here.""" + enable_prefix_caching: Optional[bool] = None + """Whether to enable prefix caching. Disabled by default for V0. Enabled by + default for V1.""" + prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin" + """Set the hash algorithm for prefix caching:\n + - "builtin" is Python's built-in hash.\n + - "sha256" is collision resistant but with certain overheads.""" + cpu_offload_gb: float = 0 + """The space in GiB to offload to CPU, per GPU. Default is 0, which means + no offloading. Intuitively, this argument can be seen as a virtual way to + increase the GPU memory size. For example, if you have one 24 GB GPU and + set this to 10, virtually you can think of it as a 34 GB GPU. Then you can + load a 13B model with BF16 weight, which requires at least 26GB GPU memory. + Note that this requires fast CPU-GPU interconnect, as part of the model is + loaded from CPU memory to GPU memory on the fly in each model forward pass. + """ + calculate_kv_scales: bool = False + """This enables dynamic calculation of `k_scale` and `v_scale` when + kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model + checkpoint if available. Otherwise, the scales will default to 1.0.""" + + # Will be set after profiling. + num_gpu_blocks: Optional[int] = field(default=None, init=False) + """The number of blocks to allocate for GPU memory.""" + num_cpu_blocks: Optional[int] = field(default=None, init=False) + """The number of blocks to allocate for CPU memory.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + factors.append(self.cache_dtype) + # `cpu_offload_gb` does not use `torch.compile` yet. + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self) -> None: + self.swap_space_bytes = self.swap_space * GiB_bytes + + self._verify_args() + self._verify_cache_dtype() + self._verify_prefix_caching() + + def metrics_info(self): + # convert cache_config to dict(key: str, value: str) for prometheus + # metrics info + return {key: str(value) for key, value in self.__dict__.items()} + + def _verify_args(self) -> None: + if self.cpu_offload_gb < 0: + raise ValueError("CPU offload space must be non-negative" + f", but got {self.cpu_offload_gb}") + + if self.gpu_memory_utilization > 1.0: + raise ValueError( + "GPU memory utilization must be less than 1.0. Got " + f"{self.gpu_memory_utilization}.") + + def _verify_cache_dtype(self) -> None: + if self.cache_dtype == "auto": + pass + elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2", "int8"): + logger.info( + "Using fp8 data type to store kv cache. It reduces the GPU " + "memory footprint and boosts the performance. " + "Meanwhile, it may cause accuracy drop without a proper " + "scaling factor") + else: + raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") + + def _verify_prefix_caching(self) -> None: + if not self.enable_prefix_caching: + return + + if self.sliding_window is not None and not envs.VLLM_USE_V1: + raise NotImplementedError( + "Prefix caching is not supported with sliding window. " + "Run with --disable-sliding-window to use prefix caching.") + + if (self.enable_prefix_caching and self.prefix_caching_hash_algo + not in get_args(PrefixCachingHashAlgo)): + raise ValueError( + "Unknown prefix caching hash algorithm: " + f"{self.prefix_caching_hash_algo}. Must be one of " + f"{get_args(PrefixCachingHashAlgo)}.") + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + total_cpu_memory = get_cpu_memory() + # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel + # group are in the same node. However, the GPUs may span multiple nodes. + num_gpus_per_node = parallel_config.tensor_parallel_size + cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node + + msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the " + f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory " + "is allocated for the swap space.") + if cpu_memory_usage > 0.7 * total_cpu_memory: + raise ValueError("Too large swap space. " + msg) + elif cpu_memory_usage > 0.4 * total_cpu_memory: + logger.warning("Possibly too large swap space. %s", msg) diff --git a/vllm_mindspore/engine/arg_utils.py b/vllm_mindspore/engine/arg_utils.py index fb72f32ce..f7a3d6aa6 100644 --- a/vllm_mindspore/engine/arg_utils.py +++ b/vllm_mindspore/engine/arg_utils.py @@ -95,22 +95,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: recommend_to_remove=False) return False - # No Fp8 KV cache so far. - if self.kv_cache_dtype != "auto": - fp8_attention = self.kv_cache_dtype.startswith("fp8") - will_use_fa = (current_platform.is_cuda() - and not envs.is_set("VLLM_ATTENTION_BACKEND") - ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" - supported = False - if current_platform.is_rocm(): - supported = True - elif fp8_attention and will_use_fa: - from vllm.attention.utils.fa_utils import flash_attn_supports_fp8 - supported = flash_attn_supports_fp8() - if not supported: - _raise_or_fallback(feature_name="--kv-cache-dtype", - recommend_to_remove=False) - return False + # The validation of kv_cache_dtype has been removed, as in fa3_quant, + # the kv_cache_dtype is required to be of type int8. # No Prompt Adapter so far. if self.enable_prompt_adapter: diff --git a/vllm_mindspore/model_executor/models/mf_models/mindformers.py b/vllm_mindspore/model_executor/models/mf_models/mindformers.py index 74bf154b4..34be519b5 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mindformers.py +++ b/vllm_mindspore/model_executor/models/mf_models/mindformers.py @@ -66,7 +66,19 @@ class MindFormersForCausalLM(MsModelBase, SupportsPP): self.network, self.lm_head = self._create_network() self.casual_mask = self._create_mask() - + self.fa3_quant = self.network.quant_config.fa3_quant \ + if self.network.quant_config else False + self.fa3_quant_layer = self.network.quant_config.fa3_quant_layer \ + if self.network.quant_config else set() + print("vllm_config.cache_config.cache_dtype:", vllm_config.cache_config.cache_dtype) + if self.fa3_quant and vllm_config.cache_config.cache_dtype == "auto": + raise RuntimeError( + "To use fa3_quant, you need to set " + "\"--kv-cache-dtype 'int8'\" in the startup command.") + if self.fa3_quant and not self.use_ringmla: + raise ValueError(f'To use fa3_quant, it is necessary to set use_ringmla to True.') + # used when allocate the kvcache in GPUModelRunner + vllm_config.quant_config = self.network.quant_config self._set_dynamic_inputs() self.set_modules({"model": self.network}) @@ -105,7 +117,13 @@ class MindFormersForCausalLM(MsModelBase, SupportsPP): # Initial kv_caches wrapper_func = (MLAAttentionWrapper if self.mla_config else AttentionWrapper) - return [wrapper_func() for _ in range(num_layers)] + if self.fa3_quant: + return [wrapper_func(fa3_quant=True, kv_cache_dtype=ms.int8) \ + if self.fa3_quant and i in self.fa3_quant_layer \ + else wrapper_func(fa3_quant=True, kv_cache_dtype=self.model_config.dtype) \ + for i in range(num_layers)] + else: + return [wrapper_func() for _ in range(num_layers)] def get_kvcache(self): if not self.mla_config: diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 0634f6052..a10528d94 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -22,7 +22,7 @@ from typing import Any, Optional, Union, cast import mindspore as ms import numpy as np import vllm.envs as envs -from mindspore import Tensor, mutable, nn +from mindspore import Tensor, mutable, nn, ops from mindspore.common import dtype as mstype from vllm.attention.backends.abstract import AttentionType from vllm.config import VllmConfig, get_current_vllm_config @@ -67,10 +67,13 @@ class AttentionWrapper: class MLAAttentionWrapper(AttentionWrapper): - def __init__(self): + def __init__(self, fa3_quant=False, kv_cache_dtype=None): super().__init__() vllm_config = get_current_vllm_config() self.use_ringmla = is_use_ringmla(vllm_config) + if kv_cache_dtype is None: + kv_cache_dtype = vllm_config.model_config.dtype + self.dtype = kv_cache_dtype if not self.use_ringmla: self.kv_cache = [ ( @@ -85,16 +88,33 @@ class MLAAttentionWrapper(AttentionWrapper): qk_rope_head_dim = getattr(vllm_config.model_config.hf_text_config, 'qk_rope_head_dim', 0) # k_shape, r_shape used for mla_op - k_shape = [*(self.kv_shape[0:-1]), kv_lora_rank - ] if self.use_ringmla else None - r_shape = [*(self.kv_shape[0:-1]), qk_rope_head_dim - ] if self.use_ringmla else None - self.kv_cache = [ - (ms.mint.zeros(k_shape, dtype=vllm_config.model_config.dtype), - ms.mint.zeros(r_shape, dtype=vllm_config.model_config.dtype)) - for _ in range( - vllm_config.parallel_config.pipeline_parallel_size) - ] + if fa3_quant: + # num_block is set to 1 because setting it to 0, + # format_cast ops may not recycle device memory + k_shape = [1, *(self.kv_shape[1:-2]), kv_lora_rank] + r_shape = [1, *(self.kv_shape[1:-2]), qk_rope_head_dim] + self.kv_cache = [ + (ops.auto_generate.format_cast( + ms.mint.zeros(k_shape, dtype=kv_cache_dtype), 29), + ops.auto_generate.format_cast( + ms.mint.zeros(r_shape, + dtype=vllm_config.model_config.dtype), + 29)) + for _ in range( + vllm_config.parallel_config.pipeline_parallel_size) + ] + + else: + k_shape = [*(self.kv_shape[0:-1]), kv_lora_rank] + r_shape = [*(self.kv_shape[0:-1]), qk_rope_head_dim] + self.kv_cache = [ + (ms.mint.zeros(k_shape, + dtype=vllm_config.model_config.dtype), + ms.mint.zeros(r_shape, + dtype=vllm_config.model_config.dtype)) + for _ in range( + vllm_config.parallel_config.pipeline_parallel_size) + ] class MsModelBase: diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index 8690fa57b..ca264031e 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -59,6 +59,7 @@ STR_DTYPE_TO_MS_DTYPE = { "fp8": ms.uint8, "fp8_e4m3": ms.uint8, "fp8_e5m2": ms.uint8, + "int8": ms.int8, } FORMAT_TYPE = { @@ -163,6 +164,7 @@ STR_DTYPE_TO_TENSOR_DTYPE = { "fp8": torch.uint8, "fp8_e4m3": torch.uint8, "fp8_e5m2": torch.uint8, + "int8": torch.int8, } STR_DTYPE_TO_MS_DTYPE = { @@ -173,6 +175,7 @@ STR_DTYPE_TO_MS_DTYPE = { "fp8": mstype.uint8, "fp8_e4m3": mstype.uint8, "fp8_e5m2": mstype.uint8, + "int8": mstype.int8, } @@ -301,6 +304,7 @@ def check_ready(): # Common environment variables of predict. set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + set_context(save_graphs=True, save_graphs_path="./ir_828") default_env = { "MS_INTERNAL_DISABLE_CUSTOM_KERNEL_LIST": "FlashAttentionScore,PagedAttention", diff --git a/vllm_mindspore/v1/core/kv_cache_utils.py b/vllm_mindspore/v1/core/kv_cache_utils.py new file mode 100644 index 000000000..cc797d461 --- /dev/null +++ b/vllm_mindspore/v1/core/kv_cache_utils.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/v1/core/kv_cache_utils.py +# +# Copyright 2025 Huawei Technologies Co., Ltd. +# Copyright 2024-2025 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from vllm.config import VllmConfig +from vllm.utils import cdiv +from vllm.v1.core.kv_cache_utils import ( + _get_kv_cache_config_uniform_page_size, _get_kv_cache_config_uniform_type, + check_enough_kv_cache_memory, create_kv_cache_group_specs, + is_kv_cache_page_size_uniform, is_kv_cache_type_uniform, logger, + unify_hybrid_kv_cache_specs) +from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec, + KVCacheTensor) + + +def get_max_concurrency_for_kv_cache_config_diff_page_size( + vllm_config: VllmConfig, kv_cache_config: KVCacheConfig) -> float: + """ + Get the maximum concurrency for the given KV cache configuration. + """ + block_size = kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + num_block_per_request = cdiv(vllm_config.model_config.max_model_len, + block_size) + max_concurrency = kv_cache_config.num_blocks / num_block_per_request + return max_concurrency + + +def _get_kv_cache_config_not_uniform(vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int) -> KVCacheConfig: + """ + Generates the KV cache configuration for a model with different page_size + of KV cache. now only use for fa3 quant deepseek network, in the case have + two AttentionSpec: + MLAQuantFullAttentionSpec for fa3 quant layer + FullAttentionSpec for not fa3 quant layer + Divide the available memory equally among all layers. + + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The kv cache spec of each attention layer in the model + available_memory: Memory available for KV cache in bytes. + + Returns: + The generated KVCacheConfig + """ + + # different layers may have different page_size_bytes + page_sizes = sum( + [layer.page_size_bytes for layer in kv_cache_spec.values()]) + num_blocks = int(available_memory // page_sizes) + # create different groups based on page_size_bytes + page_size_layer_map: dict[int, list[str]] = {} + kv_cache_tensors = [] + + for layer_name, layer in kv_cache_spec.items(): + kv_cache_tensors.append( + KVCacheTensor(size=num_blocks * layer.page_size_bytes, + shared_by=[layer_name])) + page_size_layer_map.setdefault(layer.page_size_bytes, + []).append(layer_name) + + grouped_layer_names = list(page_size_layer_map.values()) + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=kv_cache_tensors, + kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec, + grouped_layer_names), + ) + + num_tokens = num_blocks * vllm_config.cache_config.block_size + num_tokens_str = f"{num_tokens:,}" + logger.info("GPU KV cache size: %s tokens", num_tokens_str) + max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" + max_concurrency = get_max_concurrency_for_kv_cache_config_diff_page_size( + vllm_config, kv_cache_config) + logger.info("Maximum concurrency for %s tokens per request: %.2fx", + max_model_len_str, max_concurrency) + return kv_cache_config + + +# Compared to the native vLLM, the _get_kv_cache_config_not_uniform method +# is added to support DeepSeek FA3 quant. Since some layers in fa3 quant DeepSeek +# network are not quant layers with non-uniform page sizes, and the native +# vLLM does not support varying page sizes across layers, +# this new method is implemented to handle such cases. +def get_kv_cache_config( + vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int, +) -> KVCacheConfig: + """ + Generates the KV cache configuration for a model. + + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The kv cache spec of each attention layer in the model + available_memory: Memory available for KV cache in bytes. + + Returns: + The generated KVCacheConfigs + """ + check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) + + if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager: + unify_hybrid_kv_cache_specs(kv_cache_spec) + + if is_kv_cache_type_uniform(kv_cache_spec): + # KV cache of all layers are the same, which is true for + # most models. Allocate the same amount of memory for + # each layer. + return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, + available_memory) + elif is_kv_cache_page_size_uniform(kv_cache_spec): + # Model contains multiple attention types, but KV cache of all layers + # have the same physical memory per block per layer. Split the layers + # into groups with the same number of layers, and thus same total page + # size. + return _get_kv_cache_config_uniform_page_size(vllm_config, + kv_cache_spec, + available_memory) + + return _get_kv_cache_config_not_uniform(vllm_config, kv_cache_spec, + available_memory) diff --git a/vllm_mindspore/v1/core/single_type_kv_cache_manager.py b/vllm_mindspore/v1/core/single_type_kv_cache_manager.py new file mode 100644 index 000000000..5f9090b57 --- /dev/null +++ b/vllm_mindspore/v1/core/single_type_kv_cache_manager.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/v1/core/single_type_kv_cache_manager.py +# +# Copyright 2025 Huawei Technologies Co., Ltd. +# Copyright 2024-2025 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock +from vllm.v1.core.single_type_kv_cache_manager import ( + FullAttentionManager, SingleTypeKVCacheManager, SlidingWindowManager) +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, + SlidingWindowSpec) + +from vllm_mindspore.v1.kv_cache_interface import MLAQuantFullAttentionSpec + + +@classmethod +def find_longest_cache_hit( + cls, + block_hashes: list[BlockHash], + max_length: int, + kv_cache_group_ids: list[int], + block_pool: BlockPool, + kv_cache_spec: KVCacheSpec, + use_eagle: bool, +) -> tuple[list[KVCacheBlock], ...]: + assert isinstance( + kv_cache_spec, (FullAttentionSpec, MLAQuantFullAttentionSpec)), ( + "FullAttentionManager can only be used for full attention " + "or mla quant full attention groups") + computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( + [] for _ in range(len(kv_cache_group_ids))) + max_num_blocks = max_length // kv_cache_spec.block_size + for i, block_hash in zip(range(max_num_blocks), block_hashes): + # block_hashes is a chain of block hashes. If a block hash is not + # in the cached_block_hash_to_id, the following block hashes are + # not computed yet for sure. + if cached_block := block_pool.get_cached_block(block_hash, + kv_cache_group_ids): + for computed, cached in zip(computed_blocks, cached_block): + computed.append(cached) + else: + break + if use_eagle and computed_blocks[0]: + for computed in computed_blocks: + computed.pop() + return computed_blocks + + +spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { + FullAttentionSpec: FullAttentionManager, + MLAQuantFullAttentionSpec: FullAttentionManager, + SlidingWindowSpec: SlidingWindowManager, +} diff --git a/vllm_mindspore/v1/kv_cache_interface.py b/vllm_mindspore/v1/kv_cache_interface.py new file mode 100644 index 000000000..37e253f0c --- /dev/null +++ b/vllm_mindspore/v1/kv_cache_interface.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/v1/kv_cache_interface.py +# +# Copyright 2025 Huawei Technologies Co., Ltd. +# Copyright 2024-2025 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from vllm.config import VllmConfig +from vllm.utils import cdiv +from vllm.v1.kv_cache_interface import AttentionSpec + + +class MLAQuantFullAttentionSpec(AttentionSpec): + + @property + def type_id(self) -> str: + return f"mla_quant_full_attention_{self.block_size}" \ + f"_{self.page_size_bytes}" + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + max_model_len = vllm_config.model_config.max_model_len + return cdiv(max_model_len, self.block_size) * self.page_size_bytes + + @property + def page_size_bytes(self) -> int: + # For MLA we only store a single latent vector + coef = 1 if self.use_mla else 2 + assert self.head_size == 576 + ctkv_nope_dim = 512 + qk_rope_dim = 64 + return coef * self.block_size * self.num_kv_heads * (ctkv_nope_dim + + qk_rope_dim * 2) diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index 743660c11..b70f6744f 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -25,7 +25,7 @@ import mindspore as ms import numpy as np import torch from mindspore import Generator as msGenerator -from mindspore import Tensor, mint, mutable +from mindspore import Tensor, mint, mutable, ops from vllm.attention import AttentionType from vllm.logger import init_logger from vllm.sampling_params import SamplingType @@ -42,6 +42,7 @@ from vllm_mindspore.model_executor.layers.rotary_embedding import ( from vllm_mindspore.model_executor.models.utils import is_use_ringmla from vllm_mindspore.utils import (create_kv_cache, get_dtype_size, get_valid_dtype, is_310p) +from vllm_mindspore.v1.kv_cache_interface import MLAQuantFullAttentionSpec logger = init_logger(__name__) @@ -290,17 +291,44 @@ def _allocate_kv_cache_tensors(self, kv_cache_config): coef = 1 if use_mla else 2 # Determine whether deepseek use mla op use_ringmla = is_use_ringmla(self.vllm_config) + fa3_quant = self.vllm_config.quant_config.fa3_quant \ + if self.vllm_config.quant_config else False + fa3_quant_layer = self.vllm_config.quant_config.fa3_quant_layer \ + if self.vllm_config.quant_config else set() kv_lora_rank = getattr(self.vllm_config.model_config.hf_text_config, 'kv_lora_rank', 0) qk_rope_head_dim = getattr(self.vllm_config.model_config.hf_text_config, 'qk_rope_head_dim', 0) + def get_dtype_from_groups(layer_name, kv_cache_groups): + for group in kv_cache_groups: + if layer_name in group.layer_names: + kv_cache_spec = group.kv_cache_spec + use_mla = kv_cache_spec.use_mla + dtype = kv_cache_spec.dtype + coef = 1 if use_mla else 2 + return dtype, coef + return None, None + kv_cache_raw_tensors: dict[str, Tensor] = {} target_dtype = get_valid_dtype(dtype) dtype_size = get_dtype_size(target_dtype) for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + assert len(kv_cache_tensor.shared_by) == 1 raw_tensors = [] + + layer_name = kv_cache_tensor.shared_by[0] + if fa3_quant: + # fa3_quant have two groups + # fa3 quant layer target_dtype is int8 + # no fa3 quant layer target_dtype is bfloat16 + target_dtype, coef = get_dtype_from_groups( + layer_name, kv_cache_config.kv_cache_groups) + dtype_size = get_dtype_size(target_dtype) + is_fa3_quant_layer = use_ringmla and fa3_quant and \ + int(layer_name) in fa3_quant_layer raw_tensor_shape = kv_cache_tensor.size // dtype_size // coef + # for fa3_quant_layer, target_dtype is int8, dtype_size is 1 for i in range(coef): """ Formulas for calculating each parameter: @@ -312,17 +340,48 @@ def _allocate_kv_cache_tensors(self, kv_cache_config): get_dtype_size(self.dtype)) 4. kv cache shape: num_blocks, block_size, num_kv_heads, head_size """ - raw_tensors.extend( - [mint.zeros(raw_tensor_shape, dtype=target_dtype)] - if not use_ringmla else [ + if not use_ringmla: + raw_tensors.extend( + [mint.zeros(raw_tensor_shape, dtype=target_dtype)]) + elif is_fa3_quant_layer: + """ + for fa3_quant_layer, k_cache is int8, v_cache is bfloat16 + k_cache shape: + [num_block, block_size, 1(head_dim), 512(kv_lora_rank)] + v_cache shape: + [num_block, block_size, 1(head_dim), 64(qk_rope_head_dim)] + and target_dtype is int8, + raw_tensor_shape equals to kv_cache_tensor.size + The bytes occupied by k_cache is + num_block*block_size*512* 1bytes(int8) + The bytes occupied by v_cache is + num_block*block_size*64* 2bytes(bfloat16) + so k_cache row tensor shape is: + raw_tensor_shape*kv_lora_rank/ + (kv_lora_rank+qk_rope_head_dim*2) + v_cache row tensor shape is: + raw_tensor_shape*qk_rope_head_dim / + (kv_lora_rank+qk_rope_head_dim*2) + """ + raw_tensors.extend( + [ + mint.zeros(int(raw_tensor_shape * kv_lora_rank / + (kv_lora_rank + qk_rope_head_dim * 2)), + dtype=target_dtype), + mint.zeros(int(raw_tensor_shape * qk_rope_head_dim / + (kv_lora_rank + qk_rope_head_dim * 2)), + dtype=self.vllm_config.model_config.dtype) + ]) + else: + raw_tensors.extend([ mint.zeros(int(raw_tensor_shape * kv_lora_rank / (kv_lora_rank + qk_rope_head_dim)), dtype=target_dtype), - # deepseek mla op need key cache and rope cache mint.zeros(int(raw_tensor_shape * qk_rope_head_dim / (kv_lora_rank + qk_rope_head_dim)), dtype=target_dtype) ]) + for layer_name in kv_cache_tensor.shared_by: kv_cache_raw_tensors[layer_name] = tuple(raw_tensors) @@ -352,6 +411,10 @@ def _reshape_kv_cache_tensors( """ # Determine whether deepseek use mla op use_ringmla = is_use_ringmla(self.vllm_config) + fa3_quant = self.vllm_config.quant_config.fa3_quant \ + if self.vllm_config.quant_config else False + fa3_quant_layer = self.vllm_config.quant_config.fa3_quant_layer \ + if self.vllm_config.quant_config else set() kv_lora_rank = getattr(self.vllm_config.model_config.hf_text_config, 'kv_lora_rank', 0) qk_rope_head_dim = getattr(self.vllm_config.model_config.hf_text_config, @@ -364,13 +427,20 @@ def _reshape_kv_cache_tensors( raw_tensor = kv_cache_raw_tensors[layer_name] target_dtype = get_valid_dtype(kv_cache_spec.dtype) dtype_size = get_dtype_size(target_dtype) - num_blocks = \ - (raw_tensor[0].numel() - if not use_ringmla else - # deepseek mla op need key cache and rope cache - (raw_tensor[0].numel() + raw_tensor[1].numel())) * \ - coef * dtype_size // kv_cache_spec.page_size_bytes - if isinstance(kv_cache_spec, FullAttentionSpec): + is_fa3_quant_layer = use_ringmla and fa3_quant and \ + int(layer_name) in fa3_quant_layer + # fa3_quant_layer k_cache is int8, v_cache is bfloat16 + # so need to raw_tensor[0].numel() * 1 + raw_tensor[1].numel() * 2 + if is_fa3_quant_layer: + num_blocks = (raw_tensor[0].numel() + raw_tensor[1].numel() * + 2) * coef // kv_cache_spec.page_size_bytes + else: + num_blocks = \ + (raw_tensor[0].numel() if not use_ringmla else \ + (raw_tensor[0].numel() + raw_tensor[1].numel())) * \ + coef * dtype_size // kv_cache_spec.page_size_bytes + if isinstance(kv_cache_spec, + (FullAttentionSpec, MLAQuantFullAttentionSpec)): kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) @@ -406,7 +476,16 @@ def _reshape_kv_cache_tensors( else: cache_block = kv_cache_raw_tensor.view( kv_cache_shape[1:]).permute(*inv_order[1:]) - kv_cache_layer.append(cache_block) + if fa3_quant: + # for fa3_quant, kvcache need be nz format due to ops + num_blocks, block_size, _, _ = cache_block.shape + cache_block = ops.reshape(cache_block, + (num_blocks, block_size, -1)) + cache_block_nz = ops.auto_generate.format_cast( + cache_block, 29) + kv_cache_layer.append(cache_block_nz) + else: + kv_cache_layer.append(cache_block) kv_caches[layer_name] = mutable(tuple(kv_cache_layer)) else: raise NotImplementedError @@ -659,6 +738,10 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: forward_ctx = self.vllm_config.compilation_config.static_forward_context block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla + fa3_quant = self.vllm_config.quant_config.fa3_quant \ + if self.vllm_config.quant_config else False + fa3_quant_layer = self.vllm_config.quant_config.fa3_quant_layer \ + if self.vllm_config.quant_config else set() kv_cache_spec: dict[str, KVCacheSpec] = {} for layer_name, attn_module in forward_ctx.items(): """ @@ -675,12 +758,24 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: sliding_window=attn_module.sliding_window, use_mla=use_mla) else: - kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - use_mla=use_mla) + kv_cache_dtype = self.kv_cache_dtype + is_fa3_quant_layer = int(layer_name) in fa3_quant_layer + if fa3_quant and not is_fa3_quant_layer: + kv_cache_dtype = self.vllm_config.model_config.dtype + if fa3_quant and is_fa3_quant_layer: + kv_cache_spec[layer_name] = MLAQuantFullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=kv_cache_dtype, + use_mla=use_mla) + else: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=kv_cache_dtype, + use_mla=use_mla) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache. -- Gitee