diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 0d933a2db438919cf68388833e1f95c572436c81..073b77c039a8b571d1183ab736818f3d9fe6fb56 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -30,8 +30,8 @@ from vllm.attention.layer import Attention import torch -from mindspore import Tensor, nn, mutable - +from mindspore import Tensor, nn, mutable, ops +from vllm_mindspore.utils import is_310p class Fake_Attention: def __init__(self): @@ -42,14 +42,24 @@ class Fake_Attention: ) head_size = vllm_config.model_config.get_head_size() num_block = 0 - self.kv_shape = [num_block, block_size, num_kv_heads, head_size] - self.kv_cache = [ - ( - torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), - torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), - ) - for _ in range(vllm_config.parallel_config.pipeline_parallel_size) - ] + if is_310p(): + self.kv_shape = [num_block, block_size, num_kv_heads * head_size] + self.kv_cache = [ + ( + ops.auto_generate.format_cast(torch.zeros(self.kv_shape, dtype=torch.float16, device="Ascend"), 29), + ops.auto_generate.format_cast(torch.zeros(self.kv_shape, dtype=torch.float16, device="Ascend"), 29), + ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + else: + self.kv_shape = [num_block, block_size, num_kv_heads, head_size] + self.kv_cache = [ + ( + torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), + torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), + ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] self.attn_type = AttentionType.DECODER @@ -57,10 +67,18 @@ class Fake_MLA(Fake_Attention): def __init__(self): super().__init__() vllm_config = get_current_vllm_config() - self.kv_cache = [ - (torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"),) - for _ in range(vllm_config.parallel_config.pipeline_parallel_size) - ] + if is_310p(): + self.kv_cache = [ + ( + ops.auto_generate.format_cast((torch.zeros(self.kv_shape, dtype=torch.float16, device="Ascend")), 29), + ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + else: + self.kv_cache = [ + (torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"),) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] class Fake_Attention_V1(Attention): @@ -72,14 +90,24 @@ class Fake_Attention_V1(Attention): ) head_size = vllm_config.model_config.get_head_size() num_block = 0 - self.kv_shape = [num_block, block_size, num_kv_heads, head_size] - self.kv_cache = [ - ( - torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), - torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), - ) - for _ in range(vllm_config.parallel_config.pipeline_parallel_size) - ] + if is_310p(): + self.kv_shape = [num_block, block_size, num_kv_heads * head_size] + self.kv_cache = [ + ( + ops.auto_generate.format_cast(torch.zeros(self.kv_shape, dtype=torch.float16, device="Ascend"), 29), + ops.auto_generate.format_cast(torch.zeros(self.kv_shape, dtype=torch.float16, device="Ascend"), 29), + ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + else: + self.kv_shape = [num_block, block_size, num_kv_heads, head_size] + self.kv_cache = [ + ( + torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), + torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), + ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] self.attn_type = AttentionType.DECODER self.num_block = num_block self.num_kv_heads = num_kv_heads @@ -93,10 +121,18 @@ class Fake_MLA_V1(Fake_Attention_V1): def __init__(self): super().__init__() vllm_config = get_current_vllm_config() - self.kv_cache = [ - (torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"),) - for _ in range(vllm_config.parallel_config.pipeline_parallel_size) - ] + if is_310p(): + self.kv_cache = [ + ( + ops.auto_generate.format_cast((torch.zeros(self.kv_shape, dtype=torch.float16, device="Ascend")), 29), + ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + else: + self.kv_cache = [ + (torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"),) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] class MsModelBase(): diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index 60cd4af040fc9e9bda617eb8b6cd5d5130ef765f..b1acea74df82d7b85b5956e99ce1931ae73dab2d 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -54,6 +54,22 @@ STR_DTYPE_TO_MS_DTYPE = { "fp8_e5m2": ms.uint8, } +def is_version_ge(current_version, base_version): + """ + return current_version >= base_version. + Check whether the current version is higher than or equal to the base version. + for current_version: 1.8.1, base_version: 1.11.0, it return False. + """ + version_split_char = '.' + if version_split_char not in base_version or version_split_char not in current_version: + raise ValueError("The version string will contain the `.`." + "For example, current_version 1.8.1, base_version: 1.11.0.") + for x, y in zip(current_version.split(version_split_char), base_version.split(version_split_char)): + if not x.isdigit() or not y.isdigit(): + continue + if int(x) != int(y): + return int(x) >= int(y) + return True def get_valid_dtype(dtype): if isinstance(dtype, str): @@ -196,6 +212,26 @@ def ascend_device_count_stateless() -> int: return len(avl_devices) +def get_ascend_soc_version(): + """Get ascend soc version.""" + if is_version_ge(ms.__version__, "2.2.0"): + from mindspore._c_expression import MSContext + return MSContext.get_instance().get_ascend_soc_version() + ascend_chip_type = os.getenv("ASCEND_CHIP_TYPE", "UNSET") + if ascend_chip_type not in ["910a", "910b", "UNSET"]: + raise EnvironmentError(f"ASCEND_CHIP_TYPE should be in ['910a', '910b'],but get {ascend_chip_type}") + if ascend_chip_type == "UNSET": + logger.info("Environment variables need to be set manually to obtain the chip type," + "which can be set as follows: \n" + "For Atlas 800, run 'export ASCEND_CHIP_TYPE=910a' before the program runs.\n" + "For Atlas 800T A2, run 'export ASCEND_CHIP_TYPE=910b' before the program runs.\n" + "If you need to get chip information automatically, MindSpore 2.2 and above is recommended") + return ascend_chip_type + +def is_310p(): + device = get_ascend_soc_version() + return device in ['310p', 'ascend310p'] + def ascend_is_initialized(): # Just return true for check. diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index a21a2f73e889169e6d30ca2d2bdd23bb03bcc29b..daba49328447123df249a23db4bbbf3d09e6720e 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -4,12 +4,12 @@ import gc import numpy as np import torch -from mindspore import mutable +from mindspore import mutable, ops import mindspore as ms from vllm_mindspore.v1.attention.backends.flash_attn import (FlashAttentionMetadata, FlashAttentionBackend, MLABackend) -from vllm_mindspore.utils import get_valid_dtype +from vllm_mindspore.utils import get_valid_dtype, is_310p from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.utils import bind_kv_cache @@ -171,6 +171,8 @@ def _prepare_inputs( def create_block(shape, dtype, name=None, device=None): from mindspore import mint blocks = mint.empty(shape, dtype=dtype, device=device) + if is_310p(): + blocks = ops.auto_generate.format_cast(blocks, 29) return blocks def initialize_kv_cache(self, kv_cache_config) -> None: @@ -205,6 +207,9 @@ def initialize_kv_cache(self, kv_cache_config) -> None: kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + if is_310p(): + kv_cache_shape = (kv_cache_shape[0], kv_cache_shape[1], + kv_cache_shape[2], kv_cache_shape[3] * kv_cache_shape[4]) dtype = kv_cache_spec.dtype dtype = get_valid_dtype(dtype) current_cache = [] diff --git a/vllm_mindspore/worker/cache_engine.py b/vllm_mindspore/worker/cache_engine.py index 2df44ee55d663b3b14e70932d1c76afd6cbdaaae..792f43b7940a9351c96292dac2b5802595b3e728 100644 --- a/vllm_mindspore/worker/cache_engine.py +++ b/vllm_mindspore/worker/cache_engine.py @@ -18,16 +18,18 @@ """CacheEngine class for managing the KV cache.""" import mindspore as ms -from mindspore import mutable, mint +from mindspore import mutable, mint, ops from typing import List from vllm.logger import init_logger -from vllm_mindspore.utils import MsKVCache, get_valid_dtype +from vllm_mindspore.utils import MsKVCache, get_valid_dtype, is_310p logger = init_logger(__name__) def create_block(shape, dtype, name=None, device=None): blocks = mint.empty(shape, dtype=dtype, device=device) + if is_310p(): + blocks = ops.auto_generate.format_cast(blocks, 29) return blocks @@ -39,6 +41,9 @@ def ms_allocate_kv_cache( """Allocates KV cache on the specified device.""" kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, self.block_size, self.num_kv_heads, self.head_size) + if is_310p(): + kv_cache_shape = (kv_cache_shape[0], kv_cache_shape[1], + kv_cache_shape[2], kv_cache_shape[3] * kv_cache_shape[4]) kv_cache: List[MsKVCache] = [] self.dtype = get_valid_dtype(self.dtype)