From 672a02f7364287b90f3feb2b5abeacdd68ff750d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=83=A1=E7=8A=87?= Date: Fri, 14 Mar 2025 02:45:47 +0800 Subject: [PATCH] modify semi auto --- .../mindspore/source_en/model_train/index.rst | 5 +- .../model_train/parallel/operator_parallel.md | 201 +++------ .../parallel/optimizer_parallel.md | 211 +-------- .../model_train/parallel/overview.md | 78 ---- .../model_train/parallel/pipeline_parallel.md | 418 +----------------- .../parallel/semi_auto_parallel.rst | 22 - .../source_zh_cn/model_train/index.rst | 5 +- .../model_train/parallel/operator_parallel.md | 199 +++------ .../parallel/optimizer_parallel.md | 212 +-------- .../model_train/parallel/overview.md | 79 ---- .../model_train/parallel/pipeline_parallel.md | 418 +----------------- .../parallel/semi_auto_parallel.rst | 22 - 12 files changed, 133 insertions(+), 1737 deletions(-) delete mode 100644 docs/mindspore/source_en/model_train/parallel/overview.md delete mode 100644 docs/mindspore/source_en/model_train/parallel/semi_auto_parallel.rst delete mode 100644 docs/mindspore/source_zh_cn/model_train/parallel/overview.md delete mode 100644 docs/mindspore/source_zh_cn/model_train/parallel/semi_auto_parallel.rst diff --git a/docs/mindspore/source_en/model_train/index.rst b/docs/mindspore/source_en/model_train/index.rst index 41848bfdcb..ed118831a4 100644 --- a/docs/mindspore/source_en/model_train/index.rst +++ b/docs/mindspore/source_en/model_train/index.rst @@ -54,10 +54,11 @@ Model Building and Training :hidden: :caption: Distributed Parallelism - parallel/overview parallel/startup_method parallel/data_parallel - parallel/semi_auto_parallel + parallel/operator_parallel + parallel/optimizer_parallel + parallel/pipeline_parallel parallel/auto_parallel parallel/manual_parallel parallel/parameter_server_training diff --git a/docs/mindspore/source_en/model_train/parallel/operator_parallel.md b/docs/mindspore/source_en/model_train/parallel/operator_parallel.md index f63c8c21ef..395a2c9605 100644 --- a/docs/mindspore/source_en/model_train/parallel/operator_parallel.md +++ b/docs/mindspore/source_en/model_train/parallel/operator_parallel.md @@ -14,7 +14,7 @@ For a list of operators that currently support parallelism, see [Usage Constrain Related interfaces: -1. `mindspore.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)`: Sets the semi-automatic parallel mode, which must be called before initializing the network. +1. `AutoParallel(network, parallel_mode="semi_auto")`: Encapsulates the specified parallel mode via static graph parallelism, where `network` is the top-level `Cell` or function to be encapsulated, and `parallel_mode` takes the value `semi_auto`, indicating a semi-automatic parallel mode. The return type of this interface is `Cell`. 2. `mindspore.ops.Primitive.shard()`: Specify the operator slicing strategy, see [Basic Principle](#basic-principle) in this chapter for detailed examples. @@ -49,9 +49,7 @@ Users can set the sharding strategy of the operator by using the shard() interfa ```python import mindspore.nn as nn from mindspore import ops -import mindspore as ms - -ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, device_num=4) +from mindspore.parallel.auto_parallel import AutoParallel class DenseMatMulNet(nn.Cell): def __init__(self): @@ -62,181 +60,90 @@ class DenseMatMulNet(nn.Cell): y = self.matmul1(x, w) z = self.matmul2(y, v) return z + +net = DenseMatMulNet() +paralell_net = AutoParallel(net, parallel_mode='semi_auto') ``` In the above example, the user computes two consecutive two-dimensional matrix multiplications on 4 cards: `Z = (X * W) * V` . For the first matrix multiplication `Y = X * W`, the user wants to slice X by rows in 4 parts (i.e. data parallelism), while for the second matrix multiplication `Z = Y * V`, the user wants to slice V by columns in 4 parts (i.e. model parallelism): Since the Tensor Layout output from the first operator is the 0th dimensional sliced to the cluster, while the second operator requires the first input Tensor to be replicated on the cluster. So in the graph compilation stage, the difference in Tensor Layout between the two operator outputs/inputs is automatically recognized, thus the algorithm for Tensor redistribution is automatically derived. The Tensor redistribution required for this example is an AllGather operator (note: MindSpore AllGather operator automatically merges multiple input Tensors in dimension 0) -## Operation Practices +# Higher-order Operator-level Parallelism -The following is an illustration of operator-level parallelism by taking an Ascend or GPU single-machine 8-card as an example. +[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/mindspore/source_en/model_train/parallel/advanced_operator_parallel.md) -### Sample Code Description +## Overview -> Download the complete sample code here: [distributed_operator_parallel](https://gitee.com/mindspore/docs/tree/master/docs/sample_code/distributed_operator_parallel). +[Operator-level Parallelism](https://www.mindspore.cn/docs/en/master/model_train/parallel/operator_parallel.html) is a commonly used parallelism technique in large model training inference, which can slice the tensor across multiple cards and effectively reduce GPU memory on a single card. -The directory structure is as follows: +The configuration of operator-level parallelism in MindSpore is implemented through mindspore.ops.Primitive.shard() interface, which describes the way each input tensor is sliced through tuples, is suitable for most scenarios and has a simpler configuration process. However, this slicing approach only describes the tensor slicing logic, but hides the specific arrangement of the tensor on the device rank. Therefore, it has limitations in expressing the mapping relationship between tensor slicing and device ranking, and cannot meet the requirements of some complex scenarios. -```text -└─ sample_code - ├─ distributed_operator_parallel - ├── distributed_operator_parallel.py - └── run.sh - ... -``` +To cope with these complex scenarios, this tutorial introduces a higher-order operator-level parallel configuration method with an open device arrangement description. -Among them, `distributed_operator_parallel.py` is the script that defines the network structure and the training process. `run.sh` is the execution script. +> Hardware platforms supported for advanced operator-level parallel models include Ascend, GPU, and need to be run in Graph mode. -### Configuring the Distributed Environment +## Background -Specify the run mode, run device, run card number, etc. through the context interface. Unlike single-card scripts, parallel scripts also need to specify the parallel mode `parallel_mode` to be semi-automatic parallel mode, and initialize HCCL or NCCL communication through init. +[Operator-level Parallelism](https://www.mindspore.cn/docs/en/master/model_train/parallel/operator_parallel.html) describes MindSpore basic slicing logic for tensors, but cannot express all the slicing scenarios. For example, for a 2D tensor "[[a0, a1, a2, a3], [a4, a5, a6, a7]]", the tensor layout is shown below: -In addition, on the Ascend hardware platform, a portion of the memory needs to be set aside in order to ensure that there is sufficient device memory for communications. `max_size` is set to limit the maximum amount of device memory a model can have, and GPU does not need to set. If `device_target` is not set here, it will be automatically specified as the backend hardware device corresponding to the MindSpore package. +![image](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/mindspore/source_zh_cn/model_train/parallel/images/advanced_operator_parallel_view1.PNG) -```python -import mindspore as ms -from mindspore.communication import init - -ms.set_context(mode=ms.GRAPH_MODE) -ms.runtime.set_memory(max_size="28GB") -ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL) -init() -ms.set_seed(1) -``` +*Figure: Schematic of 2D tensor arrangement* -### Loading the Dataset +It can be seen that the 0-axis of the tensor, e.g. "[a0, a1, a2, a3]" slices to the discontinuous card "[Rank0, Rank4, Rank2, Rank6]" and the tensor is sliced according to strategy=(2, 4), the arrangement should be as follows: -In the operator-level parallel scenario, the dataset is loaded in the same way as single-card is loaded, with the following code: +![image](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/mindspore/source_zh_cn/model_train/parallel/images/advanced_operator_parallel_view2.PNG) -```python -import os -import mindspore.dataset as ds - -def create_dataset(batch_size): - dataset_path = os.getenv("DATA_PATH") - dataset = ds.MnistDataset(dataset_path) - image_transforms = [ - ds.vision.Rescale(1.0 / 255.0, 0), - ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)), - ds.vision.HWC2CHW() - ] - label_transform = ds.transforms.TypeCast(ms.int32) - dataset = dataset.map(image_transforms, 'image') - dataset = dataset.map(label_transform, 'label') - dataset = dataset.batch(batch_size) - return dataset - -data_set = create_dataset(32) -``` +*Figure: Schematic of a 2D tensor arranged according to a sharding strategy* -### Defining the Network +Therefore, directly slicing the input and output tensor of the operator according to the number of slices fails to express some slicing scenarios with special requirements. -In the current semi-automatic parallel mode, the network needs to be defined with ops operators(Primitive). Users can manually configure the slicing strategy for some operators based on a single-card network, e.g., the network structure after configuring the strategy is: +## Interface Configuration -```python -import mindspore as ms -from mindspore import nn, ops +In order to express sharding as in the above scenario, functional extensions are made to the [shard](https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.shard.html) interface. -class Network(nn.Cell): - def __init__(self): - super().__init__() - self.flatten = ops.Flatten() - self.fc1_weight = ms.Parameter(initializer("normal", [28*28, 512], ms.float32)) - self.fc2_weight = ms.Parameter(initializer("normal", [512, 512], ms.float32)) - self.fc3_weight = ms.Parameter(initializer("normal", [512, 10], ms.float32)) - self.matmul1 = ops.MatMul() - self.relu1 = ops.ReLU() - self.matmul2 = ops.MatMul() - self.relu2 = ops.ReLU() - self.matmul3 = ops.MatMul() - - def construct(self, x): - x = self.flatten(x) - x = self.matmul1(x, self.fc1_weight) - x = self.relu1(x) - x = self.matmul2(x, self.fc2_weight) - x = self.relu2(x) - logits = self.matmul3(x, self.fc3_weight) - return logits - -net = Network() -net.matmul1.shard(((2, 4), (4, 1))) -net.relu1.shard(((4, 1),)) -net.matmul2.shard(((1, 8), (8, 1))) -net.relu2.shard(((8, 1),)) -``` +The parameters in_strategy and out_strategy both additionally receive the new quantity type tuple(Layout) type. [Layout](https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.Layout.html) is initialized using the device matrix, while requiring an alias for each axis of the device matrix. For example: "layout = Layout((8, 4, 4), name = ("dp", "sp", "mp"))" means that the device has 128 cards in total, which are arranged in the shape of (8, 4, 4), and aliases "dp", "sp", "mp" are given to each axis. -The `ops.MatMul()` and `ops.ReLU()` operators for the above networks are configured with slicing strategy, in the case of `net.matmul1.shard(((2, 4), (4, 1)))`, which has a slicing strategy of: rows of the first input are sliced in 2 parts and columns in 4 parts; rows of the second input are sliced in 4 parts. For `net.relu2.shard(((8, 1),))`, its slicing strategy is: the row of the first input is sliced in 8 parts. Note that since the two `ops.ReLU()` here have different slicing strategies, have to be defined twice separately. +By passing in the aliases for these axes when calling Layout, each tensor determines which axis of the device matrix each dimension is mapped to based on its shape (shape), and the corresponding number of slice shares. For example: -### Training the Network +- "dp" denotes 8 cuts within 8 devices in the highest dimension of the device layout. +- "sp" denotes 4 cuts within 4 devices in the middle dimension of the device layout. +- "mp" denotes 4 cuts within 4 devices in the lowest dimension of the device layout. -In this step, we need to define the loss function, the optimizer, and the training process, which is the same as that of the single-card: +In particular, one dimension of the tensor may be mapped to multiple dimensions of the device to express multiple slices in one dimension. + +The above example of "[[a0, a1, a2, a3], [a4, a5, a6, a7]]" sliced to discontinuous cards can be expressed by Layout as follows: ```python -import mindspore as ms -from mindspore import nn - -optimizer = nn.SGD(net.trainable_params(), 1e-2) -loss_fn = nn.CrossEntropyLoss() - -def forward_fn(data, target): - logits = net(data) - loss = loss_fn(logits, target) - return loss, logits - -grad_fn = ms.value_and_grad(forward_fn, None, net.trainable_params(), has_aux=True) - -@ms.jit -def train_step(inputs, targets): - (loss_value, _), grads = grad_fn(inputs, targets) - optimizer(grads) - return loss_value - -for epoch in range(10): - i = 0 - for image, label in data_set: - loss_output = train_step(image, label) - if i % 10 == 0: - print("epoch: %s, step: %s, loss is %s" % (epoch, i, loss_output)) - i += 1 +from mindspore import Layout +a = [[a0, a1, a2, a3], [a4, a5, a6, a7]] +layout = Layout((2, 2, 2), name = ("dp", "sp", "mp")) +a_strategy = layout("mp", ("sp", "dp")) ``` -### Running the Single-machine Eight-card Script - -Next, the corresponding scripts are invoked by commands. As an example, the 8-card distributed training script uses the `mpirun` startup method for distributed training: +It can be seen that the "[a0, a1, a2, a3]" of the tensor a is sliced twice to the "sp" and "mp" axes of the device, so that the result comes out as: -```bash -bash run.sh -``` +![image](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/mindspore/source_zh_cn/model_train/parallel/images/advanced_operator_parallel_view1.PNG) -After training, the log files are saved to the `log_output` directory, where part of the file directory structure is as follows: +The following is exemplified by a concrete example in which the user computes a two-dimensional matrix multiplication over 8 cards: `Y = (X * W)` , where the devices are organized according to `2 * 2 * 2`, and the cut of X coincides with the cut of the tensor a. The code is as follows: -```text -└─ log_output - └─ 1 - ├─ rank.0 - | └─ stdout - ├─ rank.1 - | └─ stdout -... -``` +```python +import mindspore.nn as nn +from mindspore import ops, Layout +from mindspore.parallel.auto_parallel import AutoParallel -The results on the Loss section are saved in `log_output/1/rank.*/stdout`, and example is as follows: - -```text -epoch: 0, step: 0, loss is 2.3026192 -epoch: 0, step: 10, loss is 2.2928686 -epoch: 0, step: 20, loss is 2.279024 -epoch: 0, step: 30, loss is 2.2548661 -epoch: 0, step: 40, loss is 2.192434 -epoch: 0, step: 50, loss is 2.0514572 -epoch: 0, step: 60, loss is 1.7082529 -epoch: 0, step: 70, loss is 1.1759918 -epoch: 0, step: 80, loss is 0.94476485 -epoch: 0, step: 90, loss is 0.73854053 -epoch: 0, step: 100, loss is 0.71934 -... -``` +class DenseMatMulNet(nn.Cell): + def __init__(self): + super(DenseMatMulNet, self).__init__() + layout = Layout((2, 2, 2), name = ("dp", "sp", "mp")) + in_strategy = (layout("mp", ("sp", "dp")), layout(("sp", "dp"), "None")) + out_strategy = (layout(("mp", "sp", "dp"), "None"), ) + self.matmul1 = ops.MatMul().shard(in_strategy, out_strategy) + def construct(self, x, w): + y = self.matmul1(x, w) + return y -Other startup methods such as dynamic networking and `rank table` startup can be found in [startup methods](https://www.mindspore.cn/docs/en/master/model_train/parallel/startup_method.html). \ No newline at end of file +net = DenseMatMulNet() +paralell_net = AutoParallel(net, parallel_mode='semi_auto') +``` \ No newline at end of file diff --git a/docs/mindspore/source_en/model_train/parallel/optimizer_parallel.md b/docs/mindspore/source_en/model_train/parallel/optimizer_parallel.md index 6c4967e926..73eaa54acc 100644 --- a/docs/mindspore/source_en/model_train/parallel/optimizer_parallel.md +++ b/docs/mindspore/source_en/model_train/parallel/optimizer_parallel.md @@ -18,9 +18,11 @@ In either mode, the optimizer parallelism does not affect the compute graph of t Related interfaces: -1. `mindspore.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, enable_parallel_optimizer=True)`: Set semi-automatic parallel mode and enable optimizer parallel, must be called before initializing the network. When `enable_parallel_optimizer` is turned on, the optimizer slices by default for all **parameters occupying no less than 64KB** of memory. See [Advanced Interfaces](#advanced-interfaces) in this chapter. +1. `AutoParallel(network, parallel_mode="semi_auto")`: Encapsulates the specified parallel mode via static graph parallelism, where `network` is the top-level `Cell` or function to be encapsulated, and `parallel_mode` takes the value `semi_auto`, indicating a semi-automatic parallel mode. The return type of this interface is `Cell`. -2. `Cell.set_comm_fusion(fusion_type=NUM)`: In automatic/semi-automatic mode, each parameter generates a corresponding AllGather operation and ReduceScatter operation. These communication operators are automatically inserted by the auto-parallel framework. However, as the number of parameters increases, the number of corresponding communication operators also increases, and the scheduling and startup of operators generated by communication operations incurs more overhead. Therefore, it is possible to manually configure fusion markers NUM for the AllGather and ReduceScatter operations corresponding to parameters within each `Cell` through the `set_comm_fusion` method provided by `Cell` in order to improve communication efficiency. MindSpore will fuse the communication operators corresponding to the same NUM parameters to minimize communication overhead. +2. `AutoParallel.hsdp(shard_size=-1, threshold=64, optimizer_level="level1")`:Configures and enables optimizer parallelism through this interface. `shard_size` specifies the size of the communication group for optimizer weight sharding. `threshold` defines the minimum memory size (in KB) required for a parameter to be sharded. Parameters smaller than this threshold will not be sharded during parameter partitioning. `optimizer_level` is used to specify the splitting level for optimizer sharding. When optimizer_level=`level1`, splitting is performed on weights and optimizer state. When optimizer_level=`level2`, splitting is performed on weights, optimizer state, and gradients. When optimizer_level=`level3`, splitting is performed on weights, optimizer state,gradients, additionally, before the backward pass, the weights are further applied with allgather communication to release the memory used by the forward pass allgather. + +3. `Cell.set_comm_fusion(fusion_type=NUM)`: In automatic/semi-automatic mode, each parameter generates a corresponding AllGather operation and ReduceScatter operation. These communication operators are automatically inserted by the auto-parallel framework. However, as the number of parameters increases, the number of corresponding communication operators also increases, and the scheduling and startup of operators generated by communication operations incurs more overhead. Therefore, it is possible to manually configure fusion markers NUM for the AllGather and ReduceScatter operations corresponding to parameters within each `Cell` through the `set_comm_fusion` method provided by `Cell` in order to improve communication efficiency. MindSpore will fuse the communication operators corresponding to the same NUM parameters to minimize communication overhead. ## Basic Principles @@ -46,208 +48,3 @@ In the test validation of the actual network training, we found that the memory Optimizer parameter slicing implemented by MindSpore also has the advantage of being mixed with operator-level parallelism. When the number of sliced parts in the operator-level model parallel parameters are smaller than the number of dimensions, the optimizer parameters can continue to be sliced in the dimension of data parallelism, increasing the utilization of machine resources and thus improving the end-to-end performance. -## Operation Practice - -The following is an illustration of optimizer parallel operation using an Ascend or GPU single-machine 8-card example: - -### Sample Code Description - -> Download the full sample code: [distributed_optimizer_parallel](https://gitee.com/mindspore/docs/tree/master/docs/sample_code/distributed_optimizer_parallel). - -The directory structure is as follows: - -```text -└─ sample_code - ├─ distributed_optimizer_parallel - ├── distributed_optimizer_parallel.py - └── run.sh - ... -``` - -Among them, `distributed_optimizer_parallel.py` is the script that defines the network structure and the training process. `run.sh` is the execution script. - -### Configuring the Distributed Environment - -Specify the run mode, run device, run card number through the context interface. Unlike single-card scripts, parallel scripts also need to specify the parallel mode `parallel_mode` to be semi-automatic parallel mode, and initialize HCCL or NCCL communication through init. In addition, optimizer parallel should be turned on, configuring `enable_parallel_optimizer=True`. If `device_target` is not set here, it will be automatically specified as the backend hardware device corresponding to the MindSpore package. - -```python -import mindspore as ms -from mindspore.communication import init - -ms.set_context(mode=ms.GRAPH_MODE) -ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, enable_parallel_optimizer=True) -init() -ms.set_seed(1) -``` - -### Loading the Dataset - -In the optimizer parallel scenario, the dataset is loaded in the same way as single-card is loaded, with the following code: - -```python -import os -import mindspore.dataset as ds - -def create_dataset(batch_size): - """create dataset""" - dataset_path = os.getenv("DATA_PATH") - dataset = ds.MnistDataset(dataset_path) - image_transforms = [ - ds.vision.Rescale(1.0 / 255.0, 0), - ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)), - ds.vision.HWC2CHW() - ] - label_transform = ds.transforms.TypeCast(ms.int32) - dataset = dataset.map(image_transforms, 'image') - dataset = dataset.map(label_transform, 'label') - dataset = dataset.batch(batch_size) - return dataset - -data_set = create_dataset(32) -``` - -### Defining the Network - -The optimizer parallel network structure is essentially the same as the single card network structure, with the difference being the addition of a configuration for communication operator fusion: - -```python -from mindspore import nn - -class Network(nn.Cell): - def __init__(self): - super().__init__() - self.flatten = nn.Flatten() - self.layer1 = nn.Dense(28*28, 512) - self.layer2 = nn.Dense(512, 512) - self.layer3 = nn.Dense(512, 10) - self.relu = nn.ReLU() - - def construct(self, x): - x = self.flatten(x) - x = self.layer1(x) - x = self.relu(x) - x = self.layer2(x) - x = self.relu(x) - logits = self.layer3(x) - return logits - -net = Network() -net.layer1.set_comm_fusion(0) -net.layer2.set_comm_fusion(1) -net.layer3.set_comm_fusion(2) -``` - -> Here communication fusion is configured for different layers in order to reduce the communication cost. Details can be found in [Communication Operator Fusion](https://www.mindspore.cn/docs/en/master/model_train/parallel/comm_fusion.html). - -### Training the Network - -In this step, we need to define the loss function, the optimizer, and the training process, which is the same as that of the single-card: - -```python -import mindspore as ms -from mindspore import nn - -optimizer = nn.SGD(net.trainable_params(), 1e-2) -loss_fn = nn.CrossEntropyLoss() - -def forward_fn(data, target): - logits = net(data) - loss = loss_fn(logits, target) - return loss, logits - -grad_fn = ms.value_and_grad(forward_fn, None, net.trainable_params(), has_aux=True) - -@ms.jit -def train_step(inputs, targets): - (loss_value, _), grads = grad_fn(inputs, targets) - optimizer(grads) - return loss_value - -for epoch in range(10): - i = 0 - for image, label in data_set: - loss_output = train_step(image, label) - if i % 10 == 0: - print("epoch: %s, step: %s, loss is %s" % (epoch, i, loss_output)) - i += 1 -``` - -### Running the Single-machine Eight-card Script - -Next, the corresponding scripts are invoked by commands. As an example, the 8-card distributed training script uses the `mpirun` startup method for distributed training: - -```bash -bash run.sh -``` - -After training, the log files are saved to the `log_output` directory, where part of the file directory structure is as follows: - -```text -└─ log_output - └─ 1 - ├─ rank.0 - | └─ stdout - ├─ rank.1 - | └─ stdout -... -``` - -The results are saved in `log_output/1/rank.*/stdout`, and example is as follows: - -```text -epoch: 0, step: 0, loss is 2.3024087 -epoch: 0, step: 10, loss is 2.2921634 -epoch: 0, step: 20, loss is 2.278274 -epoch: 0, step: 30, loss is 2.2537143 -epoch: 0, step: 40, loss is 2.1638 -epoch: 0, step: 50, loss is 1.984318 -epoch: 0, step: 60, loss is 1.6061916 -epoch: 0, step: 70, loss is 1.20966 -epoch: 0, step: 80, loss is 0.98156196 -epoch: 0, step: 90, loss is 0.77229893 -epoch: 0, step: 100, loss is 0.6854114 -... -``` - -Other startup methods such as dynamic networking and `rank table` startup can be found in [startup methods](https://www.mindspore.cn/docs/en/master/model_train/parallel/startup_method.html). - -## Advanced Interfaces - -1. `parallel_optimizer_config`: The optimizer parallel feature also provides a configuration dictionary `parallel_optimizer_config={}`. Different effects can be achieved by configuring different key values in `mindspore.set_auto_parallel_context()`: - - - `gradient_accumulation_shard`: If True, the cumulative gradient variables will be sliced on the data parallelism. When accumulating gradients, an additional communication (ReduceScatter) will be introduced in each accumulation iteration to ensure computational consistency, but saves a large amount of compute device memory (e.g. GPU video memory), thus allowing the model to be trained in larger batches. This configuration is valid only if the model is set in pipelined parallel training or gradient accumulation and has a data parallel dimension. The default value is True. - - ```python - import mindspore as ms - ms.set_auto_parallel_context(parallel_optimizer_config={"gradient_accumulation_shard": True}, enable_parallel_optimizer=True) - ``` - - - `parallel_optimizer_threshold(int)`: This value indicates the minimum value of memory required for the target parameter when slicing the parameter. When the target parameter is smaller than this value, it will not be sliced. The default value is 64 in KB. - - ```python - import numpy as np - import mindspore as ms - param = ms.Parameter(ms.Tensor(np.ones((10, 2)), dtype=ms.float32), name='weight1') - # The float32 type occupies 4 Bytes of memory: - # param_size = np.prod(list(param.shape)) * 4 = (10 * 2) * 4 = 80B < 24KB, not be sliced - ms.set_auto_parallel_context(parallel_optimizer_config={"parallel_optimizer_threshold": 24}) - ``` - - - `optimizer_weight_shard_size`:Set the size of the communication domain split by the optimizer weight. The numerical range can be (0, device_num]. If pipeline parallel is enabled, the numerical range is (0, device_num/stage]. If the size of data parallel communication domain of the parameter cannot be divided by `optimizer_weight_shard_size`, then the specified size of the communication domain split by the optimizer weight will not take effect. Default value is ``-1`` , which means the size of the communication domain split by the optimizer weight will be the size of data parallel communication domain of each parameter. - - ```python - import mindspore as ms - ms.set_auto_parallel_context(parallel_optimizer_config={"optimizer_weight_shard_size": 2}, enable_parallel_optimizer=True) - ``` - -2. `Parameter.parallel_optimizer`: This interface also allows the user to customize whether certain weights are sliced by the optimizer, as shown below: - - ```python - import numpy as np - import mindspore as ms - param = ms.Parameter(ms.Tensor(np.ones((10, 2))), name='weight1', parallel_optimizer=True) - - # Another way to set the parallel_optimizer attribute - param2 = ms.Parameter(ms.Tensor(np.ones((10, 2))), name='weight2') - param2.parallel_optimizer = False - ``` \ No newline at end of file diff --git a/docs/mindspore/source_en/model_train/parallel/overview.md b/docs/mindspore/source_en/model_train/parallel/overview.md deleted file mode 100644 index 1fb711e3eb..0000000000 --- a/docs/mindspore/source_en/model_train/parallel/overview.md +++ /dev/null @@ -1,78 +0,0 @@ -# Distributed Parallelism Overview - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/mindspore/source_en/model_train/parallel/overview.md) - -In deep learning, as the size of the dataset and number of parameters grows, the time and hardware resources required for training will increase and eventually become a bottleneck that constrains training. Distributed parallel training, which can reduce the demand on hardware such as memory and computational performance, is an important optimization means to perform training. In addition, distributed parallel is important for large model training and inference, which provides powerful computational capabilities and performance advantages for handling large-scale data and complex models. - -To implement distributed parallel training and inference, you can refer to the following guidelines: - -## Parallel Modes - -Currently MindSpore can take the following parallel mode, and you can choose according to your needs: - -- **Data Parallel Mode**: In data parallel mode, the dataset can be split in sample dimensions and distributed to different cards. If your dataset is large and the model parameters scale is able to operate on a single card, you can choose this parallel model. Refer to the [Data Parallel](https://www.mindspore.cn/docs/en/master/model_train/parallel/data_parallel.html) tutorial for more information. -- **Automatic Parallel Mode**: a distributed parallel mode that combines data parallel and operator-level model parallel. It can automatically build cost models, find the parallel strategy with shorter training time, and select the appropriate parallel mode for the user. If your dataset and model parameters are large in size, and you want to configure the parallel strategy automatically, you can choose this parallel model. Refer to the [Automatic Parallelism](https://www.mindspore.cn/docs/en/master/model_train/parallel/auto_parallel.html) tutorial for more information. -- **Semi-Automatic Parallel Mode**: Compared with automatic parallel, this mode requires the user to manually configure a slice strategy for the operators to realize parallel. If your dataset and model parameters are large, and you are familiar with the structure of the model, and know which "key operators" are prone to become computational bottlenecks to configure the appropriate slice strategy for the "key operators" to achieve better performance, you can choose this mode. Parallel mode. This mode also allows you to manually configure optimizer parallel and pipeline parallel. Refer to the [Semi-Automatic Parallel](https://www.mindspore.cn/docs/en/master/model_train/parallel/semi_auto_parallel.html) tutorial for more information. -- **Manual Parallel Mode**: In manual parallel mode, you can manually implement parallel communication of models in distributed systems by transferring data based on communication operators such as `AllReduce`, `AllGather`, `Broadcast` and other communication operators. You can refer to the [Manual Parallel](https://www.mindspore.cn/docs/en/master/model_train/parallel/manual_parallel.html) tutorial for more information. -- **Parameter Server Mode**: parameter servers offer better flexibility and scalability than synchronized training methods. You can refer to the [Parameter Server](https://www.mindspore.cn/docs/en/master/model_train/parallel/parameter_server_training.html) mode tutorial for more information. - -## Saving and Loading Models - -Model saving can be categorized into merged and non-merged saving, which can be selected via the `integrated_save` parameter in `mindspore.save_checkpoint` or `mindspore.train.CheckpointConfig`. Model parameters are automatically aggregated and saved to the model file in merged save mode, while each card saves slices of the parameters on their respective cards in non-merged saving mode. You can refer to the [Model Saving](https://www.mindspore.cn/docs/en/master/model_train/parallel/model_saving.html) tutorial for more information about model saving in each parallel mode. - -Model loading can be categorized into full loading and slice loading. If the model file is saved with complete parameters, the model file can be loaded directly through the `load_checkpoint` interface. If the model file is a parameter-sliced file under multi-card, we need to consider whether the distributed slice strategy or cluster size has changed after loading. If the distributed slice strategy or cluster size remains unchanged, the corresponding parameter slice file for each card can be loaded via the `load_distributed_checkpoint` interface, which can be found in [model loading](https://www.mindspore.cn/docs/en/master/model_train/parallel/model_loading.html) tutorial. - -In the case that the saved and loaded distributed slice strategy or cluster size changes, the Checkpoint file under distribution needs to be converted to adapt to the new distributed slice strategy or cluster size. You can refer to [Model Transformation](https://www.mindspore.cn/docs/en/master/model_train/parallel/model_transformation.html) for more information. - -## Fault Recovery - -During the distributed parallel training process, problems such as failures of computing nodes or communication interruptions may be encountered. MindSpore provides three recovery methods to ensure the stability and continuity of training: - -- **Recovery based on full Checkpoint**:Before saving the Checkpoint file, the complete parameters of the model are aggregated by the AllGather operator, and the complete model parameter file is saved for each card, which can be loaded directly for recovery. Multiple checkpoints copies improve the fault tolerance of the model, but for large models, the aggregation process leads to excessive overhead of various resources. Refer to the [Model Loading](https://www.mindspore.cn/docs/en/master/model_train/parallel/model_loading.html) tutorial for details. -- **Disaster Recovery in Dynamic Cluster Scenarios**: In dynamic cluster, if a process fails, the other processes will enter a waiting state, and the training task can be resumed by pulling up the fault process (only GPU hardware platforms are supported at present). Compared with other methods, this fault recovery method does not require restarting the cluster. For details, please refer to [Disaster Recovery in Dynamic Cluster Scenarios](https://www.mindspore.cn/docs/en/master/model_train/parallel/disaster_recover.html) tutorial. -- **Recovery of redundant information based on parameter slicing**: In large model training, devices that are divided according to the dimension of data parallel have the same model parameters. According to this principle, these redundant parameter information can be utilized as a backup. When one node fails, another node utilizing the same parameters can recover the failed node. For details, please refer to the [Fault Recovery Based on Redundant Information](https://www.mindspore.cn/docs/en/master/model_train/parallel/fault_recover.html) tutorial. - -## Optimization Methods - -If there is a requirement on performance, throughput, or scale, or if you don't know how to choose a parallel strategy, consider the following optimization techniques: - -- **Parallel strategy optimization**: - - **Strategy Selection**: Depending on the size of your model and the amount of data, you can refer to the [Strategy Selection](https://www.mindspore.cn/docs/en/master/model_train/parallel/strategy_select.html) tutorial to select different parallel strategies to improve training efficiency and resource utilization. - - **Slicing Techniques**: Slicing techniques are also key to achieving efficient parallel computation. In the [Slicing Techniques](https://www.mindspore.cn/docs/en/master/model_train/parallel/split_technique.html) tutorial, you can learn how to apply a variety of slicing techniques to improve efficiency through concrete examples. - - **Multi-copy Parallel**: Under the existing single-copy mode, certain underlying operators cannot perform computation at the same time when communicating, which leads to resource waste. Multi-copy mode slicing the data into multiple copies in accordance with the batch size dimension can make one copy communicate while the other copy performs computational operations, which improves the resource utilization rate. For details, please refer to the [Multi-copy Parallel](https://www.mindspore.cn/docs/en/master/model_train/parallel/multiple_copy.html) tutorial. -- **Memory optimization**: - - **Gradient Accumulation**: Gradient accumulation updates the parameters of a neural network by computing gradients over multiple MicroBatches and summing them up, then applying this accumulated gradient at once. In this way, a small number of devices can be trained on a large batch size, effectively minimizing memory spikes. For detailed information, refer to [Gradient Accumulation](https://www.mindspore.cn/docs/en/master/model_train/parallel/distributed_gradient_accumulation.html) tutorial. - - **Recompute**: Recompute saves memory space by not saving the result of the forward operators. When calculating the backward operators, you need to use the forward result before recalculating the forward operators. For details, please refer to the [recompute](https://www.mindspore.cn/docs/en/master/model_train/parallel/recompute.html) tutorial. - - **Dataset Slicing**: When a dataset is too large for a single piece of data, the data can be sliced for distributed training. Dataset slicing with model parallel is an effective way to reduce graphics memory usage. For details, please refer to the [dataset slicing](https://www.mindspore.cn/docs/en/master/model_train/parallel/dataset_slice.html) tutorial. - - **Host&Device Heterogeneous**: When the number of parameters exceeds the upper limit of Device memory, you can put some operators with large memory usage and small computation amount on the Host side, so that you can utilize the large memory on the Host side and the fast computation on the Device side at the same time, and improve the utilization rate of the device. For details, please refer to [Host&Device Heterogeneous](https://www.mindspore.cn/docs/en/master/model_train/parallel/host_device_training.html) tutorial. - - **Heterogeneous Storage**: large models are currently limited by the size of the graphics memory, making it difficult to train on a single card. In large-scale distributed cluster training, with communication becoming more and more costly, boosting the graphics memory of a single machine and reducing communication can also improve training performance. Heterogeneous storage can copy the parameters or intermediate results that are not needed temporarily to the memory or hard disk on the Host side, and then restore them to the Device side when needed. For details, please refer to [Heterogeneous Storage](https://www.mindspore.cn/docs/en/master/model_train/parallel/memory_offload.html) tutorial. -- **Communication Optimization**: - - **Communication fusion**: communication fusion can merge the communication operators of the same source and target nodes into a single communication process, avoiding the extra overhead of multiple communications. For details, please refer to [Communication Fusion](https://www.mindspore.cn/docs/en/master/model_train/parallel/comm_fusion.html). - - **Communication Subgraph Extraction and Reuse**: By extracting communication subgraphs for communication operators and replacing the original communication operators, we can reduce the communication time consumption and also reduce the model compilation time. For details, please refer to [Communication Subgraph Extraction and Reuse](https://www.mindspore.cn/docs/en/master/model_train/parallel/comm_subgraph.html). - -## Differences in Different Platforms - -In distributed training, different hardware platforms (Ascend, CPU or GPU) support different characters, and users can choose the corresponding distributed startup method, parallel mode and optimization method according to their platforms. - -### Differences in Startup Methods - -- Ascend supports msrun, dynamic cluster, mpirun, and rank table startup. -- GPU supports msrun, dynamic cluster and mpirun startup. -- CPU supports msrun and dynamic cluster startup. - -For the detailed process, refer to [startup methods](https://www.mindspore.cn/docs/en/master/model_train/parallel/startup_method.html). - -### Differences in Parallel Methods - -- Ascend and GPUs support all methods of parallel, including data parallel, semi-automatic parallel, automatic parallel, and more. -- CPU only supports data parallel. - -For the detailed process, refer to [data parallel](https://www.mindspore.cn/docs/en/master/model_train/parallel/data_parallel.html), [semi-automatic parallel](https://www.mindspore.cn/docs/en/master/model_train/parallel/semi_auto_parallel.html), [auto-parallel](https://www.mindspore.cn/docs/en/master/model_train/parallel/auto_parallel.html). - -### Differences in Optimization Feature Support - -- Ascend supports all optimization features. -- GPU support optimization features other than communication subgraph extraction and multiplexing. -- CPU does not support optimization features. - -For the detailed process, refer to [optimization methods](https://www.mindspore.cn/docs/en/master/model_train/parallel/optimize_technique.html). diff --git a/docs/mindspore/source_en/model_train/parallel/pipeline_parallel.md b/docs/mindspore/source_en/model_train/parallel/pipeline_parallel.md index 56e0d2a994..2bf6103bcc 100644 --- a/docs/mindspore/source_en/model_train/parallel/pipeline_parallel.md +++ b/docs/mindspore/source_en/model_train/parallel/pipeline_parallel.md @@ -10,13 +10,16 @@ In recent years, the scale of neural networks has increased exponentially. Limit Related interfaces: -1. `mindspore.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=NUM, pipeline_result_broadcast=True)`: Set semi-automatic parallel mode and set `pipeline_stages` to indicate that the total number of stages is NUM and call it before initializing the network. `pipeline_result_broadcast`: A switch that broadcast the last stage result to all other stage in pipeline parallel inference. +1. `AutoParallel(network, parallel_mode="semi_auto")`: Encapsulates the specified parallel mode via static graph parallelism, where `network` is the top-level `Cell` or function to be encapsulated, and `parallel_mode` takes the value `semi_auto`, indicating a semi-automatic parallel mode. The return type of this interface is `Cell`. -2. `nn.PipelineCell(loss_cell, micro_size)`: pipeline parallelism requires wrapping a layer of `PipelineCell` around the LossCell and specifying the size of the MicroBatch. In order to improve machine utilization, MindSpore slices the MiniBatch into finer-grained MicroBatches, and the final loss is the sum of the loss values computed by all MicroBatches, where the size of the MicroBatch must be greater than or equal to the number of stages. +2. `AutoParallel.pipeline(stages=1, output_broadcast=False, interleave=False, scheduler='1f1b')`:Configures pipeline parallelism settings. `stages` specifies the total number of partitions for pipeline parallelism. `output_broadcast` determines whether to broadcast the output of the final pipeline stage to all other stages during inference. `interleave` shows that whether to enable interleaving scheduling.`scheduler` defines the pipeline scheduling strategy. Supported values: `gpipe` and `1f1b`. -3. `nn.PipelineGradReducer(parameters)`: pipeline parallelism requires using `PipelineGradReducer` for gradient reduction. Because the output of pipeline parallelism is derived by the addition of several micro-batch outputs, as the gradient do. +3. `mindspore.parallel.Pipeline(network, micro_size=1, stage_config={"cell1":0, "cell2":1})`:Pipeline parallelism requires wrapping the network with an additional layer of `Pipeline`, `micro_size` specifies the number of MicroBatch,which are finer-grained splits of a MiniBatch to improve hardware utilization. The final loss is the accumulation of losses from all MicroBatches. `stage_config` indicates the stage assignment for each Cell in the network. `micro_size` must be greater than or equal to the number of `stages`. + +4. `mindspore.parallel.PipelineGradReducer(parameters, scale_sense=1.0, opt_shard=None)`:pipeline parallelism requires using `PipelineGradReducer` for gradient reduction. Because the output of pipeline parallelism is derived by the addition of several micro-batch outputs, as the gradient do. + +5. `mindspore.parallel.sync_pipeline_shared_parameters(net)`: Synchronize pipeline parallel stage shared parameters. -4. `mindspore.parallel.sync_pipeline_shared_parameters(net)`: Synchronize pipeline parallel stage shared parameters. ## Basic Principle @@ -64,411 +67,4 @@ MindSpore has made memory optimization based on Megatron LM interleaved pipeline *Figure 5: MindSpore Scheduler of Interleaved Pipeline* -## Training Operation Practices - -The following is an illustration of pipeline parallel operation using Ascend or GPU single-machine 8-card as an example: - -### Sample Code Description - -> Download the complete sample code: [distributed_pipeline_parallel](https://gitee.com/mindspore/docs/tree/master/docs/sample_code/distributed_pipeline_parallel). - -The directory structure is as follows: - -```text -└─ sample_code - ├─ distributed_pipeline_parallel - ├── distributed_pipeline_parallel.py - └── run.sh - ... -``` - -`distributed_pipeline_parallel.py` is the script that defines the network structure and training process. `run.sh` is the execution script. - -### Configuring the Distributed Environment - -Specify the run mode, run device, run card number, etc. via the context interface. Unlike single-card scripts, parallel scripts also need to specify the parallel mode `parallel_mode` to be semi-automatic parallel mode and initialize HCCL or NCCL communication via init. In addition, `pipeline_stages=2` should be configured to specify the total number of stages. Not setting `device_target` here automatically specifies the backend hardware device corresponding to the MindSpore package. - -```python -import mindspore as ms -from mindspore.communication import init - -ms.set_context(mode=ms.GRAPH_MODE) -ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2) -init() -ms.set_seed(1) -``` - -If you need to run interleaved pipeline scheduling, you also need to configure: ` pipeline_config={'pipeline_scheduler ':'1f1b', 'pipeline_interleave': True} `. It should be noted that MindSpore's interleaved pipeline scheduling is still in the improvement stage and currently performs better in the kernel by kernel mode. - -```python -import mindspore as ms - -ms.set_auto_parallel_context(pipeline_config={'pipeline_scheduler':'1f1b', 'pipeline_interleave':True}) -``` - -### Loading the Dataset - -In the pipeline parallel scenario, the dataset is loaded in the same way as a single card is loaded, with the following code: - -```python -import os -import mindspore.dataset as ds - -def create_dataset(batch_size): - dataset_path = os.getenv("DATA_PATH") - dataset = ds.MnistDataset(dataset_path) - image_transforms = [ - ds.vision.Rescale(1.0 / 255.0, 0), - ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)), - ds.vision.HWC2CHW() - ] - label_transform = ds.transforms.TypeCast(ms.int32) - dataset = dataset.map(image_transforms, 'image') - dataset = dataset.map(label_transform, 'label') - dataset = dataset.batch(batch_size) - return dataset - -data_set = create_dataset(32) -``` - -### Defining the Network - -The pipeline parallel network structure is basically the same as the single-card network structure, and the difference is the addition of pipeline parallel strategy configuration. Pipeline parallel requires the user to define the parallel strategy by calling the `pipeline_stage` interface to specify the stage on which each layer is to be executed. The granularity of the `pipeline_stage` interface is `Cell`. All `Cells` containing training parameters need to be configured with `pipeline_stage`, and `pipeline_stage` should be configured in the order of network execution, from smallest to largest. If you want to enable interleaved pipeline scheduling, the `pipeline_stage` should be configured according to the non-continuous model layer introduced in [Interleaved Pipeline Scheduler](#interleaved-pipeline-scheduler). After adding `pipeline_stage` configuration based on the single-card model is as follows: - -> - Under pipeline parallelism scenario, when enabling Print/Summary/TensorDump related operators, the operator needs to be used in a Cell with the pipeline_stage attribute. Otherwise, there is a possibility that the operator will not take effect due to pipeline parallel split. -> - Under pipeline parallelism scenario, the output of the network does not support dynamic shapes. - -```python -from mindspore import nn, ops, Parameter -from mindspore.common.initializer import initializer, HeUniform - -import math - -class MatMulCell(nn.Cell): - """ - MatMulCell definition. - """ - def __init__(self, param=None, shape=None): - super().__init__() - if shape is None: - shape = [28 * 28, 512] - weight_init = HeUniform(math.sqrt(5)) - self.param = Parameter(initializer(weight_init, shape), name="param") - if param is not None: - self.param = param - self.print = ops.Print() - self.matmul = ops.MatMul() - - def construct(self, x): - out = self.matmul(x, self.param) - self.print("out is:", out) - return out - - -class Network(nn.Cell): - def __init__(self): - super().__init__() - self.flatten = nn.Flatten() - self.layer1 = MatMulCell() - self.relu1 = nn.ReLU() - self.layer2 = nn.Dense(512, 512) - self.relu2 = nn.ReLU() - self.layer3 = nn.Dense(512, 10) - - def construct(self, x): - x = self.flatten(x) - x = self.layer1(x) - x = self.relu1(x) - x = self.layer2(x) - x = self.relu2(x) - logits = self.layer3(x) - return logits - -net = Network() -net.layer1.pipeline_stage = 0 -net.relu1.pipeline_stage = 0 -net.layer2.pipeline_stage = 0 -net.relu2.pipeline_stage = 1 -net.layer3.pipeline_stage = 1 -``` - -To enable interleaved pipeline scheduling, the non-contiguous model layers need to be configured in an interleaved manner as follows: - -```python -net.layer1.pipeline_stage = 0 -net.relu1.pipeline_stage = 1 -net.layer2.pipeline_stage = 0 -net.relu2.pipeline_stage = 1 -net.layer3.pipeline_stage = 1 -``` - -Stage 0 includes layer 0 and layer 2, while stage 1 includes layer 1, layer 3, and layer 4. - -### Training the Network - -In this step, we need to define the loss function, the optimizer, and the training process, and unlike the single-card model, two interfaces need to be called in this section to configure the pipeline parallel: - -- First define the LossCell. In this case the `nn.WithLossCell` interface is called to encapsulate the network and loss functions. -- Finally, wrap the LossCell with `nn.PipelineCell`, and specify the size of MicroBatch. For detailed information, refer to the related interfaces in the overview. - -Besides, the interface `nn.PipelineGradReducer` is needed to handle gradient of pipeline parallelism, the first parameter of this interface is the network parameter to be updated. - -```python -import mindspore as ms -from mindspore import nn, ops - -optimizer = nn.SGD(net.trainable_params(), 1e-2) -loss_fn = nn.CrossEntropyLoss() -net_with_loss = nn.PipelineCell(nn.WithLossCell(net, loss_fn), 4) -net_with_loss.set_train() - -def forward_fn(inputs, target): - loss = net_with_loss(inputs, target) - return loss - -grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters) -pp_grad_reducer = nn.PipelineGradReducer(optimizer.parameters) - -@ms.jit -def train_one_step(inputs, target): - loss, grads = grad_fn(inputs, target) - grads = pp_grad_reducer(grads) - optimizer(grads) - return loss, grads - -for epoch in range(10): - i = 0 - for data, label in data_set: - loss, grads = train_one_step(data, label) - if i % 10 == 0: - print("epoch: %s, step: %s, loss is %s" % (epoch, i, loss)) - i += 1 -``` - -> - Currently pipeline parallel does not support the automatic mixed precision. -> - Pipeline parallel training is more suitable to use `model.train` approach, because the TrainOneStep logic under pipeline parallelism is complex, while `model.train` internally encapsulates the TrainOneStepCell for pipeline parallel, which is much easier to use. - -### Running the Single-host with 8 Devices Script - -Next, the corresponding scripts are invoked by commands. As an example, the 8-card distributed training script uses the `mpirun` startup method for distributed training: - -```bash -bash run.sh -``` - -After training, the log files are saved to the `log_output` directory, where part of the file directory structure is as follows: - -```text -└─ log_output - └─ 1 - ├─ rank.0 - | └─ stdout - ├─ rank.1 - | └─ stdout -... -``` - -The results are saved in `log_output/1/rank.*/stdout`, and the example is as below: - -```text -epoch: 0 step: 0, loss is 9.137518 -epoch: 0 step: 10, loss is 8.826559 -epoch: 0 step: 20, loss is 8.675843 -epoch: 0 step: 30, loss is 8.307994 -epoch: 0 step: 40, loss is 7.856993 -epoch: 0 step: 50, loss is 7.0662785 -... -``` - -The results of operator `Print` is: - -```text -out is: -Tensor(shape=[8, 512], dtype=Float32, value= -[[ 4.61914062e-01 5.78613281e-01 1.34995094e-01 ... 8.54492188e-02 7.91992188e-01 2.13378906e-01] -... -[ 4.89746094e-01 3.56689453e-01 -4.90966797e-01 ... -3.30078125e-e01 -2.38525391e-01 7.33398438e-01]]) -``` - -Other startup methods such as dynamic cluster and `rank table` startup can be found in [startup methods](https://www.mindspore.cn/docs/en/master/model_train/parallel/startup_method.html). - -## Inference Operation Practices - -The following is an illustration of pipeline parallel inference operation using Ascend or GPU single-machine 8-card as an example: - -### Sample Code Description - -> Download the complete sample code: [distributed_pipeline_parallel](https://gitee.com/mindspore/docs/tree/master/docs/sample_code/distributed_pipeline_parallel). - -The directory structure is as follows: - -```text - -└─ sample_code - ├─ distributed_pipeline_parallel - ├── distributed_pipeline_parallel_inference.py - └── run_inference.sh - ... - -``` - -`distributed_pipeline_parallel_inference.py` is the script that defines the network structure and inference process. `run_inference.sh` is the execution script. - -### Configuring the Distributed Environment - -Specify the run mode, run device, run card number, etc. via the context interface. Unlike single-card scripts, parallel scripts also need to specify the parallel mode `parallel_mode` to be semi-automatic parallel mode and initialize HCCL or NCCL communication via init. In addition, `pipeline_stages=4` should be configured to specify the total number of stages. Not setting `device_target` here automatically specifies the backend hardware device corresponding to the MindSpore package. `pipeline_result_broadcast=True` specifies broadcast last stage inference to other stages. It is useful during auto-regression inference. - -```python - -import mindspore as ms -from mindspore.communication import init - -ms.set_context(mode=ms.GRAPH_MODE) -ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, dataset_strategy="full_batch", - pipeline_stages=4, pipeline_result_broadcast=True) -init() -ms.set_seed(1) - -``` - -### Defining the Network - -The pipeline parallel network structure is basically the same as the single-card network structure, and the difference is the addition of pipeline parallel strategy configuration. Pipeline parallel requires the user to define the parallel strategy by calling the `pipeline_stage` interface to specify the stage on which each layer is to be executed. The granularity of the `pipeline_stage` interface is `Cell`. All `Cells` containing training parameters need to be configured with `pipeline_stage`, and `pipeline_stage` should be configured in the order of network execution, from smallest to largest. Configuration after adding `pipeline_stage` based on the single-card model is as follows: - -```python - -import numpy as np -from mindspore import lazy_inline, nn, ops, Tensor, Parameter, sync_pipeline_shared_parameters - -class VocabEmbedding(nn.Cell): - """Vocab Embedding""" - def __init__(self, vocab_size, embedding_size): - super().__init__() - self.embedding_table = Parameter(Tensor(np.ones([vocab_size, embedding_size]), ms.float32), - name='embedding_table') - self.gather = ops.Gather() - - def construct(self, x): - output = self.gather(self.embedding_table, x, 0) - output = output.squeeze(1) - return output, self.embedding_table.value() - - -class Head(nn.Cell): - def __init__(self): - super().__init__() - self.matmul = ops.MatMul(transpose_b=True) - - def construct(self, state, embed): - return self.matmul(state, embed) - - -class Network(nn.Cell): - """Network""" - @lazy_inline - def __init__(self): - super().__init__() - self.word_embedding = VocabEmbedding(vocab_size=32, embedding_size=32) - self.layer1 = nn.Dense(32, 32) - self.layer2 = nn.Dense(32, 32) - self.head = Head() - - def construct(self, x): - x, embed = self.word_embedding(x) - x = self.layer1(x) - x = self.layer2(x) - x = self.head(x, embed) - return x - -# Define network and set pipeline stage -net = Network() -net.word_embedding.pipeline_stage = 0 -net.layer1.pipeline_stage = 1 -net.layer2.pipeline_stage = 2 -net.head.pipeline_stage = 3 - -``` - -### Inferring the Network - -wrap the netork with `PipelineCellInference`, and specify the size of MicroBatch. `PipelineCellInference` splits input into several micro batch, then executes the network, and finally concats the results along the batch axis through `ops.Concat` operator. - -In the previous step, the parameter `embed` is shared by `self.word_embedding` and `self.head` layer, and these two layers are split into different stages. Before inference, executing `inference_network.compile()` and `sync_pipeline_shared_parameters(inference_network)`, the framework will synchronize the shared parameter automatically. - -```python - -from mindspore import nn, ops - -class PipelineCellInference(nn.Cell): - """Pipeline Cell Inference wrapper""" - def __init__(self, network, micro_batch_num): - super().__init__() - self.network = network - self.micro_batch_num = micro_batch_num - self.concat = ops.Concat() - - def construct(self, x): - """Apply the pipeline inference""" - ret = () - for i in range(self.micro_batch_num): - micro_batch_size = x.shape[0] // self.micro_batch_num - start = micro_batch_size * i - end = micro_batch_size * (i + 1) - - micro_input = x[start:end] - micro_output = self.network(micro_input) - ret = ret + (micro_output,) - - ret = self.concat(ret) - return ret - -inference_network = PipelineCellInference(network=net, micro_batch_num=4) -inference_network.set_train(False) - -# Compile and synchronize shared parameter. -input_ids = Tensor(np.random.randint(low=0, high=32, size=(8, 1)), ms.int32) -inference_network.compile(input_ids) -sync_pipeline_shared_parameters(inference_network) - -# Execute the inference network -logits = inference_network(input_ids) -print(logits.asnumpy()) - -``` - -### Running the Single-host with 8 Devices Script - -Next, the corresponding scripts are invoked by commands. As an example, the 8-card distributed training script uses the `msrun` startup method for distributed training: - -```bash - -bash run_inference.sh - -``` - -After training, the log files are saved to the `log_output` directory, where part of the file directory structure is as follows: - -```text - -└─ pipeline_inference_logs -   ├── scheduler.log -   ├── worker_0.log -   ├── worker_1.log -   ├── worker_2.log -... - -``` - -The results are saved in `pipeline_inference_logs/worker_0.log`, and the example is as below: - -```text - -[[0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 - 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 - 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 - 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 - 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 - 0.01181556 0.01181556] - ...] -``` diff --git a/docs/mindspore/source_en/model_train/parallel/semi_auto_parallel.rst b/docs/mindspore/source_en/model_train/parallel/semi_auto_parallel.rst deleted file mode 100644 index 6df2eff578..0000000000 --- a/docs/mindspore/source_en/model_train/parallel/semi_auto_parallel.rst +++ /dev/null @@ -1,22 +0,0 @@ -Semi-automatic Parallel -=========================== - -.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg - :target: https://gitee.com/mindspore/docs/blob/master/docs/mindspore/source_en/model_train/parallel/semi_auto_parallel.rst - :alt: View Source on Gitee - -.. toctree:: - :maxdepth: 1 - :hidden: - - operator_parallel - advanced_operator_parallel - optimizer_parallel - pipeline_parallel - -Semi-automatic parallel supports the automatic mixing of multiple parallel modes, including: - -- `Operator-level parallel `_: refers to slicing the input tensor and model parameters into multiple devices for computation on an operator basis to improve overall speed. -- `Higher-order Operator-level Parallelism `_: refers to operator-level parallelism that allows customized device layout with tensor layout for more complex sharding logic. -- `Optimizer parallel `_: reduces redundant computations on multiple devices for the same weight updates, spreading the computation over multiple devices. -- `Pipeline parallel `_: means that the model is sliced by layer, with each device processing only a certain part of the model. diff --git a/docs/mindspore/source_zh_cn/model_train/index.rst b/docs/mindspore/source_zh_cn/model_train/index.rst index 8abb8c706b..2edb594b24 100644 --- a/docs/mindspore/source_zh_cn/model_train/index.rst +++ b/docs/mindspore/source_zh_cn/model_train/index.rst @@ -54,10 +54,11 @@ :hidden: :caption: 分布式并行 - parallel/overview parallel/startup_method parallel/data_parallel - parallel/semi_auto_parallel + parallel/operator_parallel + parallel/optimizer_parallel + parallel/pipeline_parallel parallel/auto_parallel parallel/manual_parallel parallel/parameter_server_training diff --git a/docs/mindspore/source_zh_cn/model_train/parallel/operator_parallel.md b/docs/mindspore/source_zh_cn/model_train/parallel/operator_parallel.md index edce5aea1a..bc1e536d24 100644 --- a/docs/mindspore/source_zh_cn/model_train/parallel/operator_parallel.md +++ b/docs/mindspore/source_zh_cn/model_train/parallel/operator_parallel.md @@ -14,7 +14,7 @@ 相关接口: -1. `mindspore.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)`:设置半自动并行模式,必须在初始化网络之前调用。 +1. `AutoParallel(network, parallel_mode="semi_auto")`:通过静态图并行封装指定并行模式,其中`network`是待封装的顶层`Cell`或函数,`parallel_mode`取值`semi_auto`,表示半自动并行模式。该接口返回类型是`Cell`。 2. `mindspore.ops.Primitive.shard()`:指定算子切分策略,详细案例请参考本章的[基本原理](#基本原理)。 @@ -57,9 +57,7 @@ Tensor Redistribution用于处理不同Tensor Layout之间的转换,它能在 ```python import mindspore.nn as nn from mindspore import ops -import mindspore as ms - -ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, device_num=4) +from mindspore.parallel.auto_parallel import AutoParallel class DenseMatMulNet(nn.Cell): def __init__(self): @@ -70,6 +68,9 @@ class DenseMatMulNet(nn.Cell): y = self.matmul1(x, w) z = self.matmul2(y, v) return z + +net = DenseMatMulNet() +paralell_net = AutoParallel(net, parallel_mode='semi_auto') ``` 在以上例子中,用户在4个卡上计算两个连续的二维矩阵乘:`Z = (X * W) * V` 。第一个矩阵乘`Y = X * W`,用户想把X按行切4份(即数据并行);而第二个矩阵乘`Z = Y * V`,用户想把V按列切4份(即模型并行): @@ -78,175 +79,79 @@ class DenseMatMulNet(nn.Cell): ![image](images/operator_parallel_image_4_zh.png) -## 操作实践 +# 高阶算子级并行 -下面以Ascend或者GPU单机8卡为例,进行算子级并行操作说明: +## 概述 -### 样例代码说明 +[算子级并行](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/operator_parallel.html) 是大模型训练推理中常用的并行技术,它可以将张量切分到多卡上,有效降低单卡上的显存。 -> 下载完整的样例代码:[distributed_operator_parallel](https://gitee.com/mindspore/docs/tree/master/docs/sample_code/distributed_operator_parallel)。 +在MindSpore中,算子级并行的配置是通过mindspore.ops.Primitive.shard()接口实现的。该接口通过tuple描述每个输入张量的切分方式,适用于大多数场景,配置过程较为简单。然而,这种切分方式仅描述了张量的切分逻辑,却隐藏了张量在设备rank上的具体排布。因此,它在表达张量切分与设备排布之间的映射关系时存在局限性,无法满足一些复杂场景的需求。 -目录结构如下: +为了应对这些复杂场景,本章节将介绍一种开放设备排布描述的高阶算子级并行配置方法。 -```text -└─ sample_code - ├─ distributed_operator_parallel - ├── distributed_operator_parallel.py - └── run.sh - ... -``` +> 高阶算子级并行模型支持的硬件平台包括Ascend、GPU,需要在Graph模式下运行。 -其中,`distributed_operator_parallel.py`是定义网络结构和训练过程的脚本。`run.sh`是执行脚本。 +## 背景 -### 配置分布式环境 +[算子级并行](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/operator_parallel.html) 章节中介绍了MindSpore对张量的基本切分逻辑,但不能表达出所有的切分场景。例如,对于一个二维张量 "[[a0, a1, a2, a3], [a4, a5, a6, a7]]",其张量排布如下图所示: -通过context接口,用户可以指定运行模式、运行设备和运行卡号等参数。与单卡脚本不同,并行脚本需要额外设置并行模式`parallel_mode`为半自动并行模式,并通过init初始化HCCL或NCCL通信。 +![image](images/advanced_operator_parallel_view1.PNG) -此外,在Ascend硬件平台上,为确保通信有足够的设备内存,需要预留部分内存,则可通过设置`max_size`参数限制模型可使用的最大设备内存;在GPU上则不需要预留。此处,若不设置`device_target`,则会自动指定为MindSpore包对应的后端硬件设备。 +*图:二维张量排布示意* -```python -import mindspore as ms -from mindspore.communication import init - -ms.set_context(mode=ms.GRAPH_MODE) -ms.runtime.set_memory(max_size="28GB") -ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL) -init() -ms.set_seed(1) -``` +由图可知,张量的0轴,如"[a0, a1, a2, a3]"切分到了不连续的卡"[Rank0, Rank4, Rank2, Rank6]"上,而该张量按照strategy=(2, 4)切分,排布应该如下图所示: -### 数据集加载 +![image](images/advanced_operator_parallel_view2.PNG) -在算子级并行场景下,数据集加载方式与单卡加载方式一致,代码如下: +*图:二维张量按照切分策略排布示意* -```python -import os -import mindspore.dataset as ds - -def create_dataset(batch_size): - dataset_path = os.getenv("DATA_PATH") - dataset = ds.MnistDataset(dataset_path) - image_transforms = [ - ds.vision.Rescale(1.0 / 255.0, 0), - ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)), - ds.vision.HWC2CHW() - ] - label_transform = ds.transforms.TypeCast(ms.int32) - dataset = dataset.map(image_transforms, 'image') - dataset = dataset.map(label_transform, 'label') - dataset = dataset.batch(batch_size) - return dataset - -data_set = create_dataset(32) -``` +因此,直接对算子的输入/输出张量按照切分数目进行切分,无法表达出一些有特殊诉求的切分场景。 -### 定义网络 +## 接口配置 -在当前半自动并行模式下,需要用ops算子(Primitive)定义网络。用户可以在单卡网络的基础上手动配置一些算子的切分策略,例如配置策略后的网络结构为: +为了表达出如上述场景下的切分,[shard](https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore/mindspore.shard.html) 接口进行了功能扩展。 -```python -import mindspore as ms -from mindspore import nn, ops +入参in_strategy和out_strategy都额外接收新的数量类型——tuple(Layout)。其中[Layout](https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore/mindspore.Layout.html) 通过设备矩阵进行初始化,并同时要求给设备矩阵的每个轴取一个别名。例如:"layout = Layout((8, 4, 4), name = ("dp", "sp", "mp"))"表示该设备共有128张卡,按照(8, 4, 4)的形状进行排列,并为每个轴分别取了别名"dp"、"sp"、"mp"。 -class Network(nn.Cell): - def __init__(self): - super().__init__() - self.flatten = ops.Flatten() - self.fc1_weight = ms.Parameter(initializer("normal", [28*28, 512], ms.float32)) - self.fc2_weight = ms.Parameter(initializer("normal", [512, 512], ms.float32)) - self.fc3_weight = ms.Parameter(initializer("normal", [512, 10], ms.float32)) - self.matmul1 = ops.MatMul() - self.relu1 = ops.ReLU() - self.matmul2 = ops.MatMul() - self.relu2 = ops.ReLU() - self.matmul3 = ops.MatMul() - - def construct(self, x): - x = self.flatten(x) - x = self.matmul1(x, self.fc1_weight) - x = self.relu1(x) - x = self.matmul2(x, self.fc2_weight) - x = self.relu2(x) - logits = self.matmul3(x, self.fc3_weight) - return logits - -net = Network() -net.matmul1.shard(((2, 4), (4, 1))) -net.relu1.shard(((4, 1),)) -net.matmul2.shard(((1, 8), (8, 1))) -net.relu2.shard(((8, 1),)) -``` +在调用Layout时,通过传入这些轴的别名,每个张量根据其形状(shape)决定每个维度映射到设备矩阵的哪个轴,以及对应的切分份数。例如: -以上网络的`ops.MatMul()`和`ops.ReLU()`算子都配置了切分策略,以`net.matmul1.shard(((2, 4), (4, 1)))`为例,它的切分策略为:第一个输入的行切分2份,列切分4份;第二个输入的行切分4份;对于`net.relu2.shard(((8, 1),))`,它的切分策略为:第一个输入的行切分8份。需要注意的是,由于此处的两个`ops.ReLU()`的切分策略不同,所以要分别定义两次。 +- "dp"表示在设备排布的最高维度的8个设备内切分为8份; +- "sp"表示在设备排布的中间维度的4个设备内切分为4份; +- "mp"表示在设备排布的最低维度的4个设备内切分为4份。 -### 训练网络 +特别地,张量的一个维度可以映射到设备的多个维度,以表达在一个维度进行多次切分。 -在这一步,我们需要定义损失函数、优化器以及训练过程,这部分与单卡写法一致: +针对上述例子中"[[a0, a1, a2, a3], [a4, a5, a6, a7]]"切分到不连续卡上的情况,可以通过如下Layout表达: ```python -import mindspore as ms -from mindspore import nn - -optimizer = nn.SGD(net.trainable_params(), 1e-2) -loss_fn = nn.CrossEntropyLoss() - -def forward_fn(data, target): - logits = net(data) - loss = loss_fn(logits, target) - return loss, logits - -grad_fn = ms.value_and_grad(forward_fn, None, net.trainable_params(), has_aux=True) - -@ms.jit -def train_step(inputs, targets): - (loss_value, _), grads = grad_fn(inputs, targets) - optimizer(grads) - return loss_value - -for epoch in range(10): - i = 0 - for image, label in data_set: - loss_output = train_step(image, label) - if i % 10 == 0: - print("epoch: %s, step: %s, loss is %s" % (epoch, i, loss_output)) - i += 1 +from mindspore import Layout +a = [[a0, a1, a2, a3], [a4, a5, a6, a7]] +layout = Layout((2, 2, 2), name = ("dp", "sp", "mp")) +a_strategy = layout("mp", ("sp", "dp")) ``` -### 运行单机8卡脚本 +可以看到,在张量a的"[a0, a1, a2, a3]"上进行了两次切分,从而切分到了设备的"sp"与"mp"两个轴上,这样出来的结果才是: -接下来通过命令调用对应的脚本,以8卡的分布式训练脚本为例,使用`mpirun`启动方式进行分布式训练: +![image](images/advanced_operator_parallel_view1.PNG) -```bash -bash run.sh -``` - -训练完后,日志文件保存到`log_output`目录下,其中部分文件目录结构如下: +下面,通过一个具体的例子,演示用户在8个卡上计算二维矩阵乘:`Y = (X * W)` 。其中,设备按照`2 * 2 * 2`进行组织;X的切分与上述的张量a切分一致。代码如下所示: -```text -└─ log_output - └─ 1 - ├─ rank.0 - | └─ stdout - ├─ rank.1 - | └─ stdout -... -``` +```python +import mindspore.nn as nn +from mindspore import ops, Layout +from mindspore.parallel.auto_parallel import AutoParallel -关于Loss部分结果保存在`log_output/1/rank.*/stdout`中,示例如下: - -```text -epoch: 0, step: 0, loss is 2.3026192 -epoch: 0, step: 10, loss is 2.2928686 -epoch: 0, step: 20, loss is 2.279024 -epoch: 0, step: 30, loss is 2.2548661 -epoch: 0, step: 40, loss is 2.192434 -epoch: 0, step: 50, loss is 2.0514572 -epoch: 0, step: 60, loss is 1.7082529 -epoch: 0, step: 70, loss is 1.1759918 -epoch: 0, step: 80, loss is 0.94476485 -epoch: 0, step: 90, loss is 0.73854053 -epoch: 0, step: 100, loss is 0.71934 -... -``` +class DenseMatMulNet(nn.Cell): + def __init__(self): + super(DenseMatMulNet, self).__init__() + layout = Layout((2, 2, 2), name = ("dp", "sp", "mp")) + in_strategy = (layout("mp", ("sp", "dp")), layout(("sp", "dp"), "None")) + out_strategy = (layout(("mp", "sp", "dp"), "None"), ) + self.matmul1 = ops.MatMul().shard(in_strategy, out_strategy) + def construct(self, x, w): + y = self.matmul1(x, w) + return y -其他启动方式如动态组网、`rank table`的启动可参考[启动方式](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/startup_method.html)。 +net = DenseMatMulNet() +paralell_net = AutoParallel(net, parallel_mode='semi_auto') +``` \ No newline at end of file diff --git a/docs/mindspore/source_zh_cn/model_train/parallel/optimizer_parallel.md b/docs/mindspore/source_zh_cn/model_train/parallel/optimizer_parallel.md index 602d13fc49..841edb3b54 100644 --- a/docs/mindspore/source_zh_cn/model_train/parallel/optimizer_parallel.md +++ b/docs/mindspore/source_zh_cn/model_train/parallel/optimizer_parallel.md @@ -18,9 +18,11 @@ 相关接口: -1. `mindspore.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, enable_parallel_optimizer=True)`:设置半自动并行模式,且开启优化器并行。该接口必须在初始化网络之前调用。`enable_parallel_optimizer`开启后,默认对所有**占用内存不小于64KB**的参数进行优化器切分,请参考本章的[高级接口](#高级接口)。 +1. `AutoParallel(network, parallel_mode="semi_auto")`:通过静态图并行封装指定并行模式,其中`network`是待封装的顶层`Cell`或函数,`parallel_mode`取值`semi_auto`,表示半自动并行模式。该接口返回类型是`Cell`。 -2. `Cell.set_comm_fusion(fusion_type=NUM)`:在自动/半自动模式下,每个参数都会产生一个对应的AllGather操作和ReduceScatter操作。这些通信算子是自动并行框架自动插入的。然而,随着参数量增多,对应的通信算子也会增多,通信操作中的算子调度和启动都会产生更多的开销。因此,可以通过`Cell`提供的`set_comm_fusion`方法,手动对每个`Cell`内参数对应的AllGather和ReduceScatter操作配置融合标记NUM,以提高通信效率。MindSpore将融合相同NUM参数对应的通信算子,以减少通信开销。 +2. `AutoParallel.hsdp(shard_size=-1, threshold=64, optimizer_level="level1")`:通过该接口设置优化器并行的配置,并开启优化器并行。其中`shard_size`指定优化器权重切分通信域的大小。`threshold`表示切分参数时,要求目标参数所占内存的最小值。当目标参数小于该值时,将不会被切分。 `optimizer_level`是优化器切分级别,当级别为`level1`时,对权重和优化器状态进行切分;当级别为`level2`时,对权重、优化器状态和梯度进行切分;当级别为`level3`时,除了对权重、优化器状态和梯度进行切分外,在反向传播前,还会对权重进行all gather通信,以释放前向传播allgather占用的内存。 + +3. `Cell.set_comm_fusion(fusion_type=NUM)`:在自动/半自动模式下,每个参数都会产生一个对应的AllGather操作和ReduceScatter操作。这些通信算子是自动并行框架自动插入的。然而,随着参数量增多,对应的通信算子也会增多,通信操作中的算子调度和启动都会产生更多的开销。因此,可以通过`Cell`提供的`set_comm_fusion`方法,手动对每个`Cell`内参数对应的AllGather和ReduceScatter操作配置融合标记NUM,以提高通信效率。MindSpore将融合相同NUM参数对应的通信算子,以减少通信开销。 ## 基本原理 @@ -51,209 +53,3 @@ 在实际网络训练的测试验证中,我们发现参数切分带来的内存收益是显著的。尤其是对于大规模网络模型而言,通常选择当下流行的Adaptive Moment estimation (Adam)和Layer-wise Adaptive Moments optimizer for Batching training (LAMB)训练网络,优化器自身的参数量和计算量不容忽视。经过参数分组,网络中的权重参数和优化器中的两份状态参数都减少了N-1/N倍,极大节省了静态内存空间。这为增大单轮迭代样本数量、提升整体训练吞吐量提供了可能,有效解决了大规模网络训练的内存压力。 MindSpore实现的优化器参数切分还具有与算子级并行混合使用的优势。当算子级模型并行参数未切满时,可以继续在数据并行的维度上进行优化器参数切分,增大机器资源的利用率,从而提升端到端性能。 - -## 操作实践 - -下面以Ascend或者GPU单机8卡为例,进行优化器并行操作说明: - -### 样例代码说明 - -> 下载完整的样例代码:[distributed_optimizer_parallel](https://gitee.com/mindspore/docs/tree/master/docs/sample_code/distributed_optimizer_parallel)。 - -目录结构如下: - -```text -└─ sample_code - ├─ distributed_optimizer_parallel - ├── distributed_optimizer_parallel.py - └── run.sh - ... -``` - -其中,`distributed_optimizer_parallel.py`是定义网络结构和训练过程的脚本。`run.sh`是执行脚本。 - -### 配置分布式环境 - -通过context接口指定运行模式、运行设备、运行卡号等。与单卡脚本不同,并行脚本还需指定并行模式`parallel_mode`为半自动并行模式,并通过init初始化HCCL或NCCL通信。此外,还需开启优化器并行,配置`enable_parallel_optimizer=True`。此处未设置`device_target`,会自动指定为MindSpore包对应的后端硬件设备。 - -```python -import mindspore as ms -from mindspore.communication import init - -ms.set_context(mode=ms.GRAPH_MODE) -ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, enable_parallel_optimizer=True) -init() -ms.set_seed(1) -``` - -### 数据集加载 - -在优化器并行场景下,数据集加载方式与单卡加载方式一致,代码如下: - -```python -import os -import mindspore.dataset as ds - -def create_dataset(batch_size): - """create dataset""" - dataset_path = os.getenv("DATA_PATH") - dataset = ds.MnistDataset(dataset_path) - image_transforms = [ - ds.vision.Rescale(1.0 / 255.0, 0), - ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)), - ds.vision.HWC2CHW() - ] - label_transform = ds.transforms.TypeCast(ms.int32) - dataset = dataset.map(image_transforms, 'image') - dataset = dataset.map(label_transform, 'label') - dataset = dataset.batch(batch_size) - return dataset - -data_set = create_dataset(32) -``` - -### 定义网络 - -优化器并行网络结构与单卡网络结构基本一致,区别在于增加了通信算子融合的配置: - -```python -from mindspore import nn - -class Network(nn.Cell): - def __init__(self): - super().__init__() - self.flatten = nn.Flatten() - self.layer1 = nn.Dense(28*28, 512) - self.layer2 = nn.Dense(512, 512) - self.layer3 = nn.Dense(512, 10) - self.relu = nn.ReLU() - - def construct(self, x): - x = self.flatten(x) - x = self.layer1(x) - x = self.relu(x) - x = self.layer2(x) - x = self.relu(x) - logits = self.layer3(x) - return logits - -net = Network() -net.layer1.set_comm_fusion(0) -net.layer2.set_comm_fusion(1) -net.layer3.set_comm_fusion(2) -``` - -> 这里为了减少通信成本,为不同层配置了通信融合,详细可以参考[通信算子融合](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/comm_fusion.html)。 - -### 训练网络 - -在这一步,我们需要定义损失函数、优化器以及训练过程,这部分与单卡写法一致: - -```python -import mindspore as ms -from mindspore import nn - -optimizer = nn.SGD(net.trainable_params(), 1e-2) -loss_fn = nn.CrossEntropyLoss() - -def forward_fn(data, target): - logits = net(data) - loss = loss_fn(logits, target) - return loss, logits - -grad_fn = ms.value_and_grad(forward_fn, None, net.trainable_params(), has_aux=True) - -@ms.jit -def train_step(inputs, targets): - (loss_value, _), grads = grad_fn(inputs, targets) - optimizer(grads) - return loss_value - -for epoch in range(10): - i = 0 - for image, label in data_set: - loss_output = train_step(image, label) - if i % 10 == 0: - print("epoch: %s, step: %s, loss is %s" % (epoch, i, loss_output)) - i += 1 -``` - -### 运行单机8卡脚本 - -接下来通过命令调用对应的脚本,以8卡的分布式训练脚本为例,使用`mpirun`启动方式进行分布式训练: - -```bash -bash run.sh -``` - -训练完后,日志文件保存到`log_output`目录下,其中部分文件目录结构如下: - -```text -└─ log_output - └─ 1 - ├─ rank.0 - | └─ stdout - ├─ rank.1 - | └─ stdout -... -``` - -结果保存在`log_output/1/rank.*/stdout`中,示例如下: - -```text -epoch: 0, step: 0, loss is 2.3024087 -epoch: 0, step: 10, loss is 2.2921634 -epoch: 0, step: 20, loss is 2.278274 -epoch: 0, step: 30, loss is 2.2537143 -epoch: 0, step: 40, loss is 2.1638 -epoch: 0, step: 50, loss is 1.984318 -epoch: 0, step: 60, loss is 1.6061916 -epoch: 0, step: 70, loss is 1.20966 -epoch: 0, step: 80, loss is 0.98156196 -epoch: 0, step: 90, loss is 0.77229893 -epoch: 0, step: 100, loss is 0.6854114 -... -``` - -其他启动方式如动态组网、`rank table`的启动可参考[启动方式](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/startup_method.html)。 - -## 高级接口 - -1. `parallel_optimizer_config`:优化器并行特性还提供了配置字典`parallel_optimizer_config={}`。通过在`mindspore.set_auto_parallel_context()`中配置不同的键值,可以达到不同的效果: - - - `gradient_accumulation_shard`:如果为True,则累加梯度变量将在数据并行维度上进行分片。在累加梯度时,每个累加迭代中将会引入额外的通信(ReduceScatter)以保证计算的一致性,但节省了大量的计算设备内存(例如GPU显存),因此可以使模型以更大的批量进行训练。仅当模型在流水线并行训练或梯度累加中设置此配置,并且具有数据并行维度时,此配置才会有效。默认值为True。 - - ```python - import mindspore as ms - ms.set_auto_parallel_context(parallel_optimizer_config={"gradient_accumulation_shard": True}, enable_parallel_optimizer=True) - ``` - - - `parallel_optimizer_threshold(int)`:该值表示切分参数时,要求目标参数所占内存的最小值。当目标参数小于该值时,将不会被切分。默认值为64,单位为KB。 - - ```python - import numpy as np - import mindspore as ms - param = ms.Parameter(ms.Tensor(np.ones((10, 2)), dtype=ms.float32), name='weight1') - # float32类型占用内存4Bytes: - # param_size = np.prod(list(param.shape)) * 4 = (10 * 2) * 4 = 80B < 24KB, 不会被切分 - ms.set_auto_parallel_context(parallel_optimizer_config={"parallel_optimizer_threshold": 24}) - ``` - - - `optimizer_weight_shard_size`:设置指定优化器权重切分通信域的大小。数值范围可以是(0, device_num],若同时开启流水线并行,数值范围则为(0, device_num/stage]。如果参数的数据并行通信域大小不能被 `optimizer_weight_shard_size` 整除,那么指定的优化器权重切分通信域大小就不会生效。默认值为 ``-1`` ,表示优化器权重切片通信域大小是每个参数的数据并行通信域大小。 - - ```python - import mindspore as ms - ms.set_auto_parallel_context(parallel_optimizer_config={"optimizer_weight_shard_size": 2}, enable_parallel_optimizer=True) - ``` - -2. `Parameter.parallel_optimizer`:用户还可以通过此接口自定义某些权重是否进行优化器切分,如下所示: - - ```python - import numpy as np - import mindspore as ms - param = ms.Parameter(ms.Tensor(np.ones((10, 2))), name='weight1', parallel_optimizer=True) - - # 设置 parallel_optimizer 属性的另一种方法 - param2 = ms.Parameter(ms.Tensor(np.ones((10, 2))), name='weight2') - param2.parallel_optimizer = False - ``` \ No newline at end of file diff --git a/docs/mindspore/source_zh_cn/model_train/parallel/overview.md b/docs/mindspore/source_zh_cn/model_train/parallel/overview.md deleted file mode 100644 index 59df4de1ff..0000000000 --- a/docs/mindspore/source_zh_cn/model_train/parallel/overview.md +++ /dev/null @@ -1,79 +0,0 @@ -# 分布式并行概述 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/mindspore/source_zh_cn/model_train/parallel/overview.md) - -在深度学习中,当数据集和参数量的规模越来越大,训练所需的时间和硬件资源会随之增加,最后会变成制约训练的瓶颈。分布式并行训练,可以降低对内存、计算性能等硬件的需求,是进行训练的重要优化手段。此外,分布式并行对大模型训练和推理有着重要的意义,它为处理大规模数据和复杂模型提供了强大的计算能力和性能优势。 - -要实现分布式并行训练和推理,您可以参考以下指引: - -## 并行模式 - -目前MindSpore可以采取下述的几种并行模式,您可以按需求选择: - -- **数据并行模式**:数据并行模式下,数据集可以在样本维度拆分并下发到不同的卡上。如果您的数据集较大,而模型参数规模能在单卡运算,您可以选择这种并行模型。参考[数据并行](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/data_parallel.html)教程了解更多信息。 -- **自动并行模式**:融合了数据并行、算子级模型并行的分布式并行模式,可以自动建立代价模型,找到训练时间较短的并行策略,为用户选择合适的并行模式。如果您的数据集和模型参数规模都较大,且希望自动配置并行策略,您可以选择这种并行模型。参考[自动并行](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/auto_parallel.html)教程了解更多信息。 -- **半自动并行模式**:相较于自动并行,该模式需要用户对算子手动配置切分策略实现并行。如果您的数据集和模型参数规模都较大,且您对模型的结构比较熟悉,知道哪些“关键算子”容易成为计算瓶颈,为“关键算子”配置合适的切分策略可以获得更好的性能,您可以选择这种并行模式。此外该模式还可以手动配置优化器并行和流水线并行。参考[半自动并行](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/semi_auto_parallel.html)教程了解更多信息。 -- **手动并行模式**:在手动并行模式下,您可以基于通信原语例如`AllReduce`、`AllGather`、`Broadcast`等通信算子进行数据传输,手动实现分布式系统下模型的并行通信。您可以参考[手动并行](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/manual_parallel.html)教程了解更多信息。 -- **参数服务器模式**:相比于同步的训练方法,参数服务器具有更好的灵活性、可拓展性。您可以参考[参数服务器](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/parameter_server_training.html)模式教程了解更多信息。 - -## 保存和加载模型 - -模型的保存可以分为合并保存和非合并保存,可以通过`mindspore.save_checkpoint`或者`mindspore.train.CheckpointConfig`中的`integrated_save`参数选择是否合并保存。合并保存模式下,模型参数会自动聚合保存到模型文件中,而非合并保存模式下,每张卡保存各自卡上的参数切片。关于各并行模式下的模型保存可以参考[模型保存](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/model_saving.html)教程。 - -模型的加载可以分为完整加载和切片加载。若保存的是完整参数的模型文件,则可以直接通过`load_checkpoint`接口加载模型文件。若保存的是多卡下的参数切片文件,则需要考虑加载后的分布式切分策略或集群规模是否有变化。如果分布式切分策略或集群规模不变,则可以通过`load_distributed_checkpoint`接口加载各卡对应的参数切片文件,可以参考[模型加载](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/model_loading.html)教程。 - -若保存和加载的分布式切分策略或集群卡数改变的情况下,则需要对分布式下的Checkpoint文件进行转换以适配新的分布式切分策略或集群卡数。您可以参考[模型转换](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/model_transformation.html)了解更多信息。 - -## 故障恢复 - -在分布式并行训练过程中,可能会遇到计算节点的故障或通信中断等问题。MindSpore提供了三种恢复方式以保证训练的稳定性和连续性: - -- **根据完整Checkpoint恢复**:在保存Checkpoint文件前,通过AllGather算子汇聚模型的完整参数,每张卡均保存了完整的模型参数文件,可以直接加载恢复。多个checkpoints副本提高了模型的容错性,但是对于大模型来说,汇聚的过程会导致各种资源开销过大。详细可参考[模型加载](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/model_loading.html)教程。 -- **动态组网场景下故障恢复**:在动态组网中,若某个进程出现故障,其他进程会进入等待状态,可以通过重新拉起故障进程使得训练任务继续进行(目前仅支持GPU硬件平台)。和其他方式相比,该故障恢复方式无需重启集群。详细可参考[动态组网场景下故障恢复](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/disaster_recover.html)教程。 -- **根据参数切分的冗余信息恢复**:在大模型训练中,根据数据并行的维度所划分的设备,他们的模型参数是相同的。根据这个原理,可以利用这些冗余的参数信息作为备份,在一个节点故障时,利用相同参数的另一节点就可以恢复故障的节点。详细可参考[基于冗余信息的故障恢复](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/fault_recover.html)教程。 - -## 优化方法 - -如果对性能、吞吐量或规模有要求,或者不知道如何选择并行策略,可以考虑以下优化技术: - -- **并行策略优化**: - - **策略选择**:根据您的模型规模和数据量大小,您可以参考[策略选择](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/strategy_select.html)教程来选择不同的并行策略,以提高训练效率和资源利用率。 - - **切分技巧**:切分技巧也是实现高效并行计算的关键,在[切分技巧](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/split_technique.html)教程中,您可以通过具体案例了解到如何应用各种切分技巧来提升效率。 - - **多副本并行**:在现有的单副本模式下,某些底层算子在进行通信的时候,无法同时进行计算,从而导致资源浪费。多副本并行通过对数据按照Batch Size维度进行切分为多个副本,可以使一个副本在通信时,另一副本进行计算操作,提升了资源利用率,详细可参考[多副本并行](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/multiple_copy.html)教程。 -- **内存优化**: - - **梯度累加**:梯度累加通过在多个MicroBatch上计算梯度并将它们累加起来,然后一次性应用这个累加梯度来更新神经网络的参数。通过这种方法少量设备也能训练大Batch Size,有效减低内存峰值,详细可参考[梯度累加](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/distributed_gradient_accumulation.html)教程。 - - **重计算**:重计算通过不保存某些正向算子的计算结果,以节省内存空间,在计算反向算子时,需要用到正向结果再重新计算正向算子。详细可参考[重计算](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/recompute.html)教程。 - - **数据集切分**:数据集单个数据过大的时候,可以对数据进行切分,进行分布式训练。数据集切分配合模型并行是有效降低显存占用的方式。详细可参考[数据集切分](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/dataset_slice.html)教程。 - - **Host&Device异构**:在遇到参数量超过Device内存上限的时候,可以把一些内存占用量大且计算量少的算子放在Host端,这样能同时利用Host端内存大,Device端计算快的特性,提升了设备的利用率。详细可参考[Host&Device异构](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/host_device_training.html)教程。 - - **异构存储**:大模型目前受限显存大小,难以在单卡上训练。大规模分布式集群训练中,在通信代价越来越大的情况下,提升单机的显存,减少通信,也能提升训练性能。异构存储可以将暂时不需要用到的参数或中间结果拷贝到Host端内存或者硬盘,在需要时再恢复至Device端。详细可参考[异构存储](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/memory_offload.html)教程。 -- **通信优化**: - - **通信融合**:通信融合可以将相同源节点和目标节点的通信算子合并到一次通信过程,避免多次通信带来额外开销。详细可参考[通信融合](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/comm_fusion.html)。 - - **通信子图提取与复用**:通过对通信算子提取通信子图,替换原本的通信算子,可以减少通信耗时,同时减少模型编译时间。详细可参考[通信子图提取与复用](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/comm_subgraph.html)。 - -## 不同平台差异 - -在分布式训练中,不同硬件平台(Ascend、CPU或者GPU)支持的特性也有所不同,用户可以根据自己的平台选择对应的分布式启动方式、并行模式和优化方法。 - -### 启动方式的差异 - -- Ascend支持msrun、动态组网、mpirun以及rank table启动四种启动方式。 -- GPU支持msrun、动态组网和mpirun三种启动方式。 -- CPU支持msrun和动态组网启动。 - -详细过程请参考[启动方式](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/startup_method.html)。 - -### 并行方式的差异 - -- Ascend和GPU支持所有并行方式,包括数据并行、半自动并行、自动并行等。 -- CPU仅支持数据并行。 - -详细过程请参考[数据并行](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/data_parallel.html)、[半自动并行](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/semi_auto_parallel.html)、[自动并行](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/auto_parallel.html)。 - -### 优化特性支持的差异 - -- Ascend支持所有的优化特性。 -- GPU支持除了通信子图提取与复用以外的优化特性。 -- CPU不支持优化特性。 - -详细过程请参考[优化方法](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/optimize_technique.html)。 - diff --git a/docs/mindspore/source_zh_cn/model_train/parallel/pipeline_parallel.md b/docs/mindspore/source_zh_cn/model_train/parallel/pipeline_parallel.md index b0d493b156..e26de43216 100644 --- a/docs/mindspore/source_zh_cn/model_train/parallel/pipeline_parallel.md +++ b/docs/mindspore/source_zh_cn/model_train/parallel/pipeline_parallel.md @@ -10,13 +10,15 @@ 相关接口: -1. `mindspore.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=NUM, pipeline_result_broadcast=True)`:设置半自动并行模式,且设置`pipeline_stages`用来表明Stage的总数为NUM,必须在初始化网络之前调用。`pipeline_result_broadcast`表示流水线并行推理时,最后一个stage的结果是否广播给其他stage。 +1. `AutoParallel(network, parallel_mode="semi_auto")`:通过静态图并行封装指定并行模式,其中`network`是待封装的顶层`Cell`或函数,`parallel_mode`取值`semi_auto`,表示半自动并行模式。该接口返回类型是`Cell`。 -2. `nn.PipelineCell(loss_cell, micro_size)`:流水线并行需要在LossCell外再包一层`PipelineCell`,并指定MicroBatch的size。为了提升机器的利用率,MindSpore将MiniBatch切分成了更细粒度的MicroBatch,最终的loss则是所有MicroBatch计算的loss值累加。其中,MicroBatch的size必须大于等于Stage的数量。 +2. `AutoParallel.pipeline(stages=1, output_broadcast=False, interleave=False, scheduler='1f1b')`:设置流水线并行配置。`stages`表示流水线并行需要设置的切分总数,`output_broadcast`表示流水线并行推理时,最后一个stage的结果是否广播给其他stage,`interleave`表示是否开启interleave优化策略,`scheduler`表示流水线并行的调度策略,当前支持`gpipe`和`1f1b`。 -3. `nn.PipelineGradReducer(parameters)`:流水线并行需要使用`PipelineGradReducer`来完成梯度聚合。这是因为流水线并行中,其输出是由多个`micro-batch`的结果相加得到,因此其梯度也需要进行累加。 +3. `mindspore.parallel.Pipeline(network, micro_size=1, stage_config={"cell1":0, "cell2":1})`:流水线并行需要需要在network外再添加一层`Pipeline`,并通过`micro_size`指定MicroBatch的个数,以及指出网络中各Cell在哪个`stage`中执行。为了提升机器的利用率,MindSpore将MiniBatch切分成了更细粒度的MicroBatch,最终的loss则是所有MicroBatch计算的loss值累加。其中,micro_size必须大于等于stages的数量。 -4. `mindspore.parallel.sync_pipeline_shared_parameters(net)`: 在推理场景下,用于同步不同stage之间共享权重。 +4. `mindspore.parallel.PipelineGradReducer(parameters, scale_sense=1.0, opt_shard=None)`:流水线并行需要使用`PipelineGradReducer`来完成梯度聚合。这是因为流水线并行中,其输出是由多个`MicroBatch`的结果相加得到,因此其梯度也需要进行累加。 + +5. `mindspore.parallel.sync_pipeline_shared_parameters(net)`: 在推理场景下,用于同步不同stage之间共享权重。 ## 基本原理 @@ -64,412 +66,4 @@ MindSpore在Megatron-LM的interleaved pipeline调度的基础上做了内存优 *图5: MindSpore的interleaved pipeline调度* -## 训练操作实践 - -下面以Ascend或者GPU单机8卡为例,进行流水线并行操作说明: - -### 样例代码说明 - -> 下载完整的样例代码:[distributed_pipeline_parallel](https://gitee.com/mindspore/docs/tree/master/docs/sample_code/distributed_pipeline_parallel)。 - -目录结构如下: - -```text -└─ sample_code - ├─ distributed_pipeline_parallel - ├── distributed_pipeline_parallel.py - └── run.sh - ... -``` - -其中,`distributed_pipeline_parallel.py`是定义网络结构和训练过程的脚本。`run.sh`是执行脚本。 - -### 配置分布式环境 - -通过context接口指定运行模式、运行设备、运行卡号等。与单卡脚本不同,并行脚本还需指定并行模式`parallel_mode`为半自动并行模式,并通过init初始化HCCL或NCCL通信。此外,还需配置`pipeline_stages=2`指定Stage的总数。此处未设置`device_target`,会自动指定为MindSpore包对应的后端硬件设备。 - -```python -import mindspore as ms -from mindspore.communication import init - -ms.set_context(mode=ms.GRAPH_MODE) -ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2) -init() -ms.set_seed(1) -``` - -如果需要执行interleaved pipeline调度,还需要配置`pipeline_config={'pipeline_scheduler':'1f1b', 'pipeline_interleave':True}`。需要注意的是,MindSpore的interleaved pipeline调度功能还在完善阶段,目前在O0或者O1模式下表现会更好。 - -```python -import mindspore as ms - -ms.set_auto_parallel_context(pipeline_config={'pipeline_scheduler':'1f1b', 'pipeline_interleave':True}) -``` - -### 数据集加载 - -在流水线并行场景下,数据集加载方式与单卡加载方式一致,代码如下: - -```python -import os -import mindspore.dataset as ds - -def create_dataset(batch_size): - dataset_path = os.getenv("DATA_PATH") - dataset = ds.MnistDataset(dataset_path) - image_transforms = [ - ds.vision.Rescale(1.0 / 255.0, 0), - ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)), - ds.vision.HWC2CHW() - ] - label_transform = ds.transforms.TypeCast(ms.int32) - dataset = dataset.map(image_transforms, 'image') - dataset = dataset.map(label_transform, 'label') - dataset = dataset.batch(batch_size) - return dataset - -data_set = create_dataset(32) -``` - -### 定义网络 - -流水线并行网络结构与单卡网络结构基本一致,区别在于增加了流水线并行策略配置。流水线并行需要用户去定义并行的策略,通过调用`pipeline_stage`接口来指定每个layer要在哪个stage上去执行。`pipeline_stage`接口的粒度为`Cell`。所有包含训练参数的`Cell`都需要配置`pipeline_stage`,并且`pipeline_stage`要按照网络执行的先后顺序,从小到大进行配置。如果需要使能interleaved pipeline调度,`pipeline_stage`需按照前面章节中介绍的非连续模型层进行[交错式配置](#interleaved-pipeline调度)。在单卡模型基础上,增加`pipeline_stage`配置后如下: - -> - pipeline并行场景下,使能Print/Summary/TensorDump相关算子时,需要把该算子放到有pipeline_stage属性的Cell中使用,否则有概率因为pipeline并行切分导致算子不生效。 -> - pipeline并行场景下,网络的输出不支持动态shape。 - -```python -from mindspore import nn, ops, Parameter -from mindspore.common.initializer import initializer, HeUniform - -import math - -class MatMulCell(nn.Cell): - """ - MatMulCell definition. - """ - def __init__(self, param=None, shape=None): - super().__init__() - if shape is None: - shape = [28 * 28, 512] - weight_init = HeUniform(math.sqrt(5)) - self.param = Parameter(initializer(weight_init, shape), name="param") - if param is not None: - self.param = param - self.print = ops.Print() - self.matmul = ops.MatMul() - - def construct(self, x): - out = self.matmul(x, self.param) - self.print("out is:", out) - return out - - -class Network(nn.Cell): - def __init__(self): - super().__init__() - self.flatten = nn.Flatten() - self.layer1 = MatMulCell() - self.relu1 = nn.ReLU() - self.layer2 = nn.Dense(512, 512) - self.relu2 = nn.ReLU() - self.layer3 = nn.Dense(512, 10) - - def construct(self, x): - x = self.flatten(x) - x = self.layer1(x) - x = self.relu1(x) - x = self.layer2(x) - x = self.relu2(x) - logits = self.layer3(x) - return logits - - -net = Network() -net.layer1.pipeline_stage = 0 -net.relu1.pipeline_stage = 0 -net.layer2.pipeline_stage = 0 -net.relu2.pipeline_stage = 1 -net.layer3.pipeline_stage = 1 -``` - -使能interleaved pipeline调度,`pipeline_stage`的非连续模型层需要进行交错式配置,配置如下: - -```python -net.layer1.pipeline_stage = 0 -net.relu1.pipeline_stage = 1 -net.layer2.pipeline_stage = 0 -net.relu2.pipeline_stage = 1 -net.layer3.pipeline_stage = 1 -``` - -stage0有第0层和第2层,stage1有第1层、第3层和第4层。 - -### 训练网络 - -在这一步,我们需要定义损失函数、优化器以及训练过程。与单卡模型不同,这里调用两个接口来配置流水线并行: - -- 首先需要定义LossCell,本例中调用了`nn.WithLossCell`接口封装网络和损失函数。 -- 然后需要在LossCell外包一层`nn.PipelineCell`,并指定MicroBatch的size。详细请参考本章概述中的相关接口。 - -除此之外, 还需要增加 `nn.PipelineGradReducer` 接口,用于处理流水线并行下的梯度,该接口的第一个参数为需要更新的网络参数。 - -```python -import mindspore as ms -from mindspore import nn, ops - -optimizer = nn.SGD(net.trainable_params(), 1e-2) -loss_fn = nn.CrossEntropyLoss() -net_with_loss = nn.PipelineCell(nn.WithLossCell(net, loss_fn), 4) -net_with_loss.set_train() - -def forward_fn(inputs, target): - loss = net_with_loss(inputs, target) - return loss - -grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters) -pp_grad_reducer = nn.PipelineGradReducer(optimizer.parameters) - -@ms.jit -def train_one_step(inputs, target): - loss, grads = grad_fn(inputs, target) - grads = pp_grad_reducer(grads) - optimizer(grads) - return loss, grads - -for epoch in range(10): - i = 0 - for data, label in data_set: - loss, grads = train_one_step(data, label) - if i % 10 == 0: - print("epoch: %s, step: %s, loss is %s" % (epoch, i, loss)) - i += 1 -``` - -> - 目前流水线并行不支持自动混合精度特性。 -> - 流水线并行训练更适合用`model.train`的方式,这是因为流水线并行下的TrainOneStep逻辑复杂,而`model.train`内部封装了针对流水线并行的TrainOneStepCell,易用性更好。 - -### 运行单机8卡脚本 - -接下来通过命令调用对应的脚本,以8卡的分布式训练脚本为例,使用`mpirun`启动方式进行分布式训练: - -```bash -bash run.sh -``` - -训练完后,日志文件保存到`log_output`目录下,其中部分文件目录结构如下: - -```text -└─ log_output - └─ 1 - ├─ rank.0 - | └─ stdout - ├─ rank.1 - | └─ stdout -... -``` - -结果保存在`log_output/1/rank.*/stdout`中,示例如下: - -```text -epoch: 0 step: 0, loss is 9.137518 -epoch: 0 step: 10, loss is 8.826559 -epoch: 0 step: 20, loss is 8.675843 -epoch: 0 step: 30, loss is 8.307994 -epoch: 0 step: 40, loss is 7.856993 -epoch: 0 step: 50, loss is 7.0662785 -... -``` - -`Print` 算子的结果为: - -```text -out is: -Tensor(shape=[8, 512], dtype=Float32, value= -[[ 4.61914062e-01 5.78613281e-01 1.34995094e-01 ... 8.54492188e-02 7.91992188e-01 2.13378906e-01] -... -[ 4.89746094e-01 3.56689453e-01 -4.90966797e-01 ... -3.30078125e-e01 -2.38525391e-01 7.33398438e-01]]) -``` - -其他启动方式如动态组网、`rank table`的启动可参考[启动方式](https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/startup_method.html)。 - -## 推理操作实践 - -下面以Ascend或者GPU单机8卡为例,进行流水线并行操作说明: - -### 样例代码说明 - -> 下载完整的样例代码:[distributed_pipeline_parallel](https://gitee.com/mindspore/docs/tree/master/docs/sample_code/distributed_pipeline_parallel)。 - -目录结构如下: - -```text - -└─ sample_code - ├─ distributed_pipeline_parallel - ├── distributed_pipeline_parallel_inference.py - └── run_inference.sh - ... - -``` - -其中,`distributed_pipeline_parallel_inference.py`是定义网络结构和推理过程的脚本。`run_inference.sh`是执行脚本。 - -### 配置分布式环境 - -通过context接口指定运行模式、运行设备、运行卡号等,与单卡脚本不同。并行脚本还需指定并行模式`parallel_mode`为半自动并行模式,并通过init初始化HCCL或NCCL通信。此外,还需配置`pipeline_stages=4`指定Stage的总数。此处未设置`device_target`,会自动指定为MindSpore包对应的后端硬件设备。`pipeline_result_broadcast=True`表示流水线并行推理时,将最后一个stage的结果广播给其他stage,可以用于自回归推理场景。 - -```python - -import mindspore as ms -from mindspore.communication import init - -ms.set_context(mode=ms.GRAPH_MODE) -ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, dataset_strategy="full_batch", - pipeline_stages=4, pipeline_result_broadcast=True) -init() -ms.set_seed(1) - -``` - -### 定义网络 - -流水线并行需要用户去定义并行的策略,通过调用`pipeline_stage`接口来指定每个layer要在哪个stage上去执行。`pipeline_stage`接口的粒度为`Cell`。所有包含训练参数的`Cell`都需要配置`pipeline_stage`,并且`pipeline_stage`要按照网络执行的先后顺序,从小到大进行配置。在单卡模型基础上,增加`pipeline_stage`配置后如下: - -```python - -import numpy as np -from mindspore import lazy_inline, nn, ops, Tensor, Parameter, sync_pipeline_shared_parameters - -class VocabEmbedding(nn.Cell): - """Vocab Embedding""" - def __init__(self, vocab_size, embedding_size): - super().__init__() - self.embedding_table = Parameter(Tensor(np.ones([vocab_size, embedding_size]), ms.float32), - name='embedding_table') - self.gather = ops.Gather() - - def construct(self, x): - output = self.gather(self.embedding_table, x, 0) - output = output.squeeze(1) - return output, self.embedding_table.value() - - -class Head(nn.Cell): - def __init__(self): - super().__init__() - self.matmul = ops.MatMul(transpose_b=True) - - def construct(self, state, embed): - return self.matmul(state, embed) - - -class Network(nn.Cell): - """Network""" - @lazy_inline - def __init__(self): - super().__init__() - self.word_embedding = VocabEmbedding(vocab_size=32, embedding_size=32) - self.layer1 = nn.Dense(32, 32) - self.layer2 = nn.Dense(32, 32) - self.head = Head() - - def construct(self, x): - x, embed = self.word_embedding(x) - x = self.layer1(x) - x = self.layer2(x) - x = self.head(x, embed) - return x - -# Define network and set pipeline stage -net = Network() -net.word_embedding.pipeline_stage = 0 -net.layer1.pipeline_stage = 1 -net.layer2.pipeline_stage = 2 -net.head.pipeline_stage = 3 - -``` - -### 推理网络 - -在network外包一层`PipelineCellInference`,并指定MicroBatch的size。`PipelineCellInference`中将输入切分为若干个micro batch,执行推理网络,最后将若干个micro batch推理结果通过`ops.Concat`算子沿batch轴拼接后返回。 - -在上一步中,`embed`被`self.word_embedding`和`self.head`两层共享,并且这两层被切分到了不同的stage上。在执行推理前,先编译计算图`inference_network.compile()`,再调用`sync_pipeline_shared_parameters(inference_network)`接口,框架自动同步stage间的共享权重。 - -```python - -from mindspore import nn, ops - -class PipelineCellInference(nn.Cell): - """Pipeline Cell Inference wrapper""" - def __init__(self, network, micro_batch_num): - super().__init__() - self.network = network - self.micro_batch_num = micro_batch_num - self.concat = ops.Concat() - - def construct(self, x): - """Apply the pipeline inference""" - ret = () - for i in range(self.micro_batch_num): - micro_batch_size = x.shape[0] // self.micro_batch_num - start = micro_batch_size * i - end = micro_batch_size * (i + 1) - - micro_input = x[start:end] - micro_output = self.network(micro_input) - ret = ret + (micro_output,) - - ret = self.concat(ret) - return ret - -inference_network = PipelineCellInference(network=net, micro_batch_num=4) -inference_network.set_train(False) - -# Compile and synchronize shared parameter. -input_ids = Tensor(np.random.randint(low=0, high=32, size=(8, 1)), ms.int32) -inference_network.compile(input_ids) -sync_pipeline_shared_parameters(inference_network) - -# Execute the inference network -logits = inference_network(input_ids) -print(logits.asnumpy()) - -``` - -### 运行单机8卡脚本 - -接下来通过命令调用对应的脚本,以8卡的分布式推理脚本为例,使用`msrun`启动方式进行分布式训练: - -```bash - -bash run_inference.sh - -``` - -训练完后,日志文件保存到`pipeline_inference_logs`目录下,其中部分文件目录结构如下: - -```text - -└─ pipeline_inference_logs -   ├── scheduler.log -   ├── worker_0.log -   ├── worker_1.log -   ├── worker_2.log -... - -``` - -结果保存在`pipeline_inference_logs/worker_0.log`中,示例如下: - -```text - -[[0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 - 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 - 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 - 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 - 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 0.01181556 - 0.01181556 0.01181556] - ...] -``` diff --git a/docs/mindspore/source_zh_cn/model_train/parallel/semi_auto_parallel.rst b/docs/mindspore/source_zh_cn/model_train/parallel/semi_auto_parallel.rst deleted file mode 100644 index b10241cb69..0000000000 --- a/docs/mindspore/source_zh_cn/model_train/parallel/semi_auto_parallel.rst +++ /dev/null @@ -1,22 +0,0 @@ -半自动并行 -======================== - -.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg - :target: https://gitee.com/mindspore/docs/blob/master/docs/mindspore/source_zh_cn/model_train/parallel/semi_auto_parallel.rst - :alt: 查看源文件 - -.. toctree:: - :maxdepth: 1 - :hidden: - - operator_parallel - advanced_operator_parallel - optimizer_parallel - pipeline_parallel - -半自动并行支持多种并行模式的自动混合使用,包括: - -- `算子级并行 `_:以算子为单位,把输入张量和模型参数切分到多台设备上进行计算,提升整体速度。 -- `高阶算子级并行 `_:允许自定义设备排布与张量排布的算子级并行,以实现更复杂的切分逻辑。 -- `优化器并行 `_:减少多台设备对于相同权重更新的冗余计算,将计算量分散到多个设备上。 -- `流水线并行 `_:将模型按层切分,每个设备只处理模型中某一部分。 -- Gitee