From 7e9ed959ac752166730d62ff24419d0d3e098b41 Mon Sep 17 00:00:00 2001 From: huandong Date: Fri, 8 Aug 2025 09:44:03 +0800 Subject: [PATCH] prefill/chunk/decode use ringmla --- .../models/mf_models/deepseek_v3.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py index 1e2df73a..8fdc3c90 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py @@ -63,6 +63,14 @@ try: except ImportError: pass +from mindformers.tools.utils import is_pynative +try: + # Need to apply dllm pd patch on vllm to use pd disagg related functions + from vllm.attention.layer import maybe_save_kv_layer_to_connector, wait_for_kv_layer_from_connector + from vllm.distributed.kv_transfer import is_v1_kv_transfer_group + kv_transfer_supported = True +except: + kv_transfer_supported = False logger = init_logger(__name__) @@ -312,3 +320,47 @@ class DeepseekV3ForCausalLM(MfModelBase): ptq.layer_policies[r'.*\.shared_experts.w2.*'].aclnn_quant_list = ["w2"] ptq.decoder_layer_types.append(DeepseekV3DecodeLayer) return ptq + + def forward(self, + input_ids: Tensor, + positions: Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[Tensor] = None, + **kwargs) -> Union[Tensor, IntermediateTensors]: + model_inputs, is_prefill = self.prepare_inputs(input_ids, positions) + model_inputs = self.update_model_inputs(model_inputs, **kwargs) + + # enable_mb_split is True in lager EP enable micro-batch and per-dp-bs > 1 + enable_mb_split = self.is_enable_micro_batch_split( + is_prefill, model_inputs["q_seq_lens"]) + + is_only_decode = not is_prefill and model_inputs['q_seq_lens'].max() == 1 + + if not is_only_decode: + if self.enable_micro_batch: + self.network.phase = "prefill" if not enable_mb_split else "prefill_micro_batch" + if not self.set_flags or is_pynative() or enable_mb_split: + self.network.add_flags_custom(is_first_iteration=True) + self.network.add_flags_enable_micro_batch( + enable_micro_batch=enable_mb_split) + else: + self.network.phase = "prefill" + if not self.set_flags or is_pynative(): + self.network.add_flags_custom(is_first_iteration=True) + + hidden_states = self.network(**model_inputs) + else: + self.network.phase = "increment" + if not self.set_flags or is_pynative(): + self.network.add_flags_custom(is_first_iteration=False) + self.set_flags = True + if kv_transfer_supported: + if is_v1_kv_transfer_group() and self.is_prefill_task(): + self.connector_send_kvcache() + + if is_v1_kv_transfer_group() and self.is_decoder_task(): + self.connector_wait_for_kv_layer() + logger.debug(f"connector_wait_for_kv_layer success") + hidden_states = self.network(**model_inputs) + + return hidden_states -- Gitee