DBN implementation for the MNIST dataset
Let's look at how the DBN class implemented earlier is used for the MNIST dataset.
Loading the dataset
First, we load the dataset from idx3
and idx1
formats into test, train, and validation sets. We need to import TensorFlow common utilities that are defined in the common module explained here:
import tensorflow as tf from common.models.boltzmann import dbn from common.utils import datasets, utilities
trainX, trainY, validX, validY, testX, testY = datasets.load_mnist_dataset(mode='supervised')
You can find details about load_mnist_dataset()
in the following code listing. As mode='supervised'
is set, the train, test, and validation labels are returned:
def load_mnist_dataset(mode='supervised', one_hot=True): mnist = input_data.read_data_sets("MNIST_data/", one_hot=one_hot) # Training set trX = mnist.train.images trY = mnist.train.labels # Validation set vlX = mnist.validation.images vlY = mnist.validation.labels # Test set ...