diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index eebc2cc6c793876224fdbb97ee7de2f8bb90e21c..700e0618ea5f424c03660d75b855d6cf1389b79e 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -138,6 +138,8 @@ from vllm_mindspore.worker.model_runner import ( _get_cuda_graph_pad_size, _dummy_run, _get_supported_attention_backends, + need_recv_kv, + need_send_kv, ) vllm.worker.model_runner.ModelInputForGPUBuilder._get_cuda_graph_pad_size = ( @@ -145,6 +147,10 @@ vllm.worker.model_runner.ModelInputForGPUBuilder._get_cuda_graph_pad_size = ( ) vllm.worker.model_runner.GPUModelRunnerBase._dummy_run = _dummy_run +vllm.worker.model_runner.ModelRunner.need_recv_kv = need_recv_kv + +vllm.worker.model_runner.ModelRunner.need_send_kv = need_send_kv + import vllm.worker.multi_step_model_runner vllm.worker.multi_step_model_runner._get_supported_attention_backends = ( diff --git a/vllm_mindspore/worker/model_runner.py b/vllm_mindspore/worker/model_runner.py index 561fd2021dd7c84764a04aaa1b3b06389f720b55..44f3bd6c23ee0723eb78729760cb12180186b24b 100644 --- a/vllm_mindspore/worker/model_runner.py +++ b/vllm_mindspore/worker/model_runner.py @@ -178,4 +178,58 @@ def _get_supported_attention_backends(chunked_prefill_enabled: bool) \ if chunked_prefill_enabled: return MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS else: - return MULTI_STEP_ATTENTION_BACKENDS \ No newline at end of file + return MULTI_STEP_ATTENTION_BACKENDS + +def need_recv_kv(self, model_input, kv_caches) -> bool: + """Check if we need to receive kv-cache from the other worker. + We need to receive KV when + 1. current vLLM instance is KV cache consumer/decode vLLM instance + 2. this batch is not a profiling run + 3. this batch is a prefill run + + Args: + model_input: input to the model executable + kv_caches: vLLM's paged memory + """ + + if self.vllm_config.kv_transfer_config is None: + return False + + prefill_meta = model_input.attn_metadata.prefill_metadata + + # check if the current run is profiling + # original vllm line cannot work with ms tuple; + # is_profile_run = (kv_caches[0].numel() == 0) + is_profile_run = (model_input.input_tokens.sum(dim=-1) == 0) + # check if the current run is prefill + is_prefill_run = prefill_meta is not None + + return self.vllm_config.kv_transfer_config.is_kv_consumer and ( + not is_profile_run) and is_prefill_run + +def need_send_kv(self, model_input, kv_caches) -> bool: + """Check if we need to send kv-cache to the other worker. + We need to send KV when + 1. current vLLM instance is KV cache producer/prefill vLLM instance + 2. this batch is not a profiling run + 3. this batch is a prefill run + + Args: + model_input: input to the model executable + kv_caches: vLLM's paged memory + """ + + if self.vllm_config.kv_transfer_config is None: + return False + + prefill_meta = model_input.attn_metadata.prefill_metadata + + # check if the current run is profiling + # original vllm line cannot work with ms tuple; + # is_profile_run = (kv_caches[0].numel() == 0) + is_profile_run = (model_input.input_tokens.sum(dim=-1) == 0) + # check if the current run is prefill + is_prefill_run = prefill_meta is not None + + return self.vllm_config.kv_transfer_config.is_kv_producer and ( + not is_profile_run) and is_prefill_run \ No newline at end of file diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index 8ce1bc91d511a43a83fd3c8b0e70d228b98b951b..8c9d4422ab5e18c0946523c916ad7f16d9efa658 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -96,4 +96,4 @@ def _warm_up_model(self) -> None: # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. - set_random_seed(self.model_config.seed) + set_random_seed(self.model_config.seed) \ No newline at end of file