Project two – classifying MNIST handwritten digits
For this classification project, we are going to categorize MNIST handwritten digits. In the previous section, we covered the four essential steps for machine learning in PyTorch in detail, which we will need to repeat in this section.
You will recall that in Chapter 12 you learned the way of loading available datasets from the torchvision
module. First, we are going to load the MNIST dataset using the torchvision
module.
- The setup step includes loading the dataset and specifying hyperparameters (the size of the train set and test set, and the size of mini-batches):
>>> import torchvision >>> from torchvision import transforms >>> image_path = './' >>> transform = transforms.Compose([ ... transforms.ToTensor() ... ]) >>> mnist_train_dataset = torchvision.datasets.MNIST( ... root=image_path, train=True, ... transform=transform, download...