From 0bcd43e5da0d79747a41fc0c6ce8566c9c2b1e0f Mon Sep 17 00:00:00 2001 From: superxf Date: Thu, 22 May 2025 16:54:07 +0800 Subject: [PATCH] support 310p nz --- .../model_executor/models/model_base.py | 88 +++++++++++++------ vllm_mindspore/utils.py | 36 ++++++++ vllm_mindspore/v1/worker/gpu_model_runner.py | 9 +- vllm_mindspore/worker/cache_engine.py | 9 +- 4 files changed, 112 insertions(+), 30 deletions(-) diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 0d933a2d..073b77c0 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -30,8 +30,8 @@ from vllm.attention.layer import Attention import torch -from mindspore import Tensor, nn, mutable - +from mindspore import Tensor, nn, mutable, ops +from vllm_mindspore.utils import is_310p class Fake_Attention: def __init__(self): @@ -42,14 +42,24 @@ class Fake_Attention: ) head_size = vllm_config.model_config.get_head_size() num_block = 0 - self.kv_shape = [num_block, block_size, num_kv_heads, head_size] - self.kv_cache = [ - ( - torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), - torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), - ) - for _ in range(vllm_config.parallel_config.pipeline_parallel_size) - ] + if is_310p(): + self.kv_shape = [num_block, block_size, num_kv_heads * head_size] + self.kv_cache = [ + ( + ops.auto_generate.format_cast(torch.zeros(self.kv_shape, dtype=torch.float16, device="Ascend"), 29), + ops.auto_generate.format_cast(torch.zeros(self.kv_shape, dtype=torch.float16, device="Ascend"), 29), + ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + else: + self.kv_shape = [num_block, block_size, num_kv_heads, head_size] + self.kv_cache = [ + ( + torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), + torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), + ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] self.attn_type = AttentionType.DECODER @@ -57,10 +67,18 @@ class Fake_MLA(Fake_Attention): def __init__(self): super().__init__() vllm_config = get_current_vllm_config() - self.kv_cache = [ - (torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"),) - for _ in range(vllm_config.parallel_config.pipeline_parallel_size) - ] + if is_310p(): + self.kv_cache = [ + ( + ops.auto_generate.format_cast((torch.zeros(self.kv_shape, dtype=torch.float16, device="Ascend")), 29), + ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + else: + self.kv_cache = [ + (torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"),) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] class Fake_Attention_V1(Attention): @@ -72,14 +90,24 @@ class Fake_Attention_V1(Attention): ) head_size = vllm_config.model_config.get_head_size() num_block = 0 - self.kv_shape = [num_block, block_size, num_kv_heads, head_size] - self.kv_cache = [ - ( - torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), - torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), - ) - for _ in range(vllm_config.parallel_config.pipeline_parallel_size) - ] + if is_310p(): + self.kv_shape = [num_block, block_size, num_kv_heads * head_size] + self.kv_cache = [ + ( + ops.auto_generate.format_cast(torch.zeros(self.kv_shape, dtype=torch.float16, device="Ascend"), 29), + ops.auto_generate.format_cast(torch.zeros(self.kv_shape, dtype=torch.float16, device="Ascend"), 29), + ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + else: + self.kv_shape = [num_block, block_size, num_kv_heads, head_size] + self.kv_cache = [ + ( + torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), + torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), + ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] self.attn_type = AttentionType.DECODER self.num_block = num_block self.num_kv_heads = num_kv_heads @@ -93,10 +121,18 @@ class Fake_MLA_V1(Fake_Attention_V1): def __init__(self): super().__init__() vllm_config = get_current_vllm_config() - self.kv_cache = [ - (torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"),) - for _ in range(vllm_config.parallel_config.pipeline_parallel_size) - ] + if is_310p(): + self.kv_cache = [ + ( + ops.auto_generate.format_cast((torch.zeros(self.kv_shape, dtype=torch.float16, device="Ascend")), 29), + ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + else: + self.kv_cache = [ + (torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"),) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] class MsModelBase(): diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index 60cd4af0..b1acea74 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -54,6 +54,22 @@ STR_DTYPE_TO_MS_DTYPE = { "fp8_e5m2": ms.uint8, } +def is_version_ge(current_version, base_version): + """ + return current_version >= base_version. + Check whether the current version is higher than or equal to the base version. + for current_version: 1.8.1, base_version: 1.11.0, it return False. + """ + version_split_char = '.' + if version_split_char not in base_version or version_split_char not in current_version: + raise ValueError("The version string will contain the `.`." + "For example, current_version 1.8.1, base_version: 1.11.0.") + for x, y in zip(current_version.split(version_split_char), base_version.split(version_split_char)): + if not x.isdigit() or not y.isdigit(): + continue + if int(x) != int(y): + return int(x) >= int(y) + return True def get_valid_dtype(dtype): if isinstance(dtype, str): @@ -196,6 +212,26 @@ def ascend_device_count_stateless() -> int: return len(avl_devices) +def get_ascend_soc_version(): + """Get ascend soc version.""" + if is_version_ge(ms.__version__, "2.2.0"): + from mindspore._c_expression import MSContext + return MSContext.get_instance().get_ascend_soc_version() + ascend_chip_type = os.getenv("ASCEND_CHIP_TYPE", "UNSET") + if ascend_chip_type not in ["910a", "910b", "UNSET"]: + raise EnvironmentError(f"ASCEND_CHIP_TYPE should be in ['910a', '910b'],but get {ascend_chip_type}") + if ascend_chip_type == "UNSET": + logger.info("Environment variables need to be set manually to obtain the chip type," + "which can be set as follows: \n" + "For Atlas 800, run 'export ASCEND_CHIP_TYPE=910a' before the program runs.\n" + "For Atlas 800T A2, run 'export ASCEND_CHIP_TYPE=910b' before the program runs.\n" + "If you need to get chip information automatically, MindSpore 2.2 and above is recommended") + return ascend_chip_type + +def is_310p(): + device = get_ascend_soc_version() + return device in ['310p', 'ascend310p'] + def ascend_is_initialized(): # Just return true for check. diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index a21a2f73..daba4932 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -4,12 +4,12 @@ import gc import numpy as np import torch -from mindspore import mutable +from mindspore import mutable, ops import mindspore as ms from vllm_mindspore.v1.attention.backends.flash_attn import (FlashAttentionMetadata, FlashAttentionBackend, MLABackend) -from vllm_mindspore.utils import get_valid_dtype +from vllm_mindspore.utils import get_valid_dtype, is_310p from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.utils import bind_kv_cache @@ -171,6 +171,8 @@ def _prepare_inputs( def create_block(shape, dtype, name=None, device=None): from mindspore import mint blocks = mint.empty(shape, dtype=dtype, device=device) + if is_310p(): + blocks = ops.auto_generate.format_cast(blocks, 29) return blocks def initialize_kv_cache(self, kv_cache_config) -> None: @@ -205,6 +207,9 @@ def initialize_kv_cache(self, kv_cache_config) -> None: kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + if is_310p(): + kv_cache_shape = (kv_cache_shape[0], kv_cache_shape[1], + kv_cache_shape[2], kv_cache_shape[3] * kv_cache_shape[4]) dtype = kv_cache_spec.dtype dtype = get_valid_dtype(dtype) current_cache = [] diff --git a/vllm_mindspore/worker/cache_engine.py b/vllm_mindspore/worker/cache_engine.py index 2df44ee5..792f43b7 100644 --- a/vllm_mindspore/worker/cache_engine.py +++ b/vllm_mindspore/worker/cache_engine.py @@ -18,16 +18,18 @@ """CacheEngine class for managing the KV cache.""" import mindspore as ms -from mindspore import mutable, mint +from mindspore import mutable, mint, ops from typing import List from vllm.logger import init_logger -from vllm_mindspore.utils import MsKVCache, get_valid_dtype +from vllm_mindspore.utils import MsKVCache, get_valid_dtype, is_310p logger = init_logger(__name__) def create_block(shape, dtype, name=None, device=None): blocks = mint.empty(shape, dtype=dtype, device=device) + if is_310p(): + blocks = ops.auto_generate.format_cast(blocks, 29) return blocks @@ -39,6 +41,9 @@ def ms_allocate_kv_cache( """Allocates KV cache on the specified device.""" kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, self.block_size, self.num_kv_heads, self.head_size) + if is_310p(): + kv_cache_shape = (kv_cache_shape[0], kv_cache_shape[1], + kv_cache_shape[2], kv_cache_shape[3] * kv_cache_shape[4]) kv_cache: List[MsKVCache] = [] self.dtype = get_valid_dtype(self.dtype) -- Gitee