diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 2b711153a2c9db12690b254af6700404ac8d1631..b7b1a4402c2547851fb1fa54ff5f8f137a2c99fd 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -100,10 +100,11 @@ vllm.utils.memory_profiling = ms_memory_profiling import vllm.lora.utils from vllm_mindspore.model_executor.layers.linear import LinearBase -from vllm_mindspore.lora.utils import _all_lora_classes +from vllm_mindspore.lora.utils import _all_lora_classes, replace_submodule vllm.lora.utils._all_lora_classes = _all_lora_classes vllm.lora.utils.LinearBase = LinearBase +vllm.lora.utils.replace_submodule = replace_submodule import vllm.lora.models from vllm_mindspore.lora.models import ( diff --git a/vllm_mindspore/lora/layers.py b/vllm_mindspore/lora/layers.py index a7b53709cbd27b8fb0d6c088f98b8cd910878a80..356e1950aa6ee0fc5d8008ae7abe354a159533fc 100644 --- a/vllm_mindspore/lora/layers.py +++ b/vllm_mindspore/lora/layers.py @@ -24,6 +24,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Optional, Union, cast import mindspore as ms +from mindspore import Parameter, ops, mint +from mindspore.common.initializer import initializer import torch import torch.nn as nn import torch.nn.functional as F @@ -326,7 +328,7 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): self.output_size, self.tp_size)) else: raise NotImplementedError - + ''' self.lora_a_stacked = tuple( torch.zeros( max_loras, @@ -345,16 +347,19 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): dtype=lora_config.lora_dtype, device=self.device, ) for _ in range(self.n_slices)) + ''' + self.lora_a_stacked = Parameter( + initializer('zeros', (self.n_slices, max_loras, 1, lora_a_out_size, self.input_size), + lora_config.lora_dtype)) + self.lora_b_stacked = Parameter( + initializer('zeros', (self.n_slices, max_loras, 1, lora_b_out_size, lora_config.max_lora_rank), + lora_config.lora_dtype)) if lora_config.bias_enabled: lora_bias_out_size = lora_b_out_size - self.lora_bias_stacked = tuple( - torch.zeros( - max_loras, - 1, - lora_bias_out_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) for _ in range(self.n_slices)) + self.lora_bias_stacked = Parameter( + initializer('zeros', (self.n_slices, max_loras, 1, lora_bias_out_size), + lora_config.lora_dtype)) + self.output_slices = (self.lora_b_stacked[0].shape[2], ) def reset_lora(self, index: int): @@ -407,9 +412,9 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): x: ms.Tensor, bias: Optional[ms.Tensor] = None) -> ms.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked, - self.lora_b_stacked, - self.lora_bias_stacked, 1.0, + self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked[0], + self.lora_b_stacked[0], + self.lora_bias_stacked[0], 1.0, self.output_slices) return output @@ -551,6 +556,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): lora_a_output_size_per_partition = ( lora_config.max_lora_rank if not lora_config.fully_sharded_loras else divide(lora_config.max_lora_rank, self.tp_size)) + ''' self.lora_a_stacked = tuple( torch.zeros( max_loras, @@ -569,15 +575,17 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): dtype=lora_config.lora_dtype, device=self.device, ) for output_size in self.output_slices) + ''' + self.lora_a_stacked = Parameter( + initializer('zeros', (self.n_slices, max_loras, 1, lora_a_output_size_per_partition, self.input_size), + lora_config.lora_dtype)) + self.lora_b_stacked = Parameter( + initializer('zeros', (self.n_slices, max_loras, 1, self.output_slices[0], lora_config.max_lora_rank), + lora_config.lora_dtype)) if lora_config.bias_enabled: - self.lora_bias_stacked = tuple( - torch.zeros( - max_loras, - 1, - output_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) for output_size in self.output_slices) + self.lora_bias_stacked = Parameter( + initializer('zeros', (self.n_slices, max_loras, 1, self.output_slices[0]), + lora_config.lora_dtype)) def slice_lora_a( self, lora_a: list[Union[ms.Tensor, @@ -652,6 +660,25 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): return (type(source_layer) is MergedColumnParallelLinear and len(packed_modules_list) == 2) + def apply(self, + x: ms.Tensor, + bias: Optional[ms.Tensor] = None) -> ms.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + output_w1, output_w3= mint.split( + output, + (self.lora_b_stacked[0].shape[2], self.lora_b_stacked[0].shape[2]), + -1) + output_w1 = self.punica_wrapper(output_w1, x, self.lora_a_stacked[0], + self.lora_b_stacked[0], + self.lora_bias_stacked[0], 1.0, + self.output_slices) + output_w3 = self.punica_wrapper(output_w3, x, self.lora_a_stacked[1], + self.lora_b_stacked[1], + self.lora_bias_stacked[1], 1.0, + self.output_slices) + + return ops.cat((output_w1, output_w3), 1) + class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): """ @@ -781,6 +808,29 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): return (type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 3) + def apply(self, + x: ms.Tensor, + bias: Optional[ms.Tensor] = None) -> ms.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + output_q, output_k, output_v = mint.split( + output, + (self.q_proj_shard_size, self.kv_proj_shard_size, self.kv_proj_shard_size), + -1) + output_q = self.punica_wrapper(output_q, x, self.lora_a_stacked[0], + self.lora_b_stacked[0], + self.lora_bias_stacked[0], 1.0, + self.output_slices) + output_k = self.punica_wrapper(output_k, x, self.lora_a_stacked[1], + self.lora_b_stacked[1,:,:,:self.output_slices[1],:], + self.lora_bias_stacked[1,:,:,:self.output_slices[1]], 1.0, + self.output_slices) + output_v = self.punica_wrapper(output_v, x, self.lora_a_stacked[2], + self.lora_b_stacked[2,:,:,:self.output_slices[2],:], + self.lora_bias_stacked[2,:,:,:self.output_slices[2]], 1.0, + self.output_slices) + + return ops.cat((output_q, output_k, output_v), 1) + class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): diff --git a/vllm_mindspore/lora/ops/torch_ops/lora_ops.py b/vllm_mindspore/lora/ops/torch_ops/lora_ops.py index d3d48975546d52b58248e37bc3c43c77eaed549a..e003bea86331be31f496e8f6c8f07bde9f24b299 100644 --- a/vllm_mindspore/lora/ops/torch_ops/lora_ops.py +++ b/vllm_mindspore/lora/ops/torch_ops/lora_ops.py @@ -20,7 +20,7 @@ """ For punica_npu """ -from mindspore import mint +from mindspore import mint, ops from mindspore.ops.auto_generate import grouped_matmul_v4 @@ -85,7 +85,7 @@ def bgmv_expand(inputs, def sgmv_shrink( inputs, lora_a_weights, - output_tensor, + group_list, b_seq_start_loc, seq_len_tensor, lora_indices_tensor, @@ -94,47 +94,55 @@ def sgmv_shrink( token_nums, scaling, ): - group_list = seq_len_tensor - if (lora_indices_tensor.unique().shape[0] != lora_indices_tensor.shape[0]): - sorted_ids, sorted_counts = sort_lora_by_token_count( - lora_indices_tensor, seq_len_tensor) - group_list = sorted_counts - if lora_a_weights.shape[0] != group_list.shape[0]: - new_tensor = mint.zeros(lora_a_weights.shape[0], - dtype=group_list.dtype) - new_tensor[:group_list.size(0)] = group_list - group_list = new_tensor - if len(lora_a_weights.shape) == 4: - lora_a_weights = lora_a_weights.squeeze(1) - lora_a_weights = mint.transpose(lora_a_weights, 1, 2) + # group_list = seq_len_tensor + # if (lora_indices_tensor.unique().shape[0] != lora_indices_tensor.shape[0]): + # sorted_ids, sorted_counts = sort_lora_by_token_count( + # lora_indices_tensor, seq_len_tensor) + # group_list = sorted_counts + # if lora_a_weights.shape[0] != group_list.shape[0]: + # new_tensor = mint.zeros(lora_a_weights.shape[0], + # dtype=group_list.dtype) + # new_tensor[:group_list.size(0)] = group_list + # group_list = new_tensor + # if len(lora_a_weights.shape) == 4: + # lora_a_weights = lora_a_weights.squeeze(1) + # lora_a_weights = mint.transpose(lora_a_weights, 1, 2) + # outputs = grouped_matmul_v4([inputs], [lora_a_weights], + # group_list=group_list, + # split_item=3, + # group_type=0, + # group_list_type=1) + #outputs = bgmv_shrink(inputs, lora_a_weights, 0, scaling) + #group_list = seq_len_tensor + lora_a_weights = lora_a_weights.squeeze(1) + lora_a_weights = mint.transpose(lora_a_weights, 1, 2) + #lora_a_weights = lora_a_weights[lora_indices_tensor] outputs = grouped_matmul_v4([inputs], [lora_a_weights], group_list=group_list, split_item=3, group_type=0, - group_list_type=1) - outputs = outputs[0] - output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] - return output_tensor + group_list_type=1)[0] + # output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] + return outputs * scaling def bgmv_shrink(inputs, lora_b_weights, - output_tensor, lora_indices_tensor, scaling=1.0): - selected_loras = lora_b_weights[lora_indices_tensor].astype( - output_tensor.dtype) - inputs = inputs.astype(output_tensor.dtype) - if len(selected_loras.shape) == 4: - selected_loras = selected_loras.squeeze(1) - outputs = einsum_ms(inputs, selected_loras) - output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] - return output_tensor + selected_loras = lora_b_weights[lora_indices_tensor] + inputs = inputs.astype(lora_b_weights[0].dtype) + # if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(1) + outputs = einsum_ms(inputs, selected_loras) * scaling + # output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] + return scaling * outputs def sgmv_expand_slice(inputs, lora_b_weights, output_tensor, + group_list, b_seq_start_loc, seq_len_tensor, lora_indices_tensor, @@ -144,30 +152,41 @@ def sgmv_expand_slice(inputs, slice_offset, slice_size, add_inputs=False): - group_list = seq_len_tensor - if (lora_indices_tensor.unique().shape[0] != lora_indices_tensor.shape[0]): - sorted_ids, sorted_counts = sort_lora_by_token_count( - lora_indices_tensor, seq_len_tensor) - group_list = sorted_counts - if lora_b_weights.shape[0] != group_list.shape[0]: - new_tensor = mint.zeros(lora_b_weights.shape[0], - dtype=group_list.dtype) - new_tensor[:group_list.size(0)] = group_list - group_list = new_tensor - if len(lora_b_weights.shape) == 4: - lora_b_weights = lora_b_weights.squeeze(1) - lora_b_weights = mint.transpose(lora_b_weights, 1, 2) - inputs = inputs.astype(output_tensor.dtype) + # group_list = seq_len_tensor + # if (lora_indices_tensor.unique().shape[0] != lora_indices_tensor.shape[0]): + # sorted_ids, sorted_counts = sort_lora_by_token_count( + # lora_indices_tensor, seq_len_tensor) + # group_list = sorted_counts + # if lora_b_weights.shape[0] != group_list.shape[0]: + # new_tensor = mint.zeros(lora_b_weights.shape[0], + # dtype=group_list.dtype) + # new_tensor[:group_list.size(0)] = group_list + # group_list = new_tensor + # if len(lora_b_weights.shape) == 4: + # lora_b_weights = lora_b_weights.squeeze(1) + # lora_b_weights = mint.transpose(lora_b_weights, 1, 2) + # inputs = inputs.astype(output_tensor.dtype) + # outputs = grouped_matmul_v4([inputs], [lora_b_weights], + # group_list=group_list, + # split_item=3, + # group_type=0, + # group_list_type=1) + # outputs = outputs[0] + # if add_inputs: + # output_tensor += outputs[:] + # else: + # output_tensor = outputs[:] + #output_tensor = bgmv_expand_slice(inputs, lora_b_weights, output_tensor, 0, 0, 0) + #group_list = seq_len_tensor + lora_b_weights = lora_b_weights.squeeze(1) + lora_b_weights = mint.transpose(lora_b_weights, 1, 2) + #lora_b_weights = lora_b_weights[lora_indices_tensor] outputs = grouped_matmul_v4([inputs], [lora_b_weights], group_list=group_list, split_item=3, group_type=0, - group_list_type=1) - outputs = outputs[0] - if add_inputs: - output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] - else: - output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:] + group_list_type=1)[0] + output_tensor = ops.add(output_tensor, outputs) return output_tensor @@ -181,11 +200,12 @@ def bgmv_expand_slice(inputs, selected_loras = lora_b_weights[lora_indices_tensor].astype( output_tensor.dtype) inputs = inputs.astype(output_tensor.dtype) - if len(selected_loras.shape) == 4: - selected_loras = selected_loras.squeeze(1) + # if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(1) outputs = einsum_ms(inputs, selected_loras) - if add_inputs: - output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] - else: - output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:] + output_tensor = ops.add(output_tensor, outputs) + # if add_inputs: + # output_tensor += outputs[:] + # else: + # output_tensor = outputs[:] return output_tensor diff --git a/vllm_mindspore/lora/punica_wrapper/punica_npu.py b/vllm_mindspore/lora/punica_wrapper/punica_npu.py index 0a60baf277a56154152aaec146c5db162deec49a..545a169eb0f4ec04def2a2b9dd59859820d8292a 100644 --- a/vllm_mindspore/lora/punica_wrapper/punica_npu.py +++ b/vllm_mindspore/lora/punica_wrapper/punica_npu.py @@ -19,20 +19,22 @@ # isort: skip_file """Punica wrapper for NPU.""" -from typing import Callable +from typing import Callable, Optional -from mindspore import mint +from mindspore import mint, nn, Parameter, ops, dtype from mindspore.common import dtype as mstype +from mindspore.common.initializer import initializer from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase from vllm_mindspore.lora.ops.torch_ops.lora_ops import ( bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand, - sgmv_expand_slice, sgmv_shrink) + sgmv_expand_slice, sgmv_shrink, sort_lora_by_token_count) +from vllm_mindspore.model_executor.utils import get_model_context # The platforms that are compatible with the PyTorch-native implementation can # inherit this class -class PunicaWrapperNPU(PunicaWrapperBase): +class PunicaWrapperNPU(PunicaWrapperBase, nn.Cell): """ PunicaWrapperNPU is designed to manage and provide metadata for the punica kernel. The main function is to maintain the state information for @@ -40,32 +42,34 @@ class PunicaWrapperNPU(PunicaWrapperBase): """ def __init__(self, max_num_batched_tokens, max_batches, device, **kwargs): + nn.Cell.__init__(self) PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) + self.max_loras = kwargs["max_loras"] + self.group_list = Parameter(initializer("ones", self.max_loras, dtype.int64), name="group_list") + self.lora_indices = Parameter(initializer("ones", self.max_loras, dtype.int64), name="lora_indices") def _shrink_prefill( self, - y, x, w_t_all, scale, ): - sgmv_shrink( # type: ignore + return sgmv_shrink( # type: ignore x, w_t_all, - y, + self.group_list, *self.prefill_metadata, scale, ) def _shrink_decode( self, - y, x, w_t_all, scale, ): - bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + return bgmv_shrink(x, w_t_all, self.token_lora_indices, scale) def _expand_prefill( self, @@ -100,10 +104,11 @@ class PunicaWrapperNPU(PunicaWrapperBase): y_slice_size, add_inputs, ): - sgmv_expand_slice( # type: ignore + return sgmv_expand_slice( # type: ignore x, w_t_all, y, + self.group_list, *self.prefill_metadata, y_offset, y_slice_size, @@ -119,7 +124,7 @@ class PunicaWrapperNPU(PunicaWrapperBase): y_slice_size, add_inputs, ): - bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, + return bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs) def _apply_expand( @@ -138,11 +143,11 @@ class PunicaWrapperNPU(PunicaWrapperBase): """ expand_slice_fun: Callable = (self._expand_slice_prefill - if self.is_prefill else + if get_model_context("is_prefill") else self._expand_slice_decode) - expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs) + return self._expand_slice_prefill(y, x, w_t_all, y_offset, y_slice_size, add_inputs) - def _apply_shrink(self, y, x, w_t_all, scale): + def _apply_shrink(self, x, w_t_all, scale): """ Perform the ` y+=x@w_t_all` computation, which is suitable for the GEMM of lora'a. @@ -151,14 +156,82 @@ class PunicaWrapperNPU(PunicaWrapperBase): Otherwise, it is the decode stage, and the _shrink_decode function should be called. """ - y_org = y - y = y.view(-1, y.shape[-1]) + # y_org = y + # y = y.view(-1, y.shape[-1]) shrink_fun: Callable = (self._shrink_prefill - if self.is_prefill else self._shrink_decode) - shrink_fun(y, x, w_t_all, scale) - y.view_as(y_org) + if get_model_context("is_prefill") else self._shrink_decode) + # y = shrink_fun(x, w_t_all, scale) + return self._shrink_prefill(x, w_t_all, scale) - def add_shrink(self, y, x, lora_a_stacked, scale, **kwargs): + def _apply_bias( + self, + indices, + output, + output_slices: tuple[int, ...], + lora_bias_stacked, + ): + """Applies bias to output + + Input shapes: + lora_bias_stacked: 3 element tuple of (num_loras, output_dim) + indices: (batch_size) + output: (batch_size, q_slice_size + 2*kv_slice_size) + output_slices: n-1 element tuple of (slice_size...), + where n is number of slices + """ + org_output = output + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + + bias = lora_bias_stacked.view(-1, lora_bias_stacked.shape[-1]) + bias = bias[indices] + bias[indices == -1] = 0 + output = ops.add(output, bias) + return output.view_as(org_output) + + def update_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: list[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + **kwargs): + self._update_base_metadata(mapping, lora_index_to_id, max_loras, + vocab_size, extra_vocab_size, + long_lora_context) + if mapping.is_prefill: + # Update metadata required for prefill-related operators. + self._update_prefill_metadata(self.token_lora_indices) + self.is_prefill = True + else: + self.is_prefill = False + _, seq_len, lora_indices, _, _, _ = self.prefill_metadata + sorted_ids, sorted_counts = sort_lora_by_token_count( + lora_indices, seq_len) + group_list = sorted_counts + if len(group_list) < self.max_loras: + new_tensor = mint.zeros(self.max_loras, + dtype=group_list.dtype) + new_tensor[:group_list.size(0)] = group_list + group_list = new_tensor + self.group_list.set_data(group_list.astype(dtype.int64)) + # if len(lora_indices) > self.max_loras: + # self.group_list.set_data(seq_len[:self.max_loras].astype(dtype.int64)) + # self.lora_indices.set_data(lora_indices[:self.max_loras].astype(dtype.int64)) + # elif len(lora_indices) < self.max_loras: + # pad_len = int(self.max_loras - len(lora_indices)) + # lora_indices = ops.pad(lora_indices, (0, pad_len), mode='constant', value=0) + # seq_len = ops.pad(seq_len, (0, pad_len), mode='constant', value=0) + # self.group_list.set_data(seq_len.astype(dtype.int64)) + # self.lora_indices.set_data(lora_indices.astype(dtype.int64)) + # else: + # self.group_list.set_data(seq_len.astype(dtype.int64)) + # self.lora_indices.set_data(lora_indices.astype(dtype.int64)) + + + def add_shrink(self, x, lora_a_stacked, scale, **kwargs): """ Performs GEMM for multiple slices of lora_a. When `is_prefill is` true, it indicates that it is currently the @@ -179,9 +252,10 @@ class PunicaWrapperNPU(PunicaWrapperBase): x = x.view(-1, x.shape[-1]) # TODO fuse these kernels - for slice_idx in range(len(lora_a_stacked)): - self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], - scale) + # for slice_idx in range(len(lora_a_stacked)): + y = self._apply_shrink(x, lora_a_stacked, + scale) + return y def add_expand(self, y, @@ -215,19 +289,17 @@ class PunicaWrapperNPU(PunicaWrapperBase): y = y.view(-1, y.shape[-1]) offset_left = offset_start if lora_bias_stacked is not None: - self._apply_bias(self.token_lora_indices, y, output_slices, - lora_bias_stacked) - for slice_idx in range(len(lora_b_stacked)): - self._apply_expand( - y, - x[slice_idx], - lora_b_stacked[slice_idx], - offset_left, - output_slices[slice_idx], - add_inputs=add_inputs, - ) - offset_left += output_slices[slice_idx] - y.view_as(y_org) + y = self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + y = self._apply_expand( + y, + x, + lora_b_stacked, + offset_left, + output_slices, + add_inputs=add_inputs, + ) + return y.view_as(y_org) def add_lora_embedding(self, y, @@ -253,7 +325,7 @@ class PunicaWrapperNPU(PunicaWrapperBase): # Embedding layer only need expand op expand_fun: Callable = (self._expand_prefill if self.is_prefill else self._expand_decode) - expand_fun(y, x, lora_b_stacked, add_inputs) + return expand_fun(y, x, lora_b_stacked, add_inputs) def add_lora_linear(self, y, @@ -292,27 +364,20 @@ class PunicaWrapperNPU(PunicaWrapperBase): if self.no_lora: return x = x.reshape(-1, x.shape[-1]) - assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + # assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) if lora_bias_stacked is not None: - assert len(lora_bias_stacked) == len(output_slices) y = self._apply_bias(self.token_lora_indices, y, output_slices, lora_bias_stacked) - if buffer is None: - r = lora_b_stacked[0].shape[-1] - # We set the buffer to be float32 by default, consistent with the - # triton op - buffer = tuple( - mint.zeros((x.shape[0], r), dtype=mstype.float32) - for _ in range(len(output_slices))) - self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) - self.add_expand(y, + buffer = self.add_shrink(x, lora_a_stacked, scale, **kwargs) + y = self.add_expand(y, buffer, lora_b_stacked, None, output_slices, add_inputs=True, **kwargs) + return y def add_lora_logits(self, y, @@ -357,3 +422,6 @@ class PunicaWrapperNPU(PunicaWrapperBase): self.sampler_indices, add_inputs=True) y.view_as(y_org) + + def construct(self, *args, **kwargs): + return self.add_lora_linear(*args, **kwargs) diff --git a/vllm_mindspore/lora/utils.py b/vllm_mindspore/lora/utils.py index 0a96b5559c0400de2d5a638695207dc77497ecbd..a4bc9389e9f723d0c477a6d7706adb52efdc3af6 100644 --- a/vllm_mindspore/lora/utils.py +++ b/vllm_mindspore/lora/utils.py @@ -50,3 +50,21 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = { RowParallelLinearWithShardedLoRA, LinearScalingRotaryEmbeddingWithLoRA, } + +def replace_submodule(model, module_name, new_module): + """Replace a submodule in a model with a new module.""" + parent = model.get_submodule(".".join(module_name.split(".")[:-1])) + target_name = module_name.split(".")[-1] + weight_name = module_name + ".weight" + lora_a_name = module_name + ".lora_a_weight" + lora_b_name = module_name + ".lora_b_weight" + lora_bias_name = module_name + ".lora_bias" + bias_name = module_name + ".bias" + setattr(parent, target_name, new_module) + new_module.base_layer.weight.name = weight_name + new_module.lora_a_stacked.name = lora_a_name + new_module.lora_b_stacked.name = lora_b_name + if new_module.base_layer.bias is not None: + new_module.base_layer.bias.name = bias_name + new_module.lora_bias_stacked.name = lora_bias_name + return new_module diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 58178292e8cfa5b01a77676bb497e66bac6ffc3f..bdf51db519b961e457af7047e08d01eae8632578 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -362,9 +362,9 @@ class NativeModel(MsModelBase): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__(vllm_config=vllm_config, prefix=prefix) self.quant_config = vllm_config.quant_config - if vllm_config.lora_config is not None: - # native model lora only support pynative mode now - vllm_config.model_config.enforce_eager = True + # if vllm_config.lora_config is not None: + # # native model lora only support pynative mode now + # vllm_config.model_config.enforce_eager = True self.is_eager_mode = vllm_config.model_config.enforce_eager self.prefill_graph = None self.decode_graph = None