diff --git a/tutorials/application/source_en/cv/images/output_6_1.png b/tutorials/application/source_en/cv/images/output_6_1.png new file mode 100644 index 0000000000000000000000000000000000000000..d33371ef83081fc45775f8e2a331a5120d59099f Binary files /dev/null and b/tutorials/application/source_en/cv/images/output_6_1.png differ diff --git a/tutorials/application/source_en/cv/resnet50.md b/tutorials/application/source_en/cv/resnet50.md index 35ec45609824cc3bec40867da389b74f050a9492..094f1386aa03fc0ddec0bb07393cf088dd522e67 100644 --- a/tutorials/application/source_en/cv/resnet50.md +++ b/tutorials/application/source_en/cv/resnet50.md @@ -18,71 +18,155 @@ In ResNet, a residual network is proposed to alleviate the degradation problem, ## Preparing and Loading Datasets -[The CIFAR-10 dataset](http://www.cs.toronto.edu/~kriz/cifar.html) contains 60,000 32 x 32 color images in 10 classes, with 6,000 images per class. There are 50,000 training images and 10,000 test images. The following example uses the `mindvision.classification.dataset.Cifar10` API to download and load the CIFAR-10 dataset. +[The CIFAR-10 dataset](http://www.cs.toronto.edu/~kriz/cifar.html) contains 60,000 32 x 32 color images in 10 classes, with 6,000 images per class. There are 50,000 training images and 10,000 test images. First, the following example uses the `download` interface to download and decompress the CIFAR-10 file, which currently only supports parsing the binary version (CIFAR-10 binary version). ```python -from mindvision.classification.dataset import Cifar10 - -# Dataset root directory -data_dir = "./datasets" - -# Download, decompress, and load the CIFAR-10 training dataset. -dataset_train = Cifar10(path=data_dir, split='train', batch_size=6, resize=32, download=True) -ds_train = dataset_train.run() -step_size = ds_train.get_dataset_size() -# Download, decompress, and load the CIFAR-10 test dataset. -dataset_val = Cifar10(path=data_dir, split='test', batch_size=6, resize=32, download=True) -ds_val = dataset_val.run() +from download import download + +url = "http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz" + +download(url, "./datasets-cifar10-bin", kind="tar.gz") +``` + +```tex +Creating data folder... +Downloading data from http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz (162.2 MB) + +file_sizes: 100%|████████████████████████████| 170M/170M [00:26<00:00, 6.38MB/s] +Extracting tar.gz file... +Successfully downloaded / unzipped to ./datasets-cifar10-bin ``` The directory structure of the CIFAR-10 dataset file is as follows: ```Text -datasets/ -├── cifar-10-batches-py -│ ├── batches.meta -│ ├── data_batch_1 -│ ├── data_batch_2 -│ ├── data_batch_3 -│ ├── data_batch_4 -│ ├── data_batch_5 -│ ├── readme.html -│ └── test_batch -└── cifar-10-python.tar.gz +datasets-cifar10-bin/cifar-10-batches-bin +├── batches.meta.text +├── data_batch_1.bin +├── data_batch_2.bin +├── data_batch_3.bin +├── data_batch_4.bin +├── data_batch_5.bin +├── readme.html +└── test_batch.bin ``` -Visualize the CIFAR-10 training dataset. +Then, the `mindspore.dataset.Cifar10Dataset` interface is used to load the dataset and perform the associated image transforms. ```python +import mindspore.dataset as ds +import mindspore.dataset.vision as vision +import mindspore as ms import numpy as np + +from mindspore import nn, ops + + +data_dir = "./datasets-cifar10-bin/cifar-10-batches-bin" # Dataset root directory +batch_size = 6 # Batch size +image_size = 32 # Size of training image space +workers = 4 # Number of parallel threads +num_classes = 10 # Number of classes + +def create_dataset_cifar10(dataset_dir, usage, resize, batch_size, workers): + + data_set = ds.Cifar10Dataset(dataset_dir=dataset_dir, + usage=usage, + num_parallel_workers=workers, + shuffle=True) + + trans = [] + if usage == "train": + trans += [ + vision.RandomCrop((32, 32), (4, 4, 4, 4)), + vision.RandomHorizontalFlip(prob=0.5) + ] + + trans += [ + vision.Resize(resize), + vision.Rescale(1.0 / 255.0, 0.0), + vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]), + vision.HWC2CHW() + ] + + target_trans = [(lambda x: np.array([x]).astype(np.int32)[0])] + + # Data mapping operation + data_set = data_set.map( + operations=trans, + input_columns='image', + num_parallel_workers=workers) + + data_set = data_set.map( + operations=target_trans, + input_columns='label', + num_parallel_workers=workers) + + # Batch operation + data_set = data_set.batch(batch_size) + + + return data_set + + +# Obtain the processed training and test datasets + +dataset_train = create_dataset_cifar10(dataset_dir=data_dir, + usage="train", + resize=image_size, + batch_size=batch_size, + workers=workers) +step_size_train = dataset_train.get_dataset_size() +index_label_dict = dataset_train.get_class_indexing() + +dataset_val = create_dataset_cifar10(dataset_dir=data_dir, + usage="test", + resize=image_size, + batch_size=batch_size, + workers=workers) +step_size_val = dataset_val.get_dataset_size() +``` + +Visualize the CIFAR-10 training dataset. + +```python import matplotlib.pyplot as plt +import numpy as np -data = next(ds_train.create_dict_iterator()) +data_iter = next(dataset_train.create_dict_iterator()) -images = data["image"].asnumpy() -labels = data["label"].asnumpy() +images = data_iter["image"].asnumpy() +labels = data_iter["label"].asnumpy() print(f"Image shape: {images.shape}, Label: {labels}") +classes = [] + +with open(data_dir+"/batches.meta.txt", "r") as f: + for line in f: + line = line.rstrip() + if line != '': + classes.append(line) + plt.figure() -for i in range(1, 7): - plt.subplot(2, 3, i) - image_trans = np.transpose(images[i - 1], (1, 2, 0)) +for i in range(6): + plt.subplot(2, 3, i+1) + image_trans = np.transpose(images[i], (1, 2, 0)) mean = np.array([0.4914, 0.4822, 0.4465]) std = np.array([0.2023, 0.1994, 0.2010]) image_trans = std * image_trans + mean image_trans = np.clip(image_trans, 0, 1) - plt.title(f"{dataset_train.index2label[labels[i - 1]]}") + plt.title(f"{classes[labels[i]]}") plt.imshow(image_trans) plt.axis("off") plt.show() ``` -```python - Image shape: (6, 3, 32, 32), Label: [6 4 4 5 2 1] +```text +Image shape: (6, 3, 32, 32), Label: [5 8 0 3 0 9] ``` -![png](images/output_3_1.png) +![](images/output_6_1.png) ## Building a Network @@ -100,22 +184,24 @@ There are two residual network structures. One is the building block, which is a The following figure shows the structure of the building block. The main body has two convolutional layers. -+ On the first-layer network of the main body. 64 input channels are used. Then, 64 output channels are obtained through the $3\times3$ convolutional layer, the Batch Normalization layer, and the ReLU activation function layer. ++ On the first-layer network of the main body, 64 input channels are used. Then, 64 output channels are obtained through the $3\times3$ convolutional layer, the Batch Normalization layer, and the ReLU activation function layer. + On the second-layer network of the main body, 64 input channels are also used. Then, 64 output channels are obtained through the $3\times3$ convolutional layer, the Batch Normalization layer, and the ReLU activation function layer. Finally, the feature matrix output by the main body is added to the feature matrix output by the shortcut. After the ReLU activation function is used, the final output of the building block is obtained. ![building-block-5](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/application/source_zh_cn/cv/images/resnet_5.png) -When adding the feature matrix output by the main body to that output by the shortcut, ensure that the shape of the feature matrix output by the main body is the same as that of the feature matrix output by the shortcut. If the shapes are different, for example, when the number of output channels is twice that of input channels, the number of convolution kernels used by the shortcut for convolution operations is the same as that of the output channels and the size is $1\times1$. If the size of the output image is half of that of the input image, `stride` in the convolution operation of the shortcut must be set to **2**, and `stride` in the first-layer convolution operation of the main body must also be set to **2**. +When adding the feature matrix output by the main body to that output by the shortcut, ensure that the shape of the feature matrix output by the main body is the same as that of the feature matrix output by the shortcut. If the shapes are different, for example, when the number of output channels is twice that of input channels, the number of convolution kernels used by the shortcut for convolution operations is the same as that of the output channels and the size is $1\times1$. If the size of the output image is half of that of the input image, `stride` in the convolution operation of the shortcut must be set to 2, and `stride` in the first-layer convolution operation of the main body must also be set to 2. -The following code defines the `ResidualBlockBase` class to implement the building block structure: +The following code defines the `ResidualBlockBase` class to implement the building block structure. ```python from typing import Type, Union, List, Optional -from mindvision.classification.models.blocks import ConvNormActivation from mindspore import nn +from mindspore.common.initializer import Normal +weight_init = Normal(mean=0, sigma=0.02) +gamma_init = Normal(mean=1, sigma=0.02) class ResidualBlockBase(nn.Cell): expansion: int = 1 # The number of convolution kernels at the last layer is the same as that of convolution kernels at the first layer. @@ -125,12 +211,15 @@ class ResidualBlockBase(nn.Cell): down_sample: Optional[nn.Cell] = None) -> None: super(ResidualBlockBase, self).__init__() if not norm: - norm = nn.BatchNorm2d - - self.conv1 = ConvNormActivation(in_channel, out_channel, - kernel_size=3, stride=stride, norm=norm) - self.conv2 = ConvNormActivation(out_channel, out_channel, - kernel_size=3, norm=norm, activation=None) + self.norm = nn.BatchNorm2d(out_channel) + else: + self.norm = norm + + self.conv1 = nn.Conv2d(in_channel, out_channel, + kernel_size=3, stride=stride, + weight_init=weight_init) + self.conv2 = nn.Conv2d(in_channel, out_channel, + kernel_size=3, weight_init=weight_init) self.relu = nn.ReLU() self.down_sample = down_sample @@ -139,11 +228,14 @@ class ResidualBlockBase(nn.Cell): identity = x # shortcut out = self.conv1(x) # First layer of the main body: 3 x 3 convolutional layer + out = self.norm(out) + out = self.relu(out) out = self.conv2(out) # Second layer of the main body: 3 x 3 convolutional layer + out = self.norm(out) - if self.down_sample: + if self.down_sample is not None: identity = self.down_sample(x) - out += identity # The output is the sum of the main body and the shortcut. + out += identity # output the sum of the main body and the shortcuts out = self.relu(out) return out @@ -151,37 +243,39 @@ class ResidualBlockBase(nn.Cell): #### Bottleneck -The following figure shows the bottleneck structure. With the same input, the bottleneck structure has fewer parameters than the building block structure. Therefore, the bottleneck structure is more suitable for a deep network. The residual structure used by ResNet-50 is bottleneck. The main body of this structure has three convolutional layers, namely, $1\times1$, $3\times3$, and $1\times1$. $1\times1$ is used for dimension reduction and dimension rollup. +The following figure shows the bottleneck structure. With the same input, the bottleneck structure has fewer parameters than the building block structure. Therefore, the bottleneck structure is more suitable for a deep network. The residual structure used by ResNet-50 is bottleneck. The main branch of this structure has three convolutional layers, namely, the $1\times1$ convolutional layer, the $3\times3$ convolutional layer and the $1\times1$ convolutional layer, where the $1\times1$ convolutional layer plays the role of dimensionality reduction and dimensionality enhancement, respectively. + On the first-layer network of the main body, 256 input channels are used. Dimension reduction is performed by using 64 convolution kernels with a size of $1\times1$. Then, 64 output channels are obtained through the Batch Normalization layer and the ReLU activation function layer. + On the second-layer network of the main body, features are extracted by using 64 convolution kernels with a size of $3\times3$. Then, 64 output channels are obtained through the Batch Normalization layer and the ReLU activation function layer. + On the third-layer network of the main body, dimension rollup is performed by using 256 convolution kernels with a size of $1\times1$. Then, 256 output channels are obtained through the Batch Normalization layer. -Finally, the feature matrix output by the main body is added to the feature matrix output by the shortcut. After the ReLU activation function is used, the final output of the bottleneck is obtained. +Finally, the feature matrix output by the main body is added to that output by the shortcut. After the ReLU activation function is used, the final output of the bottleneck is obtained. ![building-block-6](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/application/source_zh_cn/cv/images/resnet_6.png) -When adding the feature matrix output by the main body to that output by the shortcut, ensure that the shape of the feature matrix output by the main body is the same as that of the feature matrix output by the shortcut. If the shapes are different, for example, when the number of output channels is twice that of input channels, the number of convolution kernels used by the shortcut for convolution operations is the same as that of the output channels and the size is $1\times1$. If the size of the output image is half of that of the input image, `stride` in the convolution operation of the shortcut must be set to **2**, and `stride` in the second-layer convolution operation of the main body must also be set to **2**. +When adding the feature matrix output by the main body to that output by the shortcut, ensure that the shape of the feature matrix output by the main body is the same as that of the feature matrix output by the shortcut. If the shapes are different, for example, when the number of output channels is twice that of input channels, the number of convolution kernels used by the shortcut for convolution operations is the same as that of the output channels and the size is $1\times1$. If the size of the output image is half of that of the input image, `stride` in the convolution operation of the shortcut must be set to 2, and `stride` in the second-layer convolution operation of the main body must also be set to 2. -The following code defines the `ResidualBlock` class to implement the bottleneck structure: +The following code defines the `ResidualBlock` class to implement the bottleneck structure. ```python class ResidualBlock(nn.Cell): expansion = 4 # The number of convolution kernels at the last layer is four times that of convolution kernels at the first layer. def __init__(self, in_channel: int, out_channel: int, - stride: int = 1, norm: Optional[nn.Cell] = None, - down_sample: Optional[nn.Cell] = None) -> None: + stride: int = 1, down_sample: Optional[nn.Cell] = None) -> None: super(ResidualBlock, self).__init__() - if not norm: - norm = nn.BatchNorm2d - - self.conv1 = ConvNormActivation(in_channel, out_channel, - kernel_size=1, norm=norm) - self.conv2 = ConvNormActivation(out_channel, out_channel, - kernel_size=3, stride=stride, norm=norm) - self.conv3 = ConvNormActivation(out_channel, out_channel * self.expansion, - kernel_size=1, norm=norm, activation=None) + + self.conv1 = nn.Conv2d(in_channel, out_channel, + kernel_size=1, weight_init=weight_init) + self.norm1 = nn.BatchNorm2d(out_channel) + self.conv2 = nn.Conv2d(out_channel, out_channel, + kernel_size=3, stride=stride, + weight_init=weight_init) + self.norm2 = nn.BatchNorm2d(out_channel) + self.conv3 = nn.Conv2d(out_channel, out_channel * self.expansion, + kernel_size=1, weight_init=weight_init) + self.norm3 = nn.BatchNorm2d(out_channel * self.expansion) + self.relu = nn.ReLU() self.down_sample = down_sample @@ -189,10 +283,15 @@ class ResidualBlock(nn.Cell): identity = x # shortcut out = self.conv1(x) # First layer of the main body: 1 x 1 convolutional layer + out = self.norm1(out) + out = self.relu(out) out = self.conv2(out) # Second layer of the main body: 3 x 3 convolutional layer + out = self.norm2(out) + out = self.relu(out) out = self.conv3(out) # Third layer of the main body: 1 x 1 convolutional layer + out = self.norm3(out) - if self.down_sample: + if self.down_sample is not None: identity = self.down_sample(x) out += identity # The output is the sum of the main body and the shortcut. @@ -220,19 +319,25 @@ The following example defines `make_layer` to build residual blocks. The paramet ```python def make_layer(last_out_channel, block: Type[Union[ResidualBlockBase, ResidualBlock]], channel: int, block_nums: int, stride: int = 1): - down_sample = None # shortcut + down_sample = None # shortcuts + if stride != 1 or last_out_channel != channel * block.expansion: - down_sample = ConvNormActivation(last_out_channel, channel * block.expansion, - kernel_size=1, stride=stride, norm=nn.BatchNorm2d, activation=None) + + down_sample = nn.SequentialCell([ + nn.Conv2d(last_out_channel, channel * block.expansion, + kernel_size=1, stride=stride, weight_init=weight_init), + nn.BatchNorm2d(channel * block.expansion, gamma_init=gamma_init) + ]) layers = [] - layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample, norm=nn.BatchNorm2d)) + layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample)) in_channel = channel * block.expansion # Stack residual networks. for _ in range(1, block_nums): - layers.append(block(in_channel, channel, norm=nn.BatchNorm2d)) + + layers.append(block(in_channel, channel)) return nn.SequentialCell(layers) ``` @@ -248,35 +353,41 @@ ResNet-50 has five convolution structures, one average pooling layer, and one fu The following sample code is used to build a ResNet-50 model. You can call the `resnet50` function to build a ResNet-50 model. The parameters of the `resnet50` function are as follows: -+ `num_classes`: number of classes. The default value is **1000**. -+ `pretrained`: Used to download the corresponding training model and load the parameters in the pre-trained model to the network. ++ `num_classes`: number of classes. The default value is 1000. ++ `pretrained`: download the corresponding training model and load the parameters in the pre-trained model to the network. ```python -from mindvision.classification.models.classifiers import BaseClassifier -from mindvision.classification.models.head import DenseHead -from mindvision.classification.models.neck import GlobalAvgPooling -from mindvision.classification.utils.model_urls import model_urls -from mindvision.utils.load_pretrained_model import LoadPretrainedModel +from mindspore import load_checkpoint, load_param_into_net class ResNet(nn.Cell): def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]], - layer_nums: List[int], norm: Optional[nn.Cell] = None) -> None: + layer_nums: List[int], num_classes: int, input_channel: int) -> None: super(ResNet, self).__init__() - if not norm: - norm = nn.BatchNorm2d + + self.relu = nn.ReLU() # At the first convolutional layer, the number of the input channels is 3 (color image) and that of the output channels is 64. - self.conv1 = ConvNormActivation(3, 64, kernel_size=7, stride=2, norm=norm) + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init) + self.norm = nn.BatchNorm2d(64) # Maximum pooling layer, reducing the image size self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') - # Definition of each residual network structure block + # Define each residual network structure block self.layer1 = make_layer(64, block, 64, layer_nums[0]) self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2) self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2) self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2) + # average pooling layer + self.avg_pool = nn.AvgPool2d() + # flattern layer + self.flatten = nn.Flatten() + # fully-connected layer + self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes) def construct(self, x): + x = self.conv1(x) + x = self.norm(x) + x = self.relu(x) x = self.max_pool(x) x = self.layer1(x) @@ -284,79 +395,137 @@ class ResNet(nn.Cell): x = self.layer3(x) x = self.layer4(x) - return x + x = self.avg_pool(x) + x = self.flatten(x) + x = self.fc(x) + return x +``` -def _resnet(arch: str, block: Type[Union[ResidualBlockBase, ResidualBlock]], - layers: List[int], num_classes: int, pretrained: bool, input_channel: int): - backbone = ResNet(block, layers) - neck = GlobalAvgPooling() # Average pooling layer - head = DenseHead(input_channel=input_channel, num_classes=num_classes) # Fully-connected layer - model = BaseClassifier(backbone, neck, head) # Connect the backbone layer, neck layer, and head layer. +```python +def _resnet(model_url: str, block: Type[Union[ResidualBlockBase, ResidualBlock]], + layers: List[int], num_classes: int, pretrained: bool, pretrained_ckpt: str, + input_channel: int): + model = ResNet(block, layers, num_classes, input_channel) if pretrained: - # Download and load the pre-trained model. - LoadPretrainedModel(model, model_urls[arch]).run() + # load pre-trained models + download(url=model_url, path=pretrained_ckpt) + param_dict = load_checkpoint(pretrained_ckpt) + load_param_into_net(model, param_dict) return model def resnet50(num_classes: int = 1000, pretrained: bool = False): - "ResNet-50 model" - return _resnet("resnet50", ResidualBlock, [3, 4, 6, 3], num_classes, pretrained, 2048) + "ResNet50 model" + resnet50_url = "https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/models/application/resnet50_224_new.ckpt" + resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt" + return _resnet(resnet50_url, ResidualBlock, [3, 4, 6, 3], num_classes, + pretrained, resnet50_ckpt, 2048) ``` ## Model Training and Evaluation -In this part, [a ResNet-50 pre-trained model](https://download.mindspore.cn/vision/classification/resnet50_224.ckpt) is used for fine-tuning. Call `resnet50` to build a ResNet-50 model and set `pretrained` to **True**. The ResNet-50 pre-trained model is automatically downloaded and the parameters of the pre-trained model are loaded to the network. Define an optimizer and a loss function, train the network by using the `model.train` API, and transfer the `mindvision.engine.callback.ValAccMonitor` API in MindSpore Vision to the callback function. The loss value and evaluation accuracy of the training are printed, and the CKPT file (**best.ckpt**) with the highest evaluation accuracy is saved to the current directory. +In this part, [a ResNet-50 pre-trained model](https://download.mindspore.cn/vision/classification/resnet50_224.ckpt) is used for fine-tuning. Call `resnet50` to build a ResNet50 model and set `pretrained` to True. The ResNet50 pre-trained model is automatically downloaded and the parameters of the pre-trained model are loaded to the network. Define the optimizer and loss function, print the loss values and evaluation accuracy of the training epoch by epoch, and save the ckpt file with the highest evaluation accuracy (resnet50-best.ckpt) to . /BestCheckPoint of the current path. ```python -from mindspore.train import Model, Accuracy -from mindvision.engine.callback import ValAccMonitor - -# Define the ResNet-50 network. +import mindspore as ms +# Define the ResNet50 network. network = resnet50(pretrained=True) # Size of the input layer of the fully-connected layer -in_channel = network.head.dense.in_channels -head = DenseHead(input_channel=in_channel, num_classes=10) +in_channel = network.fc.in_channels +fc = nn.Dense(in_channels=in_channel, out_channels=10) # Reset the fully-connected layer. -network.head = head -# Set the learning rate. +network.fc = fc + +for param in network.get_parameters(): + param.requires_grad = True +``` + +```text +Replace is False and data exists, so doing nothing. Use replace=True to re-download the data. +``` + +```python +# Set the learning rate num_epochs = 40 -lr = nn.cosine_decay_lr(min_lr=0.00001, max_lr=0.001, total_step=step_size * num_epochs, - step_per_epoch=step_size, decay_epoch=num_epochs) -# Define an optimizer and a loss function. +lr = nn.cosine_decay_lr(min_lr=0.00001, max_lr=0.001, total_step=step_size_train * num_epochs, + step_per_epoch=step_size_train, decay_epoch=num_epochs) +# Define optimizer and loss function opt = nn.Momentum(params=network.trainable_params(), learning_rate=lr, momentum=0.9) -loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') -# Instantiate the model. -model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()}) -# Perform model training. -model.train(num_epochs, ds_train, callbacks=[ValAccMonitor(model, ds_val, num_epochs)]) +loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + + +def forward_fn(inputs, targets): + logits = network(inputs) + loss = loss_fn(logits, targets) + + return loss + +grad_fn = ops.value_and_grad(forward_fn, None, opt.parameters) + +def train_step(inputs, targets): + loss, grads = grad_fn(inputs, targets) + opt(grads) + return loss + +# Instantiate models +model = ms.Model(network, loss_fn, opt, metrics={"Accuracy": nn.Accuracy()}) ``` -```Text --------------------- -Epoch: [ 0 / 40], Train Loss: [2.733], Accuracy: 0.274 --------------------- -Epoch: [ 1 / 40], Train Loss: [2.877], Accuracy: 0.319 --------------------- -Epoch: [ 2 / 40], Train Loss: [2.438], Accuracy: 0.249 --------------------- -Epoch: [ 3 / 40], Train Loss: [1.532], Accuracy: 0.386 - -······ - -Epoch: [ 37 / 40], Train Loss: [1.142], Accuracy: 0.738 --------------------- -Epoch: [ 38 / 40], Train Loss: [0.402], Accuracy: 0.727 --------------------- -Epoch: [ 39 / 40], Train Loss: [2.031], Accuracy: 0.735 --------------------- -Epoch: [ 40 / 40], Train Loss: [0.582], Accuracy: 0.745 -================================================================================ -End of validation the best Accuracy is: 0.754, save the best ckpt file in ./best.ckpt +```python +# Creating Iterators +data_loader_train = dataset_train.create_tuple_iterator(num_epochs=num_epochs) +data_loader_val = dataset_val.create_tuple_iterator(num_epochs=num_epochs) + +# Optimal model storage path +best_acc = 0 +best_ckpt_dir = "./BestCheckpoint" +best_ckpt_path = "./BestCheckpoint/resnet50-best.ckpt" +``` + +```python +import os + +# Start circuit training +print("Start Training Loop ...") +for epoch in range(num_epochs): + losses = [] + network.set_train() + + # Read in data for each training round + + for i, (images, labels) in enumerate(data_loader_train): + loss = train_step(images, labels) + if i%100 == 0 or i == step_size_train -1: + print('Epoch: [%3d/%3d], Steps: [%3d/%3d], Train Loss: [%5.3f]'%( + epoch+1, num_epochs, i+1, step_size_train, loss)) + losses.append(loss) + + # Verify the accuracy after each epoch + + acc = model.eval(dataset_val)['Accuracy'] + + print("-" * 50) + print("Epoch: [%3d/%3d], Average Train Loss: [%5.3f], Accuracy: [%5.3f]" % ( + epoch+1, num_epochs, sum(losses)/len(losses), acc + )) + print("-" * 50) + + if acc > best_acc: + best_acc = acc + if not os.path.exists(best_ckpt_dir): + os.mkdir(best_ckpt_dir) + if os.path.exists(best_ckpt_path): + os.remove(best_ckpt_path) + ms.save_checkpoint(network, best_ckpt_path) + +print("=" * 80) +print(f"End of validation the best Accuracy is: {best_acc: 5.3f}, " + f"save the best ckpt file in {best_ckpt_path}", flush=True) ``` ## Visualizing Model Prediction Results @@ -365,33 +534,40 @@ Define the `visualize_model` function, use the model with the highest validation ```python import matplotlib.pyplot as plt -import mindspore as ms -from mindspore.train import Model -def visualize_model(best_ckpt_path, val_ds): +def visualize_model(best_ckpt_path, dataset_val): num_class = 10 # Perform binary classification on wolf and dog images. net = resnet50(num_class) # Load model parameters. param_dict = ms.load_checkpoint(best_ckpt_path) ms.load_param_into_net(net, param_dict) - model = Model(net) + model = ms.Model(net) # Load the validation dataset. - data = next(val_ds.create_dict_iterator()) + data = next(dataset_val.create_dict_iterator()) images = data["image"].asnumpy() labels = data["label"].asnumpy() # Predict the image type. output = model.predict(ms.Tensor(data['image'])) pred = np.argmax(output.asnumpy(), axis=1) + # Image classification + classes = [] + + with open(data_dir+"/batches.meta.txt", "r") as f: + for line in f: + line = line.rstrip() + if line != '': + classes.append(line) + # Display the image and the predicted value of the image. plt.figure() - for i in range(1, 7): - plt.subplot(2, 3, i) + for i in range(6): + plt.subplot(2, 3, i+1) # If the prediction is correct, the color is blue. If the prediction is incorrect, the color is red. - color = 'blue' if pred[i - 1] == labels[i - 1] else 'red' - plt.title('predict:{}'.format(dataset_val.index2label[pred[i - 1]]), color=color) - picture_show = np.transpose(images[i - 1], (1, 2, 0)) + color = 'blue' if pred[i] == labels[i] else 'red' + plt.title('predict:{}'.format(classes[pred[i]]), color=color) + picture_show = np.transpose(images[i], (1, 2, 0)) mean = np.array([0.4914, 0.4822, 0.4465]) std = np.array([0.2023, 0.1994, 0.2010]) picture_show = std * picture_show + mean @@ -402,7 +578,7 @@ def visualize_model(best_ckpt_path, val_ds): plt.show() # Use the test dataset for validation. -visualize_model('best.ckpt', ds_val) +visualize_model(best_ckpt_path=best_ckpt_path, dataset_val=dataset_val) ``` ![png](images/output_161_0.png) diff --git a/tutorials/application/source_zh_cn/cv/resnet50.ipynb b/tutorials/application/source_zh_cn/cv/resnet50.ipynb index 7774028aed561eeb985fdc34cf3ce0a02ab015d6..298513b061c1f744bdfe750d8692d803bfd4a51f 100644 --- a/tutorials/application/source_zh_cn/cv/resnet50.ipynb +++ b/tutorials/application/source_zh_cn/cv/resnet50.ipynb @@ -31,7 +31,7 @@ "source": [ "## 数据集准备与加载\n", "\n", - "[CIFAR-10数据集](http://www.cs.toronto.edu/~kriz/cifar.html)共有60000张32*32的彩色图像,分为10个类别,每类有6000张图,数据集一共有50000张训练图片和10000张评估图片。首先,如下示例使用`download`接口下载并解压,目前仅支持解析二进制版本的CIFAR-10文件(CIFAR-10 binary version)" + "[CIFAR-10数据集](http://www.cs.toronto.edu/~kriz/cifar.html)共有60000张32*32的彩色图像,分为10个类别,每类有6000张图,数据集一共有50000张训练图片和10000张评估图片。首先,如下示例使用`download`接口下载并解压,目前仅支持解析二进制版本的CIFAR-10文件(CIFAR-10 binary version)。" ] }, {