diff --git a/vllm_mindspore/v1/attention/backends/flash_attn.py b/vllm_mindspore/v1/attention/backends/flash_attn.py index b5c5629ee51fc7faf969f18a7b596e60d939387f..03c976df841e532ae3c59e85e75f7645ab86e318 100644 --- a/vllm_mindspore/v1/attention/backends/flash_attn.py +++ b/vllm_mindspore/v1/attention/backends/flash_attn.py @@ -10,13 +10,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.logger import init_logger - -from vllm_mindspore.utils import MsKVCache - import mindspore as ms -from mindspore import mutable -from mindspore._c_expression import swap_cache - +from mindspore.common.api import _pynative_executor logger = init_logger(__name__) @@ -42,7 +37,7 @@ class FlashAttentionBackend(AttentionBackend): return FlashAttentionMetadata @staticmethod - def get_builder_cls() -> Type["AttentionMetadataBuilder"]: + def get_builder_cls(): return FlashAttentionMetadataBuilder @staticmethod @@ -62,6 +57,7 @@ class FlashAttentionBackend(AttentionBackend): class MLABackend(AttentionBackend): + @staticmethod def get_name() -> str: return "MS_MLA" @@ -193,38 +189,45 @@ class MsAttentionImpl(AttentionImpl): class FlashAttentionMetadataBuilder: - def __init__(self, runner: "GPUModelRunner"): + + def __init__(self, runner): self.runner = runner - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: + def reorder_batch(self, input_batch, + scheduler_output) -> bool: return False def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int): # do not manually call 'tensor.move_to("Ascend", blocking=False)' here, # because it will cause a certain amount of host time. - query_start_loc = ms.from_numpy(self.runner.query_start_loc_np[:num_reqs + 1]) - max_context_lens = self.runner.input_batch.num_computed_tokens_cpu[:num_reqs].max() - slot_mapping = ms.from_numpy(self.runner.slot_mapping_np[:num_actual_tokens]) + query_start_loc = ms.from_numpy( + self.runner.query_start_loc_np[:num_reqs + 1]) + max_context_lens = self.runner.input_batch.num_computed_tokens_cpu[: + num_reqs].max( + ) + slot_mapping = ms.from_numpy( + self.runner.slot_mapping_np[:num_actual_tokens]) seq_lens_np = self.runner.seq_lens_np[:num_reqs] max_seq_len = seq_lens_np.max() seq_lens = ms.from_numpy(seq_lens_np) - context_lens = ms.from_numpy(self.runner.input_batch.num_computed_tokens_cpu[:num_reqs]) + context_lens = ms.from_numpy( + self.runner.input_batch.num_computed_tokens_cpu[:num_reqs]) q_seq_lens_np = np.diff(self.runner.query_start_loc_np[:num_reqs + 1]) q_seq_lens = ms.from_numpy(q_seq_lens_np) + _pynative_executor.sync() attn_metadata = FlashAttentionMetadata( seq_lens=seq_lens, seq_lens_np=seq_lens_np, - block_tables=(self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]), + block_tables=(self.runner.input_batch.block_table. + get_device_tensor()[:num_reqs]), slot_mapping=slot_mapping, q_seq_lens=q_seq_lens, q_seq_lens_np=q_seq_lens_np, max_seq_len=max_seq_len, context_lens=context_lens, max_context_lens=max_context_lens, - query_start_loc = query_start_loc - ) + query_start_loc = query_start_loc) return attn_metadata