From 95e3008ec3a0e92861e27e30461495d9f35abff3 Mon Sep 17 00:00:00 2001 From: HighCloud Date: Fri, 4 Jul 2025 15:04:51 +0800 Subject: [PATCH 01/12] support native qwq --- vllm_mindspore/__init__.py | 10 ++ .../distributed/communication_op.py | 10 ++ vllm_mindspore/distributed/parallel_state.py | 93 +++++++++++++++++++ .../model_executor/layers/linear.py | 1 + .../model_loader/weight_utils.py | 13 +-- .../model_executor/models/model_base.py | 52 +++++++++-- vllm_mindspore/model_executor/models/qwen2.py | 30 +++++- vllm_mindspore/utils.py | 90 +++++++++++++----- vllm_mindspore/v1/worker/gpu_model_runner.py | 18 +++- vllm_mindspore/worker/cache_engine.py | 18 +++- vllm_mindspore/worker/model_runner.py | 5 +- 11 files changed, 292 insertions(+), 48 deletions(-) create mode 100644 vllm_mindspore/distributed/parallel_state.py diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 2b711153..fe916fe6 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -324,6 +324,16 @@ RejectionSampler._smallest_positive_value = _smallest_positive_value RejectionSampler._smallest_positive_value.__set_name__( RejectionSampler, "_smallest_positive_value") +import vllm.distributed.communication_op +import vllm.worker.worker_base +from vllm_mindspore.distributed.communication_op import cpu_broadcast_tensor_dict +vllm.distributed.communication_op.broadcast_tensor_dict = cpu_broadcast_tensor_dict +vllm.worker.worker_base.broadcast_tensor_dict = cpu_broadcast_tensor_dict + +import vllm.distributed.parallel_state +from vllm_mindspore.distributed.parallel_state import gc_broadcast_tensor_dict +vllm.distributed.parallel_state.GroupCoordinator.broadcast_tensor_dict = gc_broadcast_tensor_dict + ######### for multi-model from vllm_mindspore.inputs.registry import call_hf_processor from vllm.inputs.registry import InputProcessingContext diff --git a/vllm_mindspore/distributed/communication_op.py b/vllm_mindspore/distributed/communication_op.py index c933dc4a..475a282d 100644 --- a/vllm_mindspore/distributed/communication_op.py +++ b/vllm_mindspore/distributed/communication_op.py @@ -21,10 +21,20 @@ Implement a unified communication interface for both graph and pynative mode. """ +from typing import Any, Dict, Optional, Union +import torch + from mindspore import nn, ops from vllm.distributed.parallel_state import ( get_tensor_model_parallel_world_size, get_tp_group) +def cpu_broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor, + Any]]] = None, + src: int = 0): + if not torch.distributed.is_initialized(): + return tensor_dict + return get_tp_group().broadcast_tensor_dict(tensor_dict, src, group=get_tp_group().cpu_group) + class ReduceFromModelParallelRegion(nn.Cell): "All reduce the input from the model parallel region." diff --git a/vllm_mindspore/distributed/parallel_state.py b/vllm_mindspore/distributed/parallel_state.py new file mode 100644 index 00000000..697196fa --- /dev/null +++ b/vllm_mindspore/distributed/parallel_state.py @@ -0,0 +1,93 @@ +import torch +import torch.distributed +from torch.distributed import ProcessGroup + +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union) +from vllm.distributed.parallel_state import _split_tensor_dict, TensorMetadata +from vllm_mindspore.utils import atlas_inference + +def gc_broadcast_tensor_dict( + self, + tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if (not torch.distributed.is_initialized() or self.world_size == 1): + return tensor_dict + + if not atlas_inference(): + group = self.device_group + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + + rank_in_group = self.rank_in_group + if rank_in_group == src: + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, + dict), (f"Expecting a dictionary, got {type(tensor_dict)}") + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=group, + async_op=True) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, + src=self.ranks[src], + group=group, + async_op=True) + async_handles.append(handle) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + for async_handle in async_handles: + async_handle.wait() + return tensor_dict diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index 8c438994..d1076c68 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -608,6 +608,7 @@ class RowParallelLinear(LinearBase): def weight_loader(self, param, loaded_weight): tp_rank = get_tensor_model_parallel_rank() + param_data = param.data input_dim = getattr(param, "input_dim", None) shard_size = self.input_size_per_partition start_idx = tp_rank * shard_size diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index 6bf2dd4c..e02de0ab 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -25,6 +25,8 @@ import mindspore as ms from mindspore import Parameter from safetensors import safe_open from tqdm.auto import tqdm +from vllm_mindspore.utils import atlas_inference +import numpy as np from vllm.model_executor.model_loader.weight_utils import (_BAR_FORMAT, enable_tqdm) @@ -66,12 +68,11 @@ def safetensors_weights_iterator( ): with safe_open(st_file, framework="np") as f: for name in f.keys(): # noqa: SIM118 - # Return a lightweight PySafeSlice object that uses file - # pointer offset internally to read Safetensor on demand, - # avoiding memory explosion. Actual data can be obtained - # through slicing operation like param[start:end] - param = f.get_slice(name) - yield name, param + # TODO: use slice + x = f.get_tensor(name) + x = x.astype(np.float16) \ + if (str(x.dtype) == 'bfloat16' and atlas_inference()) else x + yield name, ms.tensor(x) def default_weight_loader(param: Parameter, loaded_weight: Any) -> None: diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 58178292..4711cb99 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -22,7 +22,7 @@ from typing import Any, Optional, Union, cast import mindspore as ms import numpy as np import vllm.envs as envs -from mindspore import Tensor, mutable, nn +from mindspore import Tensor, mutable, nn, ops from mindspore.common import dtype as mstype from vllm.attention.backends.abstract import AttentionType from vllm.config import VllmConfig, get_current_vllm_config @@ -36,7 +36,7 @@ from vllm_mindspore.model_executor.models.attention_mask import ( from vllm_mindspore.model_executor.utils import set_model_context from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata - +from vllm_mindspore.utils import atlas_inference, FORMAT_TYPE class AttentionWrapper: @@ -47,11 +47,32 @@ class AttentionWrapper: vllm_config.parallel_config) head_size = vllm_config.model_config.get_head_size() num_block = 0 - self.kv_shape = [num_block, block_size, num_kv_heads, head_size] - self.kv_cache = [( - ms.mint.zeros(self.kv_shape, dtype=vllm_config.model_config.dtype), - ms.mint.zeros(self.kv_shape, dtype=vllm_config.model_config.dtype), - ) for _ in range(vllm_config.parallel_config.pipeline_parallel_size)] + if atlas_inference(): + self.kv_shape = [num_block, block_size, num_kv_heads * head_size] + self.kv_cache = [ + ( + ops.auto_generate.format_cast( + ms.mint.zeros( + self.kv_shape, dtype=vllm_config.model_config.dtype + ), + FORMAT_TYPE['nz'], + ), + ops.auto_generate.format_cast( + ms.mint.zeros( + self.kv_shape, dtype=vllm_config.model_config.dtype + ), + FORMAT_TYPE['nz'], + ), + ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + else: + self.kv_shape = [num_block, block_size, num_kv_heads, head_size] + self.kv_cache = [( + ms.mint.zeros(self.kv_shape, dtype=vllm_config.model_config.dtype), + ms.mint.zeros(self.kv_shape, dtype=vllm_config.model_config.dtype), + ) for _ in range(vllm_config.parallel_config.pipeline_parallel_size)] + self.attn_type = AttentionType.DECODER # add for v1 @@ -71,7 +92,19 @@ class MLAAttentionWrapper(AttentionWrapper): self.use_mla_op = bool( vllm_config.additional_config and vllm_config.additional_config.get('use_mla_op') == 1) - if not self.use_mla_op: + if atlas_inference(): + self.kv_cache = [ + ( + ops.auto_generate.format_cast( + ms.mint.zeros( + self.kv_shape, dtype=vllm_config.model_config.dtype + ), + FORMAT_TYPE['nz'], + ), + ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + elif not self.use_mla_op: self.kv_cache = [ ( ms.mint.zeros( @@ -431,7 +464,8 @@ class NativeModel(MsModelBase): block_size = self.cache_config.block_size num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) head_size = self.model_config.get_head_size() - kv_cache_shape = (None, block_size, num_kv_heads, head_size) + kv_cache_shape = (None, block_size, num_kv_heads * head_size) if atlas_inference() \ + else (None, block_size, num_kv_heads, head_size) kv_cache_dtype = (self.model_config.dtype if self.cache_config.cache_dtype == "auto" else diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index 9cacf08a..8be2b4d5 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -35,7 +35,8 @@ if TYPE_CHECKING: else: Qwen2Config = None -from mindspore import Parameter, Tensor, mint, nn +from mindspore import Parameter, Tensor, mint, nn, ops +import mindspore as ms from vllm.attention.backends.abstract import AttentionType from vllm.config import CacheConfig, VllmConfig @@ -46,6 +47,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.sequence import IntermediateTensors from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm_mindspore.utils import atlas_inference, FORMAT_TYPE from vllm_mindspore.attention import Attention from vllm_mindspore.model_executor.layers.activation import SwiGLU from vllm_mindspore.model_executor.layers.layernorm import RMSNorm @@ -409,9 +411,35 @@ class Qwen2Model(nn.Cell): param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) + # Norm type in weights may be f32 + if(loaded_weight.dtype != param.dtype): + loaded_weight = loaded_weight.to(dtype=param.dtype) weight_loader(param, loaded_weight) loaded_params.add(name) + def adjust_weight(params_dict): + if not atlas_inference(): + return + + target_keywords = [ + "qkv_proj.weight", + "o_proj.weight", + "gate_up_proj.weight", + "down_proj.weight", + # "lm_head.weight", + ] + + for name, param in params_dict.items(): + if any(name.endswith(keyword) for keyword in target_keywords): + cast_weight = ops.auto_generate.format_cast(param, FORMAT_TYPE['nz']) + ms.runtime.synchronize() + param.set_data(cast_weight) + + if atlas_inference(): + ms.runtime.synchronize() + adjust_weight(params_dict) + ms.runtime.synchronize() + return loaded_params diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index 179dc93f..8bb43caa 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -56,6 +56,9 @@ STR_DTYPE_TO_MS_DTYPE = { "fp8_e5m2": ms.uint8, } +FORMAT_TYPE = { + "nz": 29, +} def get_valid_dtype(dtype): if isinstance(dtype, str): @@ -237,30 +240,6 @@ def is_310p(): return device in ['310p', 'ascend310p'] -def check_ready(): - from mindspore import set_context - - # Common environment variables of predict. - set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) - default_env = { - "MS_INTERNAL_DISABLE_CUSTOM_KERNEL_LIST": - "FlashAttentionScore,PagedAttention", - } - env_setup(default_env) - - if os.getenv("MS_MEMPOOL_BLOCK_SIZE"): - set_context( - mempool_block_size=f"{os.environ['MS_MEMPOOL_BLOCK_SIZE']}GB") - - if is_mindformers_model_backend(): - logger.info("Run with Mindformers backend!") - elif is_mindone_model_backend(): - logger.info("Run with MindONE backend!") - else: - logger.info("Run with native model backend!") - register_connector() - - def convert_np_to_ms_dtype(value): """convert_np_to_ms_dtype""" if value.dtype == np.int8: @@ -368,3 +347,66 @@ def ms_memory_profiling( result.non_torch_increase = diff_from_create.non_torch_memory result.profile_time = diff_profile.timestamp result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa + + +def is_version_ge(current_version, base_version): + """ + return current_version >= base_version. + Check whether the current version is higher than or equal to the base version. + for current_version: 1.8.1, base_version: 1.11.0, it return False. + """ + version_split_char = '.' + if version_split_char not in base_version or version_split_char not in current_version: + raise ValueError("The version string will contain the `.`." + "For example, current_version 1.8.1, base_version: 1.11.0.") + for x, y in zip(current_version.split(version_split_char), base_version.split(version_split_char)): + if not x.isdigit() or not y.isdigit(): + continue + if int(x) != int(y): + return int(x) >= int(y) + return True + +def get_ascend_soc_version(): + """Get ascend soc version.""" + if is_version_ge(ms.__version__, "2.2.0"): + from mindspore._c_expression import MSContext + return MSContext.get_instance().get_ascend_soc_version() + ascend_chip_type = os.getenv("ASCEND_CHIP_TYPE", "UNSET") + if ascend_chip_type not in ["910a", "910b", "UNSET"]: + raise EnvironmentError(f"ASCEND_CHIP_TYPE should be in ['910a', '910b'],but get {ascend_chip_type}") + if ascend_chip_type == "UNSET": + logger.info("Environment variables need to be set manually to obtain the chip type," + "which can be set as follows: \n" + "For Atlas 800, run 'export ASCEND_CHIP_TYPE=910a' before the program runs.\n" + "For Atlas 800T A2, run 'export ASCEND_CHIP_TYPE=910b' before the program runs.\n" + "If you need to get chip information automatically, MindSpore 2.2 and above is recommended") + return ascend_chip_type + +def atlas_inference(): + device = get_ascend_soc_version() + return device in ['310p', 'ascend310p'] + +def check_ready(): + from mindspore import set_context + + # Common environment variables of predict. + set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + default_env = { + "MS_INTERNAL_DISABLE_CUSTOM_KERNEL_LIST": + "FlashAttentionScore,PagedAttention", + } + if atlas_inference(): + default_env["MS_ENABLE_INTERNAL_BOOST"] = "off" + env_setup(default_env) + + if os.getenv("MS_MEMPOOL_BLOCK_SIZE"): + set_context( + mempool_block_size=f"{os.environ['MS_MEMPOOL_BLOCK_SIZE']}GB") + + if is_mindformers_model_backend(): + logger.info("Run with Mindformers backend!") + elif is_mindone_model_backend(): + logger.info("Run with MindONE backend!") + else: + logger.info("Run with native model backend!") + register_connector() diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index 15e105c5..e7f06d4f 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -24,7 +24,7 @@ from typing import Any, Optional import mindspore as ms import numpy as np from mindspore import Generator as msGenerator -from mindspore import Tensor, mint, mutable +from mindspore import Tensor, mint, mutable, ops from vllm.attention import AttentionType from vllm.logger import init_logger from vllm.sampling_params import SamplingType @@ -36,7 +36,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState from vllm_mindspore.model_executor.layers.rotary_embedding import ( InferMRotaryEmbedding as MRotaryEmbedding) -from vllm_mindspore.utils import get_dtype_size, get_valid_dtype +from vllm_mindspore.utils import get_dtype_size, get_valid_dtype, atlas_inference, FORMAT_TYPE logger = init_logger(__name__) @@ -312,6 +312,9 @@ def _reshape_kv_cache_tensors( kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + if atlas_inference(): + *dims, second_last, last = kv_cache_shape + kv_cache_shape = (*dims, second_last * last) try: kv_cache_stride_order = self.attn_backends[ i].get_kv_cache_stride_order() @@ -344,7 +347,16 @@ def _reshape_kv_cache_tensors( else: cache_block = kv_cache_raw_tensor.view( kv_cache_shape[1:]).permute(*inv_order[1:]) - kv_cache_layer.append(cache_block) + if atlas_inference(): + from mindspore.common.api import _pynative_executor + cache_block_nz = ops.auto_generate.format_cast(cache_block, FORMAT_TYPE['nz']) + _pynative_executor.sync() + import gc + del cache_block + gc.collect() + kv_cache_layer.append(cache_block_nz) + else: + kv_cache_layer.append(cache_block) kv_caches[layer_name] = mutable(tuple(kv_cache_layer)) else: raise NotImplementedError diff --git a/vllm_mindspore/worker/cache_engine.py b/vllm_mindspore/worker/cache_engine.py index b57b8833..7675379c 100644 --- a/vllm_mindspore/worker/cache_engine.py +++ b/vllm_mindspore/worker/cache_engine.py @@ -22,17 +22,26 @@ # isort:skip_file import mindspore as ms -from mindspore import mutable, mint +from mindspore import mutable, mint, ops from typing import List from vllm.logger import init_logger -from vllm_mindspore.utils import MsKVCache, get_valid_dtype +from vllm_mindspore.utils import MsKVCache, get_valid_dtype, atlas_inference, FORMAT_TYPE logger = init_logger(__name__) def create_block(shape, dtype, name=None, device=None): - blocks = mint.empty(shape, dtype=dtype, device=device) + from mindspore.common.api import _pynative_executor + blocks = mint.empty(*shape, dtype=dtype, device=device) + if device == "Ascend" and atlas_inference(): + blocks_nz = ops.auto_generate.format_cast(blocks, FORMAT_TYPE['nz']) + _pynative_executor.sync() + import gc + del blocks + gc.collect() + ms.hal.empty_cache() + return blocks_nz return blocks @@ -44,6 +53,9 @@ def ms_allocate_kv_cache( """Allocates KV cache on the specified device.""" kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, self.block_size, self.num_kv_heads, self.head_size) + if atlas_inference(): + *dims, second_last, last = kv_cache_shape + kv_cache_shape = (*dims, second_last * last) kv_cache: List[MsKVCache] = [] self.dtype = get_valid_dtype(self.dtype) diff --git a/vllm_mindspore/worker/model_runner.py b/vllm_mindspore/worker/model_runner.py index 7fd89fc5..6ab97c1b 100644 --- a/vllm_mindspore/worker/model_runner.py +++ b/vllm_mindspore/worker/model_runner.py @@ -28,7 +28,7 @@ from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.sequence import SequenceGroupMetadata -from vllm_mindspore.utils import STR_DTYPE_TO_TENSOR_DTYPE +from vllm_mindspore.utils import STR_DTYPE_TO_TENSOR_DTYPE, atlas_inference logger = init_logger(__name__) @@ -140,7 +140,8 @@ def _dummy_run(self, block_size = self.cache_config.block_size num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) head_size = self.model_config.get_head_size() - kv_shape = [0, block_size, num_kv_heads, head_size] + kv_shape = [0, block_size, num_kv_heads * head_size] if atlas_inference() else \ + [0, block_size, num_kv_heads, head_size] kv_caches = mutable([ mutable( ( -- Gitee From 58473b7dc1f8994e025a50c9b5c0e1c1944430f4 Mon Sep 17 00:00:00 2001 From: one_east Date: Thu, 24 Jul 2025 20:31:06 +0800 Subject: [PATCH 02/12] CPU bind for 910B and 910C --- vllm_mindspore/worker/worker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index b1acf865..06e6d731 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -16,6 +16,7 @@ """Adapted functions for mindspore in Worker.""" import math +import subprocess import os import subprocess -- Gitee From 1bd9ce6ed644424515172d42479342d94cb5f15b Mon Sep 17 00:00:00 2001 From: HighCloud Date: Wed, 30 Jul 2025 15:24:10 +0800 Subject: [PATCH 03/12] cpu bind support 310p --- vllm_mindspore/worker/worker.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index 06e6d731..ad6a0ba3 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -76,14 +76,17 @@ def get_numa_map(): "topo"]).strip().split("\n") numa_to_npu_map = {} max_affinity_cpu = 0 - if "Affinity" not in numa_topo_info[0]: + if "Affinity" not in numa_topo_info[0] or is_310p(): # If the device does not provide affinity, # the CPUs will be evenly distributed. cpu_num_per_npu = total_cpu_count // (npu_count * chip_count) for i in range(npu_count * chip_count): cpu_start = i * cpu_num_per_npu - # 4 CPUs are reserved for CANN - npu_to_core_map[i] = [cpu_start, cpu_start + cpu_num_per_npu - 4] + # 4 CPUs are reserved for CANN(not for 310p) + npu_to_core_map[i] = [ + cpu_start, + cpu_start + cpu_num_per_npu - (0 if is_310p() else 4) + ] return npu_to_core_map else: npu_num = 0 @@ -154,13 +157,12 @@ def wrapper_worker_bind_cpu(fun): def new_fun(*arg, **kwargs): # Bind CPU with wrapper when workers are initializing. - # Support 910B and 910C. - if not is_310p(): - local_rank = kwargs.get("local_rank") - parallel_config = kwargs.get("vllm_config").parallel_config - local_rank = (parallel_config.data_parallel_rank_local * - parallel_config.world_size + local_rank) - bind_cpu(local_rank) + # Support 910B, 910C and 310P. + local_rank = kwargs.get("local_rank") + parallel_config = kwargs.get("vllm_config").parallel_config + local_rank = (parallel_config.data_parallel_rank_local * + parallel_config.world_size + local_rank) + bind_cpu(local_rank) fun(*arg, **kwargs) return new_fun -- Gitee From 7f735b73a5df4adb26e4d5eb37d349180973d6a6 Mon Sep 17 00:00:00 2001 From: superxf Date: Wed, 23 Jul 2025 15:28:07 +0800 Subject: [PATCH 04/12] support qwq --- vllm_mindspore/__init__.py | 40 ++++++ vllm_mindspore/config.py | 5 + vllm_mindspore/engine/arg_utils.py | 123 +++++++++++++++++- .../model_executor/layers/linear.py | 43 ++++-- .../layers/quantization/__init__.py | 49 +++++++ .../layers/quantization/base_config.py | 3 + .../quantization/smooth_quant_modelslim.py | 38 +++--- .../model_loader/default_loader.py | 99 ++++++++++++++ .../model_executor/model_loader/utils.py | 60 +++++++++ .../model_loader/weight_utils.py | 109 +++++++++++++++- vllm_mindspore/model_executor/models/qwen2.py | 5 + 11 files changed, 533 insertions(+), 41 deletions(-) create mode 100644 vllm_mindspore/model_executor/model_loader/default_loader.py diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index fe916fe6..157fad33 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -533,6 +533,46 @@ sys.modules["vllm.entrypoints.openai.tool_parsers.deepseekv3_tool_parser"] = ( from vllm_mindspore.entrypoints.__main__ import ( patch_server_run_api_server_worker_proc, ) +from vllm_mindspore.model_executor.model_loader.utils import ( + process_weights_after_loading) + +vllm.model_executor.model_loader.utils.process_weights_after_loading = ( + process_weights_after_loading) +vllm.model_executor.model_loader.base_loader.process_weights_after_loading = ( + process_weights_after_loading) + +from vllm_mindspore.model_executor.layers.quantization import ( + get_quantization_config) + +vllm.model_executor.layers.quantization.get_quantization_config = ( + get_quantization_config) +vllm.config.get_quantization_config = get_quantization_config +vllm.model_executor.model_loader.weight_utils.get_quantization_config = ( + get_quantization_config) + +from vllm_mindspore.model_executor.model_loader.weight_utils import ( + get_quant_config) + +vllm.model_executor.model_loader.weight_utils.get_quant_config = ( + get_quant_config) +vllm.config.get_quant_config = get_quant_config + +from vllm_mindspore.model_executor.layers.quantization import ( + QuantizationMethods) + +vllm.model_executor.layers.quantization.QuantizationMethods = ( + QuantizationMethods) + +from vllm_mindspore.engine.arg_utils import get_kwargs + +vllm.engine.arg_utils.get_kwargs = get_kwargs + +from vllm_mindspore.model_executor.model_loader.default_loader import ( + _prepare_weights) +from vllm.model_executor.model_loader.default_loader import DefaultModelLoader + +DefaultModelLoader._prepare_weights = _prepare_weights + patch_server_run_api_server_worker_proc() from vllm_mindspore.model_executor.models.registry import _normalize_archs diff --git a/vllm_mindspore/config.py b/vllm_mindspore/config.py index c464227b..f0376485 100644 --- a/vllm_mindspore/config.py +++ b/vllm_mindspore/config.py @@ -35,6 +35,8 @@ from vllm.config import (_STR_DTYPE_TO_TORCH_DTYPE, CompilationConfig, from vllm.logger import init_logger from vllm.utils import random_uuid +from vllm_mindspore.utils import is_310p + logger = init_logger(__name__) @@ -254,6 +256,9 @@ def _get_and_verify_dtype( if torch_dtype in _STR_DTYPE_TO_TORCH_DTYPE: torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[torch_dtype] + if is_310p() and torch_dtype == torch.bfloat16: + return torch.float16 + return torch_dtype diff --git a/vllm_mindspore/engine/arg_utils.py b/vllm_mindspore/engine/arg_utils.py index a7bfc220..a497ad5d 100644 --- a/vllm_mindspore/engine/arg_utils.py +++ b/vllm_mindspore/engine/arg_utils.py @@ -19,15 +19,128 @@ # limitations under the License. """Adaption for arguments utils.""" +import argparse +import json import threading -from typing import get_args +from dataclasses import MISSING, fields, is_dataclass +from typing import Any, Literal, get_origin import torch import vllm.envs as envs -from vllm.config import (GuidedDecodingBackendV1, LoadFormat, ModelConfig, - ParallelConfig, SchedulerConfig) -from vllm.engine.arg_utils import (EngineArgs, _raise_or_fallback, - _warn_or_fallback) +from pydantic import TypeAdapter, ValidationError +from vllm.config import (ConfigType, GuidedDecodingBackendV1, LoadFormat, + ModelConfig, ParallelConfig, SchedulerConfig) +from vllm.engine.arg_utils import (EngineArgs, TypeHint, _raise_or_fallback, + _warn_or_fallback, contains_type, get_args, + get_attr_docs, get_type, get_type_hints, + human_readable_int, is_not_builtin, + literal_to_kwargs, optional_type, + parse_type, union_dict_and_str) + +from vllm_mindspore.model_executor.layers.quantization import ( + QUANTIZATION_METHODS) + + +def get_kwargs(cls: ConfigType) -> dict[str, Any]: + cls_docs = get_attr_docs(cls) + kwargs = {} + for field in fields(cls): + type_hints: set[TypeHint] = get_type_hints(field.type) + + # If the field is a dataclass, we can use the model_validate_json + generator = (th for th in type_hints if is_dataclass(th)) + dataclass_cls = next(generator, None) + + # Get the default value of the field + if field.default is not MISSING: + default = field.default + elif field.default_factory is not MISSING: + default = field.default_factory() + + # Get the help text for the field + name = field.name + help = cls_docs[name].strip() + # Escape % for argparse + help = help.replace("%", "%%") + + # Initialise the kwargs dictionary for the field + kwargs[name] = {"default": default, "help": help} + + # Set other kwargs based on the type hints + json_tip = """\n\nShould either be a valid JSON string or JSON keys + passed individually. For example, the following sets of arguments are + equivalent:\n\n + - `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n + - `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n\n""" + if dataclass_cls is not None: + + def parse_dataclass(val: str, cls=dataclass_cls) -> Any: + try: + if hasattr(cls, "from_cli"): + return cls.from_cli(val) + return TypeAdapter(cls).validate_json(val) + except ValidationError as e: + raise argparse.ArgumentTypeError(repr(e)) from e + + kwargs[name]["type"] = parse_dataclass + kwargs[name]["help"] += json_tip + elif contains_type(type_hints, bool): + # Creates --no- and -- flags + kwargs[name]["action"] = argparse.BooleanOptionalAction + elif contains_type(type_hints, Literal): + kwargs[name].update(literal_to_kwargs(type_hints)) + elif contains_type(type_hints, tuple): + type_hint = get_type(type_hints, tuple) + types = get_args(type_hint) + tuple_type = types[0] + assert all(t is tuple_type for t in types if t is not Ellipsis), ( + "All non-Ellipsis tuple elements must be of the same " + f"type. Got {types}.") + kwargs[name]["type"] = tuple_type + kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types) + elif contains_type(type_hints, list): + type_hint = get_type(type_hints, list) + types = get_args(type_hint) + assert len(types) == 1, ( + "List type must have exactly one type. Got " + f"{type_hint} with types {types}") + kwargs[name]["type"] = types[0] + kwargs[name]["nargs"] = "+" + elif contains_type(type_hints, int): + kwargs[name]["type"] = int + # Special case for large integers + if name in {"max_model_len", "max_num_batched_tokens"}: + kwargs[name]["type"] = human_readable_int + elif contains_type(type_hints, float): + kwargs[name]["type"] = float + elif (contains_type(type_hints, dict) + and (contains_type(type_hints, str) + or any(is_not_builtin(th) for th in type_hints))): + kwargs[name]["type"] = union_dict_and_str + elif contains_type(type_hints, dict): + kwargs[name]["type"] = parse_type(json.loads) + kwargs[name]["help"] += json_tip + elif (contains_type(type_hints, str) + or any(is_not_builtin(th) for th in type_hints)): + kwargs[name]["type"] = str + else: + raise ValueError( + f"Unsupported type {type_hints} for argument {name}.") + + # If the type hint was a sequence of literals, use the helper function + # to update the type and choices + if get_origin(kwargs[name].get("type")) is Literal: + kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]})) + + # If None is in type_hints, make the argument optional. + # But not if it's a bool, argparse will handle this better. + if type(None) in type_hints and not contains_type(type_hints, bool): + kwargs[name]["type"] = optional_type(kwargs[name]["type"]) + if kwargs[name].get("choices"): + kwargs[name]["choices"].append("None") + if field.name == "quantization": + kwargs[name]["choices"] = QUANTIZATION_METHODS + return kwargs def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index d1076c68..79dbd0f1 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -18,7 +18,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Linear methods for quantized linear layers. """ - from abc import abstractmethod from typing import Optional, Union @@ -344,14 +343,23 @@ class MergedColumnParallelLinear(ColumnParallelLinear): assert loaded_shard_id < len(self.output_sizes) shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size shard_size = self.output_sizes[loaded_shard_id] // tp_size - - start_idx = tp_rank * shard_size - loaded_weight = split_loaded_weight(loaded_weight, output_dim, - start_idx, shard_size) - - assert loaded_weight.shape == (shard_size, param.shape[1]) - param[shard_offset:shard_offset + - shard_size, :] = ms.from_numpy(loaded_weight) + param_data = param.data + param_data = param_data.narrow(output_dim, shard_offset, + shard_size) + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size).contiguous() + assert param_data.shape == loaded_weight.shape + if len(loaded_weight.shape) == 2: + param[shard_offset:shard_offset + + shard_size, :] = loaded_weight + else: + param[shard_offset:shard_offset + shard_size] = loaded_weight + else: + assert param.shape == loaded_weight.shape + if loaded_weight.dtype == ms.float32 and param.dtype == ms.float16: + loaded_weight = loaded_weight.astype(ms.float16) + param.set_data(loaded_weight.contiguous()) class QKVParallelLinear(ColumnParallelLinear): @@ -434,6 +442,11 @@ class QKVParallelLinear(ColumnParallelLinear): loaded_weight, loaded_shard_id: Optional[str] = None): output_dim = getattr(param, "output_dim", None) + if output_dim is None: + if loaded_weight.dtype == ms.float32 and param.dtype == ms.float16: + loaded_weight = loaded_weight.astype(ms.float16) + param.set_data(loaded_weight.contiguous()) + return tp_rank = get_tensor_model_parallel_rank() # QKV loaded weight is already fused on disk (qkv safetensors). @@ -483,11 +496,13 @@ class QKVParallelLinear(ColumnParallelLinear): start_idx, shard_size) loaded_weight = ms.from_numpy(loaded_weight) - if param.name.endswith("weight"): - assert loaded_weight.shape == (shard_size, param.shape[1]) - if param.name.endswith("bias"): - assert loaded_weight.shape == (shard_size, ) - param[shard_offset:shard_offset + shard_size] = loaded_weight + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size).contiguous() + assert param_data.shape == loaded_weight.shape + if len(loaded_weight.shape) == 2: + param[shard_offset:shard_offset + shard_size, :] = loaded_weight + else: + param[shard_offset:shard_offset + shard_size] = loaded_weight class RowParallelLinear(LinearBase): diff --git a/vllm_mindspore/model_executor/layers/quantization/__init__.py b/vllm_mindspore/model_executor/layers/quantization/__init__.py index e69de29b..6c9e2e41 100644 --- a/vllm_mindspore/model_executor/layers/quantization/__init__.py +++ b/vllm_mindspore/model_executor/layers/quantization/__init__.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from typing import Literal, get_args + +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + +QuantizationMethods = Literal["smoothquant"] +QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) + +# The customized quantization methods which will be added to this dict. +_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {} + + +def get_quantization_config(quantization: str) -> type[QuantizationConfig]: + if quantization not in QUANTIZATION_METHODS: + raise ValueError(f"Invalid quantization method: {quantization}") + + # lazy import to avoid triggering `torch.compile` too early + from .smooth_quant_modelslim import SmoothQuantModelSlimConfig + method_to_config: dict[str, type[QuantizationConfig]] = { + "smoothquant": SmoothQuantModelSlimConfig + } + # Update the `method_to_config` with customized quantization methods. + method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) + + return method_to_config[quantization] + + +__all__ = [ + "QuantizationConfig", "get_quantization_config", "QUANTIZATION_METHODS", + "QuantizationMethods" +] diff --git a/vllm_mindspore/model_executor/layers/quantization/base_config.py b/vllm_mindspore/model_executor/layers/quantization/base_config.py index 37144a43..5728702d 100644 --- a/vllm_mindspore/model_executor/layers/quantization/base_config.py +++ b/vllm_mindspore/model_executor/layers/quantization/base_config.py @@ -142,6 +142,9 @@ class QuantizationConfig(ABC): """ raise NotImplementedError + def get_cache_scale(self, name: str) -> Optional[str]: + return None + def method_has_implemented_embedding( method_class: type[QuantizeMethodBase]) -> bool: diff --git a/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py b/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py index 172401e4..94df6fd8 100644 --- a/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py +++ b/vllm_mindspore/model_executor/layers/quantization/smooth_quant_modelslim.py @@ -14,11 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re from typing import Any, Optional import mindspore import numpy as np +import regex as re from mindspore import Parameter, Tensor, ops from mindspore.common.initializer import initializer from mindspore.ops.auto_generate import (DynamicQuantExt, GroupedMatmul, @@ -107,7 +107,7 @@ class SmoothQuantModelSlimConfig(QuantizationConfig): return BaseKVCacheMethod(self) if isinstance(layer, LinearBase): - if quant_config and quant_config.lower() == 'w8a8': + if quant_config and quant_config.lower() == 'w8a8s': return A8W8LinearMethod(self) if quant_config and quant_config.lower() == 'w8a8_dyn': self.dynamic_quant = True @@ -225,12 +225,12 @@ class A8W8LinearMethod(LinearMethodBase): self.params_dtype), name="input_offset") if self.is_310p: - quant_bias_ = Parameter(initializer( + quant_bias = Parameter(initializer( 'zeros', (self.output_size_per_partition // self.quant_config.pack_factor, ), mindspore.int32), - name="quant_bias_") + name="quant_bias") else: - quant_bias_ = None + quant_bias = None set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) set_weight_attrs(weight_scale, {"output_dim": 0}) @@ -241,11 +241,12 @@ class A8W8LinearMethod(LinearMethodBase): set_weight_attrs(deq_scale, extra_weight_attrs) set_weight_attrs(input_scale, extra_weight_attrs) set_weight_attrs(input_offset, extra_weight_attrs) - if quant_bias_ is not None: - set_weight_attrs(quant_bias_, extra_weight_attrs) - layer.insert_param_to_cell("quant_bias_", quant_bias_) + if quant_bias is not None: + set_weight_attrs(quant_bias, extra_weight_attrs) + set_weight_attrs(quant_bias, {"output_dim": 0}) + layer.insert_param_to_cell("quant_bias", quant_bias) else: - layer.quant_bias_ = None + layer.quant_bias = None layer.insert_param_to_cell("weight", weight) layer.insert_param_to_cell("weight_scale", weight_scale) @@ -286,7 +287,7 @@ class A8W8LinearMethod(LinearMethodBase): input_offset = Parameter(initializer('zeros', input_scale_shape, self.params_dtype), name="input_offset") - quant_bias_ = None + quant_bias = None set_weight_attrs(weight, { "ep_dim": 0, "input_dim": 1, @@ -301,11 +302,11 @@ class A8W8LinearMethod(LinearMethodBase): set_weight_attrs(deq_scale, extra_weight_attrs) set_weight_attrs(input_scale, extra_weight_attrs) set_weight_attrs(input_offset, extra_weight_attrs) - if quant_bias_ is not None: - set_weight_attrs(quant_bias_, extra_weight_attrs) - layer.insert_param_to_cell("quant_bias_", quant_bias_) + if quant_bias is not None: + set_weight_attrs(quant_bias, extra_weight_attrs) + layer.insert_param_to_cell("quant_bias", quant_bias) else: - layer.quant_bias_ = None + layer.quant_bias = None layer.insert_param_to_cell("weight", weight) layer.insert_param_to_cell("weight_scale", weight_scale) @@ -314,8 +315,7 @@ class A8W8LinearMethod(LinearMethodBase): layer.insert_param_to_cell("input_offset", input_offset) def process_weights_after_loading(self, layer: mindspore.nn.Cell) -> None: - input_offset = np.array([0]) - params_dtype = layer.params_dtype + input_offset = layer.input_offset.asnumpy() layer.input_offset = Parameter(Tensor(input_offset, dtype=mindspore.int8), name=layer.input_offset.name) @@ -336,7 +336,7 @@ class A8W8LinearMethod(LinearMethodBase): layer.weight_scale = Parameter(Tensor( weight_scale, dtype=layer.weight_scale.dtype), name=layer.weight_scale.name) - if not self.is_310p and params_dtype is mindspore.bfloat16: + if not self.is_310p and self.params_dtype is mindspore.bfloat16: deq_scale = layer.deq_scale.asnumpy().astype(np.int32).view( np.float32) layer.deq_scale = Parameter(Tensor(deq_scale, @@ -374,10 +374,8 @@ class A8W8LinearMethod(LinearMethodBase): group_type=0, group_list_type=0 if cumsum_flag else 1)[0] else: - qx = self.matmul(qx, weight, deq_scale, None, layer.quant_bias_, + qx = self.matmul(qx, weight, deq_scale, None, layer.quant_bias, None) - if bias is not None: - qx = self.bias_add(qx, bias) qx = qx.reshape(output_shape) return qx diff --git a/vllm_mindspore/model_executor/model_loader/default_loader.py b/vllm_mindspore/model_executor/model_loader/default_loader.py new file mode 100644 index 00000000..dbd6ea8b --- /dev/null +++ b/vllm_mindspore/model_executor/model_loader/default_loader.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: Apache-2.0 +import glob +import os +from typing import Optional + +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME +from vllm.config import LoadFormat +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.model_executor.model_loader.weight_utils import ( + download_safetensors_index_file_from_hf, download_weights_from_hf, + filter_duplicate_safetensors_files, filter_files_not_needed_for_inference) + + +def _prepare_weights( + self, + model_name_or_path: str, + revision: Optional[str], + fall_back_to_pt: bool, + allow_patterns_overrides: Optional[list[str]], +) -> tuple[str, list[str], bool]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + model_name_or_path = (self._maybe_download_from_modelscope( + model_name_or_path, revision) or model_name_or_path) + + is_local = os.path.isdir(model_name_or_path) + load_format = self.load_config.load_format + use_safetensors = False + index_file = SAFE_WEIGHTS_INDEX_NAME + # Some quantized models use .pt files for storing the weights. + if load_format == LoadFormat.AUTO: + allow_patterns = ["*.safetensors", "*.bin"] + elif (load_format == LoadFormat.SAFETENSORS + or load_format == LoadFormat.FASTSAFETENSORS): + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == LoadFormat.MISTRAL: + use_safetensors = True + allow_patterns = ["consolidated*.safetensors"] + index_file = "consolidated.safetensors.index.json" + elif load_format == LoadFormat.PT: + allow_patterns = ["*.pt"] + elif load_format == LoadFormat.NPCACHE: + allow_patterns = ["*.bin"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + if allow_patterns_overrides is not None: + allow_patterns = allow_patterns_overrides + + if not is_local: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + else: + hf_folder = model_name_or_path + hf_weights_files: list[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) == 0: + tp_rank = get_tensor_model_parallel_rank() + hf_weights_files += glob.glob( + os.path.join(hf_folder, f"rank_{tp_rank}", pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + if use_safetensors: + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, + index_file, + self.load_config.download_dir, + revision, + ) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder, index_file) + else: + hf_weights_files = filter_files_not_needed_for_inference( + hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_folder, hf_weights_files, use_safetensors diff --git a/vllm_mindspore/model_executor/model_loader/utils.py b/vllm_mindspore/model_executor/model_loader/utils.py index 98d97eed..c3bf3dd1 100644 --- a/vllm_mindspore/model_executor/model_loader/utils.py +++ b/vllm_mindspore/model_executor/model_loader/utils.py @@ -18,10 +18,19 @@ # See the License for the specific language governing permissions and # limitations under the License. """ utils for load model """ +<<<<<<< HEAD from mindspore import nn +======= +import numpy as np +import torch +from torch import nn +from vllm.attention import Attention +>>>>>>> 0cd4cd6 (support qwq) from vllm.config import ModelConfig from vllm.model_executor.models import ModelRegistry +from vllm_mindspore.model_executor.layers.quantization.base_config import ( + QuantizeMethodBase) from vllm_mindspore.model_executor.models.registry import ( MindSporeModelRegistry, is_mf_mcore_archs) @@ -48,3 +57,54 @@ def get_ms_model_architecture( raise RecursionError("MindSpore unsupported reward model task now!") return model_cls, arch + + +def convert_uint64_to_fp32(arr: np.ndarray): + arr_fp32 = arr.view(np.float32) + output = arr_fp32[:, :, 0::2] + return output + + +def np_int4data_pack_to_int8_3d(np_data): + np_data = np_data.astype(np.int8) + np_data &= 0x000F + np_data[::, ::, 0::2] <<= 0 + np_data[::, ::, 1::2] <<= 4 + np_int4_data = np_data[::, ::, 0::2] | np_data[::, ::, 1::2] + return np_int4_data + + +def unpack_int8_to_int4_3d(packed_data): + low_nibbles = (packed_data & 0x0F).astype(np.uint8) + high_nibbles = ((packed_data >> 4) & 0x0F).astype(np.uint8) + + unpacked = np.empty((*packed_data.shape[:2], packed_data.shape[2] * 2), + dtype=np.uint8) + unpacked[..., 0::2] = low_nibbles + unpacked[..., 1::2] = high_nibbles + + return unpacked + + +def process_weights_after_loading(model: nn.Module, model_config: ModelConfig, + target_device: torch.device) -> None: + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if isinstance(quant_method, QuantizeMethodBase): + # # When quant methods need to process weights after loading + # # (for repacking, quantizing, etc), they expect parameters + # # to be on the global target device. This scope is for the + # # case where cpu offloading is used, where we will move the + # # parameters onto device for processing and back off after. + # with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + + # Currently only used by MLA. + # NOTE: This intentionally happens after other modules so we can easily + # decompress the weights for MLA. + for _, module in model.named_modules(): + if isinstance(module, Attention) and \ + hasattr(module, "process_weights_after_loading"): + # TODO(lucas): see if there is a way to unify the signatures + # of process_weights_after_loading + module.process_weights_after_loading(model_config.dtype) diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index e02de0ab..4e535c01 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -17,16 +17,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import glob +import json +import os from collections.abc import Generator from typing import Any +import huggingface_hub import mindspore as ms +import numpy as np +from huggingface_hub import snapshot_download from mindspore import Parameter from safetensors import safe_open from tqdm.auto import tqdm +from vllm.config import LoadConfig +from vllm.model_executor.model_loader.weight_utils import (DisabledTqdm, + get_lock) + +from vllm_mindspore.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm_mindspore.platforms.ascend import ModelConfig from vllm_mindspore.utils import atlas_inference -import numpy as np from vllm.model_executor.model_loader.weight_utils import (_BAR_FORMAT, enable_tqdm) @@ -79,3 +90,97 @@ def default_weight_loader(param: Parameter, loaded_weight: Any) -> None: """Default weight loader.""" loaded_weight = loaded_weight[:] param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype)) + + +def get_quant_config(model_config: ModelConfig, + load_config: LoadConfig) -> QuantizationConfig: + + from vllm_mindspore.model_executor.layers.quantization import ( + get_quantization_config) + quant_cls = get_quantization_config(model_config.quantization) + + # GGUF doesn't have config file + if model_config.quantization == "gguf": + return quant_cls.from_config({}) + + # Read the quantization config from the HF model config, if available. + hf_quant_config = getattr(model_config.hf_config, "quantization_config", + None) + # some vision model may keep quantization_config in their text_config + hf_text_config = getattr(model_config.hf_config, "text_config", None) + if hf_quant_config is None and hf_text_config is not None: + hf_quant_config = getattr(hf_text_config, "quantization_config", None) + if hf_quant_config is None: + # compressed-tensors uses a compressions_config + hf_quant_config = getattr(model_config.hf_config, "compression_config", + None) + if hf_quant_config is not None: + if os.path.isdir(model_config.model): + quant_config_file = os.path.join( + model_config.model, + quant_cls.get_config_filenames()[0]) + with open(quant_config_file) as f: + quant_config = json.load(f) + return quant_cls.from_config(hf_quant_config | quant_config) + + # In case of bitsandbytes/QLoRA, get quant config from the adapter model. + if model_config.quantization == "bitsandbytes": + if (not load_config.model_loader_extra_config + or "qlora_adapter_name_or_path" + not in load_config.model_loader_extra_config): + return quant_cls.from_config({"adapter_name_or_path": ""}) + model_name_or_path = load_config.model_loader_extra_config[ + "qlora_adapter_name_or_path"] + + else: + model_name_or_path = model_config.model + is_local = os.path.isdir(model_name_or_path) + if not is_local: + # Download the config files. + with get_lock(model_name_or_path, load_config.download_dir): + hf_folder = snapshot_download( + model_name_or_path, + revision=model_config.revision, + allow_patterns="*.json", + cache_dir=load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + tqdm_class=DisabledTqdm, + ) + else: + hf_folder = model_name_or_path + + possible_config_filenames = quant_cls.get_config_filenames() + + # If the quantization config is not found, use the default config. + if not possible_config_filenames: + return quant_cls() + + config_files = glob.glob(os.path.join(hf_folder, "*.json")) + + quant_config_files = [ + f for f in config_files if any( + f.endswith(x) for x in possible_config_filenames) + ] + if len(quant_config_files) == 0: + raise ValueError( + f"Cannot find the config file for {model_config.quantization}") + if len(quant_config_files) > 1: + raise ValueError( + f"Found multiple config files for {model_config.quantization}: " + f"{quant_config_files}") + + quant_config_file = quant_config_files[0] + with open(quant_config_file) as f: + config = json.load(f) + + if model_config.quantization == "bitsandbytes": + config["adapter_name_or_path"] = model_name_or_path + elif model_config.quantization == "modelopt": + if config["producer"]["name"] == "modelopt": + return quant_cls.from_config(config) + else: + raise ValueError( + f"Unsupported quantization config" + f" found for {model_config.quantization} in {f}.") + + return quant_cls.from_config(config) diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index 8be2b4d5..42d08c7a 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -46,6 +46,7 @@ from vllm.model_executor.models.interfaces import SupportsLoRA from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.sequence import IntermediateTensors from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.distributed import get_tensor_model_parallel_rank from vllm_mindspore.utils import atlas_inference, FORMAT_TYPE from vllm_mindspore.attention import Attention @@ -378,6 +379,10 @@ class Qwen2Model(nn.Cell): ] for name, loaded_weight in weights: + if get_tensor_model_parallel_rank( + ) > 0 and "o_proj.quant_bias" in name: + continue + if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name -- Gitee From 5bddce2f61d2889bdcdc96443046cec839ef1883 Mon Sep 17 00:00:00 2001 From: superxf Date: Wed, 30 Jul 2025 16:47:52 +0800 Subject: [PATCH 05/12] fix new branch --- .../model_executor/layers/linear.py | 43 ++++++------------- .../model_loader/weight_utils.py | 23 +++++++++- 2 files changed, 35 insertions(+), 31 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index 79dbd0f1..4e719ae3 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -343,23 +343,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear): assert loaded_shard_id < len(self.output_sizes) shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size shard_size = self.output_sizes[loaded_shard_id] // tp_size - param_data = param.data - param_data = param_data.narrow(output_dim, shard_offset, - shard_size) - start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size).contiguous() - assert param_data.shape == loaded_weight.shape - if len(loaded_weight.shape) == 2: - param[shard_offset:shard_offset + - shard_size, :] = loaded_weight - else: - param[shard_offset:shard_offset + shard_size] = loaded_weight - else: - assert param.shape == loaded_weight.shape - if loaded_weight.dtype == ms.float32 and param.dtype == ms.float16: - loaded_weight = loaded_weight.astype(ms.float16) - param.set_data(loaded_weight.contiguous()) + + start_idx = tp_rank * shard_size + loaded_weight = split_loaded_weight(loaded_weight, output_dim, + start_idx, shard_size) + if param.name.endswith("weight"): + assert loaded_weight.shape == (shard_size, param.shape[1]) + param[shard_offset:shard_offset + + shard_size] = ms.from_numpy(loaded_weight) class QKVParallelLinear(ColumnParallelLinear): @@ -442,11 +433,6 @@ class QKVParallelLinear(ColumnParallelLinear): loaded_weight, loaded_shard_id: Optional[str] = None): output_dim = getattr(param, "output_dim", None) - if output_dim is None: - if loaded_weight.dtype == ms.float32 and param.dtype == ms.float16: - loaded_weight = loaded_weight.astype(ms.float16) - param.set_data(loaded_weight.contiguous()) - return tp_rank = get_tensor_model_parallel_rank() # QKV loaded weight is already fused on disk (qkv safetensors). @@ -496,13 +482,11 @@ class QKVParallelLinear(ColumnParallelLinear): start_idx, shard_size) loaded_weight = ms.from_numpy(loaded_weight) - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size).contiguous() - assert param_data.shape == loaded_weight.shape - if len(loaded_weight.shape) == 2: - param[shard_offset:shard_offset + shard_size, :] = loaded_weight - else: - param[shard_offset:shard_offset + shard_size] = loaded_weight + if param.name.endswith("weight"): + assert loaded_weight.shape == (shard_size, param.shape[1]) + if param.name.endswith("bias"): + assert loaded_weight.shape == (shard_size, ) + param[shard_offset:shard_offset + shard_size] = loaded_weight class RowParallelLinear(LinearBase): @@ -623,7 +607,6 @@ class RowParallelLinear(LinearBase): def weight_loader(self, param, loaded_weight): tp_rank = get_tensor_model_parallel_rank() - param_data = param.data input_dim = getattr(param, "input_dim", None) shard_size = self.input_size_per_partition start_idx = tp_rank * shard_size diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index 4e535c01..4a0fdcf8 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -31,7 +31,9 @@ from mindspore import Parameter from safetensors import safe_open from tqdm.auto import tqdm from vllm.config import LoadConfig -from vllm.model_executor.model_loader.weight_utils import (DisabledTqdm, +from vllm.model_executor.model_loader.weight_utils import (_BAR_FORMAT, + DisabledTqdm, + enable_tqdm, get_lock) from vllm_mindspore.model_executor.layers.quantization.base_config import ( @@ -63,6 +65,11 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): loaded_weight = loaded_weight[:, :, start_idx:end_idx] else: raise ValueError("shard_dim:{} is not supported.".format(shard_dim)) + loaded_weight = ( + loaded_weight.astype(np.float16) + if (str(loaded_weight.dtype) == 'bfloat16' and is_310p()) + else loaded_weight + ) return loaded_weight @@ -79,16 +86,30 @@ def safetensors_weights_iterator( ): with safe_open(st_file, framework="np") as f: for name in f.keys(): # noqa: SIM118 +<<<<<<< HEAD # TODO: use slice x = f.get_tensor(name) x = x.astype(np.float16) \ if (str(x.dtype) == 'bfloat16' and atlas_inference()) else x yield name, ms.tensor(x) +======= + # Return a lightweight PySafeSlice object that uses file + # pointer offset internally to read Safetensor on demand, + # avoiding memory explosion. Actual data can be obtained + # through slicing operation like param[start:end] + param = f.get_slice(name) + yield name, param +>>>>>>> 8858529 (fix new branch) def default_weight_loader(param: Parameter, loaded_weight: Any) -> None: """Default weight loader.""" loaded_weight = loaded_weight[:] + loaded_weight = ( + loaded_weight.astype(np.float16) + if (str(loaded_weight.dtype) == 'bfloat16' and is_310p()) + else loaded_weight + ) param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype)) -- Gitee From d9fa2613759b9caf5c7f1a1ea799eed971dadcd1 Mon Sep 17 00:00:00 2001 From: HighCloud Date: Tue, 22 Jul 2025 14:36:53 +0800 Subject: [PATCH 06/12] change atlas_inference to is_310p --- vllm_mindspore/distributed/parallel_state.py | 5 ++-- .../model_loader/weight_utils.py | 4 +-- .../mf_models/deepseekv3_weight_processor.py | 13 ++++++++++ .../models/mf_models/weight_processor.py | 8 ++++-- .../model_executor/models/model_base.py | 12 ++++----- vllm_mindspore/model_executor/models/qwen2.py | 6 ++--- vllm_mindspore/utils.py | 25 +++++++++++++++++++ vllm_mindspore/v1/worker/gpu_model_runner.py | 8 +++--- vllm_mindspore/worker/cache_engine.py | 7 +++--- vllm_mindspore/worker/model_runner.py | 6 ++--- 10 files changed, 70 insertions(+), 24 deletions(-) diff --git a/vllm_mindspore/distributed/parallel_state.py b/vllm_mindspore/distributed/parallel_state.py index 697196fa..a3ef9fd8 100644 --- a/vllm_mindspore/distributed/parallel_state.py +++ b/vllm_mindspore/distributed/parallel_state.py @@ -4,7 +4,8 @@ from torch.distributed import ProcessGroup from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union) from vllm.distributed.parallel_state import _split_tensor_dict, TensorMetadata -from vllm_mindspore.utils import atlas_inference +from vllm_mindspore.utils import is_310p + def gc_broadcast_tensor_dict( self, @@ -20,7 +21,7 @@ def gc_broadcast_tensor_dict( if (not torch.distributed.is_initialized() or self.world_size == 1): return tensor_dict - if not atlas_inference(): + if not is_310p(): group = self.device_group metadata_group = self.cpu_group assert src < self.world_size, f"Invalid src rank ({src})" diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index 4a0fdcf8..1e162829 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import (_BAR_FORMAT, from vllm_mindspore.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm_mindspore.platforms.ascend import ModelConfig -from vllm_mindspore.utils import atlas_inference +from vllm_mindspore.utils import is_310p from vllm.model_executor.model_loader.weight_utils import (_BAR_FORMAT, enable_tqdm) @@ -90,7 +90,7 @@ def safetensors_weights_iterator( # TODO: use slice x = f.get_tensor(name) x = x.astype(np.float16) \ - if (str(x.dtype) == 'bfloat16' and atlas_inference()) else x + if (str(x.dtype) == 'bfloat16' and is_310p()) else x yield name, ms.tensor(x) ======= # Return a lightweight PySafeSlice object that uses file diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py index 73bc3027..a8aa7127 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py @@ -402,6 +402,19 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): w2_scale_ms_stack_param = np.stack(w2_scale_list, axis=0) w3_scale_ms_stack_param = np.stack(w3_scale_list, axis=0) +<<<<<<< HEAD +======= + if self.is_310p: + weight_scale_dtype = ms.float32 + weight_concat_axis = 2 + w1_ms_stack_param = w1_ms_stack_param.transpose(0, 2, 1) + w2_ms_stack_param = w2_ms_stack_param.transpose(0, 2, 1) + w3_ms_stack_param = w3_ms_stack_param.transpose(0, 2, 1) + else: + weight_scale_dtype = ms.bfloat16 + weight_concat_axis = 1 + +>>>>>>> f3c366e (change atlas_inference to is_310p) if ffn_concat: # w_gate_hidden w_gate_hidden_name = f"{base_path}.w_gate_hidden._layer.weight" diff --git a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py index d60506fb..35694a6a 100644 --- a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py @@ -25,7 +25,7 @@ from mindformers.parallel_core.inference.parallel_state import ( from mindformers.parallel_core.inference.utils import get_tp_world_size from mindspore.communication.management import get_group_size, get_rank from safetensors import safe_open - +from vllm_mindspore.utils import is_310p class EPMethod(Enum): """ @@ -45,7 +45,9 @@ class BaseWeightProcessor: """ - def __init__(self, config, network, is_quant): + def __init__(self, config, network, is_quant, vllm_config): + self.vllm_config = vllm_config + self.is_310p = is_310p() self.config = config self.network = network self.is_quant = is_quant @@ -165,6 +167,7 @@ class BaseWeightProcessor: else: raise ValueError( "split_axis:{} is not supported.".format(split_axis)) + return split_data, qint4 def get_safetensor_from_file_split_moe_tp_group(self, @@ -195,6 +198,7 @@ class BaseWeightProcessor: else: raise ValueError( "split_axis:{} is not supported.".format(split_axis)) + return split_data, qint4 def get_routed_safetensor_3_dim(self, diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 4711cb99..6a73e171 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -34,9 +34,8 @@ from vllm.sequence import IntermediateTensors from vllm_mindspore.model_executor.models.attention_mask import ( LowerTriangularMask) from vllm_mindspore.model_executor.utils import set_model_context -from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE +from vllm_mindspore.utils import FORMAT_TYPE, STR_DTYPE_TO_MS_DTYPE, is_310p from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata -from vllm_mindspore.utils import atlas_inference, FORMAT_TYPE class AttentionWrapper: @@ -47,7 +46,7 @@ class AttentionWrapper: vllm_config.parallel_config) head_size = vllm_config.model_config.get_head_size() num_block = 0 - if atlas_inference(): + if is_310p(): self.kv_shape = [num_block, block_size, num_kv_heads * head_size] self.kv_cache = [ ( @@ -92,7 +91,7 @@ class MLAAttentionWrapper(AttentionWrapper): self.use_mla_op = bool( vllm_config.additional_config and vllm_config.additional_config.get('use_mla_op') == 1) - if atlas_inference(): + if is_310p(): self.kv_cache = [ ( ops.auto_generate.format_cast( @@ -464,8 +463,9 @@ class NativeModel(MsModelBase): block_size = self.cache_config.block_size num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) head_size = self.model_config.get_head_size() - kv_cache_shape = (None, block_size, num_kv_heads * head_size) if atlas_inference() \ - else (None, block_size, num_kv_heads, head_size) + kv_cache_shape = (None, block_size, num_kv_heads * head_size) \ + if is_310p() else (None, block_size, num_kv_heads, + head_size) kv_cache_dtype = (self.model_config.dtype if self.cache_config.cache_dtype == "auto" else diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index 42d08c7a..a9c72791 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -48,7 +48,7 @@ from vllm.sequence import IntermediateTensors from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.distributed import get_tensor_model_parallel_rank -from vllm_mindspore.utils import atlas_inference, FORMAT_TYPE +from vllm_mindspore.utils import is_310p, FORMAT_TYPE from vllm_mindspore.attention import Attention from vllm_mindspore.model_executor.layers.activation import SwiGLU from vllm_mindspore.model_executor.layers.layernorm import RMSNorm @@ -423,7 +423,7 @@ class Qwen2Model(nn.Cell): loaded_params.add(name) def adjust_weight(params_dict): - if not atlas_inference(): + if not is_310p(): return target_keywords = [ @@ -440,7 +440,7 @@ class Qwen2Model(nn.Cell): ms.runtime.synchronize() param.set_data(cast_weight) - if atlas_inference(): + if is_310p(): ms.runtime.synchronize() adjust_weight(params_dict) ms.runtime.synchronize() diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index 8bb43caa..f86d754d 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -240,6 +240,31 @@ def is_310p(): return device in ['310p', 'ascend310p'] +def check_ready(): + from mindspore import set_context + + # Common environment variables of predict. + set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + default_env = { + "MS_INTERNAL_DISABLE_CUSTOM_KERNEL_LIST": + "FlashAttentionScore,PagedAttention", + } + if is_310p(): + default_env["MS_ENABLE_INTERNAL_BOOST"] = "off" + env_setup(default_env) + + if os.getenv("MS_MEMPOOL_BLOCK_SIZE"): + set_context( + mempool_block_size=f"{os.environ['MS_MEMPOOL_BLOCK_SIZE']}GB") + + if is_mindformers_model_backend(): + logger.info("Run with Mindformers backend!") + elif is_mindone_model_backend(): + logger.info("Run with MindONE backend!") + else: + logger.info("Run with native model backend!") + register_connector() + def convert_np_to_ms_dtype(value): """convert_np_to_ms_dtype""" if value.dtype == np.int8: diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index e7f06d4f..d40efe70 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -36,7 +36,8 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState from vllm_mindspore.model_executor.layers.rotary_embedding import ( InferMRotaryEmbedding as MRotaryEmbedding) -from vllm_mindspore.utils import get_dtype_size, get_valid_dtype, atlas_inference, FORMAT_TYPE +from vllm_mindspore.utils import (FORMAT_TYPE, get_dtype_size, get_valid_dtype, + is_310p) logger = init_logger(__name__) @@ -312,7 +313,7 @@ def _reshape_kv_cache_tensors( kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) - if atlas_inference(): + if is_310p(): *dims, second_last, last = kv_cache_shape kv_cache_shape = (*dims, second_last * last) try: @@ -334,6 +335,7 @@ def _reshape_kv_cache_tensors( for i in range(len(kv_cache_stride_order)) ] kv_cache_layer = [] +<<<<<<< HEAD for idx, kv_cache_raw_tensor in enumerate( kv_cache_raw_tensors[layer_name]): if use_mla_op: @@ -347,7 +349,7 @@ def _reshape_kv_cache_tensors( else: cache_block = kv_cache_raw_tensor.view( kv_cache_shape[1:]).permute(*inv_order[1:]) - if atlas_inference(): + if is_310p(): from mindspore.common.api import _pynative_executor cache_block_nz = ops.auto_generate.format_cast(cache_block, FORMAT_TYPE['nz']) _pynative_executor.sync() diff --git a/vllm_mindspore/worker/cache_engine.py b/vllm_mindspore/worker/cache_engine.py index 7675379c..e8c20397 100644 --- a/vllm_mindspore/worker/cache_engine.py +++ b/vllm_mindspore/worker/cache_engine.py @@ -26,7 +26,8 @@ from mindspore import mutable, mint, ops from typing import List from vllm.logger import init_logger -from vllm_mindspore.utils import MsKVCache, get_valid_dtype, atlas_inference, FORMAT_TYPE +from vllm_mindspore.utils import (MsKVCache, get_valid_dtype, is_310p, + FORMAT_TYPE) logger = init_logger(__name__) @@ -34,7 +35,7 @@ logger = init_logger(__name__) def create_block(shape, dtype, name=None, device=None): from mindspore.common.api import _pynative_executor blocks = mint.empty(*shape, dtype=dtype, device=device) - if device == "Ascend" and atlas_inference(): + if device == "Ascend" and is_310p(): blocks_nz = ops.auto_generate.format_cast(blocks, FORMAT_TYPE['nz']) _pynative_executor.sync() import gc @@ -53,7 +54,7 @@ def ms_allocate_kv_cache( """Allocates KV cache on the specified device.""" kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, self.block_size, self.num_kv_heads, self.head_size) - if atlas_inference(): + if is_310p(): *dims, second_last, last = kv_cache_shape kv_cache_shape = (*dims, second_last * last) kv_cache: List[MsKVCache] = [] diff --git a/vllm_mindspore/worker/model_runner.py b/vllm_mindspore/worker/model_runner.py index 6ab97c1b..1c37be98 100644 --- a/vllm_mindspore/worker/model_runner.py +++ b/vllm_mindspore/worker/model_runner.py @@ -28,7 +28,7 @@ from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.sequence import SequenceGroupMetadata -from vllm_mindspore.utils import STR_DTYPE_TO_TENSOR_DTYPE, atlas_inference +from vllm_mindspore.utils import STR_DTYPE_TO_TENSOR_DTYPE, is_310p logger = init_logger(__name__) @@ -140,8 +140,8 @@ def _dummy_run(self, block_size = self.cache_config.block_size num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) head_size = self.model_config.get_head_size() - kv_shape = [0, block_size, num_kv_heads * head_size] if atlas_inference() else \ - [0, block_size, num_kv_heads, head_size] + kv_shape = [0, block_size, num_kv_heads * head_size] \ + if is_310p() else [0, block_size, num_kv_heads, head_size] kv_caches = mutable([ mutable( ( -- Gitee From d4978f2a3c4d40c6c211da2ceb5b284b27685d3a Mon Sep 17 00:00:00 2001 From: luolihao Date: Thu, 24 Jul 2025 19:08:26 +0800 Subject: [PATCH 07/12] support qwq w8a8sc --- .../layers/quantization/__init__.py | 9 +- .../quantization/sparse_quant_modelslim.py | 182 ++++++++++++++++++ vllm_mindspore/model_executor/models/qwen2.py | 50 ++++- 3 files changed, 238 insertions(+), 3 deletions(-) create mode 100644 vllm_mindspore/model_executor/layers/quantization/sparse_quant_modelslim.py diff --git a/vllm_mindspore/model_executor/layers/quantization/__init__.py b/vllm_mindspore/model_executor/layers/quantization/__init__.py index 6c9e2e41..3c6c2da9 100644 --- a/vllm_mindspore/model_executor/layers/quantization/__init__.py +++ b/vllm_mindspore/model_executor/layers/quantization/__init__.py @@ -21,7 +21,10 @@ from typing import Literal, get_args from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -QuantizationMethods = Literal["smoothquant"] +QuantizationMethods = Literal[ + "smoothquant", + "sparsequant" +] QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) # The customized quantization methods which will be added to this dict. @@ -34,8 +37,10 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: # lazy import to avoid triggering `torch.compile` too early from .smooth_quant_modelslim import SmoothQuantModelSlimConfig + from .sparse_quant_modelslim import SparseQuantModelSlimConfig method_to_config: dict[str, type[QuantizationConfig]] = { - "smoothquant": SmoothQuantModelSlimConfig + "smoothquant": SmoothQuantModelSlimConfig, + "sparsequant": SparseQuantModelSlimConfig } # Update the `method_to_config` with customized quantization methods. method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) diff --git a/vllm_mindspore/model_executor/layers/quantization/sparse_quant_modelslim.py b/vllm_mindspore/model_executor/layers/quantization/sparse_quant_modelslim.py new file mode 100644 index 00000000..f6ede5ed --- /dev/null +++ b/vllm_mindspore/model_executor/layers/quantization/sparse_quant_modelslim.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from typing import Any, Optional, Dict + +import torch +import numpy as np +import mindspore + +from mindspore.common.initializer import initializer +from mindspore import Parameter, ops, Tensor +from mindspore.ops.operations._infer_ops import QuantV2 +from mindspore.communication import get_rank +from vllm_mindspore.model_executor.layers.linear import LinearMethodBase, UnquantizedLinearMethod, LinearBase + +from .base_config import QuantizationConfig + + + +class SparseQuantModelSlimConfig(QuantizationConfig): + '''Config class for SparseQuant.''' + + def __init__( + self, + full_config: Dict[str, Any], + weight_bits: Optional[int] = 8, + group_size: Optional[int] = 1, + zero_point: Optional[bool] = True, + dynamic_quant: Optional[bool] = False, + kv_cache_bits: Optional[int] = 16, + modules_to_not_convert: Optional[list[str]] = None, + ) -> None: + super().__init__() + self.full_config = full_config + self.weight_bits = weight_bits + self.group_size = group_size + self.zero_point = zero_point + self.dynamic_quant = dynamic_quant + self.kv_cache_bits = kv_cache_bits + self.modules_to_not_convert = modules_to_not_convert or [] + + if self.weight_bits != 8: + raise ValueError( + "Currently, only 8-bit weight quantization is supported for " + f"A8W8SC, but got {self.weight_bits} bits.") + self.pack_factor = 8 // self.weight_bits + + def __repr__(self) -> str: + return (f"SparseConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"zero_point={self.zero_point}, " + f"modules_to_not_convert={self.modules_to_not_convert})") + + @staticmethod + def get_config_filenames() -> list[str]: + return [ + "quant_model_description.json" + ] + + @classmethod + def get_min_capability(cls) -> int: + """Minimum GPU capability to support the quantization method. + + E.g., 70 for Volta, 75 for Turing, 80 for Ampere. + This requirement is due to the custom CUDA kernels used by the + quantization method. + """ + return -1 + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "SparseQuantModelSlimConfig": + return cls(config) + + def get_name(self) -> str: + return "SparseQuant" + + def get_supported_act_dtypes(self) -> list[torch.dtype]: + return [torch.int8, torch.float16, torch.bfloat16] + + def get_quant_method(self, layer: mindspore.nn.Cell, + prefix: str) -> "QuantizeMethodBase": + + rank_id = get_rank() + sparse_quant_description = self.full_config[f'rank_{rank_id}'] + if isinstance(layer, LinearBase) and sparse_quant_description[f"{prefix}.weight"].lower() == "w8a8s": + compress_weight_size = sparse_quant_description[f"{prefix}.weight.shape"] + compress_index_size = sparse_quant_description[f"{prefix}.index.shape"] + + return A8W8SCLinearMethod(self, compress_weight_size[0], compress_index_size[0]) + + return UnquantizedLinearMethod() + + +class A8W8SCLinearMethod(LinearMethodBase): + '''Linear method for A8W8SCLinearMethod.''' + + def __init__(self, quant_config: SparseQuantModelSlimConfig, compress_weight_size=None, compress_index_size=None): + self.quant_config = quant_config + self.compress_weight_size = compress_weight_size + self.compress_index_size = compress_index_size + + self.quant = QuantV2() + self.linear_sparse = ops.auto_generate.QuantLinearSparse() + + def create_weights(self, + layer: mindspore.nn.Cell, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype, + is_group_mm=False, + expert_num_per_partition=1, + **extra_weight_attrs): + if input_size_per_partition % self.quant_config.group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + output_size_per_partition = sum(output_partition_sizes) + self.output_size_per_partition = output_size_per_partition + self.input_size_per_partition = input_size_per_partition + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + weight = Parameter(initializer('normal', (self.compress_weight_size), mindspore.int8), name="weight") + index = Parameter(initializer('normal', (self.compress_index_size), mindspore.int8), name="index") + deq_scale = Parameter(initializer('normal', (self.output_size_per_partition), mindspore.int64), + name="deq_scale") + quant_bias = Parameter(initializer('zeros', (self.output_size_per_partition), mindspore.int32), + name="quant_bias") + input_scale = Parameter(Tensor(np.ones(self.input_size_per_partition), mindspore.float16), + name="input_scale") + input_offset = Parameter(Tensor(np.zeros(self.input_size_per_partition), mindspore.int8), + name="input_offset") + + layer.insert_param_to_cell("weight", weight) + layer.insert_param_to_cell("index", index) + layer.insert_param_to_cell("deq_scale", deq_scale) + layer.insert_param_to_cell("quant_bias", quant_bias) + layer.insert_param_to_cell("input_scale", input_scale) + layer.insert_param_to_cell("input_offset", input_offset) + + def apply(self, + layer: mindspore.nn.Cell, + x: mindspore.Tensor, + bias: mindspore.Parameter = None, group_list=None, cumsum_flag=False) -> mindspore.Tensor: + weight = layer.weight + index = layer.index + deq_scale = layer.deq_scale + quant_bias = layer.quant_bias + input_scale = layer.input_scale + input_offset = layer.input_offset + + output_shape = x.shape[:-1] + (self.output_size_per_partition,) + x = x.reshape(-1, self.input_size_per_partition) + + x = self.quant(x, input_scale, input_offset, False, "ROUND", mindspore.int8) + x = self.linear_sparse(x, weight, deq_scale, index, quant_bias) + + x = x.reshape(output_shape) + + return x \ No newline at end of file diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index a9c72791..6891aea4 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -65,6 +65,7 @@ from vllm_mindspore.model_executor.models.model_base import (NativeModel) from vllm_mindspore.model_executor.models.utils import ( PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from mindspore.communication.management import get_rank class Qwen2MLP(nn.Cell): @@ -366,6 +367,50 @@ class Qwen2Model(nn.Cell): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def load_split_weights(self, weights: Iterable[tuple[str, Tensor]], + params_dict: dict[str, Parameter]): + weights_dict = dict(weights) + + for name, loaded_weight in weights_dict.items(): + if get_tensor_model_parallel_rank( + ) > 0 and "o_proj.quant_bias" in name: + continue + + if name not in params_dict: + continue + + param = params_dict[name] + param.set_data(loaded_weight.contiguous()) + + def adjust_weight(params_dict): + if not is_310p(): + return + + target_keywords = [ + "qkv_proj.weight", + "o_proj.weight", + "gate_up_proj.weight", + "down_proj.weight", + # "lm_head.weight", + ] + + rank_id = get_rank() + for name, param in params_dict.items(): + if any(name.endswith(keyword) for keyword in target_keywords): + weight_type = self.quant_config.full_config[f"rank_{rank_id}"][name] + if weight_type.lower() == "w8a8s": + # 压缩后权重不需要转Nz + continue + + cast_weight = ops.auto_generate.format_cast(param, FORMAT_TYPE['nz']) + ms.runtime.synchronize() + param.set_data(cast_weight) + + if is_310p(): + ms.runtime.synchronize() + adjust_weight(params_dict) + ms.runtime.synchronize() + def load_weights(self, weights: Iterable[tuple[str, Tensor]], params_dict: dict[str, Parameter]): loaded_params: set[str] = set() @@ -516,7 +561,10 @@ class Qwen2ForCausalLM(NativeModel, SupportsLoRA): def load_weights(self, weights: Iterable[tuple[str, Tensor]]) -> set[str]: params_dict = self.get_params_dict() - self.model.load_weights(weights, params_dict) + if self.vllm_config.model_config.quantization == "sparsequant": + self.model.load_split_weights(weights, params_dict) + else: + self.model.load_weights(weights, params_dict) def sample(self, logits: Tensor, sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: -- Gitee From 06c741b6e866be489bc4734468f5d4999e9b9af8 Mon Sep 17 00:00:00 2001 From: huangzhuo Date: Mon, 4 Aug 2025 15:29:35 +0800 Subject: [PATCH 08/12] graph mode support mutilora --- vllm_mindspore/__init__.py | 3 +- vllm_mindspore/lora/layers.py | 72 ++++++++- vllm_mindspore/lora/ops/torch_ops/lora_ops.py | 120 ++++++++------ .../lora/punica_wrapper/punica_npu.py | 150 ++++++++++++------ vllm_mindspore/lora/utils.py | 18 +++ .../model_executor/models/model_base.py | 6 +- 6 files changed, 256 insertions(+), 113 deletions(-) diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 157fad33..c361f65e 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -100,10 +100,11 @@ vllm.utils.memory_profiling = ms_memory_profiling import vllm.lora.utils from vllm_mindspore.model_executor.layers.linear import LinearBase -from vllm_mindspore.lora.utils import _all_lora_classes +from vllm_mindspore.lora.utils import _all_lora_classes, replace_submodule vllm.lora.utils._all_lora_classes = _all_lora_classes vllm.lora.utils.LinearBase = LinearBase +vllm.lora.utils.replace_submodule = replace_submodule import vllm.lora.models from vllm_mindspore.lora.models import ( diff --git a/vllm_mindspore/lora/layers.py b/vllm_mindspore/lora/layers.py index 16351109..bff4d3f7 100644 --- a/vllm_mindspore/lora/layers.py +++ b/vllm_mindspore/lora/layers.py @@ -24,7 +24,9 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Optional, Union, cast import mindspore as ms -from mindspore import mint +from mindspore import Parameter, ops, mint +from mindspore.common.initializer import initializer +import torch.nn.functional as F from transformers import PretrainedConfig from vllm.adapter_commons.layers import AdapterMapping from vllm.config import LoRAConfig @@ -321,7 +323,7 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): self.output_size, self.tp_size)) else: raise NotImplementedError - + ''' self.lora_a_stacked = tuple( mint.zeros( ( @@ -342,6 +344,13 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): ), dtype=lora_config.lora_dtype, ) for _ in range(self.n_slices)) + ''' + self.lora_a_stacked = Parameter( + initializer('zeros', (self.n_slices, max_loras, 1, lora_a_out_size, self.input_size), + lora_config.lora_dtype)) + self.lora_b_stacked = Parameter( + initializer('zeros', (self.n_slices, max_loras, 1, lora_b_out_size, lora_config.max_lora_rank), + lora_config.lora_dtype)) if lora_config.bias_enabled: lora_bias_out_size = lora_b_out_size self.lora_bias_stacked = tuple( @@ -405,8 +414,8 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): x: ms.Tensor, bias: Optional[ms.Tensor] = None) -> ms.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked, - self.lora_b_stacked, + self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked[0], + self.lora_b_stacked[0], self.lora_bias_stacked, 1.0, self.output_slices) return output @@ -549,6 +558,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): lora_a_output_size_per_partition = ( lora_config.max_lora_rank if not lora_config.fully_sharded_loras else divide(lora_config.max_lora_rank, self.tp_size)) + ''' self.lora_a_stacked = tuple( mint.zeros( ( @@ -559,8 +569,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ), dtype=lora_config.lora_dtype, ) for _ in range(self.n_slices)) - self.lora_b_stacked = tuple( - mint.zeros( + ''' + self.lora_a_stacked = Parameter( + initializer('zeros', (self.n_slices, max_loras, 1, lora_a_output_size_per_partition, self.input_size), + lora_config.lora_dtype)) + self.lora_b_stacked = Parameter( + initializer('zeros', (self.n_slices, max_loras, 1, self.output_slices[0], lora_config.max_lora_rank), + lora_config.lora_dtype)) + if lora_config.bias_enabled: + self.lora_bias_stacked = tuple( + mint.zeros( ( max_loras, 1, @@ -653,6 +671,25 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): return (type(source_layer) is MergedColumnParallelLinear and len(packed_modules_list) == 2) + def apply(self, + x: ms.Tensor, + bias: Optional[ms.Tensor] = None) -> ms.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + output_w1, output_w3= mint.split( + output, + (self.lora_b_stacked[0].shape[2], self.lora_b_stacked[0].shape[2]), + -1) + output_w1 = self.punica_wrapper(output_w1, x, self.lora_a_stacked[0], + self.lora_b_stacked[0], + self.lora_bias_stacked, 1.0, + self.output_slices) + output_w3 = self.punica_wrapper(output_w3, x, self.lora_a_stacked[1], + self.lora_b_stacked[1], + self.lora_bias_stacked, 1.0, + self.output_slices) + + return ops.cat((output_w1, output_w3), 1) + class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): """ @@ -782,6 +819,29 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): return (type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 3) + def apply(self, + x: ms.Tensor, + bias: Optional[ms.Tensor] = None) -> ms.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + output_q, output_k, output_v = mint.split( + output, + (self.q_proj_shard_size, self.kv_proj_shard_size, self.kv_proj_shard_size), + -1) + output_q = self.punica_wrapper(output_q, x, self.lora_a_stacked[0], + self.lora_b_stacked[0], + self.lora_bias_stacked, 1.0, + self.output_slices) + output_k = self.punica_wrapper(output_k, x, self.lora_a_stacked[1], + self.lora_b_stacked[1,:,:,:self.output_slices[1],:], + self.lora_bias_stacked, 1.0, + self.output_slices) + output_v = self.punica_wrapper(output_v, x, self.lora_a_stacked[2], + self.lora_b_stacked[2,:,:,:self.output_slices[2],:], + self.lora_bias_stacked, 1.0, + self.output_slices) + + return ops.cat((output_q, output_k, output_v), 1) + class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): diff --git a/vllm_mindspore/lora/ops/torch_ops/lora_ops.py b/vllm_mindspore/lora/ops/torch_ops/lora_ops.py index d3d48975..92d58e61 100644 --- a/vllm_mindspore/lora/ops/torch_ops/lora_ops.py +++ b/vllm_mindspore/lora/ops/torch_ops/lora_ops.py @@ -20,7 +20,7 @@ """ For punica_npu """ -from mindspore import mint +from mindspore import mint, ops from mindspore.ops.auto_generate import grouped_matmul_v4 @@ -85,7 +85,7 @@ def bgmv_expand(inputs, def sgmv_shrink( inputs, lora_a_weights, - output_tensor, + group_list, b_seq_start_loc, seq_len_tensor, lora_indices_tensor, @@ -94,47 +94,53 @@ def sgmv_shrink( token_nums, scaling, ): - group_list = seq_len_tensor - if (lora_indices_tensor.unique().shape[0] != lora_indices_tensor.shape[0]): - sorted_ids, sorted_counts = sort_lora_by_token_count( - lora_indices_tensor, seq_len_tensor) - group_list = sorted_counts - if lora_a_weights.shape[0] != group_list.shape[0]: - new_tensor = mint.zeros(lora_a_weights.shape[0], - dtype=group_list.dtype) - new_tensor[:group_list.size(0)] = group_list - group_list = new_tensor - if len(lora_a_weights.shape) == 4: - lora_a_weights = lora_a_weights.squeeze(1) - lora_a_weights = mint.transpose(lora_a_weights, 1, 2) + # group_list = seq_len_tensor + # if (lora_indices_tensor.unique().shape[0] != lora_indices_tensor.shape[0]): + # sorted_ids, sorted_counts = sort_lora_by_token_count( + # lora_indices_tensor, seq_len_tensor) + # group_list = sorted_counts + # if lora_a_weights.shape[0] != group_list.shape[0]: + # new_tensor = mint.zeros(lora_a_weights.shape[0], + # dtype=group_list.dtype) + # new_tensor[:group_list.size(0)] = group_list + # group_list = new_tensor + # if len(lora_a_weights.shape) == 4: + # lora_a_weights = lora_a_weights.squeeze(1) + # lora_a_weights = mint.transpose(lora_a_weights, 1, 2) + # outputs = grouped_matmul_v4([inputs], [lora_a_weights], + # group_list=group_list, + # split_item=3, + # group_type=0, + # group_list_type=1) + #outputs = bgmv_shrink(inputs, lora_a_weights, 0, scaling) + lora_a_weights = lora_a_weights.squeeze(1) + lora_a_weights = mint.transpose(lora_a_weights, 1, 2) outputs = grouped_matmul_v4([inputs], [lora_a_weights], group_list=group_list, split_item=3, group_type=0, group_list_type=1) - outputs = outputs[0] - output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] - return output_tensor + # output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] + return outputs * scaling def bgmv_shrink(inputs, lora_b_weights, - output_tensor, lora_indices_tensor, scaling=1.0): - selected_loras = lora_b_weights[lora_indices_tensor].astype( - output_tensor.dtype) - inputs = inputs.astype(output_tensor.dtype) - if len(selected_loras.shape) == 4: - selected_loras = selected_loras.squeeze(1) - outputs = einsum_ms(inputs, selected_loras) - output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] - return output_tensor + selected_loras = lora_b_weights[lora_indices_tensor] + inputs = inputs.astype(lora_b_weights[0].dtype) + # if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(1) + outputs = einsum_ms(inputs, selected_loras) * scaling + # output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] + return scaling * outputs def sgmv_expand_slice(inputs, lora_b_weights, output_tensor, + group_list, b_seq_start_loc, seq_len_tensor, lora_indices_tensor, @@ -144,30 +150,39 @@ def sgmv_expand_slice(inputs, slice_offset, slice_size, add_inputs=False): - group_list = seq_len_tensor - if (lora_indices_tensor.unique().shape[0] != lora_indices_tensor.shape[0]): - sorted_ids, sorted_counts = sort_lora_by_token_count( - lora_indices_tensor, seq_len_tensor) - group_list = sorted_counts - if lora_b_weights.shape[0] != group_list.shape[0]: - new_tensor = mint.zeros(lora_b_weights.shape[0], - dtype=group_list.dtype) - new_tensor[:group_list.size(0)] = group_list - group_list = new_tensor - if len(lora_b_weights.shape) == 4: - lora_b_weights = lora_b_weights.squeeze(1) - lora_b_weights = mint.transpose(lora_b_weights, 1, 2) - inputs = inputs.astype(output_tensor.dtype) + # group_list = seq_len_tensor + # if (lora_indices_tensor.unique().shape[0] != lora_indices_tensor.shape[0]): + # sorted_ids, sorted_counts = sort_lora_by_token_count( + # lora_indices_tensor, seq_len_tensor) + # group_list = sorted_counts + # if lora_b_weights.shape[0] != group_list.shape[0]: + # new_tensor = mint.zeros(lora_b_weights.shape[0], + # dtype=group_list.dtype) + # new_tensor[:group_list.size(0)] = group_list + # group_list = new_tensor + # if len(lora_b_weights.shape) == 4: + # lora_b_weights = lora_b_weights.squeeze(1) + # lora_b_weights = mint.transpose(lora_b_weights, 1, 2) + # inputs = inputs.astype(output_tensor.dtype) + # outputs = grouped_matmul_v4([inputs], [lora_b_weights], + # group_list=group_list, + # split_item=3, + # group_type=0, + # group_list_type=1) + # outputs = outputs[0] + # if add_inputs: + # output_tensor += outputs[:] + # else: + # output_tensor = outputs[:] + #output_tensor = bgmv_expand_slice(inputs, lora_b_weights, output_tensor, 0, 0, 0) + lora_b_weights = lora_b_weights.squeeze(1) + lora_b_weights = mint.transpose(lora_b_weights, 1, 2) outputs = grouped_matmul_v4([inputs], [lora_b_weights], group_list=group_list, split_item=3, group_type=0, group_list_type=1) - outputs = outputs[0] - if add_inputs: - output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] - else: - output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:] + output_tensor = ops.add(output_tensor, outputs) return output_tensor @@ -181,11 +196,12 @@ def bgmv_expand_slice(inputs, selected_loras = lora_b_weights[lora_indices_tensor].astype( output_tensor.dtype) inputs = inputs.astype(output_tensor.dtype) - if len(selected_loras.shape) == 4: - selected_loras = selected_loras.squeeze(1) + # if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(1) outputs = einsum_ms(inputs, selected_loras) - if add_inputs: - output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] - else: - output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:] + output_tensor = ops.add(output_tensor, outputs) + # if add_inputs: + # output_tensor += outputs[:] + # else: + # output_tensor = outputs[:] return output_tensor diff --git a/vllm_mindspore/lora/punica_wrapper/punica_npu.py b/vllm_mindspore/lora/punica_wrapper/punica_npu.py index 0a60baf2..8167c670 100644 --- a/vllm_mindspore/lora/punica_wrapper/punica_npu.py +++ b/vllm_mindspore/lora/punica_wrapper/punica_npu.py @@ -19,20 +19,22 @@ # isort: skip_file """Punica wrapper for NPU.""" -from typing import Callable +from typing import Callable, Optional -from mindspore import mint +from mindspore import mint, nn, Parameter, ops, dtype from mindspore.common import dtype as mstype +from mindspore.common.initializer import initializer from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase from vllm_mindspore.lora.ops.torch_ops.lora_ops import ( bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand, - sgmv_expand_slice, sgmv_shrink) + sgmv_expand_slice, sgmv_shrink, sort_lora_by_token_count) +from vllm_mindspore.model_executor.utils import get_model_context # The platforms that are compatible with the PyTorch-native implementation can # inherit this class -class PunicaWrapperNPU(PunicaWrapperBase): +class PunicaWrapperNPU(PunicaWrapperBase, nn.Cell): """ PunicaWrapperNPU is designed to manage and provide metadata for the punica kernel. The main function is to maintain the state information for @@ -40,32 +42,34 @@ class PunicaWrapperNPU(PunicaWrapperBase): """ def __init__(self, max_num_batched_tokens, max_batches, device, **kwargs): + nn.Cell.__init__(self) PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) + self.max_loras = kwargs["max_loras"] + self.group_list = Parameter(initializer("ones", self.max_loras, dtype.int64), name="group_list") + self.lora_indices = Parameter(initializer("ones", self.max_loras, dtype.int64), name="lora_indices") def _shrink_prefill( self, - y, x, w_t_all, scale, ): - sgmv_shrink( # type: ignore + return sgmv_shrink( # type: ignore x, w_t_all, - y, + self.group_list, *self.prefill_metadata, scale, ) def _shrink_decode( self, - y, x, w_t_all, scale, ): - bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + return bgmv_shrink(x, w_t_all, self.token_lora_indices, scale) def _expand_prefill( self, @@ -78,6 +82,7 @@ class PunicaWrapperNPU(PunicaWrapperBase): x, w_t_all, y, + self.group_list, *self.prefill_metadata, add_inputs, ) @@ -100,10 +105,11 @@ class PunicaWrapperNPU(PunicaWrapperBase): y_slice_size, add_inputs, ): - sgmv_expand_slice( # type: ignore + return sgmv_expand_slice( # type: ignore x, w_t_all, y, + self.group_list, *self.prefill_metadata, y_offset, y_slice_size, @@ -119,7 +125,7 @@ class PunicaWrapperNPU(PunicaWrapperBase): y_slice_size, add_inputs, ): - bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, + return bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs) def _apply_expand( @@ -138,11 +144,11 @@ class PunicaWrapperNPU(PunicaWrapperBase): """ expand_slice_fun: Callable = (self._expand_slice_prefill - if self.is_prefill else + if get_model_context("is_prefill") else self._expand_slice_decode) - expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs) + return expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs) - def _apply_shrink(self, y, x, w_t_all, scale): + def _apply_shrink(self, x, w_t_all, scale): """ Perform the ` y+=x@w_t_all` computation, which is suitable for the GEMM of lora'a. @@ -151,14 +157,51 @@ class PunicaWrapperNPU(PunicaWrapperBase): Otherwise, it is the decode stage, and the _shrink_decode function should be called. """ - y_org = y - y = y.view(-1, y.shape[-1]) + # y_org = y + # y = y.view(-1, y.shape[-1]) shrink_fun: Callable = (self._shrink_prefill - if self.is_prefill else self._shrink_decode) - shrink_fun(y, x, w_t_all, scale) - y.view_as(y_org) + if get_model_context("is_prefill") else self._shrink_decode) + y = shrink_fun(x, w_t_all, scale) + return y + + def update_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: list[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + **kwargs): + self._update_base_metadata(mapping, lora_index_to_id, max_loras, + vocab_size, extra_vocab_size, + long_lora_context) + if mapping.is_prefill: + # Update metadata required for prefill-related operators. + self._update_prefill_metadata(self.token_lora_indices) + self.is_prefill = True + else: + self.is_prefill = False + _, seq_len, lora_indices, _, _, _ = self.prefill_metadata + sorted_ids, sorted_counts = sort_lora_by_token_count( + lora_indices, seq_len) + group_list = sorted_counts + self.group_list.set_data(group_list.astype(dtype.int64)) + # if len(lora_indices) > self.max_loras: + # self.group_list.set_data(seq_len[:self.max_loras].astype(dtype.int64)) + # self.lora_indices.set_data(lora_indices[:self.max_loras].astype(dtype.int64)) + # elif len(lora_indices) < self.max_loras: + # pad_len = int(self.max_loras - len(lora_indices)) + # lora_indices = ops.pad(lora_indices, (0, pad_len), mode='constant', value=0) + # seq_len = ops.pad(seq_len, (0, pad_len), mode='constant', value=0) + # self.group_list.set_data(seq_len.astype(dtype.int64)) + # self.lora_indices.set_data(lora_indices.astype(dtype.int64)) + # else: + # self.group_list.set_data(seq_len.astype(dtype.int64)) + # self.lora_indices.set_data(lora_indices.astype(dtype.int64)) + - def add_shrink(self, y, x, lora_a_stacked, scale, **kwargs): + def add_shrink(self, x, lora_a_stacked, scale, **kwargs): """ Performs GEMM for multiple slices of lora_a. When `is_prefill is` true, it indicates that it is currently the @@ -179,9 +222,10 @@ class PunicaWrapperNPU(PunicaWrapperBase): x = x.view(-1, x.shape[-1]) # TODO fuse these kernels - for slice_idx in range(len(lora_a_stacked)): - self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], - scale) + # for slice_idx in range(len(lora_a_stacked)): + y = self._apply_shrink(x, lora_a_stacked, + scale) + return y def add_expand(self, y, @@ -214,20 +258,20 @@ class PunicaWrapperNPU(PunicaWrapperBase): y_org = y y = y.view(-1, y.shape[-1]) offset_left = offset_start - if lora_bias_stacked is not None: - self._apply_bias(self.token_lora_indices, y, output_slices, - lora_bias_stacked) - for slice_idx in range(len(lora_b_stacked)): - self._apply_expand( - y, - x[slice_idx], - lora_b_stacked[slice_idx], - offset_left, - output_slices[slice_idx], - add_inputs=add_inputs, - ) - offset_left += output_slices[slice_idx] - y.view_as(y_org) + # if lora_bias_stacked is not None: + # self._apply_bias(self.token_lora_indices, y, output_slices, + # lora_bias_stacked) + # for slice_idx in range(len(lora_b_stacked)): + y = self._apply_expand( + y, + x, + lora_b_stacked, + offset_left, + output_slices, + add_inputs=add_inputs, + ) + #offset_left += output_slices[slice_idx] + return y.view_as(y_org) def add_lora_embedding(self, y, @@ -292,27 +336,28 @@ class PunicaWrapperNPU(PunicaWrapperBase): if self.no_lora: return x = x.reshape(-1, x.shape[-1]) - assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) - if lora_bias_stacked is not None: - assert len(lora_bias_stacked) == len(output_slices) - y = self._apply_bias(self.token_lora_indices, y, output_slices, - lora_bias_stacked) + # assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + # if lora_bias_stacked is not None: + # assert len(lora_bias_stacked) == len(output_slices) + # y = self._apply_bias(self.token_lora_indices, y, output_slices, + # lora_bias_stacked) - if buffer is None: - r = lora_b_stacked[0].shape[-1] - # We set the buffer to be float32 by default, consistent with the - # triton op - buffer = tuple( - mint.zeros((x.shape[0], r), dtype=mstype.float32) - for _ in range(len(output_slices))) - self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) - self.add_expand(y, + # if buffer is None: + # r = lora_b_stacked[0].shape[-1] + # # We set the buffer to be float32 by default, consistent with the + # # triton op + # buffer = tuple( + # mint.zeros((x.shape[0], r), dtype=mstype.float32) + # for _ in range(len(output_slices))) + buffer = self.add_shrink(x, lora_a_stacked, scale, **kwargs) + y = self.add_expand(y, buffer, lora_b_stacked, None, output_slices, add_inputs=True, **kwargs) + return y def add_lora_logits(self, y, @@ -357,3 +402,6 @@ class PunicaWrapperNPU(PunicaWrapperBase): self.sampler_indices, add_inputs=True) y.view_as(y_org) + + def construct(self, *args, **kwargs): + return self.add_lora_linear(*args, **kwargs) diff --git a/vllm_mindspore/lora/utils.py b/vllm_mindspore/lora/utils.py index 0a96b555..a4bc9389 100644 --- a/vllm_mindspore/lora/utils.py +++ b/vllm_mindspore/lora/utils.py @@ -50,3 +50,21 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = { RowParallelLinearWithShardedLoRA, LinearScalingRotaryEmbeddingWithLoRA, } + +def replace_submodule(model, module_name, new_module): + """Replace a submodule in a model with a new module.""" + parent = model.get_submodule(".".join(module_name.split(".")[:-1])) + target_name = module_name.split(".")[-1] + weight_name = module_name + ".weight" + lora_a_name = module_name + ".lora_a_weight" + lora_b_name = module_name + ".lora_b_weight" + lora_bias_name = module_name + ".lora_bias" + bias_name = module_name + ".bias" + setattr(parent, target_name, new_module) + new_module.base_layer.weight.name = weight_name + new_module.lora_a_stacked.name = lora_a_name + new_module.lora_b_stacked.name = lora_b_name + if new_module.base_layer.bias is not None: + new_module.base_layer.bias.name = bias_name + new_module.lora_bias_stacked.name = lora_bias_name + return new_module diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 6a73e171..c8c8f79a 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -394,9 +394,9 @@ class NativeModel(MsModelBase): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__(vllm_config=vllm_config, prefix=prefix) self.quant_config = vllm_config.quant_config - if vllm_config.lora_config is not None: - # native model lora only support pynative mode now - vllm_config.model_config.enforce_eager = True + # if vllm_config.lora_config is not None: + # # native model lora only support pynative mode now + # vllm_config.model_config.enforce_eager = True self.is_eager_mode = vllm_config.model_config.enforce_eager self.prefill_graph = None self.decode_graph = None -- Gitee From 60d21b8531d4bdc9a2e5a147914dd67b5a07b702 Mon Sep 17 00:00:00 2001 From: huangzhuo Date: Tue, 12 Aug 2025 16:37:11 +0800 Subject: [PATCH 09/12] fix bug --- vllm_mindspore/lora/layers.py | 41 ++++--------------- vllm_mindspore/lora/ops/torch_ops/lora_ops.py | 14 ++++--- .../lora/punica_wrapper/punica_npu.py | 23 ++++++----- vllm_mindspore/lora/utils.py | 2 +- .../model_executor/model_loader/utils.py | 7 +--- .../model_loader/weight_utils.py | 10 +---- .../mf_models/deepseekv3_weight_processor.py | 13 ------ .../model_executor/models/model_base.py | 1 + vllm_mindspore/v1/worker/gpu_model_runner.py | 1 - vllm_mindspore/worker/worker.py | 1 - 10 files changed, 36 insertions(+), 77 deletions(-) diff --git a/vllm_mindspore/lora/layers.py b/vllm_mindspore/lora/layers.py index bff4d3f7..c6ec9906 100644 --- a/vllm_mindspore/lora/layers.py +++ b/vllm_mindspore/lora/layers.py @@ -353,15 +353,9 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): lora_config.lora_dtype)) if lora_config.bias_enabled: lora_bias_out_size = lora_b_out_size - self.lora_bias_stacked = tuple( - mint.zeros( - ( - max_loras, - 1, - lora_bias_out_size, - ), - dtype=lora_config.lora_dtype, - ) for _ in range(self.n_slices)) + self.lora_bias_stacked = Parameter( + initializer('zeros', (self.n_slices, max_loras, 1, lora_bias_out_size), + lora_config.lora_dtype)) self.output_slices = (self.lora_b_stacked[0].shape[2], ) def reset_lora(self, index: int): @@ -414,7 +408,7 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): x: ms.Tensor, bias: Optional[ms.Tensor] = None) -> ms.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked[0], + output = self.punica_wrapper(output, x, self.lora_a_stacked[0], self.lora_b_stacked[0], self.lora_bias_stacked, 1.0, self.output_slices) @@ -550,7 +544,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): model_config: Optional[PretrainedConfig] = None, ) -> None: """ - The main reason for overriding this function is to enhance code + The main reason for overriding this function is to enhance code maintainability. """ self.lora_config = lora_config @@ -577,26 +571,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): initializer('zeros', (self.n_slices, max_loras, 1, self.output_slices[0], lora_config.max_lora_rank), lora_config.lora_dtype)) if lora_config.bias_enabled: - self.lora_bias_stacked = tuple( - mint.zeros( - ( - max_loras, - 1, - output_size, - lora_config.max_lora_rank, - ), - dtype=lora_config.lora_dtype, - ) for output_size in self.output_slices) - if lora_config.bias_enabled: - self.lora_bias_stacked = tuple( - mint.zeros( - ( - max_loras, - 1, - output_size, - ), - dtype=lora_config.lora_dtype, - ) for output_size in self.output_slices) + self.lora_bias_stacked = Parameter( + initializer('zeros', (self.n_slices, max_loras, 1, self.output_slices[0]), + lora_config.lora_dtype)) def slice_lora_a( self, lora_a: list[Union[ms.Tensor, @@ -802,7 +779,7 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): model_config: Optional[PretrainedConfig] = None, ) -> None: """ - The main reason for overloading this function is to handle inconsistent + The main reason for overloading this function is to handle inconsistent weight dimensions in qkv lora. """ super().create_lora_weights(max_loras, lora_config, model_config) diff --git a/vllm_mindspore/lora/ops/torch_ops/lora_ops.py b/vllm_mindspore/lora/ops/torch_ops/lora_ops.py index 92d58e61..21fd83a2 100644 --- a/vllm_mindspore/lora/ops/torch_ops/lora_ops.py +++ b/vllm_mindspore/lora/ops/torch_ops/lora_ops.py @@ -113,23 +113,25 @@ def sgmv_shrink( # group_type=0, # group_list_type=1) #outputs = bgmv_shrink(inputs, lora_a_weights, 0, scaling) + #group_list = seq_len_tensor + #lora_a_weights = lora_a_weights[lora_indices_tensor] lora_a_weights = lora_a_weights.squeeze(1) lora_a_weights = mint.transpose(lora_a_weights, 1, 2) outputs = grouped_matmul_v4([inputs], [lora_a_weights], group_list=group_list, split_item=3, group_type=0, - group_list_type=1) + group_list_type=1)[0] # output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] return outputs * scaling def bgmv_shrink(inputs, - lora_b_weights, + lora_a_weights, lora_indices_tensor, scaling=1.0): - selected_loras = lora_b_weights[lora_indices_tensor] - inputs = inputs.astype(lora_b_weights[0].dtype) + selected_loras = lora_a_weights[lora_indices_tensor] + inputs = inputs.astype(lora_a_weights[0].dtype) # if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(1) outputs = einsum_ms(inputs, selected_loras) * scaling @@ -175,13 +177,15 @@ def sgmv_expand_slice(inputs, # else: # output_tensor = outputs[:] #output_tensor = bgmv_expand_slice(inputs, lora_b_weights, output_tensor, 0, 0, 0) + #group_list = seq_len_tensor + #lora_b_weights = lora_b_weights[lora_indices_tensor] lora_b_weights = lora_b_weights.squeeze(1) lora_b_weights = mint.transpose(lora_b_weights, 1, 2) outputs = grouped_matmul_v4([inputs], [lora_b_weights], group_list=group_list, split_item=3, group_type=0, - group_list_type=1) + group_list_type=1)[0] output_tensor = ops.add(output_tensor, outputs) return output_tensor diff --git a/vllm_mindspore/lora/punica_wrapper/punica_npu.py b/vllm_mindspore/lora/punica_wrapper/punica_npu.py index 8167c670..6531659a 100644 --- a/vllm_mindspore/lora/punica_wrapper/punica_npu.py +++ b/vllm_mindspore/lora/punica_wrapper/punica_npu.py @@ -147,6 +147,7 @@ class PunicaWrapperNPU(PunicaWrapperBase, nn.Cell): if get_model_context("is_prefill") else self._expand_slice_decode) return expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs) + #return self._expand_slice_prefill(y, x, w_t_all, y_offset, y_slice_size, add_inputs) def _apply_shrink(self, x, w_t_all, scale): """ @@ -161,8 +162,8 @@ class PunicaWrapperNPU(PunicaWrapperBase, nn.Cell): # y = y.view(-1, y.shape[-1]) shrink_fun: Callable = (self._shrink_prefill if get_model_context("is_prefill") else self._shrink_decode) - y = shrink_fun(x, w_t_all, scale) - return y + #y = shrink_fun(x, w_t_all, scale) + return shrink_fun(x, w_t_all, scale) def update_metadata( self, @@ -176,16 +177,20 @@ class PunicaWrapperNPU(PunicaWrapperBase, nn.Cell): self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size, long_lora_context) - if mapping.is_prefill: - # Update metadata required for prefill-related operators. - self._update_prefill_metadata(self.token_lora_indices) - self.is_prefill = True - else: - self.is_prefill = False + # if mapping.is_prefill: + # Update metadata required for prefill-related operators. + self._update_prefill_metadata(self.token_lora_indices) + self.is_prefill = True + # else: + # self.is_prefill = False _, seq_len, lora_indices, _, _, _ = self.prefill_metadata sorted_ids, sorted_counts = sort_lora_by_token_count( lora_indices, seq_len) group_list = sorted_counts + if len(group_list) < self.max_loras: + new_tensor = mint.zeros(self.max_loras, dtype=group_list.dtype) + new_tensor[:group_list.size(0)] = group_list + group_list = new_tensor self.group_list.set_data(group_list.astype(dtype.int64)) # if len(lora_indices) > self.max_loras: # self.group_list.set_data(seq_len[:self.max_loras].astype(dtype.int64)) @@ -261,7 +266,6 @@ class PunicaWrapperNPU(PunicaWrapperBase, nn.Cell): # if lora_bias_stacked is not None: # self._apply_bias(self.token_lora_indices, y, output_slices, # lora_bias_stacked) - # for slice_idx in range(len(lora_b_stacked)): y = self._apply_expand( y, x, @@ -270,7 +274,6 @@ class PunicaWrapperNPU(PunicaWrapperBase, nn.Cell): output_slices, add_inputs=add_inputs, ) - #offset_left += output_slices[slice_idx] return y.view_as(y_org) def add_lora_embedding(self, diff --git a/vllm_mindspore/lora/utils.py b/vllm_mindspore/lora/utils.py index a4bc9389..1fb215e0 100644 --- a/vllm_mindspore/lora/utils.py +++ b/vllm_mindspore/lora/utils.py @@ -66,5 +66,5 @@ def replace_submodule(model, module_name, new_module): new_module.lora_b_stacked.name = lora_b_name if new_module.base_layer.bias is not None: new_module.base_layer.bias.name = bias_name - new_module.lora_bias_stacked.name = lora_bias_name + #new_module.lora_bias_stacked.name = lora_bias_name return new_module diff --git a/vllm_mindspore/model_executor/model_loader/utils.py b/vllm_mindspore/model_executor/model_loader/utils.py index c3bf3dd1..02ba4fca 100644 --- a/vllm_mindspore/model_executor/model_loader/utils.py +++ b/vllm_mindspore/model_executor/model_loader/utils.py @@ -18,14 +18,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """ utils for load model """ -<<<<<<< HEAD -from mindspore import nn -======= import numpy as np import torch from torch import nn +from mindspore import nn as ms_nn from vllm.attention import Attention ->>>>>>> 0cd4cd6 (support qwq) from vllm.config import ModelConfig from vllm.model_executor.models import ModelRegistry @@ -36,7 +33,7 @@ from vllm_mindspore.model_executor.models.registry import ( def get_ms_model_architecture( - model_config: ModelConfig) -> tuple[type[nn.Cell], str]: + model_config: ModelConfig) -> tuple[type[ms_nn.Cell], str]: architectures = getattr(model_config.hf_config, "architectures", []) if is_mf_mcore_archs(architectures): architectures.append("MindFormersForCausalLM") diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index 1e162829..8e893cf5 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -70,7 +70,7 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): if (str(loaded_weight.dtype) == 'bfloat16' and is_310p()) else loaded_weight ) - return loaded_weight + return loaded_weight.asnumpy() def safetensors_weights_iterator( @@ -86,20 +86,12 @@ def safetensors_weights_iterator( ): with safe_open(st_file, framework="np") as f: for name in f.keys(): # noqa: SIM118 -<<<<<<< HEAD - # TODO: use slice - x = f.get_tensor(name) - x = x.astype(np.float16) \ - if (str(x.dtype) == 'bfloat16' and is_310p()) else x - yield name, ms.tensor(x) -======= # Return a lightweight PySafeSlice object that uses file # pointer offset internally to read Safetensor on demand, # avoiding memory explosion. Actual data can be obtained # through slicing operation like param[start:end] param = f.get_slice(name) yield name, param ->>>>>>> 8858529 (fix new branch) def default_weight_loader(param: Parameter, loaded_weight: Any) -> None: diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py index a8aa7127..73bc3027 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py @@ -402,19 +402,6 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): w2_scale_ms_stack_param = np.stack(w2_scale_list, axis=0) w3_scale_ms_stack_param = np.stack(w3_scale_list, axis=0) -<<<<<<< HEAD -======= - if self.is_310p: - weight_scale_dtype = ms.float32 - weight_concat_axis = 2 - w1_ms_stack_param = w1_ms_stack_param.transpose(0, 2, 1) - w2_ms_stack_param = w2_ms_stack_param.transpose(0, 2, 1) - w3_ms_stack_param = w3_ms_stack_param.transpose(0, 2, 1) - else: - weight_scale_dtype = ms.bfloat16 - weight_concat_axis = 1 - ->>>>>>> f3c366e (change atlas_inference to is_310p) if ffn_concat: # w_gate_hidden w_gate_hidden_name = f"{base_path}.w_gate_hidden._layer.weight" diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index c8c8f79a..d4f87a6c 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -136,6 +136,7 @@ class MsModelBase: config = vllm_config.model_config.hf_config lora_config = vllm_config.lora_config + self.vllm_config = vllm_config self.config = config self.model_config = vllm_config.model_config self.lora_config = lora_config diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index d40efe70..88ff271e 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -335,7 +335,6 @@ def _reshape_kv_cache_tensors( for i in range(len(kv_cache_stride_order)) ] kv_cache_layer = [] -<<<<<<< HEAD for idx, kv_cache_raw_tensor in enumerate( kv_cache_raw_tensors[layer_name]): if use_mla_op: diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index ad6a0ba3..9c01e6d1 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -18,7 +18,6 @@ import math import subprocess import os -import subprocess import psutil import torch -- Gitee From 9b50e1edd6db619c619fb67141422f3d48539a00 Mon Sep 17 00:00:00 2001 From: huangzhuo Date: Sat, 16 Aug 2025 17:32:53 +0800 Subject: [PATCH 10/12] fix mutilora bug --- .../distributed/communication_op.py | 177 ++++++++++-------- vllm_mindspore/lora/layers.py | 11 +- .../model_executor/layers/logits_processor.py | 121 +++++++++--- .../model_loader/weight_utils.py | 2 +- vllm_mindspore/model_executor/models/qwen2.py | 8 +- 5 files changed, 205 insertions(+), 114 deletions(-) diff --git a/vllm_mindspore/distributed/communication_op.py b/vllm_mindspore/distributed/communication_op.py index 475a282d..112cc52e 100644 --- a/vllm_mindspore/distributed/communication_op.py +++ b/vllm_mindspore/distributed/communication_op.py @@ -1,76 +1,101 @@ -# SPDX-License-Identifier: Apache-2.0 - -# Communication functions are adapted from -# https://github.com/vllm-project/vllm/blob/v0.8.3/vllm/distributed/communication_op.py -# -# Copyright 2025 Huawei Technologies Co., Ltd. -# Copyright 2024-2025 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Implement a unified communication interface for both graph and pynative mode. -""" - -from typing import Any, Dict, Optional, Union -import torch - -from mindspore import nn, ops -from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_world_size, get_tp_group) - -def cpu_broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor, - Any]]] = None, - src: int = 0): - if not torch.distributed.is_initialized(): - return tensor_dict - return get_tp_group().broadcast_tensor_dict(tensor_dict, src, group=get_tp_group().cpu_group) - - -class ReduceFromModelParallelRegion(nn.Cell): - "All reduce the input from the model parallel region." - - def __init__(self): - super().__init__() - self.world_size = get_tensor_model_parallel_world_size() - if self.world_size > 1: - self.tp_group = get_tp_group().device_group._name - self.all_reduce = ops.AllReduce(group=self.tp_group) - - def construct(self, input_): - if self.world_size == 1: - return input_ - output = self.all_reduce(input_) - return output - - -class AllGatherFromModelParallelRegion(nn.Cell): - """ - Gather the input from world parallel region and concatenate, - simultaneously perform transpose operation on input. - """ - - def __init__(self): - super().__init__() - self.world_size = get_tensor_model_parallel_world_size() - if self.world_size > 1: - self.tp_group = get_tp_group().device_group._name - self.all_gather_into_tensor = ops.AllGather(group=self.tp_group) - - def construct(self, input_): - # Size and dimension. - if self.world_size == 1: - return input_ - input_ = ops.swapaxes(input_, 0, -1) - output = self.all_gather_into_tensor(input_) - output = ops.swapaxes(output, 0, -1) - return output +# SPDX-License-Identifier: Apache-2.0 + +# Communication functions are adapted from +# https://github.com/vllm-project/vllm/blob/v0.8.3/vllm/distributed/communication_op.py +# +# Copyright 2025 Huawei Technologies Co., Ltd. +# Copyright 2024-2025 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Implement a unified communication interface for both graph and pynative mode. +""" + +from typing import Any, Dict, Optional, Union +import torch + +from mindspore import Tensor, mint, nn, ops +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_tp_group) + +def cpu_broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor, + Any]]] = None, + src: int = 0): + if not torch.distributed.is_initialized(): + return tensor_dict + return get_tp_group().broadcast_tensor_dict(tensor_dict, src, group=get_tp_group().cpu_group) + + +class ReduceFromModelParallelRegion(nn.Cell): + "All reduce the input from the model parallel region." + + def __init__(self): + super().__init__() + self.world_size = get_tensor_model_parallel_world_size() + if self.world_size > 1: + self.tp_group = get_tp_group().device_group._name + self.all_reduce = ops.AllReduce(group=self.tp_group) + + def construct(self, input_): + if self.world_size == 1: + return input_ + output = self.all_reduce(input_) + return output + + +class AllGatherFromModelParallelRegion(nn.Cell): + """ + Gather the input from world parallel region and concatenate, + simultaneously perform transpose operation on input. + """ + + def __init__(self): + super().__init__() + self.world_size = get_tensor_model_parallel_world_size() + if self.world_size > 1: + self.tp_group = get_tp_group().device_group._name + self.all_gather_into_tensor = ops.AllGather(group=self.tp_group) + + def construct(self, input_): + # Size and dimension. + if self.world_size == 1: + return input_ + input_ = ops.swapaxes(input_, 0, -1) + output = self.all_gather_into_tensor(input_) + output = ops.swapaxes(output, 0, -1) + return output + + +class GatherFromModelParallelRegion(nn.Cell): + "Gather the input from model parallel region and concatenate." + + def __init__(self): + super().__init__() + self.world_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + if self.world_size > 1: + self.tp_group = get_tp_group().device_group._name + + def construct(self, + input_: Tensor, + dst: int = 0, + dim: int = -1) -> Optional[Tensor]: + # Size and dimension. + if self.world_size == 1: + return input_ + output = ops.CollectiveGather(dest_rank=dst, + group=self.tp_group)(mint.transpose(input_, 0, dim)) + if self.tp_rank != dst: + return None + return mint.transpose(output, 0, dim) \ No newline at end of file diff --git a/vllm_mindspore/lora/layers.py b/vllm_mindspore/lora/layers.py index c6ec9906..7008a20d 100644 --- a/vllm_mindspore/lora/layers.py +++ b/vllm_mindspore/lora/layers.py @@ -33,8 +33,7 @@ from vllm.config import LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, - tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce) + tensor_model_parallel_all_gather) from vllm.distributed.utils import divide # yapf: enable from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -356,13 +355,15 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): self.lora_bias_stacked = Parameter( initializer('zeros', (self.n_slices, max_loras, 1, lora_bias_out_size), lora_config.lora_dtype)) + else: + self.lora_bias_stacked = None self.output_slices = (self.lora_b_stacked[0].shape[2], ) def reset_lora(self, index: int): for s_index in range(self.n_slices): self.lora_a_stacked[s_index][index] = 0 self.lora_b_stacked[s_index][index] = 0 - if self.lora_config.bias_enabled: + if self.lora_bias_stacked: # Make mypy happy self.lora_bias_stacked = cast(tuple[ms.Tensor, ...], self.lora_bias_stacked) @@ -396,7 +397,7 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): self.lora_b_stacked[0][index, 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( lora_b.T, non_blocking=True) - if lora_bias is not None: + if self.lora_bias_stacked is not None: self.lora_bias_stacked = cast(tuple[ms.Tensor, ...], self.lora_bias_stacked) @@ -874,7 +875,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): # Matrix multiply. output_parallel = self.apply(input_parallel) if self.base_layer.reduce_results and self.base_layer.tp_size > 1: - output_ = tensor_model_parallel_all_reduce(output_parallel) + output_ = self.base_layer.tensor_model_parallel_all_reduce(output_parallel) else: output_ = output_parallel diff --git a/vllm_mindspore/model_executor/layers/logits_processor.py b/vllm_mindspore/model_executor/layers/logits_processor.py index ee8c8edc..6910804a 100644 --- a/vllm_mindspore/model_executor/layers/logits_processor.py +++ b/vllm_mindspore/model_executor/layers/logits_processor.py @@ -23,12 +23,14 @@ from concurrent.futures import ThreadPoolExecutor from typing import Optional import vllm.envs as envs -from mindspore import Tensor, mint, nn -from vllm.config import current_platform +from mindspore import Tensor, jit, mint, nn +from vllm.config import current_platform, get_current_vllm_config from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_gather) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm_mindspore.distributed.communication_op import ( + AllGatherFromModelParallelRegion, GatherFromModelParallelRegion) from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -60,6 +62,9 @@ class LogitsProcessor(nn.Cell): scale: A scaling factor to apply to the logits. """ super().__init__() + vllm_config = get_current_vllm_config() + self.vllm_config = vllm_config + self.is_graph_mode = bool(not vllm_config.model_config.enforce_eager) self.scale = scale self.vocab_size = vocab_size # Whether the input is logits (default is hidden states). @@ -71,25 +76,101 @@ class LogitsProcessor(nn.Cell): # Whether to use gather or all-gather to gather the logits. self.use_all_gather = current_platform.use_all_gather() + if self.use_all_gather: + self.tensor_model_parallel_all_gather = AllGatherFromModelParallelRegion() + else: + self.tensor_model_parallel_gather = GatherFromModelParallelRegion() + self.lm_head = None + self.run_model = None + self.cached_input_info = {} + + def set_dynamic_inputs(self): + dyn_hidden_states = Tensor(shape=[None, None], + dtype=self.vllm_config.model_config.dtype) + + if self.cached_input_info["indices"] is None: + dyn_indices = None + else: + dyn_indices_shape = [ + None for _ in range(self.cached_input_info["indices"]["ndim"]) + ] + dyn_indices_dtype = self.cached_input_info["indices"]["dtype"] + dyn_indices = Tensor(shape=dyn_indices_shape, + dtype=dyn_indices_dtype) + + if self.cached_input_info["bias"] is None: + dyn_bias = None + else: + dyn_bias_shape = [ + None for _ in range(self.cached_input_info["bias"]["ndim"]) + ] + dyn_bias_dtype = self.cached_input_info["bias"]["dtype"] + dyn_bias = Tensor(shape=dyn_bias_shape, dtype=dyn_bias_dtype) + + self.set_inputs(dyn_hidden_states, dyn_indices, dyn_bias) + + def __call__( + self, + lm_head: VocabParallelEmbedding, + hidden_states: Tensor, + sampling_metadata: Optional[SamplingMetadata] = None, + embedding_bias: Optional[Tensor] = None, + ) -> Optional[Tensor]: + if self.lm_head is None: + self.lm_head = lm_head + if self.run_model is None: + self.run_model = jit( + function=self.construct, + jit_level='O0') if self.is_graph_mode else self.construct + selected_token_indices = None + if sampling_metadata is not None: + selected_token_indices = sampling_metadata.selected_token_indices + dyn_indices_info = None if selected_token_indices is None else { + "ndim": selected_token_indices.ndim, + "dtype": selected_token_indices.dtype, + } + dyn_bias_info = None if embedding_bias is None else { + "ndim": embedding_bias.ndim, + "dtype": embedding_bias.dtype, + } + if self.cached_input_info != {"indices": dyn_indices_info, + "bias": dyn_bias_info}: + self.cached_input_info = { + "indices": dyn_indices_info, + "bias": dyn_bias_info, + } + self.set_dynamic_inputs() + + logits = self.run_model( + hidden_states, + selected_token_indices, + embedding_bias + ) + + if sampling_metadata is not None and \ + sampling_metadata.seq_groups is not None: + logits = _apply_logits_processors(logits, sampling_metadata) + + return logits + def construct( self, - lm_head: VocabParallelEmbedding, hidden_states: Tensor, - sampling_metadata: Optional[SamplingMetadata] = None, + selected_token_indices: Optional[Tensor] = None, embedding_bias: Optional[Tensor] = None, ) -> Optional[Tensor]: if self.logits_as_input: logits = hidden_states else: - if sampling_metadata is not None: - if sampling_metadata.selected_token_indices.numel() <= 0: - return mint.zeros((0, self.vocab_size), - dtype=hidden_states.dtype) - hidden_states = _prune_hidden_states(hidden_states, - sampling_metadata) + if selected_token_indices is not None: + if selected_token_indices.numel() <= 0: + return mint.zeros((0, self.vocab_size), dtype=hidden_states.dtype) + hidden_states = mint.index_select( + hidden_states, 0, selected_token_indices) # Get the logits for the next tokens. - logits = self._get_logits(hidden_states, lm_head, embedding_bias) + logits = self._get_logits( + hidden_states, self.lm_head, embedding_bias) if logits is not None: if self.soft_cap is not None: logits = logits / self.soft_cap @@ -100,9 +181,6 @@ class LogitsProcessor(nn.Cell): logits *= self.scale # Apply logits processors (if any). - if sampling_metadata is not None and \ - sampling_metadata.seq_groups is not None: - logits = _apply_logits_processors(logits, sampling_metadata) return logits @@ -118,10 +196,10 @@ class LogitsProcessor(nn.Cell): bias=embedding_bias) if self.use_all_gather: # Gather is not supported for some devices such as NPUs. - logits = tensor_model_parallel_all_gather(logits) + logits = self.tensor_model_parallel_all_gather(logits) else: # None may be returned for rank > 0 - logits = tensor_model_parallel_gather(logits) + logits = self.tensor_model_parallel_gather(logits) # Remove paddings in vocab (if any). if logits is not None: logits = logits[..., :self.org_vocab_size] @@ -134,17 +212,6 @@ class LogitsProcessor(nn.Cell): return s -def _prune_hidden_states( - hidden_states: Tensor, - sampling_metadata: SamplingMetadata, -) -> Tensor: - indices = sampling_metadata.selected_token_indices - if indices is not None and indices.numel() > 0: - return mint.index_select(hidden_states, 0, - sampling_metadata.selected_token_indices) - return hidden_states - - def _apply_logits_processors( logits: Tensor, sampling_metadata: SamplingMetadata, diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index 8e893cf5..3642f234 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -70,7 +70,7 @@ def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): if (str(loaded_weight.dtype) == 'bfloat16' and is_310p()) else loaded_weight ) - return loaded_weight.asnumpy() + return loaded_weight def safetensors_weights_iterator( diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index 6891aea4..b20a9a4b 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -380,6 +380,7 @@ class Qwen2Model(nn.Cell): continue param = params_dict[name] + loaded_weight = ms.Tensor(loaded_weight[:], dtype=param.dtype) param.set_data(loaded_weight.contiguous()) def adjust_weight(params_dict): @@ -391,7 +392,7 @@ class Qwen2Model(nn.Cell): "o_proj.weight", "gate_up_proj.weight", "down_proj.weight", - # "lm_head.weight", + "lm_head.weight", ] rank_id = get_rank() @@ -461,9 +462,6 @@ class Qwen2Model(nn.Cell): param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - # Norm type in weights may be f32 - if(loaded_weight.dtype != param.dtype): - loaded_weight = loaded_weight.to(dtype=param.dtype) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -476,7 +474,7 @@ class Qwen2Model(nn.Cell): "o_proj.weight", "gate_up_proj.weight", "down_proj.weight", - # "lm_head.weight", + "lm_head.weight", ] for name, param in params_dict.items(): -- Gitee From 632a0c35c6ade3dd4b148567d90f1ddedeaa1606 Mon Sep 17 00:00:00 2001 From: superxf Date: Tue, 5 Aug 2025 18:33:50 +0800 Subject: [PATCH 11/12] support 310p qwen3 mcore --- vllm_mindspore/__init__.py | 6 +- vllm_mindspore/config.py | 5 +- .../model_executor/models/mf_models/config.py | 16 ++- .../models/mf_models/mindformers.py | 14 +- .../model_executor/models/model_base.py | 39 +++--- vllm_mindspore/utils.py | 1 + vllm_mindspore/v1/worker/gpu_model_runner.py | 121 +++++++++++++++++- 7 files changed, 168 insertions(+), 34 deletions(-) diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index c361f65e..6fb1b1a4 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -387,14 +387,14 @@ from vllm_mindspore.v1.worker.gpu_model_runner import _update_states vllm.v1.worker.gpu_model_runner.GPUModelRunner._update_states = _update_states from vllm_mindspore.v1.worker.gpu_model_runner import ( - _allocate_kv_cache_tensors, - get_kv_cache_spec, -) + _allocate_kv_cache_tensors, get_kv_cache_spec, initialize_kv_cache_tensors) vllm.v1.worker.gpu_model_runner.GPUModelRunner._allocate_kv_cache_tensors = ( _allocate_kv_cache_tensors) vllm.v1.worker.gpu_model_runner.GPUModelRunner.get_kv_cache_spec = ( get_kv_cache_spec) +vllm.v1.worker.gpu_model_runner.GPUModelRunner.initialize_kv_cache_tensors = ( + initialize_kv_cache_tensors) from vllm_mindspore.v1.worker.gpu_model_runner import _reshape_kv_cache_tensors vllm.v1.worker.gpu_model_runner.GPUModelRunner._reshape_kv_cache_tensors = ( diff --git a/vllm_mindspore/config.py b/vllm_mindspore/config.py index f0376485..f6b13e4f 100644 --- a/vllm_mindspore/config.py +++ b/vllm_mindspore/config.py @@ -256,9 +256,8 @@ def _get_and_verify_dtype( if torch_dtype in _STR_DTYPE_TO_TORCH_DTYPE: torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[torch_dtype] - if is_310p() and torch_dtype == torch.bfloat16: - return torch.float16 - + if torch_dtype == torch.bfloat16 and is_310p(): + torch_dtype = torch.float16 return torch_dtype diff --git a/vllm_mindspore/model_executor/models/mf_models/config.py b/vllm_mindspore/model_executor/models/mf_models/config.py index a658d0b3..b552667d 100644 --- a/vllm_mindspore/model_executor/models/mf_models/config.py +++ b/vllm_mindspore/model_executor/models/mf_models/config.py @@ -19,6 +19,8 @@ from mindformers.tools.register.config import MindFormerConfig from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm_mindspore.utils import is_310p + logger = init_logger(__name__) @@ -94,6 +96,14 @@ MF_MODEL_COMMON_MAPPING = { 'model.model_config.params_dtype': (None, 'bfloat16'), 'model.model_config.router_dense_type': (None, 'bfloat16'), } + +MF_MODEL_COMMON_MAPPING_310p = { + 'model.model_config.compute_dtype': ('model_config.hf_config.torch_dtype', 'float16'), + 'model.model_config.layernorm_compute_dtype': (None, 'float16'), + 'model.model_config.rotary_dtype': (None, 'float16'), + 'model.model_config.params_dtype': (None, 'float16'), + 'model.model_config.router_dense_type': (None, 'float16'), +} # yapf: enable # model default config @@ -211,7 +221,11 @@ def gen_mf_config(vllm_config: VllmConfig): target_config.set_value( 'model.model_config', MindFormerConfig(**gen_model_config_dict(vllm_config))) - transform_config(MF_MODEL_COMMON_MAPPING, vllm_config, target_config) + if is_310p(): + transform_config(MF_MODEL_COMMON_MAPPING_310p, vllm_config, + target_config) + else: + transform_config(MF_MODEL_COMMON_MAPPING, vllm_config, target_config) # Update target config with additional config. # The configuration hierarchy in the additional config must match the # hierarchy structure of the MindFormers YAML configuration file. diff --git a/vllm_mindspore/model_executor/models/mf_models/mindformers.py b/vllm_mindspore/model_executor/models/mf_models/mindformers.py index 8a5d7b94..82a5f90c 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mindformers.py +++ b/vllm_mindspore/model_executor/models/mf_models/mindformers.py @@ -49,6 +49,8 @@ class MindFormersForCausalLM(MsModelBase): self.set_flags = False self.max_model_len = vllm_config.model_config.max_model_len self.hf_config = vllm_config.model_config.hf_config + self.is_eager_mode = vllm_config.model_config.enforce_eager + self.lm_head_graph = None mf_config = gen_mf_config(vllm_config) mf_config.load_checkpoint = self.get_model_path() @@ -191,16 +193,20 @@ class MindFormersForCausalLM(MsModelBase): and selected_token_indices.numel() <= 0: logits = ms.mint.zeros((0, self.hf_config.vocab_size), dtype=self.hf_config.torch_dtype) + return logits else: hidden_states = hidden_states.reshape( (-1, hidden_states.shape[-1])) hidden_states = hidden_states.index_select( 0, selected_token_indices) - logits = self.lm_head(hidden_states) - logits = logits.view(-1, logits.shape[-1]) - else: + if self.is_eager_mode: logits = self.lm_head(hidden_states) - logits = logits.view(-1, logits.shape[-1]) + else: + if self.lm_head_graph is None: + self.lm_head_graph = ms.jit(function=self.lm_head, + jit_level="O0") + logits = self.lm_head_graph(hidden_states) + logits = logits.view(-1, logits.shape[-1]) return logits def sample( diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index d4f87a6c..6c4cf677 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -48,29 +48,28 @@ class AttentionWrapper: num_block = 0 if is_310p(): self.kv_shape = [num_block, block_size, num_kv_heads * head_size] - self.kv_cache = [ - ( - ops.auto_generate.format_cast( - ms.mint.zeros( - self.kv_shape, dtype=vllm_config.model_config.dtype - ), - FORMAT_TYPE['nz'], - ), - ops.auto_generate.format_cast( - ms.mint.zeros( - self.kv_shape, dtype=vllm_config.model_config.dtype - ), - FORMAT_TYPE['nz'], - ), - ) - for _ in range(vllm_config.parallel_config.pipeline_parallel_size) - ] + self.kv_cache = [( + ops.auto_generate.format_cast( + ms.mint.zeros(self.kv_shape, + dtype=vllm_config.model_config.dtype), + FORMAT_TYPE['nz'], + ), + ops.auto_generate.format_cast( + ms.mint.zeros(self.kv_shape, + dtype=vllm_config.model_config.dtype), + FORMAT_TYPE['nz'], + ), + ) for _ in range( + vllm_config.parallel_config.pipeline_parallel_size)] else: self.kv_shape = [num_block, block_size, num_kv_heads, head_size] self.kv_cache = [( - ms.mint.zeros(self.kv_shape, dtype=vllm_config.model_config.dtype), - ms.mint.zeros(self.kv_shape, dtype=vllm_config.model_config.dtype), - ) for _ in range(vllm_config.parallel_config.pipeline_parallel_size)] + ms.mint.zeros(self.kv_shape, + dtype=vllm_config.model_config.dtype), + ms.mint.zeros(self.kv_shape, + dtype=vllm_config.model_config.dtype), + ) for _ in range( + vllm_config.parallel_config.pipeline_parallel_size)] self.attn_type = AttentionType.DECODER diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index f86d754d..7f9222ee 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -60,6 +60,7 @@ FORMAT_TYPE = { "nz": 29, } + def get_valid_dtype(dtype): if isinstance(dtype, str): dtype = STR_DTYPE_TO_MS_DTYPE[dtype] diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index 88ff271e..3a1391cf 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -23,16 +23,19 @@ from typing import Any, Optional import mindspore as ms import numpy as np +import torch from mindspore import Generator as msGenerator from mindspore import Tensor, mint, mutable, ops from vllm.attention import AttentionType from vllm.logger import init_logger from vllm.sampling_params import SamplingType -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, - SlidingWindowSpec) +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheSpec, SlidingWindowSpec) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState +from vllm.v1.worker.utils import initialize_kv_cache_for_kv_sharing from vllm_mindspore.model_executor.layers.rotary_embedding import ( InferMRotaryEmbedding as MRotaryEmbedding) @@ -210,6 +213,83 @@ def create_block(shape, dtype, name=None, device=None): return blocks +def _allocate_nz_kv_cache_tensors(self, kv_cache_config): + """ + Initializes and reshape the KV cache buffer with the correct size for 310p. + + Args: + kv_cache_config: The KV cache config + Returns: + dict[str, Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ + kv_caches: dict[str, tuple] = {} + + layer_to_group_info = { + layer_name: (i, group.kv_cache_spec) + for i, group in enumerate(kv_cache_config.kv_cache_groups) + for layer_name in group.layer_names + } + + use_mla_op = bool( + self.vllm_config.additional_config + and self.vllm_config.additional_config.get('use_mla_op') == 1) + if use_mla_op: + logger.error("For 310p, mla kv cache not supported") + raise NotImplementedError + + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + if not kv_cache_tensor.shared_by: + continue + + rep_layer_name = kv_cache_tensor.shared_by[0] + group_idx, kv_cache_spec = layer_to_group_info[rep_layer_name] + if not isinstance(kv_cache_spec, FullAttentionSpec): + raise NotImplementedError + + attn_backend = self.attn_backends[group_idx] + target_dtype = get_valid_dtype(kv_cache_spec.dtype) + coef = 2 + + num_blocks = kv_cache_tensor.size // kv_cache_spec.page_size_bytes + kv_cache_shape = attn_backend.get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size) + *dims, second_last, last = kv_cache_shape + kv_cache_shape = (*dims, second_last * last) + try: + stride_order = attn_backend.get_kv_cache_stride_order() + assert len(stride_order) == len(kv_cache_shape) + except (AttributeError, NotImplementedError): + stride_order = tuple(range(len(kv_cache_shape))) + + permuted_shape = tuple(kv_cache_shape[i] for i in stride_order) + + inv_order = [ + stride_order.index(i) - 1 for i in range(len(stride_order)) + ] + + reshaped_layer_tensors = [] + + for _ in range(coef): + tensor_split = mint.zeros( + permuted_shape[1:], dtype=target_dtype).permute(*inv_order[1:]) + tensor_split = ops.auto_generate.format_cast( + tensor_split, FORMAT_TYPE['nz']) + reshaped_layer_tensors.append(tensor_split) + + ms.runtime.synchronize() + final_kv_tuple = mutable(tuple(reshaped_layer_tensors)) + for layer_name in kv_cache_tensor.shared_by: + kv_caches[layer_name] = final_kv_tuple + + all_layers = set(layer_to_group_info.keys()) + assert all_layers == set( + kv_caches.keys()), "Some layers were not initialized" + + return kv_caches + + def _allocate_kv_cache_tensors(self, kv_cache_config): """ Initializes the KV cache buffer with the correct size. The buffer needs @@ -220,7 +300,7 @@ def _allocate_kv_cache_tensors(self, kv_cache_config): Returns: dict[str, Tensor]: A map between layer names to their corresponding memory buffer for KV cache. - """ + """ kv_cache_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec use_mla = kv_cache_spec.use_mla dtype = kv_cache_spec.dtype @@ -364,6 +444,41 @@ def _reshape_kv_cache_tensors( return kv_caches +def initialize_kv_cache_tensors( + self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + """ + Initialize the memory buffer for KV cache. + + Args: + kv_cache_config: The KV cache config + Returns: + Dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ + if is_310p(): + kv_caches = _allocate_nz_kv_cache_tensors(self, kv_cache_config) + else: + # Initialize the memory buffer for KV cache + kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) + # Change the memory buffer to the desired shape + kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, + kv_cache_raw_tensors) + + # Setup `kv_cache_config` and `kv_caches` for models + # with cross-layer KV sharing + if self.shared_kv_cache_layers: + initialize_kv_cache_for_kv_sharing( + self.shared_kv_cache_layers, + kv_cache_config.kv_cache_groups, + kv_caches, + ) + + bind_kv_cache(kv_caches, + self.vllm_config.compilation_config.static_forward_context, + self.kv_caches) + return kv_caches + + def _update_states(self, scheduler_output) -> None: """Update the cached states and the persistent batch with the scheduler output. -- Gitee From 16df9dea1afbe7ec5d190b4fa13376a2dcc69746 Mon Sep 17 00:00:00 2001 From: huangzhuo Date: Mon, 18 Aug 2025 11:44:40 +0800 Subject: [PATCH 12/12] fix sparsequant/mutilora conflict --- vllm_mindspore/lora/models.py | 7 ++- vllm_mindspore/lora/ops/torch_ops/lora_ops.py | 20 ++++--- .../lora/punica_wrapper/punica_npu.py | 58 +++++++++---------- vllm_mindspore/lora/utils.py | 22 +++---- 4 files changed, 58 insertions(+), 49 deletions(-) diff --git a/vllm_mindspore/lora/models.py b/vllm_mindspore/lora/models.py index 621f609a..253242b0 100644 --- a/vllm_mindspore/lora/models.py +++ b/vllm_mindspore/lora/models.py @@ -20,6 +20,7 @@ """Models for Multi-LoRA.""" import os +import numpy as np from typing import Optional, Union import mindspore as ms @@ -33,6 +34,7 @@ from vllm.model_executor.models.utils import WeightsMapper from vllm.utils import is_pin_memory_available from vllm_mindspore.lora.layers import BaseLayerWithLoRA +from vllm_mindspore.utils import is_310p _GLOBAL_LORA_ID = 0 @@ -197,7 +199,10 @@ def from_local_checkpoint( check_unexpected_modules(f) for module in f.keys(): # noqa # vllm-mindspore add numpy to tensor - tensors[module] = mint.Tensor(f.get_tensor(module)) + np_data = f.get_tensor(module) + if is_310p() and str(np_data.dtype) == "bfloat16": + np_data = np_data.astype(np.float32).astype(np.float16) + tensors[module] = mint.Tensor(np_data) elif os.path.isfile(lora_bin_file_path): # When a bin file is provided, we rely on config to find unexpected # modules. diff --git a/vllm_mindspore/lora/ops/torch_ops/lora_ops.py b/vllm_mindspore/lora/ops/torch_ops/lora_ops.py index 21fd83a2..acd279f3 100644 --- a/vllm_mindspore/lora/ops/torch_ops/lora_ops.py +++ b/vllm_mindspore/lora/ops/torch_ops/lora_ops.py @@ -20,7 +20,7 @@ """ For punica_npu """ -from mindspore import mint, ops +from mindspore import mint, ops, dtype from mindspore.ops.auto_generate import grouped_matmul_v4 @@ -86,6 +86,7 @@ def sgmv_shrink( inputs, lora_a_weights, group_list, + lora_indices, b_seq_start_loc, seq_len_tensor, lora_indices_tensor, @@ -114,9 +115,9 @@ def sgmv_shrink( # group_list_type=1) #outputs = bgmv_shrink(inputs, lora_a_weights, 0, scaling) #group_list = seq_len_tensor - #lora_a_weights = lora_a_weights[lora_indices_tensor] + lora_a_weights = lora_a_weights[lora_indices] lora_a_weights = lora_a_weights.squeeze(1) - lora_a_weights = mint.transpose(lora_a_weights, 1, 2) + lora_a_weights = lora_a_weights.transpose(0, 2, 1) outputs = grouped_matmul_v4([inputs], [lora_a_weights], group_list=group_list, split_item=3, @@ -143,6 +144,7 @@ def sgmv_expand_slice(inputs, lora_b_weights, output_tensor, group_list, + lora_indices, b_seq_start_loc, seq_len_tensor, lora_indices_tensor, @@ -178,9 +180,9 @@ def sgmv_expand_slice(inputs, # output_tensor = outputs[:] #output_tensor = bgmv_expand_slice(inputs, lora_b_weights, output_tensor, 0, 0, 0) #group_list = seq_len_tensor - #lora_b_weights = lora_b_weights[lora_indices_tensor] + lora_b_weights = lora_b_weights[lora_indices] lora_b_weights = lora_b_weights.squeeze(1) - lora_b_weights = mint.transpose(lora_b_weights, 1, 2) + lora_b_weights = lora_b_weights.transpose(0, 2, 1) outputs = grouped_matmul_v4([inputs], [lora_b_weights], group_list=group_list, split_item=3, @@ -204,8 +206,8 @@ def bgmv_expand_slice(inputs, selected_loras = selected_loras.squeeze(1) outputs = einsum_ms(inputs, selected_loras) output_tensor = ops.add(output_tensor, outputs) - # if add_inputs: - # output_tensor += outputs[:] - # else: - # output_tensor = outputs[:] + if add_inputs: + output_tensor += outputs[:] + else: + output_tensor = outputs[:] return output_tensor diff --git a/vllm_mindspore/lora/punica_wrapper/punica_npu.py b/vllm_mindspore/lora/punica_wrapper/punica_npu.py index 6531659a..9cdd985c 100644 --- a/vllm_mindspore/lora/punica_wrapper/punica_npu.py +++ b/vllm_mindspore/lora/punica_wrapper/punica_npu.py @@ -59,6 +59,7 @@ class PunicaWrapperNPU(PunicaWrapperBase, nn.Cell): x, w_t_all, self.group_list, + self.lora_indices, *self.prefill_metadata, scale, ) @@ -82,7 +83,6 @@ class PunicaWrapperNPU(PunicaWrapperBase, nn.Cell): x, w_t_all, y, - self.group_list, *self.prefill_metadata, add_inputs, ) @@ -110,6 +110,7 @@ class PunicaWrapperNPU(PunicaWrapperBase, nn.Cell): w_t_all, y, self.group_list, + self.lora_indices, *self.prefill_metadata, y_offset, y_slice_size, @@ -146,8 +147,8 @@ class PunicaWrapperNPU(PunicaWrapperBase, nn.Cell): expand_slice_fun: Callable = (self._expand_slice_prefill if get_model_context("is_prefill") else self._expand_slice_decode) - return expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs) - #return self._expand_slice_prefill(y, x, w_t_all, y_offset, y_slice_size, add_inputs) + #return expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs) + return self._expand_slice_prefill(y, x, w_t_all, y_offset, y_slice_size, add_inputs) def _apply_shrink(self, x, w_t_all, scale): """ @@ -163,7 +164,7 @@ class PunicaWrapperNPU(PunicaWrapperBase, nn.Cell): shrink_fun: Callable = (self._shrink_prefill if get_model_context("is_prefill") else self._shrink_decode) #y = shrink_fun(x, w_t_all, scale) - return shrink_fun(x, w_t_all, scale) + return self._shrink_prefill(x, w_t_all, scale) def update_metadata( self, @@ -184,26 +185,26 @@ class PunicaWrapperNPU(PunicaWrapperBase, nn.Cell): # else: # self.is_prefill = False _, seq_len, lora_indices, _, _, _ = self.prefill_metadata - sorted_ids, sorted_counts = sort_lora_by_token_count( - lora_indices, seq_len) - group_list = sorted_counts - if len(group_list) < self.max_loras: - new_tensor = mint.zeros(self.max_loras, dtype=group_list.dtype) - new_tensor[:group_list.size(0)] = group_list - group_list = new_tensor - self.group_list.set_data(group_list.astype(dtype.int64)) - # if len(lora_indices) > self.max_loras: - # self.group_list.set_data(seq_len[:self.max_loras].astype(dtype.int64)) - # self.lora_indices.set_data(lora_indices[:self.max_loras].astype(dtype.int64)) - # elif len(lora_indices) < self.max_loras: - # pad_len = int(self.max_loras - len(lora_indices)) - # lora_indices = ops.pad(lora_indices, (0, pad_len), mode='constant', value=0) - # seq_len = ops.pad(seq_len, (0, pad_len), mode='constant', value=0) - # self.group_list.set_data(seq_len.astype(dtype.int64)) - # self.lora_indices.set_data(lora_indices.astype(dtype.int64)) - # else: - # self.group_list.set_data(seq_len.astype(dtype.int64)) - # self.lora_indices.set_data(lora_indices.astype(dtype.int64)) + # sorted_ids, sorted_counts = sort_lora_by_token_count( + # lora_indices, seq_len) + # group_list = sorted_counts + # if len(group_list) < self.max_loras: + # new_tensor = mint.zeros(self.max_loras, dtype=group_list.dtype) + # new_tensor[:group_list.size(0)] = group_list + # group_list = new_tensor + # self.group_list.set_data(group_list.astype(dtype.int64)) + if len(lora_indices) > self.max_loras: + self.group_list.set_data(seq_len[:self.max_loras].astype(dtype.int64)) + self.lora_indices.set_data(lora_indices[:self.max_loras].astype(dtype.int64)) + elif len(lora_indices) < self.max_loras: + pad_len = int(self.max_loras - len(lora_indices)) + lora_indices = ops.pad(lora_indices, (0, pad_len), mode='constant', value=0) + seq_len = ops.pad(seq_len, (0, pad_len), mode='constant', value=0) + self.group_list.set_data(seq_len.astype(dtype.int64)) + self.lora_indices.set_data(lora_indices.astype(dtype.int64)) + else: + self.group_list.set_data(seq_len.astype(dtype.int64)) + self.lora_indices.set_data(lora_indices.astype(dtype.int64)) def add_shrink(self, x, lora_a_stacked, scale, **kwargs): @@ -339,11 +340,10 @@ class PunicaWrapperNPU(PunicaWrapperBase, nn.Cell): if self.no_lora: return x = x.reshape(-1, x.shape[-1]) - # assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) - # if lora_bias_stacked is not None: - # assert len(lora_bias_stacked) == len(output_slices) - # y = self._apply_bias(self.token_lora_indices, y, output_slices, - # lora_bias_stacked) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + y = self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) # if buffer is None: # r = lora_b_stacked[0].shape[-1] diff --git a/vllm_mindspore/lora/utils.py b/vllm_mindspore/lora/utils.py index 1fb215e0..53cc6d41 100644 --- a/vllm_mindspore/lora/utils.py +++ b/vllm_mindspore/lora/utils.py @@ -32,6 +32,7 @@ from vllm_mindspore.lora.layers import (BaseLayerWithLoRA, QKVParallelLinearWithLoRA, RowParallelLinearWithLoRA, VocabParallelEmbeddingWithLoRA) +from vllm_mindspore.model_executor.layers.quantization.sparse_quant_modelslim import A8W8SCLinearMethod # yapf: enable @@ -55,16 +56,17 @@ def replace_submodule(model, module_name, new_module): """Replace a submodule in a model with a new module.""" parent = model.get_submodule(".".join(module_name.split(".")[:-1])) target_name = module_name.split(".")[-1] - weight_name = module_name + ".weight" - lora_a_name = module_name + ".lora_a_weight" - lora_b_name = module_name + ".lora_b_weight" - lora_bias_name = module_name + ".lora_bias" - bias_name = module_name + ".bias" setattr(parent, target_name, new_module) - new_module.base_layer.weight.name = weight_name - new_module.lora_a_stacked.name = lora_a_name - new_module.lora_b_stacked.name = lora_b_name + new_module.base_layer.weight.name = module_name + ".weight" + new_module.lora_a_stacked.name = module_name + ".lora_a_weight" + new_module.lora_b_stacked.name = module_name + ".lora_b_weight" if new_module.base_layer.bias is not None: - new_module.base_layer.bias.name = bias_name - #new_module.lora_bias_stacked.name = lora_bias_name + new_module.base_layer.bias.name = module_name + ".bias" + #new_module.lora_bias_stacked.name = module_name + ".lora_bias" + if isinstance(new_module.base_layer.quant_method, A8W8SCLinearMethod): + new_module.base_layer.index.name = module_name + ".index" + new_module.base_layer.input_scale.name = module_name + ".input_scale" + new_module.base_layer.input_offset.name = module_name + ".input_offset" + new_module.base_layer.deq_scale.name = module_name + ".deq_scale" + new_module.base_layer.quant_bias.name = module_name + ".quant_bias" return new_module -- Gitee