Understanding EfficientNets and the future of CNN architectures
So far in our exploration from LeNet to DenseNet, we have noticed an underlying theme in the advancement of CNN architectures. That theme is the expansion or scaling of the CNN model through one of the following:
- An increase in the number of layers
- An increase in the number of feature maps or channels in a convolutional layer
- An increase in the spatial dimension going from 32x32 pixel images in LeNet to 224x224 pixel images in AlexNet and so on
These three different aspects on which scaling can be performed are identified as depth, width, and resolution, respectively. Instead of manually scaling these attributes, which often leads to suboptimal results, EfficientNets use neural architecture search to calculate the optimal scaling factors for each of them.
Scaling up depth is deemed important because the deeper the network, the more complex the model, and hence it can learn highly complex features. However, there is a trade-off because, with increasing depth, the vanishing gradient problem escalates along with the general problem of overfitting.
Similarly, scaling up width should theoretically help, as with a greater number of channels, the network should learn more fine-grained features. However, for extremely wide models, the accuracy tends to saturate quickly.
Finally, higher-resolution images, in theory, should work better as they have more fine-grained information. Empirically, however, the increase in resolution does not yield a linearly equivalent increase in the model performance. All of this is to say that there are trade-offs to be made while deciding the scaling factors and hence, neural architecture search helps in finding the optimal scaling factors.
EfficientNet proposes finding the architecture that has the right balance between depth, width, and resolution, and all three of these aspects are scaled together using a global scaling factor. The EfficientNet architecture is built in two steps. First, a basic architecture (called the base network) is devised by fixing the scaling factor to 1
. At this stage, the relative importance of depth, width, and resolution is decided for the given task and dataset. The base network obtained is pretty similar to a well-known CNN architecture – MnasNet, short for Mobile Neural Architecture Search Network.
PyTorch offers the pretrained MnasNet model, which can be loaded as shown here:
import torchvision.models as models
model = models.mnasnet1_0()
Once the base network is obtained in the first step, the optimal global scaling factor is then computed with the aim of maximizing the accuracy of the model and minimizing the number of computations (or flops). The base network is called EfficientNet B0 and the subsequent networks derived for different optimal scaling factors are called EfficientNet B1-B7. PyTorch provides pretrained models for all of these variants:
import torchvision.models as models
efficientnet_b0 = models.efficientnet_b0(pretrained=True)
efficientnet_b1 = models.efficientnet_b1(pretrained=True)
...
efficientnet_b7 = models.efficientnet_b7(pretrained=True)
As we go forward, efficient scaling of CNN architecture is going to be a prominent direction of research along with the development of more sophisticated modules inspired by the inception, residual, and dense modules. Another aspect of CNN architecture development is minimizing the model size while retaining performance. MobileNets [9] are a prime example and there is a lot of ongoing research on this front.
Besides the top-down approach of looking at architectural modifications of a pre-existing model, there will be continued efforts to adopt the bottom-up view of fundamentally rethinking the units of CNNs such as the convolutional kernels, pooling mechanism, more effective ways of flattening, and so on. One concrete example of this would be CapsuleNet [10], which revamped the convolutional units to cater to the third dimension (depth) in images.
CNNs are a huge topic of study in themselves. In this chapter, we have touched upon the architectural development of CNNs, mostly in the context of image classification. However, these same architectures are used across a wide variety of applications. One well-known example is the use of ResNets for object detection and segmentation in the form of RCNNs [11].
Some of the improved variants of RCNNs are Faster R-CNN, Mask-RCNN, and Keypoint-RCNN. PyTorch provides pretrained models for all three variants:
faster_rcnn = models.detection.fasterrcnn_resnet50_fpn()
mask_rcnn = models.detection.maskrcnn_resnet50_fpn()
keypoint_rcnn = models.detection.keypointrcnn_resnet50_fpn()
PyTorch also provides pretrained models for ResNets that are applied to video-related tasks such as video classification. Two such ResNet-based models used for video classification are ResNet3D and ResNet Mixed Convolution:
resnet_3d = models.video.r3d_18()
resnet_mixed_conv = models.video.mc3_18()
While we do not extensively cover these different applications and corresponding CNN models in this chapter, we encourage you to read more on them. PyTorch’s website can be a good starting point [12].