From 7ffbb57f0b121679277ab6a827dbc6c859f7caf1 Mon Sep 17 00:00:00 2001 From: bai-yangfan Date: Sat, 20 Feb 2021 16:51:17 +0800 Subject: [PATCH] add advanced_usage_of_checkpoint --- .../advanced_usage_of_checkpoint.md | 82 +++++++++++++++++++ .../load_checkpoint.py | 53 ++++++++++++ .../save_checkpoint.py | 68 +++++++++++++++ .../src/__init__.py | 0 .../src/config.py | 33 ++++++++ .../src/dataset.py | 60 ++++++++++++++ .../advanced_usage_of_checkpoint/src/lenet.py | 61 ++++++++++++++ 7 files changed, 357 insertions(+) create mode 100644 tutorials/tutorial_code/advanced_usage_of_checkpoint/load_checkpoint.py create mode 100644 tutorials/tutorial_code/advanced_usage_of_checkpoint/save_checkpoint.py create mode 100644 tutorials/tutorial_code/advanced_usage_of_checkpoint/src/__init__.py create mode 100644 tutorials/tutorial_code/advanced_usage_of_checkpoint/src/config.py create mode 100644 tutorials/tutorial_code/advanced_usage_of_checkpoint/src/dataset.py create mode 100644 tutorials/tutorial_code/advanced_usage_of_checkpoint/src/lenet.py diff --git a/tutorials/training/source_zh_cn/advanced_use/advanced_usage_of_checkpoint.md b/tutorials/training/source_zh_cn/advanced_use/advanced_usage_of_checkpoint.md index 28b4333444..483b76b2a3 100644 --- a/tutorials/training/source_zh_cn/advanced_use/advanced_usage_of_checkpoint.md +++ b/tutorials/training/source_zh_cn/advanced_use/advanced_usage_of_checkpoint.md @@ -170,3 +170,85 @@ def pytorch2mindspore('torch_resnet.pth'): params_list.append(param_dict) save_checkpoint(params_list, 'ms_resnet.ckpt') ``` + +## 保存模型 + +`Linux` `Ascend` `GPU` `CPU` `模型保存` `中级` `高级` + + + +- [保存模型](#保存模型) + - [概述](#概述) + - [高阶保存方式](#高阶保存方式) + - [高阶载入方式](#高阶载入方式) + + + + + +### 概述 + +MindSpore高阶CheckPoint保存方式支持根据用户需求个性化储存需要保存的结果,并且可以根据用户设置加载过滤掉不关注的模型信息。使用`save_checkpoint`可以进行个性化网络参数保存,`load_checkpoint`进行个性化网络参数载入。 + +### 高阶保存方式 + +1. 准备模型代码。训练保存的代码可参见:,其中,`train.py`为训练的主函数所在,`src/`目录中包含LeNet模型的定义、数据处理和配置信息等,`script/`目录中包含不同配置下的训练脚本。 + +2. 准备数据集。请参考[mnist](http://yann.lecun.com/exdb/mnist/)链接下载数据集,并利用脚本`src/dataset.py`创建训练datasets。 + +3. 训练保存模型。可以根据需要保存网络的全部和部分参数。例如:保存全部网络和优化器的参数可以使用以下代码。 + + ```python + + from mindspore import save_checkpoint + + network = LeNet5(cfg.num_classes) + save_checkpoint(network, "lenet.ckpt") + + ``` + + 也可以存储部分网络信息,同样以上述LeNet为例,参考`/src/lenet.py`下的网络结构,可使用以下代码。 + + ```python + + from mindspore import save_checkpoint + + network = LeNet5(cfg.num_classes) + save_checkpoint(network, "lenet.ckpt", [{"name": "conv1", "data": network.conv1.weight}]) + + ``` + + - `save_checkpoint`中integrated_save负责组合并行模式下的网络权重,当储存全网参数并且打开并行模式时需要调用Callback机制传入回调函数ModelCheckpoint对象,可以保存模型参数。 + + - `save_checkpoint`中async_save负责异步存储参数,当网络比较大存储耗时比较久时,可以异步在训练时同步储存网络参数。 + +4. 保存用户定义的子网络。 + + ```python + + network = LeNet5(cfg.num_classes) + config_ck = CheckpointConfig(save_checkpoint_steps=32, keep_checkpoint_max=10, saved_network=network) + ckpoint_cb = ModelCheckpoint(prefix="LeNet5", config=config_ck) + model.train(10, dataset, callbacks=ckpoint_cb) + + ``` + +### 高阶载入方式 + +1. 载入模型全部参数。 + + ```python + + param_dict = load_checkpoint("lenet.ckpt") + + ``` + +2. 载入时过滤部分网络参数。 + + ```python + + param_dict = load_checkpoint("lenet.ckpt", filter_prefix="conv1") + + ``` + + - `load_checkpoint`中的strict_load开启后会严格匹配参数名,否则只匹配相同尾缀。 diff --git a/tutorials/tutorial_code/advanced_usage_of_checkpoint/load_checkpoint.py b/tutorials/tutorial_code/advanced_usage_of_checkpoint/load_checkpoint.py new file mode 100644 index 0000000000..e5728fa086 --- /dev/null +++ b/tutorials/tutorial_code/advanced_usage_of_checkpoint/load_checkpoint.py @@ -0,0 +1,53 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +######################## load checkpoint example ######################## +eval lenet according to model file: +python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt +""" + +import os +import argparse +import mindspore.nn as nn +from mindspore import context +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.train import Model +from mindspore.nn.metrics import Accuracy +from src.dataset import create_dataset +from src.config import mnist_cfg as cfg +from src.lenet import LeNet5 + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='MindSpore Lenet Example') + parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], + help='device where the code will be implemented (default: Ascend)') + parser.add_argument('--data_path', type=str, default="./Data", + help='path where the dataset is saved') + parser.add_argument('--ckpt_path', type=str, default="", help='if mode is test, must provide\ + path where the trained ckpt file') + + args = parser.parse_args() + + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) + + network = LeNet5(cfg.num_classes) + net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + repeat_size = cfg.epoch_size + net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) + model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) + + print("============== Starting loading ==============") + param_dict = load_checkpoint(args.ckpt_path) + load_param_into_net(network, param_dict) diff --git a/tutorials/tutorial_code/advanced_usage_of_checkpoint/save_checkpoint.py b/tutorials/tutorial_code/advanced_usage_of_checkpoint/save_checkpoint.py new file mode 100644 index 0000000000..62c72f5306 --- /dev/null +++ b/tutorials/tutorial_code/advanced_usage_of_checkpoint/save_checkpoint.py @@ -0,0 +1,68 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +######################## save checkpoint example ######################## +train lenet and get network model files(.ckpt) : +python train.py --data_path /YourDataPath +""" + +import os +import argparse +from src.config import mnist_cfg as cfg +from src.dataset import create_dataset +from src.lenet import LeNet5 +import mindspore.nn as nn +from mindspore import context +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from mindspore.train import Model +from mindspore.nn.metrics import Accuracy +from mindspore.common import set_seed + + +parser = argparse.ArgumentParser(description='MindSpore Lenet Example') +parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], + help='device where the code will be implemented (default: Ascend)') +parser.add_argument('--data_path', type=str, default="./Data", + help='path where the dataset is saved') +parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ + path where the trained ckpt file') +args = parser.parse_args() +set_seed(1) + + +if __name__ == "__main__": + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) + ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size) + if ds_train.get_dataset_size() == 0: + raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") + + network = LeNet5(cfg.num_classes) + net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) + time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) + config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, + keep_checkpoint_max=cfg.keep_checkpoint_max) + ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=args.ckpt_path, config=config_ck) + + if args.device_target != "Ascend": + model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) + else: + if context.get_context("mode") == context.PYNATIVE_MODE: + model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) + else: + model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O3") + + print("============== Starting Training ==============") + model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()]) diff --git a/tutorials/tutorial_code/advanced_usage_of_checkpoint/src/__init__.py b/tutorials/tutorial_code/advanced_usage_of_checkpoint/src/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tutorials/tutorial_code/advanced_usage_of_checkpoint/src/config.py b/tutorials/tutorial_code/advanced_usage_of_checkpoint/src/config.py new file mode 100644 index 0000000000..e191906a07 --- /dev/null +++ b/tutorials/tutorial_code/advanced_usage_of_checkpoint/src/config.py @@ -0,0 +1,33 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in train.py +""" + +from easydict import EasyDict as edict + +mnist_cfg = edict({ + 'num_classes': 10, + 'lr': 0.01, + 'momentum': 0.9, + 'epoch_size': 10, + 'batch_size': 32, + 'buffer_size': 1000, + 'image_height': 32, + 'image_width': 32, + 'save_checkpoint_steps': 1875, + 'keep_checkpoint_max': 10, + 'air_name': "lenet", +}) diff --git a/tutorials/tutorial_code/advanced_usage_of_checkpoint/src/dataset.py b/tutorials/tutorial_code/advanced_usage_of_checkpoint/src/dataset.py new file mode 100644 index 0000000000..df9eecda1f --- /dev/null +++ b/tutorials/tutorial_code/advanced_usage_of_checkpoint/src/dataset.py @@ -0,0 +1,60 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Produce the dataset +""" + +import mindspore.dataset as ds +import mindspore.dataset.vision.c_transforms as CV +import mindspore.dataset.transforms.c_transforms as C +from mindspore.dataset.vision import Inter +from mindspore.common import dtype as mstype + + +def create_dataset(data_path, batch_size=32, repeat_size=1, + num_parallel_workers=1): + """ + create dataset for train or test + """ + # define dataset + mnist_ds = ds.MnistDataset(data_path) + + resize_height, resize_width = 32, 32 + rescale = 1.0 / 255.0 + shift = 0.0 + rescale_nml = 1 / 0.3081 + shift_nml = -1 * 0.1307 / 0.3081 + + # define map operations + resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode + rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) + rescale_op = CV.Rescale(rescale, shift) + hwc2chw_op = CV.HWC2CHW() + type_cast_op = C.TypeCast(mstype.int32) + + # apply map operations on images + mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) + + # apply DatasetOps + buffer_size = 10000 + mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script + mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) + mnist_ds = mnist_ds.repeat(repeat_size) + + return mnist_ds diff --git a/tutorials/tutorial_code/advanced_usage_of_checkpoint/src/lenet.py b/tutorials/tutorial_code/advanced_usage_of_checkpoint/src/lenet.py new file mode 100644 index 0000000000..f34dedbb6c --- /dev/null +++ b/tutorials/tutorial_code/advanced_usage_of_checkpoint/src/lenet.py @@ -0,0 +1,61 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""LeNet.""" +import mindspore.nn as nn +from mindspore.common.initializer import Normal + + +class LeNet5(nn.Cell): + """ + Lenet network + + Args: + num_class (int): Number of classes. Default: 10. + num_channel (int): Number of channels. Default: 1. + + Returns: + Tensor, output tensor + Examples: + >>> LeNet(num_class=10) + + """ + def __init__(self, num_class=10, num_channel=1, include_top=True): + super(LeNet5, self).__init__() + self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') + self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.include_top = include_top + if self.include_top: + self.flatten = nn.Flatten() + self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) + self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) + self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) + + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + if not self.include_top: + return x + x = self.flatten(x) + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + x = self.fc3(x) + return x -- Gitee