diff --git a/test/test_api/test_torch/test_generators.py b/test/test_api/test_torch/test_generators.py new file mode 100644 index 0000000000000000000000000000000000000000..38fb1675e538814e28bf4be7fbd04ff72e8b8be3 --- /dev/null +++ b/test/test_api/test_torch/test_generators.py @@ -0,0 +1,42 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu + +from torch_npu.testing.testcase import TestCase, run_tests + +device = 'npu:0' +torch.npu.set_device(device) + +def get_npu_type(type_name): + if isinstance(type_name, type): + type_name = '{}.{}'.format(type_name.__module__, type_name.__name__) + module, name = type_name.rsplit('.', 1) + assert module == 'torch' + return getattr(torch.npu, name) + +class TestGenerators(TestCase): + def test_generator(self): + g_npu = torch.Generator(device=device) + print(g_npu.device) + self.assertExpectedInline(str(g_npu.device), '''npu:0''') + + def test_default_generator(self): + output = torch.default_generator + print(output) + + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_api/test_torch/test_locally_disabling_gradient_omputation.py b/test/test_api/test_torch/test_locally_disabling_gradient_omputation.py new file mode 100644 index 0000000000000000000000000000000000000000..15899c3ca5b9781aa6f2efd48f5d286cee2f345f --- /dev/null +++ b/test/test_api/test_torch/test_locally_disabling_gradient_omputation.py @@ -0,0 +1,68 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu + +from torch_npu.testing.testcase import TestCase, run_tests +from torch.testing._internal.common_utils import freeze_rng_state + +device = 'npu:0' +torch.npu.set_device(device) + +class TestLDGComputation(TestCase): + def test_no_grad(self): + x = torch.tensor([1], dtype=torch.float32, device=device, requires_grad=True) + with torch.no_grad(): + y = x * 2 + self.assertFalse(y.requires_grad) + + @torch.no_grad() + def doubler(x): + return x * 2 + z = doubler(x) + self.assertFalse(z.requires_grad) + + def test_enable_grad(self): + x = torch.tensor([1], dtype=torch.float32, device=device, requires_grad=True) + with torch.no_grad(): + with torch.enable_grad(): + y = x * 2 + self.assertTrue(y.requires_grad) + + @torch.enable_grad() + def doubler(x): + return x * 2 + with torch.no_grad(): + z = doubler(x) + self.assertTrue(z.requires_grad) + + def test_set_grad_enabled(self): + x = torch.tensor([1.], device=device, requires_grad=True) + with torch.set_grad_enabled(False): + y = x * 2 + self.assertFalse(y.requires_grad) + with torch.set_grad_enabled(True): + y = x * 2 + self.assertTrue(y.requires_grad) + with torch.set_grad_enabled(False): + torch.set_grad_enabled(True) + y = x * 2 + self.assertTrue(y.requires_grad) + + +if __name__ == "__main__": + run_tests() + + diff --git a/test/test_api/test_torch/test_parallelism.py b/test/test_api/test_torch/test_parallelism.py new file mode 100644 index 0000000000000000000000000000000000000000..eea1b5eb130c0674a13761ea1055af14f329eae7 --- /dev/null +++ b/test/test_api/test_torch/test_parallelism.py @@ -0,0 +1,42 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu + +from torch_npu.testing.testcase import TestCase, run_tests +from torch.testing._internal.common_utils import freeze_rng_state + +device = 'npu:0' +torch.npu.set_device(device) + +class TestParallelism(TestCase): + def test_set_num_threads(self): + torch.set_num_threads(2) + + def test_get_num_threads(self): + output = torch.get_num_threads() + print(output) + + def test_set_num_interop_threads(self): + torch.set_num_interop_threads(2) + + def test_get_num_interop_threads(self): + output = torch.get_num_interop_threads() + print(output) + +if __name__ == "__main__": + run_tests() + + diff --git a/test/test_api/test_serialization.py b/test/test_api/test_torch/test_serialization.py similarity index 100% rename from test/test_api/test_serialization.py rename to test/test_api/test_torch/test_serialization.py diff --git a/test/test_api/test_torch/test_utilities.py b/test/test_api/test_torch/test_utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..fddd2df87cea6c6df34fc194c6b20cca459d2edf --- /dev/null +++ b/test/test_api/test_torch/test_utilities.py @@ -0,0 +1,51 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests + +device = 'npu:0' +torch.npu.set_device(device) + +class TestUtilities(TestCase): + def test_compiled_with_cxx11_abi(self): + output = torch.compiled_with_cxx11_abi() + self.assertTrue(output) + + def test_result_type(self): + self.assertEqual(torch.result_type(torch.tensor(1, dtype=torch.int, device=device), 1), torch.int) + self.assertEqual(torch.result_type(1, torch.tensor(1, dtype=torch.int, device=device)), torch.int) + self.assertEqual(torch.result_type(1, 1.), torch.get_default_dtype()) + self.assertEqual(torch.result_type(torch.tensor(1, device=device), 1.), torch.get_default_dtype()) + self.assertEqual(torch.result_type(torch.tensor(1, dtype=torch.long, device=device), + torch.tensor([1, 1], dtype=torch.int, device=device)), + torch.int) + self.assertEqual(torch.result_type(torch.tensor([1., 1.], dtype=torch.float, device=device), 1.), torch.float) + self.assertEqual(torch.result_type(torch.tensor(1., dtype=torch.float, device=device), + torch.tensor(1, dtype=torch.double, device=device)), + torch.double) + + def test_can_cast(self): + self.assertTrue(torch.can_cast(torch.double, torch.float)) + self.assertFalse(torch.can_cast(torch.float, torch.int)) + + def test_promote_types(self): + self.assertEqual(torch.promote_types(torch.float, torch.int), torch.float) + self.assertEqual(torch.promote_types(torch.float, torch.double), torch.double) + self.assertEqual(torch.promote_types(torch.int, torch.uint8), torch.int) + +if __name__ == "__main__": + run_tests()