diff --git a/docs/mindspore/source_en/migration_guide/enveriment_preparation.md b/docs/mindspore/source_en/migration_guide/enveriment_preparation.md
index 25483ac57f19bee00fd1edde3b2a4445b41e315d..634488fc37dcf1be78dc46b6e873e38e9c0fa2ab 100644
--- a/docs/mindspore/source_en/migration_guide/enveriment_preparation.md
+++ b/docs/mindspore/source_en/migration_guide/enveriment_preparation.md
@@ -56,8 +56,6 @@ ModelArts is a one-stop development platform for AI developers provided by HUAWE
The storage methods supported by the development environment are OBS and EFS, as shown in the following figure.
-
-
OBS: also called S3 bucket. The code, data, and pre-trained models stored on the OBS need to be transferred to the corresponding physical machine first in the development environment and training environment before the job is executed. [Upload local files to OBS](https://bbs.huaweicloud.com/blogs/212453).
[MoXing](https://bbs.huaweicloud.com/blogs/101129): MoXing is a network model development API provided by Huawei Cloud Deep Learning Service, the use of which needs to focus on the data copy interface.
diff --git a/docs/mindspore/source_en/migration_guide/model_development/dataset.md b/docs/mindspore/source_en/migration_guide/model_development/dataset.md
new file mode 100644
index 0000000000000000000000000000000000000000..b1e67584ea081f80395779130f6f047121801e6a
--- /dev/null
+++ b/docs/mindspore/source_en/migration_guide/model_development/dataset.md
@@ -0,0 +1,238 @@
+# Constructing Dataset
+
+
+
+This chapter focuses on considerations related to data processing in network migration. For basic data processing, please refer to:
+
+[Data Processing](https://www.mindspore.cn/tutorials/en/master/beginner/dataset.html)
+
+[Auto Augmentation](https://www.mindspore.cn/tutorials/experts/en/master/dataset/augment.html)
+
+[Lightweight Data Processing](https://www.mindspore.cn/tutorials/experts/en/master/dataset/eager.html)
+
+[Optimizing the Data Processing](https://www.mindspore.cn/tutorials/experts/en/master/dataset/optimize.html)
+
+## Basic Process of Data Construction
+
+The whole basic process of data construction consists of two main aspects: dataset loading and data augmentation.
+
+### Loading Dataset
+
+MindSpore provides interfaces for loading many common datasets. The most used ones are as follows:
+
+| Data Interfaces | Introduction |
+| -------| ---- |
+| [Cifar10Dataset](https://www.mindspore.cn/docs/en/master/api_python/dataset/mindspore.dataset.Cifar10Dataset.html#mindspore.dataset.Cifar10Dataset) | Cifar10 dataset read interface (you need to download the original bin file of Cifar10) |
+| [MNIST](https://www.mindspore.cn/docs/en/master/api_python/dataset/mindspore.dataset.MnistDataset.html#mindspore.dataset.MnistDataset) | Minist handwritten digit recognition dataset (you need to download the original file) |
+| [ImageFolderDataset](https://www.mindspore.cn/docs/en/master/api_python/dataset/mindspore.dataset.ImageFolderDataset.html) | Dataset reading method of using file directories as the data organization format for classification (common for ImageNet) |
+| [MindDataset](https://www.mindspore.cn/docs/en/master/api_python/dataset/mindspore.dataset.MindDataset.html#mindspore.dataset.MindDataset) | MindRecord data read interface |
+| [GeneratorDataset](https://www.mindspore.cn/docs/en/master/api_python/dataset/mindspore.dataset.GeneratorDataset.html) | customized data interface |
+| [FakeImageDataset](https://www.mindspore.cn/docs/en/master/api_python/dataset/mindspore.dataset.FakeImageDataset.html) | Constructing a fake image dataset |
+
+For common dataset interfaces in different fields, refer to [Loading interfaces to common datasets](https://www.mindspore.cn/docs/en/master/api_python/mindspore.dataset.html).
+
+#### GeneratorDataset, A Custom Dataset
+
+The basic process for constructing a custom Dataset object is as follows:
+
+First create an iterator class and define three methods `__init__`, `__getitem__` and `__len__` in that class.
+
+```python
+import numpy as np
+class MyDataset():
+ """Self Defined dataset."""
+ def __init__(self, n):
+ self.data = []
+ self.label = []
+ for _ in range(n):
+ self.data.append(np.zeros((3, 4, 5)))
+ self.label.append(np.ones((1)))
+ def __len__(self):
+ return len(self.data)
+ def __getitem__(self, idx):
+ data = self.data[idx]
+ label = self.label[idx]
+ return data, label
+```
+
+> It is better not to use MindSpore operators in the iterator class, because it usually adds multiple threads in the data processing phase, which may cause problems.
+>
+> The output of the iterator needs to be a numpy array.
+>
+> The iterator must set the `__len__` method, and the returned result must be the real dataset size. Setting it larger will cause problems when gettingitem takes values.
+
+Then use GeneratorDataset to encapsulate the iterator class:
+
+```python
+import mindspore.dataset as ds
+
+my_dataset = MyDataset(10)
+# corresponding to torch.utils.data.DataLoader(my_dataset)
+dataset = ds.GeneratorDataset(my_dataset, column_names=["data", "label"])
+```
+
+When customizing the dataset, you need to set a name for each output, such as `column_names=["data", "label"]` above, indicating that the first output column of the iterator is `data` and the second is `label`. In the subsequent data augmentation and data iteration obtaining phases, the different columns can be processed separately by name.
+
+**All MindSpore data interfaces** have some general attribute controls, and some of the common ones are described here:
+
+| Attributes | Introduction |
+| ---- | ---- |
+| num_samples(int) | Specify the total number of data samples |
+| shuffle(bool) | Whether to do random disruptions to the data |
+| sampler(Sampler) | Data sampler, customizing data disruption, allocation. `sampler` setting and `num_shards`, `shard_id` mutually exclusive |
+| num_shards(int) | Used in distributed scenarios to divide data into several parts, used in conjunction with `shard_id` |
+| shard_id(int) | For distributed scenarios, taking nth data (n ranges from 0 to n-1, and n is the set `num_shards`), used in conjunction with `num_shards` |
+| num_parallel_workers(int) | Number of threads in parallel configuration |
+
+For example:
+
+```python
+import mindspore.dataset as ds
+
+dataset = ds.FakeImageDataset(num_images=1000, image_size=(32, 32, 3), num_classes=10, base_seed=0)
+print(dataset.get_dataset_size())
+# 1000
+
+dataset = ds.FakeImageDataset(num_images=1000, image_size=(32, 32, 3), num_classes=10, base_seed=0, num_samples=3)
+print(dataset.get_dataset_size())
+# 3
+
+dataset = ds.FakeImageDataset(num_images=1000, image_size=(32, 32, 3), num_classes=10, base_seed=0,
+ num_shards=8, shard_id=0)
+print(dataset.get_dataset_size())
+# 1000 / 8 = 125
+```
+
+```text
+1000
+3
+125
+```
+
+### Data Processing and Augmentation
+
+MindSpore dataset object uses the map interface for data augmentation. See [map Interface](https://www.mindspore.cn/docs/en/master/api_python/dataset/dataset_method/operation/mindspore.dataset.Dataset.map.html#mindspore.dataset.Dataset.map)
+
+```text
+map(operations, input_columns=None, output_columns=None, column_order=None, num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None, max_rowsize=16, offload=None)
+```
+
+Given a set of data augmentation lists, data augmentations are applied to the dataset objects in order.
+
+Each data augmentation operation takes one or more data columns in the dataset object as input and outputs the result of the data augmentation as one or more data columns. The first data augmentation operation takes the specified columns in input_columns as input. If there are multiple data augmentation operations in the data augmentation list, the output columns of the previous data augmentation will be used as input columns for the next data augmentation.
+
+The column name of output column in the last data augmentation is specified by `output_columns`. If `output_columns` is not specified, the output column name is the same as `input_columns`.
+
+The above introduction may be tedious, but in short, `map` is to do the operations specified in `operations` on some columns of the dataset. Here `operations` can be the data augmentation provided by MindSpore.
+
+[audio](https://www.mindspore.cn/docs/en/master/api_python/mindspore.dataset.audio.html), [text](https://www.mindspore.cn/docs/en/master/api_python/mindspore.dataset.text.html), [vision](https://www.mindspore.cn/docs/en/master/api_python/mindspore.dataset.vision.html), and [transforms](https://www.mindspore.cn/docs/en/master/api_python/mindspore.dataset.transforms.html). For more details, refer to [Data Transforms](https://www.mindspore.cn/tutorials/en/master/beginner/transforms.html), which is also a method of python. You can use opencv, PIL, pandas and some other third party methods, like loading dataset. **Don't use MindSpore operators**.
+
+MindSpore also provides some common random augmentation methods: [Auto augmentation](https://www.mindspore.cn/tutorials/experts/en/master/dataset/augment.html). When using data augmentation specifically, it is best to read [Optimizing the Data Processing](https://www.mindspore.cn/tutorials/experts/en/master/dataset/optimize.html) in the recommended order.
+
+At the end of data augmentation, you can use the batch operator to merge `batch_size` pieces of consecutive data in the dataset into a single batch data. For details, please refer to [batch](https://www.mindspore.cn/docs/en/master/api_python/dataset/dataset_method/batch/mindspore.dataset.Dataset.batch.html#mindspore.dataset.dataset.batch). Note that the parameter `drop_remainder` needs to be set to True during training and False during inference.
+
+```python
+import mindspore.dataset as ds
+
+dataset = ds.FakeImageDataset(num_images=1000, image_size=(32, 32, 3), num_classes=10, base_seed=0)\
+ .batch(32, drop_remainder=True)
+print(dataset.get_dataset_size())
+# 1000 // 32 = 31
+
+dataset = ds.FakeImageDataset(num_images=1000, image_size=(32, 32, 3), num_classes=10, base_seed=0)\
+ .batch(32, drop_remainder=False)
+print(dataset.get_dataset_size())
+# ceil(1000 / 32) = 32
+```
+
+```text
+31
+32
+```
+
+The batch operator can also use some augmentation operations within batch. For details, see [YOLOv3](https://gitee.com/mindspore/models/blob/master/official/cv/yolov3_darknet53/src/yolo_dataset.py#L177).
+
+## Data Iteration
+
+MindSpore data objects are obtained iteratively in the following ways.
+
+### [create_dict_iterator](https://www.mindspore.cn/docs/en/master/api_python/dataset/dataset_method/iterator/mindspore.dataset.Dataset.create_dict_iterator.html#mindspore.dataset.Dataset.create_dict_iterator)
+
+Creates an iterator based on the dataset object, and the output data is of dictionary type.
+
+```python
+import mindspore.dataset as ds
+dataset = ds.FakeImageDataset(num_images=20, image_size=(32, 32, 3), num_classes=10, base_seed=0)
+dataset = dataset.batch(10, drop_remainder=True)
+iterator = dataset.create_dict_iterator()
+for data_dict in iterator:
+ for name in data_dict.keys():
+ print(name, data_dict[name].shape)
+ print("="*20)
+```
+
+```text
+image (10, 32, 32, 3)
+label (10,)
+====================
+image (10, 32, 32, 3)
+label (10,)
+====================
+```
+
+### [create_tuple_iterator](https://www.mindspore.cn/docs/en/master/api_python/dataset/dataset_method/iterator/mindspore.dataset.Dataset.create_tuple_iterator.html#mindspore.dataset.Dataset.create_tuple_iterator)
+
+Create an iterator based on the dataset object, and output data is a list of `numpy.ndarray` data.
+
+You can specify all column names and the order of the columns in the output by the parameter `columns`. If columns is not specified, the order of the columns will remain the same.
+
+```python
+import mindspore.dataset as ds
+dataset = ds.FakeImageDataset(num_images=20, image_size=(32, 32, 3), num_classes=10, base_seed=0)
+dataset = dataset.batch(10, drop_remainder=True)
+iterator = dataset.create_tuple_iterator()
+for data_tuple in iterator:
+ for data in data_tuple:
+ print(data.shape)
+ print("="*20)
+```
+
+```text
+(10, 32, 32, 3)
+(10,)
+====================
+(10, 32, 32, 3)
+(10,)
+====================
+```
+
+### Traversing Directly over dataset Objects
+
+> Note that this writing method does not `shuffle` after traversing an epoch, so it may affect the precision when used in training. The above two methods are recommended when direct data iterations are needed during training.
+
+```python
+import mindspore.dataset as ds
+dataset = ds.FakeImageDataset(num_images=20, image_size=(32, 32, 3), num_classes=10, base_seed=0)
+dataset = dataset.batch(10, drop_remainder=True)
+
+for data in dataset:
+ for data in data:
+ print(data.shape)
+ print("="*20)
+```
+
+```text
+(10, 32, 32, 3)
+(10,)
+====================
+(10, 32, 32, 3)
+(10,)
+====================
+```
+
+The latter two of these can be used directly when the order of data read is the same as the order required by the network.
+
+```python
+for data in dataset:
+ loss = net(*data)
+```
diff --git a/docs/mindspore/source_en/migration_guide/model_development/learning_rate_and_optimizer.md b/docs/mindspore/source_en/migration_guide/model_development/learning_rate_and_optimizer.md
new file mode 100644
index 0000000000000000000000000000000000000000..10903e53d0ce2288d6d6ed00ce79a01fcc015470
--- /dev/null
+++ b/docs/mindspore/source_en/migration_guide/model_development/learning_rate_and_optimizer.md
@@ -0,0 +1,98 @@
+# Learning Rate and Optimizer
+
+
+
+Before reading this chapter, please read the official MindSpore tutorial [Optimizer](https://www.mindspore.cn/tutorials/en/master/advanced/modules/optim.html).
+
+The chapter of official tutorial optimizer in MindSpore is already detailed, so here is an introduction to some special ways of using MindSpore optimizer and the principle of learning rate decay strategy.
+
+## Parameters Grouping
+
+MindSpore optimizer supports some special operations, such as different learning rates (lr), weight_decay and gradient_centralization strategies can be set for all trainable parameters in the network. For example:
+
+```python
+from mindspore import nn
+from mindvision.classification.models import resnet50
+
+def params_not_in(param, param_list):
+ # Use the Parameter id to determine if param is not in the param_list
+ param_id = id(param)
+ for p in param_list:
+ if id(p) == param_id:
+ return False
+ return True
+
+resnet = resnet50(pretrained=False)
+trainable_param = resnet.trainable_params()
+conv_weight, bn_weight, dense_weight = [], [], []
+for _, cell in resnet.cells_and_names():
+ # Determine what the API is and add the corresponding parameters to the different lists
+ if isinstance(cell, nn.Conv2d):
+ conv_weight.append(cell.weight)
+ elif isinstance(cell, nn.BatchNorm2d):
+ bn_weight.append(cell.gamma)
+ bn_weight.append(cell.beta)
+ elif isinstance(cell, nn.Dense):
+ dense_weight.append(cell.weight)
+
+other_param = []
+# The parameters in all groups cannot be duplicated, and the intersection between groups is all the parameters that need to be updated
+for param in trainable_param:
+ if params_not_in(param, conv_weight) and params_not_in(param, bn_weight) and params_not_in(param, dense_weight):
+ other_param.append(param)
+
+group_param = [{'order_params': trainable_param}]
+# The parameter list for each group cannot be empty
+
+if conv_weight:
+ conv_weight_lr = nn.cosine_decay_lr(0., 1e-3, total_step=1000, step_per_epoch=100, decay_epoch=10)
+ group_param.append({'params': conv_weight, 'weight_decay': 1e-4, 'lr': conv_weight_lr})
+if bn_weight:
+ group_param.append({'params': bn_weight, 'weight_decay': 0., 'lr': 1e-4})
+if dense_weight:
+ group_param.append({'params': dense_weight, 'weight_decay': 1e-5, 'lr': 1e-3})
+if other_param:
+ group_param.append({'params': other_param})
+
+opt = nn.Momentum(group_param, learning_rate=1e-3, weight_decay=0.0, momentum=0.9)
+```
+
+The following points need to be noted:
+
+1. The list of parameters for each group cannot be empty.
+2. Use the values set in the optimizer if `weight_decay` and `lr` are not set, and use the values in the grouping parameter dictionary if they are set.
+3. `lr` in each group can be static or dynamic, but cannot be regrouped.
+4. `weight_decay` in each group needs to be a conforming floating point number.
+5. The parameters in all groups cannot be duplicated, and the intersection between groups is all the parameters that need to be updated.
+
+## MindSpore Learning Rate Decay Strategy
+
+During the training process, MindSpore learning rate is in the form of parameters in the network. Before executing the optimizer to update the trainable parameters in network, MindSpore will call [get_lr](https://www.mindspore.cn/docs/en/master/api_python/nn/mindspore.nn.Optimizer.html#mindspore.nn.Optimizer.get_lr) to get the value of the learning rate needed for the current step.
+
+MindSpore learning rate supports static, dynamic, and grouping, where the static learning rate is a Tensor in float32 type in the network.
+
+There are two types of dynamic learning rates, one is a Tensor in the network, with the length of the total number of steps of training and in float32 type, such as [Dynamic LR function](https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#dynamic-lr-function). There is `global_step` in the optimizer, and the parameter will be +1 for every optimizer update. MindSpore will internally get the learning rate value of the current step based on the parameters `global_step` and `learning_rate`.
+
+The other one is the one that generates the value of learning rate by composition, such as [LearningRateSchedule class](https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class).
+
+The grouping learning rate is as described in parameter grouping in the previous section.
+
+Because the learning rate of MindSpore is a parameter, we can also modify the value of learning rate during training by assigning values to `learning_rate` parameter, as in [LearningRateScheduler Callback](https://www.mindspore.cn/docs/zh-CN/master/_modules/mindspore/train/callback/_lr_scheduler_callback.html#LearningRateScheduler). This method only supports static learning rates passed into the optimizer. The key code is as follows:
+
+```python
+import mindspore as ms
+from mindspore import ops, nn
+
+net = nn.Dense(1, 2)
+optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
+print(optimizer.learning_rate.data.asnumpy())
+new_lr = 0.01
+# Rewrite the value of the learning_rate parameter
+ops.assign(optimizer.learning_rate, ms.Tensor(new_lr, ms.float32))
+print(optimizer.learning_rate.data.asnumpy())
+```
+
+```text
+0.1
+0.01
+```
diff --git a/docs/mindspore/source_en/migration_guide/model_development/model_development.md b/docs/mindspore/source_en/migration_guide/model_development/model_development.md
new file mode 100644
index 0000000000000000000000000000000000000000..3d76c473b03f531ed905eb255c3281fc33c963a7
--- /dev/null
+++ b/docs/mindspore/source_en/migration_guide/model_development/model_development.md
@@ -0,0 +1,72 @@
+# Constructing MindSpore Network
+
+
+
+This chapter will introduce the related contents of MindSpore scripting, including datasets, network models and loss functions, optimizers, training processes, inference processes from the basic modules needed for training and inference. It will include some functional techniques commonly used in network migration, such as network writing specifications, training and inference process templates, and dynamic shape mitigation strategies.
+
+## Network Training Principle
+
+The basic principle of network training is shown in the figure above.
+
+The training process of the whole network consists of 5 modules:
+
+- dataset: for obtaining data, containing input of network and labels. MindSpore provides a basic [common dataset processing interface](https://www.mindspore.cn/docs/en/master/api_python/mindspore.dataset.html), and also supports constructing datasets by using python iterators.
+
+- network: network model implementation, typically encapsulated by using Cell. Declare the required modules and operators in init, and implement graph construction in construct.
+
+- loss: loss function. Used to measure the degree of difference between the predicted value and the true value. In deep learning, model training is the process of shrinking the loss function value by iterating continuously. Defining a good loss function can help the loss function value converge faster to achieve better precision. MindSpore provides many [common loss functions](https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#loss-function), but of course you can define and implement your own loss function.
+
+- Automatic gradient derivation: Generally, network and loss are encapsulated together as a forward network and the forward network is given to the automatic gradient derivation module for gradient calculation. MindSpore provides an automatic gradient derivation interface, which shields the user from a large number of derivation details and procedures and greatly reduces the threshold of framework. When you need to customize the gradient, MindSpore also provides [interface](https://www.mindspore.cn/tutorials/experts/en/master/network/custom_cell_reverse.html) to freely implement the gradient calculation.
+
+- Optimizer: used to calculate and update network parameters during model training. MindSpore provides a number of [general-purpose optimizers](https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#optimizer) for users to choose, and also supports users to customize the optimizers.
+
+## Principles of Network Inference
+
+The basic principles of network inference are shown in the figure above.
+
+The training process of the whole network consists of 3 modules:
+
+- dataset: used to obtain data, including the input of the network and labels. Since entire inference dataset needs to be inferred during inference process, batchsize is recommended to set to 1. If batchsize is not 1, note that when adding batch, add drop_remainder=False. In addition the inference process is a fixed process. Loading the same parameters every time has the same inference results, and the inference process should not have random data augmentation.
+
+- network: network model implementation, generally encapsulated by using Cell. The network structure during inference is generally the same as the network structure during training. It should be noted that Cell is tagged with set_train(False) for inference and set_train(True) for training, just like PyTorch model.eval() (model evaluation mode) and model.train() (model training mode).
+
+- metrics: When the training task is over, evaluation metrics (Metrics) and evaluation functions are used to assess whether the model works well. Commonly used evaluation metrics include Confusion Matrix, Accuracy, Precision, and Recall. The mindspore.nn module provides the common [evaluation functions](https://www.mindspore.cn/docs/en/master/api_python/mindspore.train.html#evaluation-metrics), and users can also define their own evaluation metrics as needed. Customized Metrics functions need to inherit nn.Metric parent class and reimplement the clear method, update method and eval method of the parent class.
+
+## Constructing Network
+
+After understanding the process of network training and inference, the following describes how to implement the process of network training and inference on MindSpore.
+
+- [Constructing Dataset](https://www.mindspore.cn/docs/en/master/migration_guide/model_development/dataset.html)
+- Network Body and Loss Building
+- [Learning Rate and Optimizer](https://www.mindspore.cn/docs/en/master/migration_guide/model_development/learning_rate_and_optimizer.html)
+- [Training Network and Gradient Derivation](https://www.mindspore.cn/docs/en/master/migration_guide/model_development/training_and_gradient.html)
+- [Inference and Training Process](https://www.mindspore.cn/docs/en/master/migration_guide/model_development/training_and_evaluation_procession.html)
+
+> When doing network migration, we recommend doing inference validation of the model as a priority after completing the network scripting. This has several benefits:
+>
+> - Compared with training, the inference process is fixed and able to be compared with the reference implementation.
+> - Compared with training, the time required for inference is relatively short, enabling rapid verification of the correctness of the network structure and inference process.
+> - The trained results need to be validated through the inference process to verify results of the model. It is necessary that the correctness of the inference be ensured first, then to prove that the training is valid.
+
+## Considerations for MindSpore Network Authoring
+
+During MindSpore network implementation, there are some problem-prone areas. When you encounter problems, please prioritize troubleshooting for the following situations:
+
+1. The MindSpore operator is used in data processing. Multi-threaded/multi-process is usually in the data processing process, so there is a limitation of using MindSpore operators in this scenario. It is recommended to use a three-party implementation instead of the operator use in the data processing process, such as numpy, opencv, pandas, PIL.
+2. Control flow. For details, refer to [Flow Control Statements](https://www.mindspore.cn/tutorials/experts/en/master/network/control_flow.html). Compilation in graph mode can be slow when multiple layers of conditional control statements are called.
+3. Slicing operation. When it comes to slicing a Tesnor, note that whether subscript of the slice is a variable. When it is a variable, there will be restrictions. Please refer to network body and loss building for dynamic shape mitigation.
+4. Customized mixed precision conflicts with `amp_level` in Model, so don't set `amp_level` in Model if you use customized mixed precision.
+5. In Ascend environment, Conv, Sort and TopK can only be float16, and add [loss scale](https://mindspore.cn/tutorials/experts/en/master/others/mixed_precision.html) to avoid overflow.
+6. In the Ascend environment, operators with the stride property such as Conv and Pooling have rules about the length of the stride, which needs to be mitigated.
+7. In a distributed environment, seed must be added to ensure that the initialized parameters of multiple cards are consistent.
+8. In the case of using list of Cell or list of Parameter in the network, please convert the list to [CellList](https://www.mindspore.cn/docs/en/master/api_python/nn/mindspore.nn.CellList.html), [SequentialCell](https://www.mindspore.cn/docs/en/master/api_python/nn/mindspore.nn.SequentialCell.html), and [ParameterTuple](https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.ParameterTuple.html) in `init`.
+
+```python
+# Define the required layers for graph construction in init, and don't write it like this
+self.layer = [nn.Conv2d(1, 3), nn.BatchNorm(3), nn.ReLU()]
+
+# Need to encapsulate as CellList or SequentialCell
+self.layer = nn.CellList([nn.Conv2d(1, 3), nn.BatchNorm(3), nn.ReLU()])
+# Or
+self.layer = nn.SequentialCell([nn.Conv2d(1, 3), nn.BatchNorm(3), nn.ReLU()])
+```
diff --git a/docs/mindspore/source_en/migration_guide/model_development/training_and_evaluation_procession.md b/docs/mindspore/source_en/migration_guide/model_development/training_and_evaluation_procession.md
new file mode 100644
index 0000000000000000000000000000000000000000..cfb4892452344ce50dedbf0001c03d67d18d69ad
--- /dev/null
+++ b/docs/mindspore/source_en/migration_guide/model_development/training_and_evaluation_procession.md
@@ -0,0 +1,255 @@
+# Inference and Training Process
+
+
+
+## General Operating Environment Settings
+
+We generally need to set up the operating environment before network training and inference, and a general operating environment configuration is given here.
+
+```python
+import mindspore as ms
+from mindspore.communication.management import init, get_rank, get_group_size
+
+def init_env(cfg):
+ """Initialize the operating environment."""
+ ms.set_seed(cfg.seed)
+ # If device_target is set to None, use the framework to get device_target automatically, otherwise use the set one.
+ if cfg.device_target != "None":
+ if cfg.device_target not in ["Ascend", "GPU", "CPU"]:
+ raise ValueError(f"Invalid device_target: {cfg.device_target}, "
+ f"should be in ['None', 'Ascend', 'GPU', 'CPU']")
+ ms.set_context(device_target=cfg.device_target)
+
+ # Configure operation mode, and support graph mode and PYNATIVE mode
+ if cfg.context_mode not in ["graph", "pynative"]:
+ raise ValueError(f"Invalid context_mode: {cfg.context_mode}, "
+ f"should be in ['graph', 'pynative']")
+ context_mode = ms.GRAPH_MODE if cfg.context_mode == "graph" else ms.PYNATIVE_MODE
+ ms.set_context(mode=context_mode)
+
+ cfg.device_target = ms.get_context("device_target")
+ # If running on CPU, not configure multiple-cards environment
+ if cfg.device_target == "CPU":
+ cfg.device_id = 0
+ cfg.device_num = 1
+ cfg.rank_id = 0
+
+ # Set the card to be used at runtime
+ if hasattr(cfg, "device_id") and isinstance(cfg.device_id, int):
+ ms.set_context(device_id=cfg.device_id)
+
+ if cfg.device_num > 1:
+ # The init method is used to initialize multiple cards, and does not distinguish between Ascend and GPU. get_group_size and get_rank can only be used after init
+ init()
+ print("run distribute!", flush=True)
+ group_size = get_group_size()
+ if cfg.device_num != group_size:
+ raise ValueError(f"the setting device_num: {cfg.device_num} not equal to the real group_size: {group_size}")
+ cfg.rank_id = get_rank()
+ ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True)
+ if hasattr(cfg, "all_reduce_fusion_config"):
+ ms.set_auto_parallel_context(all_reduce_fusion_config=cfg.all_reduce_fusion_config)
+ else:
+ cfg.device_num = 1
+ cfg.rank_id = 0
+ print("run standalone!", flush=True)
+```
+
+cfg is the parameter configuration file. Using this template requires at least the following parameters to be configured.
+
+```yaml
+seed: 1
+device_target: "None"
+context_mode: "graph" # should be in ['graph', 'pynative']
+device_num: 1
+device_id: 0
+```
+
+The above procedure is just a basic configuration of the operating environment. If you need to add some advanced features, please refer to [set_context](https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.set_context.html#mindspore.set_context).
+
+## Generic Scripting Framework
+
+A generic [script rack](https://gitee.com/mindspore/models/tree/master/utils/model_scaffolding) provided by the models bin is used for:
+
+1. yaml parameter file parsing, parameter obtaining
+2. ModelArts unified tool both on the cloud and on-premise
+
+The python files in the src directory are placed in the model_utils directory for use, e.g. [resnet](https://gitee.com/mindspore/models/tree/master/official/cv/resnet/src/model_utils).
+
+## Inference Process
+
+A generic inference process is as follows:
+
+```python
+import mindspore as ms
+from mindspore import nn
+from src.model import Net
+from src.dataset import create_dataset
+from src.utils import init_env
+from src.model_utils.config import config
+
+# Initialize the operating environment
+init_env(config)
+# Constructing dataset objects
+dataset = create_dataset(config, is_train=False)
+# Network model, task-related
+net = Net()
+# Loss function, task-related
+loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
+# Load the trained parameters
+ms.load_checkpoint(config.checkpoint_path, net)
+# Encapsulation into Model
+model = ms.Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
+# Model inference
+res = model.eval(dataset)
+print("result:", res, "ckpt=", config.checkpoint_path)
+```
+
+Generally, the source code for network construction and data processing will be placed in the `src` directory, and the scripting framework will be placed in the `src.model_utils` directory. For example, you can refer to the implementation in [MindSpore models](https://gitee.com/mindspore/models).
+
+The inference process cannot be encapsulated into a Model for operation sometimes, and then the inference process can be expanded into the form of a for loop. See [ssd inference](https://gitee.com/mindspore/models/blob/master/official/cv/ssd/eval.py).
+
+### Inference Verification
+
+In the model analysis and preparation phase, we get the trained parameters of the reference implementation (in the reference implementation README or for training replication). Since the implementation of the model algorithm is not related to the framework, the trained parameters can be first converted into MindSpore [checkpoint](https://www.mindspore.cn/tutorials/en/master/beginner/save_load.html) and loaded into the network for inference verification.
+
+Please refer to resnet network migration for the whole process of inference verification.
+
+## Training Process
+
+A general training process is as follows:
+
+```python
+import mindspore as ms
+from mindspore import nn
+from mindspore import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
+from src.model import Net
+from src.dataset import create_dataset
+from src.utils import init_env
+from src.model_utils.config import config
+from src.model_utils.moxing_adapter import moxing_wrapper
+
+@moxing_wrapper()
+def train_net():
+ # Initialize the operating environment
+ init_env(config)
+ # Constructing dataset objects
+ dataset = create_dataset(config, is_train=False)
+ # Network model, task-related
+ net = Net()
+ # Loss function, task-related
+ loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
+ # Optimizer implementation, task-related
+ optimizer = nn.Adam(net.trainable_params(), config.lr, weight_decay=config.weight_decay)
+ # Encapsulation into Model
+ model = ms.Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
+ # checkpoint saving
+ config_ck = CheckpointConfig(save_checkpoint_steps=dataset.get_dataset_size(),
+ keep_checkpoint_max=5)
+ ckpt_cb = ModelCheckpoint(prefix="resnet", directory="./checkpoint", config=config_ck)
+ # Model training
+ model.train(config.epoch, dataset, callbacks=[LossMonitor(), TimeMonitor()])
+
+if __name__ == '__main__':
+ train_net()
+```
+
+Please refer to [Save and Load](https://www.mindspore.cn/tutorials/en/master/beginner/save_load.html) for checkpoint saving.
+
+### Distributed Training
+
+The multi-card distributed training process is the same as the single-card training process, except for the distributed-related configuration items and gradient aggregation. It should be noted that multi-card parallelism actually starts multiple python processes on MindSpore, and before MindSpore version 1.8, on Ascend environment, multiple processes need to be started manually.
+
+```shell
+if [ $# != 4 ]
+then
+ echo "Usage: sh run_distribution_ascend.sh [DEVICE_NUM] [START_ID] [RANK_TABLE_FILE] [CONFIG_PATH]"
+exit 1
+fi
+
+get_real_path(){
+ if [ "${1:0:1}" == "/" ]; then
+ echo $1
+ else
+ echo "$(realpath -m ${PWD}/$1)"
+ fi
+}
+
+RANK_TABLE_FILE=$(get_real_path $3)
+CONFIG_PATH=$(get_real_path $4)
+
+if [ ! -f $RANK_TABLE_FILE ]
+then
+ echo "error: RANK_TABLE_FILE=$RANK_TABLE_FILE is not a file"
+exit 1
+fi
+
+if [ ! -f $CONFIG_PATH ]
+then
+ echo "error: CONFIG_PATH=$CONFIG_PATH is not a file"
+exit 1
+fi
+
+BASE_PATH=$(cd ./"`dirname $0`" || exit; pwd)
+
+export RANK_SIZE=$1
+STRAT_ID=$2
+export RANK_TABLE_FILE=$RANK_TABLE_FILE
+
+cd $BASE_PATH
+for((i=0; i<${RANK_SIZE}; i++))
+do
+ export DEVICE_ID=$((STRAT_ID + i))
+ export RANK_ID=$i
+ rm -rf ./train_parallel$i
+ mkdir ./train_parallel$i
+ cp -r ../src ./train_parallel$i
+ cp ../*.py ./train_parallel$i
+ echo "start training for rank $RANK_ID, device $DEVICE_ID"
+ cd ./train_parallel$i ||exit
+ env > env.log
+ python train.py --config_path=$CONFIG_FILE --device_num=$RANK_SIZE > log.txt 2>&1 &
+ cd ..
+done
+```
+
+After MindSpore 1.8, it can be launched with mpirun as well as the GPU.
+
+```shell
+if [ $# != 2 ]
+then
+ echo "Usage: sh run_distribution_ascend.sh [DEVICE_NUM] [CONFIG_PATH]"
+exit 1
+fi
+
+get_real_path(){
+ if [ "${1:0:1}" == "/" ]; then
+ echo $1
+ else
+ echo "$(realpath -m ${PWD}/$1)"
+ fi
+}
+
+CONFIG_PATH=$(get_real_path $2)
+
+if [ ! -f $CONFIG_PATH ]
+then
+ echo "error: CONFIG_PATH=$CONFIG_PATH is not a file"
+exit 1
+fi
+
+BASE_PATH=$(cd ./"`dirname $0`" || exit; pwd)
+
+export RANK_SIZE=$1
+
+cd $BASE_PATH
+mpirun --allow-run-as-root -n $RANK_SIZE python ../train.py --config_path=$CONFIG_FILE --device_num=$RANK_SIZE > log.txt 2>&1 &
+```
+
+If on the GPU, you can set which cards to use by `export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7`. Specifying the card number is not currently supported on Ascend.
+
+Please refer to [Distributed Case](https://www.mindspore.cn/tutorials/experts/en/master/parallel/distributed_case.html) for more details.
+
+## Offline Inference
+
+In addition to the possibility of online reasoning, MindSpore provides many offline inference methods for different environments. Please refer to [Model Inference](https://www.mindspore.cn/tutorials/experts/en/master/infer/inference.html) for details.
diff --git a/docs/mindspore/source_zh_cn/migration_guide/model_development/dataset.ipynb b/docs/mindspore/source_zh_cn/migration_guide/model_development/dataset.ipynb
index c27b03e071f30a95a8cdd3047b2f34354fd76d5c..440ff1bdafed774faaf6a8760ba0a0241fd4bfcf 100644
--- a/docs/mindspore/source_zh_cn/migration_guide/model_development/dataset.ipynb
+++ b/docs/mindspore/source_zh_cn/migration_guide/model_development/dataset.ipynb
@@ -27,7 +27,7 @@
"\n",
"### 数据集加载\n",
"\n",
- "MindSpore提供了很多常见数据集的加载接口, 用的比较多的接口有:\n",
+ "MindSpore提供了很多常见数据集的加载接口,用的比较多的接口有:\n",
"\n",
"| 数据接口 | 介绍 |\n",
"| -------| ---- |\n",
@@ -183,11 +183,11 @@
"\n",
"给定一组数据增强列表,按顺序将数据增强作用在数据集对象上。\n",
"\n",
- "每个数据增强操作将数据集对象中的一个或多个数据列作为输入,将数据增强的结果输出为一个或多个数据列。 第一个数据增强操作将 input_columns 中指定的列作为输入。 如果数据增强列表中存在多个数据增强操作,则上一个数据增强的输出列将作为下一个数据增强的输入列。\n",
+ "每个数据增强操作将数据集对象中的一个或多个数据列作为输入,将数据增强的结果输出为一个或多个数据列。第一个数据增强操作将 input_columns 中指定的列作为输入。如果数据增强列表中存在多个数据增强操作,则上一个数据增强的输出列将作为下一个数据增强的输入列。\n",
"\n",
"最后一个数据增强的输出列的列名由 `output_columns` 指定,如果没有指定 `output_columns` ,输出列名与 `input_columns` 一致。\n",
"\n",
- "上面的介绍可能比较繁琐,简单来说 `map` 就是在数据集的某些列上做 `operations` 里规定的操作。这里的`operations`可以是MindSpore提供的数据增强操作:\n",
+ "上面的介绍可能比较繁琐,简单来说 `map` 就是在数据集的某些列上做 `operations` 里规定的操作。这里的`operations`可以是MindSpore提供的数据增强操作:\n",
"[audio](https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore.dataset.audio.html)、[text](https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore.dataset.text.html)、[vision](https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore.dataset.vision.html)、[通用](https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore.dataset.transforms.html)。详情请参考[数据变换 Transforms](https://www.mindspore.cn/tutorials/zh-CN/master/beginner/transforms.html)\n",
"也可以是python的方法,里面可以用 opencv,PIL,pandas 等一些三方的方法,和数据集加载一样,**不要使用MindSpore的算子**。\n",
"\n",
diff --git a/docs/mindspore/source_zh_cn/migration_guide/model_development/model_development.md b/docs/mindspore/source_zh_cn/migration_guide/model_development/model_development.md
index 5312f628c53e94859b7fe7943042e402822ca495..c5c7f7db09b7785ba5648440ce8ad3ecfc0a553e 100644
--- a/docs/mindspore/source_zh_cn/migration_guide/model_development/model_development.md
+++ b/docs/mindspore/source_zh_cn/migration_guide/model_development/model_development.md
@@ -16,11 +16,11 @@
- network;网络模型实现,一般使用Cell包装。在init里声明需要的模块和算子,在construct里构图实现。
-- loss;损失函数。用于衡量预测值与真实值差异的程度。深度学习中,模型训练就是通过不停地迭代来缩小损失函数值的过程,定义一个好的损失函数可以帮助损失函数值更快收敛,达到更好的精度, MindSpore提供了很多[常见的loss函数](https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore.nn.html#%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0),当然可以自己定义实现自己的loss函数。
+- loss;损失函数。用于衡量预测值与真实值差异的程度。深度学习中,模型训练就是通过不停地迭代来缩小损失函数值的过程,定义一个好的损失函数可以帮助损失函数值更快收敛,达到更好的精度,MindSpore提供了很多[常见的loss函数](https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore.nn.html#%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0),当然可以自己定义实现自己的loss函数。
-- 自动梯度求导;一般将network和loss一起包装成正向网络一起给到自动梯度求导模块进行梯度计算。MindSpore提供了自动的梯度求导接口,该功能对用户屏蔽了大量的求导细节和过程,大大降低了框架的使用门槛。需要自定义梯度时,MindSpore也提供了[接口](https://www.mindspore.cn/tutorials/experts/zh-CN/master/network/custom_cell_reverse.html) 去自由实现梯度计算。
+- 自动梯度求导;一般将network和loss一起包装成正向网络一起给到自动梯度求导模块进行梯度计算。MindSpore提供了自动的梯度求导接口,该功能对用户屏蔽了大量的求导细节和过程,大大降低了框架的使用门槛。需要自定义梯度时,MindSpore也提供了[接口](https://www.mindspore.cn/tutorials/experts/zh-CN/master/network/custom_cell_reverse.html)去自由实现梯度计算。
-- 优化器;优化器在模型训练过程中,用于计算和更新网络参数。MindSpore提供了许多[通用的优化器](https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore.nn.html#%E4%BC%98%E5%8C%96%E5%99%A8) 供用户选择,同时也支持用户根据需要自定义优化器。
+- 优化器;优化器在模型训练过程中,用于计算和更新网络参数。MindSpore提供了许多[通用的优化器](https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore.nn.html#%E4%BC%98%E5%8C%96%E5%99%A8)供用户选择,同时也支持用户根据需要自定义优化器。
## 网络推理原理
@@ -57,7 +57,7 @@
在MindSpore网络实现过程中,有一些容易出现问题的地方,遇到问题请优先排查是否有以下情况:
1. 数据处理中使用MindSpore的算子。数据处理过程一般会有多线程/多进程,此场景下数据处理使用MindSpore的算子存在限制,数据处理过程中使用的算子建议使用三方的实现代替,如numpy,opencv,pandas,PIL等。
-2. 控制流。详情请参考[流程控制语句](https://mindspore.cn/tutorials/zh-CN/master/advanced/modules/control_flow.html)。当多层调用条件控制语句时在图模式下编译会很慢。
+2. 控制流。详情请参考[流程控制语句](https://www.mindspore.cn/tutorials/experts/zh-CN/master/network/control_flow.html)。当多层调用条件控制语句时在图模式下编译会很慢。
3. 切片操作,当遇到对一个Tesnor进行切片时需要注意,切片的下标是否是变量,当是变量时会有限制,请参考[网络主体和loss搭建](https://www.mindspore.cn/docs/zh-CN/master/migration_guide/model_development/model_and_loss.html)对动态shape规避。
4. 自定义混合精度和Model里的`amp_level`冲突,使用自定义的混合精度就不要设置Model里的`amp_level`。
5. 在Ascend环境下Conv,Sort,TopK只能是float16的,注意加[loss scale](https://mindspore.cn/tutorials/experts/zh-CN/master/others/mixed_precision.html)避免溢出。
diff --git a/docs/mindspore/source_zh_cn/migration_guide/model_development/training_and_evaluation_procession.md b/docs/mindspore/source_zh_cn/migration_guide/model_development/training_and_evaluation_procession.md
index 037cbc923c419900fdfae6fb838ca2052838312f..cbb68ab30f1bb64078c3628aca8c1eea72cfd5bc 100644
--- a/docs/mindspore/source_zh_cn/migration_guide/model_development/training_and_evaluation_procession.md
+++ b/docs/mindspore/source_zh_cn/migration_guide/model_development/training_and_evaluation_procession.md
@@ -13,7 +13,7 @@ from mindspore.communication.management import init, get_rank, get_group_size
def init_env(cfg):
"""初始化运行时环境."""
ms.set_seed(cfg.seed)
- # 如果device_target设置是None, 利用框架自动获取device_target,否则使用设置的。
+ # 如果device_target设置是None,利用框架自动获取device_target,否则使用设置的。
if cfg.device_target != "None":
if cfg.device_target not in ["Ascend", "GPU", "CPU"]:
raise ValueError(f"Invalid device_target: {cfg.device_target}, "
@@ -39,7 +39,7 @@ def init_env(cfg):
ms.set_context(device_id=cfg.device_id)
if cfg.device_num > 1:
- # init方法用于多卡的初始化,不区分Ascend和GPU, get_group_size和get_rank方法只能在init后使用
+ # init方法用于多卡的初始化,不区分Ascend和GPU,get_group_size和get_rank方法只能在init后使用
init()
print("run distribute!", flush=True)
group_size = get_group_size()
diff --git a/tutorials/source_en/beginner/model.md b/tutorials/source_en/beginner/model.md
index cf90b388fad896c4cd5adefb08c97d9e333d361f..2902e4dda1341fb3b18c852ffcf69550b151a5b4 100644
--- a/tutorials/source_en/beginner/model.md
+++ b/tutorials/source_en/beginner/model.md
@@ -42,7 +42,7 @@ After completing construction, instantiate the `Network` object and look at its
```python
model = Network()
-model
+print(model)
```
```text
@@ -65,7 +65,7 @@ We construct an input data and call the model directly to obtain a 10-dimensiona
```python
X = ops.ones((1, 28, 28), mindspore.float32)
logits = model(X)
-logits
+print(logits)
```
```text
@@ -91,7 +91,7 @@ In this section, we decompose each layer of the neural network model constructed
```python
input_image = ops.ones((3, 28, 28), mindspore.float32)
-input_image.shape
+print(input_image.shape)
```
```text
@@ -105,7 +105,7 @@ Initialize the `nn.Flatten` layer and convert a 28x28 2D tensor into a contiguou
```python
flatten = nn.Flatten()
flat_image = flatten(input_image)
-flat_image.shape
+print(flat_image.shape)
```
```text
@@ -119,7 +119,7 @@ flat_image.shape
```python
layer1 = nn.Dense(in_channels=28*28, out_channels=20)
hidden1 = layer1(flat_image)
-hidden1.shape
+print(hidden1.shape)
```
```text
@@ -178,7 +178,7 @@ seq_modules = nn.SequentialCell(
)
logits = seq_modules(input_image)
-logits.shape
+print(logits.shape)
```
```text
diff --git a/tutorials/source_en/beginner/quick_start.md b/tutorials/source_en/beginner/quick_start.md
index c1eea20856a4f8bee528d5c85af429a53f3d2370..53a509ed4f79f551c95f7fdc75279adf91138426 100644
--- a/tutorials/source_en/beginner/quick_start.md
+++ b/tutorials/source_en/beginner/quick_start.md
@@ -46,7 +46,7 @@ test_dataset = test_data.dataset
Print the names of the data columns contained in the dataset for dataset pre-processing.
```python
-train_dataset.column_names
+print(train_dataset.column_names)
```
```text
@@ -201,7 +201,7 @@ def test(model, dataset, loss_fn):
pred = model(data)
total += len(data)
test_loss += loss_fn(pred, label).asnumpy()
- correct += (pred.argmax(1) == label).sum().asnumpy()
+ correct += (pred.argmax(1) == label).asnumpy().sum()
test_loss /= num_batches
correct /= total
print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
diff --git a/tutorials/source_en/beginner/save_load.md b/tutorials/source_en/beginner/save_load.md
index 9944acaa0357b992bc128b3dae8c9f8ebac68272..a7a226011a2adbaa8b8d75d9941ef9228a139828 100644
--- a/tutorials/source_en/beginner/save_load.md
+++ b/tutorials/source_en/beginner/save_load.md
@@ -29,7 +29,7 @@ To load the model weights, you need to create instances of the same model and th
model = models.lenet() # we do not specify pretrained=True, i.e. do not load default weights
param_dict = mindspore.load_checkpoint("lenet.ckpt")
param_not_load = mindspore.load_param_into_net(model, param_dict)
-param_not_load
+print(param_not_load)
```
```text
@@ -52,11 +52,14 @@ mindspore.export(model, inputs, file_name="lenet", file_format="MINDIR")
The existing MindIR model can be easily loaded through the `load` interface and passed into `nn.GraphCell` for inference.
+> `nn.GraphCell` only supports graph mode.
+
```python
+mindspore.set_context(mode=mindspore.GRAPH_MODE)
graph = mindspore.load("lenet.mindir")
model = nn.GraphCell(graph)
outputs = model(inputs)
-outputs.shape
+print(outputs.shape)
```
```text
diff --git a/tutorials/source_en/beginner/train.md b/tutorials/source_en/beginner/train.md
index 18fed90a061e1ee5a72954b9cdcd0bdf56cc9618..246066ddebb75714ff14fb6f1a777fd207a963c7 100644
--- a/tutorials/source_en/beginner/train.md
+++ b/tutorials/source_en/beginner/train.md
@@ -167,7 +167,7 @@ def test_loop(model, dataset, loss_fn):
pred = model(data)
total += len(data)
test_loss += loss_fn(pred, label).asnumpy()
- correct += (pred.argmax(1) == label).sum().asnumpy()
+ correct += (pred.argmax(1) == label).asnumpy().sum()
test_loss /= num_batches
correct /= total
print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
diff --git a/tutorials/source_en/beginner/transforms.md b/tutorials/source_en/beginner/transforms.md
index f09160f574296c63e0d8dd34cba3c68903b0b9b2..c11119bb73f0a71f9bb5b0e11abe2865bb8e537d 100644
--- a/tutorials/source_en/beginner/transforms.md
+++ b/tutorials/source_en/beginner/transforms.md
@@ -35,7 +35,7 @@ train_dataset = training_data.dataset
```python
image, label = next(train_dataset.create_tuple_iterator())
-image.shape
+print(image.shape)
```
```text
@@ -55,7 +55,7 @@ composed = transforms.Compose(
```python
train_dataset = train_dataset.map(composed, 'image')
image, label = next(train_dataset.create_tuple_iterator())
-image.shape
+print(image.shape)
```
```text
@@ -82,7 +82,7 @@ Here we first use numpy to generate a random image with pixel values in \[0, 255
```python
random_np = np.random.randint(0, 255, (48, 48), np.uint8)
random_image = Image.fromarray(random_np)
-random_np
+print(random_np)
```
```text
@@ -100,7 +100,7 @@ To present a more visual comparison of the data before and after Transform, we u
```python
rescale = vision.Rescale(1.0 / 255.0, 0)
rescaled_image = rescale(random_image)
-rescaled_image
+print(rescaled_image)
```
```text
@@ -134,7 +134,7 @@ Each channel of the image will be adjusted according to `mean` and `std`, and th
```python
normalize = vision.Normalize(mean=(0.1307,), std=(0.3081,))
normalized_image = normalize(rescaled_image)
-normalized_image
+print(normalized_image)
```
```text
@@ -163,7 +163,7 @@ Here we first process the `normalized_image` in the previous section to HWC form
hwc_image = np.expand_dims(normalized_image, -1)
hwc2cwh = vision.HWC2CHW()
chw_image = hwc2cwh(hwc_image)
-hwc_image.shape, chw_image.shape
+print(hwc_image.shape, chw_image.shape)
```
```text
@@ -196,7 +196,7 @@ Tokenize is a basic method to process text data. MindSpore provides many differe
```python
test_dataset = test_dataset.map(text.BasicTokenizer())
-next(test_dataset.create_tuple_iterator())
+print(next(test_dataset.create_tuple_iterator()))
```
```text
@@ -214,7 +214,7 @@ vocab = text.Vocab.from_dataset(test_dataset)
After obtaining the vocabulary, we can use the `vocab` method to view the vocabulary.
```python
-vocab.vocab()
+print(vocab.vocab())
```
```text
@@ -237,7 +237,7 @@ After generating the vocabulary, you can perform the vocabulary mapping transfor
```python
test_dataset = test_dataset.map(text.Lookup(vocab))
-next(test_dataset.create_tuple_iterator())
+print(next(test_dataset.create_tuple_iterator()))
```
```text
@@ -253,7 +253,7 @@ Lambda functions are anonymous functions that do not require a name and consist
```python
test_dataset = GeneratorDataset([1, 2, 3], 'data', shuffle=False)
test_dataset = test_dataset.map(lambda x: x * 2)
-list(test_dataset.create_tuple_iterator())
+print(list(test_dataset.create_tuple_iterator()))
```
```text
@@ -274,7 +274,7 @@ test_dataset = test_dataset.map(lambda x: func(x))
```
```python
-list(test_dataset.create_tuple_iterator())
+print(list(test_dataset.create_tuple_iterator()))
```
```text
diff --git a/tutorials/source_zh_cn/beginner/save_load.ipynb b/tutorials/source_zh_cn/beginner/save_load.ipynb
index 63eae5ef3b72459b313c4a57ab3d9e3f171e8e06..37928126c70d270e15f2f07d858812faf76603fa 100644
--- a/tutorials/source_zh_cn/beginner/save_load.ipynb
+++ b/tutorials/source_zh_cn/beginner/save_load.ipynb
@@ -166,7 +166,7 @@
}
],
"source": [
- "mindspore.set_context(mode=mindspore.PYNATIVE_MODE)\n",
+ "mindspore.set_context(mode=mindspore.GRAPH_MODE)\n",
"\n",
"graph = mindspore.load(\"lenet.mindir\")\n",
"model = nn.GraphCell(graph)\n",