diff --git a/test/custom_ops/test_incre_flash_attention.py b/test/custom_ops/test_incre_flash_attention.py index e349fe49db0bec784dc35a93d01c030452a73f33..49372f409fc8b7ec2addc2a0c5d5c4d278880d36 100644 --- a/test/custom_ops/test_incre_flash_attention.py +++ b/test/custom_ops/test_incre_flash_attention.py @@ -45,7 +45,7 @@ class TestIncreFlashAttention(TestCase): supported_output = self.supported_op_exec(q, k, v, head_dim, hidden_size) custom_output = self.custom_op_exec(q_FA, k_FA, v_FA, head_dim, hidden_size) - self.assertRtolEqual(supported_output, custom_output) + self.assertRtolEqual(supported_output, custom_output, prec16=0.05) if __name__ == "__main__": diff --git a/test/custom_ops/test_prompt_flash_attention.py b/test/custom_ops/test_prompt_flash_attention.py index c92c64825330926a03567cd6686d06dfeea7d99d..fc9a0cd0ec501f5721b62396ca7f55ca9f48af3e 100644 --- a/test/custom_ops/test_prompt_flash_attention.py +++ b/test/custom_ops/test_prompt_flash_attention.py @@ -25,9 +25,9 @@ class TestPromptFlashAttention(TestCase): def custom_op_exec_test_quantscale2(self, query, key, value, head_dim): scale = 1 / 0.0078125 - deq_scale1 = torch.tensor([1], dtype=torch.float32).npu() - quant_scale1 = torch.tensor([1], dtype=torch.float32).npu() - deq_scale2 = torch.tensor([1], dtype=torch.float32).npu() + deq_scale1 = None + quant_scale1 = None + deq_scale2 = None quant_scale2 = torch.tensor([1], dtype=torch.float32).npu() quant_offset2 = torch.tensor([0], dtype=torch.float32).npu() return torch_npu.npu_prompt_flash_attention( @@ -47,6 +47,12 @@ class TestPromptFlashAttention(TestCase): self.assertRtolEqual(fake_result.dtype, query.dtype) self.assertRtolEqual(fake_result.device, query.device) self.assertTrue(isinstance(fake_result, FakeTensor)) + + def custom_op_exec_test_fp16_fake_tensor(self, query, key, value, head_dim): + scale = 1 / 0.0078125 + fake_result = torch.ops.npu.npu_prompt_flash_attention( + query, key, value, num_heads=32, input_layout="BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535, sparse_mode=0) + self.assertRtolEqual(fake_result.shape, query.shape) def custom_op_exec_test_int8(self, query, key, value, head_dim): scale = 1 / 0.0078125 @@ -70,19 +76,19 @@ class TestPromptFlashAttention(TestCase): custom_output = self.custom_op_exec(query, key, value, head_dim) self.assertRtolEqual(supported_output, custom_output) - query_test2 = torch.randn(1, 32, 2048, 128, dtype=torch.int8).npu() - key_test2 = torch.randn(1, 32, 2048, 128, dtype=torch.int8).npu() - value_test2 = torch.randn(1, 32, 2048, 128, dtype=torch.int8).npu() + query_test2 = torch.randn(5, 32, 2048, 128, dtype=torch.float16).npu() + key_test2 = torch.randn(5, 32, 2048, 128, dtype=torch.float16).npu() + value_test2 = torch.randn(5, 32, 2048, 128, dtype=torch.float16).npu() supported_output = self.supported_op_exec(query_test2, key_test2, value_test2, head_dim) - custom_output = self.custom_op_exec_test_quantscale2(query_test2, key_test2, value_test2, head_dim) + custom_output = self.custom_op_exec(query_test2, key_test2, value_test2, head_dim) self.assertRtolEqual(supported_output, custom_output) supported_output = self.supported_op_exec(query_test2, key_test2, value_test2, head_dim) - custom_output = self.vcustom_op_exec_test_int8(query_test2, key_test2, value_test2, head_dim) + custom_output = self.custom_op_exec(query_test2, key_test2, value_test2, head_dim) self.assertRtolEqual(supported_output, custom_output) - self.custom_op_exec_test_int8_fake_tensor(query_test2, key_test2, value_test2, head_dim) + self.custom_op_exec_test_fp16_fake_tensor(query_test2, key_test2, value_test2, head_dim) if __name__ == "__main__":