diff --git a/test/custom_ops/test_npu_scaled_dot_product_attention.py b/test/custom_ops/test_npu_scaled_dot_product_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..25db052e8c01ce8dfeb13cb6e6d385098af65083 --- /dev/null +++ b/test/custom_ops/test_npu_scaled_dot_product_attention.py @@ -0,0 +1,62 @@ +import math +import unittest +import numpy as np +import torch +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import SupportedDevices + + +class TestNPUScaledDotProductAttention(TestCase): + def get_atten_masks(self, query): + shape = [query.shape[-2], query.shape[-1]] + atten_masks = torch.from_numpy(np.tril(np.ones(shape), k=0)) + return atten_masks + + def scaled_dot_product_exec(self, query, key, value, dx, atten_mask, case): + dropout_p = 0 + query.requires_grad = True + key.requires_grad = True + value.requires_grad = True + if case[2]: + res = torch.nn.functional.scaled_dot_product_attention( + query, key, value, None, dropout_p, True) + else: + res = torch.nn.functional.scaled_dot_product_attention( + query, key, value, atten_mask, dropout_p, False) + res.backward(dx) + dq = query.grad + dk = key.grad + dv = value.grad + return res, dq, dk, dv + + def check_result(self, query, key, value, dx, case): + atten_masks = self.get_atten_masks(query) + out_cpu, dq_cpu, dk_cpu, dv_cpu = self.scaled_dot_product_exec(query.to(torch.float), key.to(torch.float), + value.to(torch.float), dx.to(torch.float), torch.tensor(atten_masks).bool(), + case) + out_npu, dq_npu, dk_npu, dv_npu = self.scaled_dot_product_exec(query.npu(), key.npu(), value.npu(), + dx.npu(), torch.tensor(atten_masks).to(torch.float16).bool().npu(), case) + self.assertRtolEqual(out_cpu.to(torch.float16), out_npu, prec=0.01, prec16=0.01) + self.assertRtolEqual(dq_cpu.to(torch.float16), dq_npu, prec=0.01, prec16=0.01) + self.assertRtolEqual(dk_cpu.to(torch.float16), dk_npu, prec=0.01, prec16=0.01) + self.assertRtolEqual(dv_cpu.to(torch.float16), dv_npu, prec=0.01, prec16=0.01) + + @SupportedDevices(['Ascend910B']) + def test_npu_scaled_dot_product_attention(self, device="npu"): + # case: [qshape, vshape, is_causal] + case_list = [ + [[1, 8, 256, 256], [1, 8, 256, 256], True], + [[1, 8, 256, 256], [1, 8, 256, 256], False], + [[1, 8, 128, 256], [1, 8, 256, 256], False] + ] + + for case in case_list: + query = torch.randn(*case[0], dtype=torch.float16) + key = torch.randn(*case[1], dtype=torch.float16) + value = torch.randn(*case[1], dtype=torch.float16) + dx = torch.randn(*case[0], dtype=torch.float16) + self.check_result(query, key, value, dx, case) + +if __name__ == "__main__": + run_tests()