diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 17589f2e2a2d767f98770eedf5283c2defe29572..dd590bf0b2d7126b3d1b0704bc8d020f9c1ff14c 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -1265,9 +1265,9 @@ class TestNpuTranspose(TestCase): class TestPromptFlashAttention(TestCase): def testPromptFlashAttention(self): with FakeTensorMode(): - q = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu() - k = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu() - v = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu() + q = torch.randn(1, 1024, 1024, dtype=torch.float16).npu() + k = torch.randn(1, 1024, 1024, dtype=torch.float16).npu() + v = torch.randn(1, 1024, 1024, dtype=torch.float16).npu() q.requires_grad = True k.requires_grad = True v.requires_grad = True @@ -1281,9 +1281,9 @@ class TestPromptFlashAttention(TestCase): class TestFusedInferAttentionScore(TestCase): def testFusedInferAttentionScore(self): with FakeTensorMode(): - q = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu() - k = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu() - v = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu() + q = torch.randn(1, 1024, 1024, dtype=torch.float16).npu() + k = torch.randn(1, 1024, 1024, dtype=torch.float16).npu() + v = torch.randn(1, 1024, 1024, dtype=torch.float16).npu() q.requires_grad = True k.requires_grad = True v.requires_grad = True