Training a multiclass classification neural network
In this recipe, we will have a look at another very common task: multiclass classification with neural networks, in this instance using PyTorch. We will work on a very iconic dataset in deep learning: MNIST handwritten digit recognition. This dataset is a set of small grayscale images of 28x28 pixels, depicting handwritten digits between 0 and 9, having thus 10 classes.
Getting ready
In classical machine learning, multiclass classification is usually not handled natively. For example, when training logistic regression with scikit-learn on a three-class task (e.g., the Iris dataset), scikit-learn will automatically train three models, using the one-versus-the-rest method.
In deep learning, it is possible for the model to natively handle more than two classes. To do so, only a few changes are required compared to binary classification:
- The output layer has as many units as classes: this way, each unit will be responsible...