From 292c2c65acedb7d71d7375d7dc21a9d95b0a1c31 Mon Sep 17 00:00:00 2001 From: withHades <244036962@qq.com> Date: Thu, 4 Sep 2025 22:44:37 +0800 Subject: [PATCH] [Feat]support auto dispatch paged attention during aclgraph running Signed-off-by: withHades <244036962@qq.com> --- test/npu/test_aclgraph_update.py | 155 ++++++++++++++++++++++++++++++- torch_npu/npu/graphs.py | 31 ++++++- 2 files changed, 182 insertions(+), 4 deletions(-) diff --git a/test/npu/test_aclgraph_update.py b/test/npu/test_aclgraph_update.py index 7db212734f3..9214b79190f 100644 --- a/test/npu/test_aclgraph_update.py +++ b/test/npu/test_aclgraph_update.py @@ -1,6 +1,9 @@ import unittest +from dataclasses import dataclass from itertools import chain +import random +import numpy as np import torch import torch_npu @@ -8,7 +11,7 @@ from torch_npu.testing.common_utils import SupportedDevices from torch_npu.testing.testcase import TestCase, run_tests -class TestAclgraphUpdate(TestCase): +class TestIFAAclgraphUpdate(TestCase): @SupportedDevices(['Ascend910B']) def test_ifa_update(self): @@ -170,5 +173,155 @@ class TestAclgraphUpdate(TestCase): self.assertEqual(output.cpu(), res_src[0].cpu()) self.assertEqual(softmax_lse.cpu(), res_src[1].cpu()) + +@dataclass +class PAAttentionParamsNumpy: + query: np.ndarray + key_cache: np.ndarray + value_cache: np.ndarray + block_table: np.ndarray + context_lens: np.ndarray + + +@dataclass +class PAAttentionParamsTensor: + query: torch.Tensor + key_cache: torch.Tensor + value_cache: torch.Tensor + block_table: torch.Tensor + context_lens: torch.Tensor + output: torch.Tensor + + +class TestPAAclgraphUpdate(TestCase): + num_blocks = 64 + num_tokens = 2 + block_size = 128 + kv_heads = 16 + head_size = 288 + num_heads = 32 + head_size_v = 96 + scale = 0.38888 + + def group_matmul(self, head, kv_head, A, B): + group_num = head // kv_head + score = [] + for i in range(kv_head): + group_A = A[i * group_num: (i + 1) * group_num] + group_B = B[i: i + 1] + score.append(np.matmul(group_A, group_B)) + return np.concatenate(score, axis=0) + + def ref_masked_attention(self, query, key, value): + """参考注意力计算""" + # 维度调整 [num_heads, seq_len, head_size] + query = query * self.scale + query = query.transpose(1, 0, 2) + key = key.transpose(1, 2, 0) + + # QK^T计算 + sim = self.group_matmul(query.shape[0], key.shape[0], query, key) + + # Softmax归一化 + sim = sim - np.max(sim, axis=-1, keepdims=True) + exp_sim = np.exp(sim.astype(np.float32)) + p = exp_sim / np.sum(exp_sim, axis=-1, keepdims=True) + p = p.astype(np.float16) + + # Value加权 + value = value.transpose(1, 0, 2) + out = self.group_matmul(p.shape[0], key.shape[0], p, value) + return out.transpose(1, 0, 2) + + def golden_attention_impl(self, params_np): + output = np.zeros((self.num_tokens, self.num_heads, self.head_size_v), dtype=np.float16) + + for i in range(self.num_tokens): + # 从缓存中收集当前序列的KV + seq_blocks = params_np.block_table[i] + context_len = params_np.context_lens[i] + + keys = [] + values = [] + for pos in range(context_len): + block_id = seq_blocks[pos // self.block_size] + offset = pos % self.block_size + keys.append(params_np.key_cache[block_id, offset].reshape(self.kv_heads, -1)) + values.append(params_np.value_cache[block_id, offset].reshape(self.kv_heads, -1)) + + # 执行注意力计算 + out = self.ref_masked_attention( + params_np.query[i:i + 1], + np.stack(keys), + np.stack(values) + ) + output[i] = out.reshape(self.num_heads, -1) + return output + + def preprocess(self): + """生成测试输入数据""" + query_np = np.random.uniform(-1, 1, (self.num_tokens, self.num_heads, self.head_size)).astype(np.float16) + key_cache_np = np.random.uniform(-1, 1, (self.num_blocks, self.block_size, self.kv_heads, self.head_size)).astype(np.float16) + value_cache_np = np.random.uniform(-1, 1, (self.num_blocks, self.block_size, self.kv_heads, self.head_size_v)).astype(np.float16) + max_blocks_per_seq = (1024 + self.block_size - 1) // self.block_size + block_table_np = np.array([ + [random.randint(0, self.num_blocks - 1) for _ in range(max_blocks_per_seq)] + for _ in range(self.num_tokens) + ], dtype=np.int32) + context_lens_np = np.full(self.num_tokens, random.randint(1, 1024), dtype=np.int32) + params_np = PAAttentionParamsNumpy(query_np, key_cache_np, value_cache_np, block_table_np, context_lens_np) + golden_output = self.golden_attention_impl(params_np) + golden_output = torch.from_numpy(golden_output) + + query = torch.from_numpy(query_np).npu() + key_cache = torch.from_numpy(key_cache_np).npu() + value_cache = torch.from_numpy(value_cache_np).npu() + block_table = torch.from_numpy(block_table_np).npu() + context_lens = torch.from_numpy(context_lens_np) + output = torch.zeros_like(query[:, :, :self.head_size_v]).npu() + params_tensor = PAAttentionParamsTensor(query, key_cache, value_cache, block_table, context_lens, output) + return params_tensor, golden_output + + def atb_paged_attention(self, params): + torch_npu._npu_paged_attention( + query=params.query, + key_cache=params.key_cache, + value_cache=params.value_cache, + num_kv_heads=self.kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=params.block_table, + context_lens=params.context_lens, + out=params.output, + ) + return params.output + + @SupportedDevices(['Ascend910B']) + def test_paged_attention_aclgraph_update(self): + params, golden_output = self.preprocess() + output = None + + # capture + graph = torch.npu.NPUGraph() + with torch.npu.graph(graph, + stream=torch.npu.Stream(), + pool=None, + auto_dispatch_capture=True): + output = self.atb_paged_attention(params) + graph.update(cpu_update_input=[{"context_lens": params.context_lens}]) + graph.replay() + torch.npu.synchronize() + self.assertRtolEqual(output, golden_output) + + params_new, golden_output = self.preprocess() + params.query.copy_(params_new.query) + params.key_cache.copy_(params_new.key_cache) + params.value_cache.copy_(params_new.value_cache) + params.block_table.copy_(params_new.block_table) + graph.update(cpu_update_input=[{"context_lens": params_new.context_lens}]) + graph.replay() + torch.npu.synchronize() + self.assertRtolEqual(output, golden_output) + if __name__ == "__main__": run_tests() diff --git a/torch_npu/npu/graphs.py b/torch_npu/npu/graphs.py index 7e21ce5ed9a..b7b312db64a 100644 --- a/torch_npu/npu/graphs.py +++ b/torch_npu/npu/graphs.py @@ -111,9 +111,20 @@ class _GraphDispatchMode(torch.utils._python_dispatch.TorchDispatchMode): with torch.npu.stream(self.update_stream): for graph_dispatch_record, update_input in zip(self.graph_dispatch_records, cpu_update_input): graph_task_update_begin(self.update_stream, graph_dispatch_record.handle) - for key in update_input: - graph_dispatch_record.kwargs[key] = update_input[key] - graph_dispatch_record.op_cache_entry(*graph_dispatch_record.args, **graph_dispatch_record.kwargs) + if graph_dispatch_record.op_cache_entry.__name__ in ["_npu_paged_attention.default", "_npu_paged_attention"]: + args = list(graph_dispatch_record.args) + # When parameters are passed through args, context_lens is the second to last parameter. + if len(args) >= 2: + args[-2] = update_input["context_lens"] + graph_dispatch_record.op_cache_entry(*args) + else: + for key in update_input: + graph_dispatch_record.kwargs[key] = update_input[key] + graph_dispatch_record.op_cache_entry(*graph_dispatch_record.args, **graph_dispatch_record.kwargs) + elif graph_dispatch_record.op_cache_entry.__name__ in ["npu_fused_infer_attention_score", "npu_fused_infer_attention_score.out"]: + for key in update_input: + graph_dispatch_record.kwargs[key] = update_input[key] + graph_dispatch_record.op_cache_entry(*graph_dispatch_record.args, **graph_dispatch_record.kwargs) graph_task_update_end(self.update_stream) graph_dispatch_record.event.record(self.update_stream) @@ -170,6 +181,20 @@ class _GraphDispatchMode(torch.utils._python_dispatch.TorchDispatchMode): self.graph_dispatch_records.append( self._append_dispatch_record(event, handle, args, kwargs, func)) return kwargs["out"] + elif func.__name__ in ["_npu_paged_attention.default", "_npu_paged_attention"]: + self.update_schema(str(func.__name__), str(func._schema)) + stream = torch_npu.npu.current_stream() + event = torch.npu.ExternalEvent() + event.wait(stream) + event.reset(stream) + # begin graph task + graph_task_group_begin(stream) + func(*args, **kwargs) + handle = graph_task_group_end(stream) + # save state for update + self.graph_dispatch_records.append( + self._append_dispatch_record(event, handle, args, kwargs, func)) + return None else: return func(*args, **kwargs) -- Gitee