diff --git a/tutorials/source_en/advanced_use/checkpoint_for_hybrid_parallel.md b/tutorials/source_en/advanced_use/checkpoint_for_hybrid_parallel.md index d9bd74fdc4335914999a283cc9638b7bbd8298cc..0ac2c4beb7f7d2c50d507f4f4311ad19284eefce 100644 --- a/tutorials/source_en/advanced_use/checkpoint_for_hybrid_parallel.md +++ b/tutorials/source_en/advanced_use/checkpoint_for_hybrid_parallel.md @@ -9,8 +9,8 @@ - [Integrating the Saved Checkpoint Files](#integrating-the-saved-checkpoint-files) - [Overall Process](#overall-process) - [Preparations](#preparations) - - [Importing the Checkpoint Files to the Network](#importing-the-checkpoint-files-to-the-network) - - [Obtaining a List of All Parameters on the Network](#obtaining-a-list-of-all-parameters-on-the-network) + - [Importing the Checkpoint Files in rank id order](#importing-the-checkpoint-files-in-rank-id-order) + - [Obtaining the slice strategy of model](#obtaining-the-slice-strategy-of-model) - [Integrate the Model Parallel Parameters](#integrate-the-model-parallel-parameters) - [Saving the Data and Generating a New Checkpoint File](#saving-the-data-and-generating-a-new-checkpoint-file) - [Loading the Integrated and Saved Checkpoint File](#loading-the-integrated-and-saved-checkpoint-file) @@ -71,7 +71,7 @@ For example, in the training stage 1, the training environment with 64 devices i ### Overall Process -Import the checkpoint files to be integrated to the network and obtain the list of all parameters through the API provided by MindSpore. See steps 1 and 2 in the following figure. +Import the checkpoint files to be integrated to the network in rank id order and obtain the list of all parameters through the API provided by MindSpore, and then obtain the slice strategy of model. See steps 1 and 2 in the following figure. Then, update the parameter list and integrate the model parallel parameters. See step 3 in the following figure. @@ -81,120 +81,67 @@ Finally, save the updated parameter list to a file through the API provided by M ### Preparations -#### Importing the Checkpoint Files to the Network +#### Importing the Checkpoint Files in rank id order -Define the network, call the `load_checkpoint` and `load_param_into_net` APIs, and import the checkpoint files to the network. +Define the network, call the `load_checkpoint` and `load_param_into_net` APIs to import the checkpoint files to the network in rank id order, and then call `parameters_and_names` API to obtain all parameters in this network. ``` -param_dict = load_checkpoint(./CKP_1-4_32.ckpt) # checkpoint file name net = Net() opt = Momentum(learning_rate=0.01, momentum=0.9, params=net.get_parameters()) net = TrainOneStepCell(net, opt) -load_param_into_net(net, param_dict) +param_dicts = [] +for i in range(rank_size): + file_name = os.path.join("./node"+str(i), "CKP_1-4_32.ckpt") # checkpoint file name of current node + param_dict = load_checkpoint(file_name) + load_param_into_net(net, param_dict) + param_dict = {} + for _, param in net.parameters_and_names(): + param_dict[param.name] = param + param_dicts.append(param_dict) ``` In the preceding information: +- `rank_size`: number of nodes in previous distributed training. - `load_checkpoint`: loads the checkpoint model parameter file and returns a parameter dictionary. - `load_param_into_net`: loads model parameter data to the network. -- `CKP_1-4_32.ckpt`: name of the saved checkpoint model parameter file. - -> If a new checkpoint file is directly saved in the training environment based on the current training data and the parameter values already exist on the network, skip this step and you do not need to import the checkpoint files. #### Obtaining a List of All Parameters on the Network -Call the `parameters_and_names` API to obtain all parameter data on the network. +Call the `build_searched_strategy` API to obtain the slice strategy of model. ``` -param_dict = {} -for _, param in net.parameters_and_names(): - param_dict[param.name] = param +strategy = build_searched_strategy("./strategy_train.ckpt") ``` -### Integrate the Model Parallel Parameters +In the preceding information: -The following uses a model parameter as an example to describe a specific integration process. +- `strategy_train.ckpt`: name of model slice strategy, set by users calling `set_auto_parallel_context` API and customizing `strategy_ckpt_save_file` parameter before training network, and the file saved on each node are the same. -The parameter name is model\_parallel\_weight and the data is Tensor \[\[1, 2, 3, 4], \[5, 6, 7, 8]]. +### Integrate the Model Parallel Parameters -The dividing strategy is to perform dividing in a 4-device scenario based on \[2, 2]. That is, the data is first divided into two slices in the row dimension, then the two slices are respectively divided into two smaller slices in the column dimension, and finally four slices are obtained. Data distribution after dividing is as follows: +The following uses a model parameter as an example to describe a specific integration process. -| Device0 | Device1 | Device2 | Device3 | -|--------------|--------------|--------------|--------------| -| Value [1, 2] | Value [3, 4] | Value [5, 6] | Value [7, 8] | +The parameter name is model\_parallel\_weight and the dividing strategy is to perform dividing in a 4-device scenario. -1. Obtain the data value on the current node for model parallel parameters. +1. Obtain the data value on all nodes for model parallel parameters. ``` - param_data = param_dict[“model_parallel_weight”] - param_data_moments = param_dict[“moments.model_parallel_weight”] + sliced_parameters = [] + for i in range(4): + parameter = param_dicts[i].get("model_parallel_weight") + sliced_parameters.append(parameter) ``` > To ensure that the parameter update speed remains unchanged, you need to integrate the parameters saved in the optimizer, for example, moments.model\_parallel\_weight. -2. Define, instantiate, and execute the `AllGather` Cell, and obtain data on all devices. - - ``` - from mindspore.nn.cell import Cell - from mindspore.ops.operations.comm_ops import AllGather - - class AllGatherCell(Cell): - """ - Allgather cell, used in model parallel scenario. - To allgather the selected parameter slice from each device. - """ - def __init__(self): - super(AllGatherCell, self).__init__(auto_prefix=False) - self.allgather = AllGather() - - def construct(self, x): - x = self.allgather(x) - return x - - allgather_net = AllGatherCell() - param_data = allgather_net(param_data) - param_data_moments = allgather_net(param_data_moments) - ``` - - The value of `param_data` is the integration of data on each device in dimension 0. The data value is \[\[1, 2], \[3, 4], \[5, 6], \[7, 8]], and the shape is \[4, 2]. The raw data value of `param_data` is \[\[1, 2, 3, 4], \[5, 6, 7, 8]], and the shape is \[2, 4]. The data needs to be redivided and integrated. - -3. Divide the data obtained from `AllGather`. - - ``` - slice_list = np.split(param_data.asnumpy(), 4, axis=0) # 4:group_size, number of nodes in cluster - slice_lis_moments = np.split(param_data_moments.asnumpy(), 4, axis=0) # 4: group_size, number of nodes in cluster - ``` - - The result of `param_data` is as follows: - - slice_list[0] --- [1, 2] Slice data on device0 - slice_list[1] --- [3, 4] Slice data on device1 - slice_list[2] --- [5, 6] Slice data on device2 - slice_list[3] --- [7, 8] Slice data on device3 - -4. Reassemble data based on the site requirements. - - In the following code, slice 1 and slice 2, slice 3 and slice 4 are first spliced by column, and then the obtained data is spliced by row. - - ``` - slice_line1 = np.concatenate((slice_list[0], slice_list[1]), axis=1) # result [1,2,3,4] - slice_line2 = np.concatenate((slice_list[2], slice_list[3]), axis=1) # result [5,6,7,8] - whole_data = np.concatenate((slice_line1, slice_line2), axis=0) # result [[1, 2, 3, 4], [5, 6, 7, 8]] +2. Call the `merge_sliced_parameter` API to merge the sliced parameters. - slice_moments_line1 = np.concatenate((slice_lis_moments[0], slice_lis_moments[1]), axis=1) - slice_moments_line2 = np.concatenate((slice_lis_moments[2], slice_lis_moments[3]), axis=1) - whole_moments_data = np.concatenate((slice_moments_line1, slice_moments_line2), axis=0) ``` - -5. Assign values to model parameters. - + merged_parameter = merge_sliced_parameter(sliced_parameters, strategy) ``` - param_data = Tensor(whole_data) - param_data_moments = Tensor(whole_moments_data) - ``` - -> 1. If there are multiple model parallel parameters, repeat steps 1 to 5 to process them one by one. -> 2. If the data obtained in step 2 is the final data, skip the following steps. That is, the dividing strategy is to perform dividing only on shape0 and each device loads different slice data. + +> If there are multiple model parallel parameters, repeat steps 1 to 2 to process them one by one. ### Saving the Data and Generating a New Checkpoint File @@ -324,106 +271,89 @@ User process: ``` - python ./integrate_checkpoint.py "Path and name of the checkpoint file to be integrated" "Path and name of the checkpoint file generated after integration" + python ./integrate_checkpoint.py "Name of the checkpoint file to be integrated" "Path and name of the checkpoint file generated after integration" "Path and name of the strategy file" "Number of nodes" ``` integrate\_checkpoint.py: ``` - import numpy as np - import os - import mindspore.nn as nn - from mindspore import context - from mindspore import Tensor, Parameter - from mindspore.ops import operations as P - from mindspore.ops.operations.comm_ops import AllGather - from mindspore.communication.management import init - from mindspore.train.serialization import save_checkpoint, load_checkpoint - devid = int(os.getenv('DEVICE_ID')) - context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=True, device_id=devid) - init() - - class Net(nn.Cell): - def __init__(self,weight_init): - super(Net, self).__init__() - self.weight = Parameter(Tensor(weight_init), "model_parallel_weight", layerwise_parallel=True) - self.fc = P.MatMul(transpose_b=True) - - def construct(self, x): - x = self.fc(x, self.weight1) - return x - - class AllGatherNet(Cell): - """ - Allgather cell, used in model parallel scenario. - To allgather the selected parameter slice from each device. - """ - def __init__(self): - super().__init__() - self.allgather = AllGather() - - def construct(self, x): - x = self.allgather(x) - return x - - def integrate_ckpt_file(old_ckpt_file, new_ckpt_file): - weight = np.ones([2, 8]).astype(np.float32) - net = Net(weight) - opt = Momentum(learning_rate=0.01, momentum=0.9, params=net.get_parameters()) - net = TrainOneStepCell(net, opt) - - # load CheckPoint into net - param_dict = load_checkpoint(old_ckpt_file) - load_param_into_net(net, param_dict) - param_dict = {} - for _, param in net.parameters_and_names(): - param_dict[param.name] = param - - for paramname in ["model_parallel_weight", "moments.model_parallel_weight"]: - # get layer wise model parallel parameter - layerwise_param = param_dict[paramname] - if isinstance(layerwise_param.data, Tensor): - param_data = layerwise_param.data - else: - param_data = Tensor(layerwise_param.data) - # merge the parallel parameters of the model - allgather_net = get_allgather_cell() - param_data = allgather_net(param_data) - layerwise_param.set_parameter_data(param_data, True) - - # convert param_dict to list type data - param_list = [] - for (key, value) in param_dict.items(): - each_param = {} - each_param["name"] = key - if isinstance(value.data, Tensor): - param_data = value.data - else: - param_data = Tensor(value.data) - each_param["data"] = param_data - param_list.append(each_param) - - # call the API to generate a new CheckPoint file - save_checkpoint(param_list, new_ckpt_file) - - return - - if __name__ == "__main__": - try: - old_ckpt_file = sys.argv[1] - new_ckpt_file = sys.argv[2] - integrate(old_ckpt_file, new_ckpt_file) - except: - print("Fail to integrate checkpoint file) - sys.exit(-1) + import numpy as np + import os + import mindspore.nn as nn + from mindspore import Tensor, Parameter + from mindspore.ops import operations as P + from mindspore.train.serialization import save_checkpoint, load_checkpoint, build_searched_strategy, merge_sliced_parameter + + class Net(nn.Cell): + def __init__(self,weight_init): + super(Net, self).__init__() + self.weight = Parameter(Tensor(weight_init), "model_parallel_weight", layerwise_parallel=True) + self.fc = P.MatMul(transpose_b=True) + + def construct(self, x): + x = self.fc(x, self.weight1) + return x + + def integrate_ckpt_file(old_ckpt_file, new_ckpt_file, strategy_file, rank_size): + weight = np.ones([2, 8]).astype(np.float32) + net = Net(weight) + opt = Momentum(learning_rate=0.01, momentum=0.9, params=net.get_parameters()) + net = TrainOneStepCell(net, opt) + + # load CheckPoint into net in rank id order + param_dicts = [] + for i in range(rank_size): + file_name = os.path.join("./node"+str(i), old_ckpt_file) + param_dict = load_checkpoint(file_name) + load_param_into_net(net, param_dict) + param_dict = {} + for _, param in net.parameters_and_names(): + param_dict[param.name] = param + param_dicts.append(param_dict) + + strategy = build_searched_strategy(strategy_file) + param_dict = {} + + for paramname in ["model_parallel_weight", "moments.model_parallel_weight"]: + # get layer wise model parallel parameter + sliced_parameters = [] + for i in range(rank_size): + parameter = param_dicts[i].get(paramname) + sliced_parameters.append(parameter) + + # merge the parallel parameters of the model + merged_parameter = merge_sliced_parameter(sliced_parameters, strategy) + param_dict[paramname] = merged_parameter + + # convert param_dict to list type data + param_list = [] + for (key, value) in param_dict.items(): + each_param = {} + each_param["name"] = key + if isinstance(value.data, Tensor): + param_data = value.data + else: + param_data = Tensor(value.data) + each_param["data"] = param_data + param_list.append(each_param) + + # call the API to generate a new CheckPoint file + save_checkpoint(param_list, new_ckpt_file) + + return + + if __name__ == "__main__": + try: + old_ckpt_file = sys.argv[1] + new_ckpt_file = sys.argv[2] + strategy_file = sys.argv[3] + rank_size = int(sys.argv[4]) + integrate_ckpt_file(old_ckpt_file, new_ckpt_file, strategy_file, rank_size) + except: + print("Fail to integrate checkpoint file) + sys.exit(-1) ``` - In the preceding information: - - - `mode=context.GRAPH_MODE`: sets the running mode to graph mode for distributed training. (The PyNative mode does not support parallel running.) - - `device_id`: physical sequence number of a device, that is, the actual sequence number of the device on a computer where the device is located. - - `init`: completes the distributed training initialization. - The command output is as follows. Before the script is executed, the parameter values in the checkpoint files are as follows: @@ -523,6 +453,7 @@ User process: import os import mindspore.nn as nn from mindspore import context + from mindspore.communication.management import init from mindspore import Tensor, Parameter from mindspore.ops import operations as P from mindspore.train.serialization import load_checkpoint, load_param_into_net @@ -570,6 +501,12 @@ User process: label = np.random.random((4, 4)).astype(np.float32) train_mindspore_impl_fc(input, label, weight1) ``` + + In the preceding information: + + - `mode=context.GRAPH_MODE`: sets the running mode to graph mode for distributed training. (The PyNative mode does not support parallel running.) + - `device_id`: physical sequence number of a device, that is, the actual sequence number of the device on a computer where the device is located. + - `init`: completes the distributed training initialization. Parameter values after loading: diff --git a/tutorials/source_en/advanced_use/images/checkpoint_integration_process.jpg b/tutorials/source_en/advanced_use/images/checkpoint_integration_process.jpg index a3d190897587c64d027a3284edade1850e0bce2f..e344782ec9ed7df74f21e1c7c3fc4643b29eebbf 100644 Binary files a/tutorials/source_en/advanced_use/images/checkpoint_integration_process.jpg and b/tutorials/source_en/advanced_use/images/checkpoint_integration_process.jpg differ diff --git a/tutorials/source_zh_cn/advanced_use/checkpoint_for_hybrid_parallel.md b/tutorials/source_zh_cn/advanced_use/checkpoint_for_hybrid_parallel.md index 9ece21de9af2400b2aa2f30025b4f86fc3677d75..8a7853a27c8e1e3bf773c13e471e919027b98208 100644 --- a/tutorials/source_zh_cn/advanced_use/checkpoint_for_hybrid_parallel.md +++ b/tutorials/source_zh_cn/advanced_use/checkpoint_for_hybrid_parallel.md @@ -11,8 +11,8 @@ - [对保存的CheckPoint文件做合并处理](#对保存的checkpoint文件做合并处理) - [整体流程](#整体流程) - [准备工作](#准备工作) - - [导入CheckPoint文件到网络](#导入checkpoint文件到网络) - - [获取网络中全量参数列表](#获取网络中全量参数列表) + - [按逻辑顺序导入CheckPoint文件](#按逻辑顺序导入checkpoint文件) + - [获取模型参数切分策略](#获取模型参数切分策略) - [对模型并行的参数做合并处理](#对模型并行的参数做合并处理) - [保存数据生成新的CheckPoint文件](#保存数据生成新的checkpoint文件) - [加载合并保存的CheckPoint文件](#加载合并保存的checkpoint文件) @@ -79,7 +79,7 @@ MindSpore模型并行场景下,每个实例进程只保存有本节点对应 ### 整体流程 -首先,执行准备工作,将待合并处理的CheckPoint文件导入网络,并通过MindSpore提供的API获取全量参数列表。对应下图中的Step1和Step2。 +首先,执行准备工作,按逻辑顺序将待合并处理的CheckPoint文件导入网络,获取模型全量参数并添加至列表中,再获取模型参数切分策略。对应下图中的Step1和Step2。 其次,更新参数列表,对涉及模型并行的参数做合并处理。对应下图中的Step3。 @@ -89,119 +89,64 @@ MindSpore模型并行场景下,每个实例进程只保存有本节点对应 ### 准备工作 -#### 导入CheckPoint文件到网络 +#### 按逻辑顺序导入CheckPoint文件 -定义网络,并调用`load_checkpoint`、`load_param_into_net`接口,将CheckPoint文件导入网络。 +定义网络,调用`load_checkpoint`、`load_param_into_net`接口,按逻辑顺序将CheckPoint文件导入网络,之后调用`parameters_and_names`接口获取网络里所有的参数数据。 ``` -param_dict = load_checkpoint(./CKP_1-4_32.ckpt) # checkpoint file name net = Net() opt = Momentum(learning_rate=0.01, momentum=0.9, params=net.get_parameters()) net = TrainOneStepCell(net, opt) -load_param_into_net(net, param_dict) +param_dicts = [] +for i in range(rank_size): + file_name = os.path.join("./node"+str(i), "CKP_1-4_32.ckpt") # checkpoint file name of current node + param_dict = load_checkpoint(file_name) + load_param_into_net(net, param_dict) + param_dict = {} + for _, param in net.parameters_and_names(): + param_dict[param.name] = param + param_dicts.append(param_dict) ``` 其中, +- `rank_size`:之前分布式训练的节点数。 - `load_checkpoint`:通过该接口加载CheckPoint模型参数文件,返回一个参数字典。 - `load_param_into_net`:模型参数数据加载到网络中。 -- `CKP_1-4_32.ckpt`:之前保存的CheckPoint模型参数文件名称。 - -> 如果直接在训练环境上,基于当前训练得到的数据直接保存新的CheckPoint文件,参数值已经存在在网络中,则可以省略该步骤,无需导入CheckPoint文件。 -#### 获取网络中全量参数列表 +#### 获取模型参数切分策略 -调用`parameters_and_names`接口,获取网络里所有的参数数据。 +调用`build_searched_strategy`接口,得到模型各个参数的切分策略。 ``` -param_dict = {} -for _, param in net.parameters_and_names(): - param_dict[param.name] = param +strategy = build_searched_strategy("./strategy_train.cpkt") ``` -### 对模型并行的参数做合并处理 +其中, -下面以一个具体的模型参数为例,说明下参数合并处理的具体流程。 +- `strategy_train.ckpt`:保存的模型参数切分策略文件名称,训练网络之前由用户调用`set_auto_parallel_context`接口自定义`strategy_ckpt_save_file`参数生成,各个节点上保存的策略文件相同。 -参数名称为"model_parallel_weight",数据为Tensor [[1, 2, 3, 4], [5, 6, 7, 8]]。 +### 对模型并行的参数做合并处理 -切分逻辑为4卡场景,按[2, 2]切分,即先在行维度切分为2个切片,之后再对得到的2个切片,分别在列维度分再切分为2个更小的切片,最后得到4个切片。 -切分后数据分布情况如下: +下面以一个具体的模型参数为例,说明下参数合并处理的具体流程。 -| Device0 | Device1 | Device2 | Device3 | -| ------------- | ------------ | ------------- | ------------- | -| Value [1, 2] | Value [3, 4] | Value [5, 6] | Value [7, 8] | +参数名称为"model_parallel_weight",切分逻辑为4卡场景。 -1. 针对涉及模型并行的参数,获取本节点上的数据值。 +1. 针对涉及模型并行的参数,获取所有节点上的参数数据。 ``` - param_data = param_dict[“model_parallel_weight”] - param_data_moments = param_dict[“moments.model_parallel_weight”] + sliced_parameters = [] + for i in range(4): + parameter = param_dicts[i].get("model_parallel_weight") + sliced_parameters.append(parameter) ``` > 如果要保证参数更新速度不变,需要对优化器中保存的参数,如“moments.model_parallel_weight”,同样做合并处理。 -2. 定义`AllGather`类型子图,并实例化和执行,获取所有卡上的数据。 - - ``` - from mindspore.nn.cell import Cell - from mindspore.ops.operations.comm_ops import AllGather - - class AllGatherCell(Cell): - """ - Allgather cell, used in model parallel scenario. - To allgather the selected parameter slice from each device. - """ - def __init__(self): - super(AllGatherCell, self).__init__(auto_prefix=False) - self.allgather = AllGather() - - def construct(self, x): - x = self.allgather(x) - return x - - allgather_net = AllGatherCell() - param_data = allgather_net(param_data) - param_data_moments = allgather_net(param_data_moments) - ``` - - ​得到的数据`param_data`为每卡上的数据在维度0上的合并,数据值为 [[1, 2], [3, 4], [5, 6], [7, 8]],shape为[4, 2]。 - ​`param_data`原始数据值为[[1, 2, 3, 4], [5, 6, 7, 8]],shape为[2, 4],需要对数据重新切分合并。 - -3. 切分通过`AllGather`得到的数据。 - - ``` - slice_list = np.split(param_data.asnumpy(), 4, axis=0) # 4:group_size, number of nodes in cluster - slice_lis_moments = np.split(param_data_moments.asnumpy(), 4, axis=0) # 4: group_size, number of nodes in cluster - ``` - - 得到结果`param_data`为: - - slice_list[0] --- [1, 2] device0上的切片数据 - slice_list[1] --- [3, 4] device1上的切片数据 - slice_list[2] --- [5, 6] device2上的切片数据 - slice_list[3] --- [7, 8] device3上的切片数据 - -4. 按照实际情况,重新组装数据。 - - 如下代码,先分别对切片1和切片2,切片3和切片4按列拼接,之后对前两步得到的数据按行拼接。 - ``` - slice_line1 = np.concatenate((slice_list[0], slice_list[1]), axis=1) # result [1,2,3,4] - slice_line2 = np.concatenate((slice_list[2], slice_list[3]), axis=1) # result [5,6,7,8] - whole_data = np.concatenate((slice_line1, slice_line2), axis=0) # result [[1, 2, 3, 4], [5, 6, 7, 8]] - - slice_moments_line1 = np.concatenate((slice_lis_moments[0], slice_lis_moments[1]), axis=1) - slice_moments_line2 = np.concatenate((slice_lis_moments[2], slice_lis_moments[3]), axis=1) - whole_moments_data = np.concatenate((slice_moments_line1, slice_moments_line2), axis=0) - ``` - -5. 对模型参数赋值。 +2. 调用`merge_sliced_parameter`接口进行参数合并。 ``` - param_data = Tensor(whole_data) - param_data_moments = Tensor(whole_moments_data) + merged_parameter = merge_sliced_parameter(sliced_parameters, strategy) ``` -> 1. 如果存在多个模型并行的参数,则需要重复步骤1到步骤5循环逐个处理。 -> 2. 如果步骤2执行`allgather`子图获取的数据,已经是最终的数据,则后面的步骤可省略。 -> 即本身切分逻辑是仅在shape0上切分,每个卡加载不同切片数据。 +> 如果存在多个模型并行的参数,则需要重复步骤1到步骤2循环逐个处理。 ### 保存数据生成新的CheckPoint文件 @@ -327,7 +272,7 @@ load_param_into_net(opt, param_dict) 脚本执行命令: ``` - python ./integrate_checkpoint.py "待合并的CheckPoint文件路径&名称" "合并生成的CheckPoint文件路径&名称" + python ./integrate_checkpoint.py "待合并的CheckPoint文件名称" "合并生成的CheckPoint文件路径&名称" "策略文件路径&名称" "节点数" ``` integrate_checkpoint.py: @@ -336,15 +281,9 @@ load_param_into_net(opt, param_dict) import numpy as np import os import mindspore.nn as nn - from mindspore import context from mindspore import Tensor, Parameter from mindspore.ops import operations as P - from mindspore.ops.operations.comm_ops import AllGather - from mindspore.communication.management import init - from mindspore.train.serialization import save_checkpoint, load_checkpoint - devid = int(os.getenv('DEVICE_ID')) - context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=True, device_id=devid) - init() + from mindspore.train.serialization import save_checkpoint, load_checkpoint, build_searched_strategy, merge_sliced_parameter class Net(nn.Cell): def __init__(self,weight_init): @@ -356,43 +295,36 @@ load_param_into_net(opt, param_dict) x = self.fc(x, self.weight1) return x - class AllGatherNet(Cell): - """ - Allgather cell, used in model parallel scenario. - To allgather the selected parameter slice from each device. - """ - def __init__(self): - super().__init__() - self.allgather = AllGather() - - def construct(self, x): - x = self.allgather(x) - return x - - def integrate_ckpt_file(old_ckpt_file, new_ckpt_file): + def integrate_ckpt_file(old_ckpt_file, new_ckpt_file, strategy_file, rank_size): weight = np.ones([2, 8]).astype(np.float32) net = Net(weight) opt = Momentum(learning_rate=0.01, momentum=0.9, params=net.get_parameters()) net = TrainOneStepCell(net, opt) - # load CheckPoint into net - param_dict = load_checkpoint(old_ckpt_file) - load_param_into_net(net, param_dict) + # load CheckPoint into net in rank id order + param_dicts = [] + for i in range(rank_size): + file_name = os.path.join("./node"+str(i), old_ckpt_file) + param_dict = load_checkpoint(file_name) + load_param_into_net(net, param_dict) + param_dict = {} + for _, param in net.parameters_and_names(): + param_dict[param.name] = param + param_dicts.append(param_dict) + + strategy = build_searched_strategy(strategy_file) param_dict = {} - for _, param in net.parameters_and_names(): - param_dict[param.name] = param - + for paramname in ["model_parallel_weight", "moments.model_parallel_weight"]: # get layer wise model parallel parameter - layerwise_param = param_dict[paramname] - if isinstance(layerwise_param.data, Tensor): - param_data = layerwise_param.data - else: - param_data = Tensor(layerwise_param.data) + sliced_parameters = [] + for i in range(rank_size): + parameter = param_dicts[i].get(paramname) + sliced_parameters.append(parameter) + # merge the parallel parameters of the model - allgather_net = get_allgather_cell() - param_data = allgather_net(param_data) - layerwise_param.set_parameter_data(param_data, True) + merged_parameter = merge_sliced_parameter(sliced_parameters, strategy) + param_dict[paramname] = merged_parameter # convert param_dict to list type data param_list = [] @@ -415,18 +347,14 @@ load_param_into_net(opt, param_dict) try: old_ckpt_file = sys.argv[1] new_ckpt_file = sys.argv[2] - integrate(old_ckpt_file, new_ckpt_file) + strategy_file = sys.argv[3] + rank_size = int(sys.argv[4]) + integrate_ckpt_file(old_ckpt_file, new_ckpt_file, strategy_file, rank_size) except: print("Fail to integrate checkpoint file) sys.exit(-1) ``` - - 其中, - - - `mode=context.GRAPH_MODE`:使用分布式训练需要指定运行模式为图模式(PyNative模式不支持并行)。 - - `device_id`:卡物理序号,即卡所在机器中的实际序号。 - - `init`:完成分布式训练初始化操作。 - + 执行结果: 脚本执行前,CheckPoint文件中参数值: @@ -526,6 +454,7 @@ load_param_into_net(opt, param_dict) import os import mindspore.nn as nn from mindspore import context + from mindspore.communication.management import init from mindspore import Tensor, Parameter from mindspore.ops import operations as P from mindspore.train.serialization import load_checkpoint, load_param_into_net @@ -573,6 +502,12 @@ load_param_into_net(opt, param_dict) label = np.random.random((4, 4)).astype(np.float32) train_mindspore_impl_fc(input, label, weight1) ``` + + 其中, + + - `mode=context.GRAPH_MODE`:使用分布式训练需要指定运行模式为图模式(PyNative模式不支持并行)。 + - `device_id`:卡物理序号,即卡所在机器中的实际序号。 + - `init`:完成分布式训练初始化操作。 加载后的参数值: diff --git a/tutorials/source_zh_cn/advanced_use/images/checkpoint_integration_process.jpg b/tutorials/source_zh_cn/advanced_use/images/checkpoint_integration_process.jpg index 518bfaaed758e283631025e6bcd49e9facae7cb5..7766e9aa32e79b823021de6d815e407c9f60e91b 100644 Binary files a/tutorials/source_zh_cn/advanced_use/images/checkpoint_integration_process.jpg and b/tutorials/source_zh_cn/advanced_use/images/checkpoint_integration_process.jpg differ