diff --git a/.gitmodules b/.gitmodules index b85ca99e1a54fc4460daa421913f1fe0416230fa..01625342abc79bfced64ceba3992dee31089493c 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,4 +1,4 @@ [submodule "tests/mindformers"] path = tests/mindformers url = https://gitee.com/mindspore/mindformers.git - branch = br_infer_boom + branch = br_feature_infer diff --git a/tests/mindformers b/tests/mindformers index a96a0cf9a0b1e3efee819c17dcf050604caa3512..6aaf2bbbbab06dabe483b0ba50d6e447aa2cc2cb 160000 --- a/tests/mindformers +++ b/tests/mindformers @@ -1 +1 @@ -Subproject commit a96a0cf9a0b1e3efee819c17dcf050604caa3512 +Subproject commit 6aaf2bbbbab06dabe483b0ba50d6e447aa2cc2cb diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index a4799e8115fc714a95c613e48ea2ca40b0c7030d..2c7b7ebe9c031d06df720a639d096e1d31999bdb 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -40,6 +40,7 @@ from vllm_mindspore.platforms.ascend import AscendPlatform ascend_platform = AscendPlatform() + import vllm.config vllm.config.current_platform = ascend_platform @@ -75,12 +76,27 @@ from vllm_mindspore.v1.engine.core import ( vllm.v1.engine.core.DPEngineCoreProc._init_data_parallel = _init_data_parallel vllm.v1.engine.core.DPEngineCoreProc.shutdown = shutdown +from vllm_mindspore.v1.core.kv_cache_utils import ( + get_kv_cache_config, +) +vllm.v1.core.kv_cache_utils.get_kv_cache_config = get_kv_cache_config +vllm.v1.engine.core.get_kv_cache_config = get_kv_cache_config +from vllm_mindspore.v1.core.single_type_kv_cache_manager import find_longest_cache_hit, spec_manager_map +vllm.v1.core.single_type_kv_cache_manager.FullAttentionManager.find_longest_cache_hit = find_longest_cache_hit +vllm.v1.core.single_type_kv_cache_manager.spec_manager_map = spec_manager_map from vllm_mindspore.utils import ( make_tensor_with_pad, async_tensor_h2d, ascend_is_initialized, ms_memory_profiling, ) +from dataclasses import fields, dataclass +from vllm_mindspore.config import CacheDType, _CacheConfig +vllm.config.CacheConfig = _CacheConfig +vllm.config.CacheDType = CacheDType +import vllm.engine.arg_utils +vllm.engine.arg_utils.CacheDType = CacheDType +vllm.engine.arg_utils.CacheConfig = _CacheConfig vllm.utils.make_tensor_with_pad = make_tensor_with_pad vllm.utils.async_tensor_h2d = async_tensor_h2d @@ -161,12 +177,16 @@ from vllm_mindspore.worker.cache_engine import ( ms_swap_in, ms_swap_out, ) - +from vllm_mindspore.utils import get_dtype_size import vllm.worker.cache_engine vllm.worker.cache_engine.CacheEngine._allocate_kv_cache = ms_allocate_kv_cache vllm.worker.cache_engine.CacheEngine.swap_in = ms_swap_in vllm.worker.cache_engine.CacheEngine.swap_out = ms_swap_out +vllm.worker.cache_engine.get_dtype_size = get_dtype_size + +import vllm.v1.kv_cache_interface +vllm.v1.kv_cache_interface.get_dtype_size = get_dtype_size from vllm_mindspore.model_executor.model_loader.weight_utils import ( safetensors_weights_iterator, ) diff --git a/vllm_mindspore/config.py b/vllm_mindspore/config.py index c464227bf6afc8cbe1d34d8c5356de2bec89c1ce..80b4db06f941bafe7ca030cccf3b6015ff7c68da 100644 --- a/vllm_mindspore/config.py +++ b/vllm_mindspore/config.py @@ -31,10 +31,26 @@ from transformers import PretrainedConfig from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.config import (_STR_DTYPE_TO_TORCH_DTYPE, CompilationConfig, CompilationLevel, VllmConfig, _find_dtype, + PrefixCachingHashAlgo, config, BlockSize, CacheConfig, _resolve_auto_dtype) from vllm.logger import init_logger from vllm.utils import random_uuid +import hashlib +from collections import Counter +from dataclasses import field +from typing import (Any, Literal, Optional, Union, get_args) + +import torch +from pydantic import SkipValidation +from pydantic.dataclasses import dataclass +from transformers import PretrainedConfig + +import vllm.envs as envs +from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass +from vllm.logger import init_logger +from vllm.transformers_utils.config import (try_get_safetensors_metadata) +from vllm.utils import (GiB_bytes, get_cpu_memory, random_uuid) logger = init_logger(__name__) @@ -386,3 +402,155 @@ def stateless_destroy_socket_process_group( dp_group.close() logger.info("Socket process group for rank %d destroyed.", dp_group.rank) + +CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "int8"] + +@config +@dataclass +class _CacheConfig(CacheConfig): + """Configuration for the KV cache.""" + + block_size: SkipValidation[BlockSize] = None # type: ignore + """Size of a contiguous cache block in number of tokens. This is ignored on + neuron devices and set to `--max-model-len`. On CUDA devices, only block + sizes up to 32 are supported. On HPU devices, block size defaults to 128. + + This config has no static default. If left unspecified by the user, it will + be set in `Platform.check_and_update_configs()` based on the current + platform.""" + gpu_memory_utilization: float = 0.9 + """The fraction of GPU memory to be used for the model executor, which can + range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory + utilization. If unspecified, will use the default value of 0.9. This is a + per-instance limit, and only applies to the current vLLM instance. It does + not matter if you have another vLLM instance running on the same GPU. For + example, if you have two vLLM instances running on the same GPU, you can + set the GPU memory utilization to 0.5 for each instance.""" + swap_space: float = 4 + """Size of the CPU swap space per GPU (in GiB).""" + cache_dtype: CacheDType = "auto" + """Data type for kv cache storage. If "auto", will use model data type. + CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports + fp8 (=fp8_e4m3).""" + is_attention_free: bool = False + """Whether the model is attention-free. This is primarily set in + `ModelConfig` and that value should be manually duplicated here.""" + num_gpu_blocks_override: Optional[int] = None + """Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks` + if specified. Does nothing if `None`. Used for testing preemption.""" + sliding_window: Optional[int] = None + """Sliding window size for the KV cache. This is primarily set in + `ModelConfig` and that value should be manually duplicated here.""" + enable_prefix_caching: Optional[bool] = None + """Whether to enable prefix caching. Disabled by default for V0. Enabled by + default for V1.""" + prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin" + """Set the hash algorithm for prefix caching:\n + - "builtin" is Python's built-in hash.\n + - "sha256" is collision resistant but with certain overheads.""" + cpu_offload_gb: float = 0 + """The space in GiB to offload to CPU, per GPU. Default is 0, which means + no offloading. Intuitively, this argument can be seen as a virtual way to + increase the GPU memory size. For example, if you have one 24 GB GPU and + set this to 10, virtually you can think of it as a 34 GB GPU. Then you can + load a 13B model with BF16 weight, which requires at least 26GB GPU memory. + Note that this requires fast CPU-GPU interconnect, as part of the model is + loaded from CPU memory to GPU memory on the fly in each model forward pass. + """ + calculate_kv_scales: bool = False + """This enables dynamic calculation of `k_scale` and `v_scale` when + kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model + checkpoint if available. Otherwise, the scales will default to 1.0.""" + + # Will be set after profiling. + num_gpu_blocks: Optional[int] = field(default=None, init=False) + """The number of blocks to allocate for GPU memory.""" + num_cpu_blocks: Optional[int] = field(default=None, init=False) + """The number of blocks to allocate for CPU memory.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + factors.append(self.cache_dtype) + # `cpu_offload_gb` does not use `torch.compile` yet. + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self) -> None: + self.swap_space_bytes = self.swap_space * GiB_bytes + + self._verify_args() + self._verify_cache_dtype() + self._verify_prefix_caching() + + def metrics_info(self): + # convert cache_config to dict(key: str, value: str) for prometheus + # metrics info + return {key: str(value) for key, value in self.__dict__.items()} + + def _verify_args(self) -> None: + if self.cpu_offload_gb < 0: + raise ValueError("CPU offload space must be non-negative" + f", but got {self.cpu_offload_gb}") + + if self.gpu_memory_utilization > 1.0: + raise ValueError( + "GPU memory utilization must be less than 1.0. Got " + f"{self.gpu_memory_utilization}.") + + def _verify_cache_dtype(self) -> None: + if self.cache_dtype == "auto": + pass + elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2", "int8"): + logger.info( + "Using fp8 data type to store kv cache. It reduces the GPU " + "memory footprint and boosts the performance. " + "Meanwhile, it may cause accuracy drop without a proper " + "scaling factor") + else: + raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") + + def _verify_prefix_caching(self) -> None: + if not self.enable_prefix_caching: + return + + if self.sliding_window is not None and not envs.VLLM_USE_V1: + raise NotImplementedError( + "Prefix caching is not supported with sliding window. " + "Run with --disable-sliding-window to use prefix caching.") + + if (self.enable_prefix_caching and self.prefix_caching_hash_algo + not in get_args(PrefixCachingHashAlgo)): + raise ValueError( + "Unknown prefix caching hash algorithm: " + f"{self.prefix_caching_hash_algo}. Must be one of " + f"{get_args(PrefixCachingHashAlgo)}.") + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + total_cpu_memory = get_cpu_memory() + # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel + # group are in the same node. However, the GPUs may span multiple nodes. + num_gpus_per_node = parallel_config.tensor_parallel_size + cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node + + msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the " + f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory " + "is allocated for the swap space.") + if cpu_memory_usage > 0.7 * total_cpu_memory: + raise ValueError("Too large swap space. " + msg) + elif cpu_memory_usage > 0.4 * total_cpu_memory: + logger.warning("Possibly too large swap space. %s", msg) diff --git a/vllm_mindspore/engine/arg_utils.py b/vllm_mindspore/engine/arg_utils.py index a7bfc2204126628a8d66e620b3e6ce195db50b37..0f0584918f57ccff26d678ede4dac7a282949b81 100644 --- a/vllm_mindspore/engine/arg_utils.py +++ b/vllm_mindspore/engine/arg_utils.py @@ -92,21 +92,21 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: return False # No Fp8 KV cache so far. - if self.kv_cache_dtype != "auto": - fp8_attention = self.kv_cache_dtype.startswith("fp8") - will_use_fa = (current_platform.is_cuda() - and not envs.is_set("VLLM_ATTENTION_BACKEND") - ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" - supported = False - if current_platform.is_rocm(): - supported = True - elif fp8_attention and will_use_fa: - from vllm.attention.utils.fa_utils import flash_attn_supports_fp8 - supported = flash_attn_supports_fp8() - if not supported: - _raise_or_fallback(feature_name="--kv-cache-dtype", - recommend_to_remove=False) - return False + # if self.kv_cache_dtype != "auto": + # fp8_attention = self.kv_cache_dtype.startswith("fp8") + # will_use_fa = ( + # current_platform.is_cuda() + # and not envs.is_set("VLLM_ATTENTION_BACKEND") + # ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" + # supported = False + # if fp8_attention and will_use_fa: + # from vllm.vllm_flash_attn.fa_utils import ( + # flash_attn_supports_fp8) + # supported = flash_attn_supports_fp8() + # if not supported: + # _raise_or_fallback(feature_name="--kv-cache-dtype", + # recommend_to_remove=False) + # return False # No Prompt Adapter so far. if self.enable_prompt_adapter: diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py index eec50dd8982a283f5ec2213512e94305d4ebdf11..fede35363f4591db12ccf59beb8af61f2b0c5ddf 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py @@ -22,7 +22,7 @@ from collections.abc import Iterable import mindspore as ms import numpy as np from mindformers.trainer.utils import transform_and_load_checkpoint -from mindspore import Model, Tensor, mutable +from mindspore import Model, Tensor, mutable, ops from mindspore.common import dtype as msdtype from mindspore.common.api import _pynative_executor from mindspore.nn.utils import no_init_parameters @@ -30,13 +30,14 @@ from mindspore_gs.common import BackendTarget from mindspore_gs.ptq import (PTQ, GPTQQuantConfig, OutliersSuppressionType, PrecisionRecovery, PTQConfig, PTQMode, QuantGranularity) - +from mindspore import Tensor, mint, mutable, ops # isort: off from research.deepseek3.deepseek3 import (DeepseekV3ForCausalLM as DeepseekV3ForCausalLM_MF) from research.deepseek3.deepseek3_config import (DeepseekV3Config as DeepseekV3Config_MF) # isort: on +import vllm.envs as envs from research.deepseek3.deepseek3_model_infer import DeepseekV3DecodeLayer from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed.parallel_state import ( @@ -123,6 +124,47 @@ def _get_padding_index(q_seq_len): ms.from_numpy(ffn_padding_idx), ms.from_numpy(ffn_unpadding_idx) +class DeepseekV3MLAAttentionWrapper(MLAAttentionWrapper): + + def __init__(self, kv_cache_dtype): + super().__init__() + vllm_config = get_current_vllm_config() + self.use_mla_op = bool( + vllm_config.additional_config + and vllm_config.additional_config.get('use_mla_op') == 1) + self.dtype = kv_cache_dtype + self.fa3_quant = bool(vllm_config.additional_config and vllm_config.additional_config.get('fa3_quant') == 1) + if self.use_mla_op: + kv_lora_rank = getattr(vllm_config.model_config.hf_text_config, + 'kv_lora_rank', 0) + qk_rope_head_dim = getattr(vllm_config.model_config.hf_text_config, + 'qk_rope_head_dim', 0) + # k_shape, r_shape used for mla_op + if self.fa3_quant: + k_shape = [*(self.kv_shape[0:-2]), kv_lora_rank] + r_shape = [*(self.kv_shape[0:-2]), qk_rope_head_dim] + self.kv_cache = [ + (ops.auto_generate.format_cast( + ms.mint.zeros(k_shape, dtype=kv_cache_dtype), + 29), + ops.auto_generate.format_cast( + ms.mint.zeros(r_shape, dtype=vllm_config.model_config.dtype), + 29)) + for _ in range( + vllm_config.parallel_config.pipeline_parallel_size) + ] + + else: + k_shape = [*(self.kv_shape[0:-1]), kv_lora_rank] + r_shape = [*(self.kv_shape[0:-1]), qk_rope_head_dim] + self.kv_cache = [ + (ms.mint.zeros(k_shape, dtype=kv_cache_dtype), + ms.mint.zeros(r_shape, dtype=vllm_config.model_config.dtype)) + for _ in range( + vllm_config.parallel_config.pipeline_parallel_size) + ] + + class DeepseekV3ForCausalLM(MfModelBase): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: @@ -135,10 +177,10 @@ class DeepseekV3ForCausalLM(MfModelBase): self.sampler = get_sampler() self.set_modules({"model": self.network}) - self.kv_caches = [ - MLAAttentionWrapper() - for i in range(self.mf_model_config.num_layers) - ] + self.kv_caches = [DeepseekV3MLAAttentionWrapper(ms.int8) if self.fa3_quant and str(i) not in \ + self.fa3_no_quant_layers else \ + DeepseekV3MLAAttentionWrapper(vllm_config.model_config.dtype) \ + for i in range(self.mf_model_config.num_layers)] compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: @@ -158,6 +200,12 @@ class DeepseekV3ForCausalLM(MfModelBase): self.mf_model_config = DeepseekV3Config_MF( **self.mf_config.model.model_config) + self.mf_model_config.use_mla_op = self.use_mla_op + if self.fa3_quant: + self.mf_model_config.quantization_config.kvcache_quant_dtype = "int8" + self.mf_model_config.quantization_config.modules_to_not_convert = self.fa3_no_quant_layers + if self.use_mla_op: + assert envs.VLLM_USE_V1 if self.mf_config.moe_config: self.mf_model_config.moe_config = self.mf_config.moe_config # dispatch/combine in moe need max_num_seqs as global_max_bs @@ -187,11 +235,18 @@ class DeepseekV3ForCausalLM(MfModelBase): def get_kvcache(self): key_cache = [] forward_context = get_forward_context() - for i in range(self.mf_model_config.num_layers): - k_cache = self.kv_caches[i].kv_cache[ - forward_context.virtual_engine][0] - key_cache.append(k_cache) - return mutable(key_cache), None + key_cache = [ + self.kv_caches[i].kv_cache[forward_context.virtual_engine][0] + for i in range(self.mf_model_config.num_layers) + ] + if not self.use_mla_op: + return mutable(key_cache), None + else: + value_cache = [ + self.kv_caches[i].kv_cache[forward_context.virtual_engine][1] + for i in range(self.mf_model_config.num_layers) + ] + return mutable(key_cache), mutable(value_cache) # DLLM def connector_send_kvcache(self): diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py index 899365eb6ffe8fc64692fa61409b97ef4307f82c..f75d0d78cd8bac8680ca5da1a50abfb8b0d94e99 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py @@ -1024,8 +1024,45 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): qkv2l_quant_scale_name = f"{base_path}.quant_op.input_scale" qkv2l_rmsnorm_beta_name = f"{base_path}.quant_op.beta" - if hasattr(self.config.model.model_config, "use_mla_pre" - ) and self.config.model.model_config.use_mla_pre: + qnope_scale_ms_name = f"model.layers.{layer_id}.attention.qnope_scale" + ctkv_scale_ms_name = f"model.layers.{layer_id}.attention.ctkv_scale" + qk_descale_ms_name = f"model.layers.{layer_id}.attention.qk_descale" + pv_descale_ms_name = f"model.layers.{layer_id}.attention.pv_descale" + qnope_scale_hf_name = f"model.layers.{layer_id}.self_attn.fa_q.scale" + ctkv_scale_hf_name = f"model.layers.{layer_id}.self_attn.fa_k.scale" + if hasattr(self.config.model.model_config, "use_mla_op" + ) and self.config.model.model_config.use_mla_op: + quant_config = self.config.model.model_config.quantization_config + if hasattr(quant_config, "kvcache_quant_dtype") and quant_config.kvcache_quant_dtype == "int8" \ + and str(layer_id) not in quant_config.modules_to_not_convert: + print(f"layer_id {layer_id} fa3_quant load weight.") + qnope_scale, _ = self.get_safetensor_from_file_split_tp_group( + qnope_scale_hf_name, src_hf_dir, hf_weight_map, split_axis=0) + qnope_scale = qnope_scale.squeeze(-1) + parameter_dict[qnope_scale_ms_name] = ms.Parameter( + ms.Tensor(1 / qnope_scale, ms.bfloat16), + name=qnope_scale_ms_name, + requires_grad=False) + + ctkv_scale, _ = self.get_safetensor_from_file( + ctkv_scale_hf_name, src_hf_dir, hf_weight_map) + ctkv_scale = ctkv_scale.squeeze(-1) + parameter_dict[ctkv_scale_ms_name] = ms.Parameter( + ms.Tensor(ctkv_scale, ms.bfloat16), + name=ctkv_scale_ms_name, + requires_grad=False) + + ctkv_scale = np.repeat(ctkv_scale, qnope_scale.shape[0]) + qk_descale = ctkv_scale.astype(np.float32) * qnope_scale.astype(np.float32) + parameter_dict[qk_descale_ms_name] = ms.Parameter( + ms.Tensor(qk_descale, ms.float32), + name=qk_descale_ms_name, + requires_grad=False) + + parameter_dict[pv_descale_ms_name] = ms.Parameter( + ms.Tensor(ctkv_scale, ms.float32), + name=pv_descale_ms_name, + requires_grad=False) qkv2l_weight = np.concatenate((kv2l_ms_param, q2l_ms_param), 0) qkv2l_bias = np.concatenate( (kv2l_quant_bias_ms_param, q2l_quant_bias_ms_param), 0) diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py index f82d2d9f5192485666e1c41879e4de54f690e05b..f0c1c8f59244240900c3160e077f96ca67b8c1d4 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py @@ -39,7 +39,7 @@ from vllm.sequence import IntermediateTensors from vllm_mindspore.model_executor.models.attention_mask import ( LowerTriangularMask) from vllm_mindspore.model_executor.models.model_base import MsModelBase - +from vllm_mindspore.utils import get_fa3_no_quant_layers try: # Need to apply dllm pd patch on vllm to use pd disagg related functions from vllm.attention.layer import (maybe_save_kv_layer_to_connector, @@ -74,6 +74,20 @@ class MfModelBase(MsModelBase): self.mf_config.model.model_config.parallel_config.model_parallel = ( get_tensor_model_parallel_world_size()) self.mf_config.model.model_config.parallel_config.pipeline_stage = 1 + self.use_mla_op = \ + bool(vllm_config.additional_config + and vllm_config.additional_config.get('use_mla_op') == 1) + self.fa3_quant = \ + bool(vllm_config.additional_config and vllm_config.additional_config.get('fa3_quant') == 1) + if self.fa3_quant and vllm_config.cache_config.cache_dtype == "auto": + raise RuntimeError("To use fa3_quant, you need to set \"--kv-cache-dtype 'int8'\" in the startup command.") + if self.fa3_quant and not self.use_mla_op: + raise RuntimeError("To use fa3_quant, you need to set \"\'use_mla_op\': 1\" in the additional-config of startup command.") + if self.fa3_quant: + self.fa3_no_quant_layers = get_fa3_no_quant_layers() + else: + self.fa3_no_quant_layers = [] + self._generate_model_config() if not hasattr(self, 'mf_model_config'): raise RuntimeError('mf_model_config not initialized') diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index 179dc93f48cc82d77a06f381a0cb349f172908ff..c7b53eedfe7b45a513505f0a85990f89ff0e5dbb 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -54,8 +54,13 @@ STR_DTYPE_TO_MS_DTYPE = { "fp8": ms.uint8, "fp8_e4m3": ms.uint8, "fp8_e5m2": ms.uint8, + "int8": ms.int8, } +def get_fa3_no_quant_layers(): + return ["0"] + #["0", "1", "2", "46", "47", "50", "54", "55", "56", \ + # "57", "58", "59", "60"] def get_valid_dtype(dtype): if isinstance(dtype, str): @@ -67,7 +72,7 @@ def get_dtype_size(dtype: torch.dtype) -> int: """Get the size of the data type in bytes.""" if isinstance(dtype, str): dtype = STR_DTYPE_TO_TENSOR_DTYPE[dtype] - return torch.tensor([], dtype=dtype).element_size() + return torch.tensor([1], dtype=dtype).itemsize def _create_empty_tensor(ms_type): @@ -138,6 +143,7 @@ STR_DTYPE_TO_TENSOR_DTYPE = { "fp8": torch.uint8, "fp8_e4m3": torch.uint8, "fp8_e5m2": torch.uint8, + "int8": ms.int8, } STR_DTYPE_TO_MS_DTYPE = { @@ -148,8 +154,16 @@ STR_DTYPE_TO_MS_DTYPE = { "fp8": mstype.uint8, "fp8_e4m3": mstype.uint8, "fp8_e5m2": mstype.uint8, + "int8": ms.int8, } +def get_kv_cache_dtype(vllm_config): + """get_kv_cache_dtype""" + kv_cache_dtype = vllm_config.model_config.dtype if vllm_config.cache_config.cache_dtype == "auto" \ + else vllm_config.cache_config.cache_dtype + if kv_cache_dtype in STR_DTYPE_TO_MS_DTYPE: + kv_cache_dtype = STR_DTYPE_TO_MS_DTYPE[kv_cache_dtype] + return kv_cache_dtype class vllmModelBackendEnum(str, Enum): """Define the variable Enum of vLLM_MODEL_BACKEND""" @@ -242,6 +256,7 @@ def check_ready(): # Common environment variables of predict. set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + set_context(save_graphs=True, save_graphs_path="./ir_86") default_env = { "MS_INTERNAL_DISABLE_CUSTOM_KERNEL_LIST": "FlashAttentionScore,PagedAttention", diff --git a/vllm_mindspore/v1/core/kv_cache_utils.py b/vllm_mindspore/v1/core/kv_cache_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3f02590723ee7a3e6a54af47dc7afd833aa15c92 --- /dev/null +++ b/vllm_mindspore/v1/core/kv_cache_utils.py @@ -0,0 +1,113 @@ +from vllm.config import VllmConfig +from vllm.utils import cdiv +from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec, + KVCacheTensor) +from vllm.v1.core.kv_cache_utils import (check_enough_kv_cache_memory, + create_kv_cache_group_specs, + unify_hybrid_kv_cache_specs, logger, + is_kv_cache_type_uniform, + is_kv_cache_page_size_uniform, + _get_kv_cache_config_uniform_type, + _get_kv_cache_config_uniform_page_size) + +def get_max_concurrency_for_kv_cache_config_diff_page_size( + vllm_config: VllmConfig, kv_cache_config: KVCacheConfig) -> float: + """ + Get the maximum concurrency for the given KV cache configuration. + """ + block_size = kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + num_block_per_request = cdiv(vllm_config.model_config.max_model_len, + block_size) + max_concurrency = kv_cache_config.num_blocks / num_block_per_request + return max_concurrency + + +def _get_kv_cache_config_not_uniform(vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int) -> KVCacheConfig: + """ + Generates the KV cache configuration for a model with one type of KV cache. + Divide the available memory equally among all layers. + + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The kv cache spec of each attention layer in the model + available_memory: Memory available for KV cache in bytes. + + Returns: + The generated KVCacheConfig + """ + + # different layers may have different page_size_bytes + page_sizes = sum([layer.page_size_bytes for layer in kv_cache_spec.values()]) + num_blocks = int(available_memory // page_sizes) + # create different groups based on page_size_bytes + page_size_layer_map = {} + kv_cache_tensors = [] + + for layer_name, layer in kv_cache_spec.items(): + kv_cache_tensors.append( + KVCacheTensor( + size=num_blocks * layer.page_size_bytes, + shared_by=[layer_name] + ) + ) + page_size_layer_map.setdefault(layer.page_size_bytes, []).append(layer_name) + + grouped_layer_names = list(page_size_layer_map.values()) + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=kv_cache_tensors, + kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec, + grouped_layer_names), + ) + + num_tokens = num_blocks * vllm_config.cache_config.block_size + num_tokens_str = f"{num_tokens:,}" + logger.info("GPU KV cache size: %s tokens", num_tokens_str) + max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" + max_concurrency = get_max_concurrency_for_kv_cache_config_diff_page_size( + vllm_config, kv_cache_config) + logger.info("Maximum concurrency for %s tokens per request: %.2fx", + max_model_len_str, max_concurrency) + return kv_cache_config + + +def get_kv_cache_config( + vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int, +) -> KVCacheConfig: + """ + Generates the KV cache configuration for a model. + + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The kv cache spec of each attention layer in the model + available_memory: Memory available for KV cache in bytes. + + Returns: + The generated KVCacheConfigs + """ + check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) + + if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager: + unify_hybrid_kv_cache_specs(kv_cache_spec) + + if is_kv_cache_type_uniform(kv_cache_spec): + # KV cache of all layers are the same, which is true for + # most models. Allocate the same amount of memory for + # each layer. + return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, + available_memory) + elif is_kv_cache_page_size_uniform(kv_cache_spec): + # Model contains multiple attention types, but KV cache of all layers + # have the same physical memory per block per layer. Split the layers + # into groups with the same number of layers, and thus same total page + # size. + return _get_kv_cache_config_uniform_page_size(vllm_config, + kv_cache_spec, + available_memory) + + return _get_kv_cache_config_not_uniform(vllm_config, kv_cache_spec, + available_memory) diff --git a/vllm_mindspore/v1/core/single_type_kv_cache_manager.py b/vllm_mindspore/v1/core/single_type_kv_cache_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..ee7de23579cf4c82a2bbba26afa0ed4951d7cfb3 --- /dev/null +++ b/vllm_mindspore/v1/core/single_type_kv_cache_manager.py @@ -0,0 +1,52 @@ +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Callable + +from vllm.utils import cdiv +from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, + SlidingWindowSpec) +from vllm.v1.request import Request +from vllm.v1.core.single_type_kv_cache_manager import (SingleTypeKVCacheManager, + FullAttentionManager, + SlidingWindowManager) +from vllm_mindspore.v1.kv_cache_interface import MLAQuantFullAttentionSpec + +@classmethod +def find_longest_cache_hit( + cls, + block_hashes: list[BlockHash], + max_length: int, + kv_cache_group_ids: list[int], + block_pool: BlockPool, + kv_cache_spec: KVCacheSpec, + use_eagle: bool, +) -> tuple[list[KVCacheBlock], ...]: + assert isinstance(kv_cache_spec, (FullAttentionSpec, MLAQuantFullAttentionSpec)), ( + "FullAttentionManager can only be used for full attention or mla quant full attention groups") + computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( + [] for _ in range(len(kv_cache_group_ids))) + max_num_blocks = max_length // kv_cache_spec.block_size + for i, block_hash in zip(range(max_num_blocks), block_hashes): + # block_hashes is a chain of block hashes. If a block hash is not + # in the cached_block_hash_to_id, the following block hashes are + # not computed yet for sure. + if cached_block := block_pool.get_cached_block( + block_hash, kv_cache_group_ids): + for computed, cached in zip(computed_blocks, cached_block): + computed.append(cached) + else: + break + if use_eagle and computed_blocks[0]: + for computed in computed_blocks: + computed.pop() + return computed_blocks + + + +spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { + FullAttentionSpec: FullAttentionManager, + MLAQuantFullAttentionSpec: FullAttentionManager, + SlidingWindowSpec: SlidingWindowManager, +} \ No newline at end of file diff --git a/vllm_mindspore/v1/kv_cache_interface.py b/vllm_mindspore/v1/kv_cache_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..99a8bf7cf1d8045f6c52654ca3d71941586e7ac9 --- /dev/null +++ b/vllm_mindspore/v1/kv_cache_interface.py @@ -0,0 +1,22 @@ +from vllm.config import VllmConfig +from vllm.utils import cdiv +from vllm.v1.kv_cache_interface import AttentionSpec + +class MLAQuantFullAttentionSpec(AttentionSpec): + + @property + def type_id(self) -> str: + return f"mla_quant_full_attention_{self.block_size}_{self.page_size_bytes}" + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + max_model_len = vllm_config.model_config.max_model_len + return cdiv(max_model_len, self.block_size) * self.page_size_bytes + + @property + def page_size_bytes(self) -> int: + # For MLA we only store a single latent vector + coef = 1 if self.use_mla else 2 + assert self.head_size == 576 + ctkv_nope_dim = 512 + qk_rope_dim = 64 + return coef * self.block_size * self.num_kv_heads * (ctkv_nope_dim + qk_rope_dim * 2) diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index c11e37fa06fd02d90067c9e8ef436c14ab28bf37..5483077d9ee3fb9e922121b6cbf37afcf0248669 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -22,9 +22,17 @@ from typing import Any, Optional import mindspore as ms import numpy as np +from mindspore import Tensor, mint, mutable, ops from mindspore import Generator as msGenerator -from mindspore import Tensor, mint, mutable +from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata +from vllm_mindspore.utils import get_valid_dtype, get_dtype_size, get_fa3_no_quant_layers +from vllm_mindspore.model_executor.layers.rotary_embedding import InferMRotaryEmbedding as MRotaryEmbedding # type: ignore[attr-defined] + +from vllm.v1.outputs import ModelRunnerOutput from vllm.attention import AttentionType +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec, SlidingWindowSpec +from vllm_mindspore.v1.kv_cache_interface import MLAQuantFullAttentionSpec +from vllm.v1.utils import bind_kv_cache from vllm.logger import init_logger from vllm.sampling_params import SamplingType from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, @@ -223,12 +231,38 @@ def _allocate_kv_cache_tensors(self, kv_cache_config): use_mla = kv_cache_spec.use_mla dtype = kv_cache_spec.dtype coef = 1 if use_mla else 2 + use_mla_op = bool(self.vllm_config.additional_config and self.vllm_config.additional_config.get('use_mla_op') == 1) + fa3_quant = bool(self.vllm_config.additional_config and self.vllm_config.additional_config.get('fa3_quant') == 1) + fa3_no_quant_layers = get_fa3_no_quant_layers() + kv_lora_rank = getattr(self.vllm_config.model_config.hf_text_config, 'kv_lora_rank', 0) + qk_rope_head_dim = getattr(self.vllm_config.model_config.hf_text_config, 'qk_rope_head_dim', 0) + def get_dtype_from_groups(layer_name, kv_cache_groups): + for group in kv_cache_groups: + if layer_name in group.layer_names: + kv_cache_spec = group.kv_cache_spec + use_mla = kv_cache_spec.use_mla + dtype = kv_cache_spec.dtype + coef = 1 if use_mla else 2 + return dtype, coef + return None, None kv_cache_raw_tensors: dict[str, Tensor] = {} target_dtype = get_valid_dtype(dtype) dtype_size = get_dtype_size(target_dtype) + print("target_dtype----", target_dtype) for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + assert len(kv_cache_tensor.shared_by) == 1 raw_tensors = [] + layer_name = kv_cache_tensor.shared_by[0] + if fa3_quant: + target_dtype, coef = get_dtype_from_groups(layer_name, kv_cache_config.kv_cache_groups) + dtype_size = get_dtype_size(target_dtype) + is_fa3_quant_layer = use_mla_op and fa3_quant and layer_name not in fa3_no_quant_layers + print("is_fa3_quant_layer----", is_fa3_quant_layer, "layer_name", layer_name) + if is_fa3_quant_layer: + raw_tensor_shape = kv_cache_tensor.size + else: + raw_tensor_shape = kv_cache_tensor.size // dtype_size // coef raw_tensor_shape = kv_cache_tensor.size // dtype_size // coef for i in range(coef): # Formulas for calculating each parameter: @@ -239,11 +273,26 @@ def _allocate_kv_cache_tensors(self, kv_cache_config): # self.block_size * self.num_kv_heads * self.head_size * # get_dtype_size(self.dtype)) # 4. kv cache shape: num_blocks, block_size, num_kv_heads, head_size - raw_tensor_split = mint.zeros( - raw_tensor_shape, - dtype=target_dtype, - ) - raw_tensors.append(raw_tensor_split) + if not use_mla_op: + raw_tensors.extend( + [mint.zeros(raw_tensor_shape, dtype=target_dtype)] + ) + elif is_fa3_quant_layer: + raw_tensors.extend( + # rope_cache dtype is bfloat16, and the target_dtype is int8, so need to add twice qk_rope_head_dim + [mint.zeros(int(raw_tensor_shape * kv_lora_rank / (kv_lora_rank + qk_rope_head_dim * 2)), + dtype=target_dtype), + mint.zeros(int(raw_tensor_shape * qk_rope_head_dim / (kv_lora_rank + qk_rope_head_dim * 2)), + dtype=self.vllm_config.model_config.dtype)] + ) + else: + raw_tensors.extend( + [mint.zeros(int(raw_tensor_shape * kv_lora_rank / (kv_lora_rank + qk_rope_head_dim)), + dtype=target_dtype), + mint.zeros(int(raw_tensor_shape * qk_rope_head_dim / (kv_lora_rank + qk_rope_head_dim)), + dtype=target_dtype)] + ) + for layer_name in kv_cache_tensor.shared_by: kv_cache_raw_tensors[layer_name] = tuple(raw_tensors) @@ -271,6 +320,11 @@ def _reshape_kv_cache_tensors( Dict[str, Tensor]: A map between layer names to their corresponding memory buffer for KV cache. """ + use_mla_op = bool(self.vllm_config.additional_config and self.vllm_config.additional_config.get('use_mla_op') == 1) + fa3_quant = bool(self.vllm_config.additional_config and self.vllm_config.additional_config.get('fa3_quant') == 1) + kv_lora_rank = getattr(self.vllm_config.model_config.hf_text_config, 'kv_lora_rank', 0) + qk_rope_head_dim = getattr(self.vllm_config.model_config.hf_text_config, 'qk_rope_head_dim', 0) + fa3_no_quant_layers = get_fa3_no_quant_layers() kv_caches: dict[str, tuple] = {} for i, kv_cache_group_spec in enumerate(kv_cache_config.kv_cache_groups): kv_cache_spec = kv_cache_group_spec.kv_cache_spec @@ -279,9 +333,14 @@ def _reshape_kv_cache_tensors( raw_tensor = kv_cache_raw_tensors[layer_name] target_dtype = get_valid_dtype(kv_cache_spec.dtype) dtype_size = get_dtype_size(target_dtype) - num_blocks = (raw_tensor[0].numel() * coef * dtype_size // - kv_cache_spec.page_size_bytes) - if isinstance(kv_cache_spec, FullAttentionSpec): + is_fa3_quant_layer = use_mla_op and fa3_quant and layer_name not in fa3_no_quant_layers + if is_fa3_quant_layer: + num_blocks = (raw_tensor[0].numel() + raw_tensor[1].numel() * 2) * coef // kv_cache_spec.page_size_bytes + else: + num_blocks = \ + (raw_tensor[0].numel() if not use_mla_op else (raw_tensor[0].numel() + raw_tensor[1].numel())) * \ + coef * dtype_size // kv_cache_spec.page_size_bytes + if isinstance(kv_cache_spec, (FullAttentionSpec, MLAQuantFullAttentionSpec)): kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) @@ -304,10 +363,29 @@ def _reshape_kv_cache_tensors( for i in range(len(kv_cache_stride_order)) ] kv_cache_layer = [] - for kv_cache_raw_tensor in kv_cache_raw_tensors[layer_name]: - cache_block = kv_cache_raw_tensor.view( - kv_cache_shape[1:]).permute(*inv_order[1:]) - kv_cache_layer.append(cache_block) + for idx, kv_cache_raw_tensor in enumerate( + kv_cache_raw_tensors[layer_name]): + if use_mla_op: + cache_shape = [ + *(kv_cache_shape[1:-1]), + kv_lora_rank if idx == 0 else qk_rope_head_dim + ] + cache_block = kv_cache_raw_tensor.view( + cache_shape).permute(*inv_order[1:]) + else: + cache_block = kv_cache_raw_tensor.view(kv_cache_shape[1:]).permute(*inv_order[1:]) + if fa3_quant: + num_blocks, block_size, _, _ = cache_block.shape + cache_block = ops.reshape(cache_block, (num_blocks, block_size, -1)) + from mindspore.common.api import _pynative_executor + cache_block_nz = ops.auto_generate.format_cast(cache_block, 29) + _pynative_executor.sync() + import gc + del cache_block + gc.collect() + kv_cache_layer.append(cache_block_nz) + else: + kv_cache_layer.append(cache_block) kv_caches[layer_name] = mutable(tuple(kv_cache_layer)) else: raise NotImplementedError @@ -524,6 +602,9 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: forward_ctx = self.vllm_config.compilation_config.static_forward_context block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla + use_mla_op = bool(self.vllm_config.additional_config and self.vllm_config.additional_config.get('use_mla_op') == 1) + fa3_quant = bool(self.vllm_config.additional_config and self.vllm_config.additional_config.get('fa3_quant') == 1) + fa3_no_quant_layers = get_fa3_no_quant_layers() kv_cache_spec: dict[str, KVCacheSpec] = {} for layer_name, attn_module in forward_ctx.items(): # vllm-mindspore AttentionWrapper is not an Attention isinstance @@ -538,12 +619,24 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: sliding_window=attn_module.sliding_window, use_mla=use_mla) else: - kv_cache_spec[layer_name] = FullAttentionSpec( + kv_cache_dtype = self.kv_cache_dtype + fa3_quant_layer = layer_name not in fa3_no_quant_layers + if fa3_quant and not fa3_quant_layer: + kv_cache_dtype = self.vllm_config.model_config.dtype + if use_mla_op and fa3_quant and fa3_quant_layer: + kv_cache_spec[layer_name] = MLAQuantFullAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, + dtype=kv_cache_dtype, use_mla=use_mla) + else: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=kv_cache_dtype, + use_mla=use_mla) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache.