diff --git a/tests/st/python/multilora_inference.py b/tests/st/python/multilora_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..7dd07c8e88fb102eb03db1a73f41a3b1124df500 --- /dev/null +++ b/tests/st/python/multilora_inference.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +This example shows how to use the multi-LoRA functionality +for offline inference. + +""" +import vllm_mindspore +from typing import List, Optional, Tuple + + +from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams +from vllm.lora.request import LoRARequest + + +def create_test_prompts( + lora_path: str +) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]: + """Create a list of test prompts with their sampling parameters. + + 2 requests for base model, 4 requests for the LoRA. We define 2 + different LoRA adapters (using the same model for demo purposes). + Since we also set `max_loras=1`, the expectation is that the requests + with the second LoRA adapter will be ran after all requests with the + first adapter have finished. + """ + return [ + ("A robot may not injure a human being", + SamplingParams(temperature=0.0, + max_tokens=128), None), + ("To be or not to be,", + SamplingParams(temperature=0.8, + top_k=5, + max_tokens=128), None), + ( + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 + SamplingParams(temperature=0.0, + max_tokens=128), + LoRARequest("sql-lora", 1, lora_path)), + ( + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 + SamplingParams(temperature=0.0, + max_tokens=128), + LoRARequest("sql-lora2", 2, lora_path)), + ] + + +def process_requests(engine: LLMEngine, + test_prompts: List[Tuple[str, SamplingParams, + Optional[LoRARequest]]]): + """Continuously process a list of prompts and handle the outputs.""" + request_id = 0 + + while test_prompts or engine.has_unfinished_requests(): + if test_prompts: + prompt, sampling_params, lora_request = test_prompts.pop(0) + engine.add_request(str(request_id), + prompt, + sampling_params, + lora_request=lora_request) + request_id += 1 + + request_outputs: List[RequestOutput] = engine.step() + + for request_output in request_outputs: + if request_output.finished: + print(request_output) + + +def initialize_engine() -> LLMEngine: + """Initialize the LLMEngine.""" + # max_loras: controls the number of LoRAs that can be used in the same + # batch. Larger numbers will cause higher memory usage, as each LoRA + # slot requires its own preallocated tensor. + # max_lora_rank: controls the maximum supported rank of all LoRAs. Larger + # numbers will cause higher memory usage. If you know that all LoRAs will + # use the same rank, it is recommended to set this as low as possible. + # max_cpu_loras: controls the size of the CPU LoRA cache. + engine_args = EngineArgs(model="/home/ckpt/qwen2.5-7b-hf", + enable_lora=True, + max_loras=1, + max_lora_rank=64, + max_cpu_loras=2, + max_num_seqs=256) + return LLMEngine.from_engine_args(engine_args) + + +def main(): + """Main function that sets up and runs the prompt processing.""" + engine = initialize_engine() + lora_path = "/home/ckpt/qwen2.5-7b-lora-law" + test_prompts = create_test_prompts(lora_path) + process_requests(engine, test_prompts) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 47a9e4d58c3a922c2f16ff444ffb1a2619696627..dd307c4d590320e81450f066afdc3c5d6ef767e9 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -62,6 +62,18 @@ vllm.utils.cuda_device_count_stateless = ascend_device_count_stateless vllm.utils.cuda_is_initialized = ascend_is_initialized vllm.config.cuda_device_count_stateless = ascend_device_count_stateless +import vllm.lora.utils + +from vllm_mindspore.lora.utils import from_layer +vllm.lora.utils.from_layer = from_layer + +import vllm.lora.models +from vllm_mindspore.lora.models import register_module, from_local_checkpoint, from_lora_tensors + +vllm.lora.models.LoRAModelManager.register_module = register_module +vllm.lora.models.LoRAModel.from_local_checkpoint = from_local_checkpoint +vllm.lora.models.LoRAModel.from_lora_tensors =from_lora_tensors + import vllm.executor vllm.executor.cuda_device_count_stateless = ascend_device_count_stateless diff --git a/vllm_mindspore/lora/__init__.py b/vllm_mindspore/lora/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_mindspore/lora/fully_sharded_layers.py b/vllm_mindspore/lora/fully_sharded_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..cf93162099bb3f28caee9fd42ffd7cdaf8bff8a8 --- /dev/null +++ b/vllm_mindspore/lora/fully_sharded_layers.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + + +# pylint: disable=unused-argument +from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.distributed.communication_op import ( + tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank +from vllm_mindspore.lora.layers import (ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLora, + QKVParallelLinearWithLora, + RowParallelLinearWithLoRA) + +if TYPE_CHECKING: + pass + + +def _fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + return (can_replace(*args, **kwargs) + and kwargs["lora_config"].fully_sharded_loras) + + return dec + + +def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA): + """ + For `ColumnParallelLinearWithLoRA` or classes that inherit from + `ColumnParallelLinearWithLoRA`, they share the same `apply` logic. + """ + assert (layer.n_slices == len(layer.lora_a_stacked) == len( + layer.lora_b_stacked) == len(layer.output_slices)) + if layer.lora_bias_stacked is not None: + assert layer.n_slices == len(layer.lora_bias_stacked) + + output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape + + # Since communication is needed, the buffer is directly initialized as a + # tensor rather than a tuple of tensor. + buffers = torch.zeros( + (layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device, + ) + + layer.punica_wrapper.add_shrink(buffers, x, layer.lora_a_stacked, 1.0) + buffers = tensor_model_parallel_all_gather(buffers) + layer.punica_wrapper.add_expand(output, + buffers, + layer.lora_b_stacked, + layer.lora_bias_stacked, + layer.output_slices, + offset_start=0, + add_input=True) + + output = output.view(*out_orig_shape) + # now have column partitioned and packed output + return output + + +# these layers are based on the tensor parallelism strategy given in +# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023, +# https://arxiv.org/abs/2311.03285. + + +class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): + """ + Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + # For all LoRA layers where the `base_layer` is `ColumnParallelLinear`, + # their `lora_a` and `lora_b` have different sharding patterns. After + # completing the `lora_a` GEMM , a gather operation is performed. + # Therefore, the sharding of `lora_a` only needs to correspond with the + # gather operation. + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.lora_a_stacked[0].shape[2] + start_idx = tp_rank * shard_size + lora_a = lora_a[:, start_idx:start_idx + shard_size] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class MergedColumnParallelLinearWithShardedLoRA( + MergedColumnParallelLinearWithLoRA): + """ + Differs from MergedColumnParallelLinearWithLoRA by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + #NOTE: lora_a contains 2 subloras, and each sublora could be None. + output_shard_size = self.lora_a_stacked[0].shape[2] + output_start_idx = self.tp_rank * output_shard_size + lora_a = [ + lora_a[0][:, output_start_idx:output_start_idx + + output_shard_size] if lora_a[0] is not None else None, + lora_a[1][:, output_start_idx:output_start_idx + + output_shard_size] if lora_a[1] is not None else None, + ] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora): + """ + Differs from QKVParallelLinearWithLora by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.lora_a_stacked[0].shape[2] + start_idx = tp_rank * shard_size + lora_a = lora_a[:, start_idx:start_idx + shard_size] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora): + """ + Differs from MergedQKVParallelLinearWithLora by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + # NOTE: lora_a contains 3 subloras, and each sublora could be None. + shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] + start_idx = [self.tp_rank * shard_size[i] for i in range(3)] + lora_a = [ + lora_a[0][:, start_idx[0]:start_idx[0] + + shard_size[0]] if lora_a[0] is not None else None, + lora_a[1][:, start_idx[1]:start_idx[1] + + shard_size[1]] if lora_a[1] is not None else None, + lora_a[2][:, start_idx[2]:start_idx[2] + + shard_size[2]] if lora_a[2] is not None else None, + ] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): + """ + Differs from RowParallelLinearWithLoRA by slicing the + LoRA B's also. + + Based on S-LoRA, slicing happens along the output dim. + This yields a combined partial sum from the row parallel base + layer and column partitioned output from the LoRA. + """ + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + shard_size = self.lora_b_stacked[0].shape[2] + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + if bias is None: + return bias + self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked) + shard_size = self.lora_bias_stacked[0].shape[2] + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + bias = bias[start_idx:end_idx] + return bias + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, + output.shape[-1]), output.shape + buffer = torch.zeros( + (self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device, + ) + + self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0) + buffer = tensor_model_parallel_all_reduce(buffer) + + # following S-LoRA, allows the fusing of all_gather and all_reduce + # by adding the column partitioned lora output to a slice of output + # tensor, which is a partial sum due to row parallel. All that + # remains is a standard all_reduce. User should be aware though that + # the output is not the same as a normal row_parallel, it should be + # reduced before being used + # NOTE offset are based on the rank. + shard_size = self.lora_b_stacked[0].shape[2] + offset_start = self.tp_rank * shard_size + self.punica_wrapper.add_expand( + output, + buffer, + self.lora_b_stacked, + self.lora_bias_stacked, + self.output_slices, + offset_start=offset_start, + add_input=True, + ) + output = output.view(*out_orig_shape) + return output + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) diff --git a/vllm_mindspore/lora/layers.py b/vllm_mindspore/lora/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..80c4919e0ab2c6126f10d80d383698b27f5f1b6e --- /dev/null +++ b/vllm_mindspore/lora/layers.py @@ -0,0 +1,1182 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# pylint: disable=unused-argument +import math +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig + +from vllm.adapter_commons.layers import AdapterMapping +from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce) +from vllm.distributed.utils import divide +# yapf: disable +from vllm_mindspore.model_executor.layers.linear import (ColumnParallelLinear, + LinearBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +# yapf: enable +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.rotary_embedding import ( + LinearScalingRotaryEmbedding, RotaryEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +import mindspore as ms + +if TYPE_CHECKING: + from vllm.lora.punica_wrapper import PunicaWrapperBase + + +def _get_lora_device(base_layer: nn.Module) -> torch.device: + # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34 + """Returns the device for where to place the LoRA tensors.""" + # unquantizedLinear + if hasattr(base_layer, "weight"): + return base_layer.weight.device + # Compressed Tensor + elif hasattr(base_layer, "weight_packed"): + return base_layer.weight_packed.device + # GPTQ/AWQ + elif hasattr(base_layer, "qweight"): + return base_layer.qweight.device + # marlin + elif hasattr(base_layer, "B"): + return base_layer.B.device + # HQQ marlin + elif hasattr(base_layer, "W_q"): + return base_layer.W_q.device + else: + raise ValueError(f"Unsupported base layer: {base_layer}") + + +def _not_fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of not using fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + decorate = kwargs.pop("decorate") if "decorate" in kwargs else True + condition = (not kwargs["lora_config"].fully_sharded_loras + if decorate else True) + return can_replace(*args, **kwargs) and condition + + return dec + + +@dataclass +class LoRAMapping(AdapterMapping): + is_prefill: bool = False + + +class BaseLayerWithLoRA(ms.nn.Cell): + + def slice_lora_a( + self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]: + """Slice lora a if splitting for tensor parallelism.""" + ... + + def slice_lora_b( + self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]: + """Slice lora b if splitting with tensor parallelism.""" + ... + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """Initializes lora matrices.""" + ... + + def reset_lora(self, index: int): + """Resets the lora weights at index back to 0.""" + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + """Overwrites lora tensors at index.""" + ... + + def set_mapping( + self, + punica_wrapper, + ): + self.punica_wrapper: PunicaWrapperBase = punica_wrapper + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + raise NotImplementedError + + def construct(self, x: ms.Tensor) -> ms.Tensor: + raise NotImplementedError + + +class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: VocabParallelEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + self.embeddings_slice: Optional[Tuple[int, int]] + self.embeddings_weights: Optional[torch.Tensor] + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + + if self.base_layer.num_added_embeddings_per_partition > 0: + # We can start adding lora weights + self.embeddings_weights = self.base_layer.weight.data[ + self.base_layer.num_org_embeddings_per_partition:self. + base_layer.num_org_embeddings_per_partition + + self.base_layer.num_added_embeddings_per_partition] + self.embeddings_slice = ( + self.base_layer.shard_indices.added_vocab_start_index - + self.base_layer.org_vocab_size, + self.base_layer.shard_indices.added_vocab_end_index - + self.base_layer.org_vocab_size) + self.base_layer.weight.data[ + self.base_layer.num_org_embeddings_per_partition:].fill_(0) + else: + self.embeddings_slice = None + self.embeddings_weights = None + + self.embeddings_tensors = torch.zeros( + ( + max_loras, + lora_config.lora_extra_vocab_size, + self.base_layer.embedding_dim, + ), + dtype=self.base_layer.weight.dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked = torch.zeros( + ( + max_loras, + self.base_layer.org_vocab_size + + lora_config.lora_extra_vocab_size, + lora_config.max_lora_rank, + ), + dtype=torch.bfloat16, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + self.base_layer.embedding_dim, + lora_config.max_lora_rank, + ), + dtype=torch.bfloat16, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked_2d = self.lora_a_stacked.view( + self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], + self.lora_a_stacked.shape[2], + ) + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + self.reset_lora(index) + self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( + lora_a, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, + :embeddings_tensor.shape[0], + :embeddings_tensor.shape[1], + ].copy_(embeddings_tensor, non_blocking=True) + if self.embeddings_slice is not None: + # TODO(yard1): Optimize this copy, we don't need to copy + # everything, just the modified part + embeddings = self.embeddings_tensors.view( + self.embeddings_tensors.shape[0] * + self.embeddings_tensors.shape[1], + self.embeddings_tensors.shape[2], + )[self.embeddings_slice[0]:self.embeddings_slice[1]] + assert self.embeddings_weights is not None + self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) + + def construct(self, x: torch.Tensor) -> torch.Tensor: + added_tokens_mask = x > self.base_layer.org_vocab_size - 1 + embeddings_indices = self.punica_wrapper.embeddings_indices + indices = embeddings_indices[1].view_as(x) + full_lora_a_embeddings = F.embedding( + x + indices, + self.lora_a_stacked_2d, + ) + indices = embeddings_indices[0].view_as(x) + full_output = self.base_layer.forward( + x.add_(indices * added_tokens_mask)) + + full_output_org = full_output + if full_output.ndim == 3: + full_output = full_output.view( + full_output.shape[0] * full_output.shape[1], -1) + if full_lora_a_embeddings.ndim == 3: + full_lora_a_embeddings = full_lora_a_embeddings.view( + full_lora_a_embeddings.shape[0] * + full_lora_a_embeddings.shape[1], + -1, + ) + + full_output = self.punica_wrapper.add_lora_embedding(full_output, + full_lora_a_embeddings, + self.lora_b_stacked, + add_input=True) + return full_output.view_as(full_output_org) + + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is VocabParallelEmbedding + + +class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: LinearBase): + super().__init__() + self.base_layer = base_layer + self.input_size = self.base_layer.input_size + self.device = _get_lora_device(self.base_layer) + self.lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]] = None + + self.output_slices: Tuple[int, ...] + self.tp_size: int + self.output_size: int + self.n_slices: int + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + self.lora_config = lora_config + # TODO ReplicatedLinear + # if isinstance(self.base_layer, ReplicatedLinear): + # lora_a_out_size = lora_config.max_lora_rank + # lora_b_out_size = self.output_size + + if isinstance(self.base_layer, ColumnParallelLinear): + lora_a_out_size = (lora_config.max_lora_rank if + not lora_config.fully_sharded_loras else divide( + lora_config.max_lora_rank, self.tp_size)) + lora_b_out_size = self.output_size + + elif isinstance(self.base_layer, RowParallelLinear): + lora_a_out_size = lora_config.max_lora_rank + lora_b_out_size = (self.output_size if + not lora_config.fully_sharded_loras else divide( + self.output_size, self.tp_size)) + else: + raise NotImplementedError + + self.lora_a_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_a_out_size, + self.input_size, + dtype=torch.bfloat16, + device=self.device, + ) for _ in range(self.n_slices)) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_b_out_size, + lora_config.max_lora_rank, + dtype=torch.bfloat16, + device=self.device, + ) for _ in range(self.n_slices)) + if lora_config.bias_enabled: + lora_bias_out_size = lora_b_out_size + self.lora_bias_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_bias_out_size, + dtype=torch.bfloat16, + device=self.device, + ) for _ in range(self.n_slices)) + self.output_slices = (self.lora_b_stacked[0].shape[2], ) + + def reset_lora(self, index: int): + for s_index in range(self.n_slices): + self.lora_a_stacked[s_index][index] = 0 + self.lora_b_stacked[s_index][index] = 0 + if self.lora_config.bias_enabled: + # Make mypy happy + self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked) + self.lora_bias_stacked[s_index][index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + lora_bias: Optional[torch.Tensor] = None, + ): + # Except for QKVParallelLinearWithLora and + # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers + # store weights in a tuple of size 1. These two layers will + # override this function. + assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) == + self.n_slices == 1) + + self.reset_lora(index) + if self.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + if lora_bias is not None: + lora_bias = self.slice_bias(lora_bias) + + self.lora_a_stacked[0][index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[0][index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if lora_bias is not None: + + self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked) + assert len(self.lora_bias_stacked) + self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( + lora_bias.T, non_blocking=True) + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + output = self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked, + self.lora_b_stacked, + self.lora_bias_stacked, 1.0, + self.output_slices) + return output + +# TODO support ReplicatedLinearWithLoRA + +class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): + """ + LoRA on top of ColumnParallelLinear layer. + LoRA B is sliced for tensor parallelism. + There are two types for the `base_layer`: + 1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`. + 2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`. + """ + + def __init__(self, base_layer: ColumnParallelLinear) -> None: + super().__init__(base_layer) + # The base_layer type is ColumnParallelLinear or + # MergedColumnParallelLinear, their weight sharding logic is + # inconsistent when TP is greater than 1. + self.is_merged_col_linear = type( + base_layer) is MergedColumnParallelLinear + self.tp_size = get_tensor_model_parallel_world_size() + self.output_size = self.base_layer.output_size_per_partition + # There is only one LoRA layer + self.n_slices = 1 + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + # Applicable to cases where the base_layer is + # MergedColumnParallelLinear. + if self.is_merged_col_linear: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size // 2 + offset = lora_b.shape[-1] // 2 + + left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) * + shard_size] + right_weight = lora_b[:, offset + tp_rank * shard_size:offset + + (tp_rank + 1) * shard_size] + lora_b = torch.cat([left_weight, right_weight], dim=1) + # Applicable to cases where the base_layer is + # ColumnParallelLinear. + else: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + # TODO: Fix the slicing logic of bias. + if bias is None: + return bias + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + bias = bias[start_idx:end_idx] + return bias + + def construct( + self, input_: torch.Tensor + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """Forward of ColumnParallelLinear + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output_parallel = self.apply(input_, bias) + if self.base_layer.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + return output, output_bias + + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is ColumnParallelLinear or ( + type(source_layer) is MergedColumnParallelLinear + and len(packed_modules_list) == 1) + + +class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + """ColumnParallelLinear layer that is composed of 2 sublayers (slices) + packed together (eg. gate_proj + up_proj -> gate_up_proj). + + This means we have 2 LoRAs, each applied to one half of the layer. + + Both slices must have the same size. + """ + + def __init__( + self, base_layer: Union[MergedColumnParallelLinear, + QKVParallelLinear]) -> None: + super().__init__(base_layer) + # There are two LoRA layers + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + # the output_sizes in MergedColumnParallelLinear is not sharded by tp + # we need to divide it by the tp_size to get correct slices size + output_sizes = self.base_layer.output_sizes + self.output_slices = tuple( + divide(output_size, self.tp_size) for output_size in output_sizes) + self.n_slices = len(self.output_slices) + self.output_ids = (self.tp_rank, ) * self.n_slices + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """ + The main reason for overriding this function is to enhance code + maintainability. + """ + self.lora_config = lora_config + + lora_a_output_size_per_partition = ( + lora_config.max_lora_rank if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size)) + self.lora_a_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_a_output_size_per_partition, + self.input_size, + dtype=torch.bfloat16, + device=self.device, + ) for _ in range(self.n_slices)) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + output_size, + lora_config.max_lora_rank, + dtype=torch.bfloat16, + device=self.device, + ) for output_size in self.output_slices) + if lora_config.bias_enabled: + self.lora_bias_stacked = tuple( + torch.zeros( + max_loras, + 1, + output_size, + dtype=torch.bfloat16, + device=self.device, + ) for output_size in self.output_slices) + + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + return lora_a + + def slice_lora_b( + self, lora_b: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + for i, (shard_id, shard_size) in enumerate( + zip(self.output_ids, self.output_slices)): + if (lora_b_i := lora_b[i]) is not None: + lora_b[i] = lora_b_i[:, shard_size * shard_id:shard_size * + (shard_id + 1)] + return lora_b + + def slice_bias( + self, bias: List[Union[torch.Tensor, + None]]) -> List[Union[torch.Tensor, None]]: + for i, (shard_id, shard_size) in enumerate( + zip(self.output_ids, self.output_slices)): + if (bias_i := bias[i]) is not None: + bias[i] = bias_i[shard_size * shard_id:shard_size * + (shard_id + 1)] + return bias + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + lora_bias: Optional[torch.Tensor] = None, + ): + self.reset_lora(index) + + if self.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + if lora_bias is not None: + lora_bias = self.slice_bias(lora_bias) + + for i in range(self.n_slices): + if (lora_a_i := lora_a[i]) is not None: + self.lora_a_stacked[i][ + index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_( + lora_a_i.T, non_blocking=True) + if (lora_b_i := lora_b[i]) is not None: + self.lora_b_stacked[i][ + index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_( + lora_b_i.T, non_blocking=True) + + if lora_bias is not None: + self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked) + for i in range(self.n_slices): + if (lora_bias_i := lora_bias[i]) is not None: + self.lora_bias_stacked[i][index, + 0, :lora_bias_i.shape[0]].copy_( + lora_bias_i.T, + non_blocking=True) + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return (type(source_layer) is MergedColumnParallelLinear + and len(packed_modules_list) == 2) + + +class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): + """ + ColumnParallelLinear layer that is specifically designed for + qkv_proj. Certain models, such as chatglm3 and baichuan-7b, + only contains a single LoRA within their qkv_proj layer. + + During inference with Tensor Parallel, the weights of lora_b + must be accurately partitioned according to the respective ranks. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + self.q_proj_total_size = (self.base_layer.total_num_heads * + self.base_layer.head_size) + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.kv_proj_total_size = (self.base_layer.total_num_kv_heads * + self.base_layer.head_size) + # There is only one LoRA layer + self.n_slices = 1 + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + self.q_shard_id = tp_rank + self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas + lora_b_q = lora_b[:, self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + k_offset = self.q_proj_total_size + lora_b_k = lora_b[:, k_offset + + self.kv_proj_shard_size * self.kv_shard_id:k_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + v_offset = k_offset + self.kv_proj_total_size + lora_b_v = lora_b[:, v_offset + + self.kv_proj_shard_size * self.kv_shard_id:v_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1) + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + bias_q = bias[self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + k_offset = self.q_proj_total_size + bias_k = bias[k_offset + + self.kv_proj_shard_size * self.kv_shard_id:k_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + v_offset = k_offset + self.kv_proj_total_size + bias_v = bias[v_offset + + self.kv_proj_shard_size * self.kv_shard_id:v_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + bias = torch.cat([bias_q, bias_k, bias_v], dim=1) + return bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + return type(source_layer) is QKVParallelLinear and len( + packed_modules_list) == 1 + + +class MergedQKVParallelLinearWithLora(MergedColumnParallelLinearWithLoRA): + """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices) + packed together in qkv proj fashion + (q_proj + k_proj + v_proj -> qkv_proj). + + This means we have 3 LoRAs, each applied to one slice of the layer. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + # There are three LoRA layer. + self.n_slices = len(self.base_layer.output_sizes) + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.q_shard_id = self.tp_rank + self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas + + self.output_slices = ( + self.q_proj_shard_size, + self.kv_proj_shard_size, + self.kv_proj_shard_size, + ) + self.output_ids = ( + self.q_shard_id, + self.kv_shard_id, + self.kv_shard_id, + ) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """ + The main reason for overloading this function is to handle inconsistent + weight dimensions in qkv lora. + """ + super().create_lora_weights(max_loras, lora_config, model_config) + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return (type(source_layer) is QKVParallelLinear + and len(packed_modules_list) == 3) + + +class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): + + def __init__(self, base_layer: RowParallelLinear) -> None: + super().__init__(base_layer) + + self.tp_size = get_tensor_model_parallel_world_size() + # reset input_size + self.input_size = self.base_layer.input_size_per_partition + self.output_size = self.base_layer.output_size + + self.tp_rank = get_tensor_model_parallel_rank() + # There is only one LoRA layer. + self.n_slices = 1 + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + + shard_size = self.input_size + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_a = lora_a[start_idx:end_idx, :] + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + return bias + + def construct( + self, input_: torch.Tensor + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """Forward of RowParallelLinear + + Args: + input_: tensor whose last dimension is `input_size`. If + `input_is_parallel` is set, then the last dimension + is `input_size // tp_size`. + + Returns: + - output + - bias + """ + # Set up backprop all-reduce. + if self.base_layer.input_is_parallel: + input_parallel = input_ + else: + # TODO: simplify code below + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.base_layer.tp_size) + input_parallel = splitted_input[self.tp_rank].contiguous() + + # Matrix multiply. + output_parallel = self.apply(input_parallel) + if self.base_layer.reduce_results and self.base_layer.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.base_layer.skip_bias_add: + output = (output_ + self.base_layer.bias + if self.base_layer.bias is not None else output_) + output_bias = None + else: + output = output_ + output_bias = self.base_layer.bias + return output, output_bias + + + @property + def weight(self): + return (self.base_layer.weight if hasattr(self.base_layer, "weight") + else self.base_layer.qweight) + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is RowParallelLinear + + +class LogitsProcessorWithLoRA(BaseLayerWithLoRA): + """ + LoRA wrapper for LogitsProcessor, with extra logic to handle the + application of the LoRA adapter and added LoRA vocabulary. + + Args: + base_layer: LogitsProcessor layer + hidden_size: hidden size of the model + dtype: data type of the model + device: device of the model + sharded_to_full_mapping: index mapping from sharded vocab to full vocab + received from base_layer.get_sharded_to_full_mapping(). If None, + no reindexing will be done. + """ + + def __init__(self, base_layer: LogitsProcessor, hidden_size: int, + dtype: torch.dtype, device: torch.device, + sharded_to_full_mapping: Optional[List[int]]) -> None: + super().__init__() + self.base_layer = base_layer + self.hidden_size = hidden_size + self.dtype = dtype + self.device = device + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.sharded_to_full_mapping = sharded_to_full_mapping + + @property + def logits_as_input(self): + return self.base_layer.logits_as_input + + @property + def vocab_size(self): + return self.base_layer.vocab_size + + @property + def scale(self): + return self.base_layer.scale + + @property + def soft_cap(self): + return self.base_layer.soft_cap + + @property + def use_all_gather(self): + return self.base_layer.use_all_gather + + @property + def org_vocab_size(self): + return self.base_layer.org_vocab_size + + @property + def include_gpu_probs_tensor(self): + return self.base_layer.include_gpu_probs_tensor + + @property + def should_modify_greedy_probs_inplace(self): + return self.base_layer.should_modify_greedy_probs_inplace + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + # TODO: Verify if this condition can be further relaxed + if 32000 < self.base_layer.vocab_size > 257024: + raise ValueError("When using LoRA, vocab size must be " + "32000 >= vocab_size <= 257024") + self.lora_a_stacked = torch.zeros( + ( + max_loras, + 1, + lora_config.max_lora_rank, + self.hidden_size, + ), + dtype=torch.bfloat16, + device=self.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + # Pad for kernel compatibility + math.ceil(self.base_layer.vocab_size / + lora_config.lora_vocab_padding_size) * + lora_config.lora_vocab_padding_size, + lora_config.max_lora_rank, + ), + dtype=torch.bfloat16, + device=self.device, + ) + self.embeddings_tensors = torch.full( + (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), + fill_value=float("-inf"), + dtype=self.dtype, + device=self.device, + ) + if self.sharded_to_full_mapping is not None: + self.sharded_to_full_mapping_gpu = torch.tensor( + self.sharded_to_full_mapping, + device=self.device, + dtype=torch.long) + else: + self.sharded_to_full_mapping_gpu = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = float("-inf") + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + self.reset_lora(index) + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, + :embeddings_tensor.shape[0], + :embeddings_tensor.shape[1], + ] = embeddings_tensor + + def _get_logits( + self, + hidden_states: torch.Tensor, + lm_head: VocabParallelEmbedding, + embedding_bias: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: + # Get the logits for the next tokens. + logits = lm_head.quant_method.apply(lm_head, hidden_states) + if embedding_bias is not None: + logits += embedding_bias + + # Gather logits for TP + logits = self.base_layer._gather_logits(logits) + + if logits is None: + return None + + if self.sharded_to_full_mapping_gpu is not None: + # Reindex full logits tensor to ensure 1:1 mapping between + # index and token_id + # Example for: + # org_vocab_size = 4 + # added_vocab_size = 2 + # pad_to_size = 8 + # tp_size = 2 + + # indices: [0, 1, 2, 3, 4, 5, 6, 7] + # token_id: [0, 1, 4, -1, 2, 3, 5, -1] + + # Therefore, the mapping is expected to be: + # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex, + # we get: + # indices: [0, 1, 2, 3, 4, 5, 6, 7] + # token_id: [0, 1, 2, 3, 4, 5, -1, -1] + logits = logits[:, self.sharded_to_full_mapping_gpu] + + lora_logits = torch.empty( + self.embeddings_tensors.shape[0] + 1, + self.embeddings_tensors.shape[1], + hidden_states.shape[0], + dtype=self.embeddings_tensors.dtype, + device=self.embeddings_tensors.device, + ) + torch.matmul(self.embeddings_tensors, + hidden_states.T, + out=lora_logits[:-1]) + lora_logits[-1] = float("-inf") + lora_logits = lora_logits.mT + indices_padded = self.punica_wrapper.sampler_indices_padded + lora_logits = (lora_logits.reshape( + lora_logits.shape[0] * lora_logits.shape[1], + lora_logits.shape[2], + ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"), + posinf=float("inf"), + neginf=float("-inf"))) + + logits[:, + self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + + lora_logits.shape[1]] = lora_logits + + # LogitsProcessorWithLoRA always using bgmv + self.punica_wrapper.add_lora_logits(logits, hidden_states, + self.lora_a_stacked, + self.lora_b_stacked, 1.0) + + # Remove paddings in vocab (if any). + logits = logits[:, :self.base_layer.vocab_size] + return logits + + def construct(self, *args, **kwargs): + return type(self.base_layer).forward(self, *args, **kwargs) + + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + # Special handling for the LogitsProcessor. + return False + + +class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA): + """Implements RoPE-scaled embeddings with linear scaling for + multiple LoRA adapters with a specialized kernel. + + Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding + which can handle multi lora adapters in a specialied kernel. + """ + + def __init__(self, base_layer: RotaryEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + + @property + def scaling_factors(self): + return self.base_layer.scaling_factors + + @property + def rotary_dim(self): + return self.base_layer.rotary_dim + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + scaling_factors = (list(lora_config.long_lora_scaling_factors) + if lora_config.long_lora_scaling_factors else []) + base_scaling_factor = (self.base_layer.scaling_factor if isinstance( + self.base_layer, LinearScalingRotaryEmbedding) else 1.0) + scaling_factors = sorted( + list(set([base_scaling_factor] + scaling_factors))) + self.base_layer = LinearScalingRotaryEmbedding( + self.base_layer.head_size, + self.base_layer.rotary_dim, + self.base_layer.max_position_embeddings, + self.base_layer.base, + self.base_layer.is_neox_style, + scaling_factors, + self.base_layer.dtype, + ) + + def reset_lora(self, index: int): + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + ... + + def construct( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + return self.base_layer( + positions, + query, + key, + offsets=self.punica_wrapper.long_lora_indices, + ) + + @property + def scaling_factor_to_offset(self) -> Dict[float, int]: + return self.base_layer.scaling_factor_to_offset + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + """Returns True if the layer can be replaced by this LoRA layer.""" + return (type(source_layer) is LinearScalingRotaryEmbedding + or type(source_layer) is RotaryEmbedding) + + def extra_repr(self) -> str: + return self.base_layer.extra_repr() diff --git a/vllm_mindspore/lora/models.py b/vllm_mindspore/lora/models.py new file mode 100644 index 0000000000000000000000000000000000000000..67dacae0b3688ad06446e0c37b93aef3176adab1 --- /dev/null +++ b/vllm_mindspore/lora/models.py @@ -0,0 +1,232 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from vllm_mindspore.lora.layers import BaseLayerWithLoRA +import copy +import math +import os +import re +from dataclasses import dataclass, field +from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Type, + Union) + +import safetensors.torch +import torch + +from vllm.lora.peft_helper import PEFTHelper +from vllm.lora.lora import LoRALayerWeights +from vllm.lora.utils import (is_regex_target_modules, + parse_fine_tuned_lora_name) +from vllm.model_executor.models.utils import WeightsMapper +from vllm.utils import is_pin_memory_available + +_GLOBAL_LORA_ID = 0 + +def get_lora_id(): + global _GLOBAL_LORA_ID + _GLOBAL_LORA_ID += 1 + return _GLOBAL_LORA_ID + +def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): + assert isinstance(module, BaseLayerWithLoRA) + self.modules[module_name] = module + +@classmethod +def from_lora_tensors( + cls, + lora_model_id: int, + tensors: Dict[str, torch.Tensor], + peft_helper: PEFTHelper, + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + embeddings: Optional[Dict[str, torch.Tensor]] = None, + target_embedding_padding: Optional[int] = None, + embedding_modules: Optional[Dict[str, str]] = None, + embedding_padding_modules: Optional[List[str]] = None, + weights_mapper: Optional[WeightsMapper] = None, + ) -> "LoRAModel": + """Create a LoRAModel from a dictionary of tensors.""" + pin_memory = str(device) == "cpu" and is_pin_memory_available() + loras: Dict[str, LoRALayerWeights] = {} + for tensor_name, tensor in tensors.items(): + module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name( + tensor_name, weights_mapper) + if module_name not in loras: + lora_embeddings_tensor = None + if embeddings: + assert embedding_modules is not None + embeddings_module = next( + (k for k in embedding_modules if k in module_name), + None) + if embeddings_module: + lora_embeddings_tensor = embeddings[ + embedding_modules[embeddings_module]] + if pin_memory: + lora_embeddings_tensor = ( + lora_embeddings_tensor.pin_memory()) + loras[module_name] = LoRALayerWeights.from_config( + module_name, peft_helper, lora_embeddings_tensor) + + if is_bias: + loras[module_name].bias = tensor.t() + bias = tensor.t() + if pin_memory: + bias = bias.pin_memory() + loras[module_name].bias = bias + elif is_lora_a: + loras[module_name].lora_a = tensor.t() + if pin_memory: + loras[module_name].lora_a = loras[ + module_name].lora_a.pin_memory() + else: + loras[module_name].lora_b = tensor.t() + assert embedding_padding_modules is not None + if any(name in module_name + for name in embedding_padding_modules + ) and target_embedding_padding is not None: + lora_b = loras[module_name].lora_b + assert target_embedding_padding >= lora_b.shape[1] + addition = target_embedding_padding - lora_b.shape[1] + loras[module_name].lora_b = torch.nn.functional.pad( + lora_b, (0, addition)) + if pin_memory: + loras[module_name].lora_b = loras[ + module_name].lora_b.pin_memory() + + for lora in loras.values(): + lora.optimize() + + return cls(lora_model_id, + peft_helper.r, + loras, + scaling_factor=peft_helper.vllm_long_context_scaling_factor) + +@classmethod +def from_local_checkpoint( + cls, + lora_dir: str, + expected_lora_modules: List[str], + peft_helper: PEFTHelper, + *, + lora_model_id: Optional[int] = None, + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + target_embedding_padding: Optional[int] = None, + embedding_modules: Optional[Dict[str, str]] = None, + embedding_padding_modules: Optional[List[str]] = None, + weights_mapper: Optional[WeightsMapper] = None, + ) -> "LoRAModel": + """Create a LoRAModel from a local checkpoint. + + Args: + lora_dir: The local path that has lora data. + expected_lora_modules: Name of modules that are expected to be + replaced by lora. + peft_helper: Loaded lora configuration information. + lora_model_id: Lora model id. If not given, automatically set by + a global counter. + device: Device where the lora model is loaded. + dtype: dtype of the lora model weights. + + Returns: + Loaded LoRA Model. + """ + lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") + lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") + new_embeddings_tensor_path = os.path.join( + lora_dir, "new_embeddings.safetensors") + new_embeddings_bin_file_path = os.path.join(lora_dir, + "new_embeddings.bin") + + unexpected_modules: List[Union[list[str], str]] + if os.path.isfile(lora_tensor_path): + tensors: Dict[str, torch.Tensor] = {} + # Find unexpected modules. + # Use safetensor key as a source of truth to find expected modules. + # in peft if you have target_modules A, B, C and C does not exist + # in the model it won’t error and model will be trained with A, B + # loraified. C won’t exist in the safetensor but it will exist in + # the target_modules of the adapter_config.json. + unexpected_modules = [] + with safetensors.safe_open(lora_tensor_path, + framework="np") as f: # type: ignore + for lora_module in f.keys(): # noqa + module_name, _, _ = parse_fine_tuned_lora_name( + lora_module, weights_mapper) + part_name = module_name.split(".")[-1] + if part_name not in expected_lora_modules: + unexpected_modules.append(module_name) + if unexpected_modules: + raise ValueError( + f"While loading {lora_dir}, expected" + f" target modules in {expected_lora_modules}" + f" but received {unexpected_modules}." + f" Please verify that the loaded LoRA module is correct" + ) + # Load tensors if there are only expected modules. + for module in f.keys(): # noqa + tensors[module] = torch.Tensor(f.get_tensor(module)) + elif os.path.isfile(lora_bin_file_path): + # When a bin file is provided, we rely on config to find unexpected + # modules. + unexpected_modules = [] + target_modules = peft_helper.target_modules + if not isinstance(target_modules, list): + target_modules = [target_modules] + for module in target_modules: + # Compatible with more modules, + # such as:layers.11.self_attn.k_proj + part_name = module.split(".")[-1] + if part_name not in expected_lora_modules: + unexpected_modules.append(module) + # loaded lora's target modules must be a subset of + # expected_lora_modules. It is not reliable. See + # https://github.com/vllm-project/vllm/pull/5909. But there's no + # other better mechanism. + if unexpected_modules and not is_regex_target_modules( + peft_helper.target_modules, expected_lora_modules): + raise ValueError( + f"While loading {lora_dir}, expected" + f" target modules in {expected_lora_modules}" + f" but received {unexpected_modules}." + f" Please verify that the loaded LoRA module is correct") + tensors = torch.load(lora_bin_file_path, map_location=device) + else: + raise ValueError(f"{lora_dir} doesn't contain tensors") + + embeddings = None + if os.path.isfile(new_embeddings_tensor_path): + embeddings = safetensors.torch.load_file( + new_embeddings_tensor_path) + elif os.path.isfile(new_embeddings_bin_file_path): + embeddings = torch.load(new_embeddings_bin_file_path, + map_location=device, + weights_only=True) + + return cls.from_lora_tensors( + lora_model_id=get_lora_id() + if lora_model_id is None else lora_model_id, + tensors=tensors, + peft_helper=peft_helper, + device=device, + dtype=dtype, + embeddings=embeddings, + target_embedding_padding=target_embedding_padding, + embedding_modules=embedding_modules, + embedding_padding_modules=embedding_padding_modules, + weights_mapper=weights_mapper) diff --git a/vllm_mindspore/lora/ops/torch_ops/__init__.py b/vllm_mindspore/lora/ops/torch_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1fb53bd6dc6e17ca55caacc9e5bb5865158a64a --- /dev/null +++ b/vllm_mindspore/lora/ops/torch_ops/__init__.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from vllm_mindspore.lora.ops.torch_ops.lora_ops import bgmv_expand # noqa: F401 +from vllm_mindspore.lora.ops.torch_ops.lora_ops import (bgmv_expand_slice, bgmv_shrink, + sgmv_expand, sgmv_expand_slice, + sgmv_shrink) + +__all__ = [ + "bgmv_expand", + "bgmv_expand_slice", + "bgmv_shrink", + "sgmv_expand", + "sgmv_expand_slice", + "sgmv_shrink", +] diff --git a/vllm_mindspore/lora/ops/torch_ops/lora_ops.py b/vllm_mindspore/lora/ops/torch_ops/lora_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..bfad97faceb86a2ba1852bfc6be637dd238507ce --- /dev/null +++ b/vllm_mindspore/lora/ops/torch_ops/lora_ops.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import torch +from mindspore import mint +import mindspore as ms + +def einsum_ms(inputs, selected_loras): + # mint.einsum("bi, boi -> bo", inputs, selected_loras) + selected_loras = mint.transpose(selected_loras, 1, 2) + outputs = mint.matmul(inputs.unsqueeze(1), selected_loras).squeeze(1) + return outputs + + +def sgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False): + exploded_indices = mint.repeat_interleave(lora_indices_tensor, + seq_len_tensor) + + return bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, + add_inputs) + + +def bgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True): + selected_loras = lora_b_weights[lora_indices_tensor].astype(output_tensor.dtype) + inputs = inputs.astype(output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(1) + outputs = einsum_ms(inputs, selected_loras) + limit = output_tensor.shape[0] + out_type = output_tensor.dtype + outputs = outputs.astype(ms.float32) + output_tensor = output_tensor.astype(ms.float32) + if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: + limit = 1 + if add_inputs: + output_tensor[:, :outputs.shape[1]] += outputs[:limit, :] + else: + output_tensor[:, :outputs.shape[1]] = outputs[:limit, :] + return output_tensor.astype(out_type) + + +def sgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + scaling: float, +): + exploded_indices = mint.repeat_interleave(lora_indices_tensor, + seq_len_tensor) + return bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, + scaling) + + +def bgmv_shrink(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0): + selected_loras = lora_b_weights[lora_indices_tensor].astype(output_tensor.dtype) + inputs = inputs.astype(output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(1) + outputs = einsum_ms(inputs, selected_loras) + out_type = output_tensor.dtype + outputs = outputs.astype(ms.float32) + output_tensor = output_tensor.astype(ms.float32) + output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] + return output_tensor.astype(out_type) + + +def sgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False): + exploded_indices = mint.repeat_interleave(lora_indices_tensor, + seq_len_tensor) + + return bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices, + slice_offset, slice_size, add_inputs) + + +def bgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True): + selected_loras = lora_b_weights[lora_indices_tensor].astype(output_tensor.dtype) + inputs = inputs.astype(output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(1) + outputs = einsum_ms(inputs, selected_loras) + out_type = output_tensor.dtype + outputs = outputs.astype(ms.float32) + output_tensor = output_tensor.astype(ms.float32) + if add_inputs: + output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] + else: + output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:] + return output_tensor.astype(out_type) \ No newline at end of file diff --git a/vllm_mindspore/lora/punica_wrapper/__init__.py b/vllm_mindspore/lora/punica_wrapper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_mindspore/lora/punica_wrapper/punica_base.py b/vllm_mindspore/lora/punica_wrapper/punica_base.py new file mode 100644 index 0000000000000000000000000000000000000000..6c42834f2afadba04fb762c76b762bf1e35f0753 --- /dev/null +++ b/vllm_mindspore/lora/punica_wrapper/punica_base.py @@ -0,0 +1,500 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import torch + +from vllm.lora.punica_wrapper.utils import compute_meta, convert_mapping + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext + + +class PunicaWrapperABC(ABC): + """ + PunicaWrapper ABC. + """ + + @abstractmethod + def update_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + **kwargs, + ) -> None: + """ + Update the lora-related metadata + """ + raise NotImplementedError + + @abstractmethod + def add_shrink( + self, + y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, + **kwargs, + ) -> None: + """ + Performs GEMM for multiple slices of lora_a. + """ + + raise NotImplementedError + + @abstractmethod + def add_expand( + self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs, + ) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + """ + raise NotImplementedError + + @abstractmethod + def add_lora_embedding( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs, + ) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA, + and this layer only requires the expand operation. + """ + raise NotImplementedError + + @abstractmethod + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + """ + + raise NotImplementedError + + @abstractmethod + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + """ + raise NotImplementedError + + +class PunicaWrapperBase(PunicaWrapperABC): + """ + PunicaWrapperBase is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the punica. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + self._token_lora_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._sampler_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._sampler_indices_padded = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._embeddings_indices = torch.empty(2, + max_num_batched_tokens, + dtype=torch.long, + device=device) + self._long_lora_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + + # 5 is the number of indices tensors. + # base_indices, sampler_indices, sampler_indices_padded, + # embeddings_indices,long_lora_indices + self.indices_len: List[Optional[int]] = [None] * 5 + # these attributes are the information required for sgmv kernel + self._seq_start_locs = torch.empty(max_batches, + dtype=torch.long, + device=device) + self._seq_lengths = torch.empty(max_batches, + dtype=torch.long, + device=device) + self._lora_indices_per_batch = torch.empty(max_batches, + dtype=torch.long, + device=device) + self.device: torch.device = device + self.max_length: int = 0 + self.token_nums: int = 0 + self.batch_size: int = -1 + self.is_prefill = False + self.no_lora = False + + def _update_base_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + ): + ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_offsets_tensor, + indices_len, + ) = convert_mapping( + mapping, + lora_index_to_id, + max_loras, + vocab_size, + extra_vocab_size, + self.device, + long_lora_context, + ) + self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) + self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) + self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded) + self._embeddings_indices[:embeddings_indices. + shape[0], :embeddings_indices.shape[1]].copy_( + embeddings_indices) + if long_lora_offsets_tensor is not None: + self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( + long_lora_offsets_tensor) + else: + self._long_lora_indices.zero_() + self.indices_len[:] = indices_len + + def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: + + (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, + batch_size, max_length, token_nums, + no_lora) = compute_meta(token_lora_tensor) + + self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( + b_seq_start_tensor) + self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor) + self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_( + lora_indices_tensor) + self.batch_size = batch_size + self.max_length = max_length + self.token_nums = token_nums + self.no_lora = no_lora + + def _apply_bias( + self, + indices: torch.Tensor, + output: torch.Tensor, + output_slices: Tuple[int, ...], + lora_bias_stacked: Tuple[Optional[torch.Tensor], ...], + ): + """Applies bias to output + + Input shapes: + lora_bias_stacked: 3 element tuple of (num_loras, output_dim) + indices: (batch_size) + output: (batch_size, q_slice_size + 2*kv_slice_size) + output_slices: n-1 element tuple of (slice_size...), + where n is number of slices + """ + org_output = output + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + + offset_left = 0 + for slice_idx, slice in enumerate(output_slices): + bias = lora_bias_stacked[slice_idx] + if bias is not None: + bias = bias.view(-1, bias.shape[-1]) + bias = bias[indices] + bias[indices == -1] = 0 + output[:, offset_left:offset_left + slice] += bias + offset_left += slice + + return output.view_as(org_output) + + @property + def prefill_metadata( + self + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: + """ + This property provides a convenient way to access the necessary + metadata for prefill-related kernel computations. + 1. seq_start_locs: Tensor of sequence start positions. + 2. seq_lengths: Tensor of sequence lengths. + 3. lora_indices_per_batch: Tensor of lora indices, and an index of + -1 means no lora should be applied. + 4. batch_size: Batch size after clustering identical lora indices. + 5. max_length: The maximum sequence length in the batch. + 6. token_nums: The token numbers in the batch. + """ + return (self._seq_start_locs[:self.batch_size], + self._seq_lengths[:self.batch_size], + self._lora_indices_per_batch[:self.batch_size], + self.batch_size, self.max_length, self.token_nums) + + @property + def token_lora_indices(self) -> torch.Tensor: + """ + This property provides the lora indices corresponding to each token + in the batch. An index of -1 means no lora should be applied. + """ + token_lora_len = self.indices_len[0] + return self._token_lora_indices[:token_lora_len] + + @property + def sampler_indices(self) -> torch.Tensor: + """ + This property is used to access the lora indices specifically for + LogitsProcessorWithLoRA. + """ + sampler_indices_len = self.indices_len[1] + return self._sampler_indices[:sampler_indices_len] + + @property + def sampler_indices_padded(self) -> torch.Tensor: + """ + This property provides access to padded sampler indices. + """ + indices_padded_len = self.indices_len[2] + return self._sampler_indices_padded[:indices_padded_len] + + @property + def embeddings_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for lora embeddings, + specifically for VocabParallelEmbeddingWithLoRA. + """ + embeddings_indices_len = self.indices_len[3] + return self._embeddings_indices[:, :embeddings_indices_len] + + @property + def long_lora_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for long context + lora, specifically for LinearScalingRotaryEmbeddingWithLora. + """ + long_lora_len = self.indices_len[4] + return self._long_lora_indices[:long_lora_len] + + def update_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + **kwargs): + + self._update_base_metadata(mapping, lora_index_to_id, max_loras, + vocab_size, extra_vocab_size, + long_lora_context) + if mapping.is_prefill: + # Update metadata required for prefill-related operators. + self._update_prefill_metada(self.token_lora_indices) + self.is_prefill = True + else: + self.is_prefill = False + + @abstractmethod + def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, **kwargs) -> None: + """ + Performs GEMM for multiple slices of lora_a. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + @abstractmethod + def add_expand(self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + offset = offset_start + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + offset_start (int): The starting position of y, defaults to 0 + add_inputs (bool): Defaults to True. + + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + @abstractmethod + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + and this layer only requires the expand operation. + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + @abstractmethod + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + @abstractmethod + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + # TODO: implement it based on torch ops + raise NotImplementedError diff --git a/vllm_mindspore/lora/punica_wrapper/punica_npu.py b/vllm_mindspore/lora/punica_wrapper/punica_npu.py new file mode 100644 index 0000000000000000000000000000000000000000..ab53920c0138e1f1ad23ecd3065b6e006d586bb1 --- /dev/null +++ b/vllm_mindspore/lora/punica_wrapper/punica_npu.py @@ -0,0 +1,367 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from typing import Callable, Optional, Tuple, Union + +import torch +from vllm_mindspore.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) +from vllm_mindspore.lora.punica_wrapper.punica_base import PunicaWrapperBase +from mindspore import mint + +# The platforms that are compatible with the PyTorch-native implementation can +# inherit this class +class PunicaWrapperNPU(PunicaWrapperBase): + """ + PunicaWrapperNPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the pytorch punica ops. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + + def _shrink_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + return sgmv_shrink( + x, + w_t_all, + y, + *self.prefill_metadata, + scale, + ) + + def _shrink_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + return bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + + def _expand_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_inputs: bool, + ): + return sgmv_expand( + x, + w_t_all, + y, + *self.prefill_metadata, + add_inputs, + ) + + def _expand_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_inputs: bool, + ): + return bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) + + def _expand_slice_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool, + ): + return sgmv_expand_slice( + x, + w_t_all, + y, + *self.prefill_metadata, + y_offset, + y_slice_size, + add_inputs, + ) + + def _expand_slice_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool, + ): + return bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, + y_slice_size, add_inputs) + + def _apply_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool = True, + ): + """ + Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` + computation, which is suitable for the + GEMM of lora'b. + """ + + expand_slice_fun: Callable = (self._expand_slice_prefill + if self.is_prefill else + self._expand_slice_decode) + y = expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs) + return y + + def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, scale: float): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + shrink_fun: Callable = (self._shrink_prefill + if self.is_prefill else self._shrink_decode) + y = shrink_fun(y, x, w_t_all, scale) + y = y.view_as(y_org) + return y + + def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, **kwargs): + """ + Performs GEMM for multiple slices of lora_a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ + + x = x.view(-1, x.shape[-1]) + # TODO fuse these kernels + for slice_idx in range(len(lora_a_stacked)): + y_list = list(y) + y_list[slice_idx] = self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], scale) + y = tuple(y_list) + return y + + def add_expand(self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + add_inputs (bool): Defaults to True. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + offset_left = offset_start + if lora_bias_stacked is not None: + self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + for slice_idx in range(len(lora_b_stacked)): + y = self._apply_expand( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + add_inputs=add_inputs, + ) + offset_left += output_slices[slice_idx] + y = y.view_as(y_org) + return y + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + #No LoRA request, so return directly + if self.no_lora: + return y + # Embedding layer only need expand op + expand_fun: Callable = (self._expand_prefill + if self.is_prefill else self._expand_decode) + y = expand_fun(y, x, lora_b_stacked, add_inputs) + return y + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + """ + #No LoRA request, so return directly + if self.no_lora: + return y + x = x.reshape(-1, x.shape[-1]) + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + y = self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + + if buffer is None: + r = lora_b_stacked[0].shape[-1] + # We set the buffer to be float32 by default, consistent with the + # triton op + buffer = tuple( + mint.zeros( + (x.shape[0], r), dtype=torch.float32) + for _ in range(len(output_slices))) + buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) + y = self.add_expand(y, + buffer, + lora_b_stacked, + None, + output_slices, + add_inputs=True, + **kwargs) + return y + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + #No LoRA request, so return directly + if self.no_lora: + return y + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = lora_b_stacked.shape[-1] + if buffer is None: + # We set the buffer to be float32 by default, consistent with the + # triton op + buffer = mint.zeros((x.shape[0], r), + dtype=torch.float32) + # LogitsProcessorWithLoRA always using bgmv. + buffer = bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) + y = bgmv_expand(buffer, + lora_b_stacked, + y, + self.sampler_indices, + add_inputs=True) + return y.view_as(y_org) \ No newline at end of file diff --git a/vllm_mindspore/lora/utils.py b/vllm_mindspore/lora/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ed399825e5d6707a45abb24f50cee710bcb8f0e1 --- /dev/null +++ b/vllm_mindspore/lora/utils.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import re +from typing import List, Optional, Set, Tuple, Type, Union + +from torch import nn +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm_mindspore.lora.fully_sharded_layers import ( + ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora, + RowParallelLinearWithShardedLoRA) + +# TODO ReplicatedLinearWithLoRA +from vllm_mindspore.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, + LinearScalingRotaryEmbeddingWithLora, + LogitsProcessorWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLora, + QKVParallelLinearWithLora, + RowParallelLinearWithLoRA, + VocabParallelEmbeddingWithLoRA) + + +# TODO ReplicatedLinearWithLoRA +_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { + VocabParallelEmbeddingWithLoRA, + ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + QKVParallelLinearWithLora, + MergedQKVParallelLinearWithLora, + RowParallelLinearWithLoRA, + LogitsProcessorWithLoRA, + ColumnParallelLinearWithShardedLoRA, + QKVParallelLinearWithShardedLora, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLora, + RowParallelLinearWithShardedLoRA, + LinearScalingRotaryEmbeddingWithLora, +} + + +def from_layer(layer: nn.Module, + max_loras: int, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig] = None) -> nn.Module: + for lora_cls in _all_lora_classes: + # specifying kwargs so they can be easily accessed in decorator + if lora_cls.can_replace_layer(source_layer=layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config): + ret = lora_cls(layer) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret + return layer \ No newline at end of file diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index 45aa4c4399a47e4912aa4fda35b19cdfaaae38d5..ac1fa93c3f627f3dd77e115c1c98570b47c4dd40 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -21,7 +21,7 @@ from abc import abstractmethod import numpy as np import mindspore as ms -from mindspore import mint, ops, Tensor +from mindspore import mint, ops, Tensor, jit from mindspore import Parameter from vllm.distributed import ( diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index b97d71526a408102248eb8e79321d86fa7a19261..b42c4791bdb52a6c5686afbfeeb079fcb58cda3d 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -142,8 +142,17 @@ class MsModelBase(): module_name = name + "." + module_name yield module_name, sub_module - def get_submodule(self): - raise RuntimeError("Cannot get submodule for mindspore model now!") + def get_submodule(self, target: str): + parts = target.split(".") + if target == "": + return self + for part in parts: + if not part: + raise ValueError(f"Invalid submodule path: empty part in '{target}'") + current = self + for part in parts: + current = getattr(current, part) + return current def eval(self): self._check_modules_valid() diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index 32d9da8d91b1a3cf2e8a7f6f51d51e152793bb09..17def30467fe93578ea3ce0334d28c5b1286230b 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -58,6 +58,7 @@ from vllm.sequence import IntermediateTensors from vllm.attention.backends.abstract import AttentionType from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.attention.backends.abstract import AttentionMetadata +from vllm.model_executor.models.interfaces import SupportsLoRA class Qwen2MLP(nn.Cell): @@ -92,7 +93,6 @@ class Qwen2MLP(nn.Cell): "Only silu is supported for now.") self.act_fn = SwiGLU() - @jit def construct(self, x): x, _ = self.gate_up_proj(x) x = self.act_fn(x) @@ -176,7 +176,6 @@ class Qwen2Attention(nn.Cell): self.attn_mask = mint.triu(mint.ones(size=(128, 128), dtype=mstype.bfloat16), 1) self.hard_mask = Tensor([0], dtype=mstype.bfloat16).reshape(1, 1) - @jit def construct( self, positions: Tensor, @@ -248,7 +247,6 @@ class Qwen2DecoderLayer(nn.Cell): eps=config.rms_norm_eps, params_dtype=mstype.bfloat16,) - @jit def construct( self, positions: Tensor, @@ -333,7 +331,6 @@ class Qwen2Model(nn.Cell): def get_input_embeddings(self, input_ids: Tensor) -> Tensor: return self.embed_tokens(input_ids) - @jit def construct( self, input_ids: Optional[Tensor], @@ -431,7 +428,7 @@ class Qwen2Model(nn.Cell): return loaded_params -class Qwen2ForCausalLM(MsModelBase): +class Qwen2ForCausalLM(MsModelBase, SupportsLoRA): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm_mindspore/platforms/ascend.py b/vllm_mindspore/platforms/ascend.py index b96403d4959fe14be3fcf3dc2f43d2b9bcc564e1..7d492445b63bdd198e7fb950e2555288e3123c1e 100644 --- a/vllm_mindspore/platforms/ascend.py +++ b/vllm_mindspore/platforms/ascend.py @@ -127,3 +127,7 @@ class AscendPlatform(Platform): """Get the total memory of a device in bytes.""" device_props = torch.cuda.get_device_properties(device_id) return device_props.total_memory + + @classmethod + def get_punica_wrapper(cls) -> str: + return "vllm_mindspore.lora.punica_wrapper.punica_npu.PunicaWrapperNPU" \ No newline at end of file diff --git a/vllm_mindspore/worker/model_runner.py b/vllm_mindspore/worker/model_runner.py index 561fd2021dd7c84764a04aaa1b3b06389f720b55..f5b0134743e91a4d120276bf01d6ac472944b0de 100644 --- a/vllm_mindspore/worker/model_runner.py +++ b/vllm_mindspore/worker/model_runner.py @@ -47,31 +47,8 @@ def _dummy_run(self, # Enable top-k sampling to reflect the accurate memory usage. sampling_params = \ SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) - - # This represents the maximum number of different requests - # that will have unique loras, an therefore the max amount of memory - # consumption create dummy lora request copies from the lora request - # passed in, which contains a lora from the lora warmup path. - dummy_lora_requests: List[LoRARequest] = [] + # TODO _dummy_run support lora dummy_lora_requests_per_seq: List[LoRARequest] = [] - if self.lora_config: - assert self.lora_manager is not None - with self.lora_manager.dummy_lora_cache(): - for idx in range(self.lora_config.max_loras): - lora_id = idx + 1 - dummy_lora_request = LoRARequest( - lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, - lora_path="/not/a/real/path", - ) - self.lora_manager.add_dummy_lora(dummy_lora_request, - rank=LORA_WARMUP_RANK) - dummy_lora_requests.append(dummy_lora_request) - dummy_lora_requests_per_seq = [ - dummy_lora_requests[idx % len(dummy_lora_requests)] - for idx in range(max_num_seqs) - ] - # Profile memory usage with max_num_sequences sequences and the # total number of tokens equal to max_num_batched_tokens. seqs: List[SequenceGroupMetadata] = [] @@ -161,10 +138,6 @@ def _dummy_run(self, self.execute_model(model_input, kv_caches, intermediate_tensors) torch.cuda.synchronize() - if self.lora_config: - # Remove dummy loras. - assert self.lora_manager is not None - self.remove_all_loras() return