diff --git a/tutorials/source_en/advanced/mixed_precision.md b/tutorials/source_en/advanced/mixed_precision.md index 4eca395ec0388afdd7f41cc5c68bb71fb50cb15f..2867c31136d7e72941cd5f5a6abec98c5305bdd4 100644 --- a/tutorials/source_en/advanced/mixed_precision.md +++ b/tutorials/source_en/advanced/mixed_precision.md @@ -1,12 +1,14 @@ -# Enabling Mixed Precision - -## Overview +# Automatic Mix Precision + +Mixed precision training is a computing strategy that uses different numerical precision for different operations of the neural network during training. In neural network operations, some operations are not sensitive to numerical precision, and using lower precision can achieve significant acceleration (such as conv, matmul), while some of the operations usually need to retain high precision to ensure the correctness of the results due to the large difference between the input and output values (such as log, softmax). + +The hardware acceleration modules are usually designed on current AI accelerator cards for targeting computationally intensive, precision-insensitive operations, such as TensorCore for NVIDIA GPUs and Cube for Ascend NPU. For neural networks with a larger share of operations, such as conv, matmul, their training speed usually has a larger acceleration ratio. -Generally, when a neural network model is trained, the default data type is FP32. In recent years, to accelerate training time, reduce memory occupied during network training, and store a trained model with same precision, more and more mixed-precision training methods are proposed in the industry. The mixed-precision training herein means that both single precision (FP32) and half precision (FP16) are used in a training process. +The `mindspore.amp` module provides a convenient interface for automatic mixed precision, allowing users to obtain training acceleration at different hardware backends with simple interface calls. In the following, we introduce the calculation principle of mixed precision, and then introduce the automatic mixed precision usage of MindSpore by example. -## Floating-point Data Type +## Principle of Mixed Precision Calculation Floating-point data types include double-precision (FP64), single-precision (FP32), and half-precision (FP16). In a training process of a neural network model, an FP32 data type is generally used by default to indicate a network model weight and other parameters. The following is a brief introduction to floating-point data types. @@ -14,544 +16,283 @@ According to [IEEE 754](https://en.wikipedia.org/wiki/IEEE_754), floating-point ![fp16_vs_FP32](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_en/advanced/images/fp16_vs_fp32.png) -As shown in the figure, the storage space of FP16 is half that of FP32, and the storage space of FP32 is half that of FP64. It consists of three parts: - -- The highest bit indicates the sign bit. -- The middle bits indicate exponent bits. -- The low bits indicate fraction bits. - -FP16 is used as an example. The first sign bit sign indicates a positive or negative sign, and the next five bits indicate an exponent. All 0s and 1s have special uses, so the binary range is 00001~11110. The last 10 bits indicate a fraction. Suppose `S` denotes the decimal value of sign bit, `E` denotes the decimal value of exponent, and `fraction` denotes the decimal value of fraction. The formula is as follows: - -$$x=(-1)^{S}\times2^{E-15}\times(1+\frac{fraction}{1024})$$ - -Similarly, suppose `M` is score value, the true value of a formatted FP32 is as follows: - -$$x=(-1)^{S}\times2^{E-127}\times(1.M)$$ - -The true value of a formatted FP64 is as follows: - -$$x=(-1)^{S}\times2^{E-1023}\times(1.M)$$ - -The maximum value that can be represented by FP16 is 0 11110 1111111111, which is calculated as follows: - -$$(-1)^0\times2^{30-15}\times1.1111111111 = 1.1111111111(b)\times2^15 = 1.9990234375(d)\times2^15 = 65504$$ - -where `b` indicates binary value, and `d` indicates decimal value. - -The minimum value that can be represented by FP16 is 0 00001 0000000000, which is calculated as follows: - -$$ (-1)^{1}\times2^{1-15}=2^{-14}=6.104×10^{-5}=-65504$$ - -Therefore, the maximum value range of FP16 is [-65504, 65504], and the precision range is $2^{-24}$. If the value is beyond this range, the value is set to 0. - -## FP16 Training Issues +As shown in the figure, the storage space of FP16 is half that of FP32, and the storage space of FP32 is half that of FP64. Therefore, using FP16 for computing has the following advantages: -Why do we need mixed-precision? Compared with FP32, FP16 has the following advantages: +- Reduce memory usage: The bit width of FP16 is half that of FP32, so the memory used for parameters such as weights is also half of the original, saving memory for larger network models or training with more data. +- Higher computational efficiency: On special AI-accelerated chips such as Huawei Ascend 910 and 310 series, or GPUs on NVIDIA VOLTA architecture, execution performance is faster using FP16 than FP32. +- Accelerate communication efficiency: For distributed training, especially in the process of training large models, the communication overhead constrains the overall performance of network model training. Less bit-width of communication means that communication performance can be improved, waiting time can be reduced, and the flow of data can be accelerated. -- Reduced memory usage: The bit width of FP16 is half of that of FP32. Therefore, the memory occupied by parameters such as the weight is also half of the original memory. The saved memory can be used to store larger network models or train more data. -- Higher communication efficiency: For distributed training, especially the large-scale model training, the communication overhead restricts the overall performance. A smaller communication bit width means that the communication performance can be improved, the waiting time can be reduced, and the data flow can be accelerated. -- Higher computing efficiency: On special AI acceleration chips, such as Huawei Ascend 910 and 310 series, or GPUs of the NVIDIA VOLTA architecture, the computing performance of FP16 is faster than that of FP32. +But the use of FP16 also poses a number of problems: -However, using FP16 also brings some problems, the most important of which are precision overflow and rounding error. +- Data overflow: The valid data representation range for FP16 is $[6.10\times10^{-5}, 65504]$ and for FP32 is $[1.4\times10^{-45}, 1.7\times10^{38}]$. It can be seen that the effective range of FP16 is much narrower than that of FP32, and using FP16 to replace FP32 will result in overflow and underflow. In deep learning, the gradient (first-order derivative) of the weights in the network model needs to be calculated, so the gradient will be even smaller than the weight value and often prone to underflow. +- Rounding error: Rounding Error is when the backward gradient of the network model is small, which is generally represented by FP32. But the conversion to FP16 will be smaller than the minimum interval in the current interval and will lead to data overflow. If `0.00006666666` can be expressed normally in FP32, it will be expressed as `0.000067` after conversion to FP16, and the numbers that do not meet the minimum interval of FP16 will be forced to be rounded. -- Data overflow: Data overflow is easliy to understand. The valid data range of FP16 is $[6.10\times10^{-5}, 65504]$, and that of FP32 is $[1.4\times10^{-45}, 1.7\times10^{38}]$. We can see that the valid range of FP16 is much narrower than that of FP32. When FP16 is used to replace FP32, overflow and underflow occur. In deep learning, a gradient (a first-order derivative) of a weight in a network model needs to be calculated. Therefore, the gradient is smaller than the weight value, and underflow often occurs. -- Rounding error: Rounding error instruction is when the backward gradient of a network model is small, FP32 is usually used. However, when it is converted to FP16, the interval is smaller than the minimum interval, causing data overflow. For example, 0.00006666666 can be properly represented in FP32, but it will be represented as 0.000067 in FP16. The number that does not meet the minimum interval requirement of FP16 will be forcibly rounded off. +Therefore, the solution of the FP16 introduction problem needs to be considered while using mixed precision to obtain training speedup and memory savings. Loss Scale, a solution to the FP16 type data overflow problem, expands the loss by a certain number of times when calculating the loss value loss. According to the chain rule, the gradient is expanded accordingly and then scaled down by a corresponding multiple when the optimizer updates the weights, thus avoiding data underflow. -## Mixed-precision Computing Process +Based on the principles described above, a typical mixed precision computation process is shown in the following figure: -The following figure shows the typical computation process of mixed precision in MindSpore. +![mix precision](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/experts/source_zh_cn/optimize/images/mix_precision_fp16.png) -![mix precision](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/experts/source_en/optimize/images/mix_precision_fp16.png) +1. Parameters stored in FP32. +2. During forward computation, when it comes to FP16 operators, the operator inputs and parameters need to be cast from FP32 to FP16 for computation. +3. Set the Loss layer to FP32 for computation. +4. During the inverse computation, the Loss Scale value is first multiplied to avoid underflow due to a too small inverse gradient. +5. FP16 parameters are involved in the gradient computation and their results will be cast back to FP32. +6. Dividing by the Loss scale value to restore the amplified gradient. +7. Determine if there is an overflow in the gradient, and skip the update if there is an overflow, otherwise the optimizer updates the original parameters with FP32. -1. Parameters are stored in FP32 format. -2. During the forward computation, if an FP16 operator is involved, the operator input and parameters need to be cast from FP32 to FP16. -3. The Loss layer is set to FP32. -4. During backward computation, the value is multiplied by Loss Scale to avoid underflow due to a small gradient. -5. The FP16 parameter is used for gradient computation, and the result is cast back to FP32. -6. Then, the value is divided by Loss scale to restore the multiplied gradient. -7. The optimizer checks whether the gradient overflows. If yes, the optimizer skips the update. If no, the optimizer uses FP32 to update the original parameters. +In the following, we demonstrate the automatic mixed precision implementation of MindSpore by importing the handwritten digit recognition model and dataset from [Quick Start](https://www.mindspore.cn/tutorials/en/master/beginner/quick_start.html). -This document describes the computation process by using examples of automatic and manual mixed precision. - -## Loss Scale - -Loss Scale is mainly used in the process of mixed-precision training. +```python +import mindspore as ms +from mindspore import nn +from mindspore import ops +from mindspore import value_and_grad +``` -In the process of mixed precision training, the FP16 type is used instead of the FP32 type for data storage, so as to achieve the effect of reducing memory and improving the computing speed. However, because the FP16 type is much smaller than the range represented by the FP32 type, data underflow occurs when parameters (such as gradients) become very small during training. The Loss Scale is proposed to solve the underflow of FP16 type data. +```python +from mindspore.dataset import vision, transforms +from mindspore.dataset import MnistDataset -The main idea is to enlarge the loss by a certain multiple when calculating the loss. Due to the existence of the chain rule, the gradient also expands accordingly, and then the corresponding multiple is reduced when the optimizer updates the weight, thus avoiding the situation of data underflow without affecting the calculation result. +# Download data from open datasets +from download import download -There are two ways of implementing Loss Scale in MindSpore, users can either use the functional programming writeup and manually call the `scale` and `unscale` methods of `StaticLossScaler` or `DynamicLossScaler` to scale the loss or gradient during training; or they can configure the loss or gradient based on the `Model` interface and configure the mixed precision `amp_level` and the Loss Scale method `loss_scale_manager` as `FixedLossScaleManager` or `DynamicLossScaleManager` when building the model by using `Model`. +url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \ + "notebook/datasets/MNIST_Data.zip" +path = download(url, "./", kind="zip", replace=True) -First, let's take a look at why mixing accuracy is needed. The advantages of using FP16 to train a neural network are: -- **Reduce memory occupation**: The bit width of FP16 is half that of FP32, so the memory occupied by parameters such as weights is also half of the original, and the saved memory can be used to put a larger network model or use more data for training. -- **Accelerate communication efficiency**: For distributed training, especially in the process of large model training, the overhead of communication restricts the overall performance of network model training, and the less bit width of communication means that communication performance can be improved. Waiting time is reduced, and data circulation can be accelerated. -- **Higher computing effciency**: On special AI-accelerated chips such as Huawei's Ascend 910 and 310 series, or GPUs of the Titan V and Tesla V100 of the NVIDIA VOLTA architecture, the performance of performing operations using FP16 is faster than that of the FP32. +def datapipe(path, batch_size): + image_transforms = [ + vision.Rescale(1.0 / 255.0, 0), + vision.Normalize(mean=(0.1307,), std=(0.3081,)), + vision.HWC2CHW() + ] + label_transform = transforms.TypeCast(ms.int32) -But using FP16 also brings some problems, the most important of which are precision overflow and rounding error, and Loss Scale is to solve the precision overflow and proposed. + dataset = MnistDataset(path) + dataset = dataset.map(image_transforms, 'image') + dataset = dataset.map(label_transform, 'label') + dataset = dataset.batch(batch_size) + return dataset -As shown in the figure, if only FP32 training is used, the model converges better, but if mixed-precision training is used, there will be a situation where the network model cannot converge. The reason is that the value of the gradient is too small, and using the FP16 representation will cause the problem of underflow under the data, resulting in the model not converging, as shown in the gray part of the figure. Loss Scale needs to be introduced. +train_dataset = datapipe('MNIST_Data/train', 64) -![loss_scale1](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/experts/source_zh_cn/optimize/images/loss_scale1.png) +# Define model +class Network(nn.Cell): + def __init__(self): + super().__init__() + self.flatten = nn.Flatten() + self.dense_relu_sequential = nn.SequentialCell( + nn.Dense(28*28, 512), + nn.ReLU(), + nn.Dense(512, 512), + nn.ReLU(), + nn.Dense(512, 10) + ) -The following is in the network model training stage, a layer of activation function gradient distribution, of which 68% of the network model activation parameter bit 0. Another 4% of the accuracy in the $2^{-32}, 2^{-20}$ interval, directly use FP16 to represent the data inside, which truncates the underflow data. All gradient values will become 0. + def construct(self, x): + x = self.flatten(x) + logits = self.dense_relu_sequential(x) + return logits +``` -![loss_scale2](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/experts/source_zh_cn/optimize/images/loss_scale2.png) +```text +Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip (10.3 MB) -In order to solve the problem of ladder overflowing over small data, the forward calculated Loss value is amplified, that is, the parameters of FP32 are multiplied by a factor coefficient, and the possible overflowing decimal data is moved forward and panned to the data range that FP16 can represent. According to the chain differentiation law, amplifying the Loss acts on each gradient of backpropagation, which is more efficient than amplifying on each gradient. +file_sizes: 100%|██████████████████████████| 10.8M/10.8M [00:07<00:00, 1.53MB/s] +Extracting zip file... +Successfully downloaded / unzipped to ./ +``` -![loss_scale3](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/experts/source_zh_cn/optimize/images/loss_scale3.png) +## Type Conversions -Loss amplification needs to be achieved in combination with mixing accuracy, and its main main ideas are: +Mixed precision calculations require type conversion of operations that require low precision, converting their input to FP16 types, and then converting them back to FP32 types after the output is obtained. MindSpore provides both automatic and manual type conversion methods to meet the different needs for ease of use and flexibility, which are described below. -- **Scale up stage**: After the network model forward calculation, the resulting loss change value DLoss is increased by a factor of $2^K$ before the repercussion propagation. -- **Scale down stage**: After backpropagation, the weight gradient is reduced by $2^K$, and the FP32 value is restored for storage. +### Automatic Type Conversion -**Dynamic Loss Scale**: The loss scale mentioned above is to use a default value to scale the loss value, in order to make full use of the dynamic range of FP16, you can better mitigate the rounding error, and try to use a relatively large magnification. To summarize the dynamic loss scaling algorithm, it is to reduce the loss scale whenever the gradient overflows, and intermittently try to increase the loss scale, so as to achieve the use of the highest loss scale factor without causing overflow, and better restore accuracy. +The `mindspore.amp.auto_mixed_precision` interface provides the function to do automatic type conversion for networks. Automatic type conversion follows a blacklist and white list mechanism with four levels configured according to common operator precision conventions, as follows: -The dynamic loss scale algorithm is as follows: +- 'O0': Neural network keeps FP32. +- 'O1':Operation cast to FP16 by whitelist. +- 'O2':Retain FP32 by blacklist and the rest of operations cast to FP16. +- 'O3':The neural network is fully cast to FP16. -1. The algorithm of dynamic loss scaling starts with a relatively high scaling factor (such as $2^{24}$), then starts training and checks whether the number overflows in the iteration (Infs/Nans); -2. If there is no gradient overflow, the scale factor is not adjusted and the iteration continues; if the gradient overflow is detected, the scale factor is halved and the gradient update is reconfirmed until the parameter does not appear in the overflow range; -3. In the later stages of training, the loss has become stable and convergent, and the amplitude of the gradient update is often small, which can allow a higher loss scaling factor to prevent data underflow again. -4. Therefore, the dynamic loss scaling algorithm attempts to increase the loss scaling by the F multiple every N (N=2000) iterations, and then performs step 2 to check for overflow. +The following is an example of using automatic type conversion: -## Using Mixed Precision and Loss Scale in MindSpore +```python +from mindspore.amp import auto_mixed_precision -MindSpore provides two ways of using mixed precision and loss scale. +model = Network() +model = auto_mixed_precision(model, 'O2') +``` -- Use functional programming: use `auto_mixed_precision` for automatic mixing accuracy, `all_finite` for overflow judgments, and `StaticLossScaler` and `DynamicLossScaler` for manual scaling of gradients and losses. +### Manual Type Conversion -- Using the training interface `Model`: configure the input `amp_level` to set the execution policy for mixed precision and the input `loss_scale_manager` to `FixedLossScaleManager` or `DynamicLossScaleManager` to implement loss scaling. +Usually automatic type conversion can be used to satisfy most of the mixed precision training needs. But when users need to finely control the precision of operations in different parts of the neural network, they can be controlled by means of manual type conversion. -## Using a Functional Programming for Mixed Precision and Loss Scale +> Manual type conversions need to take into account the precision of each module in the model and are generally used only when extreme performance is required. -MindSpore provides a functional interface for mixed precision scenarios. Users can use `auto_mixed_precision` for automatic mixed precision, `all_finite` for overflow judgments during training, and `StaticLossScaler` and `DynamicLossScaler` to manually perform gradient and loss scaling. +Below we adapt `Network` in the previous article to demonstrate different ways of manual type conversion. -Common uses of LossScaler under functional. +#### Cell Granularity Type Conversion -First import the relevant libraries and define a LeNet5 network: +The `nn.Cell` class provides the `to_float` method to configure the module's operator precision with a single click, automatically casting the module input to the specified precision. ```python -import numpy as np -import mindspore.nn as nn -from mindspore.train import Accuracy -import mindspore as ms -from mindspore.common.initializer import Normal -from mindspore import dataset as ds -from mindspore.amp import auto_mixed_precision, DynamicLossScaler, all_finite -from mindspore import ops - +class NetworkFP16(nn.Cell): + def __init__(self): + super().__init__() + self.flatten = nn.Flatten() + self.dense_relu_sequential = nn.SequentialCell( + nn.Dense(28*28, 512).to_float(ms.float16), + nn.ReLU(), + nn.Dense(512, 512).to_float(ms.float16), + nn.ReLU(), + nn.Dense(512, 10).to_float(ms.float16) + ) -class LeNet5(nn.Cell): - """ - Lenet network + def construct(self, x): + x = self.flatten(x) + logits = self.dense_relu_sequential(x) + return logits +``` - Args: - num_class (int): Number of classes. Default: 10. - num_channel (int): Number of channels. Default: 1. +#### Custom Granularity Type Conversion - Returns: - Tensor, output tensor - """ +When the user needs to configure the precision of operations in a single operation, or a combination of multiple modules, Cell granularity often can not meet the purpose of custom granularity control by directly casting the type of input data. - def __init__(self, num_class=10, num_channel=1): - super(LeNet5, self).__init__() - self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') - self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') - self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) - self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) - self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) - self.relu = nn.ReLU() - self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) +```python +class NetworkFP16Manual(nn.Cell): + def __init__(self): + super().__init__() self.flatten = nn.Flatten() + self.dense_relu_sequential = nn.SequentialCell( + nn.Dense(28*28, 512), + nn.ReLU(), + nn.Dense(512, 512), + nn.ReLU(), + nn.Dense(512, 10) + ) def construct(self, x): - x = self.max_pool2d(self.relu(self.conv1(x))) - x = self.max_pool2d(self.relu(self.conv2(x))) x = self.flatten(x) - x = self.relu(self.fc1(x)) - x = self.relu(self.fc2(x)) - x = self.fc3(x) - return x + x = x.astype(ms.float16) + logits = self.dense_relu_sequential(x) + logits = logits.astype(ms.float32) + return logits ``` -Perform auto mixed precision on the network. - -`auto_mixed_precision` implements the meanings of automatic mixed precision configuration as follows: +## Loss Scaling -- 'O0': keep FP32. -- 'O1': cast as FP16 by whitelist. -- 'O2': keep FP32 by blacklist and the rest cast as FP16. -- 'O3': fully cast to FP16. +Two implementations of Loss Scale are provided in MindSpore, `StaticLossScaler` and `DynamicLossScaler`, whose difference is whether the loss scale value is dynamically adjusted. The following is an example of `DynamicLossScalar`, which implements the neural network training logic according to the mixed precision calculation process. -> The current black and white list is Cell granularity. +First, instantiate the LossScaler and manually scale up the loss value when defining the forward network. ```python -net = LeNet5(10) -auto_mixed_precision(net, 'O1') -``` +from mindspore.amp import DynamicLossScaler -Instantiate the LossScaler and manually scale up the loss value when defining the forward network. - -```python -loss_fn = nn.BCELoss(reduction='mean') -opt = nn.Adam(net.trainable_params(), learning_rate=0.01) +# Instantiate loss function and optimizer +loss_fn = nn.CrossEntropyLoss() +optimizer = nn.SGD(model.trainable_params(), 1e-2) # Define LossScaler loss_scaler = DynamicLossScaler(scale_value=2**10, scale_factor=2, scale_window=50) -def net_forward(data, label): - out = net(data) - loss_value = loss_fn(out, label) +def forward_fn(data, label): + logits = model(data) + loss = loss_fn(logits, label) # scale up the loss value - scaled_loss = loss_scaler.scale(loss_value) - return scaled_loss, out + loss = loss_scaler.scale(loss) + return loss, logits ``` -Reverse acquisition of gradients. +Next, a function transformation is performed to obtain the gradient function. ```python -grad_fn = ms.value_and_grad(net_forward, None, net.trainable_params()) +grad_fn = value_and_grad(forward_fn, None, model.trainable_params()) ``` -Define the training step: calculate the current gradient value and recover the loss. Use `all_finite` to determine whether there is a gradient underflow problem. If there is no overflow, recover the gradient and update the network weights, while if there is overflow, skip this step. +Define the training step: Calculates the current gradient value and recovers the loss. Use `all_finite` to determine if there is a gradient underflow problem. If there is no overflow, restore the gradient and update the network weight, while if there is overflow, skip this step. ```python +from mindspore.amp import all_finite + @ms.jit -def train_step(x, y): - (loss_value, _), grads = grad_fn(x, y) - loss_value = loss_scaler.unscale(loss_value) +def train_step(data, label): + (loss, _), grads = grad_fn(data, label) + loss = loss_scaler.unscale(loss) is_finite = all_finite(grads) if is_finite: grads = loss_scaler.unscale(grads) - loss_value = ops.depend(loss_value, opt(grads)) + loss = ops.depend(loss, optimizer(grads)) loss_scaler.adjust(is_finite) - return loss_value -``` - -Then a virtual random dataset is created for the data input of the sample model. - -```python -# create dataset -def get_data(num, img_size=(1, 32, 32), num_classes=10, is_onehot=True): - for _ in range(num): - img = np.random.randn(*img_size) - target = np.random.randint(0, num_classes) - target_ret = np.array([target]).astype(np.float16) - if is_onehot: - target_onehot = np.zeros(shape=(num_classes,)) - target_onehot[target] = 1 - target_ret = target_onehot.astype(np.float16) - yield img.astype(np.float16), target_ret - -def create_dataset(num_data=1024, batch_size=32, repeat_size=1): - input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data', 'label']) - input_data = input_data.batch(batch_size, drop_remainder=True) - input_data = input_data.repeat(repeat_size) - return input_data -``` - -Execute the training. - -```python -datasets = create_dataset() -epochs = 5 -for epoch in range(epochs): - for data, label in datasets: - loss = train_step(data, label) -``` - -## Mixed precision and Loss Scale by Using the Training Interface `Model` - -### Mixed-Precision - -The `Model` interface provides the input `amp_level` to achieve automatic mixed precision, or the user can set the operator involved in the Cell to FP16 via `to_float(ms.float16)` to achieve manual mixed precision. - -> This method only supports Ascend and GPU. - -#### Automatic Mixed-Precision - -To use the automatic mixed-precision, you need to call the `Model` API to transfer the network to be trained and optimizer as the input. This API converts the network model operators into FP16 operators. - -> Due to precision problems, the `BatchNorm` operator and operators involved in loss still use FP32. - -The specific implementation steps for using the `Model` interface are: - -1. Introduce the MindSpore model API `Model`. - -2. Define a network: This step is the same as that for defining a common network (no new configuration is required). - -3. Create a dataset: For this step, refer to [Data Processing](https://www.mindspore.cn/tutorials/en/master/advanced/dataset.html). - -4. Use the `Model` API to encapsulate the network model, optimizer, and loss function, and set the `amp_level` parameter. For details, see [MindSpore API](https://www.mindspore.cn/docs/en/master/api_python/train/mindspore.train.Model.html#mindspore.train.Model). In this step, MindSpore automatically selects an appropriate operator to convert FP32 to FP16. - -The following is a basic code example. First, import the required libraries and declarations. -```python -import numpy as np -import mindspore.nn as nn -from mindspore.train import Accuracy, Model -import mindspore as ms -from mindspore.common.initializer import Normal -from mindspore import dataset as ds - -ms.set_context(mode=ms.GRAPH_MODE) -ms.set_context(device_target="CPU") + return loss ``` -Create a virtual random dataset for data input of the sample model. +Finally, we train 1 epoch and observe the convergence of the loss trained using automatic mixed precision. ```python -# create dataset -def get_data(num, img_size=(1, 32, 32), num_classes=10, is_onehot=True): - for _ in range(num): - img = np.random.randn(*img_size) - target = np.random.randint(0, num_classes) - target_ret = np.array([target]).astype(np.float32) - if is_onehot: - target_onehot = np.zeros(shape=(num_classes,)) - target_onehot[target] = 1 - target_ret = target_onehot.astype(np.float32) - yield img.astype(np.float32), target_ret - -def create_dataset(num_data=1024, batch_size=32, repeat_size=1): - input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data', 'label']) - input_data = input_data.batch(batch_size, drop_remainder=True) - input_data = input_data.repeat(repeat_size) - return input_data +size = train_dataset.get_dataset_size() +model.set_train() +for batch, (data, label) in enumerate(train_dataset.create_tuple_iterator()): + loss = train_step(data, label) + + if batch % 100 == 0: + loss, current = loss.asnumpy(), batch + print(f"loss: {loss:>7f} [{current:>3d}/{size:>3d}]") ``` -Taking the LeNet5 as an example, set the `amp_level` parameter and use the `Model` API to encapsulate the network model, optimizer, and loss function. - -```python -ds_train = create_dataset() - -# Initialize network -network = LeNet5(10) - -# Define Loss and Optimizer -net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean") -net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9) -# Set amp level -model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O3") - -# Run training -model.train(epoch=10, train_dataset=ds_train) +```text +loss: 2.305425 [ 0/938] +loss: 2.289585 [100/938] +loss: 2.259094 [200/938] +loss: 2.176874 [300/938] +loss: 1.856715 [400/938] +loss: 1.398342 [500/938] +loss: 0.889620 [600/938] +loss: 0.709884 [700/938] +loss: 0.750509 [800/938] +loss: 0.482525 [900/938] ``` -#### Manual Mixed-Precision +It can be seen that the loss convergence is normal and there is no overflow problem. -MindSpore also supports manual mixed-precision. (Manual mixed-precision is not recommended unless you want to customize special networks and features.) +## Automatic Mixed Precision for `Cell` Configuration -Assume that only one Conv layer on the network uses FP16 for computation and other layers use FP32. +MindSpore supports a programming paradigm that uses Cell to encapsulate the full computational graph. When the `mindspore.amp.build_train_network` interface can be used to automatically perform the type conversion and pass in the Loss Scale as part of the full graph computation. At this point, you only need to configure the mixed precision level and `LossScaleManager` to get the computational graph with the configured automatic mixed precision. -> The mixed-precision is configured in the unit of Cell. The default type of a Cell is FP32. +`FixedLossScaleManager` and `DynamicLossScaleManager` are the Loss scale management interfaces for configuring the automatic mixed precision with `Cell`, corresponding to `StaticLossScalar` and `DynamicLossScalar`, respectively. For detailed information, refer to [mindspore.amp](https://www.mindspore.cn/docs/en/master/api_python/mindspore.amp.html). -The following are the implementation steps of manual mixed-precision: - -1. Define the network: This step is similar with the Step 2 in the automatic mixed-precision. -2. Configure the mixed-precision: Use `to_float(mstype.float16)` to set the operators involved in the Cell to FP16. -3. Use `TrainOneStepCell` to encapsulate the network model and optimizer. - -The following is a basic code example. First, import the required libraries and declarations. +> Automated mixed precision training with `Cell` configuration supports only `GPU` and `Ascend`. ```python -import numpy as np - -import mindspore.nn as nn -from mindspore.train import Accuracy, Model -import mindspore as ms -from mindspore.common.initializer import Normal -from mindspore import dataset as ds -import mindspore.ops as ops +from mindspore.amp import build_train_network, FixedLossScaleManager -ms.set_context(mode=ms.GRAPH_MODE, device_target="GPU") -``` - -After initializing the network model, declare that the Conv1 layer in LeNet5 is computed by using FP16, i.e. `network.conv1.to_float(mstype.float16)`. +model = Network() +loss_scale_manager = FixedLossScaleManager() -```python -ds_train = create_dataset() -network = LeNet5(10) -net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean") -net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9) -network.conv1.to_float(ms.float16) -model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O2") -model.train(epoch=2, train_dataset=ds_train) +model = build_train_network(model, optimizer, loss_fn, level="O2", loss_scale_manager=loss_scale_manager) ``` -> When mixed-precision is used, the backward network can be generated only by the automatic differential function, not by user-defined inverse networks. Otherwise, MindSpore may generate exception information indicating that the data format does not match. - -### Loss scale +## `Model` Configure Automatic Mixed Precision -The following two APIs in MindSpore that use the loss scaling algorithm are described separately [FixedLossScaleManager](https://www.mindspore.cn/docs/en/master/api_python/amp/mindspore.amp.FixedLossScaleManager.html#mindspore.amp.FixedLossScaleManager) and [DynamicLossScaleManager](https://www.mindspore.cn/docs/en/master/api_python/amp/mindspore.amp.DynamicLossScaleManager.html#mindspore.amp.DynamicLossScaleManager). +`mindspore.train.Model` is a high level encapsulation for fast training of neural networks, which encapsulates `mindspore.amp.build_train_network`, so again, only the mixed precision level and `LossScaleManager` need to be configured for automatic mixed precision training. -#### FixedLossScaleManager - -`FixedLossScaleManager` does not change the size of the scale when scaling, and the value of the scale is controlled by the input parameter loss_scale, which can be specified by the user. The default value is taken if it is not specified. - -Another parameter of `FixedLossScaleManager` is `drop_overflow_update`, which controls whether parameters are updated in the event of an overflow. - -In general, the LossScale function does not need to be used with the optimizer, but when using `FixedLossScaleManager`, if `drop_overflow_update` is False, the optimizer needs to set the value of `loss_scale` and the value of `loss_scale` should be the same as that of `FixedLossScaleManager`. - -The detailed use of `FixedLossScaleManager` is as follows: - -Import the necessary libraries and declare execution using graph mode. +> Automated mixed precision training with `Model` configuration supports only `GPU` and `Ascend`. ```python -import numpy as np -import mindspore as ms -import mindspore.nn as nn -from mindspore import amp -from mindspore.train import Accuracy, Model -from mindspore.common.initializer import Normal -from mindspore import dataset as ds - -ms.set_context(mode=ms.GRAPH_MODE, device_target="GPU") -``` - -Define the network model by using LeNet5 as an example; define the dataset and the interfaces commonly used in the training process. - -```python -ds_train = create_dataset() +from mindspore.train import Model, LossMonitor # Initialize network -network = LeNet5(10) -# Define Loss and Optimizer -net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean") -``` - -Use Loss Scale API to act in optimizers and models. - -```python -# Define Loss Scale, optimizer and model -#1) Drop the parameter update if there is an overflow -loss_scale_manager = amp.FixedLossScaleManager() -net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9) -model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O0", loss_scale_manager=loss_scale_manager) - -#2) Execute parameter update even if overflow occurs -loss_scale = 1024.0 -loss_scale_manager = amp.FixedLossScaleManager(loss_scale, False) -net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9, loss_scale=loss_scale) -model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O0", loss_scale_manager=loss_scale_manager) - -# Run training -model.train(epoch=10, train_dataset=ds_train, callbacks=[ms.LossMonitor()]) -``` - -#### LossScale and Optimizer +model = Network() -As mentioned earlier, the optimizer needs to be used together when using `FixedLossScaleManager` and `drop_overflow_update` is False. +loss_scale_manager = FixedLossScaleManager() +trainer = Model(model, loss_fn=loss_fn, optimizer=optimizer, metrics={'accuracy'}, amp_level="O2", loss_scale_manager=loss_scale_manager) -This is due to the fact that when configured in this way, the division between the gradient and the `loss_scale` coefficient is performed in the optimizer. The optimizer setting is the same `loss_scale` as `FixedLossScaleManager` and the training result is correct. - -> Subsequent MindSpore will optimize the use of overflow detection in different scenarios, and gradually remove the `loss_scale` parameter in the optimizer, so that there is no need to configure the `loss_scale` parameter of the optimizer. - -It should be noted that some of the optimizers provided by MindSpore, such as `AdamWeightDecay`, do not provide the `loss_scale` parameter. If you use `FixedLossScaleManager` and the `drop_overflow_update` is configured as False, and the division between the gradient and the `loss_scale` is not performed in the optimizer, you need to customize the `TrainOneStepCell` and divide the gradient by `loss_scale` in it so that the final calculation is correct, as defined as follows: - -```python -import mindspore as ms -from mindspore.train import Model -from mindspore import nn, ops - -grad_scale = ops.MultitypeFuncGraph("grad_scale") - -@grad_scale.register("Tensor", "Tensor") -def gradient_scale(scale, grad): - return grad * ops.cast(scale, ops.dtype(grad)) - -class CustomTrainOneStepCell(nn.TrainOneStepCell): - def __init__(self, network, optimizer, sens=1.0): - super(CustomTrainOneStepCell, self).__init__(network, optimizer, sens) - self.hyper_map = ops.HyperMap() - self.reciprocal_sense = ms.Tensor(1 / sens, ms.float32) - - def scale_grad(self, gradients): - gradients = self.hyper_map(ops.partial(grad_scale, self.reciprocal_sense), gradients) - return gradients - - def construct(self, *inputs): - loss = self.network(*inputs) - sens = ops.fill(loss.dtype, loss.shape, self.sens) - # calculate gradients, the sens will equal to the loss_scale - grads = self.grad(self.network, self.weights)(*inputs, sens) - # gradients / loss_scale - grads = self.scale_grad(grads) - # reduce gradients in distributed scenarios - grads = self.grad_reducer(grads) - loss = ops.depend(loss, self.optimizer(grads)) - return loss -``` - -- network: The network participating in the training, which contains the computational logic of the forward network and the loss function, input data and labels, and output loss function values. -- optimizer: The used optimizer. -- sens: Parameters are used to receive a user-specified `loss_scale` and the gradient value is magnified by a factor of `loss_scale` during training. -- scale_grad function: Used for division between the gradient and the `loss_scale` coefficient to restore the gradient. -- construct function: Referring to `nn. TrainOneStepCell`, defines the computational logic for `construct` and calls `scale_grad` after acquiring the gradient. - -After customizing `TrainOneStepCell`, the training network needs to be manually built, which is as follows: - -```python -from mindspore import nn -from mindspore import amp - -network = LeNet5(10) - -# Define Loss and Optimizer -net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean") -net_opt = nn.AdamWeightDecay(network.trainable_params(), learning_rate=0.01) - -# Define LossScaleManager -loss_scale = 1024.0 -loss_scale_manager = amp.FixedLossScaleManager(loss_scale, False) - -# Build train network -net_with_loss = nn.WithLossCell(network, net_loss) -net_with_train = CustomTrainOneStepCell(net_with_loss, net_opt, loss_scale) -``` - -After building the training network, it can be run directly or via Model: - -```python -epochs = 2 - -#1) Execute net_with_train -ds_train = create_dataset() - -for epoch in range(epochs): - for d in ds_train.create_dict_iterator(): - result = net_with_train(d["data"], d["label"]) - -#2) Define Model and run -model = Model(net_with_train) - -ds_train = create_dataset() - -model.train(epoch=epochs, train_dataset=ds_train) -``` - -When training with `Model` in this scenario, the `loss_scale_manager` and `amp_level` do not need to be configured, as the `CustomTrainOneStepCell` already includes mixed-precision calculation logic. - -#### DynamicLossScaleManager - -`DynamicLossScaleManager` can dynamically change the size of the scale during training, keeping the scale as large as possible without overflow. - -`DynamicLossScaleManager` first sets scale to an initial value, which is controlled by the input init_loss_scale. - -During training, if no overflow occurs, after updating the parameters scale_window times, an attempt is made to expand the value of the scale, and if an overflow occurs, the parameter update is skipped and the value of the scale is reduced, and the scale_factor is to control the number of steps that are expanded or reduced. scale_window controls the maximum number of consecutive update steps when no overflow occurs. - -The detailed use is as follows and we only need to define LossScale in `FixedLossScaleManager` sample. The part code of the optimizer and model changes as the following code: - -```python -# Define Loss Scale, optimizer and model -scale_factor = 4 -scale_window = 3000 -loss_scale_manager = amp.DynamicLossScaleManager(scale_factor, scale_window) -net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9) -model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O0", loss_scale_manager=loss_scale_manager) +loss_callback = LossMonitor(100) +trainer.train(10, train_dataset, callbacks=[loss_callback]) ``` -> The pictures are cited from [automatic-mixed-precision](https://developer.nvidia.com/automatic-mixed-precision). +> The image is quoted from [automatic-mixed-precision](https://developer.nvidia.com/automatic-mixed-precision). diff --git a/tutorials/source_zh_cn/advanced/mixed_precision.ipynb b/tutorials/source_zh_cn/advanced/mixed_precision.ipynb index 7de305a42583753c9d63a902f0d88d0720a81d8e..c24af517b25239e097ce092a94b3b36ff5e7897c 100644 --- a/tutorials/source_zh_cn/advanced/mixed_precision.ipynb +++ b/tutorials/source_zh_cn/advanced/mixed_precision.ipynb @@ -17,7 +17,7 @@ "source": [ "混合精度(Mix Precision)训练是指在训练时,对神经网络不同的运算采用不同的数值精度的运算策略。在神经网络运算中,部分运算对数值精度不敏感,此时使用较低精度可以达到明显的加速效果(如conv、matmul等);而部分运算由于输入和输出的数值差异大,通常需要保留较高精度以保证结果的正确性(如log、softmax等)。\n", "\n", - "当前的AI加速卡通常通过针对计算密集、精度不敏感的运算设计了硬件加速模块,如NVIDIA GPU的TensorCore,Ascend NPU的Cube等。对于conv、matmul等运算占比较大的神经网络,其训练速度通常会有较大的加速比(2-3X)。\n", + "当前的AI加速卡通常通过针对计算密集、精度不敏感的运算设计了硬件加速模块,如NVIDIA GPU的TensorCore、Ascend NPU的Cube等。对于conv、matmul等运算占比较大的神经网络,其训练速度通常会有较大的加速比。\n", "\n", "`mindspore.amp`模块提供了便捷的自动混合精度接口,用户可以在不同的硬件后端通过简单的接口调用获得训练加速。下面我们对混合精度计算原理进行简介,而后通过实例介绍MindSpore的自动混合精度用法。" ] @@ -50,7 +50,7 @@ "但是使用FP16同样会带来一些问题:\n", "\n", "- 数据溢出:FP16的有效数据表示范围为 $[6.10\\times10^{-5}, 65504]$,FP32的有效数据表示范围为 $[1.4\\times10^{-45}, 1.7\\times10^{38}]$。可见FP16相比FP32的有效范围要窄很多,使用FP16替换FP32会出现上溢(Overflow)和下溢(Underflow)的情况。而在深度学习中,需要计算网络模型中权重的梯度(一阶导数),因此梯度会比权重值更加小,往往容易出现下溢情况。\n", - "- 舍入误差:Rounding Error指示是当网络模型的反向梯度很小,一般FP32能够表示,但是转换到FP16会小于当前区间内的最小间隔,会导致数据溢出。如`0.00006666666`在FP32中能正常表示,转换到FP16后会表示成为`0.000067`,不满足FP16最小间隔的数会强制舍入。\n", + "- 舍入误差:Rounding Error是指当网络模型的反向梯度很小,一般FP32能够表示,但是转换到FP16会小于当前区间内的最小间隔,会导致数据溢出。如`0.00006666666`在FP32中能正常表示,转换到FP16后会表示成为`0.000067`,不满足FP16最小间隔的数会强制舍入。\n", "\n", "因此,在使用混合精度获得训练加速和内存节省的同时,需要考虑FP16引入问题的解决。Loss Scale损失缩放,FP16类型数据下溢问题的解决方案,其主要思想是在计算损失值loss的时候,将loss扩大一定的倍数。根据链式法则,梯度也会相应扩大,然后在优化器更新权重时再缩小相应的倍数,从而避免了数据下溢。" ] @@ -78,7 +78,7 @@ "id": "81c60159-5242-4ea2-a73e-6015de4a675c", "metadata": {}, "source": [ - "下面我们导入[快速入门](https://www.mindspore.cn/tutorials/zh-CN/master/beginner/quick_start.html)中的手写数字识别模型及数据集,示例MindSpore的自动混合精度实现。" + "下面我们通过导入[快速入门](https://www.mindspore.cn/tutorials/zh-CN/master/beginner/quick_start.html)中的手写数字识别模型及数据集,演示MindSpore的自动混合精度实现。" ] }, { @@ -259,7 +259,7 @@ "source": [ "#### 自定义粒度类型转换\n", "\n", - "当用户需要在单个运算,或多个模块组合配置运算精度时,Cell粒度往往无法满足,此时可以直接通过对输入数据的类型进行cast来达到自定义粒度控制的目的:" + "当用户需要在单个运算,或多个模块组合配置运算精度时,Cell粒度往往无法满足,此时可以直接通过对输入数据的类型进行cast来达到自定义粒度控制的目的。" ] }, { @@ -296,7 +296,7 @@ "source": [ "## 损失缩放\n", "\n", - "MindSpore中提供了两种Loss Scale的实现, 分别为`StaticLossScaler`和`DynamicLossScaler`,其差异为损失缩放值scale value是否进行动态调整。下面以`DynamicLossScalar`为例,根据混合精度计算流程实现神经网络训练逻辑。\n", + "MindSpore中提供了两种Loss Scale的实现,分别为`StaticLossScaler`和`DynamicLossScaler`,其差异为损失缩放值scale value是否进行动态调整。下面以`DynamicLossScalar`为例,根据混合精度计算流程实现神经网络训练逻辑。\n", "\n", "首先,实例化LossScaler,并在定义前向网络时,手动放大loss值。" ]