diff --git a/docs/source_zh_cn/design/mindspore/ir.md b/docs/source_zh_cn/design/mindspore/ir.md
index 77bc45014d5301fa6f76a9505ce38a491c09b9b4..362544c7f113919404560d8b52b838d632eda6e5 100644
--- a/docs/source_zh_cn/design/mindspore/ir.md
+++ b/docs/source_zh_cn/design/mindspore/ir.md
@@ -1,6 +1,6 @@
# MindSpore IR(MindIR)
-`Linux` `框架开发` `中级` `高级` `贡献者`
+`Linux` `Windows` `框架开发` `中级` `高级` `贡献者`
diff --git a/tutorials/source_en/advanced_use/images/cifar10_c_transforms.png b/tutorials/source_en/advanced_use/images/cifar10_c_transforms.png
new file mode 100644
index 0000000000000000000000000000000000000000..10dc267dc650764566f6d20b7f090e20c12f8e11
Binary files /dev/null and b/tutorials/source_en/advanced_use/images/cifar10_c_transforms.png differ
diff --git a/tutorials/source_en/advanced_use/images/compose.png b/tutorials/source_en/advanced_use/images/compose.png
new file mode 100644
index 0000000000000000000000000000000000000000..97b8ca59f4438852526b56a8a7ce00ff63771b40
Binary files /dev/null and b/tutorials/source_en/advanced_use/images/compose.png differ
diff --git a/tutorials/source_en/advanced_use/images/data_enhancement_performance_scheme.png b/tutorials/source_en/advanced_use/images/data_enhancement_performance_scheme.png
new file mode 100644
index 0000000000000000000000000000000000000000..6417031a63dd2bade4902a83934c05aeee6be195
Binary files /dev/null and b/tutorials/source_en/advanced_use/images/data_enhancement_performance_scheme.png differ
diff --git a/tutorials/source_en/advanced_use/images/data_loading_performance_scheme.png b/tutorials/source_en/advanced_use/images/data_loading_performance_scheme.png
new file mode 100644
index 0000000000000000000000000000000000000000..44c84c1f14dee40cdd76926994ab670494abc006
Binary files /dev/null and b/tutorials/source_en/advanced_use/images/data_loading_performance_scheme.png differ
diff --git a/tutorials/source_en/advanced_use/images/operator_fusion.png b/tutorials/source_en/advanced_use/images/operator_fusion.png
new file mode 100644
index 0000000000000000000000000000000000000000..4aa6ee89a0970889abc84f1b74b95297f2ae2db4
Binary files /dev/null and b/tutorials/source_en/advanced_use/images/operator_fusion.png differ
diff --git a/tutorials/source_en/advanced_use/images/pipeline.png b/tutorials/source_en/advanced_use/images/pipeline.png
new file mode 100644
index 0000000000000000000000000000000000000000..bbb1a391f8378bc02f4d821d657f2c74c21ff24e
Binary files /dev/null and b/tutorials/source_en/advanced_use/images/pipeline.png differ
diff --git a/tutorials/source_en/advanced_use/images/shuffle_performance_scheme.png b/tutorials/source_en/advanced_use/images/shuffle_performance_scheme.png
new file mode 100644
index 0000000000000000000000000000000000000000..f4c72a99fbade41067f9e6dfe6383634d06433a8
Binary files /dev/null and b/tutorials/source_en/advanced_use/images/shuffle_performance_scheme.png differ
diff --git a/tutorials/source_en/advanced_use/optimize_the_performance_of_data_preparation.md b/tutorials/source_en/advanced_use/optimize_the_performance_of_data_preparation.md
new file mode 100644
index 0000000000000000000000000000000000000000..a37367e9b59d5c057c13d762a186b189749490d4
--- /dev/null
+++ b/tutorials/source_en/advanced_use/optimize_the_performance_of_data_preparation.md
@@ -0,0 +1,389 @@
+# Optimizing the Data Preparation Performance
+
+`Linux` `Ascend` `GPU` `CPU` `Data Preparation` `Beginner` `Intermediate` `Expert`
+
+
+
+- [Optimizing the Data Preparation Performance](#optimizing-the-data-preparation-performance)
+ - [Overview](#overview)
+ - [Overall Process](#overall-process)
+ - [Preparations](#preparations)
+ - [Importing Modules](#importing-modules)
+ - [Downloading the Required Dataset](#downloading-the-required-dataset)
+ - [Optimizing the Data Loading Performance](#optimizing-the-data-loading-performance)
+ - [Performance Optimization Solution](#performance-optimization-solution)
+ - [Code Example](#code-example)
+ - [Optimizing the Shuffle Performance](#optimizing-the-shuffle-performance)
+ - [Performance Optimization Solution](#performance-optimization-solution-1)
+ - [Code Example](#code-example-1)
+ - [Optimizing the Data Augmentation Performance](#optimizing-the-data-augmentation-performance)
+ - [Performance Optimization Solution](#performance-optimization-solution-2)
+ - [Code Example](#code-example-2)
+ - [Performance Optimization Solution Summary](#performance-optimization-solution-summary)
+ - [Multi-thread Optimization Solution](#multi-thread-optimization-solution)
+ - [Multi-process Optimization Solution](#multi-process-optimization-solution)
+ - [Compose Optimization Solution](#compose-optimization-solution)
+ - [Operator Fusion Optimization Solution](#operator-fusion-optimization-solution)
+
+
+
+
+
+
+
+## Overview
+
+Data is the most important factor of deep learning. Data quality determines the upper limit of deep learning result, whereas model quality enables the result to approach the upper limit.Therefore, high-quality data input is beneficial to the entire deep neural network. During the entire data processing and data augmentation process, data continuously flows through a "pipeline" to the training system, as shown in the following figure:
+
+
+
+MindSpore provides data processing and data augmentation functions for users. In the pipeline process, if each step can be properly used, the data performance will be greatly improved. This section describes how to optimize performance during data loading, data processing, and data augmentation based on the CIFAR-10 dataset.
+
+## Overall Process
+- Prepare data.
+- Optimize the data loading performance.
+- Optimize the shuffle performance.
+- Optimize the data augmentation performance.
+- Summarize the performance optimization solution.
+
+## Preparations
+
+### Importing Modules
+
+The `dataset` module provides APIs for loading and processing datasets.
+
+
+```python
+import mindspore.dataset as ds
+```
+
+The `numpy` module is used to generate ndarrays.
+
+
+```python
+import numpy as np
+```
+
+### Downloading the Required Dataset
+
+1. Create the `./dataset/Cifar10Data` directory in the current working directory. The dataset used for this practice is stored in this directory.
+2. Create the `./transform` directory in the current working directory. The dataset generated during the practice is stored in this directory.
+3. Download [the CIFAR-10 dataset in binary format](https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz) and decompress the dataset file to the `./dataset/Cifar10Data/cifar-10-batches-bin` directory. The dataset will be used during data loading.
+4. Download [the CIFAR-10 Python dataset in file-format](https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz) and decompress the dataset file to the `./dataset/Cifar10Data/cifar-10-batches-py` directory. The dataset will be used for data conversion.
+
+The directory structure is as follows:
+
+
+ dataset/Cifar10Data
+ ├── cifar-10-batches-bin
+ │ ├── batches.meta.txt
+ │ ├── data_batch_1.bin
+ │ ├── data_batch_2.bin
+ │ ├── data_batch_3.bin
+ │ ├── data_batch_4.bin
+ │ ├── data_batch_5.bin
+ │ ├── readme.html
+ │ └── test_batch.bin
+ └── cifar-10-batches-py
+ ├── batches.meta
+ ├── data_batch_1
+ ├── data_batch_2
+ ├── data_batch_3
+ ├── data_batch_4
+ ├── data_batch_5
+ ├── readme.html
+ └── test_batch
+
+In the preceding information:
+- The `cifar-10-batches-bin` directory is the directory for storing the CIFAR-10 dataset in binary format.
+- The `cifar-10-batches-py` directory is the directory for storing the CIFAR-10 dataset in Python file format.
+
+## Optimizing the Data Loading Performance
+
+MindSpore provides multiple data loading methods, including common dataset loading, user-defined dataset loading, and MindSpore data format loading. For details, see [Loading Datasets](https://www.mindspore.cn/tutorial/en/master/use/data_preparation/loading_the_datasets.html). The dataset loading performance varies depending on the underlying implementation method.
+
+| | Common Dataset | User-defined Dataset | MindRecord Dataset |
+| :----: | :----: | :----: | :----: |
+| Underlying implementation | C++ | Python | C++ |
+| Performance | High | Medium | High |
+
+### Performance Optimization Solution
+
+
+
+Suggestions on data loading performance optimization are as follows:
+- Built-in loading operators are preferred for supported dataset formats. For details, see [Built-in Loading Operators](https://www.mindspore.cn/api/en/master/api/python/mindspore/mindspore.dataset.html). If the performance cannot meet the requirements, use the multi-thread concurrency solution. For details, see [Multi-thread Optimization Solution](#multi-thread-optimization-solution).
+- For a dataset format that is not supported, convert the format to MindSpore data format and then use the `MindDataset` class to load the dataset. For details, see [Converting Datasets into MindSpore Data Format](https://www.mindspore.cn/tutorial/en/r0.7/use/data_preparation/converting_datasets.html). If the performance cannot meet the requirements, use the multi-thread concurrency solution, for details, see [Multi-thread Optimization Solution](#multi-thread-optimization-solution).
+- For dataset formats that are not supported, the user-defined `GeneratorDataset` class is preferred for implementing fast algorithm verification. If the performance cannot meet the requirements, the multi-process concurrency solution can be used. For details, see [Multi-process Optimization Solution](#multi-process-optimization-solution).
+
+### Code Example
+
+Based on the preceding suggestions of data loading performance optimization, the `Cifar10Dataset` class of built-in loading operators, the `MindDataset` class after data conversion, and the `GeneratorDataset` class are used to load data. The sample code is displayed as follows:
+
+1. Use the `Cifar10Dataset` class of built-in operators to load the CIFAR-10 dataset in binary format. The multi-thread optimization solution is used for data loading. Four threads are enabled to concurrently complete the task. Finally, a dictionary iterator is created for the data and a data record is read through the iterator.
+
+
+ ```python
+ cifar10_path = "./dataset/Cifar10Data/cifar-10-batches-bin/"
+
+ # create Cifar10Dataset for reading data
+ cifar10_dataset = ds.Cifar10Dataset(cifar10_path, num_parallel_workers=4)
+ # create a dictionary iterator and read a data record through the iterator
+ print(next(cifar10_dataset.create_dict_iterator()))
+ ```
+
+ The output is as follows:
+ ```
+ {'image': array([[[235, 235, 235],
+ [230, 230, 230],
+ [234, 234, 234],
+ ...,
+ [248, 248, 248],
+ [248, 248, 248],
+ [249, 249, 249]],
+ ...,
+ [120, 120, 119],
+ [146, 146, 146],
+ [177, 174, 190]]], dtype=uint8), 'label': array(9, dtype=uint32)}
+ ```
+
+2. Use the `Cifar10ToMR` class to convert the CIFAR-10 dataset into MindSpore data format. In this example, the CIFAR-10 dataset in Python file format is used. Then use the `MindDataset` class to load the dataset in MindSpore data format. The multi-thread optimization solution is used for data loading. Four threads are enabled to concurrently complete the task. Finally, a dictionary iterator is created for data and a data record is read through the iterator.
+
+
+ ```python
+ from mindspore.mindrecord import Cifar10ToMR
+
+ cifar10_path = './dataset/Cifar10Data/cifar-10-batches-py/'
+ cifar10_mindrecord_path = './transform/cifar10.record'
+
+ cifar10_transformer = Cifar10ToMR(cifar10_path, cifar10_mindrecord_path)
+ # executes transformation from Cifar10 to MindRecord
+ cifar10_transformer.transform(['label'])
+
+ # create MindDataset for reading data
+ cifar10_mind_dataset = ds.MindDataset(dataset_file=cifar10_mindrecord_path, num_parallel_workers=4)
+ # create a dictionary iterator and read a data record through the iterator
+ print(next(cifar10_mind_dataset.create_dict_iterator()))
+ ```
+
+ The output is as follows:
+ ```
+ {'data': array([255, 216, 255, ..., 63, 255, 217], dtype=uint8), 'id': array(30474, dtype=int64), 'label': array(2, dtype=int64)}
+ ```
+
+3. The `GeneratorDataset` class is used to load the user-defined dataset, and the multi-process optimization solution is used. Four processes are enabled to concurrently complete the task. Finally, a dictionary iterator is created for the data, and a data record is read through the iterator.
+
+
+ ```python
+ def generator_func(num):
+ for i in range(num):
+ yield (np.array([i]),)
+
+ # create GeneratorDataset for reading data
+ dataset = ds.GeneratorDataset(source=generator_func(5), column_names=["data"], num_parallel_workers=4)
+ # create a dictionary iterator and read a data record through the iterator
+ print(next(dataset.create_dict_iterator()))
+ ```
+
+ The output is as follows:
+ ```
+ {'data': array([0], dtype=int64)}
+ ```
+
+## Optimizing the Shuffle Performance
+
+The shuffle operation is used to shuffle ordered datasets or repeated datasets. MindSpore provides the `shuffle` function for users. A larger value of `buffer_size` indicates a higher shuffling degree, consuming more time and computing resources. This API allows users to shuffle the data at any time during the entire pipeline process. For details, see [Shuffle Processing](https://www.mindspore.cn/tutorial/en/master/use/data_preparation/data_processing_and_augmentation.html#shuffle). However, because the underlying implementation methods are different, the performance of this method is not as good as that of setting the `shuffle` parameter to directly shuffle data by referring to the [Built-in Loading Operators](https://www.mindspore.cn/api/en/master/api/python/mindspore/mindspore.dataset.html).
+
+### Performance Optimization Solution
+
+
+
+Suggestions on shuffle performance optimization are as follows:
+- Use the `shuffle` parameter of built-in loading operators to shuffle data.
+- If the `shuffle` function is used and the performance still cannot meet the requirements, increase the value of the `buffer_size` parameter to improve the performance.
+
+### Code Example
+
+Based on the preceding shuffle performance optimization suggestions, the `shuffle` parameter of the `Cifar10Dataset` class of built-in loading operators and the `Shuffle` function are used to shuffle data. The sample code is displayed as follows:
+
+1. Use the built-in operator in `Cifar10Dataset` class to load the CIFAR-10 dataset. In this example, the CIFAR-10 dataset in binary format is used, and the `shuffle` parameter is set to True to perform data shuffle. Finally, a dictionary iterator is created for the data and a data record is read through the iterator.
+
+
+ ```python
+ cifar10_path = "./dataset/Cifar10Data/cifar-10-batches-bin/"
+
+ # create Cifar10Dataset for reading data
+ cifar10_dataset = ds.Cifar10Dataset(cifar10_path, shuffle=True)
+ # create a dictionary iterator and read a data record through the iterator
+ print(next(cifar10_dataset.create_dict_iterator()))
+ ```
+
+ The output is as follows:
+ ```
+ {'image': array([[[254, 254, 254],
+ [255, 255, 254],
+ [255, 255, 254],
+ ...,
+ [232, 234, 244],
+ [226, 230, 242],
+ [228, 232, 243]],
+ ...,
+ [ 64, 61, 63],
+ [ 63, 58, 60],
+ [ 61, 56, 58]]], dtype=uint8), 'label': array(9, dtype=uint32)}
+ ```
+
+2. Use the `shuffle` function to shuffle data. Set `buffer_size` to 3 and use the `GeneratorDataset` class to generate data.
+
+
+ ```python
+ def generator_func():
+ for i in range(5):
+ yield (np.array([i, i+1, i+2, i+3, i+4]),)
+
+ ds1 = ds.GeneratorDataset(source=generator_func, column_names=["data"])
+ print("before shuffle:")
+ for data in ds1.create_dict_iterator():
+ print(data["data"])
+
+ ds2 = ds1.shuffle(buffer_size=3)
+ print("after shuffle:")
+ for data in ds2.create_dict_iterator():
+ print(data["data"])
+ ```
+ ```
+ The output is as follows:
+
+ before shuffle:
+ [0 1 2 3 4]
+ [1 2 3 4 5]
+ [2 3 4 5 6]
+ [3 4 5 6 7]
+ [4 5 6 7 8]
+ after shuffle:
+ [2 3 4 5 6]
+ [0 1 2 3 4]
+ [4 5 6 7 8]
+ [1 2 3 4 5]
+ [3 4 5 6 7]
+ ```
+
+## Optimizing the Data Augmentation Performance
+
+During image classification training, especially when the dataset is small, users can use data augmentation to preprocess images to enrich the dataset. MindSpore provides multiple data augmentation methods, including:
+- Use the built-in C operator (`c_transforms` module) to perform data augmentation.
+- Use the built-in Python operator (`py_transforms` module) to perform data augmentation.
+- Users can define Python functions as needed to perform data augmentation.
+
+For details, see [Data Augmentation](https://www.mindspore.cn/tutorial/en/master/use/data_preparation/data_processing_and_augmentation.html#id3). The performance varies according to the underlying implementation methods.
+
+| Module | Underlying API | Description |
+| :----: | :----: | :----: |
+| c_transforms | C++ (based on OpenCV) | High performance |
+| py_transforms | Python (based on PIL) | This module provides multiple image augmentation functions and the method for converting PIL images into NumPy arrays. |
+
+
+### Performance Optimization Solution
+
+
+
+
+Suggestions on data augmentation performance optimization are as follows:
+- The `c_transforms` module is preferentially used to perform data augmentation for its highest performance. If the performance cannot meet the requirements, refer to [Multi-thread Optimization Solution](#multi-thread-optimization-solution), [Compose Optimization Solution](#compose-optimization-solution), or [Operator Fusion Optimization Solution](#operator-fusion-optimization-solution).
+- If the `py_transforms` module is used to perform data augmentation and the performance still cannot meet the requirements, refer to [Multi-thread Optimization Solution](#multi-thread-optimization-solution), [Multi-process Optimization Solution](#multi-process-optimization-solution), [Compose Optimization Solution](#compose-optimization-solution), or [Operator Fusion Optimization Solution](#operator-fusion-optimization-solution).
+- The `c_transforms` module maintains buffer management in C++, and the `py_transforms` module maintains buffer management in Python. Because of the performance cost of switching between Python and C++, it is advised not to use different operator types together.
+- If the user-defined Python functions are used to perform data augmentation and the performance still cannot meet the requirements, use the [Multi-thread Optimization Solution](#multi-thread-optimization-solution) or [Multi-process Optimization Solution](#multi-process-optimization-solution). If the performance still cannot be improved, in this case, optimize the user-defined Python code.
+
+### Code Example
+
+Based on the preceding suggestions of data augmentation performance optimization, the `c_transforms` module and user-defined Python function are used to perform data augmentation. The code is displayed as follows:
+
+1. The `c_transforms` module is used to perform data augmentation. During data augmentation, the multi-thread optimization solution is used. Four threads are enabled to concurrently complete the task. The operator fusion optimization solution is used and the `RandomResizedCrop` fusion class is used to replace the `RandomResize` and `RandomCrop` classes.
+
+
+ ```python
+ import mindspore.dataset.transforms.c_transforms as c_transforms
+ import mindspore.dataset.vision.c_transforms as C
+ import matplotlib.pyplot as plt
+ cifar10_path = "./dataset/Cifar10Data/cifar-10-batches-bin/"
+
+ # create Cifar10Dataset for reading data
+ cifar10_dataset = ds.Cifar10Dataset(cifar10_path, num_parallel_workers=4)
+ transforms = C.RandomResizedCrop((800, 800))
+ # apply the transform to the dataset through dataset.map()
+ cifar10_dataset = cifar10_dataset.map(operations=transforms, input_columns="image", num_parallel_workers=4)
+
+ data = next(cifar10_dataset.create_dict_iterator())
+ plt.imshow(data["image"])
+ plt.show()
+ ```
+
+ The output is as follows:
+
+ 
+
+
+2. A user-defined Python function is used to perform data augmentation. During data augmentation, the multi-process optimization solution is used, and four processes are enabled to concurrently complete the task.
+
+
+ ```python
+ def generator_func():
+ for i in range(5):
+ yield (np.array([i, i+1, i+2, i+3, i+4]),)
+
+ ds3 = ds.GeneratorDataset(source=generator_func, column_names=["data"])
+ print("before map:")
+ for data in ds3.create_dict_iterator():
+ print(data["data"])
+
+ func = lambda x:x**2
+ ds4 = ds3.map(operations=func, input_columns="data", python_multiprocessing=True, num_parallel_workers=4)
+ print("after map:")
+ for data in ds4.create_dict_iterator():
+ print(data["data"])
+ ```
+
+ The output is as follows:
+ ```
+ before map:
+ [0 1 2 3 4]
+ [1 2 3 4 5]
+ [2 3 4 5 6]
+ [3 4 5 6 7]
+ [4 5 6 7 8]
+ after map:
+ [ 0 1 4 9 16]
+ [ 1 4 9 16 25]
+ [ 4 9 16 25 36]
+ [ 9 16 25 36 49]
+ [16 25 36 49 64]
+ ```
+
+## Performance Optimization Solution Summary
+
+### Multi-thread Optimization Solution
+
+During the data pipeline process, the number of threads for related operators can be set to improve the concurrency and performance. For example:
+- During data loading, the `num_parallel_workers` parameter in the built-in data loading class is used to set the number of threads.
+- During data augmentation, the `num_parallel_workers` parameter in the `map` function is used to set the number of threads.
+- During batch processing, the `num_parallel_workers` parameter in the `batch` function is used to set the number of threads.
+
+For details, see [Built-in Loading Operators](https://www.mindspore.cn/api/en/master/api/python/mindspore/mindspore.dataset.html).
+
+### Multi-process Optimization Solution
+
+During data processing, operators implemented by Python support the multi-process mode. For example:
+- By default, the `GeneratorDataset` class is in multi-process mode. The `num_parallel_workers` parameter indicates the number of enabled processes. The default value is 1. For details, see [Generator Dataset](https://www.mindspore.cn/api/en/master/api/python/mindspore/mindspore.dataset.html#mindspore.dataset.GeneratorDataset)
+- If the user-defined Python function or the `py_transforms` module is used to perform data augmentation and the `python_multiprocessing` parameter of the `map` function is set to True, the `num_parallel_workers` parameter indicates the number of processes and the default value of the `python_multiprocessing` parameter is False. In this case, the `num_parallel_workers` parameter indicates the number of threads. For details, see [Built-in Loading Operators](https://www.mindspore.cn/api/en/master/api/python/mindspore/mindspore.dataset.html).
+
+### Compose Optimization Solution
+
+Map operators can receive the Tensor operator list and apply all these operators based on a specific sequence. Compared with the Map operator used by each Tensor operator, such Fat Map operators can achieve better performance, as shown in the following figure:
+
+
+
+### Operator Fusion Optimization Solution
+
+Some fusion operators are provided to aggregate the functions of two or more operators into one operator. For details, see [Data Augmentation Operators](https://www.mindspore.cn/api/en/master/api/python/mindspore/mindspore.dataset.vision.html). Compared with the pipelines of their components, such fusion operators provide better performance. As shown in the figure:
+
+
diff --git a/tutorials/source_en/index.rst b/tutorials/source_en/index.rst
index 0c0e217dec12430a11983202537e4c2176226531..214277f3d0712631a32882415950ae7e14f47e13 100644
--- a/tutorials/source_en/index.rst
+++ b/tutorials/source_en/index.rst
@@ -31,6 +31,7 @@ MindSpore Tutorials
advanced_use/computer_vision_application
advanced_use/nlp_application
+ advanced_use/optimize_the_performance_of_data_preparation
advanced_use/synchronization_training_and_evaluation.md
.. toctree::
diff --git a/tutorials/source_zh_cn/advanced_use/bert_poetry.md b/tutorials/source_zh_cn/advanced_use/bert_poetry.md
index 238e968cfab0a324b75a7fbbe798f961db424243..ffe45bb508dbef8dc5ffd28658e964beb1e5e0f3 100644
--- a/tutorials/source_zh_cn/advanced_use/bert_poetry.md
+++ b/tutorials/source_zh_cn/advanced_use/bert_poetry.md
@@ -18,7 +18,7 @@
- [训练](#训练)
- [推理验证](#推理验证)
- [服务部署](#服务部署)
- - [参考资料](#参考资料)
+ - [参考文献](#参考文献)
diff --git a/tutorials/source_zh_cn/advanced_use/computer_vision_application.md b/tutorials/source_zh_cn/advanced_use/computer_vision_application.md
index 4bed55f52a566f9e0265a6a39c1beb3f80018529..10e9d5a0dc0d12582d37e18710b1a31bfacea4e8 100644
--- a/tutorials/source_zh_cn/advanced_use/computer_vision_application.md
+++ b/tutorials/source_zh_cn/advanced_use/computer_vision_application.md
@@ -9,11 +9,11 @@
- [图像分类](#图像分类)
- [任务描述及准备](#任务描述及准备)
- [下载CIFAR-10数据集](#下载cifar-10数据集)
- - [数据预加载和预处理](#数据预加载和预处理)
- - [定义卷积神经网络](#定义卷积神经网络)
- - [定义损失函数和优化器](#定义损失函数和优化器)
- - [调用`Model`高阶API进行训练和保存模型文件](#调用model高阶api进行训练和保存模型文件)
- - [加载保存的模型,并进行验证](#加载保存的模型并进行验证)
+ - [数据预加载和预处理](#数据预加载和预处理)
+ - [定义卷积神经网络](#定义卷积神经网络)
+ - [定义损失函数和优化器](#定义损失函数和优化器)
+ - [调用`Model`高阶API进行训练和保存模型文件](#调用model高阶api进行训练和保存模型文件)
+ - [加载保存的模型,并进行验证](#加载保存的模型并进行验证)
- [参考文献](#参考文献)
diff --git a/tutorials/source_zh_cn/advanced_use/differential_privacy.md b/tutorials/source_zh_cn/advanced_use/differential_privacy.md
index 6813d7f001952205689e75654ab29a043128747d..9f3fbf032d23009bd735a816f2a4470d51e9ee8a 100644
--- a/tutorials/source_zh_cn/advanced_use/differential_privacy.md
+++ b/tutorials/source_zh_cn/advanced_use/differential_privacy.md
@@ -12,7 +12,7 @@
- [预处理数据集](#预处理数据集)
- [建立模型](#建立模型)
- [引入差分隐私](#引入差分隐私)
- - [引用](#引用)
+ - [参考文献](#参考文献)
@@ -346,7 +346,7 @@ ds_train = generate_mnist_dataset(os.path.join(cfg.data_path, "train"),
============== Accuracy: 0.9698 ==============
```
-### 引用
+## 参考文献
[1] C. Dwork and J. Lei. Differential privacy and robust statistics. In STOC, pages 371–380. ACM, 2009.
diff --git a/tutorials/source_zh_cn/advanced_use/distributed_training_ascend.md b/tutorials/source_zh_cn/advanced_use/distributed_training_ascend.md
index 570a8ea4c42d2d9dc260f2d4afe0cbb491424f4a..ff7a1ed05a4298dec0f2c1d5d83642e1055406ed 100644
--- a/tutorials/source_zh_cn/advanced_use/distributed_training_ascend.md
+++ b/tutorials/source_zh_cn/advanced_use/distributed_training_ascend.md
@@ -33,9 +33,9 @@
### 下载数据集
-本样例采用`CIFAR-10`数据集,由10类32*32的彩色图片组成,每类包含6000张图片。其中训练集共50000张图片,测试集共10000张图片。
+本样例采用CIFAR-10数据集,由10类32*32的彩色图片组成,每类包含6000张图片。其中训练集共50000张图片,测试集共10000张图片。
-> `CIFAR-10`数据集下载链接:。
+> CIFAR-10数据集下载链接:。
将数据集下载并解压到本地路径下,解压后的文件夹为`cifar-10-batches-bin`。
@@ -320,6 +320,7 @@ cd ../
- `RANK_TABLE_FILE`:组网信息文件的路径。
- `DEVICE_ID`:当前卡在机器上的实际序号。
- `RANK_ID`: 当前卡的逻辑序号。
+
其余环境变量请参考安装教程中的配置项。
运行时间大约在5分钟内,主要时间是用于算子的编译,实际训练时间在20秒内。用户可以通过`ps -ef | grep pytest`来监控任务进程。
diff --git a/tutorials/source_zh_cn/advanced_use/distributed_training_gpu.md b/tutorials/source_zh_cn/advanced_use/distributed_training_gpu.md
index 13cb619328c7d2121b9d557e9b4294647922734f..918e0b503079e4592e59c72d45815898b6fc1b07 100644
--- a/tutorials/source_zh_cn/advanced_use/distributed_training_gpu.md
+++ b/tutorials/source_zh_cn/advanced_use/distributed_training_gpu.md
@@ -27,7 +27,7 @@
### 下载数据集
-本样例采用`CIFAR-10`数据集,数据集的下载以及加载方式和Ascend 910 AI处理器一致。
+本样例采`CIFAR-10数据集,数据集的下载以及加载方式和Ascend 910 AI处理器一致。
> 数据集的下载和加载方式参考:
>
@@ -109,7 +109,7 @@ echo "start training"
mpirun -n 8 pytest -s -v ./resnet50_distributed_training.py > train.log 2>&1 &
```
-脚本需要传入变量`DATA_PATH`,表示数据集的路径。此外,我们需要修改下`resnet50_distributed_training.py`文件,由于在GPU上,我们无需设置`DEVICE_ID`环境变量,因此,在脚本中不需要调用`int(os.getenv('DEVICE_ID'))`来获取卡的物理序号,同时`context`中也无需传入`device_id`。我们需要将`device_target`设置为`GPU`,并调用`init("nccl")`来使能NCCL。日志文件保存到device目录下,关于Loss部分结果保存在train.log中。将loss值grep出来后,示例如下:
+脚本需要传入变量`DATA_PATH`,表示数据集的路径。此外,我们需要修改下`resnet50_distributed_training.py`文件,由于在GPU上,我们无需设置`DEVICE_ID`环境变量,因此,在脚本中不需要调用`int(os.getenv('DEVICE_ID'))`来获取卡的物理序号,同时`context`中也无需传入`device_id`。我们需要将`device_target`设置为`GPU`,并调用`init("nccl")`来使能NCCL。日志文件保存到`device`目录下,关于Loss部分结果保存在`train.log`中。将loss值grep出来后,示例如下:
```
epoch: 1 step: 1, loss is 2.3025854
@@ -124,7 +124,7 @@ epoch: 1 step: 1, loss is 2.3025854
## 运行多机脚本
-若训练涉及多机,则需要额外在`mpirun`命令中设置多机配置。你可以直接在`mpirun`命令中用`-H`选项进行设置,比如`mpirun -n 16 -H DEVICE1_IP:8,DEVICE2_IP:8 python hello.py`,表示在ip为DEVICE1_IP和DEVICE2_IP的机器上分别起8个进程运行程序;或者也可以构造一个如下这样的hostfile文件,并将其路径传给`mpirun`的`--hostfile`的选项。hostfile文件每一行格式为`[hostname] slots=[slotnum]`,hostname可以是ip或者主机名。
+若训练涉及多机,则需要额外在`mpirun`命令中设置多机配置。你可以直接在`mpirun`命令中用`-H`选项进行设置,比如`mpirun -n 16 -H DEVICE1_IP:8,DEVICE2_IP:8 python hello.py`,表示在ip为DEVICE1_IP和DEVICE2_IP的机器上分别起8个进程运行程序;或者也可以构造一个如下这样的`hostfile`文件,并将其路径传给`mpirun`的`--hostfile`的选项。`hostfile`文件每一行格式为`[hostname] slots=[slotnum]`,`hostname`可以是ip或者主机名。
```bash
DEVICE1 slots=8
DEVICE2 slots=8
diff --git a/tutorials/source_zh_cn/advanced_use/fuzzer.md b/tutorials/source_zh_cn/advanced_use/fuzzer.md
index 516b9a7467383ff7e36ea4f712dcb30cf495bcd0..9bf85defe5dd50920dcc5e34cefed16126dd0f31 100644
--- a/tutorials/source_zh_cn/advanced_use/fuzzer.md
+++ b/tutorials/source_zh_cn/advanced_use/fuzzer.md
@@ -60,7 +60,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
### 运用Fuzzer
-1. 建立LeNet模型,加载MNIST数据集,操作同[模型安全]()
+1. 建立LeNet模型,加载MNIST数据集,操作同[模型安全]()。
```python
...
@@ -72,10 +72,10 @@ context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
ds = generate_mnist_dataset(data_list, batch_size, sparse=False)
train_images = []
for data in ds.create_tuple_iterator():
- images = data[0].astype(np.float32)
- train_images.append(images)
- train_images = np.concatenate(train_images, axis=0)
-
+ images = data[0].astype(np.float32)
+ train_images.append(images)
+ train_images = np.concatenate(train_images, axis=0)
+
# get test data
data_list = "./MNIST_unzip/test"
batch_size = 32
@@ -83,10 +83,10 @@ context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
test_images = []
test_labels = []
for data in ds.create_tuple_iterator():
- images = data[0].astype(np.float32)
- labels = data[1]
- test_images.append(images)
- test_labels.append(labels)
+ images = data[0].astype(np.float32)
+ labels = data[1]
+ test_images.append(images)
+ test_labels.append(labels)
test_images = np.concatenate(test_images, axis=0)
test_labels = np.concatenate(test_labels, axis=0)
```
@@ -95,9 +95,9 @@ context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
设置数据变异方法及参数。目前支持的数据变异方法包含三类:
- - 图像仿射变换方法:Translate、Scale、Shear、Rotate。
- - 基于图像像素值变化的方法: Contrast、Brightness、Blur、Noise。
- - 基于对抗攻击的白盒、黑盒对抗样本生成方法:FGSM、PGD、MDIIM。
+ - 图像仿射变换方法:`Translate`、`Scale`、`Shear`、`Rotate`。
+ - 基于图像像素值变化的方法: `Contrast`、`Brightness`、`Blur`、`Noise`。
+ - 基于对抗攻击的白盒、黑盒对抗样本生成方法:`FGSM`、`PGD`、`MDIIM`。
数据变异方法一定要包含基于图像像素值变化的方法。
@@ -172,8 +172,8 @@ context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
```python
if metrics:
- for key in metrics:
- LOGGER.info(TAG, key + ': %s', metrics[key])
+ for key in metrics:
+ LOGGER.info(TAG, key + ': %s', metrics[key])
```
Fuzz测试后结果如下:
@@ -190,6 +190,6 @@ context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)

- Fuzz生成的变异图片:
+ Fuzz生成的变异图片:

\ No newline at end of file
diff --git a/tutorials/source_zh_cn/advanced_use/gradient_accumulation.md b/tutorials/source_zh_cn/advanced_use/gradient_accumulation.md
index 8a232a39b320b1edef473ad28a5a3bc54b2b610f..73924f96690fd72e95c10737d03122c09b756b00 100644
--- a/tutorials/source_zh_cn/advanced_use/gradient_accumulation.md
+++ b/tutorials/source_zh_cn/advanced_use/gradient_accumulation.md
@@ -6,7 +6,7 @@
- [梯度累积](#梯度累积)
- [概述](#概述)
- - [建立梯度累积模型](#建立梯度累积模型)
+ - [创建梯度累积模型](#创建梯度累积模型)
- [导入需要的库文件](#导入需要的库文件)
- [加载数据集](#加载数据集)
- [定义网络](#定义网络)
@@ -14,7 +14,6 @@
- [定义训练过程](#定义训练过程)
- [训练并保存模型](#训练并保存模型)
- [实验结果](#实验结果)
-
@@ -59,17 +58,17 @@ from model_zoo.official.cv.lenet.src.lenet import LeNet5
### 加载数据集
-利用MindSpore的dataset提供的`MnistDataset`接口加载MNIST数据集,此部分代码由model_zoo中lenet目录下的[dataset.py]()导入。
+利用MindSpore的`dataset`提供的`MnistDataset`接口加载MNIST数据集,此部分代码由`model_zoo`中`lenet`目录下的[dataset.py]()导入。
### 定义网络
-这里以LeNet网络为例进行介绍,当然也可以使用其它的网络,如ResNet-50、BERT等, 此部分代码由model_zoo中lenet目录下的[lenet.py]()导入。
+这里以LeNet网络为例进行介绍,当然也可以使用其它的网络,如ResNet-50、BERT等, 此部分代码由`model_zoo`中`lenet`目录下的[lenet.py]()导入。
### 定义训练模型
将训练流程拆分为正向反向训练、参数更新和累积梯度清理三个部分:
-- `TrainForwardBackward`计算loss和梯度,利用grad_sum实现梯度累加。
+- `TrainForwardBackward`计算loss和梯度,利用`grad_sum`实现梯度累加。
- `TrainOptim`实现参数更新。
-- `TrainClear`实现对梯度累加变量grad_sum清零。
+- `TrainClear`实现对梯度累加变量`grad_sum`清零。
```python
_sum_op = C.MultitypeFuncGraph("grad_sum_op")
@@ -135,8 +134,7 @@ class TrainClear(Cell):
```
### 定义训练过程
-每个Mini-batch通过正反向训练计算loss和梯度,通过mini_steps控制每次更新参数前的累加次数。达到累加次数后进行参数更新和
-累加梯度变量清零。
+每个Mini-batch通过正反向训练计算loss和梯度,通过`mini_steps`控制每次更新参数前的累加次数。达到累加次数后进行参数更新和累加梯度变量清零。
```python
class GradientAccumulation:
@@ -253,7 +251,7 @@ if __name__ == "__main__":
**验证模型**
-通过model_zoo中lenet目录下的[eval.py](),使用保存的CheckPoint文件,加载验证数据集,进行验证。
+通过`model_zoo`中`lenet`目录下的[eval.py](),使用保存的CheckPoint文件,加载验证数据集,进行验证。
```shell
$ python eval.py --data_path=./MNIST_Data --ckpt_path=./gradient_accumulation.ckpt
diff --git a/tutorials/source_zh_cn/advanced_use/mixed_precision.md b/tutorials/source_zh_cn/advanced_use/mixed_precision.md
index 6d31c1d743c6f6038bf7407aa0428078c97ec338..82d5a35d2b1f5983b186fa2bd1997b4e5deebb0d 100644
--- a/tutorials/source_zh_cn/advanced_use/mixed_precision.md
+++ b/tutorials/source_zh_cn/advanced_use/mixed_precision.md
@@ -102,7 +102,7 @@ MindSpore还支持手动混合精度。假定在网络中只有一个Dense Layer
2. 配置混合精度: 通过`net.to_float(mstype.float16)`,把该Cell及其子Cell中所有的算子都配置成FP16;然后,将模型中的dense算子手动配置成FP32;
-3. 使用TrainOneStepCell封装网络模型和优化器。
+3. 使用`TrainOneStepCell`封装网络模型和优化器。
代码样例如下: