diff --git a/tutorials/training/source_en/advanced_use/dump_data_from_ir_files.md b/tutorials/training/source_en/advanced_use/dump_data_from_ir_files.md new file mode 100644 index 0000000000000000000000000000000000000000..ab044109176163e5cf441318a5dd05db264ff261 --- /dev/null +++ b/tutorials/training/source_en/advanced_use/dump_data_from_ir_files.md @@ -0,0 +1,264 @@ +

Debugging by Using IR Graph

+ +`Linux` `Ascend` `GPU` `Model Development` `Beginner` `Intermediate` `Expert` + + + +- [Debugging by Using IR Graph](#1) + - [Overview](#Overview) + - [Generating IR Files](#2) + - [The Introduction of the IR Files Content](#3) + - [Dumping Wanted Data from IR Files ](#4 ) + + + + + +## Overview + +When Running the models developed by MindSpore in graph mode `context.set_context(mode=context.GRAPH_MODE)`,if we set `context.set_context(save_graphs=True)`,then some intermediate files will be outputted during the process, which are called IR files. There are three main kinds of IR files: + +- IR file with ir suffix: It is a visual and easily understanding file describing the model in text format. It can be viewed by text editing software directly. We will introduce how to view this type of file in the following. +- IR file with dat suffix: Compared to the file with ir suffix, its format definition is more compact and content is more abundant. And it can be viewed directly by text editing software. +- IR file with dot suffix: describe the topological relationship of different nodes. We can input this file by [graphviz](http://graphviz.org) to generate the pictures so that the users can easily view the construction of the model. It is recommended to use the visual component, [MindInsight](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id5), to show the computation graph for the models with many operators. + +In this course, we use the LeNet of ModelZoo in Ascend environment as example. The related scripts can be found in [ModelZoo/LeNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/lenet). + +

Generating IR Files

+ +In the file`train.py`,we can add the following code to the function`set_context`. When running the training scripts, MindSpore will automatically save the IR file ,which is produced in the process of compiling, to given path. + +```python +if __name__ == "__main__": + context.set_context(save_graphs=True, save_graphs_path="path/to/ir/files") +``` + +In this course, we will run the stand-alone version of the training script. If the script runs on a few computing devices, MindSpore will generate one separate process for every computing device. The users are suggested to read the current computing device id in the multi-card training scripts, and set `save_graphs_path` separately for every equipment. So the IR files can be saved to different paths. For example: + +```python +device_id = os.getenv("DEVICE_ID") +context.set_context(save_graphs=True, save_graphs_path="path/to/ir/files"+device_id) +``` + +After carrying out the training command, the following files will be created in the given path. The IR files with the head of number and underline is outputted during the process of ME compiling graph. Every stages of `pipeline` will save computing graph once. The following will introduce the more important stages, such as analyzing the entrance `construct`function in `parse` stage;`symbol_resolve`stage will recursively resolve the entrance function which directly or indirectly refers to other functions or objects;`abstract_specialize` stage will deduce the class and `shape`; `optimize`stage is mainly about the optimization which is not relative to hardware, as well as auto differential and auto parallel;`validate`stage will check the compiled computation graph;`task_emit`stage will pass the computation graph to the back end to make further process; stage `execute` will execute this computation graph + +```bash +. +├── 00_parse_[xxxx].ir +├── 00_parse.dat +├── 00_parse.dot +├── 01_symbol_resolve_[xxxx].ir +├── 01_symbol_resolve.dat +├── 01_symbol_resolve.dot +├── 02_combine_like_graphs_[xxxx].ir +├── 02_combine_like_graphs.dat +├── 02_combine_like_graphs.dot +├── 03_inference_opt_prepare_[xxxx].ir +├── 03_inference_opt_prepare.dat +├── 03_inference_opt_prepare.dot +├── 04_abstract_specialize_[xxxx].ir +├── 04_abstract_specialize.dat +├── 04_abstract_specialize.dot +├── 05_inline_[xxxx].ir +├── 05_inline.dat +├── 05_inline.dot +├── 06_py_pre_ad_[xxxx].ir +├── 06_py_pre_ad.dat +├── 06_py_pre_ad.dot +├── 07_pipeline_split_[xxxx].ir +├── 07_pipeline_split.dat +├── 07_pipeline_split.dot +├── 08_optimize_[xxxx].ir +├── 08_optimize.dat +├── 08_optimize.dot +├── 09_py_opt_[xxxx].ir +├── 09_py_opt.dat +├── 09_py_opt.dot +├── 10_validate_[xxxx].ir +├── 10_validate.dat +├── 10_validate.dot +├── 11_task_emit_[xxxx].ir +├── 11_task_emit.dat +├── 11_task_emit.dot +├── 12_execute_[xxxx].ir +├── 12_execute.dat +├── 12_execute.dot +... +``` + +

The Introduction of the IR Files Content

+ +The following is a simple example which illustrates the content of IR files. + +```python +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import dtype as mstype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") +context.set_context(save_graphs=True, save_graphs_path="./ir_files") + +class Net(nn.Cell): + def __init__(self): + super().__init__() + + def construct(self, x, y): + x = x + y + x = x * y + return x + +x = Tensor(3, mstype.float32) +y = Tensor(2, mstype.float32) +net = Net() +out = net(x, y) +print(out) +``` + +Opening the file `12_execute_[xxxx].ir` by text editing software(e.g. vi). And the content is as follow: + +```text + 1 #IR entry : @6_5_1_construct_wrapper.15 + 2 #attrs : + 3 check_set_strategy_valid_once_only : 1 + 4 #Total params : 2 + 5 + 6 %para1_x : + 7 %para2_y : + 8 + 9 #Total subgraph : 1 + 10 + 11 subgraph attr: + 12 check_set_strategy_valid_once_only : 1 + 13 subgraph @6_5_1_construct_wrapper.15() { + 14 %0([CNode]8) = Add(%para1_x, %para2_y) primitive_attrs: {output_names: [output], input_names: [x, y]} + 15 : (, ) -> () + 16 # In file /home/workspace/mindspore/mindspore/ops/composite/multitype_ops/add_impl.py(129)/ return F.tensor_add(x, y)/ + 17 # In file demo.py(14)/ x = x + y/ + 18 %1([CNode]10) = Mul(%0, %para2_y) primitive_attrs: {output_names: [output], input_names: [x, y]} + 19 : (, ) -> () + 20 # In file /home/workspace/mindspore/mindspore/ops/composite/multitype_ops/mul_impl.py(48)/ return F.tensor_mul(x, y)/ + 21 # In file demo.py(15)/ x = x * y/ + 22 return(%1) + 23 : () + 24 } +``` + +The above can be divided into two parts, the first part is input list, and the second part is the structure of the graph. The first row tells us that the top graph of this network is `@6_5_1_construct_wrapper.15`,that is, the entrance graph. The forth row tells us how many inputs this network has. The 6th and 7th rows are the input lists, follow the format of `%para[order number]_[name] : <[data_type]x[shape]>`. The ninth row illustrates the number of subgraphs. The 11th to 24th rows are the structure of graph containing several nodes, that is, `CNode`. There are only 2 nodes in this case, `Add`in the 14th row, and `Mul` in the 18th row. + +Information of `CNode` follows the below format, including node name, attribute, input node, output information, format, source code analysis call stack, etc. As the ANF graph is unidirectional acyclic graph, it only illustrates the connect relationship between nodes depending on the input relationship. Source code analysis call stack shows the relationship of `CNode`and the source code, for example, the 20th raw is analyzed from 21st raw, and the 21st raw can be corresponded to `x = x * y` in the script. + +```text + %[order number]([debug_name]) = [OpName]([arg], ...) primitive_attrs: {[key]: [value], ...} + : (<[input data_type]x[input shape]>, ...) -> (<[output data_type]x[output shape]>, ...) + # source code analysis call stack +``` + +> Attention: After a few optimization of compiler, the nodes may be transferred (e.g, operator split, operator join), the information of source code analysis call stack may not be fully one-to-one matched. So this is only a complemental method. + +

Dumping the Wanted Data from IR Files

+ +The following code is from the example `lenet.py` which is in LeNet of ModelZoo. If we want to dump the data of the first convolution layer, that is, `x = self.conv1(x)` in the following code. + +```python +class LeNet5(nn.Cell): + def __init__(self, num_class=10, num_channel=1, include_top=True): + super(LeNet5, self).__init__() + self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') + self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.include_top = include_top + if self.include_top: + self.flatten = nn.Flatten() + self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) + self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) + self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) + + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + if not self.include_top: + return x + x = self.flatten(x) + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + x = self.fc3(x) + return x +``` + +Generally speaking , the 0 graph of back-end represents the data subgraph (if the sink mode of dataset is on). The 1 graph represents the backbone network. So we can search `x = self.conv1(x)` from the file `hwopt_d_end_graph_1_[xxxx].ir`that we dump, and we will get four results, 3 of which are `Cast`and`TransData`. Crossing this kind of accuracy transfer `Cast`and format transfer `TransData`,we finally located row 213 to row 221 `%24(equivoutput) = Conv2D(%23, %19)...`, which are correspond to the`conv1` of the network. And getting the op name matched to the operator in its compiling graph. (In the brackets of row 216 `Default/network-TrainOneStepWithLossScaleCell/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Conv2D-op89`) + +```bash +... + 213 %24(equivoutput) = Conv2D(%23, %19) {instance name: conv2d} primitive_attrs: {pri_format: NC1HWC0, stride: (1, 1, 1, 1), pad: (0, 0, 0, 0), pad_mode: valid, out_channel: 6, mode: 1 , dilation: (1, 1, 1, 1), output_names: [output], group: 1, format: NCHW, visited: true, offset_a: 0, kernel_size: (5, 5), groups: 1, input_names: [x, w], pad_list: (0, 0, 0, 0), IsF eatureMapOutput: true, IsFeatureMapInputList: (0)} + 214 : (, ) -> () + 215 : (, ) -> () + 216 : (Default/network-TrainOneStepWithLossScaleCell/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Conv2D-op89) + 217 # In file /home/workspace/mindspore/build/package/mindspore/nn/layer/conv.py(263)/ output = self.conv2d(x, self.weight)/ + 218 # In file /home/workspace/mindspore/model_zoo/official/cv/lenet/src/lenet.py(49)/ x = self.conv1(x)/ + 219 # In file /home/workspace/mindspore/build/package/mindspore/train/amp.py(101)/ out = self._backbone(data)/ + 220 # In file /home/workspace/mindspore/build/package/mindspore/nn/wrap/loss_scale.py(323)/ grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)/ + 221 # In file /home/workspace/mindspore/build/package/mindspore/train/dataset_helper.py(87)/ return self.network(*outputs)/ +... +``` + +After getting the operator's op name, we can use the Dump process to save the input and output of the operator, in order to make debug more convenient. Here we introduce a method called synchro Dump. + +1. Creating the configure file`data_dump.json`. This file save the information about the operators that need Dump. We will copy the op's names, which was located in last step, to the list that key`kernels` correspond. We can read [custom debugging info](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/custom_debugging_info.html#id5) to get more information about this file. + + ```json + { + "common_dump_settings": { + "dump_mode": 1, + "path": "/absolute_path", + "net_name": "LeNet", + "iteration": 0, + "input_output": 0, + "kernels": ["Default/network-TrainOneStepWithLossScaleCell/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Conv2D-op89"], + "support_device": [0,1,2,3,4,5,6,7] + }, + "e2e_dump_settings": { + "enable": true, + "trans_flag": false + } + } + ``` + +2. Configuring the environment variables, set the path of configure files. + + ```bash + export MINDSPORE_DUMP_CONFIG={Absolute path of data_dump.json} + ``` + +3. Running the case to Dump data, MindSpore will Dump the input and output data of specified operator to the given path. + + In this case, we get the following files, which correspond to the input and output of the operator respectively. + + ```bash + . + ├── Default--network-TrainOneStepWithLossScaleCell--network-WithLossCell--_backbone-LeNet5--conv1-Conv2d--Conv2D-op89_input_0_shape_32_1_32_32_16_Float16_NC1HWC0.bin + ├── Default--network-TrainOneStepWithLossScaleCell--network-WithLossCell--_backbone-LeNet5--conv1-Conv2d--Conv2D-op89_input_1_shape_25_1_16_16_Float16_FracZ.bin + └── Default--network-TrainOneStepWithLossScaleCell--network-WithLossCell--_backbone-LeNet5--conv1-Conv2d--Conv2D-op89_output_0_shape_32_1_28_28_16_Float16_NC1HWC0.bin + ``` + +4. Analyzing the Dump Data + + We can read the file generating in last step by `numpy.fromfile`. The `ndarray` read, is the input/output of the correspond operator. + + ```python + import numpy + output = numpy.fromfile("Default--network-TrainOneStepWithLossScaleCell--network-WithLossCell--_backbone-LeNet5--conv1-Conv2d--Conv2D-op89_input_0_shape_32_1_32_32_16_Float16_NC1HWC0.bin") + print(output) + ``` + + The output is: + + ```text + [1.17707155e-17 4.07526143e-17 5.84038559e-18 ... 0.00000000e+00 0.00000000e+00 0.00000000e+00] + ```