From 7b100fd0e9cad8f49a4672a7b7c9c43417416c9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=B4=E9=AA=86=E5=A6=83?= Date: Mon, 17 Jun 2024 16:02:33 +0800 Subject: [PATCH] ifa fix ut bug --- test/custom_ops/test_incre_flash_attention.py | 1 + .../custom_ops/test_prompt_flash_attention.py | 24 ++++++++++++------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/test/custom_ops/test_incre_flash_attention.py b/test/custom_ops/test_incre_flash_attention.py index e349fe49db0..7285f072a1f 100644 --- a/test/custom_ops/test_incre_flash_attention.py +++ b/test/custom_ops/test_incre_flash_attention.py @@ -43,6 +43,7 @@ class TestIncreFlashAttention(TestCase): head_dim = 128 hidden_size = 4096 + supported_output = self.supported_op_exec(q, k, v, head_dim, hidden_size) custom_output = self.custom_op_exec(q_FA, k_FA, v_FA, head_dim, hidden_size) self.assertRtolEqual(supported_output, custom_output) diff --git a/test/custom_ops/test_prompt_flash_attention.py b/test/custom_ops/test_prompt_flash_attention.py index 603a73df549..ee820226384 100644 --- a/test/custom_ops/test_prompt_flash_attention.py +++ b/test/custom_ops/test_prompt_flash_attention.py @@ -25,9 +25,9 @@ class TestPromptFlashAttention(TestCase): def custom_op_exec_test_quantscale2(self, query, key, value, head_dim): scale = 1 / 0.0078125 - deq_scale1 = torch.tensor([1], dtype=torch.float32).npu() - quant_scale1 = torch.tensor([1], dtype=torch.float32).npu() - deq_scale2 = torch.tensor([1], dtype=torch.float32).npu() + deq_scale1 = None + quant_scale1 = None + deq_scale2 = None quant_scale2 = torch.tensor([1], dtype=torch.float32).npu() quant_offset2 = torch.tensor([0], dtype=torch.float32).npu() return torch_npu.npu_prompt_flash_attention( @@ -57,6 +57,12 @@ class TestPromptFlashAttention(TestCase): self.assertRtolEqual(fake_result.dtype, query.dtype) self.assertRtolEqual(fake_result.device, query.device) self.assertTrue(isinstance(fake_result, FakeTensor)) + + def custom_op_exec_test_fp16_fake_tensor(self, query, key, value, head_dim): + scale = 1 / 0.0078125 + fake_result = torch.ops.npu.npu_prompt_flash_attention( + query, key, value, num_heads=32, input_layout="BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535, sparse_mode=0) + self.assertRtolEqual(fake_result.shape, query.shape) @SupportedDevices(['Ascend910B']) def test_npu_prompt_flash_attention(self, device="npu"): @@ -70,19 +76,19 @@ class TestPromptFlashAttention(TestCase): custom_output = self.custom_op_exec(query, key, value, head_dim) self.assertRtolEqual(supported_output, custom_output) - query_test2 = torch.randn(1, 32, 2048, 128, dtype=torch.int8).npu() - key_test2 = torch.randn(1, 32, 2048, 128, dtype=torch.int8).npu() - value_test2 = torch.randn(1, 32, 2048, 128, dtype=torch.int8).npu() + query_test2 = torch.randn(5, 32, 2048, 128, dtype=torch.float16).npu() + key_test2 = torch.randn(5, 32, 2048, 128, dtype=torch.float16).npu() + value_test2 = torch.randn(5, 32, 2048, 128, dtype=torch.float16).npu() supported_output = self.supported_op_exec(query_test2, key_test2, value_test2, head_dim) - custom_output = self.custom_op_exec_test_quantscale2(query_test2, key_test2, value_test2, head_dim) + custom_output = self.custom_op_exec(query_test2, key_test2, value_test2, head_dim) self.assertRtolEqual(supported_output, custom_output) supported_output = self.supported_op_exec(query_test2, key_test2, value_test2, head_dim) - custom_output = self.vcustom_op_exec_test_int8(query_test2, key_test2, value_test2, head_dim) + custom_output = self.custom_op_exec(query_test2, key_test2, value_test2, head_dim) self.assertRtolEqual(supported_output, custom_output) - self.custom_op_exec_test_int8_fake_tensor(query_test2, key_test2, value_test2, head_dim) + self.custom_op_exec_test_fp16_fake_tensor(query_test2, key_test2, value_test2, head_dim) if __name__ == "__main__": run_tests() -- Gitee