From b0647a0d3a665fdd535f63bff19fa8592220ec4c Mon Sep 17 00:00:00 2001 From: xiaxia3 Date: Wed, 8 Jun 2022 16:25:44 +0800 Subject: [PATCH] add python api UT --- test/test_api/test_python_api.py | 72 ++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 test/test_api/test_python_api.py diff --git a/test/test_api/test_python_api.py b/test/test_api/test_python_api.py new file mode 100644 index 00000000000..83637d9b8a2 --- /dev/null +++ b/test/test_api/test_python_api.py @@ -0,0 +1,72 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader, TensorDataset +import torch_npu + +from torch_npu.testing.testcase import TestCase, run_tests + + +class TestPythonApi(TestCase): + def test_is_storage(self): + a = torch.rand(2,3) + b = torch.FloatStorage([1,2,3,4,5,6]) + + self.assertFalse(torch.is_storage(a)) + self.assertTrue(torch.is_storage(b)) + + def test_storage_casts(self): + storage = torch.IntStorage([-1, 0, 1, 2, 3, 4]) + + storage_type = [ + [storage, 'torch.IntStorage', torch.int32 ], + [storage.float(), 'torch.FloatStorage', torch.float32], + [storage.half(), 'torch.HalfStorage', torch.float16], + [storage.long(), 'torch.LongStorage', torch.int64 ], + [storage.short(), 'torch.ShortStorage', torch.int16 ], + [storage.char(), 'torch.CharStorage', torch.int8 ] + ] + for item in storage_type: + self.assertEqual(item[0].size(), 6) + self.assertEqual(item[0].tolist(), [-1, 0, 1, 2, 3, 4]) + self.assertEqual(item[0].type(), item[1]) + self.assertEqual(item[0].int().tolist(), [-1, 0, 1, 2, 3, 4]) + self.assertIs(item[0].dtype, item[2]) + + def test_DataLoader(self): + class SimpleCustomBatch: + def __init__(self, data): + transposed_data = list(zip(*data)) + self.inp = torch.stack(transposed_data[0], 0) + self.tgt = torch.stack(transposed_data[1], 0) + + # custom memory pinning method on custom type + def pin_memory(self): + self.inp = self.inp.pin_memory() + self.tgt = self.tgt.pin_memory() + return self + def collate_wrapper(batch): + return SimpleCustomBatch(batch) + + inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5).npu() + tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5).npu() + dataset = TensorDataset(inps, tgts) + + loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper, + pin_memory=True) + +if __name__ == "__main__": + run_tests() -- Gitee