diff --git a/tutorials/training/source_zh_cn/advanced_use/advanced_usage_of_checkpoint.md b/tutorials/training/source_zh_cn/advanced_use/advanced_usage_of_checkpoint.md
index 28b4333444dec361237b53be675dfdbbb8f805bb..483b76b2a3343cd9fa344ca79941b7daf9ca8077 100644
--- a/tutorials/training/source_zh_cn/advanced_use/advanced_usage_of_checkpoint.md
+++ b/tutorials/training/source_zh_cn/advanced_use/advanced_usage_of_checkpoint.md
@@ -170,3 +170,85 @@ def pytorch2mindspore('torch_resnet.pth'):
params_list.append(param_dict)
save_checkpoint(params_list, 'ms_resnet.ckpt')
```
+
+## 保存模型
+
+`Linux` `Ascend` `GPU` `CPU` `模型保存` `中级` `高级`
+
+
+
+- [保存模型](#保存模型)
+ - [概述](#概述)
+ - [高阶保存方式](#高阶保存方式)
+ - [高阶载入方式](#高阶载入方式)
+
+
+
+
+
+### 概述
+
+MindSpore高阶CheckPoint保存方式支持根据用户需求个性化储存需要保存的结果,并且可以根据用户设置加载过滤掉不关注的模型信息。使用`save_checkpoint`可以进行个性化网络参数保存,`load_checkpoint`进行个性化网络参数载入。
+
+### 高阶保存方式
+
+1. 准备模型代码。训练保存的代码可参见:,其中,`train.py`为训练的主函数所在,`src/`目录中包含LeNet模型的定义、数据处理和配置信息等,`script/`目录中包含不同配置下的训练脚本。
+
+2. 准备数据集。请参考[mnist](http://yann.lecun.com/exdb/mnist/)链接下载数据集,并利用脚本`src/dataset.py`创建训练datasets。
+
+3. 训练保存模型。可以根据需要保存网络的全部和部分参数。例如:保存全部网络和优化器的参数可以使用以下代码。
+
+ ```python
+
+ from mindspore import save_checkpoint
+
+ network = LeNet5(cfg.num_classes)
+ save_checkpoint(network, "lenet.ckpt")
+
+ ```
+
+ 也可以存储部分网络信息,同样以上述LeNet为例,参考`/src/lenet.py`下的网络结构,可使用以下代码。
+
+ ```python
+
+ from mindspore import save_checkpoint
+
+ network = LeNet5(cfg.num_classes)
+ save_checkpoint(network, "lenet.ckpt", [{"name": "conv1", "data": network.conv1.weight}])
+
+ ```
+
+ - `save_checkpoint`中integrated_save负责组合并行模式下的网络权重,当储存全网参数并且打开并行模式时需要调用Callback机制传入回调函数ModelCheckpoint对象,可以保存模型参数。
+
+ - `save_checkpoint`中async_save负责异步存储参数,当网络比较大存储耗时比较久时,可以异步在训练时同步储存网络参数。
+
+4. 保存用户定义的子网络。
+
+ ```python
+
+ network = LeNet5(cfg.num_classes)
+ config_ck = CheckpointConfig(save_checkpoint_steps=32, keep_checkpoint_max=10, saved_network=network)
+ ckpoint_cb = ModelCheckpoint(prefix="LeNet5", config=config_ck)
+ model.train(10, dataset, callbacks=ckpoint_cb)
+
+ ```
+
+### 高阶载入方式
+
+1. 载入模型全部参数。
+
+ ```python
+
+ param_dict = load_checkpoint("lenet.ckpt")
+
+ ```
+
+2. 载入时过滤部分网络参数。
+
+ ```python
+
+ param_dict = load_checkpoint("lenet.ckpt", filter_prefix="conv1")
+
+ ```
+
+ - `load_checkpoint`中的strict_load开启后会严格匹配参数名,否则只匹配相同尾缀。
diff --git a/tutorials/tutorial_code/advanced_usage_of_checkpoint/load_checkpoint.py b/tutorials/tutorial_code/advanced_usage_of_checkpoint/load_checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5728fa086ccf6c9ad0ae910310d1b2aae7579ed
--- /dev/null
+++ b/tutorials/tutorial_code/advanced_usage_of_checkpoint/load_checkpoint.py
@@ -0,0 +1,53 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+######################## load checkpoint example ########################
+eval lenet according to model file:
+python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt
+"""
+
+import os
+import argparse
+import mindspore.nn as nn
+from mindspore import context
+from mindspore.train.serialization import load_checkpoint, load_param_into_net
+from mindspore.train import Model
+from mindspore.nn.metrics import Accuracy
+from src.dataset import create_dataset
+from src.config import mnist_cfg as cfg
+from src.lenet import LeNet5
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
+ parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
+ help='device where the code will be implemented (default: Ascend)')
+ parser.add_argument('--data_path', type=str, default="./Data",
+ help='path where the dataset is saved')
+ parser.add_argument('--ckpt_path', type=str, default="", help='if mode is test, must provide\
+ path where the trained ckpt file')
+
+ args = parser.parse_args()
+
+ context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
+
+ network = LeNet5(cfg.num_classes)
+ net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
+ repeat_size = cfg.epoch_size
+ net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
+ model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
+
+ print("============== Starting loading ==============")
+ param_dict = load_checkpoint(args.ckpt_path)
+ load_param_into_net(network, param_dict)
diff --git a/tutorials/tutorial_code/advanced_usage_of_checkpoint/save_checkpoint.py b/tutorials/tutorial_code/advanced_usage_of_checkpoint/save_checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..62c72f5306b7ba9c440c4880800159a5a1185a81
--- /dev/null
+++ b/tutorials/tutorial_code/advanced_usage_of_checkpoint/save_checkpoint.py
@@ -0,0 +1,68 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+######################## save checkpoint example ########################
+train lenet and get network model files(.ckpt) :
+python train.py --data_path /YourDataPath
+"""
+
+import os
+import argparse
+from src.config import mnist_cfg as cfg
+from src.dataset import create_dataset
+from src.lenet import LeNet5
+import mindspore.nn as nn
+from mindspore import context
+from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
+from mindspore.train import Model
+from mindspore.nn.metrics import Accuracy
+from mindspore.common import set_seed
+
+
+parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
+parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
+ help='device where the code will be implemented (default: Ascend)')
+parser.add_argument('--data_path', type=str, default="./Data",
+ help='path where the dataset is saved')
+parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\
+ path where the trained ckpt file')
+args = parser.parse_args()
+set_seed(1)
+
+
+if __name__ == "__main__":
+ context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
+ ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size)
+ if ds_train.get_dataset_size() == 0:
+ raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
+
+ network = LeNet5(cfg.num_classes)
+ net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
+ net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
+ time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
+ config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
+ keep_checkpoint_max=cfg.keep_checkpoint_max)
+ ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=args.ckpt_path, config=config_ck)
+
+ if args.device_target != "Ascend":
+ model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
+ else:
+ if context.get_context("mode") == context.PYNATIVE_MODE:
+ model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
+ else:
+ model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O3")
+
+ print("============== Starting Training ==============")
+ model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])
diff --git a/tutorials/tutorial_code/advanced_usage_of_checkpoint/src/__init__.py b/tutorials/tutorial_code/advanced_usage_of_checkpoint/src/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tutorials/tutorial_code/advanced_usage_of_checkpoint/src/config.py b/tutorials/tutorial_code/advanced_usage_of_checkpoint/src/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e191906a073088422959f0b40eb2d11e632550d6
--- /dev/null
+++ b/tutorials/tutorial_code/advanced_usage_of_checkpoint/src/config.py
@@ -0,0 +1,33 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+network config setting, will be used in train.py
+"""
+
+from easydict import EasyDict as edict
+
+mnist_cfg = edict({
+ 'num_classes': 10,
+ 'lr': 0.01,
+ 'momentum': 0.9,
+ 'epoch_size': 10,
+ 'batch_size': 32,
+ 'buffer_size': 1000,
+ 'image_height': 32,
+ 'image_width': 32,
+ 'save_checkpoint_steps': 1875,
+ 'keep_checkpoint_max': 10,
+ 'air_name': "lenet",
+})
diff --git a/tutorials/tutorial_code/advanced_usage_of_checkpoint/src/dataset.py b/tutorials/tutorial_code/advanced_usage_of_checkpoint/src/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..df9eecda1fe327a405057be7993e401c5ae08187
--- /dev/null
+++ b/tutorials/tutorial_code/advanced_usage_of_checkpoint/src/dataset.py
@@ -0,0 +1,60 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+Produce the dataset
+"""
+
+import mindspore.dataset as ds
+import mindspore.dataset.vision.c_transforms as CV
+import mindspore.dataset.transforms.c_transforms as C
+from mindspore.dataset.vision import Inter
+from mindspore.common import dtype as mstype
+
+
+def create_dataset(data_path, batch_size=32, repeat_size=1,
+ num_parallel_workers=1):
+ """
+ create dataset for train or test
+ """
+ # define dataset
+ mnist_ds = ds.MnistDataset(data_path)
+
+ resize_height, resize_width = 32, 32
+ rescale = 1.0 / 255.0
+ shift = 0.0
+ rescale_nml = 1 / 0.3081
+ shift_nml = -1 * 0.1307 / 0.3081
+
+ # define map operations
+ resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
+ rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
+ rescale_op = CV.Rescale(rescale, shift)
+ hwc2chw_op = CV.HWC2CHW()
+ type_cast_op = C.TypeCast(mstype.int32)
+
+ # apply map operations on images
+ mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
+ mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
+ mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
+ mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
+ mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
+
+ # apply DatasetOps
+ buffer_size = 10000
+ mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
+ mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
+ mnist_ds = mnist_ds.repeat(repeat_size)
+
+ return mnist_ds
diff --git a/tutorials/tutorial_code/advanced_usage_of_checkpoint/src/lenet.py b/tutorials/tutorial_code/advanced_usage_of_checkpoint/src/lenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f34dedbb6c40808aa97d891ac7a28391058b9f82
--- /dev/null
+++ b/tutorials/tutorial_code/advanced_usage_of_checkpoint/src/lenet.py
@@ -0,0 +1,61 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""LeNet."""
+import mindspore.nn as nn
+from mindspore.common.initializer import Normal
+
+
+class LeNet5(nn.Cell):
+ """
+ Lenet network
+
+ Args:
+ num_class (int): Number of classes. Default: 10.
+ num_channel (int): Number of channels. Default: 1.
+
+ Returns:
+ Tensor, output tensor
+ Examples:
+ >>> LeNet(num_class=10)
+
+ """
+ def __init__(self, num_class=10, num_channel=1, include_top=True):
+ super(LeNet5, self).__init__()
+ self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
+ self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
+ self.relu = nn.ReLU()
+ self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.include_top = include_top
+ if self.include_top:
+ self.flatten = nn.Flatten()
+ self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
+ self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
+ self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
+
+
+ def construct(self, x):
+ x = self.conv1(x)
+ x = self.relu(x)
+ x = self.max_pool2d(x)
+ x = self.conv2(x)
+ x = self.relu(x)
+ x = self.max_pool2d(x)
+ if not self.include_top:
+ return x
+ x = self.flatten(x)
+ x = self.relu(self.fc1(x))
+ x = self.relu(self.fc2(x))
+ x = self.fc3(x)
+ return x