diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 5892937a8957829c8a47ae1df13ada581137f7e6..d66c0e5a20e7fdaa720c17a9eb4500a930c453d4 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -333,6 +333,18 @@ from vllm_mindspore.v1.worker.gpu_worker import compile_or_warm_up_model from vllm.v1.worker.gpu_worker import Worker Worker.compile_or_warm_up_model = compile_or_warm_up_model +import vllm.distributed.communication_op +import vllm.worker.worker_base +from vllm_mindspore.distributed.communication_op import ds_broadcast_tensor_dict +vllm.distributed.communication_op.broadcast_tensor_dict = ds_broadcast_tensor_dict +vllm.worker.worker_base.broadcast_tensor_dict = ds_broadcast_tensor_dict + +import vllm.distributed.parallel_state +from vllm_mindspore.distributed.parallel_state import ds_gc_broadcast_tensor_dict, ds_init_model_parallel_group + +vllm.distributed.parallel_state.GroupCoordinator.broadcast_tensor_dict = ds_gc_broadcast_tensor_dict +vllm.distributed.parallel_state.init_model_parallel_group = ds_init_model_parallel_group + from .utils import check_ready from vllm_mindspore.engine.multiprocessing.engine import cleanup diff --git a/vllm_mindspore/distributed/communication_op.py b/vllm_mindspore/distributed/communication_op.py index 00447432e546516bf4d8629c374ac36e491041e8..0ca0408ded30c29cabfe11056d7a6d50f8ef01fe 100644 --- a/vllm_mindspore/distributed/communication_op.py +++ b/vllm_mindspore/distributed/communication_op.py @@ -21,6 +21,7 @@ # 不要去照搬mindspeed的, 因为训练当中包含太多的特性, 推理只需要非常简单的通信,可以提升性能。 from typing import Any, Dict, Optional, Union +import torch import mindspore as ms from mindspore import Tensor, nn, ops @@ -95,6 +96,14 @@ def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[Tensor, # return get_tp_group().broadcast_tensor_dict(tensor_dict, src) +def ds_broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor, + Any]]] = None, + src: int = 0): + if not torch.distributed.is_initialized(): + return tensor_dict + return get_tp_group().broadcast_tensor_dict(tensor_dict, src, group=get_tp_group().cpu_group) + + def send_to_next_pp_rank(tensor): send(tensor, next_pp_rank(), group=get_pp_group()) diff --git a/vllm_mindspore/distributed/parallel_state.py b/vllm_mindspore/distributed/parallel_state.py new file mode 100644 index 0000000000000000000000000000000000000000..ccb78bcc6707febf61b66aaee2ed3c22e5a8b526 --- /dev/null +++ b/vllm_mindspore/distributed/parallel_state.py @@ -0,0 +1,109 @@ +import torch +import torch.distributed +from torch.distributed import ProcessGroup + +from typing import (Any, Dict, List, Optional, Tuple, + Union) +from vllm.distributed.parallel_state import _split_tensor_dict, TensorMetadata, GroupCoordinator + +def ds_gc_broadcast_tensor_dict( + self, + tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if (not torch.distributed.is_initialized() or self.world_size == 1): + return tensor_dict + + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + + rank_in_group = self.rank_in_group + if rank_in_group == src: + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, + dict), (f"Expecting a dictionary, got {type(tensor_dict)}") + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=group, + async_op=True) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, + src=self.ranks[src], + group=group, + async_op=True) + async_handles.append(handle) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + for async_handle in async_handles: + async_handle.wait() + return tensor_dict + +def ds_init_model_parallel_group( + group_ranks: List[List[int]], + local_rank: int, + backend: str, + use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, +) -> GroupCoordinator: + if group_name == "pp": + backend="gloo" + return GroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + use_device_communicator=True, + use_message_queue_broadcaster=use_message_queue_broadcaster, + group_name=group_name, + ) diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen2.py b/vllm_mindspore/model_executor/models/mf_models/qwen2.py index d871be483b191995863112ee839aa5f2c7656765..556f22968469f6c2d03cb1afc9fea76b85c777ec 100644 --- a/vllm_mindspore/model_executor/models/mf_models/qwen2.py +++ b/vllm_mindspore/model_executor/models/mf_models/qwen2.py @@ -37,10 +37,26 @@ from vllm_mindspore.model_executor.models.model_base import Fake_Attention, Fake from vllm_mindspore.model_executor.models.mf_models.mf_model_base import MfModelBase from vllm_mindspore.model_executor.models.mf_models.qwen2_weight_processor import Qwen2WeightProcessor +import os +import mindspore as ms logger = init_logger(__name__) +def set_runtime_kernel_launch_group(): + kernel_launch_group = {} + env_kernel_launch_group = os.getenv("EXPERIMENTAL_KERNEL_LAUNCH_GROUP", None) + if env_kernel_launch_group == None: + return + if env_kernel_launch_group is not None: + pairs = env_kernel_launch_group.split(',') + for pair in pairs: + key, val = pair.split(':') + kernel_launch_group[key] = val + thread_num = int(kernel_launch_group.get('thread_num', 2)) + kernel_group_num = int(kernel_launch_group.get('kernel_group_num', 8)) + ms.runtime.set_kernel_launch_group(thread_num=thread_num, kernel_group_num=kernel_group_num) + class Qwen2ForCausalLM(MfModelBase): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super(Qwen2ForCausalLM, self).__init__(vllm_config=vllm_config, prefix=prefix) @@ -60,6 +76,7 @@ class Qwen2ForCausalLM(MfModelBase): for i in range(self.mf_model_config.num_layers): compilation_config.static_forward_context[str(i)] = self.kv_caches[i] + set_runtime_kernel_launch_group() self.set_flags = False def _generate_model_config(self): @@ -70,9 +87,9 @@ class Qwen2ForCausalLM(MfModelBase): self.mf_model_config.return_hidden_states = True # qwen qkv concat will support in next version - self.mf_model_config.qkv_concat = False + # self.mf_model_config.qkv_concat = True setattr(self.mf_model_config, 'npu_mem_size', -1) - self.mf_config.model.model_config.qkv_concat = False + # self.mf_config.model.model_config.qkv_concat = True def _create_network(self): # Initial network diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen2_weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/qwen2_weight_processor.py index 99b59a72194f774a738e14967ec92a8f10c181fb..737422d752781933dfb3ffe6ca7fd53457b5dcc5 100644 --- a/vllm_mindspore/model_executor/models/mf_models/qwen2_weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/qwen2_weight_processor.py @@ -26,6 +26,7 @@ import mindspore as ms from mindspore.communication.management import get_rank from vllm_mindspore.model_executor.models.mf_models.weight_processor import BaseWeightProcessor, EPMethod +from vllm_mindspore.utils import is_310p from vllm.logger import init_logger logger = init_logger(__name__) @@ -41,6 +42,60 @@ class Qwen2WeightProcessor(BaseWeightProcessor): def __init__(self, config, network, is_quant): super().__init__(config, network, is_quant) + self.num_heads = config.model.model_config.num_heads + self.kv_heads = config.model.model_config.n_kv_heads + self.hidden_size = config.model.model_config.hidden_size + self.dtype = ms.float16 if is_310p() else ms.bfloat16 + + def qkv_concat_hf2mg(self, qkv_weights: np.ndarray, num_heads, n_kv_heads, hidden_size): + """ + convert qkv_concat weight with huggingface format to megatron format. + """ + w, h = qkv_weights.shape + n_rep = num_heads // n_kv_heads + q_channel = hidden_size // self.tp_group_size + kv_channel = (hidden_size // n_rep) // self.tp_group_size + q_weight = qkv_weights[: q_channel, :] + k_weight = qkv_weights[q_channel: q_channel + kv_channel, :] + v_weight = qkv_weights[q_channel + kv_channel: q_channel + 2 * kv_channel, :] + q_w_reshape = q_weight.reshape(n_kv_heads // self.tp_group_size, hidden_size // n_kv_heads, -1) + k_w_reshape = k_weight.reshape(n_kv_heads // self.tp_group_size, hidden_size // num_heads, -1) + v_w_reshape = v_weight.reshape(n_kv_heads // self.tp_group_size, hidden_size // num_heads, -1) + cat_qkv_weight = np.concatenate((q_w_reshape, k_w_reshape, v_w_reshape), axis=1) + out_qkv_weight = cat_qkv_weight.reshape(w, h) + return out_qkv_weight + + def qkv_bias_concat_hf2mg(self, qkv_bias: np.ndarray, num_heads, n_kv_heads, hidden_size): + """ + convert qkv_concat bias with huggingface format to megatron format. + """ + w = qkv_bias.shape[0] + n_rep = num_heads // n_kv_heads + q_channel = hidden_size // self.tp_group_size + kv_channel = (hidden_size // n_rep) // self.tp_group_size + q_weight = qkv_bias[: q_channel] + k_weight = qkv_bias[q_channel: q_channel + kv_channel] + v_weight = qkv_bias[q_channel + kv_channel: q_channel + 2 * kv_channel] + q_w_reshape = q_weight.reshape(n_kv_heads // self.tp_group_size, hidden_size // n_kv_heads) + k_w_reshape = k_weight.reshape(n_kv_heads // self.tp_group_size, hidden_size // num_heads) + v_w_reshape = v_weight.reshape(n_kv_heads // self.tp_group_size, hidden_size // num_heads) + + cat_qkv_weight = np.concatenate((q_w_reshape, k_w_reshape, v_w_reshape), axis=1) + out_qkv_weight = cat_qkv_weight.reshape(w,) + return out_qkv_weight + + def ffn_concat_hf2mg(self, ffn_weights: np.ndarray, ffn_hidden_size): + """ + convert ffn_concat weight with huggingface format to megatron format. + """ + w, h = ffn_weights.shape + gate_weight = ffn_weights[: w // 2, :] + hidden_weight = ffn_weights[w // 2: w // 2 * 2, :] + gate_w_reshape = gate_weight.reshape(-1, 1, ffn_hidden_size) + hidden_w_reshape = hidden_weight.reshape(-1, 1, ffn_hidden_size) + cat_ffn_weight = np.concatenate((gate_w_reshape, hidden_w_reshape), axis=1) + out_ffn_weight = cat_ffn_weight.reshape(w, h) + return out_ffn_weight def infer_convert_outer_weight(self, src_hf_dir, hf_weight_map): """convert weight not in model""" @@ -51,14 +106,14 @@ class Qwen2WeightProcessor(BaseWeightProcessor): else: np_data, _ = self.get_safetensor_from_file_split_tp_group(embed_tokens_hf_name, src_hf_dir, hf_weight_map, split_axis=0) - self.parameter_dict[embed_tokens_ms_name] = ms.Parameter(ms.from_numpy(np_data).astype(ms.bfloat16), + self.parameter_dict[embed_tokens_ms_name] = ms.Parameter(ms.from_numpy(np_data).astype(self.dtype), name=embed_tokens_ms_name, requires_grad=False) norm_hf_name = "model.norm.weight" norm_ms_name = self.convert_weight_name(norm_hf_name) np_data, _ = self.get_safetensor_from_file(norm_hf_name, src_hf_dir, hf_weight_map) - self.parameter_dict[norm_ms_name] = ms.Parameter(ms.from_numpy(np_data).astype(ms.bfloat16), + self.parameter_dict[norm_ms_name] = ms.Parameter(ms.from_numpy(np_data).astype(self.dtype), name=norm_ms_name, requires_grad=False) @@ -70,7 +125,7 @@ class Qwen2WeightProcessor(BaseWeightProcessor): split_axis=0) else: np_data, _ = self.get_safetensor_from_file(lm_head_hf_name, src_hf_dir, hf_weight_map) - self.parameter_dict[lm_head_ms_name] = ms.Parameter(ms.from_numpy(np_data).astype(ms.bfloat16), + self.parameter_dict[lm_head_ms_name] = ms.Parameter(ms.from_numpy(np_data).astype(self.dtype), name=lm_head_ms_name, requires_grad=False) @@ -93,7 +148,7 @@ class Qwen2WeightProcessor(BaseWeightProcessor): def infer_process_dense_ffn_weight(self, src_hf_dir, layer_id, hf_weight_map): """infer process dense ffn weight""" - ffn_concat = self.config.model.model_config.qkv_concat + ffn_concat = self.config.model.model_config.ffn_concat w1_hf_name = f"model.layers.{layer_id}.mlp.gate_proj.weight" w1_ms_name = self.convert_weight_name(w1_hf_name) w1_ms_param, _ = self.get_safetensor_from_file_split_tp_group(w1_hf_name, src_hf_dir, hf_weight_map, @@ -112,17 +167,18 @@ class Qwen2WeightProcessor(BaseWeightProcessor): if ffn_concat: w_gate_hidden_name = f"model.layers.{layer_id}.feed_forward.w_gate_hidden.weight" w_gate_hidden_param = np.concatenate((w1_ms_param, w3_ms_param), axis=0) + w_gate_hidden_param = self.ffn_concat_hf2mg(w_gate_hidden_param, self.hidden_size) self.parameter_dict[w_gate_hidden_name] = ms.Parameter(w_gate_hidden_param, name=w_gate_hidden_name, requires_grad=False) else: - self.parameter_dict[w1_ms_name] = ms.Parameter(ms.from_numpy(w1_ms_param).astype(ms.bfloat16), + self.parameter_dict[w1_ms_name] = ms.Parameter(ms.from_numpy(w1_ms_param).astype(self.dtype), name=w1_ms_name, requires_grad=False) - self.parameter_dict[w3_ms_name] = ms.Parameter(ms.from_numpy(w3_ms_param).astype(ms.bfloat16), + self.parameter_dict[w3_ms_name] = ms.Parameter(ms.from_numpy(w3_ms_param).astype(self.dtype), name=w3_ms_name, requires_grad=False) - self.parameter_dict[w2_ms_name] = ms.Parameter(ms.from_numpy(w2_ms_param).astype(ms.bfloat16), + self.parameter_dict[w2_ms_name] = ms.Parameter(ms.from_numpy(w2_ms_param).astype(self.dtype), name=w2_ms_name, requires_grad=False) @@ -165,35 +221,37 @@ class Qwen2WeightProcessor(BaseWeightProcessor): if qkv_concat: w_qkv_name = f"model.layers.{layer_id}.attention.w_qkv.weight" w_qkv_param = np.concatenate((wq_ms_param, wk_ms_param, wv_ms_param), axis=0) - w_qkv_param = ms.from_numpy(w_qkv_param).astype(ms.bfloat16) + w_qkv_param = self.qkv_concat_hf2mg(w_qkv_param, self.num_heads, self.kv_heads, self.hidden_size) + w_qkv_param = ms.from_numpy(w_qkv_param).astype(self.dtype) self.parameter_dict[w_qkv_name] = ms.Parameter(w_qkv_param, name=w_qkv_name, requires_grad=False) w_qkv_bias_name = f"model.layers.{layer_id}.attention.w_qkv.bias" w_qkv_bias_param = np.concatenate((wq_bias_ms_param, wk_bias_ms_param, wv_bias_ms_param), axis=0) - w_qkv_bias_param = ms.from_numpy(w_qkv_bias_param).astype(ms.bfloat16) + w_qkv_bias_param = self.qkv_bias_concat_hf2mg(w_qkv_bias_param, self.num_heads, self.kv_heads, self.hidden_size) + w_qkv_bias_param = ms.from_numpy(w_qkv_bias_param).astype(self.dtype) self.parameter_dict[w_qkv_bias_name] = ms.Parameter(w_qkv_bias_param, name=w_qkv_bias_name, requires_grad=False) else: - self.parameter_dict[wq_ms_name] = ms.Parameter(ms.from_numpy(wq_ms_param).astype(ms.bfloat16), + self.parameter_dict[wq_ms_name] = ms.Parameter(ms.from_numpy(wq_ms_param).astype(self.dtype), name=wq_ms_name, requires_grad=False) - self.parameter_dict[wk_ms_name] = ms.Parameter(ms.from_numpy(wk_ms_param).astype(ms.bfloat16), + self.parameter_dict[wk_ms_name] = ms.Parameter(ms.from_numpy(wk_ms_param).astype(self.dtype), name=wk_ms_name, requires_grad=False) - self.parameter_dict[wv_ms_name] = ms.Parameter(ms.from_numpy(wv_ms_param).astype(ms.bfloat16), + self.parameter_dict[wv_ms_name] = ms.Parameter(ms.from_numpy(wv_ms_param).astype(self.dtype), name=wv_ms_name, requires_grad=False) self.parameter_dict[wq_bias_ms_name] = ms.Parameter( - ms.from_numpy(wq_bias_ms_param).astype(ms.bfloat16), + ms.from_numpy(wq_bias_ms_param).astype(self.dtype), name=wq_bias_ms_name, requires_grad=False) self.parameter_dict[wk_bias_ms_name] = ms.Parameter( - ms.from_numpy(wk_bias_ms_param).astype(ms.bfloat16), + ms.from_numpy(wk_bias_ms_param).astype(self.dtype), name=wk_bias_ms_name, requires_grad=False) self.parameter_dict[wv_bias_ms_name] = ms.Parameter( - ms.from_numpy(wv_bias_ms_param).astype(ms.bfloat16), + ms.from_numpy(wv_bias_ms_param).astype(self.dtype), name=wv_bias_ms_name, requires_grad=False) @@ -202,7 +260,7 @@ class Qwen2WeightProcessor(BaseWeightProcessor): wo_ms_name = self.convert_weight_name(wo_hf_name) wo_ms_param, _ = self.get_safetensor_from_file_split_tp_group(wo_hf_name, src_hf_dir, hf_weight_map, split_axis=1) - self.parameter_dict[wo_ms_name] = ms.Parameter(ms.from_numpy(wo_ms_param).astype(ms.bfloat16), + self.parameter_dict[wo_ms_name] = ms.Parameter(ms.from_numpy(wo_ms_param).astype(self.dtype), name=wo_ms_name, requires_grad=False) @@ -215,7 +273,7 @@ class Qwen2WeightProcessor(BaseWeightProcessor): src_hf_dir, hf_weight_map) self.parameter_dict[attention_norm_ms_name] = ms.Parameter( - ms.from_numpy(attention_norm_ms_param).astype(ms.bfloat16), + ms.from_numpy(attention_norm_ms_param).astype(self.dtype), name=attention_norm_ms_name, requires_grad=False) @@ -224,7 +282,7 @@ class Qwen2WeightProcessor(BaseWeightProcessor): ffn_norm_ms_name = self.convert_weight_name(ffn_norm_hf_name) ffn_norm_ms_param, _ = self.get_safetensor_from_file(ffn_norm_hf_name, src_hf_dir, hf_weight_map) self.parameter_dict[ffn_norm_ms_name] = ms.Parameter( - ms.from_numpy(ffn_norm_ms_param).astype(ms.bfloat16), + ms.from_numpy(ffn_norm_ms_param).astype(self.dtype), name=ffn_norm_ms_name, requires_grad=False) diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 0d933a2db438919cf68388833e1f95c572436c81..974dee229380048d0095794e11837ab4bba5c65e 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -27,10 +27,11 @@ from vllm.sequence import IntermediateTensors from vllm.attention.backends.abstract import AttentionType from vllm.forward_context import get_forward_context from vllm.attention.layer import Attention +from vllm_mindspore.utils import FORMAT_TYPE, is_310p import torch -from mindspore import Tensor, nn, mutable +from mindspore import Tensor, nn, mutable, ops class Fake_Attention: @@ -42,14 +43,24 @@ class Fake_Attention: ) head_size = vllm_config.model_config.get_head_size() num_block = 0 - self.kv_shape = [num_block, block_size, num_kv_heads, head_size] - self.kv_cache = [ - ( - torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), - torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), - ) - for _ in range(vllm_config.parallel_config.pipeline_parallel_size) - ] + if is_310p(): + self.kv_shape = [num_block, block_size, num_kv_heads*head_size] + self.kv_cache = [ + ( + ops.auto_generate.format_cast(torch.zeros(self.kv_shape, dtype=torch.float16, device="Ascend"), FORMAT_TYPE["nz"]), + ops.auto_generate.format_cast(torch.zeros(self.kv_shape, dtype=torch.float16, device="Ascend"), FORMAT_TYPE["nz"]), + ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + else: + self.kv_shape = [num_block, block_size, num_kv_heads, head_size] + self.kv_cache = [ + ( + torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), + torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), + ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] self.attn_type = AttentionType.DECODER @@ -57,11 +68,17 @@ class Fake_MLA(Fake_Attention): def __init__(self): super().__init__() vllm_config = get_current_vllm_config() - self.kv_cache = [ - (torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"),) - for _ in range(vllm_config.parallel_config.pipeline_parallel_size) - ] - + if is_310p: + self.kv_cache = [ + (ops.auto_generate.format_cast(torch.zeros( + self.kv_shape, dtype=torch.float16, device="Ascend"), FORMAT_TYPE["nz"]),) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + else: + self.kv_cache = [ + (torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"),) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] class Fake_Attention_V1(Attention): def __init__(self): @@ -72,14 +89,24 @@ class Fake_Attention_V1(Attention): ) head_size = vllm_config.model_config.get_head_size() num_block = 0 - self.kv_shape = [num_block, block_size, num_kv_heads, head_size] - self.kv_cache = [ - ( - torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), - torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), - ) - for _ in range(vllm_config.parallel_config.pipeline_parallel_size) - ] + if is_310p(): + self.kv_shape = [num_block, block_size, num_kv_heads*head_size] + self.kv_cache = [ + ( + ops.auto_generate.format_cast(torch.zeros(self.kv_shape, dtype=torch.float16, device="Ascend"), FORMAT_TYPE["nz"]), + ops.auto_generate.format_cast(torch.zeros(self.kv_shape, dtype=torch.float16, device="Ascend"), FORMAT_TYPE["nz"]), + ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + else: + self.kv_shape = [num_block, block_size, num_kv_heads, head_size] + self.kv_cache = [ + ( + torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), + torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), + ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] self.attn_type = AttentionType.DECODER self.num_block = num_block self.num_kv_heads = num_kv_heads @@ -93,10 +120,17 @@ class Fake_MLA_V1(Fake_Attention_V1): def __init__(self): super().__init__() vllm_config = get_current_vllm_config() - self.kv_cache = [ - (torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"),) - for _ in range(vllm_config.parallel_config.pipeline_parallel_size) - ] + if is_310p: + self.kv_cache = [ + (ops.auto_generate.format_cast(torch.zeros( + self.kv_shape, dtype=torch.float16, device="Ascend"), FORMAT_TYPE["nz"]),) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + else: + self.kv_cache = [ + (torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"),) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] class MsModelBase(): diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index 6fd3ca83193c5ce6fc719a24072ccd5849bce920..b46e0a8a6e6ded11263c15f491605a1ed34f620b 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -55,6 +55,9 @@ STR_DTYPE_TO_MS_DTYPE = { "fp8_e5m2": ms.uint8, } +FORMAT_TYPE = { + "nz": 29, +} def get_valid_dtype(dtype): if isinstance(dtype, str): @@ -272,6 +275,46 @@ def convert_np_to_ms_dtype(value): return value_dtype +def is_version_ge(current_version, base_version): + """ + return current_version >= base_version. + Check whether the current version is higher than or equal to the base version. + for current_version: 1.8.1, base_version: 1.11.0, it return False. + """ + version_split_char = '.' + if version_split_char not in base_version or version_split_char not in current_version: + raise ValueError("The version string will contain the `.`." + "For example, current_version 1.8.1, base_version: 1.11.0.") + for x, y in zip(current_version.split(version_split_char), base_version.split(version_split_char)): + if not x.isdigit() or not y.isdigit(): + continue + if int(x) != int(y): + return int(x) >= int(y) + return True + + +def get_ascend_soc_version(): + """Get ascend soc version.""" + if is_version_ge(ms.__version__, "2.2.0"): + from mindspore._c_expression import MSContext + return MSContext.get_instance().get_ascend_soc_version() + ascend_chip_type = os.getenv("ASCEND_CHIP_TYPE", "UNSET") + if ascend_chip_type not in ["910a", "910b", "UNSET"]: + raise EnvironmentError(f"ASCEND_CHIP_TYPE should be in ['910a', '910b'],but get {ascend_chip_type}") + if ascend_chip_type == "UNSET": + logger.info("Environment variables need to be set manually to obtain the chip type," + "which can be set as follows: \n" + "For Atlas 800, run 'export ASCEND_CHIP_TYPE=910a' before the program runs.\n" + "For Atlas 800T A2, run 'export ASCEND_CHIP_TYPE=910b' before the program runs.\n" + "If you need to get chip information automatically, MindSpore 2.2 and above is recommended") + return ascend_chip_type + + +def is_310p(): + device = get_ascend_soc_version() + return device in ['310p', 'ascend310p'] + + # Replace the directly loaded module in vllm, such as 'from module import xxx' def update_modules(name, module): logger.debug(f"replace module {0} by {1}".format(name, module)) diff --git a/vllm_mindspore/v1/attention/backends/flash_attn.py b/vllm_mindspore/v1/attention/backends/flash_attn.py index b5c5629ee51fc7faf969f18a7b596e60d939387f..2734ec55f9f5c9c8212f2bb060a72909001d6ce7 100644 --- a/vllm_mindspore/v1/attention/backends/flash_attn.py +++ b/vllm_mindspore/v1/attention/backends/flash_attn.py @@ -54,7 +54,11 @@ class FlashAttentionBackend(AttentionBackend): ) -> Tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) + from vllm_mindspore.utils import is_310p + if is_310p(): + return (2, num_blocks, block_size, num_kv_heads*head_size) + else: + return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index a21a2f73e889169e6d30ca2d2bdd23bb03bcc29b..1b7639394af597a110f5ff6fa0333dfaaa0d98b7 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -6,10 +6,11 @@ import torch from mindspore import mutable import mindspore as ms +from mindspore import ops from vllm_mindspore.v1.attention.backends.flash_attn import (FlashAttentionMetadata, FlashAttentionBackend, MLABackend) -from vllm_mindspore.utils import get_valid_dtype +from vllm_mindspore.utils import get_valid_dtype, FORMAT_TYPE, is_310p from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.utils import bind_kv_cache @@ -171,6 +172,8 @@ def _prepare_inputs( def create_block(shape, dtype, name=None, device=None): from mindspore import mint blocks = mint.empty(shape, dtype=dtype, device=device) + if is_310p(): + blocks = ops.auto_generate.format_cast(blocks, FORMAT_TYPE["nz"]) return blocks def initialize_kv_cache(self, kv_cache_config) -> None: diff --git a/vllm_mindspore/worker/profile.py b/vllm_mindspore/worker/profile.py index 9958ebcbedd5803504fe4c7eca8b0d09a525180f..4309112aab35ce5ce253ff471412241a8e419693 100644 --- a/vllm_mindspore/worker/profile.py +++ b/vllm_mindspore/worker/profile.py @@ -21,7 +21,6 @@ class AdapterProfiler: def __init__(self, path): self.profiler = Profiler( profiler_level=ProfilerLevel.Level1, - activities=[ProfilerActivity.CPU, ProfilerActivity.NPU], output_path=path, start_profile=False )