From 688b8f5e6ebcd28751a561cba4e3b2da4873b721 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=88=B4=E6=98=8A=E6=9B=8C?= Date: Thu, 15 May 2025 15:54:00 +0800 Subject: [PATCH] [Feature] V1 PD disaggregation, DLLM vllm v0.8.3 --- ...6572dd1f76b31b93be19e550790afcfb8843.patch | 3381 +++++++++++++++++ vllm_mindspore/__init__.py | 7 + vllm_mindspore/engine/arg_utils.py | 6 - .../model_loader/weight_utils.py | 3 +- .../models/mf_models/deepseek_v3.py | 17 + .../models/mf_models/mf_model_base.py | 51 +- vllm_mindspore/v1/core/sched/scheduler.py | 26 +- vllm_mindspore/v1/worker/gpu_worker.py | 14 +- vllm_mindspore/worker/worker.py | 1 - 9 files changed, 3493 insertions(+), 13 deletions(-) create mode 100644 vllm_dp/vllm_0_8_3_pr15960_pr15977_296c6572dd1f76b31b93be19e550790afcfb8843.patch diff --git a/vllm_dp/vllm_0_8_3_pr15960_pr15977_296c6572dd1f76b31b93be19e550790afcfb8843.patch b/vllm_dp/vllm_0_8_3_pr15960_pr15977_296c6572dd1f76b31b93be19e550790afcfb8843.patch new file mode 100644 index 00000000..2eaade01 --- /dev/null +++ b/vllm_dp/vllm_0_8_3_pr15960_pr15977_296c6572dd1f76b31b93be19e550790afcfb8843.patch @@ -0,0 +1,3381 @@ +#commit 296c6572dd1f76b31b93be19e550790afcfb8843 (grafted, tag: v0.8.3, origin/v0.8.3) +--- + requirements/build.txt | 2 +- + vllm/attention/layer.py | 45 ++- + vllm/compilation/compiler_interface.py | 4 +- + vllm/compilation/inductor_pass.py | 3 +- + vllm/config.py | 49 ++- + vllm/distributed/kv_transfer/__init__.py | 11 + + .../kv_transfer/kv_connector/base.py | 4 + + .../kv_transfer/kv_connector/factory.py | 52 ++- + .../kv_transfer/kv_connector/v1/__init__.py | 8 + + .../kv_transfer/kv_connector/v1/base.py | 209 ++++++++++ + .../v1/shared_storage_connector.py | 382 ++++++++++++++++++ + .../kv_transfer/kv_connector_agent.py | 76 ++++ + .../kv_transfer/kv_transfer_state.py | 70 ++++ + vllm/distributed/parallel_state.py | 39 +- + vllm/distributed/utils.py | 8 +- + vllm/engine/arg_utils.py | 56 ++- + vllm/entrypoints/cli/serve.py | 81 +++- + vllm/forward_context.py | 26 +- + vllm/utils.py | 56 ++- + vllm/v1/attention/backends/mla/common.py | 12 +- + vllm/v1/core/sched/output.py | 5 + + vllm/v1/core/sched/scheduler.py | 96 ++++- + vllm/v1/engine/core.py | 198 ++++++--- + vllm/v1/engine/core_client.py | 351 ++++++++++------ + vllm/v1/executor/multiproc_executor.py | 6 +- + vllm/v1/serial_utils.py | 4 +- + vllm/v1/utils.py | 112 +++-- + vllm/v1/worker/gpu_model_runner.py | 103 +++-- + vllm/v1/worker/gpu_worker.py | 12 +- + vllm/v1/worker/utils.py | 44 ++ + vllm/worker/model_runner.py | 5 +- + vllm/worker/worker.py | 9 +- + 32 files changed, 1763 insertions(+), 375 deletions(-) + create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/__init__.py + create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/base.py + create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py + create mode 100644 vllm/distributed/kv_transfer/kv_connector_agent.py + create mode 100644 vllm/distributed/kv_transfer/kv_transfer_state.py + +diff --git a/requirements/build.txt b/requirements/build.txt +index 13d643b..1745191 100644 +--- a/requirements/build.txt ++++ b/requirements/build.txt +@@ -4,6 +4,6 @@ ninja + packaging + setuptools>=61 + setuptools-scm>=8 +-torch==2.6.0 ++torch==2.5.1 + wheel + jinja2>=3.1.6 +diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py +index dbf4723..68452f4 100644 +--- a/vllm/attention/layer.py ++++ b/vllm/attention/layer.py +@@ -10,6 +10,9 @@ import vllm.envs as envs + from vllm.attention import AttentionType + from vllm.attention.selector import backend_name_to_enum, get_attn_backend + from vllm.config import CacheConfig, get_current_vllm_config ++from vllm.distributed.kv_transfer import (get_kv_transfer_group, ++ has_kv_transfer_group, ++ is_v1_kv_transfer_group) + from vllm.forward_context import ForwardContext, get_forward_context + from vllm.model_executor.layers.linear import UnquantizedLinearMethod + from vllm.model_executor.layers.quantization.base_config import ( +@@ -329,17 +332,54 @@ class MultiHeadAttention(nn.Module): + return out.reshape(bsz, q_len, -1) + + ++def wait_for_kv_layer_from_connector(layer_name: str): ++ if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): ++ return ++ ++ connector = get_kv_transfer_group() ++ ++ forward_context: ForwardContext = get_forward_context() ++ attn_metadata = forward_context.attn_metadata ++ if attn_metadata is None: ++ return ++ ++ connector.wait_for_layer_load(layer_name) ++ ++ ++def maybe_save_kv_layer_to_connector( ++ layer_name: str, ++ kv_cache_layer: List[torch.Tensor], ++): ++ if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): ++ return ++ ++ connector = get_kv_transfer_group() ++ ++ forward_context: ForwardContext = get_forward_context() ++ attn_metadata = forward_context.attn_metadata ++ if attn_metadata is None: ++ return ++ ++ connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata) ++ ++ + def unified_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + layer_name: str, + ) -> torch.Tensor: ++ wait_for_kv_layer_from_connector(layer_name) ++ + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] +- return self.impl.forward(self, query, key, value, kv_cache, attn_metadata) ++ output = self.impl.forward(self, query, key, value, kv_cache, ++ attn_metadata) ++ ++ maybe_save_kv_layer_to_connector(layer_name, kv_cache) ++ return output + + + def unified_attention_fake( +@@ -367,6 +407,7 @@ def unified_attention_with_output( + output: torch.Tensor, + layer_name: str, + ) -> None: ++ wait_for_kv_layer_from_connector(layer_name) + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + self = forward_context.no_compile_layers[layer_name] +@@ -379,6 +420,8 @@ def unified_attention_with_output( + attn_metadata, + output=output) + ++ maybe_save_kv_layer_to_connector(layer_name, kv_cache) ++ + + def unified_attention_with_output_fake( + query: torch.Tensor, +diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py +index 5a22cf7..73d822e 100644 +--- a/vllm/compilation/compiler_interface.py ++++ b/vllm/compilation/compiler_interface.py +@@ -14,7 +14,7 @@ import torch.fx as fx + from packaging.version import Version + + from vllm.config import VllmConfig +- ++from vllm.utils import is_torch_equal_or_newer + + class CompilerInterface: + """ +@@ -379,7 +379,7 @@ class InductorAdaptor(CompilerInterface): + manually setting up internal contexts. But we also rely on non-public + APIs which might not provide these guarantees. + """ +- if Version(importlib.metadata.version('torch')) >= Version("2.6"): ++ if is_torch_equal_or_newer("2.6"): + import torch._dynamo.utils + return torch._dynamo.utils.get_metrics_context() + else: +diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py +index 08dd8c8..348b56b 100644 +--- a/vllm/compilation/inductor_pass.py ++++ b/vllm/compilation/inductor_pass.py +@@ -10,8 +10,9 @@ from typing import Any, Callable, Dict, Optional, Union + import torch + from packaging.version import Version + from torch import fx ++from vllm.utils import is_torch_equal_or_newer + +-if Version(importlib.metadata.version('torch')) >= Version("2.6"): ++if is_torch_equal_or_newer("2.6"): + from torch._inductor.custom_graph_pass import CustomGraphPass + else: + # CustomGraphPass is not present in 2.5 or lower, import our version +diff --git a/vllm/config.py b/vllm/config.py +index bd52fc9..a384ade 100644 +--- a/vllm/config.py ++++ b/vllm/config.py +@@ -40,7 +40,7 @@ from vllm.transformers_utils.config import ( + from vllm.transformers_utils.s3_utils import S3Model + from vllm.transformers_utils.utils import is_s3, maybe_model_redirect + from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless, +- get_cpu_memory, get_open_port, random_uuid, ++ get_cpu_memory, get_open_port, random_uuid, is_torch_equal_or_newer, + resolve_obj_by_qualname) + + if TYPE_CHECKING: +@@ -1429,16 +1429,27 @@ class LoadConfig: + class ParallelConfig: + """Configuration for the distributed execution.""" + +- pipeline_parallel_size: int = 1 # Number of pipeline parallel groups. +- tensor_parallel_size: int = 1 # Number of tensor parallel groups. +- data_parallel_size: int = 1 # Number of data parallel groups. +- data_parallel_rank: int = 0 # Rank of the data parallel group. +- # Local rank of the data parallel group, defaults to global rank. ++ pipeline_parallel_size: int = 1 ++ """Number of pipeline parallel groups.""" ++ tensor_parallel_size: int = 1 ++ """Number of tensor parallel groups.""" ++ data_parallel_size: int = 1 ++ """Number of data parallel groups. MoE layers will be sharded according to ++ the product of the tensor parallel size and data parallel size.""" ++ data_parallel_size_local: int = 1 ++ """Number of local data parallel groups.""" ++ data_parallel_rank: int = 0 ++ """Rank of the data parallel group.""" + data_parallel_rank_local: Optional[int] = None + # IP of the data parallel master. + data_parallel_master_ip: str = "127.0.0.1" +- data_parallel_master_port: int = 29500 # Port of the data parallel master. +- enable_expert_parallel: bool = False # Use EP instead of TP for MoE layers. ++ """IP of the data parallel master.""" ++ data_parallel_rpc_port: int = 29550 ++ """Port for data parallel messaging.""" ++ data_parallel_master_port: int = 29500 ++ """Port of the data parallel master.""" ++ enable_expert_parallel: bool = False ++ """Use expert parallelism instead of tensor parallelism for MoE layers.""" + + # Maximum number of multiple batches + # when load model sequentially. To avoid RAM OOM when using tensor +@@ -1475,12 +1486,16 @@ class ParallelConfig: + + # world_size is TPxPP, it affects the number of workers we create. + world_size: int = field(init=False) +- # world_size_across_dp is TPxPPxDP, it is the size of the world +- # including data parallelism. +- world_size_across_dp: int = field(init=False) ++ """world_size is TPxPP, it affects the number of workers we create.""" + + rank: int = 0 + ++ @property ++ def world_size_across_dp(self) -> int: ++ """world_size_across_dp is TPxPPxDP, it is the size of the world ++ including data parallelism.""" ++ return self.world_size * self.data_parallel_size ++ + def get_next_dp_init_port(self) -> int: + """ + We might need to initialize process groups in multiple +@@ -1533,16 +1548,20 @@ class ParallelConfig: + factors: list[Any] = [] + factors.append(self.pipeline_parallel_size) + factors.append(self.tensor_parallel_size) ++ factors.append(self.data_parallel_size) + return hashlib.sha256(str(factors).encode()).hexdigest() + + def __post_init__(self) -> None: + self.world_size = self.pipeline_parallel_size * \ + self.tensor_parallel_size + +- if self.data_parallel_size > 1: ++ if self.data_parallel_size_local > self.data_parallel_size: ++ raise ValueError( ++ "data_parallel_size_local must be <= data_parallel_size") ++ ++ if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: + # Data parallel was specified in the engine args. + self.data_parallel_master_port = get_open_port() +- # TODO multi-node + else: + # Otherwise fall back to env vars (e.g. for offline SPMD case). + self.data_parallel_size = envs.VLLM_DP_SIZE +@@ -1551,8 +1570,6 @@ class ParallelConfig: + self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP + self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT + +- self.world_size_across_dp = self.world_size * self.data_parallel_size +- + if self.distributed_executor_backend == "external_launcher": + import os + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" +@@ -3268,7 +3285,7 @@ class CompilationConfig(BaseModel): + # and it is not yet a priority. RFC here: + # https://github.com/vllm-project/vllm/issues/14703 + +- if Version(importlib.metadata.version('torch')) >= Version("2.6"): ++ if is_torch_equal_or_newer("2.6"): + KEY = 'enable_auto_functionalized_v2' + if KEY not in self.inductor_compile_config: + self.inductor_compile_config[KEY] = False +diff --git a/vllm/distributed/kv_transfer/__init__.py b/vllm/distributed/kv_transfer/__init__.py +index e69de29..ec07c6f 100644 +--- a/vllm/distributed/kv_transfer/__init__.py ++++ b/vllm/distributed/kv_transfer/__init__.py +@@ -0,0 +1,11 @@ ++# SPDX-License-Identifier: Apache-2.0 ++ ++from vllm.distributed.kv_transfer.kv_transfer_state import ( ++ ensure_kv_transfer_initialized, get_kv_transfer_group, ++ has_kv_transfer_group, is_v1_kv_transfer_group) ++ ++__all__ = [ ++ "get_kv_transfer_group", "has_kv_transfer_group", ++ "is_v1_kv_transfer_group", "ensure_kv_transfer_initialized", ++ "KVConnectorBaseType" ++] +diff --git a/vllm/distributed/kv_transfer/kv_connector/base.py b/vllm/distributed/kv_transfer/kv_connector/base.py +index 57c764b..0d1a3d4 100644 +--- a/vllm/distributed/kv_transfer/kv_connector/base.py ++++ b/vllm/distributed/kv_transfer/kv_connector/base.py +@@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, List, Tuple, Union + + import torch + ++from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 + from vllm.sequence import IntermediateTensors + + if TYPE_CHECKING: +@@ -121,3 +122,6 @@ class KVConnectorBase(ABC): + """ + + raise NotImplementedError ++ ++ ++KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1] +diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py +index e37ce6d..665ea2f 100644 +--- a/vllm/distributed/kv_transfer/kv_connector/factory.py ++++ b/vllm/distributed/kv_transfer/kv_connector/factory.py +@@ -3,14 +3,22 @@ + import importlib + from typing import TYPE_CHECKING, Callable, Dict, Type + ++import vllm.envs as envs ++from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType ++from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, ++ KVConnectorRole) ++from vllm.logger import init_logger ++ + from .base import KVConnectorBase + + if TYPE_CHECKING: + from vllm.config import VllmConfig + ++logger = init_logger(__name__) ++ + + class KVConnectorFactory: +- _registry: Dict[str, Callable[[], Type[KVConnectorBase]]] = {} ++ _registry: Dict[str, Callable[[], Type[KVConnectorBaseType]]] = {} + + @classmethod + def register_connector(cls, name: str, module_path: str, +@@ -19,22 +27,51 @@ class KVConnectorFactory: + if name in cls._registry: + raise ValueError(f"Connector '{name}' is already registered.") + +- def loader() -> Type[KVConnectorBase]: ++ def loader() -> Type[KVConnectorBaseType]: + module = importlib.import_module(module_path) + return getattr(module, class_name) + + cls._registry[name] = loader + + @classmethod +- def create_connector(cls, rank: int, local_rank: int, +- config: "VllmConfig") -> KVConnectorBase: ++ def create_connector_v0(cls, rank: int, local_rank: int, ++ config: "VllmConfig") -> KVConnectorBase: ++ if envs.VLLM_USE_V1: ++ raise ValueError("Attempting to initialize a V0 Connector, " ++ f"but found {envs.VLLM_USE_V1=}") ++ + connector_name = config.kv_transfer_config.kv_connector + if connector_name not in cls._registry: + raise ValueError(f"Unsupported connector type: {connector_name}") + + connector_cls = cls._registry[connector_name]() ++ assert issubclass(connector_cls, KVConnectorBase) + return connector_cls(rank, local_rank, config) + ++ @classmethod ++ def create_connector_v1( ++ cls, ++ config: "VllmConfig", ++ role: KVConnectorRole, ++ ) -> KVConnectorBase_V1: ++ if not envs.VLLM_USE_V1: ++ raise ValueError("Attempting to initialize a V1 Connector, " ++ f"but found {envs.VLLM_USE_V1=}") ++ ++ connector_name = config.kv_transfer_config.kv_connector ++ connector_cls = cls._registry[connector_name]() ++ assert issubclass(connector_cls, KVConnectorBase_V1) ++ logger.info("Creating v1 connector with name: %s", connector_name) ++ # NOTE(Kuntai): v1 connector is explicitly separated into two roles. ++ # Scheduler connector: ++ # - Co-locate with scheduler process ++ # - Should only be used inside the Scheduler class ++ # Worker connector: ++ # - Co-locate with worker process ++ # - Should only be used inside the forward context & attention layer ++ # We build separately to enforce strict separation ++ return connector_cls(config, role) ++ + + # Register various connectors here. + # The registration should not be done in each individual file, as we want to +@@ -57,4 +94,9 @@ KVConnectorFactory.register_connector( + KVConnectorFactory.register_connector( + "MooncakeStoreConnector", + "vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector", +- "MooncakeStoreConnector") +\ No newline at end of file ++ "MooncakeStoreConnector") ++ ++KVConnectorFactory.register_connector( ++ "SharedStorageConnector", ++ "vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector", ++ "SharedStorageConnector") +diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py +new file mode 100644 +index 0000000..a017b14 +--- /dev/null ++++ b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py +@@ -0,0 +1,8 @@ ++# SPDX-License-Identifier: Apache-2.0 ++from vllm.distributed.kv_transfer.kv_connector.v1.base import ( ++ KVConnectorBase_V1, KVConnectorRole) ++ ++__all__ = [ ++ "KVConnectorRole", ++ "KVConnectorBase_V1", ++] +diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py +new file mode 100644 +index 0000000..95967d2 +--- /dev/null ++++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py +@@ -0,0 +1,209 @@ ++# SPDX-License-Identifier: Apache-2.0 ++""" ++KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State ++communication in vLLM v1 ++ ++The class provides the following primitives: ++ Scheduler-side: runs in the scheduler, binds metadata, which ++ is used by the worker-side to load/save KV cache. ++ get_num_new_matched_tokens() - get number of new tokens ++ that exist in the remote KV cache ++ update_state_after_alloc() - update KVConnector state after ++ temporary buffer alloc by the CacheManager. ++ ++ Worker-side: runs in each worker, loads/saves KV cache to/from ++ the Connector based on the metadata. ++ start_load_kv() - starts loading all KVs (maybe async) ++ wait_for_layer_load() - blocks until layer i load is done ++ ++ save_kv_layer() - starts saving KV for layer i (maybe async) ++ wait_for_save() - blocks until all saves are done ++""" ++ ++import enum ++from abc import ABC, abstractmethod ++from dataclasses import dataclass ++from typing import TYPE_CHECKING ++ ++import torch ++ ++from vllm.logger import init_logger ++from vllm.v1.core.sched.output import SchedulerOutput ++ ++if TYPE_CHECKING: ++ from vllm.attention.backends.abstract import AttentionMetadata ++ from vllm.config import VllmConfig ++ from vllm.forward_context import ForwardContext ++ from vllm.v1.request import Request ++ ++logger = init_logger(__name__) ++ ++ ++class KVConnectorRole(enum.Enum): ++ # Connector running in the scheduler process ++ SCHEDULER = 0 ++ ++ # Connector running in the worker process ++ WORKER = 1 ++ ++ ++@dataclass ++class KVConnectorMetadata: ++ pass ++ ++ ++class KVConnectorBase_V1(ABC): ++ ++ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): ++ logger.warning( ++ "Initializing KVConnectorBase_V1. This API is experimental and " ++ "subject to change in the future as we iterate the design.") ++ self._connector_metadata = KVConnectorMetadata() ++ self._vllm_config = vllm_config ++ self._role = role ++ ++ @property ++ def role(self) -> KVConnectorRole: ++ return self._role ++ ++ def bind_connector_metadata( ++ self, connector_metadata: KVConnectorMetadata) -> None: ++ """Set the connector metadata from the scheduler. ++ ++ This function should be called by the model runner every time ++ before the model execution. The metadata will be used for runtime ++ KV cache loading and saving. ++ ++ Args: ++ connector_metadata (dict): the connector metadata. ++ """ ++ self._connector_metadata = connector_metadata ++ ++ def clear_connector_metadata(self) -> None: ++ """Clear the connector metadata. ++ ++ This function should be called by the model runner every time ++ after the model execution. ++ """ ++ self._connector_metadata = KVConnectorMetadata() ++ ++ def _get_connector_metadata(self) -> KVConnectorMetadata: ++ """Get the connector metadata. ++ ++ This function should only be called inside the connector. ++ ++ Returns: ++ ConnectorMetadata: the connector metadata. ++ """ ++ return self._connector_metadata ++ ++ # ============================== ++ # Worker-side methods ++ # ============================== ++ ++ @abstractmethod ++ def start_load_kv(self, forward_context: "ForwardContext", ++ **kwargs) -> None: ++ """ ++ Start loading the KV cache from the connector to vLLM's paged ++ KV buffer. This is called from the forward context before the ++ forward pass to enable async loading during model execution. ++ ++ Args: ++ forward_context (ForwardContext): the forward context. ++ **kwargs: additional arguments for the load operation ++ ++ Note: ++ The number of elements in kv_caches and layer_names should be ++ the same. ++ ++ """ ++ pass ++ ++ @abstractmethod ++ def wait_for_layer_load(self, layer_name: str) -> None: ++ """ ++ Block until the KV for a specific layer is loaded into vLLM's ++ paged buffer. This is called from within attention layer to ensure ++ async copying from start_load_kv is complete. ++ ++ This interface will be useful for layer-by-layer pipelining. ++ ++ Args: ++ layer_name: the name of that layer ++ """ ++ pass ++ ++ @abstractmethod ++ def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, ++ attn_metadata: "AttentionMetadata", **kwargs) -> None: ++ """ ++ Start saving a layer of KV cache from vLLM's paged buffer ++ to the connector. This is called from within attention layer to ++ enable async copying during execution. ++ ++ Args: ++ layer_name (str): the name of the layer. ++ kv_layer (torch.Tensor): the paged KV buffer of the current ++ layer in vLLM. ++ attn_metadata (AttentionMetadata): the attention metadata. ++ **kwargs: additional arguments for the save operation. ++ """ ++ pass ++ ++ @abstractmethod ++ def wait_for_save(self): ++ """ ++ Block until all the save operations is done. This is called ++ as the forward context exits to ensure that the async saving ++ from save_kv_layer is complete before finishing the forward. ++ ++ This prevents overwrites of paged KV buffer before saving done. ++ """ ++ pass ++ ++ # ============================== ++ # Scheduler-side methods ++ # ============================== ++ @abstractmethod ++ def get_num_new_matched_tokens( ++ self, ++ request: "Request", ++ num_computed_tokens: int, ++ ) -> int: ++ """ ++ Get number of new tokens that can be loaded from the ++ external KV cache beyond the num_computed_tokens. ++ ++ Args: ++ request (Request): the request object. ++ num_computed_tokens (int): the number of locally ++ computed tokens for this request ++ ++ Returns: ++ the number of tokens that can be loaded from the ++ external KV cache beyond what is already computed. ++ """ ++ pass ++ ++ @abstractmethod ++ def update_state_after_alloc(self, request: "Request", ++ num_external_tokens: int): ++ """ ++ Update KVConnector state after block allocation. ++ """ ++ pass ++ ++ @abstractmethod ++ def build_connector_meta( ++ self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: ++ """ ++ Build the connector metadata for this step. ++ ++ This function should NOT modify fields in the scheduler_output. ++ Also, calling this function will reset the state of the connector. ++ ++ Args: ++ scheduler_output (SchedulerOutput): the scheduler output object. ++ """ ++ pass +diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +new file mode 100644 +index 0000000..1d20407 +--- /dev/null ++++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +@@ -0,0 +1,382 @@ ++# SPDX-License-Identifier: Apache-2.0 ++import hashlib ++import os ++from dataclasses import dataclass ++from typing import TYPE_CHECKING ++ ++import safetensors ++import torch ++ ++from vllm.config import VllmConfig ++from vllm.distributed.kv_transfer.kv_connector.v1.base import ( ++ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) ++from vllm.logger import init_logger ++from vllm.v1.attention.backends.mla.common import MLACommonMetadata ++from vllm.v1.core.sched.output import SchedulerOutput ++ ++if TYPE_CHECKING: ++ from vllm.attention.backends.abstract import AttentionMetadata ++ from vllm.forward_context import ForwardContext ++ from vllm.v1.request import Request ++ ++logger = init_logger(__name__) ++ ++ ++@dataclass ++class ReqMeta: ++ # Request tokens ++ token_ids: torch.Tensor ++ # Slot mappings, should have the same length as token_ids ++ slot_mapping: torch.Tensor ++ # Is store or load ++ is_store: bool ++ ++ @staticmethod ++ def make_meta(token_ids: list[int], block_ids: list[int], block_size: int, ++ is_store: bool) -> "ReqMeta": ++ valid_num_tokens = align_to_block_size(len(token_ids), block_size) ++ token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens] ++ block_ids_tensor = torch.tensor(block_ids) ++ num_blocks = block_ids_tensor.shape[0] ++ block_offsets = torch.arange(0, block_size) ++ slot_mapping = block_offsets.reshape((1, block_size)) + \ ++ block_ids_tensor.reshape((num_blocks, 1)) * block_size ++ slot_mapping = slot_mapping.flatten()[:valid_num_tokens] ++ return ReqMeta( ++ token_ids=token_ids_tensor, ++ slot_mapping=slot_mapping, ++ is_store=is_store, ++ ) ++ ++ ++@dataclass ++class SharedStorageConnectorMetadata(KVConnectorMetadata): ++ requests: list[ReqMeta] ++ ++ def __init__(self): ++ self.requests = [] ++ ++ def add_request( ++ self, ++ token_ids: list[int], ++ block_ids: list[int], ++ block_size: int, ++ is_store: bool, ++ ) -> None: ++ self.requests.append( ++ ReqMeta.make_meta(token_ids, block_ids, block_size, is_store)) ++ ++ ++class SharedStorageConnector(KVConnectorBase_V1): ++ # NOTE: This is Simple debug implementation of the KV connector. ++ # It save / load the KV cache to / from the disk. ++ # It does extra work which will overwrite the existing prefix-cache in GPU ++ # - to remove the overhead, need to add some "mask" in the ReqMeta class ++ ++ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): ++ super().__init__(vllm_config=vllm_config, role=role) ++ self._block_size = vllm_config.cache_config.block_size ++ self._requests_need_load: dict[str, Request] = {} ++ transfer_config = vllm_config.kv_transfer_config ++ self._storage_path = transfer_config.get_from_extra_config( ++ "shared_storage_path", "/tmp") ++ logger.info(vllm_config.kv_transfer_config) ++ logger.info("Shared storage path is %s", self._storage_path) ++ ++ def start_load_kv(self, forward_context: "ForwardContext", ++ **kwargs) -> None: ++ """Start loading the KV cache from the connector buffer to vLLM's ++ paged KV buffer. ++ ++ Args: ++ forward_context (ForwardContext): the forward context. ++ **kwargs: additional arguments for the load operation ++ ++ Note: ++ The number of elements in kv_caches and layer_names should be ++ the same. ++ """ ++ attn_metadata = forward_context.attn_metadata ++ ++ def inject_kv_into_layer( ++ dst_kv_cache_layer: torch.Tensor, ++ src_kv_cache: torch.Tensor, ++ slot_mapping: torch.Tensor, ++ ) -> None: ++ """Inject the KV cache into the layer. ++ ++ Args: ++ dst_kv_cache_layer (torch.Tensor): the destination KV cache ++ layer. In shape [2, num_pages, page_size, xxx] if not ++ using MLA, [num_pages, page_size, xxx] otherwise. ++ src_kv_cache (torch.Tensor): the source KV cache. In shape ++ [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx] ++ otherwise. ++ slot_mapping (torch.Tensor): the slot mapping. In shape ++ [num_tokens]. ++ """ ++ dst_kv_cache_layer_shape = dst_kv_cache_layer.shape ++ if isinstance(attn_metadata, MLACommonMetadata): ++ num_pages = dst_kv_cache_layer_shape[0] ++ page_size = dst_kv_cache_layer_shape[1] ++ dst_kv_cache_layer = dst_kv_cache_layer.reshape( ++ num_pages * page_size, -1) ++ dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache ++ dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) ++ else: ++ num_pages = dst_kv_cache_layer_shape[1] ++ page_size = dst_kv_cache_layer_shape[2] ++ dst_kv_cache_layer = dst_kv_cache_layer.reshape( ++ 2, num_pages * page_size, -1) ++ dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache ++ dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) ++ ++ # Get the metadata ++ metadata: KVConnectorMetadata = \ ++ self._get_connector_metadata() ++ assert isinstance(metadata, SharedStorageConnectorMetadata) ++ ++ if metadata is None: ++ logger.warning( ++ "In connector.start_load_kv, but the connector metadata is None" ++ ) ++ return ++ ++ attn_metadata = forward_context.attn_metadata ++ if attn_metadata is None: ++ logger.warning( ++ "In connector.start_load_kv, but the attn_metadata is None") ++ return ++ ++ # Load the KV for each request each layer ++ for request in metadata.requests: ++ if request.is_store: ++ continue ++ logger.info("Inject KV cache of %d tokens to the paged memory", ++ len(request.slot_mapping)) ++ for layer_name in forward_context.no_compile_layers: ++ attn_layer = forward_context.no_compile_layers[layer_name] ++ kv_cache_layer = attn_layer.kv_cache[\ ++ forward_context.virtual_engine] ++ ++ filename = self._generate_filename_debug( ++ layer_name, request.token_ids) ++ kv_cache = safetensors.torch.load_file( ++ filename)["kv_cache"].cuda() ++ inject_kv_into_layer(kv_cache_layer, kv_cache, ++ request.slot_mapping) ++ ++ def wait_for_layer_load(self, layer_name: str) -> None: ++ """Blocking until the KV for a specific layer is loaded into vLLM's ++ paged buffer. ++ ++ This interface will be useful for layer-by-layer pipelining. ++ ++ Args: ++ layer_name: the name of that layer ++ """ ++ return ++ ++ def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, ++ attn_metadata: "AttentionMetadata", **kwargs) -> None: ++ """Start saving the KV cache of the layer from vLLM's paged buffer ++ to the connector. ++ ++ Args: ++ layer_name (str): the name of the layer. ++ kv_layer (torch.Tensor): the paged KV buffer of the current ++ layer in vLLM. ++ attn_metadata (AttentionMetadata): the attention metadata. ++ **kwargs: additional arguments for the save operation. ++ """ ++ ++ def extract_kv_from_layer( ++ layer: torch.Tensor, ++ slot_mapping: torch.Tensor, ++ ) -> torch.Tensor: ++ """Extract the KV cache from the layer. ++ ++ Assume the shape of the layer is (2, num_pages, page_size, xxx) ++ if MLA is not used, and (num_pages, page_size, xxx) otherwise. ++ """ ++ if isinstance(attn_metadata, MLACommonMetadata): ++ num_pages, page_size = layer.shape[0], layer.shape[1] ++ return layer.reshape(num_pages * page_size, -1)[slot_mapping, ++ ...] ++ num_pages, page_size = layer.shape[1], layer.shape[2] ++ return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ++ ...] ++ ++ connector_metadata = self._get_connector_metadata() ++ assert isinstance(connector_metadata, SharedStorageConnectorMetadata) ++ for request in connector_metadata.requests: ++ if request.is_store: ++ filename = self._generate_filename_debug( ++ layer_name, request.token_ids) ++ kv_cache = extract_kv_from_layer(kv_layer, ++ request.slot_mapping) ++ tensors = {"kv_cache": kv_cache.detach().cpu()} ++ safetensors.torch.save_file(tensors, filename) ++ ++ def wait_for_save(self): ++ return ++ ++ def get_num_new_matched_tokens( ++ self, ++ request: "Request", ++ num_computed_tokens: int, ++ ) -> int: ++ """ ++ Get number of new tokens that can be loaded from the ++ external KV cache beyond the num_computed_tokens. ++ ++ Args: ++ request (Request): the request object. ++ num_computed_tokens (int): the number of locally ++ computed tokens for this request ++ ++ Returns: ++ the number of tokens that can be loaded from the ++ external KV cache beyond what is already computed. ++ """ ++ ++ # NOTE: in this debug implementation, we assume that the prompt is ++ # cached_prompt + newly_generated_single_token ++ # Therefore, we use prompt_token_ids[:-1] to determine the folder name ++ ++ # NOTE: in current v1 scheduler, the num_computed_tokens is aligned ++ # with the block granularity. And it expects the returned blocks and ++ # num_computed_tokens to also be aligned with the block granularity. ++ if not self._found_match_for_request(request): ++ return 0 ++ ++ logger.info("External Cache Hit!") ++ ++ # Now, first num_tokens_to_check tokens are hit, we need to prepare ++ # the metadata for the worker connector to correctly load the KV ++ num_tokens_to_check = align_to_block_size( ++ len(request.prompt_token_ids) - 1, self._block_size) ++ ++ return num_tokens_to_check - num_computed_tokens ++ ++ def update_state_after_alloc(self, request: "Request", ++ num_external_tokens: int): ++ """ ++ Update KVConnector state after block allocation. ++ ++ If blocks were allocated, add to _requests_need_load, ++ such that we load the KVs in the next forward pass. ++ """ ++ if num_external_tokens > 0: ++ self._requests_need_load[request.request_id] = request ++ ++ def build_connector_meta( ++ self, ++ scheduler_output: SchedulerOutput, ++ ) -> KVConnectorMetadata: ++ """Build the connector metadata for this step. ++ ++ This function should NOT modify any fields in the scheduler_output. ++ Also, calling this function will reset the state of the connector. ++ ++ Args: ++ scheduler_output (SchedulerOutput): the scheduler output object. ++ """ ++ meta = SharedStorageConnectorMetadata() ++ ++ total_need_load = 0 ++ for new_req in scheduler_output.scheduled_new_reqs: ++ if new_req.req_id in self._requests_need_load: ++ meta.add_request(token_ids=new_req.prompt_token_ids, ++ block_ids=new_req.block_ids, ++ block_size=self._block_size, ++ is_store=False) ++ total_need_load += 1 ++ else: ++ # NOTE: here, we set the store and load being exclusive, ++ # but a single request can have both store and load. ++ # NOTE(rob): for this debug implementation, we only cache ++ # the original prompt tokens. ++ if not self._found_match_for_request(new_req): ++ meta.add_request(token_ids=new_req.prompt_token_ids, ++ block_ids=new_req.block_ids, ++ block_size=self._block_size, ++ is_store=True) ++ ++ for cached_req in scheduler_output.scheduled_cached_reqs: ++ # NOTE(rob): here we rely on the resumed requests being ++ # the first N requests in the list scheduled_cache_reqs. ++ if not cached_req.resumed_from_preemption: ++ break ++ if cached_req.req_id in self._requests_need_load: ++ # NOTE(rob): cached_req_data does not have the full ++ # list of token ids (only new tokens). So we look it ++ # up in the actual request object. ++ request = self._requests_need_load[cached_req.req_id] ++ total_tokens = (len(cached_req.new_token_ids) + ++ cached_req.num_computed_tokens) ++ token_ids = request.all_token_ids[:total_tokens] ++ ++ # NOTE(rob): For resumed req, new_block_ids is all ++ # of the block_ids for the request. ++ block_ids = cached_req.new_block_ids ++ ++ meta.add_request(token_ids=token_ids, ++ block_ids=block_ids, ++ block_size=self._block_size, ++ is_store=False) ++ total_need_load += 1 ++ ++ assert total_need_load == len(self._requests_need_load) ++ self._requests_need_load.clear() ++ return meta ++ ++ # ============================== ++ # Helper functions ++ # ============================== ++ ++ def _found_match_for_request( ++ self, ++ request: "Request", ++ ) -> bool: ++ """Check if the cache is hit for the request. ++ """ ++ num_tokens_to_check = align_to_block_size( ++ len(request.prompt_token_ids) - 1, self._block_size) ++ foldername = self._generate_foldername_debug(torch.tensor( ++ request.prompt_token_ids)[:num_tokens_to_check], ++ create_folder=False) ++ return os.path.exists(foldername) ++ ++ def _generate_foldername_debug( ++ self, ++ input_ids: torch.Tensor, ++ create_folder=False, ++ ) -> str: ++ """Generate a folder name based on the hash of the bytes of the input ++ ids. ++ """ ++ input_ids_bytes = input_ids.numpy().tobytes() ++ input_ids_hash = hashlib.md5(input_ids_bytes).hexdigest() ++ foldername = os.path.join(self._storage_path, input_ids_hash) ++ if create_folder: ++ os.makedirs(foldername, exist_ok=True) ++ return foldername ++ ++ def _generate_filename_debug( ++ self, ++ layer_name: str, ++ input_ids: torch.Tensor, ++ ) -> str: ++ """Generate a file name based on the layer name and the hash ++ of the bytes of the input ids. ++ """ ++ foldername = self._generate_foldername_debug(input_ids, ++ create_folder=True) ++ return os.path.join(foldername, f"{layer_name}.safetensors") ++ ++ ++def align_to_block_size(num_tokens: int, block_size) -> int: ++ """Align the number of tokens to the block size. ++ """ ++ return (num_tokens - 1) // block_size * block_size +diff --git a/vllm/distributed/kv_transfer/kv_connector_agent.py b/vllm/distributed/kv_transfer/kv_connector_agent.py +new file mode 100644 +index 0000000..9d71450 +--- /dev/null ++++ b/vllm/distributed/kv_transfer/kv_connector_agent.py +@@ -0,0 +1,76 @@ ++# SPDX-License-Identifier: Apache-2.0 ++"""A centralized entrypoint to perform distributed KV cache transfer. ++ ++This implementation is a shim wrapper on two APIs exposed by `kv_connector`: ++1. `send_kv_caches_and_hidden_states` ++2. `recv_kv_caches_and_hidden_states ++""" ++from typing import TYPE_CHECKING, List, Tuple, Union ++ ++if TYPE_CHECKING: ++ from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata ++ from vllm.config import VllmConfig ++ ++import torch ++ ++from vllm.distributed.kv_transfer.kv_connector.factory import ( ++ KVConnectorFactory) ++from vllm.logger import init_logger ++from vllm.sequence import IntermediateTensors ++ ++logger = init_logger(__name__) ++ ++ ++class KVTransferAgent: ++ """ ++ A class designated for distributed KV transfer ++ ++ Target use cases: ++ 1. Disaggregated prefill ++ 2. Remote KV cache storage ++ """ ++ ++ def __init__( ++ self, ++ rank: int, ++ local_rank: int, ++ config: "VllmConfig", ++ ): ++ ++ self.config = config ++ ++ if config.kv_transfer_config is None: ++ raise ValueError("KVTransferConfig is not set in the VllmConfig," ++ " cannot initialize KVConnector.") ++ ++ assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\ ++ "TransferAgent should only be used when kv_connector is set." ++ ++ self.connector = KVConnectorFactory.create_connector_v0( ++ rank, local_rank, config) ++ ++ def send_kv_caches_and_hidden_states( ++ self, ++ model_executable: torch.nn.Module, ++ model_input: "ModelInputForGPUWithSamplingMetadata", ++ kv_caches: List[torch.Tensor], ++ hidden_or_intermediate_states: Union[torch.Tensor, ++ IntermediateTensors], ++ ) -> None: ++ ++ self.connector.send_kv_caches_and_hidden_states( ++ model_executable, model_input, kv_caches, ++ hidden_or_intermediate_states) ++ ++ def close(self) -> None: ++ self.connector.close() ++ ++ def recv_kv_caches_and_hidden_states( ++ self, model_executable: torch.nn.Module, ++ model_input: "ModelInputForGPUWithSamplingMetadata", ++ kv_caches: List[torch.Tensor] ++ ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, ++ "ModelInputForGPUWithSamplingMetadata"]: ++ ++ return self.connector.recv_kv_caches_and_hidden_states( ++ model_executable, model_input, kv_caches) +diff --git a/vllm/distributed/kv_transfer/kv_transfer_state.py b/vllm/distributed/kv_transfer/kv_transfer_state.py +new file mode 100644 +index 0000000..25d2f2c +--- /dev/null ++++ b/vllm/distributed/kv_transfer/kv_transfer_state.py +@@ -0,0 +1,70 @@ ++# SPDX-License-Identifier: Apache-2.0 ++from typing import TYPE_CHECKING, Optional ++ ++from vllm import envs ++from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType ++from vllm.distributed.kv_transfer.kv_connector.factory import ( ++ KVConnectorFactory) ++from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, ++ KVConnectorRole) ++from vllm.distributed.parallel_state import get_world_group ++ ++if TYPE_CHECKING: ++ from vllm.config import VllmConfig ++ ++_KV_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None ++ ++ ++def get_kv_transfer_group() -> KVConnectorBaseType: ++ assert _KV_CONNECTOR_AGENT is not None, ( ++ "disaggregated KV cache transfer parallel group is not initialized") ++ return _KV_CONNECTOR_AGENT ++ ++ ++def has_kv_transfer_group() -> bool: ++ return _KV_CONNECTOR_AGENT is not None ++ ++ ++def is_v1_kv_transfer_group( ++ connector: Optional[KVConnectorBaseType] = None) -> bool: ++ """Check if the KV connector is the v1 connector. ++ If the argument is None, it will check the global KV connector ++ ++ Args: ++ connector: The KV connector to check. If None, it will check the ++ global KV connector. ++ ++ Note: ++ This function will no-longer be needed after the v1 KV connector ++ becomes the default. ++ """ ++ if connector is None: ++ connector = _KV_CONNECTOR_AGENT ++ ++ if connector is None: ++ return False ++ ++ return isinstance(connector, KVConnectorBase_V1) ++ ++ ++def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: ++ """ ++ Initialize KV cache transfer parallel group. ++ """ ++ ++ global _KV_CONNECTOR_AGENT ++ ++ if vllm_config.kv_transfer_config is None: ++ return ++ ++ if (vllm_config.kv_transfer_config.is_kv_transfer_instance ++ and _KV_CONNECTOR_AGENT is None): ++ if envs.VLLM_USE_V1: ++ _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1( ++ config=vllm_config, role=KVConnectorRole.WORKER) ++ else: ++ _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0( ++ rank=get_world_group().rank, ++ local_rank=get_world_group().local_rank, ++ config=vllm_config, ++ ) +diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py +index fa493fe..d0ac7e9 100644 +--- a/vllm/distributed/parallel_state.py ++++ b/vllm/distributed/parallel_state.py +@@ -29,15 +29,13 @@ from collections import namedtuple + from contextlib import contextmanager, nullcontext + from dataclasses import dataclass + from multiprocessing import shared_memory +-from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, +- Union) ++from typing import Any, Callable, Dict, List, Optional, Tuple, Union + from unittest.mock import patch + + import torch + import torch.distributed + from torch.distributed import Backend, ProcessGroup + +-import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer + import vllm.envs as envs + from vllm.distributed.device_communicators.base_device_communicator import ( + DeviceCommunicatorBase) +@@ -46,9 +44,6 @@ from vllm.logger import init_logger + from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname, + supports_custom_op) + +-if TYPE_CHECKING: +- from vllm.config import VllmConfig +- + + @dataclass + class GraphCaptureContext: +@@ -194,9 +189,11 @@ class GroupCoordinator: + + from vllm.platforms import current_platform + +- # TODO: fix it for other platforms + if current_platform.is_cuda_alike(): + self.device = torch.device(f"cuda:{local_rank}") ++ elif current_platform.is_out_of_tree(): ++ self.device = torch.device( ++ f"{current_platform.device_name}:{local_rank}") + else: + self.device = torch.device("cpu") + +@@ -770,14 +767,6 @@ def get_pp_group() -> GroupCoordinator: + # kept for backward compatibility + get_pipeline_model_parallel_group = get_pp_group + +-_KV_TRANSFER: Optional[kv_transfer.KVTransferAgent] = None +- +- +-def get_kv_transfer_group() -> kv_transfer.KVTransferAgent: +- assert _KV_TRANSFER is not None, ( +- "disaggregated KV cache transfer parallel group is not initialized") +- return _KV_TRANSFER +- + + @contextmanager + def graph_capture(device: torch.device): +@@ -960,26 +949,6 @@ def initialize_model_parallel( + _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group) + + +-def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: +- """ +- Initialize KV cache transfer parallel group. +- """ +- +- global _KV_TRANSFER +- +- if vllm_config.kv_transfer_config is None: +- return +- +- if all([ +- vllm_config.kv_transfer_config.is_kv_transfer_instance, +- _KV_TRANSFER is None +- ]): +- _KV_TRANSFER = kv_transfer.KVTransferAgent( +- rank=get_world_group().rank, +- local_rank=get_world_group().local_rank, +- config=vllm_config) +- +- + def ensure_model_parallel_initialized( + tensor_model_parallel_size: int, + pipeline_model_parallel_size: int, +diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py +index cae1a25..2bdcdf4 100644 +--- a/vllm/distributed/utils.py ++++ b/vllm/distributed/utils.py +@@ -21,6 +21,7 @@ from torch.distributed.rendezvous import rendezvous + + import vllm.envs as envs + from vllm.logger import init_logger ++from vllm.utils import get_tcp_uri + + logger = init_logger(__name__) + +@@ -282,7 +283,7 @@ def stateless_init_torch_distributed_process_group( + always formed with process 1, 2, ..., 8, and the additional communication + channel is formed with process 9 and 10. + """ +- init_method = f"tcp://{host}:{port}" ++ init_method = get_tcp_uri(host, port) + backend = Backend(backend) # it is basically string + timeout = _get_default_timeout(backend) + +@@ -301,6 +302,9 @@ def stateless_init_torch_distributed_process_group( + prefix_store, + group_rank, + group_size, ++ ProcessGroup.Options( ++ backend=backend ++ ) + ) + + if backend == "gloo": +@@ -325,7 +329,7 @@ def stateless_init_torch_distributed_process_group( + else: + raise RuntimeError(f"Unsupported torch distributed backend: {backend}") + +- pg._set_default_backend(backend_type) ++ #pg._set_default_backend(backend_type) + backend_class._set_sequence_number_for_group() + + pg._register_backend(device, backend_type, backend_class) +diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py +index 89c9b67..1e3969b 100644 +--- a/vllm/engine/arg_utils.py ++++ b/vllm/engine/arg_utils.py +@@ -113,11 +113,14 @@ class EngineArgs: + distributed_executor_backend: Optional[Union[str, + Type[ExecutorBase]]] = None + # number of P/D disaggregation (or other disaggregation) workers +- pipeline_parallel_size: int = 1 +- tensor_parallel_size: int = 1 +- data_parallel_size: int = 1 +- enable_expert_parallel: bool = False +- max_parallel_loading_workers: Optional[int] = None ++ pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size ++ tensor_parallel_size: int = ParallelConfig.tensor_parallel_size ++ data_parallel_size: int = ParallelConfig.data_parallel_size ++ data_parallel_size_local: Optional[int] = None ++ data_parallel_address: Optional[str] = None ++ data_parallel_rpc_port: Optional[int] = None ++ enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel ++ max_parallel_loading_workers: Optional[int] = ParallelConfig.max_parallel_loading_workers + block_size: Optional[int] = None + enable_prefix_caching: Optional[bool] = None + prefix_caching_hash_algo: str = "builtin" +@@ -126,7 +129,7 @@ class EngineArgs: + use_v2_block_manager: bool = True + swap_space: float = 4 # GiB + cpu_offload_gb: float = 0 # GiB +- gpu_memory_utilization: float = 0.90 ++ gpu_memory_utilization: float = 0.80 + max_num_batched_tokens: Optional[int] = None + max_num_partial_prefills: Optional[int] = 1 + max_long_partial_prefills: Optional[int] = 1 +@@ -434,6 +437,21 @@ class EngineArgs: + 'MoE layers will be sharded according to the ' + 'product of the tensor-parallel-size and ' + 'data-parallel-size.') ++ parser.add_argument('--data-parallel-size-local', ++ '-dpl', ++ type=int, ++ help='Number of data parallel replicas ' ++ 'to run on this node.') ++ parser.add_argument('--data-parallel-address', ++ '-dpa', ++ type=str, ++ help='Address of data parallel cluster ' ++ 'head-node.') ++ parser.add_argument('--data-parallel-rpc-port', ++ '-dpp', ++ type=int, ++ help='Port for data parallel RPC ' ++ 'communication.') + parser.add_argument( + '--enable-expert-parallel', + action='store_true', +@@ -1186,10 +1204,30 @@ class EngineArgs: + # but we should not do this here. + placement_group = ray.util.get_current_placement_group() + ++ # Local DP size defaults to global DP size if not set. ++ data_parallel_size_local = self.data_parallel_size if ( ++ self.data_parallel_size_local ++ is None) else self.data_parallel_size_local ++ ++ # DP address, used in multi-node case for torch distributed group ++ # and ZMQ sockets. ++ data_parallel_address = self.data_parallel_address if ( ++ self.data_parallel_address ++ is not None) else ParallelConfig.data_parallel_master_ip ++ ++ # This port is only used when there are remote data parallel engines, ++ # otherwise the local IPC transport is used. ++ data_parallel_rpc_port = self.data_parallel_rpc_port if ( ++ self.data_parallel_rpc_port ++ is not None) else ParallelConfig.data_parallel_rpc_port ++ + parallel_config = ParallelConfig( + pipeline_parallel_size=self.pipeline_parallel_size, + tensor_parallel_size=self.tensor_parallel_size, + data_parallel_size=self.data_parallel_size, ++ data_parallel_size_local=data_parallel_size_local, ++ data_parallel_master_ip=data_parallel_address, ++ data_parallel_rpc_port=data_parallel_rpc_port, + enable_expert_parallel=self.enable_expert_parallel, + max_parallel_loading_workers=self.max_parallel_loading_workers, + disable_custom_all_reduce=self.disable_custom_all_reduce, +@@ -1487,12 +1525,6 @@ class EngineArgs: + recommend_to_remove=False) + return False + +- # No Disaggregated Prefill so far. +- if self.kv_transfer_config != EngineArgs.kv_transfer_config: +- _raise_or_fallback(feature_name="--kv-transfer-config", +- recommend_to_remove=False) +- return False +- + # No FlashInfer or XFormers so far. + V1_BACKENDS = [ + "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1", +diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py +index e89ac4e..ffcc2bb 100644 +--- a/vllm/entrypoints/cli/serve.py ++++ b/vllm/entrypoints/cli/serve.py +@@ -1,14 +1,24 @@ + # SPDX-License-Identifier: Apache-2.0 + + import argparse ++import signal + + import uvloop + ++import vllm.envs as envs ++from vllm import AsyncEngineArgs + from vllm.entrypoints.cli.types import CLISubcommand + from vllm.entrypoints.openai.api_server import run_server + from vllm.entrypoints.openai.cli_args import (make_arg_parser, + validate_parsed_serve_args) +-from vllm.utils import FlexibleArgumentParser ++from vllm.logger import init_logger ++from vllm.usage.usage_lib import UsageContext ++from vllm.utils import FlexibleArgumentParser, get_tcp_uri ++from vllm.v1.engine.core import EngineCoreProc ++from vllm.v1.engine.core_client import CoreEngineProcManager ++from vllm.v1.executor.abstract import Executor ++ ++logger = init_logger(__name__) + + + class ServeSubcommand(CLISubcommand): +@@ -24,7 +34,10 @@ class ServeSubcommand(CLISubcommand): + if hasattr(args, 'model_tag') and args.model_tag is not None: + args.model = args.model_tag + +- uvloop.run(run_server(args)) ++ if args.headless: ++ run_headless(args) ++ else: ++ uvloop.run(run_server(args)) + + def validate(self, args: argparse.Namespace) -> None: + validate_parsed_serve_args(args) +@@ -41,6 +54,18 @@ class ServeSubcommand(CLISubcommand): + nargs='?', + help="The model tag to serve " + "(optional if specified in config)") ++ serve_parser.add_argument( ++ "--headless", ++ action='store_true', ++ default=False, ++ help="Run in headless mode. See multi-node data parallel " ++ "documentation for more details.") ++ serve_parser.add_argument( ++ '--data-parallel-start-rank', ++ '-dpr', ++ type=int, ++ default=0, ++ help='Starting data parallel rank for secondary nodes.') + serve_parser.add_argument( + "--config", + type=str, +@@ -56,3 +81,55 @@ class ServeSubcommand(CLISubcommand): + + def cmd_init() -> list[CLISubcommand]: + return [ServeSubcommand()] ++ ++ ++def run_headless(args: argparse.Namespace): ++ ++ # Create the EngineConfig. ++ engine_args = AsyncEngineArgs.from_cli_args(args) ++ usage_context = UsageContext.OPENAI_API_SERVER ++ vllm_config = engine_args.create_engine_config(usage_context=usage_context) ++ ++ if not envs.VLLM_USE_V1: ++ raise RuntimeError("Headless mode is only supported for V1") ++ ++ parallel_config = vllm_config.parallel_config ++ local_engine_count = parallel_config.data_parallel_size_local ++ host = parallel_config.data_parallel_master_ip ++ port = engine_args.data_parallel_rpc_port # add to config too ++ input_address = get_tcp_uri(host, port) ++ ++ if local_engine_count <= 0: ++ raise RuntimeError("data_parallel_size_local must be > 0 in " ++ "headless mode") ++ ++ # Catch SIGTERM and SIGINT to allow graceful shutdown. ++ def signal_handler(signum, frame): ++ logger.debug("Received %d signal.", signum) ++ raise SystemExit ++ ++ signal.signal(signal.SIGTERM, signal_handler) ++ signal.signal(signal.SIGINT, signal_handler) ++ ++ logger.info( ++ "Launching %d data parallel engine(s) in headless mode, " ++ "with head node address %s.", local_engine_count, input_address) ++ ++ # Create the engines. ++ engine_manager = CoreEngineProcManager( ++ target_fn=EngineCoreProc.run_engine_core, ++ local_engine_count=local_engine_count, ++ start_index=args.data_parallel_start_rank, ++ local_start_index=0, ++ vllm_config=vllm_config, ++ on_head_node=False, ++ input_address=input_address, ++ executor_class=Executor.get_class(vllm_config), ++ log_stats=not engine_args.disable_log_stats, ++ ) ++ ++ try: ++ engine_manager.join_first() ++ finally: ++ logger.info("Shutting down.") ++ engine_manager.close() +diff --git a/vllm/forward_context.py b/vllm/forward_context.py +index e195a03..178d197 100644 +--- a/vllm/forward_context.py ++++ b/vllm/forward_context.py +@@ -11,6 +11,10 @@ import torch.distributed as dist + + import vllm.envs as envs + from vllm.config import VllmConfig ++from vllm.distributed.kv_transfer import (get_kv_transfer_group, ++ has_kv_transfer_group, ++ is_v1_kv_transfer_group) ++from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 + from vllm.logger import init_logger + + if TYPE_CHECKING: +@@ -77,7 +81,8 @@ def set_forward_context(attn_metadata: Any, + attn_metadata.num_decode_tokens + else: + # for v1 attention backends +- batchsize = attn_metadata.num_input_tokens ++ # batchsize = attn_metadata.num_input_tokens ++ batchsize = len(attn_metadata.seq_lens) + else: + batchsize = num_tokens + num_tokens_across_dp = [0] * dp_size +@@ -98,6 +103,17 @@ def set_forward_context(attn_metadata: Any, + virtual_engine=virtual_engine, + attn_metadata=attn_metadata, + dp_metadata=dp_metadata) ++ ++ # KVConnector: trigger (possibly async) load before forward. ++ # Each attn layer will block until the reading is complete. ++ trigger_kv_transfer = (attn_metadata is not None ++ and has_kv_transfer_group() ++ and is_v1_kv_transfer_group()) ++ if trigger_kv_transfer: ++ kv_connector = get_kv_transfer_group() ++ assert isinstance(kv_connector, KVConnectorBase_V1) ++ kv_connector.start_load_kv(_forward_context) ++ + try: + yield + finally: +@@ -133,4 +149,12 @@ def set_forward_context(attn_metadata: Any, + logger.info(("Batchsize forward time stats " + "(batchsize, count, median_time(ms)): %s"), + forward_stats) ++ ++ # KVConnector: each attn layer triggers (possibly async) save. ++ # Ensure all those operations complete before forward() is done. ++ if trigger_kv_transfer: ++ kv_connector = get_kv_transfer_group() ++ assert isinstance(kv_connector, KVConnectorBase_V1) ++ kv_connector.wait_for_save() ++ + _forward_context = prev_context +diff --git a/vllm/utils.py b/vllm/utils.py +index 5f32f8c..b324bbd 100644 +--- a/vllm/utils.py ++++ b/vllm/utils.py +@@ -53,6 +53,7 @@ import torch.types + import yaml + import zmq + import zmq.asyncio ++from packaging import version + from packaging.version import Version + from torch.library import Library + from typing_extensions import Never, ParamSpec, TypeIs, assert_never +@@ -551,6 +552,10 @@ def is_valid_ipv6_address(address: str) -> bool: + + + def get_distributed_init_method(ip: str, port: int) -> str: ++ return get_tcp_uri(ip, port) ++ ++ ++def get_tcp_uri(ip: str, port: int) -> str: + # Brackets are not permitted in ipv4 addresses, + # see https://github.com/python/cpython/issues/103848 + return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}" +@@ -2189,6 +2194,8 @@ def make_zmq_socket( + ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined] + path: str, + socket_type: Any, ++ bind: Optional[bool] = None, ++ identity: Optional[bytes] = None, + ) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined] + """Make a ZMQ socket with the proper bind/connect semantics.""" + +@@ -2207,16 +2214,24 @@ def make_zmq_socket( + else: + buf_size = -1 # Use system default buffer size + +- if socket_type == zmq.constants.PULL: +- socket.setsockopt(zmq.constants.RCVHWM, 0) +- socket.setsockopt(zmq.constants.RCVBUF, buf_size) ++ if bind is None: ++ bind = socket_type != zmq.PUSH ++ ++ if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER): ++ socket.setsockopt(zmq.RCVHWM, 0) ++ socket.setsockopt(zmq.RCVBUF, buf_size) ++ ++ if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER): ++ socket.setsockopt(zmq.SNDHWM, 0) ++ socket.setsockopt(zmq.SNDBUF, buf_size) ++ ++ if identity is not None: ++ socket.setsockopt(zmq.IDENTITY, identity) ++ ++ if bind: + socket.bind(path) +- elif socket_type == zmq.constants.PUSH: +- socket.setsockopt(zmq.constants.SNDHWM, 0) +- socket.setsockopt(zmq.constants.SNDBUF, buf_size) +- socket.connect(path) + else: +- raise ValueError(f"Unknown Socket Type: {socket_type}") ++ socket.connect(path) + + return socket + +@@ -2225,14 +2240,19 @@ def make_zmq_socket( + def zmq_socket_ctx( + path: str, + socket_type: Any, ++ bind: Optional[bool] = None, + linger: int = 0, ++ identity: Optional[bytes] = None, + ) -> Iterator[zmq.Socket]: + """Context manager for a ZMQ socket""" + + ctx = zmq.Context() # type: ignore[attr-defined] + try: +- yield make_zmq_socket(ctx, path, socket_type) +- ++ yield make_zmq_socket(ctx, ++ path, ++ socket_type, ++ bind=bind, ++ identity=identity) + except KeyboardInterrupt: + logger.debug("Got Keyboard Interrupt.") + +@@ -2564,3 +2584,19 @@ def sha256(input) -> int: + input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) + return int.from_bytes(hashlib.sha256(input_bytes).digest(), + byteorder="big") ++ ++def is_torch_equal_or_newer(target: str) -> bool: ++ """Check if the installed torch version is >= the target version. ++ ++ Args: ++ target: a version string, like "2.6.0". ++ ++ Returns: ++ Whether the condition meets. ++ """ ++ try: ++ torch_version = version.parse(str(torch.__version__)) ++ return torch_version >= version.parse(target) ++ except Exception: ++ # Fallback to PKG-INFO to load the package info, needed by the doc gen. ++ return Version(importlib.metadata.version('torch')) >= Version(target) +\ No newline at end of file +diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py +index 1437db7..7b95360 100644 +--- a/vllm/v1/attention/backends/mla/common.py ++++ b/vllm/v1/attention/backends/mla/common.py +@@ -195,7 +195,6 @@ from vllm import _custom_ops as ops + from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, + AttentionMetadata, + MLAAttentionImpl) +-from vllm.attention.ops.triton_merge_attn_states import merge_attn_states + from vllm.logger import init_logger + from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearBase, RowParallelLinear, +@@ -204,12 +203,21 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding + from vllm.platforms import current_platform + from vllm.utils import cdiv, round_down + from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version ++from vllm.triton_utils import HAS_TRITON ++ ++if HAS_TRITON: ++ from vllm.attention.ops.triton_merge_attn_states import merge_attn_states ++else: ++ merge_attn_states = None + + try: + from vllm.vllm_flash_attn import flash_attn_varlen_func + except ImportError: + # For rocm use upstream flash attention +- from flash_attn import flash_attn_varlen_func ++ try: ++ from flash_attn import flash_attn_varlen_func ++ except ImportError: ++ flash_attn_varlen_func = None + + if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput +diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py +index dc0d2d5..1d3f1f4 100644 +--- a/vllm/v1/core/sched/output.py ++++ b/vllm/v1/core/sched/output.py +@@ -9,6 +9,8 @@ if TYPE_CHECKING: + import numpy as np + import numpy.typing as npt + ++ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( ++ KVConnectorMetadata) + from vllm.lora.request import LoRARequest + from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange + from vllm.sampling_params import SamplingParams +@@ -121,3 +123,6 @@ class SchedulerOutput: + structured_output_request_ids: dict[str, int] + # the bitmask for the whole batch + grammar_bitmask: Optional[npt.NDArray[np.int32]] ++ ++ # KV Cache Connector metadata. ++ kv_connector_metadata: Optional[KVConnectorMetadata] = None +diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py +index 81f8ad2..7e658d1 100644 +--- a/vllm/v1/core/sched/scheduler.py ++++ b/vllm/v1/core/sched/scheduler.py +@@ -7,7 +7,10 @@ from collections import deque + from collections.abc import Iterable + from typing import Optional, Union + +-from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig ++from vllm.config import VllmConfig ++from vllm.distributed.kv_transfer.kv_connector.factory import ( ++ KVConnectorFactory) ++from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole + from vllm.logger import init_logger + from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry + from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, +@@ -33,19 +36,17 @@ class Scheduler(SchedulerInterface): + + def __init__( + self, +- scheduler_config: SchedulerConfig, +- model_config: ModelConfig, +- cache_config: CacheConfig, +- lora_config: Optional[LoRAConfig], ++ vllm_config: VllmConfig, + kv_cache_config: KVCacheConfig, + structured_output_manager: StructuredOutputManager, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + include_finished_set: bool = False, + log_stats: bool = False, + ) -> None: +- self.scheduler_config = scheduler_config +- self.cache_config = cache_config +- self.lora_config = lora_config ++ self.vllm_config = vllm_config ++ self.scheduler_config = vllm_config.scheduler_config ++ self.cache_config = vllm_config.cache_config ++ self.lora_config = vllm_config.lora_config + self.kv_cache_config = kv_cache_config + self.log_stats = log_stats + self.structured_output_manager = structured_output_manager +@@ -62,11 +63,22 @@ class Scheduler(SchedulerInterface): + self.scheduler_config.max_num_batched_tokens + self.max_model_len = self.scheduler_config.max_model_len + ++ # Create KVConnector for the Scheduler. Note that each Worker ++ # will have a corresponding KVConnector with Role=WORKER. ++ # KV Connector pushes/pull of remote KVs for P/D and offloading. ++ self.connector = None ++ if self.vllm_config.kv_transfer_config is not None: ++ self.connector = KVConnectorFactory.create_connector_v1( ++ config=self.vllm_config, role=KVConnectorRole.SCHEDULER) ++ ++ num_gpu_blocks = self.cache_config.num_gpu_blocks ++ assert num_gpu_blocks is not None and num_gpu_blocks > 0 ++ + # Create the KV cache manager. + self.kv_cache_manager = KVCacheManager( + kv_cache_config=kv_cache_config, + max_model_len=self.max_model_len, +- enable_caching=cache_config.enable_prefix_caching, ++ enable_caching=self.cache_config.enable_prefix_caching, + caching_hash_algo=self.cache_config.prefix_caching_hash_algo, + log_stats=self.log_stats) + self.block_size = self.cache_config.block_size +@@ -97,8 +109,8 @@ class Scheduler(SchedulerInterface): + # This can be changed when we make encoder cache for embedding caching + # across requests. + encoder_compute_budget, encoder_cache_size = compute_encoder_budget( +- model_config=model_config, +- scheduler_config=scheduler_config, ++ model_config=vllm_config.model_config, ++ scheduler_config=vllm_config.scheduler_config, + mm_registry=mm_registry, + ) + +@@ -112,6 +124,12 @@ class Scheduler(SchedulerInterface): + self.encoder_cache_manager = EncoderCacheManager( + cache_size=encoder_cache_size) + ++ self.num_lookahead_tokens = 0 ++ speculative_config = vllm_config.speculative_config ++ if speculative_config and speculative_config.method == "eagle": ++ self.num_lookahead_tokens = \ ++ speculative_config.num_speculative_tokens ++ + def schedule(self) -> SchedulerOutput: + # NOTE(woosuk) on the scheduling algorithm: + # There's no "decoding phase" nor "prefill phase" in the scheduler. +@@ -188,7 +206,9 @@ class Scheduler(SchedulerInterface): + + while True: + new_blocks = self.kv_cache_manager.allocate_slots( +- request, num_new_tokens) ++ request, ++ num_new_tokens, ++ num_lookahead_tokens=self.num_lookahead_tokens) + if new_blocks is None: + # The request cannot be scheduled. + # Preempt the lowest-priority request. +@@ -295,6 +315,16 @@ class Scheduler(SchedulerInterface): + # Get already-cached tokens. + computed_blocks, num_computed_tokens = \ + self.kv_cache_manager.get_computed_blocks(request) ++ ++ # Get externally-cached tokens if using a KVConnector. ++ num_external_tokens = ( ++ 0 if self.connector is None else ++ self.connector.get_num_new_matched_tokens( ++ request, num_computed_tokens)) ++ ++ # Total computed tokens (local + external). ++ num_computed_tokens += num_external_tokens ++ + # Number of tokens to be scheduled. + # We use `request.num_tokens` instead of + # `request.num_prompt_tokens` to consider the resumed requests, +@@ -321,11 +351,21 @@ class Scheduler(SchedulerInterface): + new_encoder_budget = encoder_budget + + new_blocks = self.kv_cache_manager.allocate_slots( +- request, num_new_tokens, computed_blocks) ++ request, num_new_tokens + num_external_tokens, ++ computed_blocks) + if new_blocks is None: + # The request cannot be scheduled. + break + ++ # KVConnector: update internal state after allocation. ++ # This information is used to determine if a load is ++ # needed for this request. ++ if self.connector is not None: ++ self.connector.update_state_after_alloc( ++ request, ++ num_external_tokens, ++ ) ++ + self.waiting.popleft() + if request.use_structured_output: + structured_output_request_ids[ +@@ -434,6 +474,14 @@ class Scheduler(SchedulerInterface): + grammar_bitmask=grammar_bitmask, + ) + ++ # NOTE(Kuntai): this function is designed for multiple purposes: ++ # 1. Plan the KV cache store ++ # 2. Wrap up all the KV cache load / save ops into an opaque object ++ # 3. Clear the internal states of the connector ++ if self.connector is not None: ++ meta = self.connector.build_connector_meta(scheduler_output) ++ scheduler_output.kv_connector_metadata = meta ++ + # Advance the number of computed tokens for the request AFTER + # the request is scheduled. + # 1. The scheduler_output of the current step has to include the +@@ -499,14 +547,17 @@ class Scheduler(SchedulerInterface): + If an encoder input cannot be scheduled due to cache or budget + limitations, the method adjusts `num_new_tokens` to schedule only the + decoder tokens up to just before the unschedulable encoder input. ++ ++ Note that num_computed_tokens includes both locally cached ++ blocks and externally cached blocks (via KVConnector). + """ + encoder_inputs_to_schedule: list[int] = [] + mm_positions = request.mm_positions + assert mm_positions is not None + assert len(mm_positions) > 0 + for i, pos_info in enumerate(mm_positions): +- start_pos = pos_info["offset"] +- num_encoder_tokens = pos_info["length"] ++ start_pos = pos_info.offset ++ num_encoder_tokens = pos_info.length + + # The encoder output is needed if the two ranges overlap: + # [num_computed_tokens, num_computed_tokens + num_new_tokens) and +@@ -522,6 +573,17 @@ class Scheduler(SchedulerInterface): + if self.encoder_cache_manager.has_cache(request, i): + # The encoder input is already computed and cached. + continue ++ ++ # If no encoder input chunking is allowed, we do not want to ++ # partially schedule a multimodal item. If the scheduled range would ++ # only cover part of the mm input, roll back to before the mm item. ++ if (self.scheduler_config.disable_chunked_mm_input ++ and num_computed_tokens < start_pos ++ and (num_computed_tokens + num_new_tokens) ++ < (start_pos + num_encoder_tokens)): ++ num_new_tokens = start_pos - num_computed_tokens ++ break ++ + if (not self.encoder_cache_manager.can_allocate(request, i) + or num_encoder_tokens > encoder_budget): + # The encoder cache is full or the encoder budget is exhausted. +@@ -596,8 +658,8 @@ class Scheduler(SchedulerInterface): + if cached_encoder_input_ids: + for input_id in list(cached_encoder_input_ids): + mm_positions = request.mm_positions[input_id] +- start_pos = mm_positions["offset"] +- num_tokens = mm_positions["length"] ++ start_pos = mm_positions.offset ++ num_tokens = mm_positions.length + if start_pos + num_tokens <= request.num_computed_tokens: + # The encoder output is already processed and stored + # in the decoder's KV cache. +diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py +index 39caca0..573d00b 100644 +--- a/vllm/v1/engine/core.py ++++ b/vllm/v1/engine/core.py +@@ -23,7 +23,7 @@ from vllm.lora.request import LoRARequest + from vllm.transformers_utils.config import ( + maybe_register_config_serialize_by_value) + from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname, +- zmq_socket_ctx) ++ make_zmq_socket, resolve_obj_by_qualname, zmq_socket_ctx) + from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, + unify_kv_cache_configs) + from vllm.v1.core.sched.interface import SchedulerInterface +@@ -43,6 +43,7 @@ from vllm.version import __version__ as VLLM_VERSION + logger = init_logger(__name__) + + POLLING_TIMEOUT_S = 2.5 ++HANDSHAKE_TIMEOUT_MINS = 5 + + _R = TypeVar('_R') # Return type for collective_rpc + +@@ -93,10 +94,7 @@ class EngineCore: + vllm_config.scheduler_config.scheduler_cls) + + self.scheduler: SchedulerInterface = Scheduler( +- scheduler_config=vllm_config.scheduler_config, +- model_config=vllm_config.model_config, +- cache_config=vllm_config.cache_config, +- lora_config=vllm_config.lora_config, ++ vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + structured_output_manager=self.structured_output_manager, + include_finished_set=vllm_config.parallel_config.data_parallel_size +@@ -306,43 +304,109 @@ class EngineCore: + + class EngineCoreProc(EngineCore): + """ZMQ-wrapper for running EngineCore in background process.""" ++ ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD' + + def __init__( + self, +- input_path: str, +- output_path: str, + vllm_config: VllmConfig, ++ on_head_node: bool, ++ input_address: str, + executor_class: type[Executor], + log_stats: bool, + engine_index: int = 0, + ): +- super().__init__(vllm_config, executor_class, log_stats) +- +- # Background Threads and Queues for IO. These enable us to +- # overlap ZMQ socket IO with GPU since they release the GIL, +- # and to overlap some serialization/deserialization with the +- # model forward pass. +- # Threads handle Socket <-> Queues and core_busy_loop uses Queue. +- self.input_queue: queue.Queue[tuple[EngineCoreRequestType, +- Any]] = queue.Queue() +- self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() +- threading.Thread(target=self.process_input_socket, +- args=(input_path, ), +- daemon=True).start() +- threading.Thread(target=self.process_output_socket, +- args=(output_path, engine_index), +- daemon=True).start() +- +- self.global_unfinished_reqs = False +- +- self.step_fn = (self.step if self.batch_queue is None else +- self.step_with_batch_queue) ++ input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]() ++ ++# GZQ DP patch TODO: executor_fail_callback is useful sometimes for reliabiltiy issues, should add later ++ executor_fail_callback = lambda: input_queue.put_nowait( ++ (EngineCoreRequestType.EXECUTOR_FAILED, b'')) ++ ++ # Create input socket. ++ input_ctx = zmq.Context() ++ identity = engine_index.to_bytes(length=2, byteorder="little") ++ input_socket = make_zmq_socket(input_ctx, ++ input_address, ++ zmq.DEALER, ++ identity=identity, ++ bind=False) ++ try: ++ # Register engine with front-end. ++ output_address = self.startup_handshake( ++ input_socket, on_head_node, vllm_config.parallel_config) ++ # Update config which may have changed from the handshake. ++ vllm_config.__post_init__() ++ # Set up data parallel environment. ++ self._init_data_parallel(vllm_config) ++ ++ # Initialize engine core and model. ++ super().__init__(vllm_config, executor_class, log_stats) ++ ++ self.step_fn = (self.step if self.batch_queue is None else ++ self.step_with_batch_queue) ++ ++ self.global_unfinished_reqs = False ++ ++ # Send ready message. ++ input_socket.send( ++ msgspec.msgpack.encode({ ++ "status": "READY", ++ "local": on_head_node ++ })) ++ ++ # Background Threads and Queues for IO. These enable us to ++ # overlap ZMQ socket IO with GPU since they release the GIL, ++ # and to overlap some serialization/deserialization with the ++ # model forward pass. ++ # Threads handle Socket <-> Queues and core_busy_loop uses Queue. ++ self.input_queue = input_queue ++ self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]() ++ threading.Thread(target=self.process_input_socket, ++ args=(input_socket, ), ++ daemon=True).start() ++ input_socket = None ++ self.output_thread = threading.Thread( ++ target=self.process_output_socket, ++ args=(output_address, engine_index), ++ daemon=True) ++ self.output_thread.start() ++ finally: ++ if input_socket is not None: ++ input_socket.close(linger=0) ++ ++ @staticmethod ++ def startup_handshake(input_socket: zmq.Socket, on_head_node: bool, ++ parallel_config: ParallelConfig) -> str: ++ ++ # Send registration message. ++ input_socket.send( ++ msgspec.msgpack.encode({ ++ "status": "HELLO", ++ "local": on_head_node, ++ })) ++ ++ # Receive initialization message. ++ logger.info("Waiting for init message from front-end.") ++ if not input_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60 * 1000): ++ raise RuntimeError("Did not receive response from front-end " ++ f"process within {HANDSHAKE_TIMEOUT_MINS} " ++ f"minutes") ++ init_bytes = input_socket.recv() ++ init_message = msgspec.msgpack.decode(init_bytes) ++ logger.debug("Received init message: %s", init_message) ++ ++ output_socket_address = init_message["output_socket_address"] ++ #TBD(nick) maybe replace IP with configured head node address ++ ++ received_parallel_config = init_message["parallel_config"] ++ for key, value in received_parallel_config.items(): ++ setattr(parallel_config, key, value) ++ ++ return output_socket_address + + @staticmethod + def run_engine_core(*args, + dp_rank: int = 0, + local_dp_rank: int = 0, +- ready_pipe, + **kwargs): + """Launch EngineCore busy loop in background process.""" + +@@ -369,7 +433,7 @@ class EngineCoreProc(EngineCore): + try: + parallel_config: ParallelConfig = kwargs[ + "vllm_config"].parallel_config +- if parallel_config.data_parallel_size > 1: ++ if parallel_config.data_parallel_size > 1 or dp_rank > 0: + # Set data parallel rank for this engine process. + parallel_config.data_parallel_rank = dp_rank + parallel_config.data_parallel_rank_local = local_dp_rank +@@ -377,9 +441,6 @@ class EngineCoreProc(EngineCore): + else: + engine_core = EngineCoreProc(*args, **kwargs) + +- # Send Readiness signal to EngineClient. +- ready_pipe.send({"status": "READY"}) +- + engine_core.run_busy_loop() + + except SystemExit: +@@ -394,6 +455,9 @@ class EngineCoreProc(EngineCore): + if engine_core is not None: + engine_core.shutdown() + ++ def _init_data_parallel(self, vllm_config: VllmConfig): ++ pass ++ + def run_busy_loop(self): + """Core busy loop of the EngineCore.""" + +@@ -476,27 +540,37 @@ class EngineCoreProc(EngineCore): + and not isinstance(v, p.annotation) else v + for v, p in zip(args, arg_types)) + +- def process_input_socket(self, input_path: str): ++ def _send_engine_dead(self): ++ """Send EngineDead status to the EngineCoreClient.""" ++ ++ # Put ENGINE_CORE_DEAD in the queue. ++ self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD) ++ ++ # Wait until msg sent by the daemon before shutdown. ++ self.output_thread.join(timeout=5.0) ++ if self.output_thread.is_alive(): ++ logger.fatal("vLLM shutdown signal from EngineCore failed " ++ "to send. Please report this issue.") ++ ++ def process_input_socket(self, input_socket: zmq.Socket): + """Input socket IO thread.""" + + # Msgpack serialization decoding. + add_request_decoder = MsgpackDecoder(EngineCoreRequest) + generic_decoder = MsgpackDecoder() + +- with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket: +- while True: +- # (RequestType, RequestData) +- type_frame, data_frame = socket.recv_multipart(copy=False) +- request_type = EngineCoreRequestType(bytes(type_frame.buffer)) ++ while True: ++ # (RequestType, RequestData) ++ type_frame, data_frames = input_socket.recv_multipart(copy=False) ++ request_type = EngineCoreRequestType(bytes(type_frame.buffer)) + +- # Deserialize the request data. +- decoder = add_request_decoder if ( +- request_type +- == EngineCoreRequestType.ADD) else generic_decoder +- request = decoder.decode(data_frame.buffer) ++ # Deserialize the request data. ++ decoder = add_request_decoder if ( ++ request_type == EngineCoreRequestType.ADD) else generic_decoder ++ request = decoder.decode(data_frames) + +- # Push to input queue for core busy loop. +- self.input_queue.put_nowait((request_type, request)) ++ # Push to input queue for core busy loop. ++ self.input_queue.put_nowait((request_type, request)) + + def process_output_socket(self, output_path: str, engine_index: int): + """Output socket IO thread.""" +@@ -523,9 +597,9 @@ class DPEngineCoreProc(EngineCoreProc): + + def __init__( + self, +- input_path: str, +- output_path: str, + vllm_config: VllmConfig, ++ on_head_node: bool, ++ input_address: str, + executor_class: type[Executor], + log_stats: bool, + ): +@@ -537,8 +611,20 @@ class DPEngineCoreProc(EngineCoreProc): + _add_prefix(sys.stdout, process_name, pid) + _add_prefix(sys.stderr, process_name, pid) + +- dp_size = vllm_config.parallel_config.data_parallel_size ++ # Counts forward-passes of the model so that we can synchronize ++ # finished with DP peers every N steps. ++ self.counter = 0 ++ ++ # Initialize the engine. ++ dp_rank = vllm_config.parallel_config.data_parallel_rank ++ super().__init__(vllm_config, on_head_node, input_address, ++ executor_class, log_stats, dp_rank) ++ ++ def _init_data_parallel(self, vllm_config: VllmConfig): ++ ++ # Configure GPUs and stateless process group for data parallel. + dp_rank = vllm_config.parallel_config.data_parallel_rank ++ dp_size = vllm_config.parallel_config.data_parallel_size + local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local + + assert dp_size > 1 +@@ -547,22 +633,14 @@ class DPEngineCoreProc(EngineCoreProc): + from vllm.platforms import current_platform + if current_platform.is_cuda_alike(): + from vllm.platforms.cuda import device_id_to_physical_device_id +- tp_size = vllm_config.parallel_config.tensor_parallel_size ++ world_size = vllm_config.parallel_config.world_size + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( + str(device_id_to_physical_device_id(i)) +- for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) * +- tp_size)) ++ for i in range(local_dp_rank * ++ world_size, (local_dp_rank + 1) * world_size)) + + self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() + +- # Initialize the engine after setting up environment. +- super().__init__(input_path, output_path, vllm_config, executor_class, +- log_stats, dp_rank) +- +- # Counts forward-passes of the model so that we can synchronize +- # finished with DP peers every N steps. +- self.counter = 0 +- + def shutdown(self): + super().shutdown() + if dp_group := getattr(self, "dp_group", None): +diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py +index e948e59..e443f45 100644 +--- a/vllm/v1/engine/core_client.py ++++ b/vllm/v1/engine/core_client.py +@@ -8,26 +8,29 @@ import threading + import uuid + import weakref + from abc import ABC, abstractmethod +-from collections.abc import Awaitable, Sequence ++from collections.abc import Awaitable + from concurrent.futures import Future +-from dataclasses import dataclass, field ++from dataclasses import dataclass ++from enum import Enum, auto + from threading import Thread + from typing import Any, Callable, Optional, TypeVar, Union + ++import msgspec + import zmq + import zmq.asyncio + +-from vllm.config import VllmConfig ++from vllm.config import ParallelConfig, VllmConfig + from vllm.logger import init_logger + from vllm.lora.request import LoRARequest +-from vllm.utils import (get_open_zmq_inproc_path, get_open_zmq_ipc_path, +- kill_process_tree, make_zmq_socket) ++ ++from vllm.utils import (get_open_port, get_open_zmq_inproc_path, ++ get_open_zmq_ipc_path, kill_process_tree, get_tcp_uri, make_zmq_socket) + from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, + EngineCoreRequestType, UtilityOutput) + from vllm.v1.engine.core import EngineCore, EngineCoreProc + from vllm.v1.executor.abstract import Executor +-from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder +-from vllm.v1.utils import BackgroundProcHandle ++from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr ++from vllm.v1.utils import CoreEngineProcManager + + logger = init_logger(__name__) + +@@ -35,6 +38,8 @@ AnyFuture = Union[asyncio.Future[Any], Future[Any]] + + _R = TypeVar('_R') # Return type for collective_rpc + ++STARTUP_POLL_PERIOD_MS = 10000 ++ + + class EngineCoreClient(ABC): + """ +@@ -253,52 +258,21 @@ class InprocClient(EngineCoreClient): + return self.engine_core.collective_rpc(method, timeout, args, kwargs) + + +-class CoreEngine: +- """One per data parallel rank.""" ++class CoreEngineState(Enum): ++ NEW = auto() ++ CONNECTED = auto() ++ READY = auto() + +- def __init__( +- self, +- vllm_config: VllmConfig, +- executor_class: type[Executor], +- log_stats: bool, +- ctx: Union[zmq.Context, zmq.asyncio.Context], +- output_path: str, +- index: int = 0, +- local_dp_rank: int = 0, +- ): +- # Paths and sockets for IPC. +- input_path = get_open_zmq_ipc_path() +- self.input_socket = make_zmq_socket(ctx, input_path, +- zmq.constants.PUSH) +- try: +- # Start EngineCore in background process. +- self.proc_handle = BackgroundProcHandle( +- input_path=input_path, +- output_path=output_path, +- process_name=f"EngineCore_{index}", +- target_fn=EngineCoreProc.run_engine_core, +- process_kwargs={ +- "vllm_config": vllm_config, +- "dp_rank": index, +- "local_dp_rank": local_dp_rank, +- "executor_class": executor_class, +- "log_stats": log_stats, +- }) + +- self.num_reqs_in_flight = 0 +- finally: +- if not hasattr(self, "num_reqs_in_flight"): +- # Ensure socket is closed if process fails to start. +- self.close() ++class CoreEngine: ++ """One per data parallel rank.""" + +- def send_multipart(self, msg_parts: Sequence): +- return self.input_socket.send_multipart(msg_parts, copy=False) ++ def __init__(self, index: int = 0, local: bool = True): ++ self.local = local ++ self.identity = index.to_bytes(length=2, byteorder="little") + +- def close(self): +- if proc_handle := getattr(self, "proc_handle", None): +- proc_handle.shutdown() +- if socket := getattr(self, "input_socket", None): +- socket.close(linger=0) ++ self.state = CoreEngineState.NEW ++ self.num_reqs_in_flight = 0 + + + @dataclass +@@ -307,20 +281,23 @@ class BackgroundResources: + circular reference back to the client object.""" + + ctx: Union[zmq.Context] +- core_engines: list[CoreEngine] = field(default_factory=list) ++ local_engine_manager: Optional[CoreEngineProcManager] = None + output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None ++ input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None + shutdown_path: Optional[str] = None + + def __call__(self): + """Clean up background resources.""" + +- for core_engine in self.core_engines: +- core_engine.close() ++ if self.local_engine_manager is not None: ++ self.local_engine_manager.close() + + # ZMQ context termination can hang if the sockets + # aren't explicitly closed first. + if self.output_socket is not None: + self.output_socket.close(linger=0) ++ if self.input_socket is not None: ++ self.input_socket.close(linger=0) + if self.shutdown_path is not None: + # We must ensure that the sync output socket is + # closed cleanly in its own thread. +@@ -384,38 +361,169 @@ class MPClient(EngineCoreClient): + # exception is raised mid-construction. + self.resources = BackgroundResources(ctx=sync_ctx) + self._finalizer = weakref.finalize(self, self.resources) ++ success = False ++ try: ++ parallel_config = vllm_config.parallel_config ++ local_engine_count = parallel_config.data_parallel_size_local ++ start_index = parallel_config.data_parallel_rank ++ local_start_index = parallel_config.data_parallel_rank_local ++ ++ # SPMD mode is where there is an LLM instance per DP rank and ++ # one core engine per LLM, see ++ # examples/offline_inference/data_parallel.py. ++ spmd_mode = local_start_index is not None ++ if spmd_mode: ++ assert local_engine_count == 1 ++ self.core_engines = [ ++ CoreEngine(index=local_start_index, local=True) ++ ] ++ else: ++ assert start_index == 0 ++ local_start_index = 0 ++ self.core_engines = [ ++ CoreEngine(index=i, local=(i < local_engine_count)) ++ for i in range(parallel_config.data_parallel_size) ++ ] + +- # Paths and sockets for IPC. +- self.output_path = get_open_zmq_ipc_path() +- +- new_core_engine = lambda index, local_dp_rank=None: CoreEngine( +- vllm_config, executor_class, log_stats, self.ctx, self.output_path, +- index, local_dp_rank) +- +- # Start engine core process(es). +- self._init_core_engines(vllm_config, new_core_engine, +- self.resources.core_engines) +- +- # Wait for engine core process(es) to start. +- for engine in self.resources.core_engines: +- engine.proc_handle.wait_for_startup() +- +- self.utility_results: dict[int, AnyFuture] = {} ++ input_address, output_address = self._get_zmq_addresses( ++ parallel_config, spmd_mode) ++ ++ # Create input and output sockets. ++ self.input_socket = self.resources.input_socket = make_zmq_socket( ++ self.ctx, input_address, zmq.ROUTER, bind=True) ++ ++ self.resources.output_socket = make_zmq_socket( ++ self.ctx, output_address, zmq.constants.PULL) ++ # Start local engines. ++ if local_engine_count: ++ # In server mode, start_index and local_start_index will ++ # both be 0. ++ self.resources.local_engine_manager = CoreEngineProcManager( ++ EngineCoreProc.run_engine_core, ++ vllm_config=vllm_config, ++ executor_class=executor_class, ++ log_stats=log_stats, ++ input_address=input_address, ++ on_head_node=True, ++ local_engine_count=local_engine_count, ++ start_index=start_index, ++ local_start_index=local_start_index) ++ ++ self.core_engine = self.core_engines[0] ++ ++ # Wait for engine core process(es) to start. ++ self._wait_for_engine_startup(output_address, parallel_config) ++ ++ self.utility_results: dict[int, AnyFuture] = {} ++ success = True ++ finally: ++ if not success: ++ self._finalizer() + +- def _init_core_engines( +- self, +- vllm_config: VllmConfig, +- new_core_engine: Callable[[int, Optional[int]], CoreEngine], +- core_engines: list[CoreEngine], +- ) -> None: +- +- # Default case - single core engine. +- dp_rank = vllm_config.parallel_config.data_parallel_rank +- local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local +- core_engine = new_core_engine( +- dp_rank, local_dp_rank if local_dp_rank is not None else dp_rank) +- core_engines.append(core_engine) +- self.core_engine = core_engine ++ @staticmethod ++ def _get_zmq_addresses(parallel_config: ParallelConfig, ++ spmd_mode: bool) -> tuple[str, str]: ++ """Returns (input_address, output_address).""" ++ dp_size = parallel_config.data_parallel_size ++ local_engine_count = parallel_config.data_parallel_size_local ++ ++ if local_engine_count == dp_size or spmd_mode: ++ input_address = get_open_zmq_ipc_path() ++ output_address = get_open_zmq_ipc_path() ++ else: ++ host = parallel_config.data_parallel_master_ip ++ input_port = parallel_config.data_parallel_rpc_port ++ output_port = get_open_port() ++ input_address = get_tcp_uri(host, input_port) ++ output_address = get_tcp_uri(host, output_port) ++ ++ return input_address, output_address ++ ++ def _wait_for_engine_startup(self, output_address: str, ++ parallel_config: ParallelConfig): ++ # Get a sync handle to the socket which can be sync or async. ++ sync_input_socket = zmq.Socket.shadow(self.input_socket) ++ ++ # Wait for engine core process(es) to send ready messages. ++ local_count = parallel_config.data_parallel_size_local ++ remote_count = len(self.core_engines) - local_count ++ # [local, remote] counts ++ conn_pending, start_pending = [local_count, remote_count], [0, 0] ++ ++ poller = zmq.Poller() ++ poller.register(sync_input_socket, zmq.POLLIN) ++ proc_manager = self.resources.local_engine_manager ++ if proc_manager is not None: ++ for sentinel in proc_manager.sentinels(): ++ poller.register(sentinel, zmq.POLLIN) ++ while any(conn_pending) or any(start_pending): ++ events = poller.poll(STARTUP_POLL_PERIOD_MS) ++ if not events: ++ if any(conn_pending): ++ logger.debug( ++ "Waiting for %d local, %d remote core engine proc(s) " ++ "to connect.", *conn_pending) ++ if any(start_pending): ++ logger.debug( ++ "Waiting for %d local, %d remote core engine proc(s) " ++ "to start.", *start_pending) ++ continue ++ if len(events) > 1 or events[0][0] != sync_input_socket: ++ # One of the local core processes exited. ++ finished = proc_manager.finished_procs( ++ ) if proc_manager else {} ++ raise RuntimeError("Engine core initialization failed. " ++ "See root cause above. " ++ f"Failed core proc(s): {finished}") ++ ++ # Receive HELLO and READY messages from the input socket. ++ eng_identity, ready_msg_bytes = sync_input_socket.recv_multipart() ++ eng_index = int.from_bytes(eng_identity, byteorder="little") ++ engine = next( ++ (e for e in self.core_engines if e.identity == eng_identity), ++ None) ++ if engine is None: ++ raise RuntimeError(f"Message from engine with unexpected data " ++ f"parallel rank: {eng_index}") ++ msg = msgspec.msgpack.decode(ready_msg_bytes) ++ status, local = msg["status"], msg["local"] ++ if local != engine.local: ++ raise RuntimeError(f"{status} message from " ++ f"{'local' if local else 'remote'} " ++ f"engine {eng_index}, expected it to be " ++ f"{'local' if engine.local else 'remote'}") ++ ++ if status == "HELLO" and engine.state == CoreEngineState.NEW: ++ ++ # Send init message with DP config info. ++ init_message = self.encoder.encode({ ++ "output_socket_address": output_address, ++ "parallel_config": { ++ "data_parallel_master_ip": ++ parallel_config.data_parallel_master_ip, ++ "data_parallel_master_port": ++ parallel_config.data_parallel_master_port, ++ "data_parallel_size": ++ parallel_config.data_parallel_size, ++ }, ++ }) ++ sync_input_socket.send_multipart((eng_identity, init_message), ++ copy=False) ++ conn_pending[0 if local else 1] -= 1 ++ start_pending[0 if local else 1] += 1 ++ engine.state = CoreEngineState.CONNECTED ++ elif status == "READY" and (engine.state ++ == CoreEngineState.CONNECTED): ++ start_pending[0 if local else 1] -= 1 ++ engine.state = CoreEngineState.READY ++ else: ++ raise RuntimeError(f"Unexpected {status} message for " ++ f"{'local' if local else 'remote'} engine " ++ f"{eng_index} in {engine.state} state.") ++ ++ logger.debug("%s from %s core engine process %s.", status, ++ "local" if local else "remote", eng_index) ++# >>>>>>> fbe7575cc... squashed commit of pr#15977 + + def shutdown(self): + self._finalizer() +@@ -448,7 +556,8 @@ class SyncMPClient(MPClient): + # Ensure that the outputs socket processing thread does not have + # a ref to the client which prevents gc. + ctx = self.ctx +- output_path = self.output_path ++ out_socket = self.resources.output_socket ++ assert out_socket is not None + decoder = self.decoder + utility_results = self.utility_results + outputs_queue = self.outputs_queue +@@ -458,7 +567,6 @@ class SyncMPClient(MPClient): + + def process_outputs_socket(): + shutdown_socket = ctx.socket(zmq.PAIR) +- out_socket = make_zmq_socket(ctx, output_path, zmq.constants.PULL) + try: + shutdown_socket.bind(shutdown_path) + poller = zmq.Poller() +@@ -490,13 +598,17 @@ class SyncMPClient(MPClient): + daemon=True) + self.output_queue_thread.start() + ++ # The thread takes on responsibility for closing the socket. ++ self.resources.output_socket = None ++ + def get_output(self) -> EngineCoreOutputs: + return self.outputs_queue.get() + + def _send_input(self, request_type: EngineCoreRequestType, request: Any): +- # (RequestType, SerializedRequest) +- msg = (request_type.value, self.encoder.encode(request)) +- self.core_engine.send_multipart(msg) ++ # (Identity, RequestType, SerializedRequest) ++ msg = (self.core_engine.identity, request_type.value, ++ self.encoder.encode(request)) ++ self.input_socket.send_multipart(msg, copy=False) + + def call_utility(self, method: str, *args) -> Any: + call_id = uuid.uuid1().int >> 64 +@@ -581,6 +693,7 @@ class AsyncMPClient(MPClient): + [AsyncMPClient, EngineCoreOutputs], Awaitable[None]]] = None + + def _ensure_output_queue_task(self): ++ resources = self.resources + if self.outputs_queue is not None: + return + +@@ -592,10 +705,8 @@ class AsyncMPClient(MPClient): + outputs_queue = self.outputs_queue + output_handler = self.outputs_handler + _self_ref = weakref.ref(self) if output_handler else None +- output_path = self.output_path +- output_socket = make_zmq_socket(self.ctx, output_path, +- zmq.constants.PULL) +- self.resources.output_socket = output_socket ++ output_socket = resources.output_socket ++ assert output_socket is not None + + async def process_outputs_socket(): + while True: +@@ -625,30 +736,34 @@ class AsyncMPClient(MPClient): + assert self.outputs_queue is not None + return await self.outputs_queue.get() + +- async def _send_input(self, request_type: EngineCoreRequestType, +- request: Any) -> None: +- await self.core_engine.send_multipart( +- (request_type.value, self.encoder.encode(request))) ++ def _send_input(self, ++ request_type: EngineCoreRequestType, ++ request: Any, ++ engine: Optional[CoreEngine] = None) -> Awaitable[None]: ++ if engine is None: ++ engine = self.core_engine + +- self._ensure_output_queue_task() ++ message = (request_type.value, self.encoder.encode(request)) ++ return self._send_input_message(message, engine) ++ ++ def _send_input_message(self, message: tuple[bytes, bytes], ++ engine: CoreEngine) -> Awaitable[None]: ++ message = (engine.identity, ) + message # type: ignore[assignment] ++ return self.input_socket.send_multipart(message, copy=False) + + async def call_utility_async(self, method: str, *args) -> Any: + return await self._call_utility_async(method, + *args, + engine=self.core_engine) + +- async def _call_utility_async( +- self, +- method: str, +- *args, +- engine: CoreEngine, +- ) -> Any: ++ async def _call_utility_async(self, method: str, *args, ++ engine: CoreEngine) -> Any: + call_id = uuid.uuid1().int >> 64 + future = asyncio.get_running_loop().create_future() + self.utility_results[call_id] = future + message = (EngineCoreRequestType.UTILITY.value, + self.encoder.encode((call_id, method, args))) +- await engine.send_multipart(message) ++ await self._send_input_message(message, engine) + self._ensure_output_queue_task() + return await future + +@@ -657,6 +772,7 @@ class AsyncMPClient(MPClient): + # tokenized. + request.prompt = None + await self._send_input(EngineCoreRequestType.ADD, request) ++ self._ensure_output_queue_task() + + async def abort_requests_async(self, request_ids: list[str]) -> None: + if len(request_ids) > 0: +@@ -728,21 +844,6 @@ class DPAsyncMPClient(AsyncMPClient): + + self.outputs_handler = DPAsyncMPClient.process_engine_outputs # type: ignore[assignment] + +- def _init_core_engines( +- self, +- vllm_config: VllmConfig, +- new_core_engine: Callable[[int, Optional[int]], CoreEngine], +- core_engines: list[CoreEngine], +- ) -> None: +- +- # Launch a core engine for each data parallel rank. +- dp_size = vllm_config.parallel_config.data_parallel_size +- for i in range(dp_size): +- # Multi-node not yet supported so local_dp_rank == dp_rank. +- core_engines.append(new_core_engine(i, i)) +- +- self.core_engines = core_engines +- + async def call_utility_async(self, method: str, *args) -> Any: + # Only the result from the first engine is returned. + return (await asyncio.gather(*[ +@@ -761,15 +862,15 @@ class DPAsyncMPClient(AsyncMPClient): + self.reqs_in_flight[request.request_id] = chosen_engine + chosen_engine.num_reqs_in_flight += 1 + if self.num_engines_running >= len(self.core_engines): +- await chosen_engine.send_multipart(msg) ++ await self._send_input_message(msg, chosen_engine) + else: + # Send request to chosen engine and dp start loop + # control message to all other engines. + self.num_engines_running += len(self.core_engines) + await asyncio.gather(*[ +- engine.send_multipart(msg if engine is +- chosen_engine else self.start_dp_msg) +- for engine in self.core_engines ++ self._send_input_message( ++ msg if engine is chosen_engine else self.start_dp_msg, ++ engine) for engine in self.core_engines + ]) + + self._ensure_output_queue_task() +@@ -794,7 +895,7 @@ class DPAsyncMPClient(AsyncMPClient): + # sure to start the other engines: + self.num_engines_running = len(self.core_engines) + coros = [ +- engine.send_multipart(self.start_dp_msg) ++ self._send_input_message(self.start_dp_msg, engine) + for engine in self.core_engines + if not engine.num_reqs_in_flight + ] +@@ -820,5 +921,5 @@ class DPAsyncMPClient(AsyncMPClient): + + async def _abort_requests(self, request_ids: list[str], + engine: CoreEngine) -> None: +- await engine.send_multipart((EngineCoreRequestType.ABORT.value, +- self.encoder.encode(request_ids))) ++ await self._send_input(EngineCoreRequestType.ABORT, request_ids, ++ engine) +diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py +index 1d5175e..fc300ff 100644 +--- a/vllm/v1/executor/multiproc_executor.py ++++ b/vllm/v1/executor/multiproc_executor.py +@@ -243,7 +243,6 @@ class WorkerProc: + protocol=pickle.HIGHEST_PROTOCOL) + ready_socket.send_string(WorkerProc.READY_STR) + ready_socket.send(payload) +- + self.worker.init_device() + self.worker.load_model() + +@@ -327,6 +326,11 @@ class WorkerProc: + logger.debug("Worker interrupted.") + + except Exception: ++ import sys ++ import traceback ++ exec_type, exec_value, exec_traceback = sys.exc_info() ++ exception_str = "".join(traceback.format_exception(exec_type, exec_value, exec_traceback)) ++ logger.error("WorkerProc failed! %s" % exception_str) + # worker_busy_loop sends exceptions exceptons to Executor + # for shutdown, but if there is an error in startup or an + # error with IPC itself, we need to alert the parent. +diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py +index 146d7d7..7c1d484 100644 +--- a/vllm/v1/serial_utils.py ++++ b/vllm/v1/serial_utils.py +@@ -2,9 +2,10 @@ + + import pickle + from types import FunctionType +-from typing import Any, Optional ++from typing import Any, Optional, Union + + import cloudpickle ++import zmq + import torch + from msgspec import msgpack + +@@ -12,6 +13,7 @@ CUSTOM_TYPE_TENSOR = 1 + CUSTOM_TYPE_PICKLE = 2 + CUSTOM_TYPE_CLOUDPICKLE = 3 + ++bytestr = Union[bytes, bytearray, memoryview, zmq.Frame] + + class MsgpackEncoder: + """Encoder with custom torch tensor serialization.""" +diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py +index f42b350..fbc0ee3 100644 +--- a/vllm/v1/utils.py ++++ b/vllm/v1/utils.py +@@ -2,17 +2,21 @@ + + import multiprocessing + import os ++import time + import weakref + from collections import defaultdict + from collections.abc import Sequence ++from multiprocessing import Process, connection + from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, + Union, overload) + + import torch + ++from vllm.config import VllmConfig + from vllm.logger import init_logger + from vllm.model_executor.models.utils import extract_layer_index + from vllm.utils import get_mp_context, kill_process_tree ++from vllm.v1.executor.abstract import Executor + + if TYPE_CHECKING: + from vllm.attention.layer import Attention +@@ -90,7 +94,7 @@ class ConstantList(Generic[T], Sequence): + return f"ConstantList({self._x})" + + +-class BackgroundProcHandle: ++class CoreEngineProcManager: + """ + Utility class to handle creation, readiness, and shutdown + of background processes used by the AsyncLLM and LLMEngine. +@@ -98,55 +102,91 @@ class BackgroundProcHandle: + + def __init__( + self, +- input_path: str, +- output_path: str, +- process_name: str, + target_fn: Callable, +- process_kwargs: dict[Any, Any], ++ local_engine_count: int, ++ start_index: int, ++ local_start_index: int, ++ vllm_config: VllmConfig, ++ on_head_node: bool, ++ input_address: str, ++ executor_class: type[Executor], ++ log_stats: bool, + ): + context = get_mp_context() +- self.reader, writer = context.Pipe(duplex=False) +- +- assert ("ready_pipe" not in process_kwargs +- and "input_path" not in process_kwargs +- and "output_path" not in process_kwargs) +- process_kwargs["ready_pipe"] = writer +- process_kwargs["input_path"] = input_path +- process_kwargs["output_path"] = output_path +- +- # Run busy loop in background process. +- self.proc = context.Process(target=target_fn, +- kwargs=process_kwargs, +- name=process_name) +- self._finalizer = weakref.finalize(self, shutdown, self.proc, +- input_path, output_path) +- self.proc.start() +- +- def wait_for_startup(self): +- # Wait for startup. +- if self.reader.recv()["status"] != "READY": +- raise RuntimeError(f"{self.proc.name} initialization failed. " +- "See root cause above.") +- +- def shutdown(self): ++ common_kwargs = { ++ "vllm_config": vllm_config, ++ "on_head_node": on_head_node, ++ "input_address": input_address, ++ "executor_class": executor_class, ++ "log_stats": log_stats, ++ } ++ ++ self.processes: list[Process] = [] ++ for index in range(local_engine_count): ++ local_index = local_start_index + index ++ global_index = start_index + index ++ # Start EngineCore in background process. ++ self.processes.append( ++ context.Process(target=target_fn, ++ name=f"EngineCore_{global_index}", ++ kwargs=common_kwargs | { ++ "dp_rank": global_index, ++ "local_dp_rank": local_index, ++ })) ++ ++ self._finalizer = weakref.finalize(self, shutdown, self.processes, ++ input_address) ++ try: ++ for proc in self.processes: ++ proc.start() ++ finally: ++ # Kill other procs if not all are running. ++ if self.finished_procs(): ++ self.close() ++ ++ def close(self): ++ """Shutdown all procs.""" + self._finalizer() + ++ def join_first(self): ++ """Wait for any process to exit.""" ++ connection.wait(proc.sentinel for proc in self.processes) ++ ++ def sentinels(self) -> list: ++ return [proc.sentinel for proc in self.processes] ++ ++ def finished_procs(self) -> dict[str, int]: ++ """Returns dict of proc name -> exit code for any finished procs.""" ++ return { ++ proc.name: proc.exitcode ++ for proc in self.processes if proc.exitcode is not None ++ } ++ + + # Note(rob): shutdown function cannot be a bound method, + # else the gc cannot collect the object. +-def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str): ++def shutdown(procs: list[Process], input_address: str): + # Shutdown the process. +- if proc.is_alive(): +- proc.terminate() +- proc.join(5) ++ for proc in procs: ++ if proc.is_alive(): ++ proc.terminate() ++ ++ # Allow 5 seconds for remaining procs to terminate. ++ deadline = time.monotonic() + 5 ++ for proc in procs: ++ remaining = deadline - time.monotonic() ++ if remaining <= 0: ++ break ++ if proc.is_alive(): ++ proc.join(remaining) + ++ for proc in procs: + if proc.is_alive(): + kill_process_tree(proc.pid) + + # Remove zmq ipc socket files. +- ipc_sockets = [output_path, input_path] +- for ipc_socket in ipc_sockets: +- socket_file = ipc_socket.replace("ipc://", "") ++ if input_address.startswith("ipc://"): ++ socket_file = input_address[len("ipc://"):] + if os and os.path.exists(socket_file): + os.remove(socket_file) + +diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py +index 5133c63..5b8805f 100644 +--- a/vllm/v1/worker/gpu_model_runner.py ++++ b/vllm/v1/worker/gpu_model_runner.py +@@ -13,13 +13,16 @@ import torch.nn as nn + from vllm.attention import AttentionType, get_attn_backend + from vllm.attention.layer import Attention + from vllm.config import CompilationLevel, VllmConfig ++from vllm.distributed.kv_transfer import (get_kv_transfer_group, ++ has_kv_transfer_group) + from vllm.distributed.parallel_state import get_pp_group, graph_capture + from vllm.forward_context import set_forward_context + from vllm.logger import init_logger + from vllm.model_executor.layers.fused_moe import FusedMoE + from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding + from vllm.model_executor.model_loader import get_model +-from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs ++from vllm.multimodal import MULTIMODAL_REGISTRY ++from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange + from vllm.multimodal.utils import group_mm_inputs_by_modality + from vllm.sampling_params import SamplingType + from vllm.sequence import IntermediateTensors +@@ -43,7 +46,8 @@ from vllm.v1.utils import bind_kv_cache + from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch + from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin + +-from .utils import sanity_check_mm_encoder_outputs ++from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs, ++ scatter_mm_placeholders) + + if TYPE_CHECKING: + import xgrammar as xgr +@@ -482,14 +486,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): + self.input_batch.block_table.commit(num_reqs) + + # Get the number of scheduled tokens for each request. +- # TODO: The Python loop can be slow. Optimize. +- num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32) +- max_num_scheduled_tokens = 0 +- for i, req_id in enumerate(self.input_batch.req_ids): +- num_tokens = scheduler_output.num_scheduled_tokens[req_id] +- num_scheduled_tokens[i] = num_tokens +- max_num_scheduled_tokens = max(max_num_scheduled_tokens, +- num_tokens) ++ req_ids = self.input_batch.req_ids ++ tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] ++ num_scheduled_tokens = np.array(tokens, dtype=np.int32) ++ max_num_scheduled_tokens = max(tokens) + + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] +@@ -830,19 +830,21 @@ class GPUModelRunner(LoRAModelRunnerMixin): + ) + return metadata + +- def _execute_encoder(self, scheduler_output: "SchedulerOutput"): ++ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): + scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs + if not scheduled_encoder_inputs: + return + + # Batch the multi-modal inputs. +- mm_inputs: list[MultiModalKwargs] = [] +- req_input_ids: list[tuple[str, int]] = [] ++ mm_inputs = list[MultiModalKwargs]() ++ req_ids_pos = list[tuple[str, int, PlaceholderRange]]() + for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): + req_state = self.requests[req_id] +- for input_id in encoder_input_ids: +- mm_inputs.append(req_state.mm_inputs[input_id]) +- req_input_ids.append((req_id, input_id)) ++ ++ for mm_input_id in encoder_input_ids: ++ mm_inputs.append(req_state.mm_inputs[mm_input_id]) ++ req_ids_pos.append( ++ (req_id, mm_input_id, req_state.mm_positions[mm_input_id])) + + # Batch mm inputs as much as we can: if a request in the batch has + # multiple modalities or a different modality than the previous one, +@@ -878,16 +880,23 @@ class GPUModelRunner(LoRAModelRunnerMixin): + encoder_outputs.append(output) + + # Cache the encoder outputs. +- for (req_id, input_id), output in zip(req_input_ids, encoder_outputs): ++ for (req_id, input_id, pos_info), output in zip( ++ req_ids_pos, ++ encoder_outputs, ++ ): + if req_id not in self.encoder_cache: + self.encoder_cache[req_id] = {} +- self.encoder_cache[req_id][input_id] = output + +- def _gather_encoder_outputs( ++ self.encoder_cache[req_id][input_id] = scatter_mm_placeholders( ++ output, ++ is_embed=pos_info.is_embed, ++ ) ++ ++ def _gather_mm_embeddings( + self, + scheduler_output: "SchedulerOutput", + ) -> list[torch.Tensor]: +- encoder_outputs: list[torch.Tensor] = [] ++ mm_embeds: list[torch.Tensor] = [] + for req_id in self.input_batch.req_ids: + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] +@@ -895,8 +904,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): + num_computed_tokens = req_state.num_computed_tokens + mm_positions = req_state.mm_positions + for i, pos_info in enumerate(mm_positions): +- start_pos = pos_info["offset"] +- num_encoder_tokens = pos_info["length"] ++ start_pos = pos_info.offset ++ num_encoder_tokens = pos_info.length + + # The encoder output is needed if the two ranges overlap: + # [num_computed_tokens, +@@ -918,8 +927,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): + assert req_id in self.encoder_cache + assert i in self.encoder_cache[req_id] + encoder_output = self.encoder_cache[req_id][i] +- encoder_outputs.append(encoder_output[start_idx:end_idx]) +- return encoder_outputs ++ ++ if (is_embed := pos_info.is_embed) is not None: ++ is_embed = is_embed[start_idx:end_idx] ++ ++ mm_embeds_item = gather_mm_placeholders( ++ encoder_output[start_idx:end_idx], ++ is_embed=is_embed, ++ ) ++ mm_embeds.append(mm_embeds_item) ++ return mm_embeds + + def get_model(self) -> nn.Module: + return self.model +@@ -977,18 +994,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[ModelRunnerOutput, torch.Tensor]: ++ # Update KVConnector with the KVConnector metadata forward(). ++ if has_kv_transfer_group(): ++ get_kv_transfer_group().bind_connector_metadata( ++ scheduler_output.kv_connector_metadata) ++ + self._update_states(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: +- # Return empty ModelRunnerOuptut if there's no work to do. ++ # Return empty ModelRunnerOutput if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + +- if self.is_multimodal_model: +- # Run the multimodal encoder if any. +- self._execute_encoder(scheduler_output) +- encoder_outputs = self._gather_encoder_outputs(scheduler_output) +- else: +- encoder_outputs = [] +- + # Prepare the decoder inputs. + attn_metadata, logits_indices, spec_decode_metadata = ( + self._prepare_inputs(scheduler_output)) +@@ -1004,14 +1019,23 @@ class GPUModelRunner(LoRAModelRunnerMixin): + num_input_tokens = num_scheduled_tokens + attn_metadata.num_input_tokens = num_input_tokens + ++ # _prepare_inputs may reorder the batch, so we must gather multi ++ # modal outputs after that to ensure the correct order ++ if self.is_multimodal_model: ++ # Run the multimodal encoder if any. ++ self._execute_mm_encoder(scheduler_output) ++ mm_embeds = self._gather_mm_embeddings(scheduler_output) ++ else: ++ mm_embeds = [] ++ + if self.is_multimodal_model: + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + input_ids = self.input_ids[:num_scheduled_tokens] +- if encoder_outputs: ++ if mm_embeds: + inputs_embeds = self.model.get_input_embeddings( +- input_ids, encoder_outputs) ++ input_ids, mm_embeds) + else: + inputs_embeds = self.model.get_input_embeddings(input_ids) + # TODO(woosuk): Avoid the copy. Optimize. +@@ -1172,9 +1196,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): + + if spec_decode_metadata is None: + # input_ids can be None for multimodal models. ++ # We need to slice token_ids, positions, and hidden_states ++ # because the eagle head does not use cuda graph and should ++ # not include padding. + target_token_ids = self.input_ids[:num_scheduled_tokens] +- target_positions = positions +- target_hidden_states = hidden_states ++ target_positions = positions[:num_scheduled_tokens] ++ target_hidden_states = hidden_states[:num_scheduled_tokens] + target_slot_mapping = attn_metadata.slot_mapping + cu_num_tokens = attn_metadata.query_start_loc + else: +@@ -1213,6 +1240,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): + # in the next step. + del draft_probs + ++ # Clear KVConnector state after all KVs are generated. ++ if has_kv_transfer_group(): ++ get_kv_transfer_group().clear_connector_metadata() ++ + return ModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, +diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py +index 2972e0f..35d114f 100644 +--- a/vllm/v1/worker/gpu_worker.py ++++ b/vllm/v1/worker/gpu_worker.py +@@ -9,11 +9,12 @@ import torch.distributed + import torch.nn as nn + + import vllm.envs as envs +-from vllm.config import ParallelConfig, VllmConfig ++from vllm.config import VllmConfig + from vllm.device_allocator.cumem import CuMemAllocator + from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment, + set_custom_all_reduce) ++from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized + from vllm.distributed.parallel_state import get_pp_group + from vllm.logger import init_logger + from vllm.lora.request import LoRARequest +@@ -42,6 +43,8 @@ class Worker(WorkerBase): + is_driver_worker: bool = False, + ): + ++ if vllm_config.kv_transfer_config and vllm_config.kv_transfer_config.kv_connector_extra_config: ++ local_rank = vllm_config.kv_transfer_config.kv_connector_extra_config["device_ids"][rank] + super().__init__(vllm_config=vllm_config, + local_rank=local_rank, + rank=rank, +@@ -110,7 +113,7 @@ class Worker(WorkerBase): + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + # Initialize the distributed environment. +- init_worker_distributed_environment(self.parallel_config, self.rank, ++ init_worker_distributed_environment(self.vllm_config, self.rank, + self.distributed_init_method, + self.local_rank) + # Set random seed. +@@ -285,19 +288,20 @@ class Worker(WorkerBase): + + + def init_worker_distributed_environment( +- parallel_config: ParallelConfig, ++ vllm_config: VllmConfig, + rank: int, + distributed_init_method: Optional[str] = None, + local_rank: int = -1, + ) -> None: + """Initialize the distributed environment.""" ++ parallel_config = vllm_config.parallel_config + set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + + init_distributed_environment(parallel_config.world_size, rank, + distributed_init_method, local_rank) +- + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) ++ ensure_kv_transfer_initialized(vllm_config) + + + def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): +diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py +index b1d3aa7..2b4b70b 100644 +--- a/vllm/v1/worker/utils.py ++++ b/vllm/v1/worker/utils.py +@@ -1,4 +1,6 @@ + # SPDX-License-Identifier: Apache-2.0 ++from typing import Optional ++ + import torch + + +@@ -27,3 +29,45 @@ def sanity_check_mm_encoder_outputs( + f"but got tensors with shapes {[e.shape for e in mm_embeddings]} " + "instead. This is most likely due to incorrect implementation " + "of the model's `get_multimodal_embeddings` method.") ++ ++def scatter_mm_placeholders( ++ embeds: torch.Tensor, ++ is_embed: Optional[torch.Tensor], ++) -> torch.Tensor: ++ """ ++ Scatter the multimodal embeddings into a contiguous tensor that represents ++ the placeholder tokens. ++ ++ :class:`vllm.multimodal.processing.PromptUpdateDetails.is_embed`. ++ ++ Args: ++ embeds: The multimodal embeddings. ++ Shape: `(num_embeds, embed_dim)` ++ is_embed: A boolean mask indicating which positions in the placeholder ++ tokens need to be filled with multimodal embeddings. ++ Shape: `(num_placeholders, num_embeds)` ++ """ ++ if is_embed is None: ++ return embeds ++ ++ placeholders = embeds.new_full( ++ (is_embed.shape[0], embeds.shape[-1]), ++ fill_value=torch.nan, ++ ) ++ placeholders[is_embed] = embeds ++ return placeholders ++ ++ ++def gather_mm_placeholders( ++ placeholders: torch.Tensor, ++ is_embed: Optional[torch.Tensor], ++) -> torch.Tensor: ++ """ ++ Reconstructs the embeddings from the placeholder tokens. ++ ++ This is the operation of :func:`scatter_mm_placeholders`. ++ """ ++ if is_embed is None: ++ return placeholders ++ ++ return placeholders[is_embed] +\ No newline at end of file +diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py +index 86e6d97..49b0ba1 100644 +--- a/vllm/worker/model_runner.py ++++ b/vllm/worker/model_runner.py +@@ -15,7 +15,7 @@ import numpy as np + import torch + import torch.distributed + import torch.nn as nn +-from tqdm import tqdm ++from tqdm.auto import tqdm + + import vllm.envs as envs + from vllm.attention import AttentionMetadata, get_attn_backend +@@ -23,7 +23,8 @@ from vllm.attention.backends.abstract import AttentionState + from vllm.attention.backends.utils import CommonAttentionState + from vllm.config import CompilationLevel, VllmConfig + from vllm.core.scheduler import SchedulerOutputs +-from vllm.distributed import get_kv_transfer_group, get_pp_group ++from vllm.distributed import get_pp_group ++from vllm.distributed.kv_transfer import get_kv_transfer_group + from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, + graph_capture) + from vllm.forward_context import get_forward_context, set_forward_context +diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py +index d59f20f..b4459fd 100644 +--- a/vllm/worker/worker.py ++++ b/vllm/worker/worker.py +@@ -10,10 +10,10 @@ import torch.distributed + import vllm.envs as envs + from vllm.config import VllmConfig + from vllm.device_allocator.cumem import CuMemAllocator +-from vllm.distributed import (ensure_kv_transfer_initialized, +- ensure_model_parallel_initialized, ++from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment, + set_custom_all_reduce) ++from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized + from vllm.logger import init_logger + from vllm.lora.request import LoRARequest + from vllm.model_executor import set_random_seed +@@ -53,6 +53,8 @@ class Worker(LocalOrDistributedWorkerBase): + model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, + ) -> None: + WorkerBase.__init__(self, vllm_config) ++ if vllm_config.kv_transfer_config and vllm_config.kv_transfer_config.kv_connector_extra_config: ++ local_rank = vllm_config.kv_transfer_config.kv_connector_extra_config["device_ids"][rank] + self.parallel_config.rank = rank + self.local_rank = local_rank + self.rank = rank +@@ -506,7 +508,8 @@ def init_worker_distributed_environment( + distributed_init_method, local_rank) + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) +- ++ import faulthandler ++ faulthandler.enable() + ensure_kv_transfer_initialized(vllm_config) + + +-- +2.45.1.windows.1 + diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index af97709b..175b566a 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -178,6 +178,13 @@ from vllm_mindspore.distributed.parallel_state import ( vllm.distributed.parallel_state.init_model_parallel_group = init_model_parallel_group vllm.distributed.parallel_state.GroupCoordinator.__init__ = init_group_coordinator +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory + +KVConnectorFactory.register_connector( + "DLLMDsConnector", + "dllm.dkvc.v1.dllm_ds_connector", + "DLLMDsConnector") + from vllm_mindspore.executor.multiproc_worker_utils import ( get_mp_context as ms_get_mp_context, ) diff --git a/vllm_mindspore/engine/arg_utils.py b/vllm_mindspore/engine/arg_utils.py index ed74ba9e..9efb8923 100644 --- a/vllm_mindspore/engine/arg_utils.py +++ b/vllm_mindspore/engine/arg_utils.py @@ -164,12 +164,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: recommend_to_remove=False) return False - # No Disaggregated Prefill so far. - if self.kv_transfer_config != EngineArgs.kv_transfer_config: - _raise_or_fallback(feature_name="--kv-transfer-config", - recommend_to_remove=False) - return False - # No FlashInfer or XFormers so far. V1_BACKENDS = [ "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1", diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index 45fe4bdd..d68b22e9 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -17,7 +17,7 @@ # ============================================================================ from tqdm.auto import tqdm -from typing import Generator, List, Tuple +from typing import Generator, List, Optional, Tuple import torch @@ -27,6 +27,7 @@ from mindspore import Parameter, Tensor def safetensors_weights_iterator( hf_weights_files: List[str], + enable_tqdm: Optional[bool] = None, ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files.""" from safetensors import safe_open diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py index 7491aeac..e6a515ee 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py @@ -29,10 +29,12 @@ from vllm.config import get_current_vllm_config from vllm.distributed.parallel_state import get_dp_group, get_tensor_model_parallel_world_size from vllm.forward_context import get_forward_context from vllm.logger import init_logger +from vllm.attention.layer import Attention from mindspore import Tensor, Model, mutable from mindspore.common import dtype as msdtype from mindspore.nn.utils import no_init_parameters +from mindspore.common.api import _pynative_executor from mindspore_gs.ptq import PTQ from mindspore_gs.ptq import PTQMode, PTQConfig, OutliersSuppressionType, PrecisionRecovery, QuantGranularity, \ @@ -54,6 +56,13 @@ from vllm_mindspore.model_executor.models.mf_models.mf_model_base import MfModel from vllm_mindspore.model_executor.models.mf_models.deepseekv3_weight_processor import DeepseekV3WeightProcessor from vllm_mindspore.model_executor.models.mf_models.attention_mask import LowerTriangularMask +try: + # Need to apply dllm pd patch on vllm to use pd disagg related functions + from vllm.attention.layer import maybe_save_kv_layer_to_connector +except ImportError: + pass + + logger = init_logger(__name__) @@ -168,6 +177,14 @@ class DeepseekV3ForCausalLM(MfModelBase): key_cache.append(k_cache) return mutable(key_cache), None + def connector_send_kvcache(self): + _pynative_executor.sync() + forward_context = get_forward_context() + for i in range(self.mf_model_config.num_layers): + kv_cache_module = self.kv_caches[i] + kv_cache = kv_cache_module.kv_cache[forward_context.virtual_engine][0] + maybe_save_kv_layer_to_connector("key." + str(i), kv_cache) + def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: if self.mf_config.load_ckpt_format == "ckpt": model = Model(self.network) diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py index 3e5dca52..6fb4fecf 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py @@ -35,6 +35,7 @@ from vllm.attention.backends.abstract import AttentionType from vllm.logger import init_logger from vllm.attention.layer import Attention import vllm.envs as envs + import torch import mindspore as ms from mindspore import Tensor, mutable @@ -42,10 +43,18 @@ from mindspore import Tensor, mutable from mindformers.tools.register.config import MindFormerConfig from mindformers.core.context import build_mf_context from mindformers.core.parallel_config import build_parallel_config - +from mindspore.common.api import _pynative_executor from vllm_mindspore.model_executor.models.model_base import MsModelBase from vllm_mindspore.v1.attention.backends.flash_attn import FlashAttentionMetadata +try: + # Need to apply dllm pd patch on vllm to use pd disagg related functions + from vllm.attention.layer import maybe_save_kv_layer_to_connector, wait_for_kv_layer_from_connector + from vllm.distributed.kv_transfer import is_v1_kv_transfer_group + kv_transfer_supported = True +except ImportError: + kv_transfer_supported = False + logger = init_logger(__name__) @@ -120,6 +129,7 @@ class MfModelBase(MsModelBase): vllm_config=vllm_config, prefix=prefix ) + self.kv_transfer_config = vllm_config.kv_transfer_config self.mf_config = MindFormerConfig(os.getenv("MINDFORMERS_MODEL_CONFIG")) build_mf_context(self.mf_config) build_parallel_config(self.mf_config) @@ -159,6 +169,17 @@ class MfModelBase(MsModelBase): value_cache.append(v_cache) return mutable(key_cache), mutable(value_cache) + def is_decoder_task(self) -> bool: + if self.kv_transfer_config is None: + return False + + return self.kv_transfer_config.is_kv_consumer + + def is_prefill_task(self) -> bool: + if self.kv_transfer_config is None: + return False + + return self.kv_transfer_config.is_kv_producer def _dummy_attention_metadata(self, input_ids: Tensor, positions: Tensor) -> FlashAttentionMetadata: input_len = input_ids.shape[0] @@ -256,6 +277,24 @@ class MfModelBase(MsModelBase): def update_model_inputs(self, model_inputs, **kwargs): return model_inputs + def connector_send_kvcache(self): + #TODO 可优化 + _pynative_executor.sync() + forward_context = get_forward_context() + for i in range(self.mf_model_config.num_layers): + kv_cache = self.kv_caches[i] + k_cache = kv_cache.kv_cache[forward_context.virtual_engine][0] + v_cache = kv_cache.kv_cache[forward_context.virtual_engine][1] + maybe_save_kv_layer_to_connector("key." + str(i), (k_cache, v_cache)) + + + def connector_wait_for_kv_layer(self): + logger.debug(f"connector_wait_for_kv_layer") + #TODO 可优化 + for i in range(self.mf_model_config.num_layers): + wait_for_kv_layer_from_connector("key." + str(i)) + + def forward( self, input_ids: Tensor, @@ -279,7 +318,17 @@ class MfModelBase(MsModelBase): if not self.set_flags: self.network.add_flags_custom(is_first_iteration=False) self.set_flags = True + if kv_transfer_supported: + if is_v1_kv_transfer_group(): + self.connector_send_kvcache() else: + if kv_transfer_supported: + if is_v1_kv_transfer_group() and self.is_prefill_task(): + self.connector_send_kvcache() + + if is_v1_kv_transfer_group() and self.is_decoder_task(): + self.connector_wait_for_kv_layer() + logger.debug(f"connector_wait_for_kv_layer success") hidden_states = self.network(**model_inputs) return hidden_states diff --git a/vllm_mindspore/v1/core/sched/scheduler.py b/vllm_mindspore/v1/core/sched/scheduler.py index c03f3469..11419a33 100644 --- a/vllm_mindspore/v1/core/sched/scheduler.py +++ b/vllm_mindspore/v1/core/sched/scheduler.py @@ -112,6 +112,16 @@ def schedule(self) -> SchedulerOutput: # Get already-cached tokens. computed_blocks, num_computed_tokens = ( self.kv_cache_manager.get_computed_blocks(request)) + logger.info(f"num_computed_tokens:{num_computed_tokens}, computed_blocks:{computed_blocks}") + # Get externally-cached tokens if using a KVConnector. + num_external_tokens = ( + 0 if self.connector is None else + self.connector.get_num_new_matched_tokens( + request, num_computed_tokens)) + + # Total computed tokens (local + external). + num_computed_tokens += num_external_tokens + logger.debug(f"num_computed_tokens:{num_computed_tokens}") num_new_tokens = request.num_prompt_tokens - num_computed_tokens if (0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens): @@ -148,11 +158,19 @@ def schedule(self) -> SchedulerOutput: continue new_blocks = self.kv_cache_manager.allocate_slots( - request, num_new_tokens, computed_blocks) + request, num_new_tokens + num_external_tokens, + computed_blocks) + logger.info(f"computed_blocks:{computed_blocks}, new_blocks:{new_blocks}") if new_blocks is None: # The request cannot be scheduled. break + if self.connector is not None: + self.connector.update_state_after_alloc( + request, + num_external_tokens, + ) + self.waiting.popleft() self.running.append(request) self.scheduled_req_ids.add(request.request_id) @@ -285,6 +303,7 @@ def schedule(self) -> SchedulerOutput: resumed_from_preemption=False, ) for req in scheduled_running_reqs ] + logger.info(f"req_to_new_block_ids:{req_to_new_block_ids}") scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=resumed_reqs_data + running_reqs_data, @@ -303,6 +322,11 @@ def schedule(self) -> SchedulerOutput: grammar_bitmask=None, ) + if self.connector is not None: + meta = self.connector.build_connector_meta(scheduler_output) + logger.info(f"scheduler: new reqs: {scheduler_output.scheduled_new_reqs}, kv connector metadata: {meta}") + scheduler_output.kv_connector_metadata = meta + # Advance the number of computed tokens for the request AFTER # the request is scheduled. # 1. The scheduler_output of the current step has to include the diff --git a/vllm_mindspore/v1/worker/gpu_worker.py b/vllm_mindspore/v1/worker/gpu_worker.py index 0395c339..0af8b49d 100644 --- a/vllm_mindspore/v1/worker/gpu_worker.py +++ b/vllm_mindspore/v1/worker/gpu_worker.py @@ -2,6 +2,7 @@ """A GPU worker class""" import gc +import imoportlib import torch from vllm.logger import init_logger from vllm.distributed.parallel_state import get_pp_group @@ -31,9 +32,16 @@ def init_device(self): self.init_gpu_memory = torch.cuda.mem_get_info()[0] # Initialize the distributed environment. - init_worker_distributed_environment(self.parallel_config, self.rank, - self.distributed_init_method, - self.local_rank) + if importlib.util.find_spec("vllm.distributed.kv_transfer.kv_transfer_state") is not None: + # not None -> Module found: DLLM patch applied + init_worker_distributed_environment(config, self.rank, + self.distributed_init_method, + self.local_rank) + else: + # None -> Module not found: Patch not applied + init_worker_distributed_environment(self.parallel_config, self.rank, + self.distributed_init_method, + self.local_rank) # Set random seed. set_random_seed(self.model_config.seed) diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index 8ce1bc91..2dc69fcd 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -26,7 +26,6 @@ import torch from vllm.config import VllmConfig from vllm.distributed import ( - ensure_kv_transfer_initialized, ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce, -- Gitee