From 188b3e1cc164ffdcf625632343a025a165c75cfe Mon Sep 17 00:00:00 2001 From: superxf Date: Tue, 5 Aug 2025 18:33:50 +0800 Subject: [PATCH] support 310p qwen3 mcore --- vllm_mindspore/__init__.py | 6 +- vllm_mindspore/config.py | 5 + .../model_executor/models/mf_models/config.py | 20 ++- .../models/mf_models/mindformers.py | 14 +- .../model_executor/models/model_base.py | 12 +- vllm_mindspore/utils.py | 20 +++ vllm_mindspore/v1/worker/gpu_model_runner.py | 138 +++++++++++++++--- 7 files changed, 184 insertions(+), 31 deletions(-) diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 2b711153..1f96192a 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -376,14 +376,14 @@ from vllm_mindspore.v1.worker.gpu_model_runner import _update_states vllm.v1.worker.gpu_model_runner.GPUModelRunner._update_states = _update_states from vllm_mindspore.v1.worker.gpu_model_runner import ( - _allocate_kv_cache_tensors, - get_kv_cache_spec, -) + _allocate_kv_cache_tensors, get_kv_cache_spec, initialize_kv_cache_tensors) vllm.v1.worker.gpu_model_runner.GPUModelRunner._allocate_kv_cache_tensors = ( _allocate_kv_cache_tensors) vllm.v1.worker.gpu_model_runner.GPUModelRunner.get_kv_cache_spec = ( get_kv_cache_spec) +vllm.v1.worker.gpu_model_runner.GPUModelRunner.initialize_kv_cache_tensors = ( + initialize_kv_cache_tensors) from vllm_mindspore.v1.worker.gpu_model_runner import _reshape_kv_cache_tensors vllm.v1.worker.gpu_model_runner.GPUModelRunner._reshape_kv_cache_tensors = ( diff --git a/vllm_mindspore/config.py b/vllm_mindspore/config.py index c464227b..8e572249 100644 --- a/vllm_mindspore/config.py +++ b/vllm_mindspore/config.py @@ -35,6 +35,8 @@ from vllm.config import (_STR_DTYPE_TO_TORCH_DTYPE, CompilationConfig, from vllm.logger import init_logger from vllm.utils import random_uuid +from vllm_mindspore.utils import is_310p + logger = init_logger(__name__) @@ -240,6 +242,9 @@ def _get_and_verify_dtype( else: raise ValueError(f"Unknown dtype: {dtype}") + if torch_dtype == torch.bfloat16 and is_310p(): + torch_dtype = torch.float16 + if torch_dtype != config_dtype: if torch_dtype == torch.float32: # Upcasting to float32 is allowed. diff --git a/vllm_mindspore/model_executor/models/mf_models/config.py b/vllm_mindspore/model_executor/models/mf_models/config.py index a658d0b3..d1fc68b9 100644 --- a/vllm_mindspore/model_executor/models/mf_models/config.py +++ b/vllm_mindspore/model_executor/models/mf_models/config.py @@ -19,6 +19,8 @@ from mindformers.tools.register.config import MindFormerConfig from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm_mindspore.utils import is_310p + logger = init_logger(__name__) @@ -70,7 +72,7 @@ source_path: Specifies the path to a configuration parameter in VllmConfig, value_or_function: Specifies the default value for the configuration parameter or a partial function for computing configuration values. """ -# yapf: disable +# yapf: disable # noqa:ERA001 # flake8: noqa: E501 MF_CTX_MAPPING = { 'run_mode': (None, "predict"), @@ -94,7 +96,15 @@ MF_MODEL_COMMON_MAPPING = { 'model.model_config.params_dtype': (None, 'bfloat16'), 'model.model_config.router_dense_type': (None, 'bfloat16'), } -# yapf: enable + +MF_MODEL_COMMON_MAPPING_310p = { + 'model.model_config.compute_dtype': ('model_config.hf_config.torch_dtype', 'float16'), + 'model.model_config.layernorm_compute_dtype': (None, 'float16'), + 'model.model_config.rotary_dtype': (None, 'float16'), + 'model.model_config.params_dtype': (None, 'float16'), + 'model.model_config.router_dense_type': (None, 'float16'), +} +# yapf: enable # noqa:ERA001 # model default config MODEL_RELATED_MAPPING = { @@ -211,7 +221,11 @@ def gen_mf_config(vllm_config: VllmConfig): target_config.set_value( 'model.model_config', MindFormerConfig(**gen_model_config_dict(vllm_config))) - transform_config(MF_MODEL_COMMON_MAPPING, vllm_config, target_config) + if is_310p(): + transform_config(MF_MODEL_COMMON_MAPPING_310p, vllm_config, + target_config) + else: + transform_config(MF_MODEL_COMMON_MAPPING, vllm_config, target_config) # Update target config with additional config. # The configuration hierarchy in the additional config must match the # hierarchy structure of the MindFormers YAML configuration file. diff --git a/vllm_mindspore/model_executor/models/mf_models/mindformers.py b/vllm_mindspore/model_executor/models/mf_models/mindformers.py index 8a5d7b94..611dfe7c 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mindformers.py +++ b/vllm_mindspore/model_executor/models/mf_models/mindformers.py @@ -38,6 +38,7 @@ from vllm_mindspore.model_executor.models.attention_mask import ( from vllm_mindspore.model_executor.models.mf_models.config import gen_mf_config from vllm_mindspore.model_executor.models.model_base import (AttentionWrapper, MsModelBase) +from vllm_mindspore.utils import is_310p logger = init_logger(__name__) @@ -49,6 +50,7 @@ class MindFormersForCausalLM(MsModelBase): self.set_flags = False self.max_model_len = vllm_config.model_config.max_model_len self.hf_config = vllm_config.model_config.hf_config + self.lm_head_graph = None mf_config = gen_mf_config(vllm_config) mf_config.load_checkpoint = self.get_model_path() @@ -191,16 +193,22 @@ class MindFormersForCausalLM(MsModelBase): and selected_token_indices.numel() <= 0: logits = ms.mint.zeros((0, self.hf_config.vocab_size), dtype=self.hf_config.torch_dtype) + return logits else: hidden_states = hidden_states.reshape( (-1, hidden_states.shape[-1])) hidden_states = hidden_states.index_select( 0, selected_token_indices) - logits = self.lm_head(hidden_states) - logits = logits.view(-1, logits.shape[-1]) + if is_310p(): + # To get better performance in 310p, the lm head should run + # in O0 mode to avoid transdata, 910 keep the original process. + if self.lm_head_graph is None: + self.lm_head_graph = ms.jit(function=self.lm_head, + jit_level="O0") + logits = self.lm_head_graph(hidden_states) else: logits = self.lm_head(hidden_states) - logits = logits.view(-1, logits.shape[-1]) + logits = logits.view(-1, logits.shape[-1]) return logits def sample( diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 58178292..30e14f39 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -34,7 +34,7 @@ from vllm.sequence import IntermediateTensors from vllm_mindspore.model_executor.models.attention_mask import ( LowerTriangularMask) from vllm_mindspore.model_executor.utils import set_model_context -from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE +from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE, create_kv_cache from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata @@ -48,10 +48,12 @@ class AttentionWrapper: head_size = vllm_config.model_config.get_head_size() num_block = 0 self.kv_shape = [num_block, block_size, num_kv_heads, head_size] - self.kv_cache = [( - ms.mint.zeros(self.kv_shape, dtype=vllm_config.model_config.dtype), - ms.mint.zeros(self.kv_shape, dtype=vllm_config.model_config.dtype), - ) for _ in range(vllm_config.parallel_config.pipeline_parallel_size)] + self.kv_cache = [ + (create_kv_cache(self.kv_shape, vllm_config.model_config.dtype), + create_kv_cache(self.kv_shape, vllm_config.model_config.dtype)) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + self.attn_type = AttentionType.DECODER # add for v1 diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index 179dc93f..c971f944 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -56,6 +56,26 @@ STR_DTYPE_TO_MS_DTYPE = { "fp8_e5m2": ms.uint8, } +FORMAT_TYPE = { + "nz": 29, +} + + +def create_kv_cache(kv_shape, dtype): + if is_310p(): + if len(kv_shape) != 4: + raise ValueError(f"Format_cast op need kv_cache shape be" + f"(batch_size, num_heads, seq_len, head_dim), " + f"but got {len(kv_shape)} dimensions: {kv_shape}") + + batch_size, num_heads, seq_len, head_dim = kv_shape + reshaped_for_nz = (batch_size, num_heads, seq_len * head_dim) + zeros_tensor = ms.mint.zeros(reshaped_for_nz, dtype=dtype) + + return ms.ops.auto_generate.format_cast(zeros_tensor, + FORMAT_TYPE['nz']) + return ms.mint.zeros(kv_shape, dtype=dtype) + def get_valid_dtype(dtype): if isinstance(dtype, str): diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index 12e0b0c4..c9ab37ef 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -22,20 +22,24 @@ from typing import Any, Optional import mindspore as ms import numpy as np +import torch from mindspore import Generator as msGenerator from mindspore import Tensor, mint, mutable from vllm.attention import AttentionType from vllm.logger import init_logger from vllm.sampling_params import SamplingType -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, - SlidingWindowSpec) +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheSpec, SlidingWindowSpec) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState +from vllm.v1.worker.utils import initialize_kv_cache_for_kv_sharing from vllm_mindspore.model_executor.layers.rotary_embedding import ( InferMRotaryEmbedding as MRotaryEmbedding) -from vllm_mindspore.utils import get_dtype_size, get_valid_dtype +from vllm_mindspore.utils import (create_kv_cache, get_dtype_size, + get_valid_dtype, is_310p) logger = init_logger(__name__) @@ -62,9 +66,10 @@ def _prepare_inputs( # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) - - # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] - # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + """ + cu_num_tokens: [2, 5, 3] -> [2, 7, 10] + arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + """ cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens) # Get positions. @@ -208,6 +213,66 @@ def create_block(shape, dtype, name=None, device=None): return blocks +def _allocate_nz_kv_cache_tensors(self, kv_cache_config): + """ + Initializes and reshape the KV cache buffer with the correct size. + The buffer needs to be convert to nz format for 310p. + + Args: + kv_cache_config: The KV cache config + Returns: + dict[str, Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ + kv_caches: dict[str, tuple] = {} + + layer_to_group_info = { + layer_name: (i, group.kv_cache_spec) + for i, group in enumerate(kv_cache_config.kv_cache_groups) + for layer_name in group.layer_names + } + + use_mla_op = bool( + self.vllm_config.additional_config + and self.vllm_config.additional_config.get('use_mla_op') == 1) + if use_mla_op: + logger.error("For 310p, mla kv cache not supported") + raise NotImplementedError + + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + if not kv_cache_tensor.shared_by: + continue + + rep_layer_name = kv_cache_tensor.shared_by[0] + group_idx, kv_cache_spec = layer_to_group_info[rep_layer_name] + if not isinstance(kv_cache_spec, FullAttentionSpec): + raise NotImplementedError + + attn_backend = self.attn_backends[group_idx] + target_dtype = get_valid_dtype(kv_cache_spec.dtype) + + num_blocks = kv_cache_tensor.size // kv_cache_spec.page_size_bytes + kv_cache_shape = attn_backend.get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size) + + reshaped_layer_tensors = [] + coef = 1 if kv_cache_spec.use_mla else 2 + for _ in range(coef): + reshaped_layer_tensors.append( + create_kv_cache(kv_cache_shape[1:], target_dtype)) + + final_kv_tuple = mutable(tuple(reshaped_layer_tensors)) + for layer_name in kv_cache_tensor.shared_by: + kv_caches[layer_name] = final_kv_tuple + + all_layers = set(layer_to_group_info.keys()) + if all_layers != set(kv_caches.keys()): + raise RuntimeError("Some layers were not initialized") + + return kv_caches + + def _allocate_kv_cache_tensors(self, kv_cache_config): """ Initializes the KV cache buffer with the correct size. The buffer needs @@ -218,7 +283,7 @@ def _allocate_kv_cache_tensors(self, kv_cache_config): Returns: dict[str, Tensor]: A map between layer names to their corresponding memory buffer for KV cache. - """ + """ kv_cache_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec use_mla = kv_cache_spec.use_mla dtype = kv_cache_spec.dtype @@ -239,14 +304,16 @@ def _allocate_kv_cache_tensors(self, kv_cache_config): raw_tensors = [] raw_tensor_shape = kv_cache_tensor.size // dtype_size // coef for i in range(coef): - # Formulas for calculating each parameter: - # 1. page_size = coef * self.block_size * self.num_kv_heads * - # self.head_size * get_dtype_size(self.dtype) - # 2. num_blocks = kv_cache_tensors.size / page_size - # 3. kv_cache_tensors.size = num_blocks * (coef * - # self.block_size * self.num_kv_heads * self.head_size * - # get_dtype_size(self.dtype)) - # 4. kv cache shape: num_blocks, block_size, num_kv_heads, head_size + """ + Formulas for calculating each parameter: + 1. page_size = coef * self.block_size * self.num_kv_heads * + self.head_size * get_dtype_size(self.dtype) + 2. num_blocks = kv_cache_tensors.size / page_size + 3. kv_cache_tensors.size = num_blocks * (coef * + self.block_size * self.num_kv_heads * self.head_size * + get_dtype_size(self.dtype)) + 4. kv cache shape: num_blocks, block_size, num_kv_heads, head_size + """ raw_tensors.extend( [mint.zeros(raw_tensor_shape, dtype=target_dtype)] if not use_mla_op else [ @@ -350,6 +417,41 @@ def _reshape_kv_cache_tensors( return kv_caches +def initialize_kv_cache_tensors( + self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + """ + Initialize the memory buffer for KV cache. + + Args: + kv_cache_config: The KV cache config + Returns: + Dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ + if is_310p(): + kv_caches = _allocate_nz_kv_cache_tensors(self, kv_cache_config) + else: + # Initialize the memory buffer for KV cache + kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) + # Change the memory buffer to the desired shape + kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, + kv_cache_raw_tensors) + + # Setup `kv_cache_config` and `kv_caches` for models + # with cross-layer KV sharing + if self.shared_kv_cache_layers: + initialize_kv_cache_for_kv_sharing( + self.shared_kv_cache_layers, + kv_cache_config.kv_cache_groups, + kv_caches, + ) + + bind_kv_cache(kv_caches, + self.vllm_config.compilation_config.static_forward_context, + self.kv_caches) + return kv_caches + + def _update_states(self, scheduler_output) -> None: """Update the cached states and the persistent batch with the scheduler output. @@ -562,8 +664,10 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: use_mla = self.vllm_config.model_config.use_mla kv_cache_spec: dict[str, KVCacheSpec] = {} for layer_name, attn_module in forward_ctx.items(): - # vllm-mindspore AttentionWrapper is not an Attention isinstance - # assert isinstance(attn_module, Attention) + """ + vllm-mindspore AttentionWrapper is not an Attention isinstance + assert isinstance(attn_module, Attention) + """ if attn_module.attn_type == AttentionType.DECODER: if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( -- Gitee