In this post, you will learn the notion behind Autoencoders as well as how to implement an autoencoder in TensorFlow.
Autoencoders are a kind of neural networks which imitate their inputs and produce the exact information at their outputs. They usually include two parts: Encoder and Decoder. The encoder transforms the input into a hidden space (hidden layer). The decoder then reconstructs the input information as the output. There are various types of autoencoders:
In this post, we are going to design an Undercomplete Autoencoder in TensorFlow to train a low dimension representation.
We are working on building an autoencoder with a 3-layer encoder and 3-layer decoder. Each layer of encoder compresses its input along the spatial dimensions by a factor of two. Similarly, each segment of the decoder increases its input dimensionality by a factor of two.
import tensorflow.contrib.layers as lays
def autoencoder(inputs):
# encoder
# 32 file code blockx 32 x 1 -> 16 x 16 x 32
# 16 x 16 x 32 -> 8 x 8 x 16
# 8 x 8 x 16 -> 2 x 2 x 8
net = lays.conv2d(inputs, 32, [5, 5], stride=2, padding='SAME')
net = lays.conv2d(net, 16, [5, 5], stride=2, padding='SAME')
net = lays.conv2d(net, 8, [5, 5], stride=4, padding='SAME')
# decoder
# 2 x 2 x 8 -> 8 x 8 x 16
# 8 x 8 x 16 -> 16 x 16 x 32
# 16 x 16 x 32 -> 32 x 32 x 1
net = lays.conv2d_transpose(net, 16, [5, 5], stride=4, padding='SAME')
net = lays.conv2d_transpose(net, 32, [5, 5], stride=2, padding='SAME')
net = lays.conv2d_transpose(net, 1, [5, 5], stride=2, padding='SAME', activation_fn=tf.nn.tanh)
return net
Figure 1: Autoencoder
The MNIST dataset contains vectorized images of 28X28. Therefore we define a new function to reshape each batch of MNIST images to 28X28 and then resize to 32X32. The reason of resizing to 32X32 is to make it a power of two and therefore we can easily use the stride of 2 for downsampling and upsampling.
import numpy as np
from skimage import transform
def resize_batch(imgs):
# A function to resize a batch of MNIST images to (32, 32)
# Args:
# imgs: a numpy array of size [batch_size, 28 X 28].
# Returns:
# a numpy array of size [batch_size, 32, 32].
imgs = imgs.reshape((-1, 28, 28, 1))
resized_imgs = np.zeros((imgs.shape[0], 32, 32, 1))
for i in range(imgs.shape[0]):
resized_imgs[i, ..., 0] = transform.resize(imgs[i, ..., 0], (32, 32))
return resized_imgs
Now we create an autoencoder, define a square error loss and an optimizer.
import tensorflow as tf
ae_inputs = tf.placeholder(tf.float32, (None, 32, 32, 1)) # input to the network (MNIST images)
ae_outputs = autoencoder(ae_inputs) # create the Autoencoder network
# calculate the loss and optimize the network
loss = tf.reduce_mean(tf.square(ae_outputs - ae_inputs)) # claculate the mean square error loss
train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss)
# initialize the network
init = tf.global_variables_initializer()
Now we can read the batches, train the network and finally test the network by reconstructing a batch of test images.
from tensorflow.examples.tutorials.mnist import input_data
batch_size = 500 # Number of samples in each batch
epoch_num = 5 # Number of epochs to train the network
lr = 0.001 # Learning rate
# read MNIST dataset
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
# calculate the number of batches per epoch
batch_per_ep = mnist.train.num_examples // batch_size
with tf.Session() as sess:
sess.run(init)
for ep in range(epoch_num): # epochs loop
for batch_n in range(batch_per_ep): # batches loop
batch_img, batch_label = mnist.train.next_batch(batch_size) # read a batch
batch_img = batch_img.reshape((-1, 28, 28, 1)) # reshape each sample to an (28, 28) image
batch_img = resize_batch(batch_img) # reshape the images to (32, 32)
_, c = sess.run([train_op, loss], feed_dict={ae_inputs: batch_img})
print('Epoch: {} - cost= {:.5f}'.format((ep + 1), c))
# test the trained network
batch_img, batch_label = mnist.test.next_batch(50)
batch_img = resize_batch(batch_img)
recon_img = sess.run([ae_outputs], feed_dict={ae_inputs: batch_img})[0]
# plot the reconstructed images and their ground truths (inputs)
plt.figure(1)
plt.title('Reconstructed Images')
for i in range(50):
plt.subplot(5, 10, i+1)
plt.imshow(recon_img[i, ..., 0], cmap='gray')
plt.figure(2)
plt.title('Input Images')
for i in range(50):
plt.subplot(5, 10, i+1)
plt.imshow(batch_img[i, ..., 0], cmap='gray')
plt.show()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。