From a4765bdc812455d764a1756af439a84ea524a794 Mon Sep 17 00:00:00 2001 From: xiaxia3 Date: Mon, 14 Feb 2022 14:50:48 +0800 Subject: [PATCH 1/4] =?UTF-8?q?1.8.1=E7=89=88=E6=9C=AC=20save=E3=80=81load?= =?UTF-8?q?=E6=8F=92=E4=BB=B6=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_api/test_serialization.py | 75 +++++++++++++ torch_npu/utils/__init__.py | 0 torch_npu/utils/serialization.py | 159 ++++++++++++++++++++++++++++ torch_npu/utils/utils.py | 118 +++++++++++++++++++++ 4 files changed, 352 insertions(+) create mode 100644 test/test_api/test_serialization.py create mode 100644 torch_npu/utils/__init__.py create mode 100644 torch_npu/utils/serialization.py create mode 100644 torch_npu/utils/utils.py diff --git a/test/test_api/test_serialization.py b/test/test_api/test_serialization.py new file mode 100644 index 0000000000..83e2c497d7 --- /dev/null +++ b/test/test_api/test_serialization.py @@ -0,0 +1,75 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import torch +import torch_npu +import torch.nn as nn +import torch.nn.functional as F + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests + + +class NpuMNIST(nn.Module): + + def __init__(self): + super(XlaMNIST, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2(x), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + +class TestSerialization(TestCase): + def test_save(self): + x = torch.randn(5).npu() + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'data.pt') + torch_npu.save(x, path) + x_loaded = torch_npu.load(path) + self.assertTrue(x.cpu().sum(), x_loaded.cpu().sum()) + + def test_save_tuple(self): + x = torch.randn(5).npu() + number = 3 + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'data.pt') + torch_npu.save((x, number), path) + x_loaded, number_loaded = torch_npu.load(path) + self.assertTrue(x.cpu().sum(), x_loaded.cpu().sum()) + self.assertTrue(number, number_loaded) + + def test_serialization_api(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'data.pt') + model = NpuMNIST().npu() + torch_npu.save(model.state_dict(), path) + state_dict = torch_npu.load(path) + cpu_model = NpuMNIST() + cpu_model.load_state_dict(state_dict) + loaded_model = cpu_model.npu() + self.assertTrue(loaded_model.state_dict()) + +instantiate_device_type_tests(TestSerialization, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/utils/__init__.py b/torch_npu/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torch_npu/utils/serialization.py b/torch_npu/utils/serialization.py new file mode 100644 index 0000000000..03847eaa93 --- /dev/null +++ b/torch_npu/utils/serialization.py @@ -0,0 +1,159 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import os +import shutil +import collections +import threading + +import torch +import torch_npu +import torch.nn.functional as F +import torch_npu.utils.utils as nu +from torch_npu.npu.utils import _get_device_index + + +class DeviceContext(object): + + def __init__(self, device): + self.device = device + +def is_npu_tensor(tensor): + return tensor.device.type == 'npu' + + +class ToNpuTensorArena(object): + + def __init__(self, convert_fn, select_fn): + self._convert_fn = convert_fn + self._select_fn = select_fn + self._tensors = [] + + def _add(self, tensor): + self._tensors.append(tensor) + + def _convert(self): + self._index = 0 + if self._tensors: + self._converted_tensors = self._convert_fn(self._tensors) + else: + self._converted_tensors = [] + + def _get_converted_tensor(self): + assert self._index < len(self._converted_tensors) + new_tensor = self._converted_tensors[self._index] + self._index += 1 + return new_tensor + + def _collect_tensors(self, inputs): + + def collect_fn(value): + self._add(value) + + nu.for_each_instance(inputs, lambda x: self._select_fn(x), collect_fn) + + def _replace_tensors(self, inputs): + + def convert_fn(value): + return self._get_converted_tensor() + + return nu.for_each_instance_rewrite(inputs, lambda x: self._select_fn(x), + convert_fn) + + def transform(self, inputs): + self._tensors = [] + self._collect_tensors(inputs) + self._convert() + return self._replace_tensors(inputs) + + +class TensorReference(object): + + def __init__(self, tid): + self.tid = tid + + +def _get_tensors_folder(path): + return path + '.tensors' + + +def _get_tensor_file(path, tid): + return os.path.join(path, 'tensor_{}.pt'.format(tid)) + + +def _rewrite_data(path, data): + + def convert_fn(tensors): + rewritten_tensors = [] + for i, t in enumerate(tensors): + torch.save(t.cpu(), _get_tensor_file(path, i)) + rewritten_tensors.append(TensorReference(i)) + return rewritten_tensors + + def select_fn(v): + return type(v) == torch.Tensor and is_npu_tensor(v) + + if os.path.isdir(path): + shutil.rmtree(path) + os.mkdir(path) + return ToNpuTensorArena(convert_fn, select_fn).transform(data) + + +def save(data, path): + """Saves the input data into a file. + + The saved data is transferred to PyTorch CPU device before being saved, so a + following `torch.load()` will load CPU data. + Care must be taken when working with views. Instead of saving views it's + recommended that you recreate them after the tensors have been loaded and + moved to their destination device(s). + + Args: + data: The input data to be saved. Any nested combination of Python objects + (list, tuples, sets, dicts, ...). + path: The destination file for the data saving operation. all the writes from + the same host will override each other. + """ + + ref_data = _rewrite_data(_get_tensors_folder(path), data) + torch.save(ref_data, path) + + +def load(path): + """Loads data previously saved with the `save()` API. + + Args: + path (str): The path passed to the `save()` API. + Returns: + The loaded data. + """ + ref_data = torch.load(path) + tensor_folder = _get_tensors_folder(path) + + def convert_fn(tensors): + rewritten_tensors = [] + for t in tensors: + rewritten_tensors.append( + torch.load(_get_tensor_file(tensor_folder, t.tid))) + return rewritten_tensors + + def select_fn(v): + return type(v) == TensorReference + + return ToNpuTensorArena(convert_fn, select_fn).transform(ref_data) diff --git a/torch_npu/utils/utils.py b/torch_npu/utils/utils.py new file mode 100644 index 0000000000..100b4bb8bd --- /dev/null +++ b/torch_npu/utils/utils.py @@ -0,0 +1,118 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import copy +import os + +class DataWrapper(object): + """Utility class to wrap data structures to be sent to device.""" + + def __init__(self): + pass + + def get_tensors(self): + """Returns the list of CPU tensors which must be sent to device.""" + raise NotImplementedError('The method is missing an implementation') + + def from_tensors(self, tensors): + """Build an instance of the wrapped object given the input tensors. + + The number of tensors is the same as the ones returned by the + `get_tensors()` API, and `tensors[i]` is the device copy of + `get_tensors()[i]`. + + Returns: + The unwrapped instance of the object with tensors on device. + """ + raise NotImplementedError('The method is missing an implementation') + + +def _for_each_instance(value, select_fn, fn, seen): + if id(value) in seen: + return + seen.add(id(value)) + if select_fn(value): + fn(value) + elif isinstance(value, dict): + for k, v in value.items(): + _for_each_instance(k, select_fn, fn, seen) + _for_each_instance(v, select_fn, fn, seen) + elif isinstance(value, (list, tuple, set)): + for x in value: + _for_each_instance(x, select_fn, fn, seen) + elif isinstance(value, DataWrapper): + for x in value.get_tensors(): + _for_each_instance(x, select_fn, fn, seen) + elif hasattr(value, '__dict__'): + for k in value.__dict__.keys(): + _for_each_instance(value.__dict__[k], select_fn, fn, seen) + + +def for_each_instance(value, select_fn, fn): + seen = set() + _for_each_instance(value, select_fn, fn, seen) + + +def _for_each_instance_rewrite(value, select_fn, fn, rwmap): + rvalue = rwmap.get(id(value), None) + if rvalue is not None: + return rvalue + result = value + if select_fn(value): + result = fn(value) + rwmap[id(value)] = result + elif isinstance(value, dict): + result = dict() + rwmap[id(value)] = result + for k, v in value.items(): + k = _for_each_instance_rewrite(k, select_fn, fn, rwmap) + result[k] = _for_each_instance_rewrite(v, select_fn, fn, rwmap) + elif isinstance(value, set): + result = set() + rwmap[id(value)] = result + for x in value: + result.add(_for_each_instance_rewrite(x, select_fn, fn, rwmap)) + elif isinstance(value, (list, tuple)): + # We transform tuples to lists here, as we need to set the object mapping + # before calling into the recursion. This code might break if user code + # expects a tuple. + result = list() + rwmap[id(value)] = result + for x in value: + result.append(_for_each_instance_rewrite(x, select_fn, fn, rwmap)) + elif isinstance(value, DataWrapper): + new_tensors = [] + for x in value.get_tensors(): + new_tensors.append(_for_each_instance_rewrite(x, select_fn, fn, rwmap)) + result = value.from_tensors(new_tensors) + rwmap[id(value)] = result + elif hasattr(value, '__dict__'): + result = copy.copy(value) + rwmap[id(value)] = result + for k in result.__dict__.keys(): + v = _for_each_instance_rewrite(result.__dict__[k], select_fn, fn, rwmap) + result.__dict__[k] = v + else: + rwmap[id(value)] = result + return result + + +def for_each_instance_rewrite(value, select_fn, fn): + rwmap = dict() + return _for_each_instance_rewrite(value, select_fn, fn, rwmap) \ No newline at end of file -- Gitee From dc1a1d7ae68d466deafd526ec255ede2e3967922 Mon Sep 17 00:00:00 2001 From: xiaxia3 Date: Mon, 14 Feb 2022 14:58:17 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E6=9B=B4=E6=96=B0device?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_api/test_serialization.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_api/test_serialization.py b/test/test_api/test_serialization.py index 83e2c497d7..58c40ec1f7 100644 --- a/test/test_api/test_serialization.py +++ b/test/test_api/test_serialization.py @@ -26,7 +26,7 @@ from torch_npu.testing.common_device_type import instantiate_device_type_tests class NpuMNIST(nn.Module): def __init__(self): - super(XlaMNIST, self).__init__() + super(NpuMNIST, self).__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.fc1 = nn.Linear(320, 50) @@ -41,7 +41,7 @@ class NpuMNIST(nn.Module): return F.log_softmax(x, dim=1) class TestSerialization(TestCase): - def test_save(self): + def test_save(self, device): x = torch.randn(5).npu() with tempfile.TemporaryDirectory() as tmpdir: path = os.path.join(tmpdir, 'data.pt') @@ -49,7 +49,7 @@ class TestSerialization(TestCase): x_loaded = torch_npu.load(path) self.assertTrue(x.cpu().sum(), x_loaded.cpu().sum()) - def test_save_tuple(self): + def test_save_tuple(self, device): x = torch.randn(5).npu() number = 3 with tempfile.TemporaryDirectory() as tmpdir: @@ -59,7 +59,7 @@ class TestSerialization(TestCase): self.assertTrue(x.cpu().sum(), x_loaded.cpu().sum()) self.assertTrue(number, number_loaded) - def test_serialization_api(self): + def test_serialization_api(self, device): with tempfile.TemporaryDirectory() as tmpdir: path = os.path.join(tmpdir, 'data.pt') model = NpuMNIST().npu() -- Gitee From 504de5acf22106bb01cde0e942ecd2d55b7e73f4 Mon Sep 17 00:00:00 2001 From: xiaxia3 Date: Mon, 14 Feb 2022 15:02:36 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index cedd54404f..88ccbe7253 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -26,7 +26,8 @@ import torch_npu._C from .version import __version__ as __version__ -__all__ = [] +__all__ = ["save", "load"] +from .utils.serialization import (save, load) for name in dir(torch_npu._C._VariableFunctions): -- Gitee From d4f2a8e384ed221e947de453983b0633f531ea0a Mon Sep 17 00:00:00 2001 From: xiaxia3 Date: Mon, 14 Feb 2022 15:17:57 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/utils/serialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_npu/utils/serialization.py b/torch_npu/utils/serialization.py index 03847eaa93..eb008aaa3f 100644 --- a/torch_npu/utils/serialization.py +++ b/torch_npu/utils/serialization.py @@ -35,7 +35,7 @@ class DeviceContext(object): self.device = device def is_npu_tensor(tensor): - return tensor.device.type == 'npu' + return tensor.device.type == 'npu' class ToNpuTensorArena(object): -- Gitee