From 63750877297e8d0ec58da748a40804e6fb4e7877 Mon Sep 17 00:00:00 2001 From: HuangLi Date: Sat, 17 Aug 2024 16:01:59 +0800 Subject: [PATCH] fix FIA ut --- test/test_fake_tensor.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 17589f2e2a2..dd590bf0b2d 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -1265,9 +1265,9 @@ class TestNpuTranspose(TestCase): class TestPromptFlashAttention(TestCase): def testPromptFlashAttention(self): with FakeTensorMode(): - q = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu() - k = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu() - v = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu() + q = torch.randn(1, 1024, 1024, dtype=torch.float16).npu() + k = torch.randn(1, 1024, 1024, dtype=torch.float16).npu() + v = torch.randn(1, 1024, 1024, dtype=torch.float16).npu() q.requires_grad = True k.requires_grad = True v.requires_grad = True @@ -1281,9 +1281,9 @@ class TestPromptFlashAttention(TestCase): class TestFusedInferAttentionScore(TestCase): def testFusedInferAttentionScore(self): with FakeTensorMode(): - q = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu() - k = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu() - v = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu() + q = torch.randn(1, 1024, 1024, dtype=torch.float16).npu() + k = torch.randn(1, 1024, 1024, dtype=torch.float16).npu() + v = torch.randn(1, 1024, 1024, dtype=torch.float16).npu() q.requires_grad = True k.requires_grad = True v.requires_grad = True -- Gitee