From e49e9b35a86a65562a8eb5966fc931530753f4a6 Mon Sep 17 00:00:00 2001 From: fengtingyan Date: Tue, 1 Jul 2025 22:29:46 +0800 Subject: [PATCH] [feature]Enable MLA op --- .../models/mf_models/deepseek_v3.py | 38 ++++++++++++++++--- .../mf_models/deepseekv3_weight_processor.py | 20 +++++++--- .../models/mf_models/mf_model_base.py | 4 ++ .../model_executor/models/model_base.py | 3 -- vllm_mindspore/v1/worker/gpu_model_runner.py | 31 +++++++++++---- 5 files changed, 73 insertions(+), 23 deletions(-) diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py index 1e2df73a..014d11c2 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py @@ -120,6 +120,21 @@ def _get_padding_index(q_seq_len): ms.from_numpy(ffn_unpadding_idx) +class DeepseekV3MLAAttentionWrapper(MLAAttentionWrapper): + def __init__(self): + super().__init__() + vllm_config = get_current_vllm_config() + self.use_mla_op = bool(vllm_config.additional_config and vllm_config.additional_config.get('use_mla_op') == 1) + if self.use_mla_op: + kv_lora_rank = getattr(vllm_config.model_config.hf_text_config, 'kv_lora_rank', 0) + qk_rope_head_dim = getattr(vllm_config.model_config.hf_text_config, 'qk_rope_head_dim', 0) + # k_shape, r_shape used for mla_op + k_shape = [*(self.kv_shape[0:-1]), kv_lora_rank] if self.use_mla_op else None + r_shape = [*(self.kv_shape[0:-1]), qk_rope_head_dim] if self.use_mla_op else None + self.kv_cache = [(ms.mint.zeros(k_shape, dtype=vllm_config.model_config.dtype), + ms.mint.zeros(r_shape, dtype=vllm_config.model_config.dtype)) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size)] + class DeepseekV3ForCausalLM(MfModelBase): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: @@ -133,7 +148,7 @@ class DeepseekV3ForCausalLM(MfModelBase): self.sampler = get_sampler() self.set_modules({"model": self.network}) - self.kv_caches = [MLAAttentionWrapper() for i in range(self.mf_model_config.num_layers)] + self.kv_caches = [DeepseekV3MLAAttentionWrapper() for i in range(self.mf_model_config.num_layers)] compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: @@ -151,6 +166,9 @@ class DeepseekV3ForCausalLM(MfModelBase): self.mf_model_config = DeepseekV3Config_MF(**self.mf_config.model.model_config) self.mf_model_config.enable_micro_batch = self.enable_micro_batch + self.mf_model_config.use_mla_op = self.use_mla_op + if self.use_mla_op: + assert envs.VLLM_USE_V1 if self.mf_config.moe_config: self.mf_model_config.moe_config = self.mf_config.moe_config # dispatch/combine in moe need max_num_seqs as global_max_bs @@ -174,12 +192,16 @@ class DeepseekV3ForCausalLM(MfModelBase): return network, network.lm_head def get_kvcache(self): - key_cache = [] + kv_cache = [] forward_context = get_forward_context() - for i in range(self.mf_model_config.num_layers): - k_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][0] - key_cache.append(k_cache) - return mutable(key_cache), None + kv_cache = [self.kv_caches[i].kv_cache[forward_context.virtual_engine][0] + for i in range(self.mf_model_config.num_layers)] + if not self.use_mla_op: + return mutable(kv_cache), None + else: + rope_cache = [self.kv_caches[i].kv_cache[forward_context.virtual_engine][1] + for i in range(self.mf_model_config.num_layers)] + return mutable(kv_cache), mutable(rope_cache) def connector_send_kvcache(self): logger.debug(f"reached deepseek_v3 connector_send_kvcache") @@ -216,6 +238,10 @@ class DeepseekV3ForCausalLM(MfModelBase): model_inputs["ffn_padding_idx"] = ffn_padding_idx model_inputs["ffn_unpadding_idx"] = ffn_unpadding_idx + model_inputs.pop("value_cache") + _, rope_cache = self.get_kvcache() + model_inputs["rope_cache"] = rope_cache + return model_inputs, is_prefill def get_model_path(self): diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py index c63abe69..83f5e052 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py @@ -973,20 +973,28 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): qkv2l_quant_zp_name = f"model.layers.{layer_id}.attention.qkv2l.quant_op.input_zp" qkv2l_quant_scale_name = f"model.layers.{layer_id}.attention.qkv2l.quant_op.input_scale" qkv2l_rmsnorm_beta_name = f"model.layers.{layer_id}.attention.qkv2l.quant_op.beta" - - qkv2l_weight = np.concatenate((q2l_ms_param, kv2l_ms_param), 0) + if hasattr(self.config.model.model_config, "use_mla_op" + ) and self.config.model.model_config.use_mla_op: + qkv2l_weight = np.concatenate((kv2l_ms_param, q2l_ms_param), 0) + qkv2l_bias = np.concatenate( + (kv2l_quant_bias_ms_param, q2l_quant_bias_ms_param), 0) + qkv2l_scale = np.concatenate( + (kv2l_dequant_scale_ms_param, q2l_dequant_scale_ms_param), 0) + else: + qkv2l_weight = np.concatenate((q2l_ms_param, kv2l_ms_param), 0) + qkv2l_bias = np.concatenate( + (q2l_quant_bias_ms_param, kv2l_quant_bias_ms_param), 0) + qkv2l_scale = np.concatenate( + (q2l_dequant_scale_ms_param, kv2l_dequant_scale_ms_param), 0) + parameter_dict[qkv2l_weight_name] = ms.Parameter( ms.Tensor(qkv2l_weight, ms.int8), name=qkv2l_weight_name, requires_grad=False) - qkv2l_bias = np.concatenate( - (q2l_quant_bias_ms_param, kv2l_quant_bias_ms_param), 0) parameter_dict[qkv2l_bias_name] = ms.Parameter( ms.Tensor(qkv2l_bias, ms.int32), name=qkv2l_bias_name, requires_grad=False) - qkv2l_scale = np.concatenate( - (q2l_dequant_scale_ms_param, kv2l_dequant_scale_ms_param), 0) parameter_dict[qkv2l_scale_name] = ms.Parameter( ms.Tensor(qkv2l_scale, ms.float32), name=qkv2l_scale_name, diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py index c4df0f43..d6637605 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py @@ -78,6 +78,10 @@ class MfModelBase(MsModelBase): self.mf_config.model.model_config.parallel_config.model_parallel = ( get_tensor_model_parallel_world_size()) self.mf_config.model.model_config.parallel_config.pipeline_stage = 1 + self.enable_micro_batch = \ + bool(vllm_config.additional_config and vllm_config.additional_config.get('enable_micro_batch') == 1) + self.use_mla_op = \ + bool(vllm_config.additional_config and vllm_config.additional_config.get('use_mla_op') == 1) self._generate_model_config() self.casual_mask = LowerTriangularMask( dtype=self.mf_model_config.compute_dtype, diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 123256da..2592455a 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -90,9 +90,6 @@ class MsModelBase: self.parallel_config = vllm_config.parallel_config self.load_config = vllm_config.load_config self.scheduler_config = vllm_config.scheduler_config - self.enable_micro_batch = \ - vllm_config.additional_config.get('enable_micro_batch', 0) == 1 \ - if vllm_config.additional_config is not None else False self.modules_dict = None diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index e9818717..da06e161 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -223,6 +223,9 @@ def _allocate_kv_cache_tensors( use_mla = kv_cache_spec.use_mla dtype = kv_cache_spec.dtype coef = 1 if use_mla else 2 + use_mla_op = bool(self.vllm_config.additional_config and self.vllm_config.additional_config.get('use_mla_op') == 1) + kv_lora_rank = getattr(self.vllm_config.model_config.hf_text_config, 'kv_lora_rank', 0) + qk_rope_head_dim = getattr(self.vllm_config.model_config.hf_text_config, 'qk_rope_head_dim', 0) kv_cache_raw_tensors: dict[str, torch.Tensor] = {} target_dtype = get_valid_dtype(dtype) @@ -239,10 +242,14 @@ def _allocate_kv_cache_tensors( # self.block_size * self.num_kv_heads * self.head_size * # get_dtype_size(self.dtype)) # 4. kv cache shape: num_blocks, block_size, num_kv_heads, head_size - raw_tensor_split = torch.zeros(raw_tensor_shape, - dtype=target_dtype, - device=self.device) - raw_tensors.append(raw_tensor_split) + raw_tensors.extend( + [torch.zeros(raw_tensor_shape, dtype=target_dtype, device=self.device)] + if not use_mla_op else + [torch.zeros(int(raw_tensor_shape * kv_lora_rank / (kv_lora_rank + qk_rope_head_dim)), + dtype=target_dtype, device=self.device), + torch.zeros(int(raw_tensor_shape * qk_rope_head_dim / (kv_lora_rank + qk_rope_head_dim)), + dtype=target_dtype, device=self.device)] + ) for layer_name in kv_cache_tensor.shared_by: kv_cache_raw_tensors[layer_name] = tuple(raw_tensors) @@ -270,6 +277,9 @@ def _reshape_kv_cache_tensors( Dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. """ + use_mla_op = bool(self.vllm_config.additional_config and self.vllm_config.additional_config.get('use_mla_op') == 1) + kv_lora_rank = getattr(self.vllm_config.model_config.hf_text_config, 'kv_lora_rank', 0) + qk_rope_head_dim = getattr(self.vllm_config.model_config.hf_text_config, 'qk_rope_head_dim', 0) kv_caches: dict[str, tuple] = {} for i, kv_cache_group_spec in enumerate( kv_cache_config.kv_cache_groups): @@ -279,8 +289,9 @@ def _reshape_kv_cache_tensors( raw_tensor = kv_cache_raw_tensors[layer_name] target_dtype = get_valid_dtype(kv_cache_spec.dtype) dtype_size = get_dtype_size(target_dtype) - num_blocks = (raw_tensor[0].numel() * coef * - dtype_size // kv_cache_spec.page_size_bytes) + num_blocks = \ + (raw_tensor[0].numel() if not use_mla_op else (raw_tensor[0].numel() + raw_tensor[1].numel())) * \ + coef * dtype_size // kv_cache_spec.page_size_bytes if isinstance(kv_cache_spec, FullAttentionSpec): kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, @@ -306,8 +317,12 @@ def _reshape_kv_cache_tensors( for i in range(len(kv_cache_stride_order)) ] kv_cache_layer = [] - for kv_cache_raw_tensor in kv_cache_raw_tensors[layer_name]: - cache_block = kv_cache_raw_tensor.view(kv_cache_shape[1:]).permute(*inv_order[1:]) + for idx, kv_cache_raw_tensor in enumerate(kv_cache_raw_tensors[layer_name]): + if use_mla_op: + cache_shape = [*(kv_cache_shape[1:-1]), kv_lora_rank if idx == 0 else qk_rope_head_dim] + cache_block = kv_cache_raw_tensor.view(cache_shape).permute(*inv_order[1:]) + else: + cache_block = kv_cache_raw_tensor.view(kv_cache_shape[1:]).permute(*inv_order[1:]) kv_cache_layer.append(cache_block) kv_caches[layer_name] = mutable(tuple(kv_cache_layer)) else: -- Gitee