diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index af97709bddca61b9e1af2e49c857a36746458cbe..175b566ad34f1acefa1487f12dd1bd980c13b8d4 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -178,6 +178,13 @@ from vllm_mindspore.distributed.parallel_state import ( vllm.distributed.parallel_state.init_model_parallel_group = init_model_parallel_group vllm.distributed.parallel_state.GroupCoordinator.__init__ = init_group_coordinator +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory + +KVConnectorFactory.register_connector( + "DLLMDsConnector", + "dllm.dkvc.v1.dllm_ds_connector", + "DLLMDsConnector") + from vllm_mindspore.executor.multiproc_worker_utils import ( get_mp_context as ms_get_mp_context, ) diff --git a/vllm_mindspore/engine/arg_utils.py b/vllm_mindspore/engine/arg_utils.py index ed74ba9e38d54f7e507951ea585106f833e83d6b..9efb8923686f0fa2ba5fa5554e6baabc7ddfab02 100644 --- a/vllm_mindspore/engine/arg_utils.py +++ b/vllm_mindspore/engine/arg_utils.py @@ -164,12 +164,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: recommend_to_remove=False) return False - # No Disaggregated Prefill so far. - if self.kv_transfer_config != EngineArgs.kv_transfer_config: - _raise_or_fallback(feature_name="--kv-transfer-config", - recommend_to_remove=False) - return False - # No FlashInfer or XFormers so far. V1_BACKENDS = [ "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1", diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index 45fe4bdd5a4b256a39fe46c73bd49e8537c82a03..680635fbaf7900bb2221e8f353fe6d2ce83fddee 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -27,6 +27,7 @@ from mindspore import Parameter, Tensor def safetensors_weights_iterator( hf_weights_files: List[str], + enable_tqdm: bool, ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files.""" from safetensors import safe_open diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py index 7491aeac5ff9b43be8b129d36a11b79c6964c65c..bf1730fca3263705c62ebfd2b143134de6eaf04d 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py @@ -29,10 +29,12 @@ from vllm.config import get_current_vllm_config from vllm.distributed.parallel_state import get_dp_group, get_tensor_model_parallel_world_size from vllm.forward_context import get_forward_context from vllm.logger import init_logger +from vllm.attention.layer import Attention, maybe_save_kv_layer_to_connector from mindspore import Tensor, Model, mutable from mindspore.common import dtype as msdtype from mindspore.nn.utils import no_init_parameters +from mindspore.common.api import _pynative_executor from mindspore_gs.ptq import PTQ from mindspore_gs.ptq import PTQMode, PTQConfig, OutliersSuppressionType, PrecisionRecovery, QuantGranularity, \ @@ -168,6 +170,14 @@ class DeepseekV3ForCausalLM(MfModelBase): key_cache.append(k_cache) return mutable(key_cache), None + def connector_send_kvcache(self): + _pynative_executor.sync() + forward_context = get_forward_context() + for i in range(self.mf_model_config.num_layers): + kv_cache_module = self.kv_caches[i] + kv_cache = kv_cache_module.kv_cache[forward_context.virtual_engine][0] + maybe_save_kv_layer_to_connector("key." + str(i), kv_cache) + def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: if self.mf_config.load_ckpt_format == "ckpt": model = Model(self.network) diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py index 3e5dca524c3cd5006e5cde3df4a7be8bd010acf9..dfcec651cd295af8fa39cfbbfa49da75ac12ddf1 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py @@ -31,10 +31,12 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.forward_context import ForwardContext, get_forward_context from vllm.sequence import IntermediateTensors from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.kv_transfer import is_v1_kv_transfer_group from vllm.attention.backends.abstract import AttentionType from vllm.logger import init_logger -from vllm.attention.layer import Attention +from vllm.attention.layer import Attention, maybe_save_kv_layer_to_connector, wait_for_kv_layer_from_connector import vllm.envs as envs + import torch import mindspore as ms from mindspore import Tensor, mutable @@ -42,7 +44,7 @@ from mindspore import Tensor, mutable from mindformers.tools.register.config import MindFormerConfig from mindformers.core.context import build_mf_context from mindformers.core.parallel_config import build_parallel_config - +from mindspore.common.api import _pynative_executor from vllm_mindspore.model_executor.models.model_base import MsModelBase from vllm_mindspore.v1.attention.backends.flash_attn import FlashAttentionMetadata @@ -120,6 +122,7 @@ class MfModelBase(MsModelBase): vllm_config=vllm_config, prefix=prefix ) + self.kv_transfer_config = vllm_config.kv_transfer_config self.mf_config = MindFormerConfig(os.getenv("MINDFORMERS_MODEL_CONFIG")) build_mf_context(self.mf_config) build_parallel_config(self.mf_config) @@ -159,6 +162,17 @@ class MfModelBase(MsModelBase): value_cache.append(v_cache) return mutable(key_cache), mutable(value_cache) + def is_decoder_task(self) -> bool: + if self.kv_transfer_config is None: + return False + + return self.kv_transfer_config.is_kv_consumer + + def is_prefill_task(self) -> bool: + if self.kv_transfer_config is None: + return False + + return self.kv_transfer_config.is_kv_producer def _dummy_attention_metadata(self, input_ids: Tensor, positions: Tensor) -> FlashAttentionMetadata: input_len = input_ids.shape[0] @@ -256,6 +270,24 @@ class MfModelBase(MsModelBase): def update_model_inputs(self, model_inputs, **kwargs): return model_inputs + def connector_send_kvcache(self): + #TODO 可优化 + _pynative_executor.sync() + forward_context = get_forward_context() + for i in range(self.mf_model_config.num_layers): + kv_cache = self.kv_caches[i] + k_cache = kv_cache.kv_cache[forward_context.virtual_engine][0] + v_cache = kv_cache.kv_cache[forward_context.virtual_engine][1] + maybe_save_kv_layer_to_connector("key." + str(i), (k_cache, v_cache)) + + + def connector_wait_for_kv_layer(self): + logger.debug(f"connector_wait_for_kv_layer") + #TODO 可优化 + for i in range(self.mf_model_config.num_layers): + wait_for_kv_layer_from_connector("key." + str(i)) + + def forward( self, input_ids: Tensor, @@ -279,7 +311,15 @@ class MfModelBase(MsModelBase): if not self.set_flags: self.network.add_flags_custom(is_first_iteration=False) self.set_flags = True + if is_v1_kv_transfer_group(): + self.connector_send_kvcache() else: + if is_v1_kv_transfer_group() and self.is_prefill_task(): + self.connector_send_kvcache() + + if is_v1_kv_transfer_group() and self.is_decoder_task(): + self.connector_wait_for_kv_layer() + logger.debug(f"connector_wait_for_kv_layer success") hidden_states = self.network(**model_inputs) return hidden_states diff --git a/vllm_mindspore/v1/core/sched/scheduler.py b/vllm_mindspore/v1/core/sched/scheduler.py index c03f34691b7267f745ce230dc296d4ff540eab91..11419a335fe987c5d61fa62db2f247e9154e88c9 100644 --- a/vllm_mindspore/v1/core/sched/scheduler.py +++ b/vllm_mindspore/v1/core/sched/scheduler.py @@ -112,6 +112,16 @@ def schedule(self) -> SchedulerOutput: # Get already-cached tokens. computed_blocks, num_computed_tokens = ( self.kv_cache_manager.get_computed_blocks(request)) + logger.info(f"num_computed_tokens:{num_computed_tokens}, computed_blocks:{computed_blocks}") + # Get externally-cached tokens if using a KVConnector. + num_external_tokens = ( + 0 if self.connector is None else + self.connector.get_num_new_matched_tokens( + request, num_computed_tokens)) + + # Total computed tokens (local + external). + num_computed_tokens += num_external_tokens + logger.debug(f"num_computed_tokens:{num_computed_tokens}") num_new_tokens = request.num_prompt_tokens - num_computed_tokens if (0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens): @@ -148,11 +158,19 @@ def schedule(self) -> SchedulerOutput: continue new_blocks = self.kv_cache_manager.allocate_slots( - request, num_new_tokens, computed_blocks) + request, num_new_tokens + num_external_tokens, + computed_blocks) + logger.info(f"computed_blocks:{computed_blocks}, new_blocks:{new_blocks}") if new_blocks is None: # The request cannot be scheduled. break + if self.connector is not None: + self.connector.update_state_after_alloc( + request, + num_external_tokens, + ) + self.waiting.popleft() self.running.append(request) self.scheduled_req_ids.add(request.request_id) @@ -285,6 +303,7 @@ def schedule(self) -> SchedulerOutput: resumed_from_preemption=False, ) for req in scheduled_running_reqs ] + logger.info(f"req_to_new_block_ids:{req_to_new_block_ids}") scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=resumed_reqs_data + running_reqs_data, @@ -303,6 +322,11 @@ def schedule(self) -> SchedulerOutput: grammar_bitmask=None, ) + if self.connector is not None: + meta = self.connector.build_connector_meta(scheduler_output) + logger.info(f"scheduler: new reqs: {scheduler_output.scheduled_new_reqs}, kv connector metadata: {meta}") + scheduler_output.kv_connector_metadata = meta + # Advance the number of computed tokens for the request AFTER # the request is scheduled. # 1. The scheduler_output of the current step has to include the diff --git a/vllm_mindspore/v1/worker/gpu_worker.py b/vllm_mindspore/v1/worker/gpu_worker.py index 0395c33928a2e7e42c4e3c36a12a20963a808133..ffaa54369ca6c63e0044c1e360c7483d27dc514c 100644 --- a/vllm_mindspore/v1/worker/gpu_worker.py +++ b/vllm_mindspore/v1/worker/gpu_worker.py @@ -31,9 +31,14 @@ def init_device(self): self.init_gpu_memory = torch.cuda.mem_get_info()[0] # Initialize the distributed environment. - init_worker_distributed_environment(self.parallel_config, self.rank, - self.distributed_init_method, - self.local_rank) + if config.kv_transfer_config is not None: + init_worker_distributed_environment(config, self.rank, + self.distributed_init_method, + self.local_rank) + else: + init_worker_distributed_environment(self.parallel_config, self.rank, + self.distributed_init_method, + self.local_rank) # Set random seed. set_random_seed(self.model_config.seed) diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index 8ce1bc91d511a43a83fd3c8b0e70d228b98b951b..46feb55ab60a21df78c4af5ccef63c7016830155 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -20,13 +20,13 @@ import gc import os import math -from typing import Tuple, Optional +from typing import Tuple, Optional, Type import torch from vllm.config import VllmConfig +from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed import ( - ensure_kv_transfer_initialized, ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce,