diff --git a/tutorials/source_en/advanced_use/parameter_server_training.md b/tutorials/source_en/advanced_use/parameter_server_training.md index 375bdadc3b7aa9288622b50ddde5f90c33511b58..570bb28d7f6c21d533c17dfa56ab71bb0bba744c 100644 --- a/tutorials/source_en/advanced_use/parameter_server_training.md +++ b/tutorials/source_en/advanced_use/parameter_server_training.md @@ -39,20 +39,27 @@ Learn how to train a LeNet using the [MNIST dataset](http://yann.lecun.com/exdb/ ### Parameter Setting -In this training mode, you can use either of the following methods to control whether the training parameters are updated through the parameter server: +1. First of all, Use `mindspore.context.set_ps_context(enable_ps=True)` to enable Parameter Server training mode. + +- This method should be called before `mindspore.communication.management.init()`. +- If you don't call this method, the [Environment Variable Setting](https://www.mindspore.cn/tutorial/en/master/advanced_use/parameter_server_training.html#environment-variable-setting) below will not take effect. +- Use `mindspore.context.reset_ps_context()` to disable Parameter Server training mode. + +2. In this training mode, you can use either of the following methods to control whether the training parameters are updated through the parameter server: - Use `mindspore.nn.Cell.set_param_ps()` to set all weight recursions of `nn.Cell`. - Use `mindspore.common.Parameter.set_param_ps()` to set the weight. -On the basis of the [original training script](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/official/cv/lenet/train.py), set all LeNet model weights to be trained on the parameter server: +3. On the basis of the [original training script](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/official/cv/lenet/train.py), set all LeNet model weights to be trained on the parameter server: ```python +context.set_ps_context(enable_ps=True) network = LeNet5(cfg.num_classes) network.set_param_ps() ``` ### Environment Variable Setting -MindSpore reads environment variables to control parameter server training. The environment variables include the following options (all scripts of `MS_SCHED_HOST` and `MS_SCHED_POST` must be consistent): +MindSpore reads environment variables to control parameter server training. The environment variables include the following options (all scripts of `MS_SCHED_HOST` and `MS_SCHED_PORT` must be consistent): ``` export PS_VERBOSE=1 # Print ps-lite log @@ -67,38 +74,7 @@ export MS_ROLE=MS_SCHED # The role of this process: MS_SCHED repre 1. Shell scripts - Provide the shell scripts corresponding to the worker, server, and scheduler roles to start training, and the shell directory structure is as follows: - - ``` - └─mindspore - ├─model_zoo - └─official - └─cv - └─lenets - | Scheduler.sh - | Server.sh - | Worker.sh - ``` - - The data directory structure is as follows: - - ``` - └─mindspore - ├─model_zoo - └─official - └─cv - └─lenets - └─Data - ├─test - │ t10k-images.idx3-ubyte - │ t10k-labels.idx1-ubyte - │ - └─train - | train-images.idx3-ubyte - | train-labels.idx1-ubyte - ``` - - If it is an Ascend hardware, the content of the shell script is as follows. If it is a GPU device, then the `train.py` script need to specify `--device_ target="GPU"`. + Provide the shell scripts corresponding to the worker, server, and scheduler roles to start training: `Scheduler.sh`: ```bash @@ -109,7 +85,7 @@ export MS_ROLE=MS_SCHED # The role of this process: MS_SCHED repre export MS_SCHED_HOST=XXX.XXX.XXX.XXX export MS_SCHED_PORT=XXXX export MS_ROLE=MS_SCHED - python train.py + python train.py --device_target=Ascend --data_path=path/to/dataset ``` `Server.sh`: @@ -121,7 +97,7 @@ export MS_ROLE=MS_SCHED # The role of this process: MS_SCHED repre export MS_SCHED_HOST=XXX.XXX.XXX.XXX export MS_SCHED_PORT=XXXX export MS_ROLE=MS_PSERVER - python train.py + python train.py --device_target=Ascend --data_path=path/to/dataset ``` `Worker.sh`: @@ -133,7 +109,7 @@ export MS_ROLE=MS_SCHED # The role of this process: MS_SCHED repre export MS_SCHED_HOST=XXX.XXX.XXX.XXX export MS_SCHED_PORT=XXXX export MS_ROLE=MS_WORKER - python train.py + python train.py --device_target=Ascend --data_path=path/to/dataset ``` Run the following commands separately: diff --git a/tutorials/source_zh_cn/advanced_use/parameter_server_training.md b/tutorials/source_zh_cn/advanced_use/parameter_server_training.md index 32fe3e0debf40b57e8df064cbfd7f6b2fed97be7..934e6bbeb260a2d07acd56512db016e0342382b1 100644 --- a/tutorials/source_zh_cn/advanced_use/parameter_server_training.md +++ b/tutorials/source_zh_cn/advanced_use/parameter_server_training.md @@ -38,20 +38,27 @@ Parameter Server(参数服务器)是分布式训练中一种广泛使用的架 ### 参数设置 -在本训练模式下,有以下两种调用接口方式以控制训练参数是否通过Parameter Server进行更新: +1. 首先调用`mindspore.context.set_ps_context(enable_ps=True)`开启Parameter Server训练模式. -- 通过`mindspore.nn.Cell.set_param_ps()`对`nn.Cell`中所有权重递归设置 -- 通过`mindspore.common.Parameter.set_param_ps()`对此权重进行设置 +- 此接口需在`mindspore.communication.management.init()`之前调用。 +- 若没有调用此接口,下面的[环境变量设置](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/parameter_server_training.html#id5)则不会生效。 +- 调用`mindspore.context.reset_ps_context()`可以关闭Parameter Server训练模式。 -在[原训练脚本](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/official/cv/lenet/train.py)基础上,设置LeNet模型所有权重通过Parameter Server训练: +2. 在本训练模式下,有以下两种调用接口方式以控制训练参数是否通过Parameter Server进行更新: + +- 通过`mindspore.nn.Cell.set_param_ps()`对`nn.Cell`中所有权重递归设置。 +- 通过`mindspore.common.Parameter.set_param_ps()`对此权重进行设置。 + +3. 在[原训练脚本](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/official/cv/lenet/train.py)基础上,设置LeNet模型所有权重通过Parameter Server训练: ```python +context.set_ps_context(enable_ps=True) network = LeNet5(cfg.num_classes) network.set_param_ps() ``` ### 环境变量设置 -MindSpore通过读取环境变量,控制Parameter Server训练,环境变量包括以下选项(其中MS_SCHED_HOST及MS_SCHED_POST所有脚本需保持一致): +MindSpore通过读取环境变量,控制Parameter Server训练,环境变量包括以下选项(其中`MS_SCHED_HOST`及`MS_SCHED_PORT`所有脚本需保持一致): ``` export PS_VERBOSE=1 # Print ps-lite log @@ -66,38 +73,7 @@ export MS_ROLE=MS_SCHED # The role of this process: MS_SCHED repre 1. shell脚本 - 提供Worker,Server和Scheduler三个角色对应的shell脚本,以启动训练,shell脚本的结构如下: - - ``` - └─mindspore - ├─model_zoo - └─official - └─cv - └─lenets - | Scheduler.sh - | Server.sh - | Worker.sh - ``` - - 数据集的目录如下: - - ``` - └─mindspore - ├─model_zoo - └─official - └─cv - └─lenets - └─Data - ├─test - │ t10k-images.idx3-ubyte - │ t10k-labels.idx1-ubyte - │ - └─train - | train-images.idx3-ubyte - | train-labels.idx1-ubyte - ``` - - 如果是Ascend设备,那么脚本的内容如下所示,如果是GPU设备,那么`train.py`脚本需要指定`--device_target="GPU"`。 + 提供Worker,Server和Scheduler三个角色对应的shell脚本,以启动训练: `Scheduler.sh`: @@ -109,7 +85,7 @@ export MS_ROLE=MS_SCHED # The role of this process: MS_SCHED repre export MS_SCHED_HOST=XXX.XXX.XXX.XXX export MS_SCHED_PORT=XXXX export MS_ROLE=MS_SCHED - python train.py + python train.py --device_target=Ascend --data_path=path/to/dataset ``` `Server.sh`: @@ -121,7 +97,7 @@ export MS_ROLE=MS_SCHED # The role of this process: MS_SCHED repre export MS_SCHED_HOST=XXX.XXX.XXX.XXX export MS_SCHED_PORT=XXXX export MS_ROLE=MS_PSERVER - python train.py + python train.py --device_target=Ascend --data_path=path/to/dataset ``` `Worker.sh`: @@ -133,7 +109,7 @@ export MS_ROLE=MS_SCHED # The role of this process: MS_SCHED repre export MS_SCHED_HOST=XXX.XXX.XXX.XXX export MS_SCHED_PORT=XXXX export MS_ROLE=MS_WORKER - python train.py + python train.py --device_target=Ascend --data_path=path/to/dataset ``` 最后分别执行: