From a5738dfb2e236527e6829ec062d5d993ed9a5b21 Mon Sep 17 00:00:00 2001 From: zhang_xu_hao1230 Date: Fri, 4 Jul 2025 14:59:12 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=90=8C=E6=AD=A5=EF=BC=8C?= =?UTF-8?q?=E4=BF=AE=E5=A4=8Dpa=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../v1/attention/backends/flash_attn.py | 37 ++++++++++--------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/vllm_mindspore/v1/attention/backends/flash_attn.py b/vllm_mindspore/v1/attention/backends/flash_attn.py index b5c5629e..03c976df 100644 --- a/vllm_mindspore/v1/attention/backends/flash_attn.py +++ b/vllm_mindspore/v1/attention/backends/flash_attn.py @@ -10,13 +10,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.logger import init_logger - -from vllm_mindspore.utils import MsKVCache - import mindspore as ms -from mindspore import mutable -from mindspore._c_expression import swap_cache - +from mindspore.common.api import _pynative_executor logger = init_logger(__name__) @@ -42,7 +37,7 @@ class FlashAttentionBackend(AttentionBackend): return FlashAttentionMetadata @staticmethod - def get_builder_cls() -> Type["AttentionMetadataBuilder"]: + def get_builder_cls(): return FlashAttentionMetadataBuilder @staticmethod @@ -62,6 +57,7 @@ class FlashAttentionBackend(AttentionBackend): class MLABackend(AttentionBackend): + @staticmethod def get_name() -> str: return "MS_MLA" @@ -193,38 +189,45 @@ class MsAttentionImpl(AttentionImpl): class FlashAttentionMetadataBuilder: - def __init__(self, runner: "GPUModelRunner"): + + def __init__(self, runner): self.runner = runner - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: + def reorder_batch(self, input_batch, + scheduler_output) -> bool: return False def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int): # do not manually call 'tensor.move_to("Ascend", blocking=False)' here, # because it will cause a certain amount of host time. - query_start_loc = ms.from_numpy(self.runner.query_start_loc_np[:num_reqs + 1]) - max_context_lens = self.runner.input_batch.num_computed_tokens_cpu[:num_reqs].max() - slot_mapping = ms.from_numpy(self.runner.slot_mapping_np[:num_actual_tokens]) + query_start_loc = ms.from_numpy( + self.runner.query_start_loc_np[:num_reqs + 1]) + max_context_lens = self.runner.input_batch.num_computed_tokens_cpu[: + num_reqs].max( + ) + slot_mapping = ms.from_numpy( + self.runner.slot_mapping_np[:num_actual_tokens]) seq_lens_np = self.runner.seq_lens_np[:num_reqs] max_seq_len = seq_lens_np.max() seq_lens = ms.from_numpy(seq_lens_np) - context_lens = ms.from_numpy(self.runner.input_batch.num_computed_tokens_cpu[:num_reqs]) + context_lens = ms.from_numpy( + self.runner.input_batch.num_computed_tokens_cpu[:num_reqs]) q_seq_lens_np = np.diff(self.runner.query_start_loc_np[:num_reqs + 1]) q_seq_lens = ms.from_numpy(q_seq_lens_np) + _pynative_executor.sync() attn_metadata = FlashAttentionMetadata( seq_lens=seq_lens, seq_lens_np=seq_lens_np, - block_tables=(self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]), + block_tables=(self.runner.input_batch.block_table. + get_device_tensor()[:num_reqs]), slot_mapping=slot_mapping, q_seq_lens=q_seq_lens, q_seq_lens_np=q_seq_lens_np, max_seq_len=max_seq_len, context_lens=context_lens, max_context_lens=max_context_lens, - query_start_loc = query_start_loc - ) + query_start_loc = query_start_loc) return attn_metadata -- Gitee