Weighting techniques
Let’s continue to use the imbalanced MNIST dataset from the previous chapter, which has long-tailed data distribution, as shown in the following bar chart (Figure 8.1):
Figure 8.1 – Imbalanced MNIST dataset
Here, the x axis is the class label, and the y axis is the count of samples of various classes. In the next section, we will see how to use the weight parameter in PyTorch.
We will use the following model code for all the vision-related tasks in this chapter. We have defined a PyTorch neural network class called Net
with two convolutional layers, a dropout layer, and two fully connected layers. The forward
method applies these layers sequentially along with ReLU activations and max-pooling to process the input, x
. Finally, it returns the log_softmax
activation of the output:
class Net(torch.nn.Module): def __init__(self): super(Net, self...