From fc438758d1a4de3f7f3310f511ec3a1892cee495 Mon Sep 17 00:00:00 2001 From: wukesong Date: Thu, 9 Jul 2020 16:11:02 +0800 Subject: [PATCH] modify lenet alexent --- .../computer_vision_application.md | 10 +--- .../advanced_use/dashboard_and_lineage.md | 56 ++++++++++++++++++- .../computer_vision_application.md | 10 +--- .../advanced_use/dashboard_and_lineage.md | 56 ++++++++++++++++++- 4 files changed, 114 insertions(+), 18 deletions(-) diff --git a/tutorials/source_en/advanced_use/computer_vision_application.md b/tutorials/source_en/advanced_use/computer_vision_application.md index be0522cd85..845b2f0336 100644 --- a/tutorials/source_en/advanced_use/computer_vision_application.md +++ b/tutorials/source_en/advanced_use/computer_vision_application.md @@ -36,12 +36,7 @@ def classify(image): The key point is to select a proper model. The model generally refers to a deep convolutional neural network (CNN), such as AlexNet, VGG, GoogleNet, and ResNet. -MindSpore presets a typical CNN, such as LeNet, which can be directly used by developers. The usage method is as follows: - -```python -from mindspore.model_zoo.lenet import LeNet5 -network = LeNet(num_classes) -``` +MindSpore presets a typical CNN, developer can visit [model_zoo](https://gitee.com/mindspore/mindspore/tree/master/model_zoo) to get more details. MindSpore supports the following image classification networks: LeNet, AlexNet, and ResNet. @@ -148,10 +143,9 @@ CNN is a standard algorithm for image classification tasks. CNN uses a layered s ResNet is recommended. First, it is deep enough with 34 layers, 50 layers, or 101 layers. The deeper the hierarchy, the stronger the representation capability, and the higher the classification accuracy. Second, it is learnable. The residual structure is used. The lower layer is directly connected to the upper layer through the shortcut connection, which solves the problem of gradient disappearance caused by the network depth during the reverse propagation. In addition, the ResNet network has good performance, including the recognition accuracy, model size, and parameter quantity. -MindSpore Model Zoo has a built-in ResNet model. In this example, the ResNet-50 network is used. The calling method is as follows: +MindSpore Model Zoo has a ResNet [model](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/resnet/src/resnet.py). The calling method is as follows: ```python -from mindspore.model_zoo.resnet import resnet50 network = resnet50(class_num=10) ``` diff --git a/tutorials/source_en/advanced_use/dashboard_and_lineage.md b/tutorials/source_en/advanced_use/dashboard_and_lineage.md index 551be944c7..7d604ba935 100644 --- a/tutorials/source_en/advanced_use/dashboard_and_lineage.md +++ b/tutorials/source_en/advanced_use/dashboard_and_lineage.md @@ -52,9 +52,63 @@ import mindspore.nn as nn from mindspore import context from mindspore import Tensor from mindspore.train import Model -from mindspore.model_zoo.alexnet import AlexNet +from mindspore.common.initializer import TruncatedNormal +from mindspore.ops import operations as P from mindspore.train.callback import SummaryCollector +"""AlexNet initial.""" +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0, pad_mode="valid"): + weight = weight_variable() + return nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + weight_init=weight, has_bias=False, pad_mode=pad_mode) + +def fc_with_initialize(input_channels, out_channels): + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + +def weight_variable(): + return TruncatedNormal(0.02) # 0.02 + + +class AlexNet(nn.Cell): + def __init__(self, num_classes=10, channel=3): + super(AlexNet, self).__init__() + self.conv1 = conv(channel, 96, 11, stride=4) + self.conv2 = conv(96, 256, 5, pad_mode="same") + self.conv3 = conv(256, 384, 3, pad_mode="same") + self.conv4 = conv(384, 384, 3, pad_mode="same") + self.conv5 = conv(384, 256, 3, pad_mode="same") + self.relu = nn.ReLU() + self.max_pool2d = P.MaxPool(ksize=3, strides=2) + self.flatten = nn.Flatten() + self.fc1 = fc_with_initialize(6*6*256, 4096) + self.fc2 = fc_with_initialize(4096, 4096) + self.fc3 = fc_with_initialize(4096, num_classes) + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv3(x) + x = self.relu(x) + x = self.conv4(x) + x = self.relu(x) + x = self.conv5(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + return x + context.set_context(mode=context.GRAPH_MODE) network = AlexNet(num_classes=10) diff --git a/tutorials/source_zh_cn/advanced_use/computer_vision_application.md b/tutorials/source_zh_cn/advanced_use/computer_vision_application.md index d8839c1b94..5592f05548 100644 --- a/tutorials/source_zh_cn/advanced_use/computer_vision_application.md +++ b/tutorials/source_zh_cn/advanced_use/computer_vision_application.md @@ -36,12 +36,7 @@ def classify(image): 选择合适的model是关键。这里的model一般指的是深度卷积神经网络,如AlexNet、VGG、GoogLeNet、ResNet等等。 -MindSpore预置了典型的卷积神经网络,开发者可以直接使用,如LeNet,使用方式如下: - -```python -from mindspore.model_zoo.lenet import LeNet5 -network = LeNet(num_classes) -``` +MindSpore实现了典型的卷积神经网络,开发者可以参考[model_zoo](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。 MindSpore当前支持的图像分类网络包括:典型网络LeNet、AlexNet、ResNet。 @@ -150,10 +145,9 @@ tar -zvxf cifar-10-binary.tar.gz ResNet通常是较好的选择。首先,它足够深,常见的有34层,50层,101层。通常层次越深,表征能力越强,分类准确率越高。其次,可学习,采用了残差结构,通过shortcut连接把低层直接跟高层相连,解决了反向传播过程中因为网络太深造成的梯度消失问题。此外,ResNet网络的性能很好,既表现为识别的准确率,也包括它本身模型的大小和参数量。 -MindSpore Model Zoo中已经内置了ResNet模型,可以采用ResNet-50网络,调用方法如下: +MindSpore Model Zoo中已经实现了ResNet模型,可以采用[ResNet-50](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/resnet/src/resnet.py)。调用方法如下: ```python -from mindspore.model_zoo.resnet import resnet50 network = resnet50(class_num=10) ``` diff --git a/tutorials/source_zh_cn/advanced_use/dashboard_and_lineage.md b/tutorials/source_zh_cn/advanced_use/dashboard_and_lineage.md index 2cf578aae9..a721d2e1a8 100644 --- a/tutorials/source_zh_cn/advanced_use/dashboard_and_lineage.md +++ b/tutorials/source_zh_cn/advanced_use/dashboard_and_lineage.md @@ -54,9 +54,63 @@ import mindspore.nn as nn from mindspore import context from mindspore import Tensor from mindspore.train import Model -from mindspore.model_zoo.alexnet import AlexNet +from mindspore.common.initializer import TruncatedNormal +from mindspore.ops import operations as P from mindspore.train.callback import SummaryCollector +"""AlexNet initial.""" +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0, pad_mode="valid"): + weight = weight_variable() + return nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + weight_init=weight, has_bias=False, pad_mode=pad_mode) + +def fc_with_initialize(input_channels, out_channels): + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + +def weight_variable(): + return TruncatedNormal(0.02) # 0.02 + + +class AlexNet(nn.Cell): + def __init__(self, num_classes=10, channel=3): + super(AlexNet, self).__init__() + self.conv1 = conv(channel, 96, 11, stride=4) + self.conv2 = conv(96, 256, 5, pad_mode="same") + self.conv3 = conv(256, 384, 3, pad_mode="same") + self.conv4 = conv(384, 384, 3, pad_mode="same") + self.conv5 = conv(384, 256, 3, pad_mode="same") + self.relu = nn.ReLU() + self.max_pool2d = P.MaxPool(ksize=3, strides=2) + self.flatten = nn.Flatten() + self.fc1 = fc_with_initialize(6*6*256, 4096) + self.fc2 = fc_with_initialize(4096, 4096) + self.fc3 = fc_with_initialize(4096, num_classes) + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv3(x) + x = self.relu(x) + x = self.conv4(x) + x = self.relu(x) + x = self.conv5(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + return x + context.set_context(mode=context.GRAPH_MODE) network = AlexNet(num_classes=10) -- Gitee