From 6a25eef3b2d3347b2584eaaae183ed48a0d88d12 Mon Sep 17 00:00:00 2001 From: w00521005 Date: Mon, 1 Sep 2025 12:33:26 +0800 Subject: [PATCH] adapt ds_mtp-model(mcore) --- tests/mindformers | 2 +- vllm_mindspore/lora/ops/torch_ops/lora_ops.py | 2 +- vllm_mindspore/model_executor/model_loader/utils.py | 5 ++++- vllm_mindspore/model_executor/models/mf_models/config.py | 4 ++++ .../model_executor/models/mf_models/mindformers.py | 3 +++ vllm_mindspore/model_executor/models/utils.py | 3 +++ vllm_mindspore/worker/worker.py | 2 +- 7 files changed, 17 insertions(+), 4 deletions(-) diff --git a/tests/mindformers b/tests/mindformers index b4f30790..21bf5ccd 160000 --- a/tests/mindformers +++ b/tests/mindformers @@ -1 +1 @@ -Subproject commit b4f3079026a1f2d809be7485f6bae6d96a74c423 +Subproject commit 21bf5ccddb757e02ea139a6beb525c878dc03e9e diff --git a/vllm_mindspore/lora/ops/torch_ops/lora_ops.py b/vllm_mindspore/lora/ops/torch_ops/lora_ops.py index 094e24b8..38c4b39a 100644 --- a/vllm_mindspore/lora/ops/torch_ops/lora_ops.py +++ b/vllm_mindspore/lora/ops/torch_ops/lora_ops.py @@ -25,7 +25,7 @@ from mindspore.ops.auto_generate import grouped_matmul_v4 def einsum_ms(inputs, selected_loras): - # mint.einsum("bi, boi -> bo", inputs, selected_loras) + # equal to einsum("bi, boi -> bo", inputs, selected_loras) selected_loras = mint.transpose(selected_loras, 1, 2) outputs = mint.matmul(inputs.unsqueeze(1), selected_loras).squeeze(1) return outputs diff --git a/vllm_mindspore/model_executor/model_loader/utils.py b/vllm_mindspore/model_executor/model_loader/utils.py index 54288209..fb155027 100644 --- a/vllm_mindspore/model_executor/model_loader/utils.py +++ b/vllm_mindspore/model_executor/model_loader/utils.py @@ -33,7 +33,10 @@ from vllm_mindspore.utils import (is_mindformers_model_backend, def mf_mcore_compatible(arch): - return arch in mcore_support_list + # vllm overrides the model arch to `DeepSeekMTPModel` for mtp model, + # which is not registered in mf independently and is + # a sub-class of `DeepseekV3ForCausalLM`. + return arch in mcore_support_list or arch == "DeepSeekMTPModel" def resolve_mf_mcore_arch(model_config: ModelConfig, architectures: list[str]): diff --git a/vllm_mindspore/model_executor/models/mf_models/config.py b/vllm_mindspore/model_executor/models/mf_models/config.py index 469a3797..3a1d281f 100644 --- a/vllm_mindspore/model_executor/models/mf_models/config.py +++ b/vllm_mindspore/model_executor/models/mf_models/config.py @@ -127,6 +127,10 @@ MODEL_RELATED_MAPPING = { }, 'deepseek_v3': { 'multi_latent_attention': True, + }, + 'deepseek_mtp': { + 'multi_latent_attention': True, + 'model_type': 'deepseek_mtp' } # Add anther model type... } diff --git a/vllm_mindspore/model_executor/models/mf_models/mindformers.py b/vllm_mindspore/model_executor/models/mf_models/mindformers.py index 2d7cacfe..8bfa7681 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mindformers.py +++ b/vllm_mindspore/model_executor/models/mf_models/mindformers.py @@ -350,6 +350,9 @@ class MindFormersForCausalLM(MsModelBase, SupportsPP): if intermediate_tensors is not None: model_inputs["hidden_states"] = \ intermediate_tensors["hidden_states"] + elif kwargs.get("previous_hidden_states") is not None: + # used for deepseek-mtp + model_inputs["hidden_states"] = kwargs["previous_hidden_states"] if is_prefill or is_ringmla_chunked: self.network.phase = \ diff --git a/vllm_mindspore/model_executor/models/utils.py b/vllm_mindspore/model_executor/models/utils.py index 2c2f0ec6..f099df88 100644 --- a/vllm_mindspore/model_executor/models/utils.py +++ b/vllm_mindspore/model_executor/models/utils.py @@ -277,6 +277,9 @@ def is_use_ringmla(vllm_config, mf_config=None): return False if is_310p(): return False + if vllm_config.model_config.hf_config.model_type == "deepseek_mtp": + # weight of deepseek mtp model has not been quantized + return False use_ringmla = (vllm_config.model_config.use_mla and vllm_config.model_config.quantization is not None and vllm_config.parallel_config.tensor_parallel_size < 16) diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index 111c89a2..2a78881b 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -194,7 +194,7 @@ def _prepare_input_for_warmup(model_config, model_input = model_runner.prepare_model_input(seqs) previous_hidden_states = None if not is_mtp_model else torch.ones( - [bs, seq_len, model_config.get_hidden_size()], + [bs * seq_len, model_config.get_hidden_size()], dtype=get_valid_dtype(model_config.dtype)) return model_input, previous_hidden_states -- Gitee