diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index a67833ff803038017202a5081164acf146d5dad9..6c9dbf9ab921d715096e8515f6c4a662d04ed8a6 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -1493,6 +1493,15 @@ class TestAntiQuant(TestCase): x = x.to(dstType) self.assertTrue(x.numel() * x.element_size() == res.numel() * res.element_size()) +class TestRFFT(TestCase): + def test_npu_rfft(self): + with FakeTensorMode(): + shape = [64,64,1024] + length = shape[-1] + x = torch.randn(shape, dtype=torch.float32).npu() + res = torch.fft.rfft(x, length, norm = "backward") + self.assertTrue(res.shape[2] == (length / 2 + 1)) + instantiate_parametrized_tests(FakeTensorTest) instantiate_device_type_tests(FakeTensorOpInfoTest, globals(), only_for="cpu")