diff --git a/tutorials/experts/source_en/parallel/comm_fusion.md b/tutorials/experts/source_en/parallel/comm_fusion.md
new file mode 100644
index 0000000000000000000000000000000000000000..912237f1b02f888f928de98753ace5e05e217c7e
--- /dev/null
+++ b/tutorials/experts/source_en/parallel/comm_fusion.md
@@ -0,0 +1,248 @@
+# Distributed Training Communication Fusion
+
+
+
+## Overview
+
+In distributed parallel training scenarios to train large-scale parameter models (e.g., GPT-3, Pangu-$\alpha$), data transmission of cross-device or even cross-node is a bottleneck that limits scalability as well as operator power utilization [1]. Communication fusion is an important method to improve network resource utilization and accelerate data transmission efficiency by encapsulating the communication operator of the same source and destination nodes for simultaneous execution to avoid the extra overhead caused by multiple single operator executions.
+
+MindSpore supports the fusion of three common communication operators (`AllReduce`, `AllGather` and `ReduceScatter`) in distributed training, and provides a simple and easy-to-use interface for user configuration. The communication fusion plays an important role in the long and steady training mission support.
+
+## Basic Principle
+
+This section firstly introduces the relationship between computation and communication in distributed training with the example of data parallelism, and secondly introduces the necessity of communication fusion in distributed training scenarios.
+
+### Computation and Communication in Distributed Training
+
+The whole process of distributed training can be roughly divided into two processes: local model computation and cross-device network data interaction.
+
+The following is an example of data parallelism [2] to introduce the overall training process. For other parallel approaches, such as model parallelism [3], pipeline parallelism [4], please refer to related papers.
+
+As shown in the figure below, each node backs up the complete neural network model and uses the local dataset partition to train a mini-batch for forward and backward computation. The gradient obtained from the backward computation is synchronized across the nodes, and the training of the next mini-batch continues after synchronization, and so on, until the accuracy/loss reaches a threshold, or a certain number of epochs are trained. It can be seen that computation and communication alternate in the distributed training process. Work has been done on how to do pipelining of interdependent computation and transmission to reduce the percentage of cross-node data synchronization in the overall training duration [5-6], which will not be repeated here.
+
+
+
+### The Necessity of Communication Fusion
+
+The time overhead of network communication can be measured by the following equation, where $m$ is the size of the data transmission, $\alpha$ is the network transmission rate, and $\beta$ is the inherent overhead of network startup. As can be seen, when the number of transmitted messages becomes larger, the inherent overhead share of network shartup rises, transmitting small messages does not make efficient use of network bandwidth resources. Even communication primitives in the HPC domain, such as `AllReduce` and `AllGather`, follow this principle. Therefore, communication fusion technology can effectively improve network resource utilization and reduce network synchronization delay.
+
+$$t = \alpha m+\beta$$
+
+### Communication Fusion Implementation
+
+Currently, fusion is supported for each of the three communication operators `AllReduce`, `AllGather` and `ReduceScatter`, with the configuration item being a dict type, e.g.
+
+comm_fusion={"allreduce": {"mode": "auto", "config": None}}, where "mode" has three options:
+
+"auto": Automatic operator fusion according to the data volume threshold of 64MB, with the configuration parameter "config" as None.
+
+"size": Communication operator fusion is performed by manually setting the data volume threshold, with the configuration parameter "config" of type int, in MB.
+
+"index": Only "allreduce" supports the configuration of index, which indicates the way of fusion according to the sequence number of communication operator, and the configuration parameter "config" is of type list. For example, [20, 35], means the first 20 AllReduce are fused into 1, the 20th to 35th AllReduce are fused into 1, and the remaining AllReduce are fused into 1.
+
+### Communication Fusion Usage
+
+MindSpore provides two interfaces to enable communication fusion, each of which is described below.
+
+#### Configuration in the Automatic Parallel Scenario
+
+In automatic parallel or semi-automatic parallel scenarios, users can use the `comm_fusion` parameter provided by this interface to set the parallel strategy when configuring the parallel strategy via `set_auto_parallel_context`. Users can specify whether to use the index method or the fusion buffer method.
+
+#### Using the Interfaces Provided by `Cell`
+
+Regardless of the parallel mode scenario, users can set the index for the parameters of a layer in the model through the `Cell.set_comm_fusion` interface, and MindSpore will fuse the parameters with the same index. In auto-parallel and semi-auto-parallel scenarios, it is recommended that the `comm_fusion` parameter be used in preference for configuration.
+
+## Operation Practice
+
+### Sample Code Description
+
+> You can download the full sample code here:
+>
+> .
+
+The directory structure is as follows:
+
+```text
+└─sample_code
+ ├─distributed_comm_fusion
+ ├── fusion_example_cell.py
+ ├── rank_table_2pcs.json
+ ├── rank_table_8pcs.json
+ └── run_fusion_example.sh
+```
+
+The function of each file is as follows:
+
+- fusion_example_cell.py: Example of communication fusion by using the interface provided by `Cell`.
+- rank_table_2pcs.json: 2-card configuration file of RANK_TABLE_FILE.
+- rank_table_8pcs.json: 8-card configuration file of RANK_TABLE_FILE.
+- run_fusion_example.sh: Startup script for communication fusion.
+
+### Configuring the Communication Fusion
+
+The following introduces the configuration of two usage methods through the practical sample.
+
+#### `comm_fusion` Parameter
+
+As shown in the following code, the `comm_fusion` parameter of the `set_auto_parallel_context` interface is used to configure the fusion mode for the `AllReduce` operator to be `auto`, implying that the fusion buffer size is set to 64MB by default.
+
+```python
+from mindspore.communication import init
+from mindspore import nn
+import mindspore as ms
+ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
+ms.set_auto_parallel_context(comm_fusion={"allreduce": {"mode": "auto", "config": None}})
+init()
+```
+
+If all similar communication operators are fused into one operator, in the current training iteration, the transmission needs to wait until the computation is completely finished before it can be executed, which will cause the device to wait.
+
+In order to avoid the above problem, the network parameters can be fused in groups: while the next group of parameters is computed, the communication of the previous group of parameters is carried out, so that the computation and communication can be hidden from each other, to perform group fusion either by limiting the size of the fusion buffer, or by index partitioning.
+
+For more usage, you can refer to MindSpore [test cases](https://gitee.com/mindspore/mindspore/blob/master/tests/ut/python/parallel/test_comm_fusion.py).
+
+> Users can try the size and index modes of `comm_fusion` on their own, which are essentially methods of the fusion buffer class.
+
+#### `Cell.set_comm_fusion` Interface
+
+As shown in the following code, the `set_comm_fusion` method is called for the instantiated DenseLayer to set the fusion value for each layer.
+
+```python
+"""Cell Fusion Example"""
+import os
+from mindspore.communication import init
+from mindspore import nn
+import mindspore as ms
+
+ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend", device_id=int(os.environ["DEVICE_ID"]))
+ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
+init()
+
+class DenseLayer(nn.Cell):
+ """A base layer with two dense layer"""
+ def __init__(self):
+ super().__init__()
+ self.input_mapping = nn.Dense(10, 10)
+ self.output_mapping = nn.Dense(10, 10)
+ def construct(self, x):
+ x = self.input_mapping(x)
+ return self.output_mapping(x)
+
+class Net(nn.Cell):
+ """An network with many dense layers"""
+ def __init__(self):
+ super().__init__()
+ self.layer1 = DenseLayer()
+ self.layer2 = DenseLayer()
+ self.layer3 = DenseLayer()
+ self.layer1.set_comm_fusion(0)
+ self.layer2.set_comm_fusion(1)
+ self.layer3.set_comm_fusion(2)
+ def construct(self, x):
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ return x
+
+net = Net()
+for item in net.trainable_params():
+ print(f"The parameter {item.name}'s fusion id is {item.comm_fusion}")
+```
+
+The corresponding output, representing the fusion index value for each layer of a particular dense, is as follows:
+
+```text
+The parameter layer1.input_mapping.weight's fusion id is 0
+The parameter layer1.input_mapping.bias's fusion id is 0
+The parameter layer1.output_mapping.weight's fusion id is 0
+The parameter layer1.output_mapping.bias's fusion id is 0
+The parameter layer2.input_mapping.weight's fusion id is 1
+The parameter layer2.input_mapping.bias's fusion id is 1
+The parameter layer2.output_mapping.weight's fusion id is 1
+The parameter layer2.output_mapping.bias's fusion id is 1
+The parameter layer3.input_mapping.weight's fusion id is 2
+The parameter layer3.input_mapping.bias's fusion id is 2
+The parameter layer3.output_mapping.weight's fusion id is 2
+The parameter layer3.output_mapping.bias's fusion id is 2
+```
+
+### Running the Code
+
+The above code needs to be configured with distributed variables before it can run. The Ascend environment needs to be configured with RANK_TABLE_FILE, RANK_ID and DEVICE_ID. For the configuration process, refer to [here](https://www.mindspore.cn/tutorials/experts/en/master/parallel/train_ascend.html#configuring-distributed-environment-variables). The GPU environment needs to be configured with [OpenMPI](https://www.mindspore.cn/tutorials/experts/en/master/parallel/train_gpu.html#configuring-distributed-environment), NCCL and [HOST_FILE](https://www.mindspore.cn/tutorials/experts/en/master/parallel/train_gpu.html#multi-host-training). For the configuration process, refer to [here](https://www.mindspore.cn/tutorials/experts/en/master/parallel/train_gpu.html#configuring-distributed-environment).
+
+Environment variables related to Ascend distributed are:
+
+- RANK_TABLE_FILE: the path of networking information file. The rank_table_file file can be generated by using hccl_tools.py in the models code repository, which can be obtained from [here](https://gitee.com/mindspore/models/tree/master/utils/hccl_tools).
+- DEVICE_ID: The actual serial number of the current card on the machine.
+- RANK_ID: The logical serial number of the current card.
+
+Environment variables related to GPU distributed are:
+
+- HOST_FILE: describes the IP and number of devices for multi-card training. Each line of the file has the format [hostname] slots=[slotnum], and hostname can be an ip or hostname. Note that the username needs to be the same on different machines, but the hostname cannot be the same.
+
+The user can access the above script in this document via [here](https://gitee.com/mindspore/docs/tree/master/docs/sample_code/distributed_optimizer_parallel). Execute the following `bash` script to run the program and output the log in the device0/train.log0 file.
+
+```bash
+#!/bin/bash
+set -e
+echo "=============================================================================================================="
+echo "Please run the script as: "
+echo "bash run_fusion_example.sh DATA_PATH RANK_SIZE"
+echo "For example: bash run_fusion_example.sh 8"
+echo "It is better to use the absolute path."
+echo "This example is expected to run on the Ascend environment."
+echo "=============================================================================================================="
+RANK_SIZE=$1
+
+EXEC_PATH=$(pwd)
+
+test_dist_8pcs()
+{
+ export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_8pcs.json
+ export RANK_SIZE=8
+}
+
+test_dist_2pcs()
+{
+ export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_2pcs.json
+ export RANK_SIZE=2
+}
+
+test_dist_${RANK_SIZE}pcs
+
+for((i=0;i<${RANK_SIZE};i++))
+do
+ rm -rf device$i
+ mkdir device$i
+ cp ./fusion_example_cell.py ./device$i
+ cd ./device$i
+ export DEVICE_ID=$i
+ export RANK_ID=$i
+ echo "start training for device $i"
+ env > env$i.log
+ pytest -s -v ./fusion_example_cell.py > train.log$i 2>&1 &
+ cd ../
+done
+echo "The program launch succeed, the log is under device0/train.log0."
+```
+
+After configuring RANK_TABLE_FILE in the current directory, the following command requires the user to have 8 Ascend 910 devices. Run the command as follows:
+
+```bash
+bash run_fusion_example.sh 8
+```
+
+## References
+
+[1] Xu Y, Lee H J, Chen D, et al. GSPMD: general and scalable parallelization for ML computation graphs[J]. arXiv preprint arXiv:2105.04663, 2021.
+
+[2] Li M, Zhou L, Yang Z, et al. Parameter server for distributed machine learning[C]//Big learning NIPS workshop. 2013, 6: 2.
+
+[3] Dean J, Corrado G, Monga R, et al. Large scale distributed deep networks[J]. Advances in neural information processing systems, 2012, 25.
+
+[4] Narayanan D, Harlap A, Phanishayee A, et al. PipeDream: generalized pipeline parallelism for DNN training[C]//Proceedings of the 27th ACM Symposium on Operating Systems Principles. 2019: 1-15.
+
+[5] Zhang H, Zheng Z, Xu S, et al. Poseidon: An efficient communication architecture for distributed deep learning on {GPU} clusters[C]//2017 USENIX Annual Technical Conference (USENIX ATC 17). 2017: 181-193.
+
+[6] Peng Y, Zhu Y, Chen Y, et al. A generic communication scheduler for distributed dnn training acceleration[C]//Proceedings of the 27th ACM Symposium on Operating Systems Principles. 2019: 16-29.
+
diff --git a/tutorials/experts/source_en/parallel/dataset_slice.md b/tutorials/experts/source_en/parallel/dataset_slice.md
new file mode 100644
index 0000000000000000000000000000000000000000..51e08390ba42c740d7bfc53e7a134dc5bd303117
--- /dev/null
+++ b/tutorials/experts/source_en/parallel/dataset_slice.md
@@ -0,0 +1,117 @@
+# Dataset Slicing
+
+
+
+## Overview
+
+When performing distributed training, taking image data as an example, when the size of a single image is too large, such as large-format images of remote sensing satellites, even one image is too large, it is necessary to slice the images and read a portion of each card to perform distributed training. Scenarios that deal with dataset slicing need to be combined with model parallelism to achieve the desired effect of reducing video memory, so this feature is provided based on automatic parallelism. The sample used in this tutorial is ResNet50, not a large-format network, and is intended as an example only. Real-life applications to large-format networks often require detailed design of parallel strategies.
+
+## Operation Practices
+
+### Sample Code Description
+
+> You can download the full sample code here:
+>
+>
+
+The directory structure is as follows:
+
+```text
+└─sample_code
+ ├─distributed_training
+ │ rank_table_16pcs.json
+ │ rank_table_8pcs.json
+ │ rank_table_2pcs.json
+ │ resnet.py
+ │ resnet50_distributed_training_dataset_slice.py
+ │ run_dataset_slice.sh
+```
+
+### Creating the Dataset
+
+> Dataset slicing is only supported in full/semi-automatic mode and is not involved in data parallel mode.
+
+When using dataset slicing, you need to call the [SlicePatches](https://www.mindspore.cn/docs/en/master/api_python/dataset_vision/mindspore.dataset.vision.SlicePatches.html) interface to construct the dataset at the same time. To ensure that the read-in data is consistent across cards, the dataset needs to be fixed with a random number seed.
+
+The dataset definition section is as follows.
+
+```python
+import mindspore as ms
+import mindspore.dataset as ds
+import mindspore.dataset.vision as vision
+import mindspore.dataset.transforms as transforms
+from mindspore.communication import init, get_rank, get_group_size
+
+ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend")
+init()
+ds.config.set_seed(1000) # set dataset seed to make sure that all cards read the same data
+def create_dataset(data_path, repeat_num=1, batch_size=32, slice_h_num=1, slice_w_num=1):
+ resize_height = 224
+ resize_width = 224
+ rescale = 1.0 / 255.0
+ shift = 0.0
+
+ rank_id = get_rank()
+
+ # create a full dataset before slicing
+ data_set = ds.Cifar10Dataset(data_path, shuffle=True)
+
+ # define map operations
+ random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4))
+ random_horizontal_op = vision.RandomHorizontalFlip()
+ resize_op = vision.Resize((resize_height, resize_width))
+ rescale_op = vision.Rescale(rescale, shift)
+ normalize_op = vision.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023))
+ changeswap_op = vision.HWC2CHW()
+ type_cast_op = transforms.TypeCast(ms.int32)
+
+ c_trans = [random_crop_op, random_horizontal_op]
+ c_trans += [resize_op, rescale_op, normalize_op]
+
+ # apply map operations on images
+ data_set = data_set.map(operations=type_cast_op, input_columns="label")
+ # in random map function, using num_parallel_workers=1 to avoid the dataset random seed not working.
+ data_set = data_set.map(operations=c_trans, input_columns="image", num_parallel_workers=1)
+ # slice image
+ slice_patchs_img_op = vision.SlicePatchs(slice_h_num, slice_w_num)
+ img_cols = ['img' + str(x) for x in range(slice_h_num * slice_w_num)]
+ data_set = data_set.map(operations=slice_patchs_img_op, input_columns="image", output_columns=img_cols)
+ data_set = data_set.project([img_cols[rank_id % (slice_h_num * slice_w_num)], "label"])
+ # change hwc to chw
+ data_set = data_set.map(operations=changeswap_op, input_columns=img_cols[rank_id % (slice_h_num * slice_w_num)])
+ # apply batch operations
+ data_set = data_set.batch(batch_size=batch_size, drop_remainder=True)
+
+ # apply repeat operations
+ data_set = data_set.repeat(repeat_num)
+
+ return data_set
+```
+
+### Configuring Dataset Slicing Strategy
+
+> Dataset slicing is only supported in full/semi-automatic mode and is not involved in data parallel mode.
+
+The `dataset_strategy` option is provided in `mindspore.auto_parallel_context` to configure the slicing strategy for the dataset.
+
+The dataset_strategy interface also has the following limitations:
+
+1. Each input is allowed to be sliced in at most one dimension. If support `set_auto_parallel_context(dataset_strategy=((1, 1, 1, 8), (8,))))` or `dataset_strategy=((1, 1, 1, 8), (1,)))`, each input is sliced to just one dimension, but does not support `dataset_strategy=((1, 1, 4, 2), (1,))`, whose first input is sliced to two dimensions.
+
+2. The number of slices for one input with the highest dimension, must be more than the other dimensions. If support `dataset_strategy=((1, 1, 1, 8), (8,)))` or `dataset_strategy=((1, 1, 1, 1, 1), (1,)))` is supported, the input with the most dimensions is the first input, the number of slices is 8, and the rest of the inputs are sliced by no more than 8 parts. However, it does not support `dataset_strategy=((1, 1, 1, 1), (8,)`, whose input with the most dimensions is the first dimension and the number of slices is 1, but the number of slices of second input is 8, which exceeds the number of slices of the first input.
+
+```python
+import os
+import mindspore as ms
+ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.AUTO_PARALLEL, gradients_mean=True)
+slice_h_num = 1
+slice_w_num = 8
+batch_size = 256
+ms.set_auto_parallel_context(dataset_strategy=(((1, 1, slice_h_num, slice_w_num), (1,))))
+data_path = os.getenv('DATA_PATH')
+dataset = create_dataset(data_path, batch_size=batch_size, slice_h_num=slice_h_num, slice_w_num=slice_w_num)
+```
+
+### Running the Code
+
+The data, code and execution of the above process can be found at: . The difference is that the execution script is changed to run_dataset_slice.sh.
diff --git a/tutorials/experts/source_en/parallel/other_features.rst b/tutorials/experts/source_en/parallel/other_features.rst
index 5ca5a68d1851595fada807cf6e8838d9d99ab2f7..a2b86bfc984ec4d15fb7906e64b3d2389f2387f6 100644
--- a/tutorials/experts/source_en/parallel/other_features.rst
+++ b/tutorials/experts/source_en/parallel/other_features.rst
@@ -10,6 +10,8 @@ Other Features
sharding_propagation
parameter_server_training
+ comm_fusion
+ dataset_slice
pynative_shard_function_parallel
ms_operator
@@ -42,8 +44,8 @@ some node abnormalities, and under the architecture of parameter
servers, such failures can be easily handled without affecting the tasks
in training.
-Communication Operator Fusion
------------------------------
+`Communication Operator Fusion `__
+---------------------------------------------------------------------------------------------------------------------
In the distributed training scenario, cross-device or even cross-node
data transmission is a bottleneck that restricts scalability and
@@ -54,8 +56,8 @@ communication operators of the same source node and the destination node
and executes them at the same time to avoid the additional overhead
caused by multiple single operator execution.
-Dataset Splitting
------------------
+`Dataset Slicing `__
+--------------------------------------------------------------------------------------------------------
When doing distributed training, you need to import the training dataset
to each device. There are two common ways to import: 1) Import in
diff --git a/tutorials/experts/source_zh_cn/parallel/comm_fusion.md b/tutorials/experts/source_zh_cn/parallel/comm_fusion.md
index 044c6a9f9a5c56d2b1e4754003818e3669ddf431..2a67d4082f392c39170223d3c3e67f71ee23dca8 100644
--- a/tutorials/experts/source_zh_cn/parallel/comm_fusion.md
+++ b/tutorials/experts/source_zh_cn/parallel/comm_fusion.md
@@ -6,7 +6,7 @@
在分布式并行训练场景下训练大规模参数量的模型(如GPT-3, Pangu-$\alpha$),跨设备甚至跨节点的数据传输是制约扩展性以及算力利用率的瓶颈[1]。通信融合是一种提升网络资源利用率、加速数据传输效率的重要方法,其将相同源节点和目的节点的通信算子打包同时执行,以避免多个单算子执行带来的额外开销。
-MindSpore支持对分布式训练中三种常用通信算子(`AllReduce`, `AllGather`, `ReduceScatter`)的融合,并提供简洁易用的接口方便用户自行配置。在长稳训练任务支撑中,通信融合特性发挥了重要作用。
+MindSpore支持对分布式训练中三种常用通信算子(`AllReduce`、`AllGather`、`ReduceScatter`)的融合,并提供简洁易用的接口方便用户自行配置。在长稳训练任务支撑中,通信融合特性发挥了重要作用。
## 基本原理
@@ -22,13 +22,13 @@ MindSpore支持对分布式训练中三种常用通信算子(`AllReduce`, `All
### 通信融合的必要性
-网络通信的时间开销可以用以下公式衡量,其中,$m$是传输数据的大小,$\alpha$是网络传输速率,$\beta$是网络启动的固有开销。可见,当传输的message数变多,网络启动的固有开销占比会上升,并且传输小message,并不能有效利用网络带宽资源。即便是HPC领域的通信原语,如`AllReduce`, `AllGather`等,也遵循该原则。因此,通信融合技术能够有效提升网络资源利用率,降低网络同步时延。
+网络通信的时间开销可以用以下公式衡量,其中,$m$是传输数据的大小,$\alpha$是网络传输速率,$\beta$是网络启动的固有开销。可见,当传输的message数变多,网络启动的固有开销占比会上升,并且传输小message,并不能有效利用网络带宽资源。即便是HPC领域的通信原语,如`AllReduce`,`AllGather`等,也遵循该原则。因此,通信融合技术能够有效提升网络资源利用率,降低网络同步时延。
$$t = \alpha m+\beta$$
### 通信融合的实现
-当前支持对`AllReduce`, `AllGather`和`ReduceScatter`三种通信算子分别进行融合,配置项为一个dict类型,如:
+当前支持对`AllReduce`,`AllGather`和`ReduceScatter`三种通信算子分别进行融合,配置项为一个dict类型,如:
comm_fusion={"allreduce": {"mode": "auto", "config": None}}。其中,"mode"有三种选项:
@@ -168,7 +168,7 @@ The parameter layer3.output_mapping.bias's fusion id is 2
上述代码需要在配置分布式变量后才可以运行。Ascend环境需要配置RANK_TABLE_FILE、RANK_ID和DEVICE_ID。配置的过程请参考[此处](https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/train_ascend.html#配置分布式环境变量),GPU环境需要配置[OpenMPI](https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/train_gpu.html#配置分布式环境)、NCCL和[HOST_FILE](https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/train_gpu.html#多机多卡训练),配置的过程请参考[此处](https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/train_gpu.html#配置分布式环境)。
-Ascend分布式相关的环境变量有:
+Ascend分布式相关的环境变量有:
- RANK_TABLE_FILE:组网信息文件的路径。rank_table_file文件可以使用models代码仓中的hccl_tools.py生成,可以从[此处](https://gitee.com/mindspore/models/tree/master/utils/hccl_tools)获取。
- DEVICE_ID:当前卡在机器上的实际序号。
@@ -176,7 +176,7 @@ Ascend分布式相关的环境变量有:
GPU分布式相关的环境变量:
-- HOST_FILE: 描述多卡训练时的设备IP和个数。文件每一行格式为[hostname] slots=[slotnum],hostname可以是ip或者主机名。需要注意的是,不同机器上的用户名需要相同,但是hostname不可以相同。
+- HOST_FILE:描述多卡训练时的设备IP和个数。文件每一行格式为[hostname] slots=[slotnum],hostname可以是ip或者主机名。需要注意的是,不同机器上的用户名需要相同,但是hostname不可以相同。
用户可以通过[此处](https://gitee.com/mindspore/docs/tree/master/docs/sample_code/distributed_optimizer_parallel)获取上述的此文档中的脚本。执行下述的`bash`脚本即可运行程序,输出日志在device0/train.log0文件。