From 0d1237ef7afd0168af1753ad7777c355263f3ba1 Mon Sep 17 00:00:00 2001 From: nashturing Date: Wed, 11 Jun 2025 10:52:35 +0800 Subject: [PATCH] [featrue] multilora support mindformers mcore model --- vllm_mindspore/__init__.py | 29 +++++++++++++------ vllm_mindspore/lora/models.py | 18 ++++++++++++ vllm_mindspore/lora/utils.py | 9 ++++++ vllm_mindspore/lora/worker_manager.py | 29 +++++++++++++++++++ .../model_executor/models/mf_models/qwen3.py | 15 ++++++++-- vllm_mindspore/utils.py | 10 +++---- 6 files changed, 94 insertions(+), 16 deletions(-) create mode 100644 vllm_mindspore/lora/worker_manager.py diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index c5e2deb..0b9e727 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -76,26 +76,19 @@ vllm.utils.cuda_is_initialized = ascend_is_initialized vllm.utils.memory_profiling = ms_memory_profiling import vllm.lora.utils - +from vllm_mindspore.lora.utils import get_supported_lora_modules_ms from vllm_mindspore.model_executor.layers.linear import LinearBase from vllm_mindspore.lora.utils import _all_lora_classes +vllm.lora.utils.get_supported_lora_modules = get_supported_lora_modules_ms vllm.lora.utils._all_lora_classes = _all_lora_classes vllm.lora.utils.LinearBase = LinearBase -import vllm.lora.models -from vllm_mindspore.lora.models import register_module, from_local_checkpoint, from_lora_tensors - -vllm.lora.models.LoRAModelManager.register_module = register_module -vllm.lora.models.LoRAModel.from_local_checkpoint = from_local_checkpoint -vllm.lora.models.LoRAModel.from_lora_tensors = from_lora_tensors - from vllm_mindspore.lora.layers import (ColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA, MergedQKVParallelLinearWithLoRA, QKVParallelLinearWithLoRA, RowParallelLinearWithLoRA) - import vllm.lora.layers vllm.lora.layers.ColumnParallelLinearWithLoRA = ColumnParallelLinearWithLoRA @@ -104,6 +97,18 @@ vllm.lora.layers.MergedQKVParallelLinearWithLoRA = MergedQKVParallelLinearWithLo vllm.lora.layers.QKVParallelLinearWithLoRA = QKVParallelLinearWithLoRA vllm.lora.layers.RowParallelLinearWithLoRA = RowParallelLinearWithLoRA +from vllm_mindspore.lora.models import LoraModelMs, LoraModelManagerMs +from vllm_mindspore.lora.worker_manager import _load_adapter_ms +import vllm.lora.models +from vllm.lora.worker_manager import WorkerLoRAManager + +vllm.lora.models.LoRAModelManager._create_lora_modules = LoraModelManagerMs._create_lora_modules +vllm.lora.models.LoRAModelManager._set_adapter_mapping = LoraModelManagerMs._set_adapter_mapping +vllm.lora.models.LoRAModelManager.register_module = LoraModelManagerMs.register_module +vllm.lora.models.LoRAModel.from_lora_tensors = LoraModelMs.from_lora_tensors +vllm.lora.models.LoRAModel.from_local_checkpoint = LoraModelMs.from_local_checkpoint +vllm.lora.worker_manager.WorkerLoRAManager._load_adapter = _load_adapter_ms + import vllm.executor from vllm_mindspore.model_executor.models.registry import ( @@ -265,13 +270,16 @@ vllm.model_executor.layers.rejection_sampler._multinomial = _multinomial ######### for multi-model from vllm_mindspore.inputs.registry import call_hf_processor from vllm.inputs.registry import InputProcessingContext + InputProcessingContext.call_hf_processor = call_hf_processor from vllm_mindspore.multimodal.inputs import as_kwargs from vllm.multimodal.inputs import MultiModalKwargs + MultiModalKwargs.as_kwargs = as_kwargs from vllm_mindspore.model_executor.layers.rotary_embedding import InferMRotaryEmbedding + vllm.model_executor.layers.rotary_embedding.MRotaryEmbedding = InferMRotaryEmbedding # patch for V1 @@ -284,6 +292,7 @@ from vllm_mindspore.v1.spec_decode import eagle update_modules("vllm.v1.spec_decode.eagle", eagle) from vllm_mindspore.v1.attention.backends import ms_attn + update_modules("vllm.v1.attention.backends.flash_attn", ms_attn) import vllm.v1.worker.gpu_model_runner @@ -297,6 +306,7 @@ from vllm_mindspore.v1.worker.gpu_model_runner import _update_states vllm.v1.worker.gpu_model_runner.GPUModelRunner._update_states = _update_states from vllm_mindspore.v1.worker.gpu_model_runner import initialize_kv_cache, get_kv_cache_spec + vllm.v1.worker.gpu_model_runner.GPUModelRunner.initialize_kv_cache = initialize_kv_cache vllm.v1.worker.gpu_model_runner.GPUModelRunner.get_kv_cache_spec = get_kv_cache_spec @@ -369,6 +379,7 @@ Worker.compile_or_warm_up_model = compile_or_warm_up_model from vllm_mindspore.v1.core.sched.scheduler import update_from_output from vllm.v1.core.sched.scheduler import Scheduler + Scheduler.update_from_output = update_from_output from .utils import check_ready diff --git a/vllm_mindspore/lora/models.py b/vllm_mindspore/lora/models.py index 9219784..07ec706 100644 --- a/vllm_mindspore/lora/models.py +++ b/vllm_mindspore/lora/models.py @@ -21,12 +21,14 @@ from typing import Dict, List, Optional, Union import safetensors.torch import torch from vllm.lora.lora import LoRALayerWeights +from vllm.lora.models import LoRAModel, LoRAModelManager from vllm.lora.peft_helper import PEFTHelper from vllm.lora.utils import is_regex_target_modules, parse_fine_tuned_lora_name from vllm.model_executor.models.utils import WeightsMapper from vllm.utils import is_pin_memory_available from vllm_mindspore.lora.layers import BaseLayerWithLoRA +from vllm_mindspore.utils import is_mindformers_model_backend _GLOBAL_LORA_ID = 0 @@ -225,3 +227,19 @@ def from_local_checkpoint( embedding_modules=embedding_modules, embedding_padding_modules=embedding_padding_modules, weights_mapper=weights_mapper) + + +if is_mindformers_model_backend(): + from mindformers.pet.loramodels import LoRAModel as MFLoRAModel + from mindformers.pet.models.vllm_multilora import VLLMLoRAModelManager + LoRAModel.from_lora_tensors = MFLoRAModel.from_lora_tensors + LoRAModel.from_local_checkpoint = MFLoRAModel.from_local_checkpoint + LoRAModelManager._create_lora_modules = VLLMLoRAModelManager._create_lora_modules + LoRAModelManager._set_adapter_mapping = VLLMLoRAModelManager._set_adapter_mapping +else: + LoRAModel.from_lora_tensors = from_lora_tensors + LoRAModel.from_local_checkpoint = from_local_checkpoint + LoRAModelManager.register_module = register_module + +LoraModelMs = LoRAModel +LoraModelManagerMs = LoRAModelManager diff --git a/vllm_mindspore/lora/utils.py b/vllm_mindspore/lora/utils.py index 0084e60..0dcf104 100644 --- a/vllm_mindspore/lora/utils.py +++ b/vllm_mindspore/lora/utils.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# isort:skip_file # Copyright 2025 Huawei Technologies Co., Ltd # Copyright 2024 The vLLM team. # @@ -22,6 +23,7 @@ from vllm.lora.fully_sharded_layers import ( MergedColumnParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA) +from vllm.lora.utils import get_supported_lora_modules from vllm_mindspore.lora.layers import ( BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, @@ -29,6 +31,7 @@ from vllm_mindspore.lora.layers import ( MergedColumnParallelLinearWithLoRA, MergedQKVParallelLinearWithLoRA, QKVParallelLinearWithLoRA, RowParallelLinearWithLoRA, VocabParallelEmbeddingWithLoRA) +from vllm_mindspore.utils import is_mindformers_model_backend _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { VocabParallelEmbeddingWithLoRA, @@ -45,3 +48,9 @@ _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { RowParallelLinearWithShardedLoRA, LinearScalingRotaryEmbeddingWithLoRA, } + +if is_mindformers_model_backend(): + from mindformers.pet.utils import get_mf_supported_lora_modules + get_supported_lora_modules_ms = get_mf_supported_lora_modules +else: + get_supported_lora_modules_ms = get_supported_lora_modules diff --git a/vllm_mindspore/lora/worker_manager.py b/vllm_mindspore/lora/worker_manager.py new file mode 100644 index 0000000..340cba4 --- /dev/null +++ b/vllm_mindspore/lora/worker_manager.py @@ -0,0 +1,29 @@ +from typing import List + +from vllm.lora.request import LoRARequest + +from vllm_mindspore.utils import is_mindformers_model_backend + +if not is_mindformers_model_backend(): + from vllm.lora.worker_manager import WorkerLoRAManager + _load_adapter_ms = WorkerLoRAManager._load_adapter +else: + from mindformers.pet.loramodels import load_lora_ckpt + + def _load_adapter_ms(self, lora_request: LoRARequest): + supported_lora_modules = (self._adapter_manager.supported_lora_modules) + packed_modules_mapping = (self._adapter_manager.packed_modules_mapping) + expected_lora_modules: List[str] = [] + for module in supported_lora_modules: + if module in packed_modules_mapping: + expected_lora_modules.extend(packed_modules_mapping[module]) + else: + expected_lora_modules.append(module) + + expected_lora_modules = list(set(expected_lora_modules)) + lora = load_lora_ckpt(lora_request, self.lora_config, self.vocab_size, + expected_lora_modules, + self.max_position_embeddings, + self.embedding_modules, + self.embedding_padding_modules) + return lora diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen3.py b/vllm_mindspore/model_executor/models/mf_models/qwen3.py index a11a93f..9f5f668 100644 --- a/vllm_mindspore/model_executor/models/mf_models/qwen3.py +++ b/vllm_mindspore/model_executor/models/mf_models/qwen3.py @@ -33,6 +33,7 @@ from vllm.config import VllmConfig, get_current_vllm_config from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.models.interfaces import SupportsLoRA from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors @@ -47,7 +48,18 @@ from vllm_mindspore.model_executor.models.model_base import (AttentionWrapper, logger = init_logger(__name__) -class Qwen3ForCausalLM(MsModelBase): +class Qwen3ForCausalLM(MsModelBase, SupportsLoRA): + packed_modules_mapping = { + "linear_qkv": [ + "linear_q", + "linear_k", + "linear_v", + ], + "linear_fc1": [ + "gating", + "linear_fc1", + ] + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__(vllm_config=vllm_config, prefix=prefix) @@ -220,5 +232,4 @@ class Qwen3ForCausalLM(MsModelBase): def load_weights(self, weights: Iterable[Tuple[str, Tensor]]): self.network.load_weights(self.mf_config.load_checkpoint) - self.network.set_dynamic_inputs() return None diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index 920bb23..9acc753 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -141,7 +141,7 @@ STR_DTYPE_TO_MS_DTYPE = { } -class vllmModelBackendEnum(str, Enum): +class VllmModelBackendEnum(str, Enum): """Define the variable Enum of vLLM_MODEL_BACKEND""" MF = 'MindFormers' MIND_ONE = 'MindONE' @@ -156,10 +156,10 @@ def is_mindformers_model_backend(): vllm_model_backend = os.getenv("vLLM_MODEL_BACKEND") # noqa: SIM112 if vllm_model_backend: try: - vllmModelBackendEnum(vllm_model_backend) - return vllm_model_backend == vllmModelBackendEnum.MF + VllmModelBackendEnum(vllm_model_backend) + return vllm_model_backend == VllmModelBackendEnum.MF except ValueError as exc: - allowed_values = [member.value for member in vllmModelBackendEnum] + allowed_values = [member.value for member in VllmModelBackendEnum] raise ValueError( f"Illegal value of vLLM_MODEL_BACKEND '{vllm_model_backend}'," f" allowed_values: {', '.join(allowed_values)}") from exc @@ -170,7 +170,7 @@ def is_mindformers_model_backend(): def is_mindone_model_backend(): return (os.getenv("vLLM_MODEL_BACKEND") # noqa: SIM112 and os.environ["vLLM_MODEL_BACKEND"] # noqa: SIM112 - == vllmModelBackendEnum.MIND_ONE) + == VllmModelBackendEnum.MIND_ONE) def check_ready(): -- Gitee