From ac99a0a0643655ba1de4c9273df351ba87417ac5 Mon Sep 17 00:00:00 2001 From: zhangyi Date: Tue, 24 May 2022 15:24:47 +0800 Subject: [PATCH] modify the files --- .../source_en/design/auto_gradient.md | 140 ++++++++++++------ .../source_en/design/dataset_offload.md | 54 +++---- .../source_zh_cn/design/auto_gradient.ipynb | 4 +- 3 files changed, 123 insertions(+), 75 deletions(-) diff --git a/docs/mindspore/source_en/design/auto_gradient.md b/docs/mindspore/source_en/design/auto_gradient.md index 9c3826d478..98b2c9fc03 100644 --- a/docs/mindspore/source_en/design/auto_gradient.md +++ b/docs/mindspore/source_en/design/auto_gradient.md @@ -1,65 +1,75 @@ -# MindSpore Automatic Differentiation +# Functional Differential Programming ## Automatic Differentiation Overview -Modern AI algorithm, such as deep learning, uses huge amount of data to train a model with parameters. This training process often uses loss back-propagation to update parameters. Automatic differentiation (AD) is one of the key techniques. +Modern AI algorithm, such as deep learning, uses a large amount of data to learn and fit an optimized model with parameters. This training process often uses loss back-propagation to update parameters. **Automatic differentiation (AD)** is one of the key techniques. -Automatic differentiation is a method between neumerical differentiation and symbolic differentiation. The key concept of AD is to divide the calculation of the computer program into a finite set with basic operations. The gradients of all the basic operations are known. After calculating the gradient of all the basic operations, AD uses chain rule to combine them and gets the final gradient. +Automatic differentiation is a derivation method between neumerical differentiation and symbolic differentiation. The key concept of AD is to divide the calculation of the computer program into a finite set with basic operations. The derivations of all the basic operations are known. After calculating the derivation of all the basic operations, AD uses chain rule to combine them and gets the final gradient. -The formula of chain rule is: $(f\circ g)^{'}(x)=f^{'}(g(x))g^{'}(x)$ +The formula of chain rule is: +$$ +(f\circ g)^{'}(x)=f^{'}(g(x))g^{'}(x) \tag{1} +$$ +Based on how to connect the gradient of basic components, AD can be divided into **forward mode AD** and **reverse mode AD**. -Based on how to connect the gradient of basic components, AD can be divided into forward mode AD and reverse mode AD. +- Forward Automatic Differentiation (also known as tangent linear mode AD) or forward cumulative gradient (forward mode). +- Reverse Automatic Differentiation (also known as adjoint mode AD) or reverse cumulative gradient (reverse mode). -For example, if we define function $f$ -$y=f(x_{1},x_{2})=ln(x_{1})+x_{1}x_{2}-sin(x_{2})$ and we want to use forward mode AD to calculate $\frac{\partial y}{\partial x_{1}}$ when $x_{1}=2,x_{2}=5$. +Let's take formula (2) as an example to introduce the specific calculation method of forward and reverse differentiation: +$$ +y=f(x_{1},x_{2})=ln(x_{1})+x_{1}x_{2}-sin(x_{2}) \tag{2} +$$ +When we use the forward automatic differentiation formula (2) at $x_{1}=2, x_{2}=5$,$,frac{partial y}{partial x_{1}}$, the direction of derivation of forward automatic differentiation is consistent with the evaluation direction of the original function, and the original function result and the differential result can be obtained at the same time. ![image](./images/forward_ad.png) -The calculation direction of the origin function is the same as the calculation direction of forward mode AD. The function output and the gradient can be calculated simultaneously. -When we use reverse mode AD: +When using reverse automatic differentiation, the direction of differentiation of the reverse automatic differentiation is opposite to the evaluation direction of the original function, and the differential result depends on the running result of the original function. ![image](./images/backward_ad.png) +MindSpore first developed automatic differentiation based on the reverse pattern, and implemented forward differentiation on the basis of this method. -The calculation direction of the origin function is opposite to the calculation direction of reverse mode AD. The calculation of the gradient relies on the output of the original function. -MindSpore first developed method GradOperation based on reverse mode AD and then used the GradOperation to develop forward mode AD method Jvp. - -In order to explain the differences between forward mode AD and reverse mode AD in further. We define an origin function $F$ with N inputs and M outputs: -$ (Y_{1},Y_{2},...,Y_{M})=F(X_{1},X_{2},...,X_{N})$ -The gradient of function $F$ is a Jacobian matrix. -$ - \left[ +In order to explain the differences between forward mode AD and reverse mode AD in further, we generalize the derived function to F, which has an N input and an M output: +$$ +(Y_{1},Y_{2},...,Y_{M})=F(X_{1},X_{2},...,X_{N}) \tag{3} +$$ +The gradient of function $F()$ is a Jacobian matrix. +$$ +\left[ \begin{matrix} \frac{\partial Y_{1}}{\partial X_{1}}& ... & \frac{\partial Y_{1}}{\partial X_{N}} \\ ... & ... & ... \\ \frac{\partial Y_{M}}{\partial X_{1}} & ... & \frac{\partial Y_{M}}{\partial X_{N}} \end{matrix} \right] -$ +\tag{4} +$$ ### Forward Mode AD In forward mode AD, the calculation of gradient starts from inputs. So, for each calculation, we can get the gradient of outputs with respect to one input, which is one column of the Jacobian matrix. $$ - \left[ +\left[ \begin{matrix} \frac{\partial Y_{1}}{\partial X_{1}}\\ ... \\ \frac{\partial Y_{M}}{\partial X_{1}} \end{matrix} \right] +\tag{5} $$ -In order to get this value, AD divies the program into a series of basic operations. The gradient rules of these basic operations is known. The basic operation can also be represented as a function $f$ with n inputs and m outputs: - -$$ (y_{1},y_{2},...,y_{m})=f(x_{1},x_{2},...,x_{n})$$ +In order to get this value, AD divies the program into a series of basic operations. The gradient rules of these basic operations is known. The basic operation can also be represented as a function $f$ with $n$ inputs and $m$ outputs: +$$ +(y_{1},y_{2},...,y_{m})=f(x_{1},x_{2},...,x_{n}) \tag{6} +$$ Since we have defined the gradient rule of $f$, we know the jacobian matrix of $f$. So we can calculate the Jacobian-vector-product (Jvp) and use the chain rule to get the gradient outoput. $$ - \left[ +\left[ \begin{matrix} \frac{\partial y_{1}}{\partial X_{i}}\\ ... \\ @@ -78,6 +88,7 @@ $$ \frac{\partial x_{n}}{\partial X_{i}} \end{matrix} \right] +\tag{7} $$ ### Reverse Mode AD @@ -85,21 +96,23 @@ $$ In reverse mode AD, the calculation of gradient starts from outputs. So, for each calculation, we can get the gradient of one output with respect to inputs, which is one row of the Jacobian matrix. $$ - \left[ +\left[ \begin{matrix} \frac{\partial Y_{1}}{\partial X_{1}}& ... & \frac{\partial Y_{1}}{\partial X_{N}} \\ \end{matrix} \right] +\tag{8} $$ In order to get this value, AD divies the program into a series of basic operations. The gradient rules of these basic operations is known. The basic operation can also be represented as a function $f$ with n inputs and m outputs: -$$ (y_{1},y_{2},...,y_{m})=f(x_{1},x_{2},...,x_{n})$$ - +$$ +(y_{1},y_{2},...,y_{m})=f(x_{1},x_{2},...,x_{n}) \tag{9} +$$ Since we have defined the gradient rule of $f$, we know the jacobian matrix of $f$. So we can calculate the Vector-Jacobian-product (Vjp) and use the chain rule to get the gradient outoput. $$ - \left[ +\left[ \begin{matrix} \frac{\partial Y_{j}}{\partial x_{1}}& ... & \frac{\partial Y_{j}}{\partial x_{N}} \\ \end{matrix} @@ -114,6 +127,7 @@ $$ \frac{\partial y_{m}}{\partial x_{1}} & ... & \frac{\partial y_{m}}{\partial x_{n}} \end{matrix} \right] +\tag{10} $$ ## GradOperation @@ -122,15 +136,21 @@ GradOperation uses reverse mode AD, which calcultes gradients from network outpu ### GradOperation Design -Define origin function $f(g(x, y, z))$ , then: - -$$\frac{df}{dx}=\frac{df}{dg}\frac{dg}{dx}\frac{dx}{dx}+\frac{df}{dg}\frac{dg}{dy}\frac{dy}{dx}+\frac{df}{dg}\frac{dg}{dz}\frac{dz}{dx}$$ +Consuming that the origin function of defining model is as follows: +$$ +f(g(x, y, z)) \tag{11} +$$ +Then the gradient of $f()$ to $x$ is: +$$ +\frac{df}{dx}=\frac{df}{dg}\frac{dg}{dx}\frac{dx}{dx}+\frac{df}{dg}\frac{dg}{dy}\frac{dy}{dx}+\frac{df}{dg}\frac{dg}{dz}\frac{dz}{dx}\tag{12} +$$ The formula of $\frac{df}{dy}$ and $\frac{df}{dz}$ is similar to $\frac{df}{dx}$. -Based on chain rule, we define gradient function `bprop: dout->(df, dinputs)` for every functions (including operators and graph). Here, `df` means gradients with respect to free variables and `dinputs` is gradients to function inputs. Then we use total derivative rule to accumulate `(df, dinputs)` to correspond variables. +Based on chain rule, we define gradient function `bprop: dout->(df, dinputs)` for every functions (including operators and graph). Here, `df` means gradients with respect to free variables (variables defined outside the function) and `dinputs` is gradients to function inputs. Then we use total derivative rule to accumulate `(df, dinputs)` to correspond variables. MindSporeIR has developed the formulas for branching, loops and closures. So if we define the gradient rules correctly, we can get the correct gradient. + Define operator K, backward mode AD can be represented as: ```text @@ -145,17 +165,23 @@ F(v): { ### GradOperation Implementation -In GradOperation process, the function that needs to calculate gradient will be taken out and used as the input of automatic differentiation module. AD module will map input function to gradient `fprop`. The output gradient has form `fprop = (forward_result, bprop)`. `forward_result` is the output node of the origin function. `bprop` is the gradient function which relies on the closure object of `fprop`. `bprop` has only one input `dout`. `inputs` and `outputs` are the inputs and outputs of `fprop`. +In GradOperation process, the function that needs to calculate gradient will be taken out and used as the input of automatic differentiation module. + +AD module will map input function to gradient `fprop`. + +The output gradient has form `fprop = (forward_result, bprop)`. `forward_result` is the output node of the origin function. `bprop` is the gradient function which relies on the closure object of `fprop`. `bprop` has only one input `dout`. `inputs` and `outputs` are the called inputs and outputs of `fprop`. ```c++ MapObject(); // Map ValueNode/Parameter/FuncGraph/Primitive object MapMorphism(); // Map CNode morphism - res = k_graph(); // res is fprop object + res = k_graph(); // res is fprop object of gradient function ``` -When generating gradient function object, we need to do a series of mapping from origin function to gradient function. These mapping will generate gradient function nodes and we will connect these nodes according to reverse mode AD rules. For each subgraph of origin function, we will create an `DFunctor` object. `Dfunctor` will run `MapObject` and `MapMorphism` to do the mapping. +When generating gradient function object, we need to do a series of mapping from origin function to gradient function. These mapping will generate gradient function nodes and we will connect these nodes according to reverse mode AD rules. -`MapObject` maps nodes of origin function to nodes of gradient function, including free variable nodes, parameter nodes and ValueNodes. +For each subgraph of origin function, we will create an `DFunctor` object, for mapping the original function object to a gradient function object. `Dfunctor` will run `MapObject` and `MapMorphism` to do the mapping. + +`MapObject` implements the mapping of the original function node to the gradient function node, including the mapping of free variables, parameter nodes, and ValueNode. ```c++ MapFvObject(); // map free vriabels @@ -163,7 +189,13 @@ MapParamObject(); // map parameters MapValueObject(); // map ValueNodes ``` -`MapFvObject` maps free variables, `MapParamObject` maps parameter nodes. `MapValueObject` mainly maps `Primitive` and `FuncGraph` objects. For `FuncGraph`, we need to create another `DFunctor` object and perform the mapping. This is a recursion process. `Primitive` defines the type of the operator. We need to define gradient function for every `Primitive`. MindSpore defines these gradient functions in Python, for example: +- `MapFvObject` maps free variables. +- `MapParamObject` maps parameter nodes. +- `MapValueObject` mainly maps `Primitive` and `FuncGraph` objects. + +For `FuncGraph`, we need to create another `DFunctor` object and perform the mapping, which is a recursion process. `Primitive` defines the type of the operator. We need to define gradient function for every `Primitive`. + +MindSpore defines these gradient functions in Python, taking `sin` operator for example: ```python @bprop_getters.register(P.Sin) @@ -178,15 +210,25 @@ def get_bprop_sin(self): return bprop ``` -This code is the gradient function for `sin`. `x` is the input to `sin`, `y` is the output to `sin` and `dout` is the accumulated gradient. +`x` is the input to the original function object `sin`. `out` is the output of the original function object `sin`, and `dout` is the gradient input of the current accumulation. -After `MapObject` process, `MapMorphism` maps `CNode` morphism starting from the output of origin function and establishes the connectiontion between AD nodes. +When `MapObject` completes the mapping of the above nodes, `MapMorphism` recursively implements the state injection of `CNode` from the output node of the original function, establishes a backpropagation link between nodes, and realizes gradient accumulation. ### GradOperation Example -We build a simple network and calculate the gradient according to `x`. The structure of the network is: +Let's build a simple network to represent the formula: +$$ +f(x) = cos(sin(x)) \tag{13} +$$ +And derive the input `x` of formula (13): +$$ +f'(x) = -sin(sin(x)) * cos(x) \tag{14} +$$ +The structure of the network in formula (13) in MindSpore is implemented as follows: ```python +import mindspore.nn as nn + class Net(nn.Cell): def __init__(self): super(Net, self).__init__() @@ -199,21 +241,25 @@ class Net(nn.Cell): return out ``` -Origin network structure is: +The structure of a forward network is: + +![auto-gradient-foward](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/mindspore/source_zh_cn/design/images/auto_gradient_foward.png) + +After the network is reversely differential, the resulting differential network structure is: -![image](./images/origin_net.png) +![auto-gradient-forward2](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/mindspore/source_zh_cn/design/images/auto_gradient_forward2.png) -After reverse mode AD, the network structure is: +## Jacobian-Vector-Product Implementation -![image](./images/backward_net.png) +Besides GradOperation, Mindspore has developed forward mode automatic differentiation method Jvp (Jacobian-Vector-Product). -## Forward Mode Implementation +Compared to reverse mode AD, forward mode AD is more suitable for networks whose input dimension is smaller than output dimension. Mindspore forward mode AD is developed based on reversed mode GradOperation function. -Besides GradOperation, Mindspore has developed forward mode automatic differentiation method Jvp (Jacobian-Vector-Product). Compared to reverse mode AD, forward mode AD is more suitable for networks whose input dimension is smaller than output dimension. Mindspore forward mode AD is developed based on reversed mode GradOperation function. +![auto-gradient-jvp](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/mindspore/source_zh_cn/design/images/auto_gradient_jvp.png) -![image](./images/Jvp.png) +The network in black is the origin function. After the first derivative based on one input $x$, we get the network in blue. The second is the blue plot for the $v$ derivative, resulting in a yellow plot. -The network in black is the origin function. After the first derivative based on one input x, we get the network in blue. The we compute the gradient of blue network with respect to vector v and we can get the yellow network. This yellow network is the forward mode AD gradient network of black network. Since blue network is a linear network for vector v, there will be no connection between blue network and yellow network. So, all the nodes in blue are dangling nodes. We can use only blue and yellow nodes to calculate the gradient. +This yellow network is the forward mode AD gradient network of black network. Since blue network is a linear network for vector $v$, there will be no connection between blue network and yellow network. So, all the nodes in blue are dangling nodes. We can use only blue and yellow nodes to calculate the gradient. ### References diff --git a/docs/mindspore/source_en/design/dataset_offload.md b/docs/mindspore/source_en/design/dataset_offload.md index 2be522198b..b4b7a06bcc 100644 --- a/docs/mindspore/source_en/design/dataset_offload.md +++ b/docs/mindspore/source_en/design/dataset_offload.md @@ -1,4 +1,4 @@ -# Enabling Offload for Dataset +# Enabling Heterogeneous Acceleration for Data @@ -10,40 +10,40 @@ Currently this heterogeneous hardware acceleration technology (introduced as the The offload feature will move only the supported dataset operations applied on the specific input column at the end of the pipeline to the accelerator. This includes consecutive data augmentation operators which are used in the map data processing operator, granted they come at the end of the dataset pipeline for a specific input column. -The current supported data augmentation operators which can be offloaded are: +The current supported data augmentation operators which can perform heterogeneous acceleration are: -| Operator Name | Operator Path | Operator Introduction | -| -------------------- | -------------------------------------------- | --------------------------------------------------------------------------------------------------- | -| HWC2CHW | mindspore.dataset.vision.c_transforms.py | Transpose a Numpy image array from shape (H, W, C) to shape (C, H, W) | -| Normalize | mindspore.dataset.vision.c_transforms.py | Normalize the input Numpy image array of shape (H, W, C) with the given mean and standard deviation | -| RandomColorAdjust | mindspore.dataset.vision.c_transforms.py | Perform a random brightness, contrast, saturation, and hue adjustment on the input PIL image | -| RandomHorizontalFlip | mindspore.dataset.vision.c_transforms.py | Randomly flip the input image horizontally with a given probability | -| RandomSharpness | mindspore.dataset.vision.c_transforms.py | Adjust the sharpness of the input PIL Image by a random degree | -| RandomVerticalFlip | mindspore.dataset.vision.c_transforms.py | Randomly flip the input image vertically with a given probability | -| Rescale | mindspore.dataset.vision.c_transforms.py | Rescale the input image with the given rescale and shift | -| TypeCast | mindspore.dataset.transforms.c_transforms.py | Cast tensor to a given MindSpore data type | +| Operator Name | Operator Path | Operator Introduction | +| -------------------- | -------------------------------------------- | ------------------------------------------------------------ | +| HWC2CHW | mindspore.dataset.vision.c_transforms.py | Transpose a Numpy image array from shape (H, W, C) to shape (C, H, W) | +| Normalize | mindspore.dataset.vision.c_transforms.py | Normalize the image | +| RandomColorAdjust | mindspore.dataset.vision.c_transforms.py | Perform a random brightness, contrast, saturation, and hue adjustment on the input PIL image | +| RandomHorizontalFlip | mindspore.dataset.vision.c_transforms.py | Randomly flip the input image | +| RandomSharpness | mindspore.dataset.vision.c_transforms.py | Adjust the sharpness of the input PIL Image by a random degree | +| RandomVerticalFlip | mindspore.dataset.vision.c_transforms.py | Randomly flip the input image vertically with a given probability | +| Rescale | mindspore.dataset.vision.c_transforms.py | Rescale the input image with the given rescale and shift | +| TypeCast | mindspore.dataset.transforms.c_transforms.py | Cast tensor to a given MindSpore data type | ## Offload Process -The following figures show the typical computation process of how to use the offload feature in the given dataset pipeline. +The following figures show the typical computation process of how to use heterogeneous acceleration in the given dataset pipeline. ![offload](images/offload_process.PNG) -Offload has two new API changes to let users enable this functionality: +Heterogeneous acceleration has two new API changes to let users enable this functionality: -1. A new argument “offload” is added to the map dataset processing operator. +1. Map data operator adds offload input parameter. -2. A new API “set_auto_offload” is introduced to the dataset config. +2. Dataset global configuration of mindspore.dataset.config added set_auto_offload interface. -To check if the data augmentation operators are offloaded to the accelerator, users can save and check the computation graph IR files which will have the related operators written before the model structure. The offload feature is currently available for both dataset sink mode (dataset_sink_mode=True) and dataset non-sink mode (dataset_sink_mode=False). +To check if the data augmentation operators are moved to the accelerator, users can save and check the computation graph IR files which will have the related operators written before the model structure. The heterogeneous acceleration is currently available for both dataset sink mode (dataset_sink_mode=True) and dataset non-sink mode (dataset_sink_mode=False). -## Enabling Offload +## Enabling Heterogeneous Acceleration by Using Data -There are two options to enable offload. +There are two options provided by MindSpore to enable heterogeneous acceleration. ### Option 1 -Use the global config to set automatic offload. In this case, the offload argument for all map data processing operators will be set to True (see Option 2). However, if the offload argument is given for a specific map operator, it will have priority over the global config option. +Use the global config to set automatic heterogeneous acceleration. In this case, the offload argument for all map data processing operators will be set to True (see Option 2). It should be noted that the offload argument is given for a specific map operator, it will have priority over the global config option. ```python import mindspore.dataset as ds @@ -70,7 +70,7 @@ dataset = dataset.map(operations=type_cast_op, input_columns="label", offload=Tr dataset = dataset.map(operations=image_ops , input_columns="image", offload=True) ``` -The offload feature supports being applied on multi-column dataset as the below example shows. +The heterogeneous acceleration supports being applied on multi-column dataset as the below example shows. ```python dataset = dataset.map(operations=type_cast_op, input_columns="label") @@ -84,20 +84,22 @@ dataset = dataset.map(operations=image_ops, input_columns=["image2"], offload=Tr ## Constraints -The offload feature is still under development. The current usage is limited under the following constraints: +The heterogeneous acceleration feature is still under development. The current usage is limited under the following constraints: -1. Offload feature does not support concatenated or zipped datasets currently. +1. The feature does not support concatenated or zipped datasets currently. -2. The map operation(s) you wish to offload must be the last map operation(s) in the pipeline for their specific input column(s). There is no limitation of the input columns' order. For example, a map operation applied to the “label” data column like +2. The heterogeneous acceleration operator must be the last or more consecutive data augmentation operations acting on a particular data input column, but the data input column is processed in an unlimited order, for example ```python dataset = dataset.map(operations=type_cast_op, input_columns="label", offload=True) ``` - can be offloaded even if non-offload map operations applied on different data column(s) occur afterwards, such as + which can be shown in: ```python dataset = dataset.map(operations=image_ops, input_columns="image", offload=False) ``` -3. Offload feature does not support map operations with a user specified `output_columns`. + That is, even if the map operator acting on the "image" column is not set to offload, the map operator acting on the "label" column can also perform offload. + +3. This feature does not currently support the user to specify output columns in the map data operator. diff --git a/docs/mindspore/source_zh_cn/design/auto_gradient.ipynb b/docs/mindspore/source_zh_cn/design/auto_gradient.ipynb index 2adf4e0303..743103bd97 100644 --- a/docs/mindspore/source_zh_cn/design/auto_gradient.ipynb +++ b/docs/mindspore/source_zh_cn/design/auto_gradient.ipynb @@ -82,7 +82,7 @@ "(y_{1},y_{2},...,y_{m})=f(x_{1},x_{2},...,x_{n}) \\tag{6}\n", "$$\n", "\n", - "由于我们的已知基础函数 $f$ 的求导规则,即 $f$ 的雅可比矩阵是已知的。 于是我们可以对$f$计算雅可比向量积(Jvp, Jacobian-vector-product),并应用链式求导法则获得导数结果。\n", + "由于我们的已知基础函数 $f$ 的求导规则,即 $f$ 的雅可比矩阵是已知的。于是我们可以对$f$计算雅可比向量积(Jvp, Jacobian-vector-product),并应用链式求导法则获得导数结果。\n", "\n", "$$\n", "\\left[\n", @@ -126,7 +126,7 @@ "(y_{1},y_{2},...,y_{m})=f(x_{1},x_{2},...,x_{n}) \\tag{9}\n", "$$\n", "\n", - "由于我们的已知基础函数$f$的求导规则,即f的雅可比矩阵是已知的。 于是我们可以对$f$计算向量雅可比积(Vjp, Vector-jacobian-product),并应用链式求导法则获得导数结果。\n", + "由于我们的已知基础函数$f$的求导规则,即f的雅可比矩阵是已知的。于是我们可以对$f$计算向量雅可比积(Vjp, Vector-jacobian-product),并应用链式求导法则获得导数结果。\n", "\n", "$$\n", "\\left[\n", -- Gitee