From 5087bc76f5147b56724d9a4d0cba1d2013e69074 Mon Sep 17 00:00:00 2001
From: huanxiaoling <3174348550@qq.com>
Date: Thu, 10 Nov 2022 14:07:39 +0800
Subject: [PATCH] update the en files in tutorials
---
tutorials/experts/source_en/index.rst | 1 +
.../source_en/parallel/distributed_case.rst | 1 +
.../source_en/parallel/fault_recover.md | 2 +-
.../source_en/parallel/introduction.md | 2 +-
.../experts/source_en/parallel/pangu_alpha.md | 330 ++++++++++++++++++
.../parallel/parallel_training_quickstart.md | 281 +++++++++++++++
.../source_zh_cn/parallel/pangu_alpha.md | 4 +-
.../parallel/parallel_training_quickstart.md | 8 +-
8 files changed, 620 insertions(+), 9 deletions(-)
create mode 100644 tutorials/experts/source_en/parallel/pangu_alpha.md
create mode 100644 tutorials/experts/source_en/parallel/parallel_training_quickstart.md
diff --git a/tutorials/experts/source_en/index.rst b/tutorials/experts/source_en/index.rst
index 4be570fbe5..95f0143940 100644
--- a/tutorials/experts/source_en/index.rst
+++ b/tutorials/experts/source_en/index.rst
@@ -73,6 +73,7 @@ For Experts
:titlesonly:
parallel/introduction
+ parallel/parallel_training_quickstart
parallel/communicate_ops
parallel/distributed_case
parallel/distributed_inference
diff --git a/tutorials/experts/source_en/parallel/distributed_case.rst b/tutorials/experts/source_en/parallel/distributed_case.rst
index bab008d539..e728e6210c 100644
--- a/tutorials/experts/source_en/parallel/distributed_case.rst
+++ b/tutorials/experts/source_en/parallel/distributed_case.rst
@@ -7,3 +7,4 @@ Distributed Case
train_ascend
train_gpu
transformer
+ pangu_alpha
\ No newline at end of file
diff --git a/tutorials/experts/source_en/parallel/fault_recover.md b/tutorials/experts/source_en/parallel/fault_recover.md
index 435c4262b2..65908c3bba 100644
--- a/tutorials/experts/source_en/parallel/fault_recover.md
+++ b/tutorials/experts/source_en/parallel/fault_recover.md
@@ -85,7 +85,7 @@ rank_list = ms.restore_group_info_list("./ckpt_dir0/group_info.pb")
print(rank_list) // [0, 4]
```
-Distributed fault recovery requires prior access to the slicing scores, thus, it is necessary to first call [model.build](https://www.mindspore.cn/docs/zh-CN/master/api_python/train/mindspore.train.Model.html#mindspore.train.model.build) to compile and then perform the training.
+Distributed fault recovery requires prior access to the slicing scores, thus, it is necessary to first call [model.build](https://www.mindspore.cn/docs/zh-CN/master/api_python/train/mindspore.train.Model.html#mindspore.train.Model.build) to compile and then perform the training.
```python
import os
diff --git a/tutorials/experts/source_en/parallel/introduction.md b/tutorials/experts/source_en/parallel/introduction.md
index e1741c9f65..485050a729 100644
--- a/tutorials/experts/source_en/parallel/introduction.md
+++ b/tutorials/experts/source_en/parallel/introduction.md
@@ -195,7 +195,7 @@ Therefore, the inserted rescheduling operators may be `AllGather`, `Split`, `Con
Pipeline parallel is also possible in automatic and semi-automatic modes by configuring the `pipeline_stage` property on the `Cell`. The corresponding tutorial on pipeline parallelism can be found in [Applying Pipeline Parallel](https://www.mindspore.cn/tutorials/experts/en/master/parallel/pipeline_parallel.html).
-### Automatic Parallelism
+### Fully Automatic Parallelism
Automatic parallel mode, a distributed parallel mode that combines data parallel, model parallel and hybrid parallel in, can automatically build cost models, find parallel strategies with shorter training time, and select the appropriate parallel mode for users. MindSpore provides the following three different strategy search algorithms:
diff --git a/tutorials/experts/source_en/parallel/pangu_alpha.md b/tutorials/experts/source_en/parallel/pangu_alpha.md
new file mode 100644
index 0000000000..106d2939c2
--- /dev/null
+++ b/tutorials/experts/source_en/parallel/pangu_alpha.md
@@ -0,0 +1,330 @@
+# PengCheng·PanGu Model Network Multi-dimension Hydrid Parallel Analysis
+
+
+
+## Overview
+
+In the PengCheng·PanGu model [1] published by MindSpore, we see that distributed training of very large Transformer networks can be achieved with the help of multi-dimensional automatic hybrid parallelism. This article will explain the sharding method of each component in the model in detail, starting from the network script.
+
+> For the complete code, refer to [pangu_alpha](https://gitee.com/mindspore/models/tree/master/official/nlp/pangu_alpha)
+
+In the training entry script train.py, the semi-automatic parallel mode `SEMI_AUTO_PARALLEL` is enabled by the `set_auto_parallel_context` interface, indicating that users can automatically complete the sharding with the help of the framework by configuring the sharding strategy for the operator. According to the features of operation volume and calculation methods in different network layers, choosing the appropriate sharding strategy is the focus of this paper. In addition, you can configure the optimizer parallelism and pipeline parallelism through the `enable_parallel_optimizer` and `pipeline_stages` parameters.
+
+## Embedding Layer
+
+In language model training, the input data are sentences composed of words, and we usually use the embedding algorithm to implement word vectorization, which maps the words and their location information into word vectors of size dimension `config.hidden_size`. The Embedding layer in the PanGu model consists of two parts, location encoding and word embedding, and implements basic data parallelism and model parallelism logic through `mindspore.nn.transformer.VocabEmbedding`.
+
+The following code shows that the `Gather` operator takes two inputs and finds the corresponding vectors in the lookup table `embedding_table` according to the index `input_ids`. The lookup table is a parameter to be learned during training and statically occupies memory resources on the card. We can decide to use a data parallel strategy for the `Gather` operator to slice the index batch dimension or a model parallel strategy to row slice the lookup table depending on the size of the lookup table. When the word list range `config.vocab_size` is large, it is recommended to choose a model parallel strategy for `word_embedding`, and the framework will automatically introduce computation and communication operators to handle out-of-bounds lookup cases.
+
+- Data parallel strategy `gather.shard(((1, 1), (parallel_config.data_parallel, 1)))`
+
+- Model parallel strategy `gather.shard(((parallel_config.model_parallel, 1), (1, 1)))`
+
+> The scripts and articles use config.data_parallel and config.model_parallel to refer to the data parallel slice dimension size and the model parallel slice dimension size.
+
+```python
+import mindspore as ms
+from mindspore.common.initializer import initializer
+import mindspore.ops as ops
+from mindspore.nn import Cell
+from mindspore.nn.transformer import EmbeddingOpParallelConfig
+default_embedding_parallel_config = EmbeddingOpParallelConfig()
+class VocabEmbedding(Cell):
+ def __init__(self, vocab_size, hidden_size, parallel_config=default_embedding_parallel_config,
+ param_init='normal'):
+ super(VocabEmbedding, self).__init__()
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.embedding_table = ms.Parameter(initializer(param_init, [self.vocab_size, self.hidden_size]),
+ name='embedding_table', parallel_optimizer=False)
+ if parallel_config.vocab_emb_dp:
+ self.gather = ops.GatherV2().shard(((1, 1), (parallel_config.data_parallel, 1)))
+ else:
+ self.gather = ops.GatherV2().shard(((parallel_config.model_parallel, 1), (1, 1)))
+ def construct(self, input_ids):
+ output = self.gather(self.embedding_table, input_ids, 0)
+ return output, self.embedding_table
+```
+
+Based on `mindspore.nn.transformer.VocabEmbedding`, we can implement the summation of word embedding vectors and location embedding vectors. We define the `Add` and `Dropout` operators and set the strategy corresponding to these two operators to be data parallelism.
+
+```python
+from mindspore.common.initializer import initializer
+import mindspore.ops as ops
+from mindspore import nn
+from mindspore.nn.transformer import VocabEmbedding
+class EmbeddingLayer(nn.Cell):
+ """Embedding layer of the PanGUAlpha Model"""
+ def __init__(self, config):
+ super(EmbeddingLayer, self).__init__()
+ self.word_embedding = VocabEmbedding(vocab_size=config.vocab_size,
+ hidden_size=config.hidden_size,
+ param_init=initializer("normal", [config.vocab_size, config.hidden_size],
+ dtype=config.param_init_type),
+ parallel_config=config.parallel_config.embedding_dp_mp_config)
+ self.position_embedding = VocabEmbedding(vocab_size=config.seq_length,
+ hidden_size=config.hidden_size,
+ param_init=initializer("normal",
+ [config.seq_length, config.hidden_size],
+ dtype=config.param_init_type),
+ parallel_config=config.parallel_config.embedding_dp_mp_config)
+ self.add = ops.Add().shard(
+ ((config.parallel_config.data_parallel, 1, 1), (config.parallel_config.data_parallel, 1, 1)))
+ self.dropout = nn.Dropout(1 - config.dropout_rate)
+ self.dropout.dropout.shard(((config.parallel_config.data_parallel, 1, 1),))
+ self.is_first_iteration = True
+ self.use_past = config.use_past
+ self.batch_size = config.batch_size
+
+ def construct(self, input_ids, input_position, init_reset, batch_valid_length):
+ word_embedding, word_table = self.word_embedding(input_ids)
+ if self.use_past and not self.is_first_iteration:
+ _, seq_length = ops.shape(input_ids)
+ input_position = batch_valid_length.view(self.batch_size, seq_length)
+ position_embedding, _ = self.position_embedding(input_position)
+ embed = self.add(word_embedding, position_embedding)
+ embed = self.dropout(embed)
+ return embed, word_table
+```
+
+## Decoder Layer
+
+The key difficulty in training large-scale Transformer networks is how to solve the computational and memory bottlenecks caused by the increasing number of layers, and it is especially important to choose a reasonable slicing. The main network of the PengCheng-PanGu model consists of multiple Decoders with the same structure but do not share weights, and the Decoder is composed of two parts, Self-Attention and FeedForward. The principle of slicing is to minimize the communication, and their slicing can be referred to the following figure [1]:
+
+
+
+### Self-Attention
+
+Self-Attention can be implemented directly via `mindspore.nn.transformer.MultiHeadAttention`. In the process of computing Attention, the input vector needs to be projected to the Query, Key, and Value vectors, and then the output of attention needs to be passed through the Dense layer again after the calculation of attention is completed. The following describes the strategy configuration of these three sections respectively.
+
+- Three Dense Matrix Multiplication
+
+ Here project the input tensor with shape `[batch*sequence_length, hidden_size]` into three vectors as the Query, Key, and Value vectors for the Attention calculation.
+
+ Hybrid parallel slicing of the input batch dimension and the output_channel dimension of the weight:
+
+ `matmul.shard(((parallel_config.data_parallel, 1), (parallel_config.model_parallel, 1)))`.
+
+ Output matrix rows and sliced columns, plus the sliced bias term.
+
+ `bias_add.shard(((parallel_config.data_parallel, parallel_config.model_parallel), (parallel_config.model_parallel,)))`.
+
+ ```python
+ self.dense1 = nn.Dense(hidden_size,
+ hidden_size).to_float(compute_dtype)
+ self.dense1.matmul.shard(((parallel_config.data_parallel, 1), (parallel_config.model_parallel, 1)))
+ self.dense1.bias_add.shard(((parallel_config.data_parallel, parallel_config.model_parallel), (parallel_config.model_parallel,)))
+ ```
+
+- `Softmax` and `BatchMatMul`
+
+ The matrix multiplication of Query and Key vectors is implemented by `BatchMatMul` in the process of computing Attention. Here the input shape of `softmax` is `[batch, sequence_length, num_heads, size_per_head]`. Because each `head` is independent from each other in computing the attention score, the `softmax` operator can be sliced in the `batch` dimension and the `heads` dimension.
+
+ ```python
+ self.softmax = nn.Softmax()
+ self.softmax.softmax.shard(((parallel_config.data_parallel, parallel_config.model_parallel, 1),))
+ self.batch_matmul = ops.BatchMatMul().shard(
+ ((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1),
+ (parallel_config.data_parallel, parallel_config.model_parallel, 1, 1)))
+ ```
+
+- Projection Layer
+
+ Projection projects the output of attention once. The relevant dimension in the `MatMul` operator is sliced.
+
+ ```python
+ self.projection = nn.Dense(hidden_size,
+ hidden_size).to_float(compute_dtype)
+ self.projection.matmul.shard(((parallel_config.data_parallel, 1), (1, parallel_config.model_parallel)))
+ ```
+
+### FeedForward
+
+FeedForward can be implemented by calling `mindspore.nn.transformer.FeedForward` directly. The FeedForward network layer consists of two matrix multiplications. The first matrix multiplication slices in the same way as attention, outputting matrix rows and sliced columns, i.e., in the `batch` dimension and the `output dimension`. In order to avoid introducing redistribution communication between operators, the second matrix multiplication slices the input_channel dimension of the weights, i.e. `matmul.shard(((parallel_config.data_parallel, parallel_config.model_parallel), ( parallel_config.model_parallel, 1)))`. The framework automatically inserts the `AllReduce` operator when the relevant dimension is sliced, and accumulates the slicing results in the model parallel dimension. The output matrix is sliced in the `batch` dimension only, plus the bias term `add.shard(((parallel_config.data_parallel, 1), (1,)))`.
+
+```python
+from mindspore.common.initializer import initializer
+import mindspore as ms
+import mindspore.ops as ops
+from mindspore import nn
+from mindspore.nn import get_activation
+from mindspore.nn.transformer import OpParallelConfig
+
+default_dpmp_config = OpParallelConfig()
+class Linear(nn.Cell):
+ """
+ The dense connected layer. Once the parallel mode is enabled, the input shape should be
+ a 3-D tensor.
+ """
+ def __init__(self,
+ in_channels,
+ out_channels,
+ weight_init='normal',
+ bias_init='zeros',
+ has_bias=True,
+ activation=None,
+ transpose_b=True,
+ expert_num=1,
+ param_init_type=ms.float32,
+ compute_dtype=ms.float16):
+ super(Linear, self).__init__()
+ if transpose_b:
+ weight_shape = [out_channels, in_channels]
+ else:
+ weight_shape = [in_channels, out_channels]
+ self.expert_num = expert_num
+ if self.expert_num > 1:
+ self.expert_flag = True
+ self.weight = ms.Parameter(initializer(weight_init, [self.expert_num] + weight_shape, param_init_type),
+ name="weight")
+ self.matmul = ops.BatchMatMul(transpose_b=transpose_b)
+ else:
+ self.expert_flag = False
+ self.weight = ms.Parameter(initializer(weight_init, weight_shape, param_init_type), name="weight")
+ self.matmul = ops.MatMul(transpose_b=transpose_b)
+ self.bias = None
+ self.has_bias = has_bias
+ if self.has_bias:
+ if isinstance(bias_init, ms.Tensor):
+ if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
+ raise ValueError("Bias init shape error.")
+ self.bias = ms.Parameter(initializer(bias_init, [out_channels], param_init_type), name="bias")
+ self.bias_add = ops.Add()
+ self.act_name = activation
+ self.activation = get_activation(activation) if isinstance(activation, str) else activation
+ self.activation_flag = self.activation is not None
+ self.dtype = compute_dtype
+ self.cast = ops.Cast()
+
+ def construct(self, x):
+ out_shape = ops.Shape()(x)[:-1] + (self.out_channels,)
+ x = ops.Reshape()(x, (-1, self.in_channels))
+ if self.expert_flag is True:
+ x = ops.Reshape()(x, (self.expert_num, -1, self.in_channels))
+ weight = self.cast(self.weight, self.dtype)
+ x = self.matmul(x, weight)
+ if self.has_bias:
+ x = self.bias_add(x, self.cast(self.bias, self.dtype))
+ output = ops.Reshape()(x, out_shape)
+ if self.activation_flag:
+ output = self.activation(output)
+ return output
+
+ def shard(self, strategy_matmul, strategy_bias=None, strategy_activation=None):
+ """
+ Set the shard for the linear. the strategy size should be equal to the inputs.
+ """
+ self.matmul.shard(strategy_matmul)
+ if self.has_bias:
+ self.bias_add.shard(strategy_bias)
+ if self.activation_flag:
+ getattr(self.activation, self.act_name).shard(strategy_activation)
+ return self
+
+class FeedForward(nn.Cell):
+ """
+ The multilayer perceptron with two linear layers with dropout applied at final output. The first linear
+ will project the input dimension from hidden_size to ffn_hidden_size, the second linear will project the
+ dimension from ffn_hidden_size to hidden_size. The first linear is sharded on the relative dimension,
+ the second linear is sharded on the output dimension.
+ """
+ def __init__(self, hidden_size,
+ ffn_hidden_size,
+ dropout_rate,
+ hidden_act='gelu',
+ expert_num=1,
+ param_init_type=ms.float32,
+ parallel_config=default_dpmp_config):
+ super(FeedForward, self).__init__()
+ dp = parallel_config.data_parallel
+ mp = parallel_config.model_parallel
+ input_size = hidden_size
+ output_size = ffn_hidden_size
+ # Here, 'ep' stands for expert parallel number, which is equal to data parallel number.
+ ep = dp
+ # Project to ffn_hidden_size
+ self.mapping = Linear(in_channels=input_size,
+ out_channels=output_size,
+ activation=hidden_act,
+ transpose_b=False,
+ expert_num=expert_num,
+ param_init_type=param_init_type)
+ self.mapping.shard(strategy_matmul=((dp, 1), (1, mp)),
+ strategy_bias=((dp, mp), (mp,)),
+ strategy_activation=((dp, mp),))
+ # Project back to hidden_size
+ self.projection = Linear(in_channels=output_size,
+ out_channels=input_size,
+ transpose_b=False,
+ expert_num=expert_num,
+ param_init_type=param_init_type)
+ self.projection.shard(strategy_matmul=((dp, mp), (mp, 1)),
+ strategy_bias=((dp, 1), (1,)))
+ self.projection.bias.parallel_optimizer = False
+ self.dropout = nn.Dropout(1 - dropout_rate)
+ self.dropout.dropout.shard(((dp, 1),))
+ self.cast = ops.Cast()
+
+ def construct(self, x):
+ x = self.cast(x, ms.float16)
+ # returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size]
+ hidden = self.mapping(x)
+ output = self.projection(hidden)
+ # returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size]
+ output = self.dropout(output)
+ return output
+```
+
+## Residual Layer
+
+A detail of the Transformer structure that should be noted is that each sublayer is connected with residuals and follows the layernorm operation. Although the layernorm also contains weights, it is only a one-dimensional vector of size `hidden_size`, which accounts for a very small proportion of the network weights, so data parallel slicing is directly used here.
+
+```python
+from mindspore import nn
+
+layernorm1 = nn.LayerNorm((hidden_size,))
+layernorm1.shard(((parallel_config.data_parallel, 1),))
+```
+
+## Prediction Layer
+
+A fully-connected layer is needed to map the output features from `config.hidden_size` back to the `config.vocab_size` dimension to get logits before calculating the loss. Here the fully-connected layer and the `word_embedding` operation share weights, so the slicing of the fully connected layer weights is required to be consistent with that of the embedding layer.
+
+```python
+import mindspore.ops as ops
+from mindspore import nn
+class PanguAlpha_Head(nn.Cell):
+ """
+ Head for PanguAlpha to get the logits of each token in the vocab
+ Args:
+ config(PanguAlphaConfig): the config of network
+ Inputs:
+ state: the output of the backbone
+ embedding_table: the embedding table of the vocabulary
+ Returns:
+ logits: Tensor, the logits of the corresponding inputs
+ """
+
+ def __init__(self, config):
+ super(PanguAlpha_Head, self).__init__()
+ if config.word_emb_dp:
+ self.matmul = ops.MatMul(transpose_b=True).shard(((parallel_config.dp, 1), (1, 1)))
+ else:
+ self.matmul = ops.MatMul(transpose_b=True).shard(((parallel_config.dp, 1), (parallel_config.model_parallel, 1)))
+ self.hidden_size = config.hidden_size
+ self.log_softmax = ops.LogSoftmax(axis=-1)
+ self.dtype = config.compute_dtype
+ self.cast = ops.Cast()
+
+ def construct(self, state, embedding_table):
+ state = ops.Reshape()(state, (-1, self.hidden_size))
+ # output logits over vocabulary [bs*seq_length, vocab_size]
+ logits = self.matmul(state, self.cast(embedding_table, self.dtype))
+ return logits
+```
+
+In this article, we learn how to quickly implement distributed training of Transformer-like networks on the basis of a stand-alone script by configuring an operator sharding strategy. When specific to the network structure, embedding layer, decoder layer, residual layer and linear layer all have their own slicing features, and users can improve the distributed training and tuning efficiency by mastering the operator strategy configuration method.
+
+## References
+
+[1] Zeng W, Ren X, Su T, et al. PanGu-$\\alpha$: Large-scale Autoregressive Pretrained Chinese Language Models with Auto-parallel Computation. 2021.
diff --git a/tutorials/experts/source_en/parallel/parallel_training_quickstart.md b/tutorials/experts/source_en/parallel/parallel_training_quickstart.md
new file mode 100644
index 0000000000..8035505212
--- /dev/null
+++ b/tutorials/experts/source_en/parallel/parallel_training_quickstart.md
@@ -0,0 +1,281 @@
+# Quick Start Distributed Parallel Training
+
+
+
+## Overview
+
+This tutorial shows how to perform MindSpore distributed parallel training in a single 8-card **GPU** environment via **OpenMPI** with a simple example of a single hidden layer fully connected neural network.
+
+A tutorial on distributed parallel training of ResNet networks on a GPU platform is available at [Sample Distributed Parallel Training Basics (GPU)](https://www.mindspore.cn/tutorials/experts/en/master/parallel/train_gpu.html). In contrast: (1) the example uses a more complex ResNet network; (2) in addition to pull-up training by using OpenMPI, the example also introduces pull-up training by using a scripted approach.
+
+> You can download the complete sample code here:
+>
+>
+
+The directory structure is as follows:
+
+```text
+└─sample_code
+ ├─distributed_training_quickstart
+ ├── net.py
+ ├── run_with_mpi.sh
+ ...
+```
+
+where `net.py` is the network definition script and `run_with_mpi.sh` is the execution script.
+
+> In addition, tutorials for distributed parallel training on Ascend 910 platform are available in [Distributed Parallel Training Example (Ascend)](https://www.mindspore.cn/tutorials/experts/en/master/parallel/train_ascend.html) and [Distributed Parallel Training of Transformer Models](https://www.mindspore.cn/tutorials/experts/en/master/parallel/transformer.html).
+
+## Preparation
+
+### Datasets
+
+This sample example constructs a random set of input data and labels, with the following code:
+
+```python
+import numpy as np
+
+def get_dataset(batch_size, in_dim, out_dim, step_per_epoch):
+ np.random.seed(1)
+ input_data = np.random.rand(batch_size, in_dim).astype(np.float32)
+ label_data = np.random.rand(batch_size, out_dim).astype(np.float32)
+ def generate():
+ for _ in range(step_per_epoch):
+ yield (input_data, label_data)
+ return generate
+```
+
+where `step_per_epoch` is the number of steps performed per epoch for training, `batch_size` is the batch size, `in_dim` is the input vector length, and `out_dim` is the output vector length.
+
+### Network Structure
+
+The network code used in this sample is as follows:
+
+```python
+class Net(Cell):
+ """define net"""
+ def __init__(self, in_dim, hidden_dim, out_dim):
+ super().__init__()
+ self.in_dim = in_dim
+ self.hidden_dim = hidden_dim
+ self.out_dim = out_dim
+ self.weight = Parameter(initializer("normal", [self.in_dim, self.hidden_dim]), "w")
+ self.weight2 = Parameter(initializer("normal", [self.hidden_dim, self.out_dim]), "w2")
+ self.matmul = ops.MatMul()
+ self.relu = ops.ReLU()
+ self.matmul2 = ops.MatMul()
+
+ def construct(self, x):
+ out = self.matmul(x, self.weight)
+ out = self.relu(out)
+ out = self.matmul2(out, self.weight2)
+ return out
+```
+
+where `in_dim` is the network input dimension, `out_dim` is the output dimension, which needs to match the data dimension, and `hidden_dim` is the number of nodes in the hidden layer of the network.
+
+## Semi-automatic Parallel Distributed Training via OpenMPI
+
+### OpenMPI Environment Configuration
+
+[OpenMPI](https://www.open-mpi.org/) is a high-performance messaging library, a multi-process communication library adopted by MindSpore. For the related environment configuration, see [Running the Script through OpenMPI](https://www.mindspore.cn/tutorials/experts/en/master/parallel/train_ascend.html#running-the-script-through-openmpi).
+
+> In addition, MindSpore also supports distributed training without relying on OpenMPI. For the details, see [Training without Relying on OpenMPI](https://www.mindspore.cn/tutorials/experts/en/master/parallel/train_gpu.html#training-without-relying-on-openmpi).
+
+### Semi-automatic Parallelism
+
+Currently MindSpore supports four parallel modes, and see [Distributed Parallel Training Modes](https://www.mindspore.cn/tutorials/experts/en/master/parallel/introduction.html#distributed-parallel-training-mode-1) for details.
+
+This example demonstrates fully automatic parallelism, which is achieved by configuring `parallel_mode=ms.ParallelMode.AUTO_PARALLEL` through the `set_auto_parallel_context()` interface.
+There are three configurable parallel strategy search algorithms under fully automatic parallelism, see: [Fully automatic parallelism](https://www.mindspore.cn/tutorials/experts/en/master/parallel/introduction.html#fully-automatic-parallelism) for details. In this example, the **sharding strategy propagation algorithm** is selected, which is implemented by configuring `search_mode="sharding_propagation"` through the `set_auto_parallel_context()` interface, and manually setting the `matmul` operator sharding strategy. The sharding strategy of other operators is given by the parallel strategy search algorithm automatically. The code is as follows:
+
+```python
+class Net(Cell):
+ """define net"""
+ def __init__(self, in_dim, hidden_dim, out_dim):
+ super().__init__()
+ self.in_dim = in_dim
+ self.hidden_dim = hidden_dim
+ self.out_dim = out_dim
+ self.weight = Parameter(initializer("normal", [self.in_dim, self.hidden_dim]), "w")
+ self.weight2 = Parameter(initializer("normal", [self.hidden_dim, self.out_dim]), "w2")
+
+ # Set the sharding strategy manually for the matmul operator
+ # where (2, 4) means that the input data of matmul operator is sliced into two parts in batch dimension and four parts in width dimension
+ # (4, 1) indicates that the weights of the matmul operator are sliced into four parts in the HEIGHT dimension
+ self.matmul = ops.MatMul().shard(((2, 4), (4, 1)))
+
+ self.relu = ops.ReLU()
+ self.matmul2 = ops.MatMul()
+
+ def construct(self, x):
+ out = self.matmul(x, self.weight)
+ out = self.relu(out)
+ out = self.matmul2(out, self.weight2)
+ return out
+```
+
+where the `shard()` method is described in detail in [Principles of Automatic Parallelism](https://www.mindspore.cn/docs/en/master/design/distributed_training_design.html#principle-of-automatic-parallelism). The inference introduction is in [functional operator sharding](https://www.mindspore.cn/tutorials/experts/en/master/parallel/pynative_shard_function_parallel.html)
+
+For the parallel sharding strategy set in the above example, the `matmul` operator computation process for the forward propagation process in a single-machine 8-card environment is schematically shown as follows:
+
+
+
+The top half of the diagram shows the data sharding, and the bottom half shows the calculation and communication process performed by each GPU card at logical number (rank) 0-7.
+
+#### Code Running
+
+In this example, the loss function, optimizer and training procedure are defined similarly to single card training, with the following code:
+
+```python
+var_step_per_epoch = 4
+var_single_batch_size = 2
+var_in_dim = 32
+var_hidden_dim = 16
+var_out_dim = 16
+
+ms.set_context(mode=ms.GRAPH_MODE, device_target="GPU", save_graphs=True, save_graphs_path="../saved_graph")
+
+# Single-machine 8-card environment. Parallel mode is fully automatic parallelism, and strategy search is set to strategy propagation algorithm
+ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.AUTO_PARALLEL, search_mode="sharding_propagation", dataset_strategy="data_parallel")
+
+# Initialize the communication environment and get the logical serial number of the current card, i.e. rank_id
+init("nccl")
+rank_id = get_rank()
+
+# Randomly constructed datasets
+fake_dataset = get_dataset(var_single_batch_size, var_step_per_epoch, var_in_dim, var_out_dim)
+dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
+
+# Define the network structure
+net = Net(var_in_dim, var_hidden_dim, var_out_dim)
+
+# Define the loss function and callback
+loss = MSELoss()
+callback = [LossMonitor(), ModelCheckpoint(directory="{}".format(rank_id))]
+
+# Define the optimizer
+learning_rate = 0.2
+momentum = 0.1
+epoch_size = 5
+opt = Momentum(net.trainable_params(), learning_rate, momentum)
+
+# Model training
+model = Model(net, loss_fn=loss, optimizer=opt)
+model.train(epoch_size, dataset, callbacks=callback, dataset_sink_mode=False)
+```
+
+Training can be performed with `mpirun` command of OpenMPI, as specified in the script `run_with_mpi.sh`.
+
+After running, the script is performed in the background and the training log is saved in the `. /device` directory, and the model of the card with the logical number `rank_id` is saved in the `. /device/{rank_id}` directory.
+
+In addition, `save_graphs=True` is configured via the `ms.set_context()` interface to save the model intermediate representation `MindIR`, and the `MindIR` of the card with the logical number `rank_id` is saved in the `. /saved_graph/{rank_id}` directory. MindSpore IR (MindIR) is a program representation between the source language and the target language during the compilation of MindSpore framework programs to facilitate program analysis and optimization by the compiler, see [MindIR](https://www.mindspore.cn/docs/en/master/design/mindir.html).
+
+#### Verification
+
+After running the `run_with_mpi.sh` script, the recorded loss should decrease, e.g.
+
+```text
+# ./device/train.log: #
+...
+epoch: 3 step: 2, loss is 0.367389976978302
+epoch: 3 step: 2, loss is 0.367389976978302
+epoch: 3 step: 2, loss is 0.367389976978302
+epoch: 3 step: 2, loss is 0.367389976978302
+epoch: 3 step: 2, loss is 0.367389976978302
+epoch: 3 step: 2, loss is 0.367389976978302
+epoch: 3 step: 2, loss is 0.367389976978302
+epoch: 3 step: 2, loss is 0.367389976978302
+epoch: 3 step: 3, loss is 0.35383114218711853
+epoch: 3 step: 3, loss is 0.35383114218711853
+epoch: 3 step: 3, loss is 0.35383114218711853
+epoch: 3 step: 3, loss is 0.35383114218711853
+epoch: 3 step: 3, loss is 0.35383114218711853
+epoch: 3 step: 3, loss is 0.35383114218711853
+epoch: 3 step: 3, loss is 0.35383114218711853
+epoch: 3 step: 3, loss is 0.35383114218711853
+epoch: 3 step: 4, loss is 0.3312329947948456
+epoch: 3 step: 4, loss is 0.3312329947948456
+epoch: 3 step: 4, loss is 0.3312329947948456
+epoch: 3 step: 4, loss is 0.3312329947948456
+epoch: 3 step: 4, loss is 0.3312329947948456
+epoch: 3 step: 4, loss is 0.3312329947948456
+epoch: 3 step: 4, loss is 0.3312329947948456
+epoch: 3 step: 4, loss is 0.3312329947948456
+epoch: 4 step: 1, loss is 0.295515775680542
+epoch: 4 step: 1, loss is 0.295515775680542
+epoch: 4 step: 1, loss is 0.295515775680542
+epoch: 4 step: 1, loss is 0.295515775680542
+epoch: 4 step: 1, loss is 0.295515775680542
+epoch: 4 step: 1, loss is 0.295515775680542
+epoch: 4 step: 1, loss is 0.295515775680542
+epoch: 4 step: 1, loss is 0.295515775680542
+epoch: 4 step: 2, loss is 0.2440134435892105
+epoch: 4 step: 2, loss is 0.2440134435892105
+epoch: 4 step: 2, loss is 0.2440134435892105
+epoch: 4 step: 2, loss is 0.2440134435892105
+epoch: 4 step: 2, loss is 0.2440134435892105
+epoch: 4 step: 2, loss is 0.2440134435892105
+epoch: 4 step: 2, loss is 0.2440134435892105
+epoch: 4 step: 2, loss is 0.2440134435892105
+...
+```
+
+You can check the configuration of the sharding strategy for each operator in `. /saved_graph/rank_x/step_parallel_begin_xxxx.ir` to see the configuration of the sharding strategy for each operator, e.g.
+
+```text
+# ./saved_graph/rank_0/step_parallel_begin_0041.ir: #
+...
+%3(out) = MatMul(%1, %2) {instance name: matmul} primitive_attrs: {input_names: [x1, x2], out_strategy: None, transpose_x2: false, transpose_b: false, in_strategy: ((2, 4), (4, 1)), output_names: [output], transpose_a: false, transpose_x1: false} {in_strategy: ((2, 4), (4, 1))}
+ : (, ) -> ()
+ # scope: (Default/network-WithLossCell/_backbone-Net)
+%4(out) = ReLU(%3) {instance name: relu} primitive_attrs: {output_names: [output], input_names: [x]} {in_strategy: ((2, 4))}
+ : () -> ()
+ # scope: (Default/network-WithLossCell/_backbone-Net)
+%5([CNode]472) = Load($(@1_construct_wrapper.337:para4_w2), %para12_u)
+ : ([, ) -> ()
+ # scope: (Default/network-WithLossCell)
+%6(out) = MatMul(%4, %5) {instance name: matmul2} primitive_attrs: {output_names: [output], transpose_a: false, input_names: [x1, x2], transpose_x2: false, transpose_x1: false, transpose_b: false} {in_strategy: ((2, 4), (4, 1))}
+ : (, ) -> ()
+ # scope: (Default/network-WithLossCell/_backbone-Net)
+...
+```
+
+It can be seen that the `relu` operator corresponding to the `%4(out)` line and the `matmul2` operator corresponding to the `%6(out)` line are automatically configured with a sharding strategy.
+
+Further, you can view `. /saved_graph/rank_x/18_execute_xxxx.ir` to see the actual execution of the slice operator dimension for each card, e.g.
+
+```text
+# ./saved_graph/rank_0/18_execute_0185.ir: #
+...
+%12(equivout) = MatMul(%10, %11) {instance name: matmul} primitive_attrs: {input_names: [x1, x2], out_strategy: None, transpose_x2: false, transpose_b: false, in_strategy: ((2, 4), (4, 1)), output_names: [output], transpose_a: false, transpose_x1: false} {in_strategy: ((2, 4), (4, 1))}
+ : (, ) -> ()
+ # scope: (Default/network-WithLossCell/_backbone-Net)
+ # In file /home/jenkins/my_dir/parallel_training_quick_start/device/./matmul.py(45)/ out = self.matmul(x, self.weight)/
+ # In file /home/miniconda3/envs/my_env/lib/python3.9/site-packages/mindspore/nn/wrap/cell_wrapper.py(114)/ out = self._backbone(data)/
+ # In file /home/miniconda3/envs/my_env/lib/python3.9/site-packages/mindspore/nn/wrap/cell_wrapper.py(376)/ loss = self.network(*inputs)/
+%13(equiv[CNode]520) = AllReduce(%12) {instance name: forward_op_11795743325248501408} primitive_attrs: {group: 4-6301172352641561019, fusion: 0, op: sum, rank_list: (0, 1, 2, 3), group_ranks: 0-1-2-3, index: 0, group_rank_ids: (0, 1, 2, 3), no_eliminate: true} cnode_attrs: {comm_reuse: true}
+ : () -> ()
+ # scope: (Default/network-WithLossCell/_backbone-Net)
+%14(equiv[CNode]519) = StridedSlice(%13, (0, 0), (8, 4), (1, 1)) {instance name: redistribution_op_16390315056374637535StridedSlice} primitive_attrs: {new_axis_mask: 0, shrink_axis_mask: 0, end_mask: 0, input_names: [x, begin, end, strides], output_names: [output], keep_value_node_input: true, begin_mask: 0, ellipsis_mask: 0}
+ : (, (0, 0), elements_use_flags: {ptr: 0x560e8fef5fa0, value: [const vector][1, 1]}}, node={ValueNode (0, 0), elements_use_flags: {ptr: 0x560e8fef5fa0, value: [const vector][1, 1]}}, node={}, node={ValueNode (0, 0), elements_use_flags: {ptr: 0x560e8fef5fa0, value: [const vector][1, 1]}}, node={ValueNode (0, 0), elements_use_flags: {ptr: 0x560e8fef5fa0, value: [const vector][1, 1]}}, node={ValueNode (0, 0), elements_use_flags: {ptr: 0x560e8fef5fa0, value: [const vector][1, 1]}}, node={ValueNode (0, 0), elements_use_flags: {ptr: 0x560e8fef5fa0, value: [const vector][1, 1]}}, node={ValueNode (0, 0), elements_use_flags: {ptr: 0x560e8fef5fa0, value: [const vector][1, 1]}}, node={}}>, (8, 4), elements_use_flags: {ptr: 0x560e8fed50d0, value: [const vector][1, 1]}}, node={ValueNode (8, 4), elements_use_flags: {ptr: 0x560e8fed50d0, value: [const vector][1, 1]}}, node={}, node={ValueNode (8, 4), elements_use_flags: {ptr: 0x560e8fed50d0, value: [const vector][1, 1]}}, node={ValueNode (8, 4), elements_use_flags: {ptr: 0x560e8fed50d0, value: [const vector][1, 1]}}, node={ValueNode (8, 4), elements_use_flags: {ptr: 0x560e8fed50d0, value: [const vector][1, 1]}}, node={ValueNode (8, 4), elements_use_flags: {ptr: 0x560e8fed50d0, value: [const vector][1, 1]}}, node={ValueNode (8, 4), elements_use_flags: {ptr: 0x560e8fed50d0, value: [const vector][1, 1]}}, node={}}>, (1, 1), elements_use_flags: {ptr: 0x560e8ffb4ff0, value: [const vector][1, 1]}}, node={ValueNode (1, 1), elements_use_flags: {ptr: 0x560e8ffb4ff0, value: [const vector][1, 1]}}, node={}, node={ValueNode (1, 1), elements_use_flags: {ptr: 0x560e8ffb4ff0, value: [const vector][1, 1]}}, node={ValueNode (1, 1), elements_use_flags: {ptr: 0x560e8ffb4ff0, value: [const vector][1, 1]}}, node={ValueNode (1, 1), elements_use_flags: {ptr: 0x560e8ffb4ff0, value: [const vector][1, 1]}}, node={ValueNode (1, 1), elements_use_flags: {ptr: 0x560e8ffb4ff0, value: [const vector][1, 1]}}, node={ValueNode (1, 1), elements_use_flags: {ptr: 0x560e8ffb4ff0, value: [const vector][1, 1]}}, node={}}>) -> ()
+ # scope: (Default/network-WithLossCell/_backbone-Net)
+%15(equivout) = ReLU(%14) {instance name: relu} primitive_attrs: {output_names: [output], input_names: [x]} {in_strategy: ((2, 4))}
+ : () -> ()
+ # scope: (Default/network-WithLossCell/_backbone-Net)
+ # In file /home/jenkins/my_dir/parallel_training_quick_start/device/./matmul.py(46)/ out = self.relu(out)/
+ # In file /home/miniconda3/envs/my_env/lib/python3.9/site-packages/mindspore/nn/wrap/cell_wrapper.py(114)/ out = self._backbone(data)/
+ # In file /home/miniconda3/envs/my_env/lib/python3.9/site-packages/mindspore/nn/wrap/cell_wrapper.py(376)/ loss = self.network(*inputs)/
+%16(equiv[CNode]472) = Load(%para4_w2, U)
+ : (][, ) -> ()
+ # scope: (Default/network-WithLossCell)
+%17(equivout) = MatMul(%15, %16) {instance name: matmul2} primitive_attrs: {output_names: [output], transpose_a: false, input_names: [x1, x2], transpose_x2: false, transpose_x1: false, transpose_b: false} {in_strategy: ((2, 4), (4, 1))}
+ : (, ) -> ()
+ # scope: (Default/network-WithLossCell/_backbone-Net)
+ # In file /home/jenkins/my_dir/parallel_training_quick_start/device/./matmul.py(47)/ out = self.matmul2(out, self.weight2)/
+ # In file /home/miniconda3/envs/my_env/lib/python3.9/site-packages/mindspore/nn/wrap/cell_wrapper.py(114)/ out = self._backbone(data)/
+ # In file /home/miniconda3/envs/my_env/lib/python3.9/site-packages/mindspore/nn/wrap/cell_wrapper.py(376)/ loss = self.network(*inputs)/
+...
+```
+
+It can be seen that the dimension of the `matmul` operator corresponding to the `%12(equivout)` line is the same as that shown in the figure.
diff --git a/tutorials/experts/source_zh_cn/parallel/pangu_alpha.md b/tutorials/experts/source_zh_cn/parallel/pangu_alpha.md
index c0762dbe17..0013fce5de 100644
--- a/tutorials/experts/source_zh_cn/parallel/pangu_alpha.md
+++ b/tutorials/experts/source_zh_cn/parallel/pangu_alpha.md
@@ -118,7 +118,7 @@ Self-Attention可以直接通过`mindspore.nn.transformer.MultiHeadAttention`实
- `Softmax`以及`BatchMatMul`
- 在计算Attention的过程中,通过`BatchMatMul`实现Query和Key向量的矩阵乘法。此处,`softmax`的输入shape为`[batch, sequence_length, num_heads, size_per_head]`。因为每个`head`之间在计算attention score时是独立的,所以可以在`batch`维度和`heads`维度对`softmax`算子进行切分。
+ 在计算Attention的过程中,通过`BatchMatMul`实现Query和Key向量的矩阵乘法。此处,`softmax`的输入shape为`[batch, sequence_length, num_heads, size_per_head]`。因为每个`head`之间在计算attention score时是独立的,所以可以在`batch`维度和`heads`维度对`softmax`算子进行切分。
```python
self.softmax = nn.Softmax()
@@ -327,4 +327,4 @@ class PanguAlpha_Head(nn.Cell):
## 参考文献
-[1] Zeng W , Ren X , Su T , et al. PanGu-$\\alpha$: Large-scale Autoregressive Pretrained Chinese Language Models with Auto-parallel Computation. 2021.
+[1] Zeng W, Ren X, Su T, et al. PanGu-$\\alpha$: Large-scale Autoregressive Pretrained Chinese Language Models with Auto-parallel Computation. 2021.
diff --git a/tutorials/experts/source_zh_cn/parallel/parallel_training_quickstart.md b/tutorials/experts/source_zh_cn/parallel/parallel_training_quickstart.md
index 6a9b65a4f5..bfc83dfb25 100644
--- a/tutorials/experts/source_zh_cn/parallel/parallel_training_quickstart.md
+++ b/tutorials/experts/source_zh_cn/parallel/parallel_training_quickstart.md
@@ -24,7 +24,7 @@
其中,`net.py`为网络定义脚本,`run_with_mpi.sh`是执行脚本。
->此外,在Ascend 910平台上进行分布式并行训练的教程详见[分布式并行训练基础样例(Ascend)](https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/train_ascend.html)与[分布式并行训练Transformer模型](https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/transformer.html)。
+> 此外,在Ascend 910平台上进行分布式并行训练的教程详见[分布式并行训练基础样例(Ascend)](https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/train_ascend.html)与[分布式并行训练Transformer模型](https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/transformer.html)。
## 准备环节
@@ -80,7 +80,7 @@ class Net(Cell):
[OpenMPI](https://www.open-mpi.org/)是一种高性能消息传递库,是MindSpore采用的多进程通讯库,相关环境配置见:[通过OpenMPI运行脚本](https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/train_ascend.html#通过openmpi运行脚本)。
->此外,MindSpore还支持不依赖OpenMPI进行分布式训练,详见:[不依赖OpenMPI进行训练](https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/train_gpu.html#不依赖openmpi进行训练)。
+> 此外,MindSpore还支持不依赖OpenMPI进行分布式训练,详见:[不依赖OpenMPI进行训练](https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/train_gpu.html#不依赖openmpi进行训练)。
### 半自动并行
@@ -115,7 +115,7 @@ class Net(Cell):
return out
```
-其中,`shard()`方法的详细介绍见[自动并行原理](https://www.mindspore.cn/docs/zh-CN/master/design/distributed_training_design.html#自动并行原理),接口介绍见[函数式算子切分](https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/pynative_shard_function_parallel.html)
+其中,`shard()`方法的详细介绍见[自动并行原理](https://www.mindspore.cn/docs/zh-CN/master/design/distributed_training_design.html#自动并行原理),接口介绍见[函数式算子切分](https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/pynative_shard_function_parallel.html)。
对于上述例子中设置的并行切分策略,在单机8卡环境下,前向传播过程的`matmul`算子计算过程示意图如下:
@@ -279,5 +279,3 @@ epoch: 4 step: 2, loss is 0.2440134435892105
```
可以看到,`%12(equivout)`行对应的`matmul`算子维度与图示中一致。
-
-
--
Gitee
]