From e2b2ec906c71628c083c726e1599ec67b7c744b5 Mon Sep 17 00:00:00 2001 From: w00613459 Date: Tue, 6 May 2025 22:47:44 +0800 Subject: [PATCH] =?UTF-8?q?=E9=80=82=E9=85=8Dmindspore=20PD=E5=88=86?= =?UTF-8?q?=E7=A6=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vllm_mindspore/__init__.py | 7 +++ vllm_mindspore/engine/arg_utils.py | 6 --- .../model_loader/weight_utils.py | 1 + .../models/mf_models/deepseek_v3.py | 10 +++++ .../models/mf_models/mf_model_base.py | 44 ++++++++++++++++++- vllm_mindspore/v1/core/sched/scheduler.py | 26 ++++++++++- vllm_mindspore/v1/worker/gpu_worker.py | 11 +++-- vllm_mindspore/worker/worker.py | 4 +- 8 files changed, 95 insertions(+), 14 deletions(-) diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index af97709bd..175b566ad 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -178,6 +178,13 @@ from vllm_mindspore.distributed.parallel_state import ( vllm.distributed.parallel_state.init_model_parallel_group = init_model_parallel_group vllm.distributed.parallel_state.GroupCoordinator.__init__ = init_group_coordinator +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory + +KVConnectorFactory.register_connector( + "DLLMDsConnector", + "dllm.dkvc.v1.dllm_ds_connector", + "DLLMDsConnector") + from vllm_mindspore.executor.multiproc_worker_utils import ( get_mp_context as ms_get_mp_context, ) diff --git a/vllm_mindspore/engine/arg_utils.py b/vllm_mindspore/engine/arg_utils.py index ed74ba9e3..9efb89236 100644 --- a/vllm_mindspore/engine/arg_utils.py +++ b/vllm_mindspore/engine/arg_utils.py @@ -164,12 +164,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: recommend_to_remove=False) return False - # No Disaggregated Prefill so far. - if self.kv_transfer_config != EngineArgs.kv_transfer_config: - _raise_or_fallback(feature_name="--kv-transfer-config", - recommend_to_remove=False) - return False - # No FlashInfer or XFormers so far. V1_BACKENDS = [ "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1", diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index 45fe4bdd5..680635fba 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -27,6 +27,7 @@ from mindspore import Parameter, Tensor def safetensors_weights_iterator( hf_weights_files: List[str], + enable_tqdm: bool, ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files.""" from safetensors import safe_open diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py index 7491aeac5..bf1730fca 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py @@ -29,10 +29,12 @@ from vllm.config import get_current_vllm_config from vllm.distributed.parallel_state import get_dp_group, get_tensor_model_parallel_world_size from vllm.forward_context import get_forward_context from vllm.logger import init_logger +from vllm.attention.layer import Attention, maybe_save_kv_layer_to_connector from mindspore import Tensor, Model, mutable from mindspore.common import dtype as msdtype from mindspore.nn.utils import no_init_parameters +from mindspore.common.api import _pynative_executor from mindspore_gs.ptq import PTQ from mindspore_gs.ptq import PTQMode, PTQConfig, OutliersSuppressionType, PrecisionRecovery, QuantGranularity, \ @@ -168,6 +170,14 @@ class DeepseekV3ForCausalLM(MfModelBase): key_cache.append(k_cache) return mutable(key_cache), None + def connector_send_kvcache(self): + _pynative_executor.sync() + forward_context = get_forward_context() + for i in range(self.mf_model_config.num_layers): + kv_cache_module = self.kv_caches[i] + kv_cache = kv_cache_module.kv_cache[forward_context.virtual_engine][0] + maybe_save_kv_layer_to_connector("key." + str(i), kv_cache) + def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: if self.mf_config.load_ckpt_format == "ckpt": model = Model(self.network) diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py index 3e5dca524..dfcec651c 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py @@ -31,10 +31,12 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.forward_context import ForwardContext, get_forward_context from vllm.sequence import IntermediateTensors from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.kv_transfer import is_v1_kv_transfer_group from vllm.attention.backends.abstract import AttentionType from vllm.logger import init_logger -from vllm.attention.layer import Attention +from vllm.attention.layer import Attention, maybe_save_kv_layer_to_connector, wait_for_kv_layer_from_connector import vllm.envs as envs + import torch import mindspore as ms from mindspore import Tensor, mutable @@ -42,7 +44,7 @@ from mindspore import Tensor, mutable from mindformers.tools.register.config import MindFormerConfig from mindformers.core.context import build_mf_context from mindformers.core.parallel_config import build_parallel_config - +from mindspore.common.api import _pynative_executor from vllm_mindspore.model_executor.models.model_base import MsModelBase from vllm_mindspore.v1.attention.backends.flash_attn import FlashAttentionMetadata @@ -120,6 +122,7 @@ class MfModelBase(MsModelBase): vllm_config=vllm_config, prefix=prefix ) + self.kv_transfer_config = vllm_config.kv_transfer_config self.mf_config = MindFormerConfig(os.getenv("MINDFORMERS_MODEL_CONFIG")) build_mf_context(self.mf_config) build_parallel_config(self.mf_config) @@ -159,6 +162,17 @@ class MfModelBase(MsModelBase): value_cache.append(v_cache) return mutable(key_cache), mutable(value_cache) + def is_decoder_task(self) -> bool: + if self.kv_transfer_config is None: + return False + + return self.kv_transfer_config.is_kv_consumer + + def is_prefill_task(self) -> bool: + if self.kv_transfer_config is None: + return False + + return self.kv_transfer_config.is_kv_producer def _dummy_attention_metadata(self, input_ids: Tensor, positions: Tensor) -> FlashAttentionMetadata: input_len = input_ids.shape[0] @@ -256,6 +270,24 @@ class MfModelBase(MsModelBase): def update_model_inputs(self, model_inputs, **kwargs): return model_inputs + def connector_send_kvcache(self): + #TODO 可优化 + _pynative_executor.sync() + forward_context = get_forward_context() + for i in range(self.mf_model_config.num_layers): + kv_cache = self.kv_caches[i] + k_cache = kv_cache.kv_cache[forward_context.virtual_engine][0] + v_cache = kv_cache.kv_cache[forward_context.virtual_engine][1] + maybe_save_kv_layer_to_connector("key." + str(i), (k_cache, v_cache)) + + + def connector_wait_for_kv_layer(self): + logger.debug(f"connector_wait_for_kv_layer") + #TODO 可优化 + for i in range(self.mf_model_config.num_layers): + wait_for_kv_layer_from_connector("key." + str(i)) + + def forward( self, input_ids: Tensor, @@ -279,7 +311,15 @@ class MfModelBase(MsModelBase): if not self.set_flags: self.network.add_flags_custom(is_first_iteration=False) self.set_flags = True + if is_v1_kv_transfer_group(): + self.connector_send_kvcache() else: + if is_v1_kv_transfer_group() and self.is_prefill_task(): + self.connector_send_kvcache() + + if is_v1_kv_transfer_group() and self.is_decoder_task(): + self.connector_wait_for_kv_layer() + logger.debug(f"connector_wait_for_kv_layer success") hidden_states = self.network(**model_inputs) return hidden_states diff --git a/vllm_mindspore/v1/core/sched/scheduler.py b/vllm_mindspore/v1/core/sched/scheduler.py index c03f34691..11419a335 100644 --- a/vllm_mindspore/v1/core/sched/scheduler.py +++ b/vllm_mindspore/v1/core/sched/scheduler.py @@ -112,6 +112,16 @@ def schedule(self) -> SchedulerOutput: # Get already-cached tokens. computed_blocks, num_computed_tokens = ( self.kv_cache_manager.get_computed_blocks(request)) + logger.info(f"num_computed_tokens:{num_computed_tokens}, computed_blocks:{computed_blocks}") + # Get externally-cached tokens if using a KVConnector. + num_external_tokens = ( + 0 if self.connector is None else + self.connector.get_num_new_matched_tokens( + request, num_computed_tokens)) + + # Total computed tokens (local + external). + num_computed_tokens += num_external_tokens + logger.debug(f"num_computed_tokens:{num_computed_tokens}") num_new_tokens = request.num_prompt_tokens - num_computed_tokens if (0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens): @@ -148,11 +158,19 @@ def schedule(self) -> SchedulerOutput: continue new_blocks = self.kv_cache_manager.allocate_slots( - request, num_new_tokens, computed_blocks) + request, num_new_tokens + num_external_tokens, + computed_blocks) + logger.info(f"computed_blocks:{computed_blocks}, new_blocks:{new_blocks}") if new_blocks is None: # The request cannot be scheduled. break + if self.connector is not None: + self.connector.update_state_after_alloc( + request, + num_external_tokens, + ) + self.waiting.popleft() self.running.append(request) self.scheduled_req_ids.add(request.request_id) @@ -285,6 +303,7 @@ def schedule(self) -> SchedulerOutput: resumed_from_preemption=False, ) for req in scheduled_running_reqs ] + logger.info(f"req_to_new_block_ids:{req_to_new_block_ids}") scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=resumed_reqs_data + running_reqs_data, @@ -303,6 +322,11 @@ def schedule(self) -> SchedulerOutput: grammar_bitmask=None, ) + if self.connector is not None: + meta = self.connector.build_connector_meta(scheduler_output) + logger.info(f"scheduler: new reqs: {scheduler_output.scheduled_new_reqs}, kv connector metadata: {meta}") + scheduler_output.kv_connector_metadata = meta + # Advance the number of computed tokens for the request AFTER # the request is scheduled. # 1. The scheduler_output of the current step has to include the diff --git a/vllm_mindspore/v1/worker/gpu_worker.py b/vllm_mindspore/v1/worker/gpu_worker.py index 0395c3392..ffaa54369 100644 --- a/vllm_mindspore/v1/worker/gpu_worker.py +++ b/vllm_mindspore/v1/worker/gpu_worker.py @@ -31,9 +31,14 @@ def init_device(self): self.init_gpu_memory = torch.cuda.mem_get_info()[0] # Initialize the distributed environment. - init_worker_distributed_environment(self.parallel_config, self.rank, - self.distributed_init_method, - self.local_rank) + if config.kv_transfer_config is not None: + init_worker_distributed_environment(config, self.rank, + self.distributed_init_method, + self.local_rank) + else: + init_worker_distributed_environment(self.parallel_config, self.rank, + self.distributed_init_method, + self.local_rank) # Set random seed. set_random_seed(self.model_config.seed) diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index 8ce1bc91d..46feb55ab 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -20,13 +20,13 @@ import gc import os import math -from typing import Tuple, Optional +from typing import Tuple, Optional, Type import torch from vllm.config import VllmConfig +from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed import ( - ensure_kv_transfer_initialized, ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce, -- Gitee