From e3884091cbb73701e4ba9cba9334258e217f19a9 Mon Sep 17 00:00:00 2001 From: wukesong Date: Wed, 1 Apr 2020 19:33:14 +0800 Subject: [PATCH] update lenet alexnet --- tutorials/source_zh_cn/quick_start/quick_start.md | 6 ++---- tutorials/tutorial_code/lenet.py | 5 ++--- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/tutorials/source_zh_cn/quick_start/quick_start.md b/tutorials/source_zh_cn/quick_start/quick_start.md index c00b41ef18..96bbb5b396 100644 --- a/tutorials/source_zh_cn/quick_start/quick_start.md +++ b/tutorials/source_zh_cn/quick_start/quick_start.md @@ -228,8 +228,6 @@ def fc_with_initialize(input_channels, out_channels): 神经网络的各层需要预先在`__init__()`方法中定义,然后通过定义`construct()`方法来完成神经网络的前向构造。按照LeNet的网络结构,定义网络各层如下: ```python -import mindspore.ops.operations as P - class LeNet5(nn.Cell): """ Lenet network structure @@ -245,7 +243,7 @@ class LeNet5(nn.Cell): self.fc3 = fc_with_initialize(84, 10) self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) - self.reshape = P.Reshape() + self.flatten = nn.Flatten() #use the preceding operators to construct networks def construct(self, x): @@ -255,7 +253,7 @@ class LeNet5(nn.Cell): x = self.conv2(x) x = self.relu(x) x = self.max_pool2d(x) - x = self.reshape(x, (self.batch_size, -1)) + x = self.flatten(x) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) diff --git a/tutorials/tutorial_code/lenet.py b/tutorials/tutorial_code/lenet.py index f3899de146..a9b4571ffb 100644 --- a/tutorials/tutorial_code/lenet.py +++ b/tutorials/tutorial_code/lenet.py @@ -24,7 +24,6 @@ from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.train import Model -import mindspore.ops.operations as P from mindspore.common.initializer import TruncatedNormal import mindspore.dataset.transforms.vision.c_transforms as CV import mindspore.dataset.transforms.c_transforms as C @@ -150,7 +149,7 @@ class LeNet5(nn.Cell): self.fc3 = fc_with_initialize(84, 10) self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) - self.reshape = P.Reshape() + self.flatten = nn.Flatten() # use the preceding operators to construct networks def construct(self, x): @@ -160,7 +159,7 @@ class LeNet5(nn.Cell): x = self.conv2(x) x = self.relu(x) x = self.max_pool2d(x) - x = self.reshape(x, (self.batch_size, -1)) + x = self.flatten(x) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) -- Gitee