Reconstructing sample images
We will also reconstruct some sample images to see how the model is performing. We will use the following images as the input:
The code for reconstructing the preceding images is as follows:
def reconstruct_sample(model, n_samples=5):
x_test, y_test = load_data(load_type='test')
sample_images, sample_labels = x_test[:BATCH_SIZE], y_test[:BATCH_SIZE]
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state(CHECKPOINT_PATH_DIR)
with tf.Session() as sess:
saver.restore(sess, ckpt.model_checkpoint_path)
feed_dict_samples = {model.X: sample_images, model.Y: sample_labels}
decoder_out, y_predicted = sess.run([model.decoder_output, model.y_predicted],
feed_dict=feed_dict_samples)
reconstruction(sample_images, sample_labels, decoder_out, y_predicted, n_samples)
The reconstruction function for plotting the images and saving them is given as follows:
def reconstruction(x, y, decoder_output, y_pred, n_samples): ''' This function is used...