Explicit loss function modification
In PyTorch, we can formulate custom loss functions by deriving a subclass from the nn.Module
class and overriding the forward()
method. The forward()
method for a loss function accepts the predicted and actual outputs as inputs, subsequently returning the computed loss value.
Even though class weighting does assign different weights to balance the majority and minority class examples, this alone is often insufficient, especially in cases of extreme class imbalance. What we would like is to reduce the loss from easily classified examples as well. The reason is that such easily classified examples usually belong to the majority class, and since they are higher in number, they dominate our training loss. This is the main idea of focal loss and allows for a more nuanced handling of examples, irrespective of the class they belong to. We’ll look at this in this section.
Understanding the forward() method in PyTorch
In PyTorch, you’...