From 28f1082bba793c326b4258d246dcce95be298f31 Mon Sep 17 00:00:00 2001 From: zhang_xu_hao1230 Date: Sat, 26 Apr 2025 16:41:33 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8E=9F=E7=94=9Fqwen2=E6=94=AF=E6=8C=81lora?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- install_depend_pkgs.sh | 2 +- tests/st/python/test_multilora_inference.py | 113 ++ vllm_mindspore/__init__.py | 80 +- vllm_mindspore/attention/layer.py | 77 +- vllm_mindspore/lora/__init__.py | 0 vllm_mindspore/lora/layers.py | 1165 +++++++++++++++++ vllm_mindspore/lora/models.py | 227 ++++ vllm_mindspore/lora/ops/__init__.py | 0 vllm_mindspore/lora/ops/torch_ops/__init__.py | 0 vllm_mindspore/lora/ops/torch_ops/lora_ops.py | 171 +++ .../lora/punica_wrapper/__init__.py | 0 .../lora/punica_wrapper/punica_npu.py | 357 +++++ vllm_mindspore/lora/utils.py | 47 + .../model_executor/layers/rotary_embedding.py | 27 +- .../layers/vocab_parallel_embedding.py | 130 +- .../model_executor/models/model_base.py | 118 +- vllm_mindspore/model_executor/models/qwen2.py | 299 +++-- vllm_mindspore/model_executor/models/utils.py | 62 +- vllm_mindspore/platforms/ascend.py | 27 +- 19 files changed, 2523 insertions(+), 379 deletions(-) create mode 100644 tests/st/python/test_multilora_inference.py create mode 100644 vllm_mindspore/lora/__init__.py create mode 100644 vllm_mindspore/lora/layers.py create mode 100644 vllm_mindspore/lora/models.py create mode 100644 vllm_mindspore/lora/ops/__init__.py create mode 100644 vllm_mindspore/lora/ops/torch_ops/__init__.py create mode 100644 vllm_mindspore/lora/ops/torch_ops/lora_ops.py create mode 100644 vllm_mindspore/lora/punica_wrapper/__init__.py create mode 100644 vllm_mindspore/lora/punica_wrapper/punica_npu.py create mode 100644 vllm_mindspore/lora/utils.py diff --git a/install_depend_pkgs.sh b/install_depend_pkgs.sh index ba0f7988..b3d8306e 100644 --- a/install_depend_pkgs.sh +++ b/install_depend_pkgs.sh @@ -100,4 +100,4 @@ cd "$msadapter_dir" || { echo "Failed to git clone msadapter!"; exit 1; } pip uninstall msadapter -y && pip install . || { echo "Failed to install msadapter"; exit 1; } cd .. -echo "========= All dependencies installed successfully!" +echo "========= All dependencies installed successfully!" \ No newline at end of file diff --git a/tests/st/python/test_multilora_inference.py b/tests/st/python/test_multilora_inference.py new file mode 100644 index 00000000..d5e86441 --- /dev/null +++ b/tests/st/python/test_multilora_inference.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +This example shows how to use the multi-LoRA functionality +for offline inference. + +""" +import pytest +import os +from . import set_env + +env_manager = set_env.EnvVarManager() +# def env +env_vars = { + "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"), + "MS_ENABLE_LCCL": "off", + "HCCL_OP_EXPANSION_MODE": "AIV", + "ASCEND_RT_VISIBLE_DEVICES": "0,1", + "MS_ALLOC_CONF": "enable_vmm:True", + "LCCL_DETERMINISTIC": "1", + "HCCL_DETERMINISTIC": "true", + "ATB_MATMUL_SHUFFLE_K_ENABLE": "0", + "ATB_LLM_LCOC_ENABLE": "0", + "VLLM_USE_V1": "1", +} +# set env +env_manager.setup_ai_environment(env_vars) +import vllm_mindspore +from typing import List, Optional, Tuple + +from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams +from vllm.lora.request import LoRARequest + + +def create_test_prompts( + lora_path: str +) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]: + """Create a list of test prompts with their sampling parameters. + """ + return [ + ("违章停车与违法停车是否有区别?", + SamplingParams(temperature=0.0, top_p=1, top_k=-1, + max_tokens=10), LoRARequest("sql-lora1", 1, + lora_path)), + ] + + +def process_requests(engine: LLMEngine, + test_prompts: List[Tuple[str, SamplingParams, + Optional[LoRARequest]]]): + """Continuously process a list of prompts and handle the outputs.""" + request_id = 0 + + while test_prompts or engine.has_unfinished_requests(): + if test_prompts: + prompt, sampling_params, lora_request = test_prompts.pop(0) + engine.add_request(str(request_id), + prompt, + sampling_params, + lora_request=lora_request) + request_id += 1 + + request_outputs: List[RequestOutput] = engine.step() + for request_output in request_outputs: + if request_output.finished: + print(f'text is: {request_output.outputs[0].text}', flush=True) + assert " 从法律上来说,违章停车和违法" in request_output.outputs[0].text + + +def initialize_engine() -> LLMEngine: + """Initialize the LLMEngine.""" + # max_loras: controls the number of LoRAs that can be used in the same + # batch. Larger numbers will cause higher memory usage, as each LoRA + # slot requires its own preallocated tensor. + # max_lora_rank: controls the maximum supported rank of all LoRAs. Larger + # numbers will cause higher memory usage. If you know that all LoRAs will + # use the same rank, it is recommended to set this as low as possible. + # max_cpu_loras: controls the size of the CPU LoRA cache. + engine_args = EngineArgs( + model="/home/workspace/mindspore_dataset/weight/Qwen2.5-7B-Instruct", + enable_lora=True, + max_loras=1, + max_lora_rank=64, + max_cpu_loras=2, + max_num_seqs=256, + max_model_len=256, + max_num_batched_tokens=400) + return LLMEngine.from_engine_args(engine_args) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_single +def test_multilora_inference(): + """test function that sets up and runs the prompt processing.""" + engine = initialize_engine() + lora_path = "/home/workspace/mindspore_dataset/weight/Qwen2.5-7B-Lora-Law" + test_prompts = create_test_prompts(lora_path) + process_requests(engine, test_prompts) + env_manager.unset_all() diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 5892937a..dfc1dfb0 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# encoding: utf-8 +# isort:skip_file # Copyright 2025 Huawei Technologies Co., Ltd # Copyright 2024 The vLLM team. # @@ -27,6 +27,7 @@ if "vllm" in sys.modules: # 1. set env before import mindspore. from vllm_mindspore.scripts import env_setup + env_setup() # 2. update the log configuration ahead of other modifications. @@ -49,14 +50,17 @@ import vllm.utils vllm.utils.current_platform = ascend_platform import vllm.attention.selector + vllm.attention.selector.current_platform = ascend_platform import vllm.engine.arg_utils from vllm_mindspore.engine.arg_utils import _is_v1_supported_oracle + vllm.engine.arg_utils.EngineArgs._is_v1_supported_oracle = _is_v1_supported_oracle import vllm.v1.engine.core from vllm_mindspore.v1.engine.core import shutdown + vllm.v1.engine.core.DPEngineCoreProc.shutdown = shutdown from vllm_mindspore.utils import ( @@ -78,6 +82,35 @@ vllm.utils.cuda_is_initialized = ascend_is_initialized vllm.utils.memory_profiling = ms_memory_profiling vllm.config.cuda_device_count_stateless = ascend_device_count_stateless +import vllm.lora.utils + +from vllm_mindspore.model_executor.layers.linear import LinearBase +from vllm_mindspore.lora.utils import _all_lora_classes + +vllm.lora.utils._all_lora_classes = _all_lora_classes +vllm.lora.utils.LinearBase = LinearBase + +import vllm.lora.models +from vllm_mindspore.lora.models import register_module, from_local_checkpoint, from_lora_tensors + +vllm.lora.models.LoRAModelManager.register_module = register_module +vllm.lora.models.LoRAModel.from_local_checkpoint = from_local_checkpoint +vllm.lora.models.LoRAModel.from_lora_tensors = from_lora_tensors + +from vllm_mindspore.lora.layers import (ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLoRA, + QKVParallelLinearWithLoRA, + RowParallelLinearWithLoRA) + +import vllm.lora.layers + +vllm.lora.layers.ColumnParallelLinearWithLoRA = ColumnParallelLinearWithLoRA +vllm.lora.layers.MergedColumnParallelLinearWithLoRA = MergedColumnParallelLinearWithLoRA +vllm.lora.layers.MergedQKVParallelLinearWithLoRA = MergedQKVParallelLinearWithLoRA +vllm.lora.layers.QKVParallelLinearWithLoRA = QKVParallelLinearWithLoRA +vllm.lora.layers.RowParallelLinearWithLoRA = RowParallelLinearWithLoRA + import vllm.executor vllm.executor.cuda_device_count_stateless = ascend_device_count_stateless @@ -102,11 +135,9 @@ from vllm.model_executor.model_loader import get_model_architecture vllm.model_executor.model_loader.get_model_architecture = get_ms_model_architecture vllm.model_executor.model_loader.utils.get_model_architecture = ( - get_ms_model_architecture -) + get_ms_model_architecture) vllm.model_executor.model_loader.loader.get_model_architecture = ( - get_ms_model_architecture -) + get_ms_model_architecture) from vllm_mindspore.model_executor.sampling_metadata import ( SequenceGroupToSample, @@ -133,12 +164,10 @@ vllm.worker.cache_engine.CacheEngine.swap_in = ms_swap_in vllm.worker.cache_engine.CacheEngine.swap_out = ms_swap_out from vllm_mindspore.model_executor.model_loader.weight_utils import ( - safetensors_weights_iterator, -) + safetensors_weights_iterator, ) vllm.model_executor.model_loader.loader.safetensors_weights_iterator = ( - safetensors_weights_iterator -) + safetensors_weights_iterator) from vllm_mindspore.worker.worker import _warm_up_model from vllm_mindspore.worker.profile import ( @@ -158,15 +187,13 @@ from vllm_mindspore.worker.model_runner import ( ) vllm.worker.model_runner.ModelInputForGPUBuilder._get_cuda_graph_pad_size = ( - _get_cuda_graph_pad_size -) + _get_cuda_graph_pad_size) vllm.worker.model_runner.GPUModelRunnerBase._dummy_run = _dummy_run import vllm.worker.multi_step_model_runner vllm.worker.multi_step_model_runner._get_supported_attention_backends = ( - _get_supported_attention_backends -) + _get_supported_attention_backends) from vllm_mindspore.executor.multiproc_worker_utils import ( get_mp_context as ms_get_mp_context, @@ -183,8 +210,10 @@ import vllm.executor.multiproc_worker_utils vllm.executor.multiproc_worker_utils.ProcessWorkerWrapper.terminate_worker = ms_terminate_worker import vllm.v1.executor.multiproc_executor + vllm.v1.executor.multiproc_executor.get_mp_context = ms_get_mp_context import vllm.v1.utils + vllm.v1.utils.get_mp_context = ms_get_mp_context from vllm_mindspore.executor.ray_gpu_executor import ( @@ -219,6 +248,7 @@ vllm.config.ParallelConfig.has_unfinished_dp = has_unfinished_dp from .utils import update_modules from vllm_mindspore.attention.backends import ms_attn + update_modules("vllm.attention.backends.flash_attn", ms_attn) from vllm_mindspore.worker.spec_decode_worker import ( @@ -229,20 +259,25 @@ from vllm_mindspore.worker.spec_decode_worker import ( _merge_outputs, ) from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker + SpecDecodeWorker.__init__ = spec_decode_worker_init SpecDecodeWorker._verify_tokens = _verify_tokens SpecDecodeWorker._run_no_spec = _run_no_spec from vllm.model_executor.layers.spec_decode_base_sampler import SpecDecodeBaseSampler + SpecDecodeBaseSampler._create_output = _create_output from vllm.spec_decode.top1_proposer import Top1Proposer + Top1Proposer._merge_outputs = _merge_outputs from vllm_mindspore.model_executor.layers.rejection_sampler import _smallest_positive_value, _multinomial from vllm.model_executor.layers.rejection_sampler import RejectionSampler + RejectionSampler._smallest_positive_value = _smallest_positive_value -RejectionSampler._smallest_positive_value.__set_name__(RejectionSampler, '_smallest_positive_value') +RejectionSampler._smallest_positive_value.__set_name__( + RejectionSampler, '_smallest_positive_value') vllm.model_executor.layers.rejection_sampler._multinomial = _multinomial ######### for multi-model @@ -258,34 +293,42 @@ from vllm_mindspore.model_executor.layers.rotary_embedding import InferMRotaryEm vllm.model_executor.layers.rotary_embedding.MRotaryEmbedding = InferMRotaryEmbedding from vllm_mindspore.v1.sample import rejection_sampler + update_modules("vllm.v1.sample.rejection_sampler", rejection_sampler) from vllm_mindspore.v1.spec_decode import eagle + update_modules("vllm.v1.spec_decode.eagle", eagle) from vllm_mindspore.v1.attention.backends import flash_attn import vllm.v1.attention.backends + sys.modules['vllm.v1.attention.backends.flash_attn'] = flash_attn import vllm.v1.attention.backends.flash_attn import vllm.v1.worker.gpu_model_runner from vllm_mindspore.v1.worker.gpu_model_runner import _prepare_inputs + vllm.v1.worker.gpu_model_runner.GPUModelRunner._prepare_inputs = _prepare_inputs from vllm_mindspore.v1.worker.gpu_model_runner import _update_states + vllm.v1.worker.gpu_model_runner.GPUModelRunner._update_states = _update_states from vllm_mindspore.v1.worker.gpu_model_runner import initialize_kv_cache + vllm.v1.worker.gpu_model_runner.GPUModelRunner.initialize_kv_cache = initialize_kv_cache import vllm.v1.worker.block_table from vllm_mindspore.v1.worker.block_table import BlockTable + vllm.v1.worker.block_table.BlockTable = BlockTable vllm.v1.worker.gpu_input_batch.BlockTable = BlockTable import vllm.v1.worker.gpu_input_batch from vllm_mindspore.v1.worker.gpu_input_batch import _make_sampling_metadata, _make_prompt_token_ids_tensor + vllm.v1.worker.gpu_input_batch.InputBatch._make_sampling_metadata = _make_sampling_metadata vllm.v1.worker.gpu_model_runner.InputBatch._make_sampling_metadata = _make_sampling_metadata vllm.v1.worker.gpu_input_batch.InputBatch._make_prompt_token_ids_tensor = _make_prompt_token_ids_tensor @@ -297,17 +340,19 @@ from vllm_mindspore.v1.worker.gpu_worker import init_device Worker.__init__ = wrapper_worker_init(Worker.__init__) Worker.init_device = wrapper_worker_init_device(init_device) - import vllm.v1.utils from vllm_mindspore.v1.utils import copy_slice + vllm.v1.utils.copy_slice = copy_slice vllm.v1.worker.gpu_input_batch.copy_slice = copy_slice from vllm_mindspore.v1.sample.ops.penalties import _convert_to_tensors import vllm.v1.sample.ops.penalties + vllm.v1.sample.ops.penalties._convert_to_tensors = _convert_to_tensors import vllm.model_executor.layers.utils from vllm_mindspore.model_executor.layers.utils import apply_penalties + vllm.model_executor.layers.utils.apply_penalties = apply_penalties vllm.v1.sample.ops.penalties.apply_penalties = apply_penalties @@ -317,26 +362,31 @@ from vllm_mindspore.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p, ra import vllm.v1.sample.ops.topk_topp_sampler from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler + TopKTopPSampler.forward_native = topk_topp_sampler_forward_native vllm.v1.sample.ops.topk_topp_sampler.apply_top_k_top_p = apply_top_k_top_p vllm.v1.sample.ops.topk_topp_sampler.random_sample = random_sample vllm.v1.sample.ops.topk_topp_sampler.apply_top_k_only = apply_top_k_only from vllm_mindspore.v1.sample.sampler import apply_temperature import vllm.v1.sample.sampler + vllm.v1.sample.sampler.Sampler.apply_temperature = apply_temperature from vllm_mindspore.distributed.shm_broadcast import initialize_ShmRingBuffer from vllm.distributed.device_communicators.shm_broadcast import ShmRingBuffer + ShmRingBuffer.__init__ = initialize_ShmRingBuffer from vllm_mindspore.v1.worker.gpu_worker import compile_or_warm_up_model from vllm.v1.worker.gpu_worker import Worker + Worker.compile_or_warm_up_model = compile_or_warm_up_model from .utils import check_ready from vllm_mindspore.engine.multiprocessing.engine import cleanup import vllm.engine.multiprocessing.engine + vllm.engine.multiprocessing.engine.MQLLMEngine.cleanup = cleanup check_ready() diff --git a/vllm_mindspore/attention/layer.py b/vllm_mindspore/attention/layer.py index 89914e97..73b294f2 100644 --- a/vllm_mindspore/attention/layer.py +++ b/vllm_mindspore/attention/layer.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# encoding: utf-8 # Copyright 2025 Huawei Technologies Co., Ltd # Copyright 2024 The vLLM team. # @@ -18,37 +17,31 @@ """Common layer for LLM.""" from typing import Any, Dict, List, Optional, Tuple -from mindspore import Tensor, mint, nn, ops, jit +from mindspore import Tensor, mint, nn, ops from mindspore.common import dtype as mstype from mindspore.ops.auto_generate import PagedAttention, ReshapeAndCache from mindspore.ops.operations.nn_ops import FlashAttentionScore - -from vllm.config import CacheConfig from vllm.attention.backends.abstract import AttentionType -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.config import CacheConfig +from vllm.model_executor.layers.quantization.base_config import \ + QuantizationConfig -def _pad_to_max_tensor( - input_: Tensor, - max_len: int, - dim: int = 0, - pad_value: int = -1 -) -> Tensor: +def _pad_to_max_tensor(input_: Tensor, + max_len: int, + dim: int = 0, + pad_value: int = -1) -> Tensor: """Temporary function, will be deprecated in the future.""" if input_.shape[dim] == max_len: return input_ - pad_shape = (input_.shape[0], max_len - input_.shape[dim], *input_.shape[dim + 1:]) + pad_shape = (input_.shape[0], max_len - input_.shape[dim], + *input_.shape[dim + 1:]) pad_tensor = mint.ones(size=pad_shape, dtype=input_.dtype) * pad_value output = mint.cat([input_, pad_tensor], dim=dim) return output -def _generate_attn_mask( - query: Tensor, - value: Tensor, - flatten: bool -) -> Tensor: +def _generate_attn_mask(query: Tensor, value: Tensor, flatten: bool) -> Tensor: """Temporary function, will be deprecated in the future.""" if flatten: return mint.triu(mint.ones(size=(128, 128), dtype=query.dtype), 1) @@ -59,16 +52,14 @@ def _generate_attn_mask( return mask -def _hidden_states_th2bsh( - input_: Tensor, - batch_valid_length: Tensor -) -> Tensor: +def _hidden_states_th2bsh(input_: Tensor, + batch_valid_length: Tensor) -> Tensor: """Temporary function, will be deprecated in the future.""" max_seq_len = batch_valid_length.max().item() start_pos = 0 padding_input_list = [] for valid_length in batch_valid_length: - valid_input = input_[:, start_pos: start_pos + valid_length, :] + valid_input = input_[:, start_pos:start_pos + valid_length, :] padded_input = _pad_to_max_tensor(valid_input, max_seq_len, 1) padding_input_list.append(padded_input) start_pos += valid_length @@ -76,10 +67,8 @@ def _hidden_states_th2bsh( return bsh_output -def _hidden_states_bsh2th( - input_: Tensor, - batch_valid_length: Tensor -) -> Tensor: +def _hidden_states_bsh2th(input_: Tensor, + batch_valid_length: Tensor) -> Tensor: """Temporary function, will be deprecated in the future.""" unpadded_input_list = [] for batch_index, valid_length in enumerate(batch_valid_length): @@ -128,9 +117,9 @@ class Attention(nn.Cell): self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_size = head_size - self.hidden_size_per_partition = num_heads*head_size - self.kv_hidden_size_per_partition = num_kv_heads*head_size self.flatten = True + self.hidden_size_per_partition = num_heads * head_size + self.kv_hidden_size_per_partition = num_kv_heads * head_size input_layout = "TH" if self.flatten else "BSH" # pynative 下不支持拉平操作。 scale = float(scale) @@ -147,7 +136,6 @@ class Attention(nn.Cell): scale_value=scale, kv_head_num=num_kv_heads) - @jit def construct( self, query: Tensor, @@ -162,7 +150,7 @@ class Attention(nn.Cell): q_seq_lens: Tensor, block_tables: Tensor, ) -> Tensor: - """Attention foward, support MHA and GQA. + """Attention forward, support MHA and GQA. Args: query: shape = [1, num_tokens, hidden_size] @@ -173,13 +161,18 @@ class Attention(nn.Cell): batch_valid_length: shape = [batch_size, ] block_tables: shape = [block_size, num_block] """ - output = query - cache_out = self.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) + # ensure that the input tensors of reshape_and_cache is contiguous + value = value.contiguous() + cache_out = self.reshape_and_cache(key, value, key_cache, value_cache, + slot_mapping) query = ops.depend(query, cache_out) if is_prefill: - output = self._run_prefill_forward(query, key, value, attn_mask, batch_valid_length, batch_valid_length) + output = self._run_prefill_forward(query, key, value, attn_mask, + batch_valid_length, + batch_valid_length) else: - output = self._run_decode_forward(query, key_cache, value_cache, block_tables, batch_valid_length, + output = self._run_decode_forward(query, key_cache, value_cache, + block_tables, batch_valid_length, attn_mask, q_seq_lens) return output @@ -239,15 +232,7 @@ class Attention(nn.Cell): block_tables: shape = [block_size, num_block] context_lens: shape = [batch_size, ] """ - output = self.paged_attention( - query, - key_cache, - value_cache, - block_tables, - batch_valid_length, - None, - None, - attn_mask, - q_seq_lens - ) + output = self.paged_attention(query, key_cache, value_cache, + block_tables, batch_valid_length, None, + None, attn_mask, q_seq_lens) return output diff --git a/vllm_mindspore/lora/__init__.py b/vllm_mindspore/lora/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_mindspore/lora/layers.py b/vllm_mindspore/lora/layers.py new file mode 100644 index 00000000..19a132c0 --- /dev/null +++ b/vllm_mindspore/lora/layers.py @@ -0,0 +1,1165 @@ +#!/usr/bin/env python3 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import math +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast + +import mindspore as ms +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig +from vllm.adapter_commons.layers import AdapterMapping +from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce) +from vllm.distributed.utils import divide +# yapf: enable +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.rotary_embedding import ( + LinearScalingRotaryEmbedding, RotaryEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import \ + VocabParallelEmbedding + +# yapf: disable +from vllm_mindspore.model_executor.layers.linear import ( + ColumnParallelLinear, LinearBase, MergedColumnParallelLinear, + QKVParallelLinear, RowParallelLinear) + +if TYPE_CHECKING: + from vllm.lora.punica_wrapper import PunicaWrapperBase + + +def _get_lora_device(base_layer: nn.Module) -> torch.device: + # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34 + """Returns the device for where to place the LoRA tensors.""" + # unquantizedLinear + if hasattr(base_layer, "weight"): + return base_layer.weight.device + # Compressed Tensor + elif hasattr(base_layer, "weight_packed"): + return base_layer.weight_packed.device + # GPTQ/AWQ + elif hasattr(base_layer, "qweight"): + return base_layer.qweight.device + # marlin + elif hasattr(base_layer, "B"): + return base_layer.B.device + # HQQ marlin + elif hasattr(base_layer, "W_q"): + return base_layer.W_q.device + else: + raise ValueError(f"Unsupported base layer: {base_layer}") + + +def _not_fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of not using fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + decorate = kwargs.pop("decorate") if "decorate" in kwargs else True + condition = (not kwargs["lora_config"].fully_sharded_loras + if decorate else True) + return can_replace(*args, **kwargs) and condition + + return dec + + +@dataclass +class LoRAMapping(AdapterMapping): + is_prefill: bool = False + +# vllm-mindspore Inherits ms.nn.Cell +class BaseLayerWithLoRA(ms.nn.Cell): + + def slice_lora_a( + self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]: + """Slice lora a if splitting for tensor parallelism.""" + ... + + def slice_lora_b( + self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]: + """Slice lora b if splitting with tensor parallelism.""" + ... + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """Initializes lora matrices.""" + ... + + def reset_lora(self, index: int): + """Resets the lora weights at index back to 0.""" + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + """Overwrites lora tensors at index.""" + ... + + def set_mapping( + self, + punica_wrapper, + ): + self.punica_wrapper: PunicaWrapperBase = punica_wrapper + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + raise NotImplementedError + + +class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: VocabParallelEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + self.embeddings_slice: Optional[Tuple[int, int]] + self.embeddings_weights: Optional[torch.Tensor] + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + + if self.base_layer.num_added_embeddings_per_partition > 0: + # We can start adding lora weights + self.embeddings_weights = self.base_layer.weight.data[ + self.base_layer.num_org_embeddings_per_partition:self. + base_layer.num_org_embeddings_per_partition + + self.base_layer.num_added_embeddings_per_partition] + self.embeddings_slice = ( + self.base_layer.shard_indices.added_vocab_start_index - + self.base_layer.org_vocab_size, + self.base_layer.shard_indices.added_vocab_end_index - + self.base_layer.org_vocab_size) + self.base_layer.weight.data[ + self.base_layer.num_org_embeddings_per_partition:].fill_(0) + else: + self.embeddings_slice = None + self.embeddings_weights = None + + self.embeddings_tensors = torch.zeros( + ( + max_loras, + lora_config.lora_extra_vocab_size, + self.base_layer.embedding_dim, + ), + dtype=self.base_layer.weight.dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked = torch.zeros( + ( + max_loras, + self.base_layer.org_vocab_size + + lora_config.lora_extra_vocab_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + self.base_layer.embedding_dim, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked_2d = self.lora_a_stacked.view( + self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], + self.lora_a_stacked.shape[2], + ) + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + self.reset_lora(index) + self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( + lora_a, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, :embeddings_tensor.shape[0], :embeddings_tensor. + shape[1], ].copy_(embeddings_tensor, non_blocking=True) + if self.embeddings_slice is not None: + # TODO(yard1): Optimize this copy, we don't need to copy + # everything, just the modified part + embeddings = self.embeddings_tensors.view( + self.embeddings_tensors.shape[0] * + self.embeddings_tensors.shape[1], + self.embeddings_tensors.shape[2], + )[self.embeddings_slice[0]:self.embeddings_slice[1]] + assert self.embeddings_weights is not None + self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) + + def construct(self, x: torch.Tensor) -> torch.Tensor: + added_tokens_mask = x > self.base_layer.org_vocab_size - 1 + embeddings_indices = self.punica_wrapper.embeddings_indices + indices = embeddings_indices[1].view_as(x) + full_lora_a_embeddings = F.embedding( + x + indices, + self.lora_a_stacked_2d, + ) + indices = embeddings_indices[0].view_as(x) + full_output = self.base_layer.forward( + x.add_(indices * added_tokens_mask)) + + full_output_org = full_output + if full_output.ndim == 3: + full_output = full_output.view( + full_output.shape[0] * full_output.shape[1], -1) + if full_lora_a_embeddings.ndim == 3: + full_lora_a_embeddings = full_lora_a_embeddings.view( + full_lora_a_embeddings.shape[0] * + full_lora_a_embeddings.shape[1], + -1, + ) + + full_output = self.punica_wrapper.add_lora_embedding( + full_output, + full_lora_a_embeddings, + self.lora_b_stacked, + add_input=True) + return full_output.view_as(full_output_org) + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is VocabParallelEmbedding + + +class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: LinearBase): + super().__init__() + self.base_layer = base_layer + self.input_size = self.base_layer.input_size + self.device = _get_lora_device(self.base_layer) + self.lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]] = None + + self.output_slices: Tuple[int, ...] + self.tp_size: int + self.output_size: int + self.n_slices: int + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + self.lora_config = lora_config + + if isinstance(self.base_layer, ColumnParallelLinear): + lora_a_out_size = (lora_config.max_lora_rank if + not lora_config.fully_sharded_loras else divide( + lora_config.max_lora_rank, self.tp_size)) + lora_b_out_size = self.output_size + + elif isinstance(self.base_layer, RowParallelLinear): + lora_a_out_size = lora_config.max_lora_rank + lora_b_out_size = (self.output_size if + not lora_config.fully_sharded_loras else divide( + self.output_size, self.tp_size)) + else: + raise NotImplementedError + + self.lora_a_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_a_out_size, + self.input_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_b_out_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + if lora_config.bias_enabled: + lora_bias_out_size = lora_b_out_size + self.lora_bias_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_bias_out_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + self.output_slices = (self.lora_b_stacked[0].shape[2], ) + + def reset_lora(self, index: int): + for s_index in range(self.n_slices): + self.lora_a_stacked[s_index][index] = 0 + self.lora_b_stacked[s_index][index] = 0 + if self.lora_config.bias_enabled: + # Make mypy happy + self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked) + self.lora_bias_stacked[s_index][index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + lora_bias: Optional[torch.Tensor] = None, + ): + # Except for QKVParallelLinearWithLora and + # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers + # store weights in a tuple of size 1. These two layers will + # override this function. + assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) == + self.n_slices == 1) + + self.reset_lora(index) + if self.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + if lora_bias is not None: + lora_bias = self.slice_bias(lora_bias) + + self.lora_a_stacked[0][index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[0][index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if lora_bias is not None: + + self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked) + assert len(self.lora_bias_stacked) + self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( + lora_bias.T, non_blocking=True) + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked, + self.lora_b_stacked, + self.lora_bias_stacked, 1.0, + self.output_slices) + return output + + + + +class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): + """ + LoRA on top of ColumnParallelLinear layer. + LoRA B is sliced for tensor parallelism. + There are two types for the `base_layer`: + 1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`. + 2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`. + """ + + def __init__(self, base_layer: ColumnParallelLinear) -> None: + super().__init__(base_layer) + # The base_layer type is ColumnParallelLinear or + # MergedColumnParallelLinear, their weight sharding logic is + # inconsistent when TP is greater than 1. + self.is_merged_col_linear = type( + base_layer) is MergedColumnParallelLinear + self.tp_size = get_tensor_model_parallel_world_size() + self.output_size = self.base_layer.output_size_per_partition + # There is only one LoRA layer + self.n_slices = 1 + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + # Applicable to cases where the base_layer is + # MergedColumnParallelLinear. + if self.is_merged_col_linear: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size // 2 + offset = lora_b.shape[-1] // 2 + + left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) * + shard_size] + right_weight = lora_b[:, offset + tp_rank * shard_size:offset + + (tp_rank + 1) * shard_size] + lora_b = torch.cat([left_weight, right_weight], dim=1) + # Applicable to cases where the base_layer is + # ColumnParallelLinear. + else: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + # TODO: Fix the slicing logic of bias. + if bias is None: + return bias + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + bias = bias[start_idx:end_idx] + return bias + + def construct( + self, input_: torch.Tensor + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """Forward of ColumnParallelLinear + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output_parallel = self.apply(input_, bias) + if self.base_layer.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + return output, output_bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is ColumnParallelLinear or ( + type(source_layer) is MergedColumnParallelLinear + and len(packed_modules_list) == 1) + + +class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + """ColumnParallelLinear layer that is composed of 2 sublayers (slices) + packed together (eg. gate_proj + up_proj -> gate_up_proj). + + This means we have 2 LoRAs, each applied to one half of the layer. + + Both slices must have the same size. + """ + + def __init__( + self, base_layer: Union[MergedColumnParallelLinear, + QKVParallelLinear]) -> None: + super().__init__(base_layer) + # There are two LoRA layers + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + # the output_sizes in MergedColumnParallelLinear is not sharded by tp + # we need to divide it by the tp_size to get correct slices size + output_sizes = self.base_layer.output_sizes + self.output_slices = tuple( + divide(output_size, self.tp_size) for output_size in output_sizes) + self.n_slices = len(self.output_slices) + self.output_ids = (self.tp_rank, ) * self.n_slices + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """ + The main reason for overriding this function is to enhance code + maintainability. + """ + self.lora_config = lora_config + + lora_a_output_size_per_partition = ( + lora_config.max_lora_rank if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size)) + self.lora_a_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_a_output_size_per_partition, + self.input_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + output_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ) for output_size in self.output_slices) + if lora_config.bias_enabled: + self.lora_bias_stacked = tuple( + torch.zeros( + max_loras, + 1, + output_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for output_size in self.output_slices) + + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + return lora_a + + def slice_lora_b( + self, lora_b: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + for i, (shard_id, shard_size) in enumerate( + zip(self.output_ids, self.output_slices)): + if (lora_b_i := lora_b[i]) is not None: + lora_b[i] = lora_b_i[:, shard_size * shard_id:shard_size * + (shard_id + 1)] + return lora_b + + def slice_bias( + self, bias: List[Union[torch.Tensor, + None]]) -> List[Union[torch.Tensor, None]]: + for i, (shard_id, shard_size) in enumerate( + zip(self.output_ids, self.output_slices)): + if (bias_i := bias[i]) is not None: + bias[i] = bias_i[shard_size * shard_id:shard_size * + (shard_id + 1)] + return bias + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + lora_bias: Optional[torch.Tensor] = None, + ): + self.reset_lora(index) + + if self.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + if lora_bias is not None: + lora_bias = self.slice_bias(lora_bias) + + for i in range(self.n_slices): + if (lora_a_i := lora_a[i]) is not None: + self.lora_a_stacked[i][ + index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_( + lora_a_i.T, non_blocking=True) + if (lora_b_i := lora_b[i]) is not None: + self.lora_b_stacked[i][ + index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_( + lora_b_i.T, non_blocking=True) + + if lora_bias is not None: + self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked) + for i in range(self.n_slices): + if (lora_bias_i := lora_bias[i]) is not None: + self.lora_bias_stacked[i][index, + 0, :lora_bias_i.shape[0]].copy_( + lora_bias_i.T, + non_blocking=True) + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return (type(source_layer) is MergedColumnParallelLinear + and len(packed_modules_list) == 2) + + +class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + """ + ColumnParallelLinear layer that is specifically designed for + qkv_proj. Certain models, such as chatglm3 and baichuan-7b, + only contains a single LoRA within their qkv_proj layer. + + During inference with Tensor Parallel, the weights of lora_b + must be accurately partitioned according to the respective ranks. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + self.q_proj_total_size = (self.base_layer.total_num_heads * + self.base_layer.head_size) + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.kv_proj_total_size = (self.base_layer.total_num_kv_heads * + self.base_layer.head_size) + # There is only one LoRA layer + self.n_slices = 1 + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + self.q_shard_id = tp_rank + self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas + lora_b_q = lora_b[:, self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + k_offset = self.q_proj_total_size + lora_b_k = lora_b[:, k_offset + + self.kv_proj_shard_size * self.kv_shard_id:k_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + v_offset = k_offset + self.kv_proj_total_size + lora_b_v = lora_b[:, v_offset + + self.kv_proj_shard_size * self.kv_shard_id:v_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1) + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + bias_q = bias[self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + k_offset = self.q_proj_total_size + bias_k = bias[k_offset + + self.kv_proj_shard_size * self.kv_shard_id:k_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + v_offset = k_offset + self.kv_proj_total_size + bias_v = bias[v_offset + + self.kv_proj_shard_size * self.kv_shard_id:v_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + bias = torch.cat([bias_q, bias_k, bias_v], dim=1) + return bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + return type(source_layer) is QKVParallelLinear and len( + packed_modules_list) == 1 + + +class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): + """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices) + packed together in qkv proj fashion + (q_proj + k_proj + v_proj -> qkv_proj). + + This means we have 3 LoRAs, each applied to one slice of the layer. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + # There are three LoRA layer. + self.n_slices = len(self.base_layer.output_sizes) + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.q_shard_id = self.tp_rank + self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas + + self.output_slices = ( + self.q_proj_shard_size, + self.kv_proj_shard_size, + self.kv_proj_shard_size, + ) + self.output_ids = ( + self.q_shard_id, + self.kv_shard_id, + self.kv_shard_id, + ) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """ + The main reason for overloading this function is to handle inconsistent + weight dimensions in qkv lora. + """ + super().create_lora_weights(max_loras, lora_config, model_config) + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return (type(source_layer) is QKVParallelLinear + and len(packed_modules_list) == 3) + + +class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): + + def __init__(self, base_layer: RowParallelLinear) -> None: + super().__init__(base_layer) + + self.tp_size = get_tensor_model_parallel_world_size() + # reset input_size + self.input_size = self.base_layer.input_size_per_partition + self.output_size = self.base_layer.output_size + + self.tp_rank = get_tensor_model_parallel_rank() + # There is only one LoRA layer. + self.n_slices = 1 + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + + shard_size = self.input_size + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_a = lora_a[start_idx:end_idx, :] + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + return bias + + def construct( + self, input_: torch.Tensor + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """Forward of RowParallelLinear + + Args: + input_: tensor whose last dimension is `input_size`. If + `input_is_parallel` is set, then the last dimension + is `input_size // tp_size`. + + Returns: + - output + - bias + """ + # Set up backprop all-reduce. + if self.base_layer.input_is_parallel: + input_parallel = input_ + else: + # TODO: simplify code below + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.base_layer.tp_size) + input_parallel = splitted_input[self.tp_rank].contiguous() + + # Matrix multiply. + output_parallel = self.apply(input_parallel) + if self.base_layer.reduce_results and self.base_layer.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.base_layer.skip_bias_add: + output = (output_ + self.base_layer.bias + if self.base_layer.bias is not None else output_) + output_bias = None + else: + output = output_ + output_bias = self.base_layer.bias + return output, output_bias + + @property + def weight(self): + return (self.base_layer.weight if hasattr(self.base_layer, "weight") + else self.base_layer.qweight) + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is RowParallelLinear + + +class LogitsProcessorWithLoRA(BaseLayerWithLoRA): + """ + LoRA wrapper for LogitsProcessor, with extra logic to handle the + application of the LoRA adapter and added LoRA vocabulary. + + Args: + base_layer: LogitsProcessor layer + hidden_size: hidden size of the model + dtype: data type of the model + device: device of the model + sharded_to_full_mapping: index mapping from sharded vocab to full vocab + received from base_layer.get_sharded_to_full_mapping(). If None, + no reindexing will be done. + """ + + def __init__(self, base_layer: LogitsProcessor, hidden_size: int, + dtype: torch.dtype, device: torch.device, + sharded_to_full_mapping: Optional[List[int]]) -> None: + super().__init__() + self.base_layer = base_layer + self.hidden_size = hidden_size + self.dtype = dtype + self.device = device + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.sharded_to_full_mapping = sharded_to_full_mapping + + @property + def logits_as_input(self): + return self.base_layer.logits_as_input + + @property + def vocab_size(self): + return self.base_layer.vocab_size + + @property + def scale(self): + return self.base_layer.scale + + @property + def soft_cap(self): + return self.base_layer.soft_cap + + @property + def use_all_gather(self): + return self.base_layer.use_all_gather + + @property + def org_vocab_size(self): + return self.base_layer.org_vocab_size + + @property + def include_gpu_probs_tensor(self): + return self.base_layer.include_gpu_probs_tensor + + @property + def should_modify_greedy_probs_inplace(self): + return self.base_layer.should_modify_greedy_probs_inplace + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + # TODO: Verify if this condition can be further relaxed + if 32000 < self.base_layer.vocab_size > 257024: + raise ValueError("When using LoRA, vocab size must be " + "32000 >= vocab_size <= 257024") + self.lora_a_stacked = torch.zeros( + ( + max_loras, + 1, + lora_config.max_lora_rank, + self.hidden_size, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + # Pad for kernel compatibility + math.ceil(self.base_layer.vocab_size / + lora_config.lora_vocab_padding_size) * + lora_config.lora_vocab_padding_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.embeddings_tensors = torch.full( + (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), + fill_value=float("-inf"), + dtype=self.dtype, + device=self.device, + ) + if self.sharded_to_full_mapping is not None: + self.sharded_to_full_mapping_gpu = torch.tensor( + self.sharded_to_full_mapping, + device=self.device, + dtype=torch.long) + else: + self.sharded_to_full_mapping_gpu = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = float("-inf") + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + self.reset_lora(index) + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, :embeddings_tensor.shape[0], :embeddings_tensor. + shape[1], ] = embeddings_tensor + + def _get_logits( + self, + hidden_states: torch.Tensor, + lm_head: VocabParallelEmbedding, + embedding_bias: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: + # Get the logits for the next tokens. + logits = lm_head.quant_method.apply(lm_head, hidden_states) + if embedding_bias is not None: + logits += embedding_bias + + # Gather logits for TP + logits = self.base_layer._gather_logits(logits) + + if logits is None: + return None + + if self.sharded_to_full_mapping_gpu is not None: + # Reindex full logits tensor to ensure 1:1 mapping between + # index and token_id + # Example for: + # org_vocab_size = 4 + # added_vocab_size = 2 + # pad_to_size = 8 + # tp_size = 2 + + # indices: [0, 1, 2, 3, 4, 5, 6, 7] + # token_id: [0, 1, 4, -1, 2, 3, 5, -1] + + # Therefore, the mapping is expected to be: + # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex, + # we get: + # indices: [0, 1, 2, 3, 4, 5, 6, 7] + # token_id: [0, 1, 2, 3, 4, 5, -1, -1] + logits = logits[:, self.sharded_to_full_mapping_gpu] + + lora_logits = torch.empty( + self.embeddings_tensors.shape[0] + 1, + self.embeddings_tensors.shape[1], + hidden_states.shape[0], + dtype=self.embeddings_tensors.dtype, + device=self.embeddings_tensors.device, + ) + torch.matmul(self.embeddings_tensors, + hidden_states.T, + out=lora_logits[:-1]) + lora_logits[-1] = float("-inf") + lora_logits = lora_logits.mT + indices_padded = self.punica_wrapper.sampler_indices_padded + lora_logits = (lora_logits.reshape( + lora_logits.shape[0] * lora_logits.shape[1], + lora_logits.shape[2], + ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"), + posinf=float("inf"), + neginf=float("-inf"))) + + logits[:, + self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + + lora_logits.shape[1]] = lora_logits + + # LogitsProcessorWithLoRA always using bgmv + self.punica_wrapper.add_lora_logits(logits, hidden_states, + self.lora_a_stacked, + self.lora_b_stacked, 1.0) + + # Remove paddings in vocab (if any). + logits = logits[:, :self.base_layer.vocab_size] + return logits + + def construct(self, *args, **kwargs): + return type(self.base_layer).forward(self, *args, **kwargs) + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + # Special handling for the LogitsProcessor. + return False + + +class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA): + """Implements RoPE-scaled embeddings with linear scaling for + multiple LoRA adapters with a specialized kernel. + + Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding + which can handle multi lora adapters in a specialied kernel. + """ + + def __init__(self, base_layer: RotaryEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + + @property + def scaling_factors(self): + return self.base_layer.scaling_factors + + @property + def rotary_dim(self): + return self.base_layer.rotary_dim + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + scaling_factors = (list(lora_config.long_lora_scaling_factors) + if lora_config.long_lora_scaling_factors else []) + base_scaling_factor = (self.base_layer.scaling_factor if isinstance( + self.base_layer, LinearScalingRotaryEmbedding) else 1.0) + scaling_factors = sorted( + list(set([base_scaling_factor] + scaling_factors))) + self.base_layer = LinearScalingRotaryEmbedding( + self.base_layer.head_size, + self.base_layer.rotary_dim, + self.base_layer.max_position_embeddings, + self.base_layer.base, + self.base_layer.is_neox_style, + scaling_factors, + self.base_layer.dtype, + ) + + def reset_lora(self, index: int): + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + ... + + def construct( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ): + return self.base_layer( + positions, + query, + key, + offsets=self.punica_wrapper.long_lora_indices, + ) + + @property + def scaling_factor_to_offset(self) -> Dict[float, int]: + return self.base_layer.scaling_factor_to_offset + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + """Returns True if the layer can be replaced by this LoRA layer.""" + return (type(source_layer) is LinearScalingRotaryEmbedding + or type(source_layer) is RotaryEmbedding) + + def extra_repr(self) -> str: + return self.base_layer.extra_repr() diff --git a/vllm_mindspore/lora/models.py b/vllm_mindspore/lora/models.py new file mode 100644 index 00000000..92197849 --- /dev/null +++ b/vllm_mindspore/lora/models.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +from typing import Dict, List, Optional, Union + +import safetensors.torch +import torch +from vllm.lora.lora import LoRALayerWeights +from vllm.lora.peft_helper import PEFTHelper +from vllm.lora.utils import is_regex_target_modules, parse_fine_tuned_lora_name +from vllm.model_executor.models.utils import WeightsMapper +from vllm.utils import is_pin_memory_available + +from vllm_mindspore.lora.layers import BaseLayerWithLoRA + +_GLOBAL_LORA_ID = 0 + + +def get_lora_id(): + global _GLOBAL_LORA_ID + _GLOBAL_LORA_ID += 1 + return _GLOBAL_LORA_ID + + +def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): + assert isinstance(module, BaseLayerWithLoRA) + self.modules[module_name] = module + + +@classmethod #type:ignore +def from_lora_tensors( + cls, + lora_model_id: int, + tensors: Dict[str, torch.Tensor], + peft_helper: PEFTHelper, + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + embeddings: Optional[Dict[str, torch.Tensor]] = None, + target_embedding_padding: Optional[int] = None, + embedding_modules: Optional[Dict[str, str]] = None, + embedding_padding_modules: Optional[List[str]] = None, + weights_mapper: Optional[WeightsMapper] = None, +): + """Create a LoRAModel from a dictionary of tensors.""" + pin_memory = str(device) == "cpu" and is_pin_memory_available() + loras: Dict[str, LoRALayerWeights] = {} + for tensor_name, tensor in tensors.items(): + module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name( + tensor_name, weights_mapper) + if module_name not in loras: + lora_embeddings_tensor = None + if embeddings: + assert embedding_modules is not None + embeddings_module = next( + (k for k in embedding_modules if k in module_name), None) + if embeddings_module: + lora_embeddings_tensor = embeddings[ + embedding_modules[embeddings_module]] + if pin_memory: + lora_embeddings_tensor = ( + lora_embeddings_tensor.pin_memory()) + loras[module_name] = LoRALayerWeights.from_config( + module_name, peft_helper, lora_embeddings_tensor) + + if is_bias: + # vllm-mindspore remove tensor device + loras[module_name].bias = tensor.to(dtype=dtype).t() + bias = tensor.to(dtype=dtype).t() + if pin_memory: + bias = bias.pin_memory() + loras[module_name].bias = bias + elif is_lora_a: + loras[module_name].lora_a = tensor.to(dtype=dtype).t() + if pin_memory: + loras[module_name].lora_a = loras[ + module_name].lora_a.pin_memory() + else: + loras[module_name].lora_b = tensor.to(dtype=dtype).t() + assert embedding_padding_modules is not None + if any(name in module_name for name in embedding_padding_modules + ) and target_embedding_padding is not None: + lora_b = loras[module_name].lora_b + assert target_embedding_padding >= lora_b.shape[1] + addition = target_embedding_padding - lora_b.shape[1] + loras[module_name].lora_b = torch.nn.functional.pad( + lora_b, (0, addition)) + if pin_memory: + loras[module_name].lora_b = loras[ + module_name].lora_b.pin_memory() + + for lora in loras.values(): + lora.optimize() + + return cls(lora_model_id, + peft_helper.r, + loras, + scaling_factor=peft_helper.vllm_long_context_scaling_factor) + + +@classmethod #type:ignore +def from_local_checkpoint( + cls, + lora_dir: str, + expected_lora_modules: List[str], + peft_helper: PEFTHelper, + *, + lora_model_id: Optional[int] = None, + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + target_embedding_padding: Optional[int] = None, + embedding_modules: Optional[Dict[str, str]] = None, + embedding_padding_modules: Optional[List[str]] = None, + weights_mapper: Optional[WeightsMapper] = None, +): + """Create a LoRAModel from a local checkpoint. + + Args: + lora_dir: The local path that has lora data. + expected_lora_modules: Name of modules that are expected to be + replaced by lora. + peft_helper: Loaded lora configuration information. + lora_model_id: Lora model id. If not given, automatically set by + a global counter. + device: Device where the lora model is loaded. + dtype: dtype of the lora model weights. + + Returns: + Loaded LoRA Model. + """ + lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") + lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") + new_embeddings_tensor_path = os.path.join(lora_dir, + "new_embeddings.safetensors") + new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin") + + unexpected_modules: List[Union[list[str], str]] + if os.path.isfile(lora_tensor_path): + tensors: Dict[str, torch.Tensor] = {} + # Find unexpected modules. + # Use safetensor key as a source of truth to find expected modules. + # in peft if you have target_modules A, B, C and C does not exist + # in the model it won’t error and model will be trained with A, B + # loraified. C won’t exist in the safetensor but it will exist in + # the target_modules of the adapter_config.json. + unexpected_modules = [] + # vllm-mindspore safetensors open with np + with safetensors.safe_open(lora_tensor_path, + framework="np") as f: # type: ignore + for lora_module in f.keys(): # noqa + module_name, _, _ = parse_fine_tuned_lora_name( + lora_module, weights_mapper) + part_name = module_name.split(".")[-1] + if part_name not in expected_lora_modules: + unexpected_modules.append(module_name) + if unexpected_modules: + raise ValueError( + f"While loading {lora_dir}, expected" + f" target modules in {expected_lora_modules}" + f" but received {unexpected_modules}." + f" Please verify that the loaded LoRA module is correct") + # Load tensors if there are only expected modules. + for module in f.keys(): # noqa + # vllm-mindspore add numpy to tensor + tensors[module] = torch.Tensor(f.get_tensor(module)) + elif os.path.isfile(lora_bin_file_path): + # When a bin file is provided, we rely on config to find unexpected + # modules. + unexpected_modules = [] + target_modules = peft_helper.target_modules + if not isinstance(target_modules, list): + target_modules = [target_modules] + for module in target_modules: + # Compatible with more modules, + # such as:layers.11.self_attn.k_proj + part_name = module.split(".")[-1] + if part_name not in expected_lora_modules: + unexpected_modules.append(module) + # loaded lora's target modules must be a subset of + # expected_lora_modules. It is not reliable. See + # https://github.com/vllm-project/vllm/pull/5909. But there's no + # other better mechanism. + if unexpected_modules and not is_regex_target_modules( + peft_helper.target_modules, expected_lora_modules): + raise ValueError( + f"While loading {lora_dir}, expected" + f" target modules in {expected_lora_modules}" + f" but received {unexpected_modules}." + f" Please verify that the loaded LoRA module is correct") + tensors = torch.load(lora_bin_file_path, map_location=device) + else: + raise ValueError(f"{lora_dir} doesn't contain tensors") + + embeddings = None + if os.path.isfile(new_embeddings_tensor_path): + embeddings = safetensors.torch.load_file(new_embeddings_tensor_path) + elif os.path.isfile(new_embeddings_bin_file_path): + embeddings = torch.load(new_embeddings_bin_file_path, + map_location=device, + weights_only=True) + + return cls.from_lora_tensors( + lora_model_id=get_lora_id() + if lora_model_id is None else lora_model_id, + tensors=tensors, + peft_helper=peft_helper, + device=device, + dtype=dtype, + embeddings=embeddings, + target_embedding_padding=target_embedding_padding, + embedding_modules=embedding_modules, + embedding_padding_modules=embedding_padding_modules, + weights_mapper=weights_mapper) diff --git a/vllm_mindspore/lora/ops/__init__.py b/vllm_mindspore/lora/ops/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_mindspore/lora/ops/torch_ops/__init__.py b/vllm_mindspore/lora/ops/torch_ops/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_mindspore/lora/ops/torch_ops/lora_ops.py b/vllm_mindspore/lora/ops/torch_ops/lora_ops.py new file mode 100644 index 00000000..d085c34e --- /dev/null +++ b/vllm_mindspore/lora/ops/torch_ops/lora_ops.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +For punica_npu +""" +from mindspore import mint +from mindspore.ops.auto_generate import grouped_matmul_v4 + +def einsum_ms(inputs, selected_loras): + # mint.einsum("bi, boi -> bo", inputs, selected_loras) + selected_loras = mint.transpose(selected_loras, 1, 2) + outputs = mint.matmul(inputs.unsqueeze(1), selected_loras).squeeze(1) + return outputs + +def sort_lora_by_token_count(lora_indices_tensor, seq_len_tensor): + unique_ids = mint.unique(lora_indices_tensor) + token_sums = [] + for uid in unique_ids: + mask = (lora_indices_tensor == uid) + total_tokens = mint.sum(seq_len_tensor[mask]) + token_sums.append(total_tokens) + token_sums_tensor = mint.stack(token_sums) + sorted_counts, sort_indices = mint.sort(token_sums_tensor, descending=True) + sorted_ids = unique_ids[sort_indices] + return sorted_ids, sorted_counts + +def sgmv_expand(inputs, + lora_b_weights, + output_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + add_inputs = False): + exploded_indices = mint.repeat_interleave(lora_indices_tensor, + seq_len_tensor) + + return bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, + add_inputs) + + +def bgmv_expand(inputs, + lora_b_weights, + output_tensor, + lora_indices_tensor, + add_inputs = True): + selected_loras = lora_b_weights[lora_indices_tensor].astype(output_tensor.dtype) + inputs = inputs.astype(output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(1) + outputs = einsum_ms(inputs, selected_loras) + limit = output_tensor.shape[0] + if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: + limit = 1 + if add_inputs: + output_tensor[:, :outputs.shape[1]] += outputs[:limit, :] + else: + output_tensor[:, :outputs.shape[1]] = outputs[:limit, :] + return output_tensor + + +def sgmv_shrink( + inputs, + lora_a_weights, + output_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + scaling, +): + group_list = seq_len_tensor + if (lora_indices_tensor.unique().shape[0] != lora_indices_tensor.shape[0]): + sorted_ids, sorted_counts = sort_lora_by_token_count(lora_indices_tensor, seq_len_tensor) + group_list = sorted_counts + if lora_a_weights.shape[0] != group_list.shape[0]: + new_tensor = mint.zeros(lora_a_weights.shape[0], dtype=group_list.dtype) + new_tensor[:group_list.size(0)] = group_list + group_list = new_tensor + if len(lora_a_weights.shape) == 4: + lora_a_weights = lora_a_weights.squeeze(1) + lora_a_weights = mint.transpose(lora_a_weights, 1, 2) + outputs = grouped_matmul_v4([inputs], [lora_a_weights], group_list=group_list, split_item=3, group_type=0, group_list_type=1) + outputs = outputs[0] + output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] + return output_tensor + + +def bgmv_shrink(inputs, + lora_b_weights, + output_tensor, + lora_indices_tensor, + scaling = 1.0): + selected_loras = lora_b_weights[lora_indices_tensor].astype(output_tensor.dtype) + inputs = inputs.astype(output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(1) + outputs = einsum_ms(inputs, selected_loras) + output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] + return output_tensor + + +def sgmv_expand_slice(inputs, + lora_b_weights, + output_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + slice_offset, + slice_size, + add_inputs = False): + group_list = seq_len_tensor + if (lora_indices_tensor.unique().shape[0] != lora_indices_tensor.shape[0]): + sorted_ids, sorted_counts = sort_lora_by_token_count(lora_indices_tensor, seq_len_tensor) + group_list = sorted_counts + if lora_b_weights.shape[0] != group_list.shape[0]: + new_tensor = mint.zeros(lora_b_weights.shape[0], dtype=group_list.dtype) + new_tensor[:group_list.size(0)] = group_list + group_list = new_tensor + if len(lora_b_weights.shape) == 4: + lora_b_weights = lora_b_weights.squeeze(1) + lora_b_weights = mint.transpose(lora_b_weights, 1, 2) + inputs = inputs.astype(output_tensor.dtype) + outputs = grouped_matmul_v4([inputs], [lora_b_weights], group_list=group_list, split_item=3, group_type=0, group_list_type=1) + outputs = outputs[0] + if add_inputs: + output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] + else: + output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:] + return output_tensor + + +def bgmv_expand_slice(inputs, + lora_b_weights, + output_tensor, + lora_indices_tensor, + slice_offset, + slice_size, + add_inputs = True): + selected_loras = lora_b_weights[lora_indices_tensor].astype(output_tensor.dtype) + inputs = inputs.astype(output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(1) + outputs = einsum_ms(inputs, selected_loras) + if add_inputs: + output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] + else: + output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:] + return output_tensor \ No newline at end of file diff --git a/vllm_mindspore/lora/punica_wrapper/__init__.py b/vllm_mindspore/lora/punica_wrapper/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_mindspore/lora/punica_wrapper/punica_npu.py b/vllm_mindspore/lora/punica_wrapper/punica_npu.py new file mode 100644 index 00000000..51b41b15 --- /dev/null +++ b/vllm_mindspore/lora/punica_wrapper/punica_npu.py @@ -0,0 +1,357 @@ +#!/usr/bin/env python3 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +refer to https://github.com/vllm-project/vllm-ascend/blob/v0.7.3/vllm_ascend/lora/punica_wrapper/punica_npu.py +""" +from typing import Callable + +from mindspore import mint +from mindspore.common import dtype as mstype +from vllm_mindspore.lora.ops.torch_ops.lora_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) +from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase + + +# The platforms that are compatible with the PyTorch-native implementation can +# inherit this class +class PunicaWrapperNPU(PunicaWrapperBase): + """ + PunicaWrapperNPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the pytorch punica ops. + """ + + def __init__(self, max_num_batched_tokens, max_batches, device, **kwargs): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + + def _shrink_prefill( + self, + y, + x, + w_t_all, + scale, + ): + sgmv_shrink( + x, + w_t_all, + y, + *self.prefill_metadata, + scale, + ) + + def _shrink_decode( + self, + y, + x, + w_t_all, + scale, + ): + bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + + def _expand_prefill( + self, + y, + x, + w_t_all, + add_inputs, + ): + sgmv_expand( + x, + w_t_all, + y, + *self.prefill_metadata, + add_inputs, + ) + + def _expand_decode( + self, + y, + x, + w_t_all, + add_inputs, + ): + bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) + + def _expand_slice_prefill( + self, + y, + x, + w_t_all, + y_offset, + y_slice_size, + add_inputs, + ): + sgmv_expand_slice( + x, + w_t_all, + y, + *self.prefill_metadata, + y_offset, + y_slice_size, + add_inputs, + ) + + def _expand_slice_decode( + self, + y, + x, + w_t_all, + y_offset, + y_slice_size, + add_inputs, + ): + bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, + y_slice_size, add_inputs) + + def _apply_expand( + self, + y, + x, + w_t_all, + y_offset, + y_slice_size, + add_inputs, + ): + """ + Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` + computation, which is suitable for the + GEMM of lora'b. + """ + + expand_slice_fun: Callable = (self._expand_slice_prefill + if self.is_prefill else + self._expand_slice_decode) + expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs) + + def _apply_shrink(self, y, x, w_t_all, scale): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + shrink_fun: Callable = (self._shrink_prefill + if self.is_prefill else self._shrink_decode) + shrink_fun(y, x, w_t_all, scale) + y.view_as(y_org) + + def add_shrink(self, y, x, lora_a_stacked, scale, **kwargs): + """ + Performs GEMM for multiple slices of lora_a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[ms.Tensor, ...], ms.Tensor]): Output tensors + x (ms.Tensor): Input tensor + lora_a_stacked (Tuple[ms.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ + + x = x.view(-1, x.shape[-1]) + # TODO fuse these kernels + for slice_idx in range(len(lora_a_stacked)): + self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], + scale) + + def add_expand(self, + y, + x, + lora_b_stacked, + lora_bias_stacked, + output_slices, + offset_start=0, + add_inputs=True, + **kwargs) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (ms.Tensor): Output tensor. + x (Union[Tuple[ms.Tensor, ...], ms.Tensor]): Input tensors + lora_b_stacked (Tuple[ms.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[ms.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + add_inputs (bool): Defaults to True. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + offset_left = offset_start + if lora_bias_stacked is not None: + self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + for slice_idx in range(len(lora_b_stacked)): + self._apply_expand( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + add_inputs=add_inputs, + ) + offset_left += output_slices[slice_idx] + y.view_as(y_org) + + def add_lora_embedding(self, + y, + x, + lora_b_stacked, + add_inputs=True, + **kwargs) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (ms.Tensor): Output tensor. + x (ms.Tensor): Input tensor. + lora_b_stacked (ms.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + #No LoRA request, so return directly + if self.no_lora: + return + # Embedding layer only need expand op + expand_fun: Callable = (self._expand_prefill + if self.is_prefill else self._expand_decode) + expand_fun(y, x, lora_b_stacked, add_inputs) + + def add_lora_linear(self, + y, + x, + lora_a_stacked, + lora_b_stacked, + lora_bias_stacked, + scale, + output_slices, + *, + buffer=None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (ms.Tensor): Output tensor. Will be changed in-place. + x (ms.Tensor): Input tensor + lora_a_stacked (Tuple[ms.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[ms.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[ms.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[ms.Tensor, ...]]): Defaults to None. + """ + #No LoRA request, so return directly + if self.no_lora: + return + x = x.reshape(-1, x.shape[-1]) + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + y = self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + + if buffer is None: + r = lora_b_stacked[0].shape[-1] + # We set the buffer to be float32 by default, consistent with the + # triton op + buffer = tuple( + mint.zeros((x.shape[0], r), dtype=mstype.float32) + for _ in range(len(output_slices))) + self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) + self.add_expand(y, + buffer, + lora_b_stacked, + None, + output_slices, + add_inputs=True, + **kwargs) + + def add_lora_logits(self, + y, + x, + lora_a_stacked, + lora_b_stacked, + scale, + *, + buffer=None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (ms.Tensor): Output tensor. + x (ms.Tensor): Input tensor. + lora_a_stacked (ms.Tensor): lora_a's weights. + lora_b_stacked (ms.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[ms.Tensor]):Default to None. + """ + #No LoRA request, so return directly + if self.no_lora: + return + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = lora_b_stacked.shape[-1] + if buffer is None: + # We set the buffer to be float32 by default, consistent with the + # triton op + buffer = mint.zeros((x.shape[0], r), dtype=mstype.float32) + # LogitsProcessorWithLoRA always using bgmv. + bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) + bgmv_expand(buffer, + lora_b_stacked, + y, + self.sampler_indices, + add_inputs=True) + y.view_as(y_org) diff --git a/vllm_mindspore/lora/utils.py b/vllm_mindspore/lora/utils.py new file mode 100644 index 00000000..0084e607 --- /dev/null +++ b/vllm_mindspore/lora/utils.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from typing import Set, Type + +from vllm.lora.fully_sharded_layers import ( + ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA, + RowParallelLinearWithShardedLoRA) + +from vllm_mindspore.lora.layers import ( + BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, + LinearScalingRotaryEmbeddingWithLoRA, LogitsProcessorWithLoRA, + MergedColumnParallelLinearWithLoRA, MergedQKVParallelLinearWithLoRA, + QKVParallelLinearWithLoRA, RowParallelLinearWithLoRA, + VocabParallelEmbeddingWithLoRA) + +_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { + VocabParallelEmbeddingWithLoRA, + ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + QKVParallelLinearWithLoRA, + MergedQKVParallelLinearWithLoRA, + RowParallelLinearWithLoRA, + LogitsProcessorWithLoRA, + ColumnParallelLinearWithShardedLoRA, + QKVParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLoRA, + RowParallelLinearWithShardedLoRA, + LinearScalingRotaryEmbeddingWithLoRA, +} diff --git a/vllm_mindspore/model_executor/layers/rotary_embedding.py b/vllm_mindspore/model_executor/layers/rotary_embedding.py index c9dfe254..425714dc 100644 --- a/vllm_mindspore/model_executor/layers/rotary_embedding.py +++ b/vllm_mindspore/model_executor/layers/rotary_embedding.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# encoding: utf-8 # Copyright 2025 Huawei Technologies Co., Ltd # Copyright 2024 The vLLM team. # @@ -58,6 +57,7 @@ def _apply_rotary_emb( class RotaryEmbedding(CustomOp): + def __init__( self, head_size: int, @@ -86,10 +86,8 @@ class RotaryEmbedding(CustomOp): # use CPU to compute the cache and then move it to GPU. However, we # create the cache on GPU for faster initialization. This may cause # a slight numerical difference between the HF implementation and ours. - inv_freq = 1.0 / ( - base - ** (mint.arange(0, self.rotary_dim, 2, dtype=mstype.float32) / self.rotary_dim) - ) + inv_freq = 1.0 / (base**(mint.arange( + 0, self.rotary_dim, 2, dtype=mstype.float32) / self.rotary_dim)) return inv_freq def _compute_cos_sin_cache(self) -> Tensor: @@ -121,14 +119,14 @@ class RotaryEmbedding(CustomOp): query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) - query_rot = query[..., : self.rotary_dim] + query_rot = query[..., :self.rotary_dim] query_pass = query[..., self.rotary_dim:] query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) query = mint.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., : self.rotary_dim] + key_rot = key[..., :self.rotary_dim] key_pass = key[..., self.rotary_dim:] key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) key = mint.cat((key_rot, key_pass), dim=-1).reshape(key_shape) @@ -136,6 +134,7 @@ class RotaryEmbedding(CustomOp): class InferRotaryEmbedding(CustomOp): + def __init__( self, head_size: int, @@ -146,8 +145,9 @@ class InferRotaryEmbedding(CustomOp): dtype, ) -> None: super().__init__() - freqs_base = np.arange(0, rotary_dim, 2)[: (rotary_dim // 2)].astype(np.float32) # (head_dim // 2, ) - freqs = 1.0 / (base ** (freqs_base / rotary_dim)) # (head_dim // 2, ) + freqs_base = np.arange(0, rotary_dim, 2)[:(rotary_dim // 2)].astype( + np.float32) # (head_dim // 2, ) + freqs = 1.0 / (base**(freqs_base / rotary_dim)) # (head_dim // 2, ) mscale = 1.0 t = np.arange(0, max_position_embeddings, 1).astype(np.float32) @@ -170,12 +170,17 @@ class InferRotaryEmbedding(CustomOp): is_prefill: bool, offsets: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: + # ensure that the input tensors of rotary_embedding_op is contiguous + query = query.contiguous() + key = key.contiguous() if is_prefill: - return self.rotary_embedding_op(query, key, self.freqs_cos, self.freqs_sin, batch_valid_length) + return self.rotary_embedding_op(query, key, self.freqs_cos, + self.freqs_sin, batch_valid_length) freqs_cos = self.gather(self.freqs_cos, positions, 0) freqs_sin = self.gather(self.freqs_sin, positions, 0) - return self.rotary_embedding_op(query, key, freqs_cos, freqs_sin, batch_valid_length) + return self.rotary_embedding_op(query, key, freqs_cos, freqs_sin, + batch_valid_length) class MRotaryEmbedding(RotaryEmbedding): diff --git a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py index e3407f51..b694075d 100644 --- a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# encoding: utf-8 # Copyright 2025 Huawei Technologies Co., Ltd # Copyright 2024 The vLLM team. # @@ -23,16 +22,15 @@ from mindspore import Parameter, Tensor, mint, nn, ops from mindspore.common import dtype as mstype from mindspore.common.dtype import typing from vllm.distributed import (divide, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce,) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.quantization.base_config import \ + QuantizationConfig +from vllm_mindspore.distributed.communication_op import \ + ReduceFromModelParallelRegion from vllm_mindspore.model_executor.layers.quantization.base_config import ( QuantizeMethodBase, method_has_implemented_embedding) from vllm_mindspore.model_executor.utils import set_weight_attrs -from vllm_mindspore.distributed.communication_op import ReduceFromModelParallelRegion -from mindspore import jit DEFAULT_VOCAB_PADDING_SIZE = 64 @@ -40,15 +38,13 @@ DEFAULT_VOCAB_PADDING_SIZE = 64 class UnquantizedEmbeddingMethod(QuantizeMethodBase): """Unquantized method for embeddings.""" - def create_weights(self, layer: nn.Cell, - input_size_per_partition: int, + def create_weights(self, layer: nn.Cell, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, - output_size: int, params_dtype, - **extra_weight_attrs): + output_size: int, params_dtype, **extra_weight_attrs): """Create weights for embedding layer.""" - weight = Parameter(mint.zeros((sum(output_partition_sizes), - input_size_per_partition), - dtype=params_dtype), + weight = Parameter(mint.zeros( + (sum(output_partition_sizes), input_size_per_partition), + dtype=params_dtype), requires_grad=False) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) layer.insert_param_to_cell("weight", weight) @@ -64,7 +60,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): layer: nn.Cell, x: Tensor, bias: Optional[Tensor] = None) -> Tensor: - output_shape = x.shape[:-1] + (self.output_size_per_partition,) + output_shape = x.shape[:-1] + (self.output_size_per_partition, ) x = x.reshape(-1, self.input_size_per_partition) x = self.matmul(x, layer.weight) if bias is not None: @@ -72,8 +68,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): x = x.reshape(output_shape) return x - def embedding(self, layer: nn.Cell, - input_: Tensor) -> Tensor: + def embedding(self, layer: nn.Cell, input_: Tensor) -> Tensor: return self.gather(layer.weight, input_, 0) @@ -87,12 +82,15 @@ def get_masked_input_and_mask( ) -> Tuple[Tensor, Tensor]: displaced_x = mint.sub(input_, org_vocab_start_index) down_truncated_x = mint.nn.functional.relu(displaced_x) - truncated_x = mint.minimum(down_truncated_x, (org_vocab_end_index - org_vocab_start_index - 1)) + truncated_x = mint.minimum( + down_truncated_x, (org_vocab_end_index - org_vocab_start_index - 1)) org_vocab_mask = mint.eq(displaced_x, truncated_x) displaced_x = mint.sub(input_, added_vocab_start_index) down_truncated_x = mint.nn.functional.relu(displaced_x) - truncated_x = mint.minimum(down_truncated_x, (added_vocab_end_index - added_vocab_start_index - 1)) + truncated_x = mint.minimum( + down_truncated_x, + (added_vocab_end_index - added_vocab_start_index - 1)) added_vocab_mask = mint.eq(displaced_x, truncated_x) added_offset = added_vocab_start_index - ( org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding @@ -103,26 +101,29 @@ def get_masked_input_and_mask( return input_, vocab_mask.expand_dims(-1) -def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: +def pad_vocab_size(vocab_size: int, + pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: """Pad the vocab size to the given value.""" return ((vocab_size + pad_to - 1) // pad_to) * pad_to def vocab_range_from_per_partition_vocab_size( - per_partition_vocab_size: int, rank: int, offset: int = 0 -) -> Sequence[int]: + per_partition_vocab_size: int, + rank: int, + offset: int = 0) -> Sequence[int]: index_f = rank * per_partition_vocab_size index_l = index_f + per_partition_vocab_size return index_f + offset, index_l + offset -def vocab_range_from_global_vocab_size( - global_vocab_size: int, rank: int, world_size: int, offset: int = 0 -) -> Sequence[int]: +def vocab_range_from_global_vocab_size(global_vocab_size: int, + rank: int, + world_size: int, + offset: int = 0) -> Sequence[int]: per_partition_vocab_size = divide(global_vocab_size, world_size) - return vocab_range_from_per_partition_vocab_size( - per_partition_vocab_size, rank, offset=offset - ) + return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, + rank, + offset=offset) @dataclass @@ -185,6 +186,7 @@ class VocabParallelEmbeddingShardIndices: class VocabParallelEmbedding(nn.Cell): + def __init__( self, num_embeddings: int, @@ -203,12 +205,11 @@ class VocabParallelEmbedding(nn.Cell): self.padding_size = padding_size self.org_vocab_size = org_num_embeddings or num_embeddings num_added_embeddings = num_embeddings - self.org_vocab_size - self.org_vocab_size_padded = pad_vocab_size( - self.org_vocab_size, self.padding_size - ) + self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, + self.padding_size) self.num_embeddings_padded = pad_vocab_size( - self.org_vocab_size_padded + num_added_embeddings, self.padding_size - ) + self.org_vocab_size_padded + num_added_embeddings, + self.padding_size) assert self.org_vocab_size_padded <= self.num_embeddings_padded self.shard_indices = self._get_indices( @@ -233,13 +234,11 @@ class VocabParallelEmbedding(nn.Cell): # layer type like ParallelLMHead, this is not important. is_embedding_layer = type(self) is VocabParallelEmbedding quant_method_implements_embedding = method_has_implemented_embedding( - type(quant_method) - ) + type(quant_method)) if is_embedding_layer and not quant_method_implements_embedding: raise NotImplementedError( f"The class {type(quant_method).__name__} must implement " - "the 'embedding' method, see UnquantizedEmbeddingMethod." - ) + "the 'embedding' method, see UnquantizedEmbeddingMethod.") self.quant_method: QuantizeMethodBase = quant_method @@ -247,20 +246,16 @@ class VocabParallelEmbedding(nn.Cell): params_dtype = mstype.float16 # Divide the weight matrix along the vocaburaly dimension. self.num_added_embeddings = self.num_embeddings - self.org_vocab_size - self.num_embeddings_per_partition = divide( - self.num_embeddings_padded, self.tp_size - ) - assert ( - self.shard_indices.num_elements_padded == self.num_embeddings_per_partition - ) + self.num_embeddings_per_partition = divide(self.num_embeddings_padded, + self.tp_size) + assert (self.shard_indices.num_elements_padded == + self.num_embeddings_per_partition) self.num_org_embeddings_per_partition = ( - self.shard_indices.org_vocab_end_index - - self.shard_indices.org_vocab_start_index - ) + self.shard_indices.org_vocab_end_index - + self.shard_indices.org_vocab_start_index) self.num_added_embeddings_per_partition = ( - self.shard_indices.added_vocab_end_index - - self.shard_indices.added_vocab_start_index - ) + self.shard_indices.added_vocab_end_index - + self.shard_indices.added_vocab_start_index) self.quant_method.create_weights( self, @@ -288,17 +283,19 @@ class VocabParallelEmbedding(nn.Cell): tp_size.""" num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded padded_org_vocab_start_index, padded_org_vocab_end_index = ( - vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank, tp_size) - ) + vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank, + tp_size)) padded_added_vocab_start_index, padded_added_vocab_end_index = ( - vocab_range_from_global_vocab_size( - num_added_embeddings_padded, tp_rank, tp_size, offset=org_vocab_size - ) - ) + vocab_range_from_global_vocab_size(num_added_embeddings_padded, + tp_rank, + tp_size, + offset=org_vocab_size)) # remove padding - org_vocab_start_index = min(padded_org_vocab_start_index, org_vocab_size) + org_vocab_start_index = min(padded_org_vocab_start_index, + org_vocab_size) org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size) - added_vocab_start_index = min(padded_added_vocab_start_index, vocab_size) + added_vocab_start_index = min(padded_added_vocab_start_index, + vocab_size) added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size) return VocabParallelEmbeddingShardIndices( padded_org_vocab_start_index, @@ -311,18 +308,15 @@ class VocabParallelEmbedding(nn.Cell): added_vocab_end_index, ) - @jit def construct(self, input_): if self.tp_size > 1: # Build the mask. masked_input, input_mask = get_masked_input_and_mask( - input_, - self.shard_indices.org_vocab_start_index, + input_, self.shard_indices.org_vocab_start_index, self.shard_indices.org_vocab_end_index, self.shard_indices.num_org_vocab_padding, self.shard_indices.added_vocab_start_index, - self.shard_indices.added_vocab_end_index - ) + self.shard_indices.added_vocab_end_index) else: masked_input, input_mask = input_, None # Get the embeddings. @@ -354,11 +348,13 @@ class VocabParallelEmbedding(nn.Cell): if loaded_weight.shape[output_dim] != self.org_vocab_size: raise ValueError( f"'loaded_weight.shape[output_dim]' should be equal to 'org_vocab_size'," - f" but got {loaded_weight.shape[output_dim]} and {self.org_vocab_size}") + f" but got {loaded_weight.shape[output_dim]} and {self.org_vocab_size}" + ) # Copy the data. - loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size).contiguous() - param[: loaded_weight.shape[0]] = loaded_weight + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size).contiguous() + param[:loaded_weight.shape[0]] = loaded_weight param[loaded_weight.shape[0]:] = 0 @@ -401,8 +397,8 @@ class ParallelLMHead(VocabParallelEmbedding): self.quant_config = quant_config if bias: self.bias = Parameter( - mint.zeros(self.num_embeddings_per_partition, dtype=params_dtype) - ) + mint.zeros(self.num_embeddings_per_partition, + dtype=params_dtype)) set_weight_attrs( self.bias, { diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 0d933a2d..c980050e 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# encoding: utf-8 # Copyright 2025 Huawei Technologies Co., Ltd # Copyright 2024 The vLLM team. # @@ -18,68 +17,62 @@ import os from abc import abstractmethod -from typing import Iterable, List, Optional, Set, Tuple, Union, Dict +from typing import Dict, Iterable, Optional, Set, Tuple, Union +import torch +from mindspore import Tensor, mutable, nn +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.layer import Attention from vllm.config import VllmConfig, get_current_vllm_config +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.attention.backends.abstract import AttentionType -from vllm.forward_context import get_forward_context -from vllm.attention.layer import Attention - -import torch - -from mindspore import Tensor, nn, mutable class Fake_Attention: + def __init__(self): vllm_config = get_current_vllm_config() block_size = vllm_config.cache_config.block_size num_kv_heads = vllm_config.model_config.get_num_kv_heads( - vllm_config.parallel_config - ) + vllm_config.parallel_config) head_size = vllm_config.model_config.get_head_size() num_block = 0 self.kv_shape = [num_block, block_size, num_kv_heads, head_size] - self.kv_cache = [ - ( - torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), - torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), - ) - for _ in range(vllm_config.parallel_config.pipeline_parallel_size) - ] + self.kv_cache = [( + torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), + torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), + ) for _ in range(vllm_config.parallel_config.pipeline_parallel_size)] self.attn_type = AttentionType.DECODER class Fake_MLA(Fake_Attention): + def __init__(self): super().__init__() vllm_config = get_current_vllm_config() self.kv_cache = [ - (torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"),) + (torch.zeros(self.kv_shape, dtype=torch.bfloat16, + device="Ascend"), ) for _ in range(vllm_config.parallel_config.pipeline_parallel_size) ] class Fake_Attention_V1(Attention): + def __init__(self): vllm_config = get_current_vllm_config() block_size = vllm_config.cache_config.block_size num_kv_heads = vllm_config.model_config.get_num_kv_heads( - vllm_config.parallel_config - ) + vllm_config.parallel_config) head_size = vllm_config.model_config.get_head_size() num_block = 0 self.kv_shape = [num_block, block_size, num_kv_heads, head_size] - self.kv_cache = [ - ( - torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), - torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), - ) - for _ in range(vllm_config.parallel_config.pipeline_parallel_size) - ] + self.kv_cache = [( + torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), + torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), + ) for _ in range(vllm_config.parallel_config.pipeline_parallel_size)] self.attn_type = AttentionType.DECODER self.num_block = num_block self.num_kv_heads = num_kv_heads @@ -90,18 +83,21 @@ class Fake_Attention_V1(Attention): class Fake_MLA_V1(Fake_Attention_V1): + def __init__(self): super().__init__() vllm_config = get_current_vllm_config() self.kv_cache = [ - (torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"),) + (torch.zeros(self.kv_shape, dtype=torch.bfloat16, + device="Ascend"), ) for _ in range(vllm_config.parallel_config.pipeline_parallel_size) ] -class MsModelBase(): +class MsModelBase: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: - super(MsModelBase, self).__init__() + super().__init__() config = vllm_config.model_config.hf_config lora_config = vllm_config.lora_config @@ -125,7 +121,8 @@ class MsModelBase(): if os.path.isdir(model_name_or_path): return model_name_or_path else: - from vllm.model_executor.model_loader.weight_utils import download_weights_from_hf + from vllm.model_executor.model_loader.weight_utils import \ + download_weights_from_hf allow_patterns = ["*.safetensors"] revision = self.model_config.revision return download_weights_from_hf( @@ -171,15 +168,24 @@ class MsModelBase(): def named_modules(self, remove_duplicate: bool = True): self._check_modules_valid() - res_modules = set() for name, module in self.modules_dict.items(): for module_name, sub_module in module.cells_and_names(): if name != "self": module_name = name + "." + module_name yield module_name, sub_module - def get_submodule(self): - raise RuntimeError("Cannot get submodule for mindspore model now!") + def get_submodule(self, target: str): + parts = target.split(".") + if target == "": + return self + for part in parts: + if not part: + raise ValueError( + f"Invalid submodule path: empty part in '{target}'") + current = self + for part in parts: + current = getattr(current, part) + return current def eval(self): self._check_modules_valid() @@ -198,23 +204,19 @@ class MsModelBase(): previous_hidden_states: Optional[Tensor] = None, spec_step_idx: int = 0, ) -> Union[Tensor, IntermediateTensors]: - return self.forward( - input_ids, - positions, - intermediate_tensors, - inputs_embeds, - previous_hidden_states=previous_hidden_states, - spec_step_idx=spec_step_idx - ) - - def forward( - self, - input_ids: Tensor, - positions: Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[Tensor] = None, - **kwargs - ) -> Union[Tensor, IntermediateTensors]: + return self.forward(input_ids, + positions, + intermediate_tensors, + inputs_embeds, + previous_hidden_states=previous_hidden_states, + spec_step_idx=spec_step_idx) + + def forward(self, + input_ids: Tensor, + positions: Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[Tensor] = None, + **kwargs) -> Union[Tensor, IntermediateTensors]: raise NotImplementedError def set_model_inputs(self, is_prefill): @@ -264,8 +266,10 @@ class MsModelBase(): value_cache = [] forward_context = get_forward_context() for i in range(self.config.num_hidden_layers): - k_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][0] - v_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][1] + k_cache = self.kv_caches[i].kv_cache[ + forward_context.virtual_engine][0] + v_cache = self.kv_caches[i].kv_cache[ + forward_context.virtual_engine][1] key_cache.append(k_cache) value_cache.append(v_cache) return mutable(key_cache), mutable(value_cache) @@ -276,7 +280,8 @@ class MsModelBase(): hidden_states: Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[Tensor]: - raise NotImplementedError("Function compute_logits should be Implemented!") + raise NotImplementedError( + "Function compute_logits should be Implemented!") @abstractmethod def sample( @@ -288,4 +293,5 @@ class MsModelBase(): @abstractmethod def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: - raise NotImplementedError("Function load_weights should be Implemented!") + raise NotImplementedError( + "Function load_weights should be Implemented!") diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index 444ddc5a..16adcacf 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -# encoding: utf-8 +# type: ignore +# isort:skip_file # Copyright 2025 Huawei Technologies Co., Ltd # Copyright 2024 The vLLM team. # @@ -15,21 +16,28 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union, Iterable +from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple, + Union) if TYPE_CHECKING: from transformers import Qwen2Config else: Qwen2Config = None +import mindspore as ms import numpy as np - -from mindspore import Parameter, Tensor, mint, nn, jit, ops, mutable +import vllm.envs as envs +from mindspore import Parameter, Tensor, mint, mutable, nn, ops from mindspore.common import dtype as mstype - +from vllm.attention.backends.abstract import AttentionType +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.models.interfaces import SupportsLoRA +from vllm.sequence import IntermediateTensors from vllm_mindspore.attention import Attention - from vllm_mindspore.model_executor.layers.activation import SwiGLU from vllm_mindspore.model_executor.layers.layernorm import RMSNorm from vllm_mindspore.model_executor.layers.linear import ( @@ -43,27 +51,22 @@ from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm_mindspore.model_executor.model_loader.weight_utils import \ default_weight_loader +from vllm_mindspore.model_executor.models.attention_mask import \ + LowerTriangularMask +from vllm_mindspore.model_executor.models.model_base import (Fake_Attention, + Fake_Attention_V1, + MsModelBase) from vllm_mindspore.model_executor.models.utils import ( - PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) + PPMissingLayer, _jit, make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix, set_enforce_eager) from vllm_mindspore.model_executor.sampling_metadata import SamplingMetadata -from vllm_mindspore.model_executor.models.model_base import MsModelBase, Fake_Attention, Fake_Attention_V1 -from vllm_mindspore.model_executor.models.attention_mask import LowerTriangularMask from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE +from vllm_mindspore.v1.attention.backends.flash_attn import \ + FlashAttentionMetadata -from vllm.config import CacheConfig, VllmConfig -import vllm.envs as envs -from vllm.model_executor.layers.quantization import \ - QuantizationConfig -from vllm.sequence import IntermediateTensors -from vllm.attention.backends.abstract import AttentionType -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.forward_context import get_forward_context -from vllm_mindspore.v1.attention.backends.flash_attn import FlashAttentionMetadata -import mindspore as ms - class Qwen2MLP(nn.Cell): + def __init__( self, hidden_size: int, @@ -80,22 +83,18 @@ class Qwen2MLP(nn.Cell): bias=bias, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", - params_dtype=mstype.bfloat16 - ) - self.down_proj = RowParallelLinear( - input_size=intermediate_size, - output_size=hidden_size, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj", - params_dtype=mstype.bfloat16 - ) + params_dtype=mstype.bfloat16) + self.down_proj = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + params_dtype=mstype.bfloat16) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") self.act_fn = SwiGLU() - @jit def construct(self, x): x, _ = self.gate_up_proj(x) x = self.act_fn(x) @@ -104,19 +103,18 @@ class Qwen2MLP(nn.Cell): class Qwen2Attention(nn.Cell): - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - rope_scaling: Optional[Tuple] = None, - prefix: str = "", - attn_type: str = AttentionType.DECODER - ) -> None: + + def __init__(self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[Tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -166,18 +164,15 @@ class Qwen2Attention(nn.Cell): rope_scaling=rope_scaling, dtype=mstype.bfloat16, ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - attn_type=attn_type - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=attn_type) - @jit def construct( self, positions: Tensor, @@ -192,10 +187,12 @@ class Qwen2Attention(nn.Cell): block_tables: Tensor, ) -> Tensor: qkv, _ = self.qkv_proj(hidden_states) - q, k, v = mint.split(qkv, (self.q_size, self.kv_size, self.kv_size), -1) + q, k, v = mint.split(qkv, (self.q_size, self.kv_size, self.kv_size), + -1) q, k = self.rotary_emb(positions, q, k, batch_valid_length, is_prefill) - attn_output = self.attn(q, k, v, key_cache, value_cache, is_prefill, slot_mapping, attn_mask, - batch_valid_length, q_seq_lens, block_tables) + attn_output = self.attn(q, k, v, key_cache, value_cache, is_prefill, + slot_mapping, attn_mask, batch_valid_length, + q_seq_lens, block_tables) output, _ = self.o_proj(attn_output) return output @@ -243,14 +240,17 @@ class Qwen2DecoderLayer(nn.Cell): quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps, - params_dtype=mstype.bfloat16,) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps, - params_dtype=mstype.bfloat16,) - - @jit + self.input_layernorm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + params_dtype=mstype.bfloat16, + ) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + params_dtype=mstype.bfloat16, + ) + def construct( self, positions: Tensor, @@ -270,22 +270,16 @@ class Qwen2DecoderLayer(nn.Cell): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) - hidden_states = self.self_attn( - positions, - hidden_states, - key_cache, - value_cache, - is_prefill, - slot_mapping, - attn_mask, - batch_valid_length, - q_seq_lens, - block_tables - ) + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn(positions, hidden_states, key_cache, + value_cache, is_prefill, slot_mapping, + attn_mask, batch_valid_length, + q_seq_lens, block_tables) # Fully Connected - hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -302,6 +296,9 @@ class Qwen2Model(nn.Cell): self.config = config self.quant_config = quant_config self.vocab_size = config.vocab_size + if vllm_config.lora_config is not None: + vllm_config.model_config.enforce_eager = True + set_enforce_eager(vllm_config.model_config.enforce_eager) if get_pp_group().is_first_rank or (config.tie_word_embeddings and get_pp_group().is_last_rank): @@ -328,15 +325,18 @@ class Qwen2Model(nn.Cell): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) if get_pp_group().is_last_rank: - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, - params_dtype=mstype.bfloat16,) + self.norm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + params_dtype=mstype.bfloat16, + ) else: self.norm = PPMissingLayer() def get_input_embeddings(self, input_ids: Tensor) -> Tensor: return self.embed_tokens(input_ids) - @jit + @_jit def construct( self, input_ids: Optional[Tensor], @@ -364,19 +364,12 @@ class Qwen2Model(nn.Cell): for i in range(self.start_layer, self.end_layer): # PP 并行对层进行切分 layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - key_caches[i - self.start_layer], - value_caches[i - self.start_layer], - is_prefill, - slot_mapping, - attn_mask, - batch_valid_length, - q_seq_lens, - block_tables, - residual - ) + hidden_states, residual = layer(positions, hidden_states, + key_caches[i - self.start_layer], + value_caches[i - self.start_layer], + is_prefill, slot_mapping, + attn_mask, batch_valid_length, + q_seq_lens, block_tables, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -385,7 +378,8 @@ class Qwen2Model(nn.Cell): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, Tensor]], params_dict: Dict[str, Parameter]): + def load_weights(self, weights: Iterable[Tuple[str, Tensor]], + params_dict: Dict[str, Parameter]): loaded_params: Set[str] = set() stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -405,7 +399,7 @@ class Qwen2Model(nn.Cell): # the checkpoint. Skip them. continue if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + (scale_name := self.quant_config.get_cache_scale(name))): # Loading kv cache quantization scales param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", @@ -436,7 +430,7 @@ class Qwen2Model(nn.Cell): return loaded_params -class Qwen2ForCausalLM(MsModelBase): +class Qwen2ForCausalLM(MsModelBase, SupportsLoRA): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -480,7 +474,8 @@ class Qwen2ForCausalLM(MsModelBase): config.hidden_size, params_dtype=mstype.bfloat16, quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head")) + prefix=maybe_prefix( + prefix, "lm_head")) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = get_sampler() else: @@ -491,20 +486,26 @@ class Qwen2ForCausalLM(MsModelBase): self.set_modules({"model": self.model, "lm_head": self.lm_head}) self.prefill = True - self.mstype = STR_DTYPE_TO_MS_DTYPE.get(self.model_config.dtype, self.model_config.dtype) - self.casual_mask = LowerTriangularMask(dtype=self.mstype, - max_model_len=self.model_config.max_model_len) + self.mstype = STR_DTYPE_TO_MS_DTYPE.get(self.model_config.dtype, + self.model_config.dtype) + self.casual_mask = LowerTriangularMask( + dtype=self.mstype, max_model_len=self.model_config.max_model_len) self.set_model_inputs(self.prefill) if envs.VLLM_USE_V1: - self.kv_caches = [Fake_Attention_V1() for i in range(config.num_hidden_layers)] + self.kv_caches = [ + Fake_Attention_V1() for i in range(config.num_hidden_layers) + ] else: - self.kv_caches = [Fake_Attention() for i in range(config.num_hidden_layers)] + self.kv_caches = [ + Fake_Attention() for i in range(config.num_hidden_layers) + ] compilation_config = vllm_config.compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") for i in range(config.num_hidden_layers): - compilation_config.static_forward_context[str(i)] = self.kv_caches[i] + compilation_config.static_forward_context[str( + i)] = self.kv_caches[i] def set_model_inputs(self, is_prefill): dyn_input_ids = Tensor(shape=[None, None], dtype=mstype.int64) @@ -525,43 +526,40 @@ class Qwen2ForCausalLM(MsModelBase): dyn_key_cache = Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype) dyn_value_cache = Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype) dyn_key_caches = mutable([dyn_key_cache for _ in range(num_layers)]) - dyn_value_caches = mutable([dyn_value_cache for _ in range(num_layers)]) + dyn_value_caches = mutable( + [dyn_value_cache for _ in range(num_layers)]) - dyn_slot_mapping = Tensor(shape=[None, ], dtype=mstype.int32) + dyn_slot_mapping = Tensor(shape=[ + None, + ], dtype=mstype.int32) dynamic_attention_mask = Tensor(shape=[None, None], dtype=self.mstype) - dyn_batch_valid_length = Tensor(shape=[None,], dtype=mstype.int32) - dyn_q_seq_lens = Tensor(shape=[None, ], dtype=mstype.int32) + dyn_batch_valid_length = Tensor(shape=[ + None, + ], dtype=mstype.int32) + dyn_q_seq_lens = Tensor(shape=[ + None, + ], dtype=mstype.int32) dyn_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) dyn_intermediate_tensors = None dyn_inputs_embeds = None - self.model.set_inputs( - dyn_input_ids, - dyn_position_ids, - dyn_key_caches, - dyn_value_caches, - is_prefill, - dyn_slot_mapping, - dynamic_attention_mask, - dyn_batch_valid_length, - dyn_q_seq_lens, - dyn_block_tables, - dyn_intermediate_tensors, - dyn_inputs_embeds - ) - - def forward( - self, - input_ids: Tensor, - positions: Tensor, - intermediate_tensors: IntermediateTensors = None, - inputs_embeds: Tensor = None, - **kwargs - ) -> Union[Tensor, IntermediateTensors]: + self.model.set_inputs(dyn_input_ids, dyn_position_ids, dyn_key_caches, + dyn_value_caches, is_prefill, dyn_slot_mapping, + dynamic_attention_mask, dyn_batch_valid_length, + dyn_q_seq_lens, dyn_block_tables, + dyn_intermediate_tensors, dyn_inputs_embeds) + + def forward(self, + input_ids: Tensor, + positions: Tensor, + intermediate_tensors: IntermediateTensors = None, + inputs_embeds: Tensor = None, + **kwargs) -> Union[Tensor, IntermediateTensors]: key_cache, value_cache = self.get_kvcache() attn_metadata = get_forward_context().attn_metadata input_ids = input_ids.to(ms.int64) if attn_metadata is None: - attn_metadata = self._dummy_attention_metadata(input_ids, positions) + attn_metadata = self._dummy_attention_metadata( + input_ids, positions) if not envs.VLLM_USE_V1: seq_lens = attn_metadata.seq_lens max_query_len = attn_metadata.max_query_len @@ -576,24 +574,25 @@ class Qwen2ForCausalLM(MsModelBase): seq_lens_np = np.array(seq_lens, dtype=np.int32) query_lens_np = np.array(query_lens, dtype=np.int32) kv_cache_lens = seq_lens_np - query_lens_np - is_prefill = attn_metadata.num_decode_tokens == 0 and kv_cache_lens.max() == 0 + is_prefill = attn_metadata.num_decode_tokens == 0 and kv_cache_lens.max( + ) == 0 slot_mapping = attn_metadata.slot_mapping - batch_valid_length = Tensor.from_numpy(np.array(attn_metadata.seq_lens, dtype=np.int32)) + batch_valid_length = Tensor.from_numpy( + np.array(attn_metadata.seq_lens, dtype=np.int32)) q_seq_lens = ms.Tensor(query_lens_np, dtype=ms.int32) block_tables = attn_metadata.block_tables position_ids = ms.Tensor(positions, dtype=ms.int32) - attn_mask = self.casual_mask.gen_attention_mask(is_prefill, position_ids, query_lens) + attn_mask = self.casual_mask.gen_attention_mask( + is_prefill, position_ids, query_lens) else: - if attn_metadata.max_context_lens == 0: - is_prefill = True - else: - is_prefill = False + is_prefill = attn_metadata.max_context_lens == 0 slot_mapping = attn_metadata.slot_mapping batch_valid_length = Tensor.from_numpy(attn_metadata.seq_lens_np) q_seq_lens = attn_metadata.q_seq_lens block_tables = attn_metadata.block_tables query_lens_np = attn_metadata.q_seq_lens_np - attn_mask = self.casual_mask.gen_attention_mask(is_prefill, positions, query_lens_np) + attn_mask = self.casual_mask.gen_attention_mask( + is_prefill, positions, query_lens_np) positions = positions.to(ms.int64) if is_prefill: input_ids = ops.expand_dims(input_ids, 0) @@ -623,7 +622,8 @@ class Qwen2ForCausalLM(MsModelBase): model_output = ops.squeeze(model_output, 1) return model_output - def _dummy_attention_metadata(self, input_ids: Tensor, positions: Tensor) -> FlashAttentionMetadata: + def _dummy_attention_metadata(self, input_ids: Tensor, + positions: Tensor) -> FlashAttentionMetadata: input_len = input_ids.shape[0] max_seq_len = ms.Tensor(input_len, dtype=ms.int32) seq_lengths = ms.Tensor([input_len], dtype=ms.int32) @@ -646,16 +646,14 @@ class Qwen2ForCausalLM(MsModelBase): # To enforce prefill and decode are both complied in warmup process. # So set max_context_lens to 0 for prefill and 1 for decode. max_context_lens=0 if self.prefill else 1, - query_start_loc = None - ) + query_start_loc=None) def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: params_dict = self.get_params_dict() self.model.load_weights(weights, params_dict) - def sample( - self, logits: Tensor, sampling_metadata: SamplingMetadata - ) -> Optional[SamplerOutput]: + def sample(self, logits: Tensor, + sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens @@ -664,5 +662,6 @@ class Qwen2ForCausalLM(MsModelBase): hidden_states: Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) return logits diff --git a/vllm_mindspore/model_executor/models/utils.py b/vllm_mindspore/model_executor/models/utils.py index 1acd616d..252ede54 100644 --- a/vllm_mindspore/model_executor/models/utils.py +++ b/vllm_mindspore/model_executor/models/utils.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# encoding: utf-8 # Copyright 2025 Huawei Technologies Co., Ltd # Copyright 2024 The vLLM team. # @@ -18,9 +17,11 @@ from dataclasses import dataclass, field from typing import List, Tuple, Union, Mapping, Optional, Iterable +from functools import wraps +from typing import List, Tuple import mindspore as ms -from mindspore import mint +from mindspore import jit, mint from mindspore import ops from vllm.sequence import IntermediateTensors @@ -70,6 +71,7 @@ class WeightsMapper: ) -> Iterable[Tuple[str, ms.Tensor]]: return ((out_name, data) for name, data in weights if (out_name := self._map_name(name)) is not None) +enforce_eager = False class PPMissingLayer(ms.nn.Cell): """ @@ -117,9 +119,8 @@ def extract_layer_index(layer_name: str) -> int: int_vals.append(int(subname)) except ValueError: continue - assert len(int_vals) == 1, ( - f"layer name {layer_name} should" " only contain one integer" - ) + assert len(int_vals) == 1, (f"layer name {layer_name} should" + " only contain one integer") return int_vals[0] @@ -134,17 +135,13 @@ def make_layers( from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.utils import get_pp_indices - start_layer, end_layer = get_pp_indices( - num_hidden_layers, get_pp_group().rank_in_group, get_pp_group().world_size - ) - modules = ms.nn.CellList( - [PPMissingLayer() for _ in range(start_layer)] - + [ - maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}")) - for idx in range(start_layer, end_layer) - ] - + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)] - ) + start_layer, end_layer = get_pp_indices(num_hidden_layers, + get_pp_group().rank_in_group, + get_pp_group().world_size) + modules = ms.nn.CellList([PPMissingLayer() for _ in range(start_layer)] + [ + maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}")) + for idx in range(start_layer, end_layer) + ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]) return start_layer, end_layer, modules @@ -156,9 +153,10 @@ def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int): device, ) -> IntermediateTensors: dtype = get_valid_dtype(dtype) - return IntermediateTensors( - {key: mint.zeros((batch_size, hidden_size), dtype=dtype) for key in keys} - ) + return IntermediateTensors({ + key: mint.zeros((batch_size, hidden_size), dtype=dtype) + for key in keys + }) return make_empty_intermediate_tensors @@ -262,4 +260,28 @@ def merge_multimodal_embeddings( inputs_embeds, (input_ids == placeholder_token_id), multimodal_embeddings, - ) \ No newline at end of file + ) +def set_enforce_eager(value): + """ + set global variable enforce_eager to value. + """ + global enforce_eager + enforce_eager = value + + +def _jit(func): + """ + A decorator to apply JIT compilation to a function or method. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + if enforce_eager: + # If enforce_eager is True, we do not apply JIT compilation. + return func(*args, **kwargs) + if hasattr(func, "__wrapped_by_jit__"): + # If the function is already wrapped by JIT, we call it directly. + return func(*args, **kwargs) + return jit(func, jit_level="O0", infer_boost="on")(*args, **kwargs) + + return wrapper diff --git a/vllm_mindspore/platforms/ascend.py b/vllm_mindspore/platforms/ascend.py index 356a33a0..89e22829 100644 --- a/vllm_mindspore/platforms/ascend.py +++ b/vllm_mindspore/platforms/ascend.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# encoding: utf-8 # Copyright 2025 Huawei Technologies Co., Ltd # Copyright 2024 The vLLM team. # @@ -17,15 +16,12 @@ # ============================================================================ """Ascend platform.""" -import os -from typing import (TYPE_CHECKING, Optional, Union, Tuple) +from typing import TYPE_CHECKING, Optional, Tuple, Union import torch -import mindspore as ms - -from vllm.platforms.interface import DeviceCapability, Platform, PlatformEnum, _Backend -from vllm.logger import init_logger import vllm.envs as envs +from vllm.logger import init_logger +from vllm.platforms.interface import Platform, PlatformEnum, _Backend if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig @@ -40,7 +36,7 @@ class AscendPlatform(Platform): _enum = PlatformEnum.OOT device_name: str = "npu" - device_type: str = "cuda" # To use cuda worker, executor... + device_type: str = "cuda" # To use cuda worker, executor... simple_compile_backend: str = "npu" ray_device_key: str = "NPU" device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES" @@ -103,7 +99,8 @@ class AscendPlatform(Platform): model_config.disable_cascade_attn = True @classmethod - def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla): + def get_attn_backend_cls(cls, selected_backend, head_size, dtype, + kv_cache_dtype, block_size, use_v1, use_mla): """Get the attention backend class of a device.""" if use_v1: if use_mla: @@ -119,12 +116,13 @@ class AscendPlatform(Platform): return "vllm_mindspore.attention.backends.ms_attn.MsAttentionBackend" raise ValueError( - "Invaild attention backend %s for vLLM-MindSpore with head_size: %s, dtype: %s, kv_cache_dtype: %s, block_size: %s." - % (str(selected_backend), str(head_size), str(dtype), str(kv_cache_dtype), str(block_size)) + f"Invalid attention backend {str(selected_backend)} for vLLM-MindSpore with head_size: {str(head_size)}, dtype: {str(dtype)}, kv_cache_dtype: {str(kv_cache_dtype)}, block_size: {str(block_size)}." ) @classmethod - def get_current_memory_usage(cls, device: Optional[torch.types.Device] = None) -> float: + def get_current_memory_usage(cls, + device: Optional[torch.types.Device] = None + ) -> float: """Return the memory usage in bytes.""" torch.cuda.reset_peak_memory_stats() return torch.cuda.max_memory_allocated(device) @@ -144,4 +142,7 @@ class AscendPlatform(Platform): @classmethod def supports_v1(cls, model_config: ModelConfig) -> bool: - return True \ No newline at end of file + return True + + def get_punica_wrapper(cls) -> str: + return "vllm_mindspore.lora.punica_wrapper.punica_npu.PunicaWrapperNPU" -- Gitee