Wavelet Pooling CNN
Introduction
It is generally the case that Convolution Neural Networks (CNNs) out perform vector-based deep learning techniques. As such, the fundamentals of CNNs are good candidates to be innovated in order to improve said performance. The pooling layer is one of these fundamentals, and although various methods exist ranging from deterministic and simple: max pooling and average pooling, to probabilistic: mixed pooling and stochastic pooling, all these methods employ a neighborhood approach to the sub-sampling which, albeit fast and simple, can produce artifacts such as blurring, aliasing, and edge halos (Parker et al., 1983).
This paper introduces a novel pooling method based on the discrete wavelet transform. Specifically, it uses a second-level wavelet decomposition for the sub-sampling. This method, instead of nearest neighbor interpolation, uses a sub-band method that the authors claim produces less artifacts and represents the underlying features more accurately. Therefore, if pooling is viewed as a lossy process, the reason for employing a wavelet approach is to try to minimize this loss.
Pooling Background
Pooling essentially means sub-sampling. After the pooling layer, the spatial dimensions of the data is reduced to some degree, with the goal being to compress the data rather than discard some of it. Typical approaches to pooling reduce the dimensionality by using some method to combine a region of values into one. For max pooling, this can be represented by the equation (EQUATION) where akij is the output activation of the kth feature map at (i,j), akpq is input activation at (p,q) within Rij, and Rij is the size of the pooling region. Mean pooling can be represented by the equation (EQUATION) with everything defined as before. Figure 1 provides a numerical example that can be followed.
The paper mentions that these pooling methods, although simple and effective, have shortcomings. Max pooling can omit details from an image if the important features have less intensity than the insignificant ones, and also commonly overfits. On the other hand, average pooling can dilute important features if the data is averaged with values of significantly lower intensities. Figure 2 displays an image of this.
Wavelet Background
Data or signals tend to be composed of slowly changing trends (low frequency) as well as fast changing transients (high frequency). Similarly, images have smooth regions of intensity which are perturbed by edges or abrupt changes. We know that these abrupt changes can represent features that are of great importance to us when we perform deep learning. Wavelets are a class of functions that are well localized in time and frequency. Compare this to the Fourier transform which represents signals as the sum of sine waves which oscillate forever (not localized in time and space). The ability of wavelets to be localized in time and space is what makes it suitable for detecting the abrupt changes in an image well.
Essentially, a wavelet is a fast decaying, oscillating signal with zero mean that only exists for a fixed duration and can be scaled and shifted in time. There are some well defined types of wavelets as shown in Figure 3. The key characteristic of wavelets for us is that they have a band-pass characteristic, and the band can be adjusted based on the scaling and shifting.
The paper uses discrete wavelet transform and more specifically a faster variation called Fast Wavelet Transform (FWT) using the Haar wavelet. There also exists a continuous wavelet transform. The main difference in these is how the scale and shift parameters are selected.
Fast Wavelet Transform
Proposed Method
The proposed method uses subbands from the second level FWT and discards the first level subbands. The authors postulate that this method is more 'organic' in capturing the data compression and will create less artifacts that may affect the image classification.
Forward Propagation
FWT in two dimensions can be expressed by [Insert equation] and [Insert Equation] where [] is the approximation function, [] is the detail function, W, W, are approximation and detail coefficients, h and h are time reversed scaling and wavelet vectors, (n) represents the sample in the vector, and j denotes the resolution level. To apply to images, FWT is first applied on the rows and then the coloumns. If a low (L) and high(H) subband is extracted from the rows and similarly for the columns than at each level there is 4 subbands (LH, HL, HH, and LL) where LL will further be decomposed into the level 2 decomposition.
Using the level 2 decomposition subbands, the Inverse Fast Wavelet Transform (IFWT) is used to obtain the resulting sub sampled image, which is subsampled by a factor of two. The Equation for IFWT is [] where the parameters are the same as previously explained. Figure 4 displays the algorithm for the forward propagation.
[Insert Figure 4]
Back Propagation
This is simply the reverse of the forward propagation. The FWT of the image is upsampled to be used as the level 2 decomposition. Then IFWT is performed to obtain the original image which is upsampled by a factor of two using wavelet methods. Figure 5 displays the algorithm.
[Insert Figure 5]
Results
The authors tested on MNIST, CIFAR-10, SHVN, and KDEF and the paper provides comprehensive results for each. Stochastic gradient descent was used and the Haar wavelet is used due to its even, square subbands. The network for all datasets except MNIST is loosedly based on (Zeiler & Fergus, 2013). The authors keep the network consistent, but change the pooling method for each dataset. They also experiment with dropout and Batch Normalization to examine the effects of regularization on their method. All pooling methods compared use a 2x2 window. The overall results teach us that the pooling method should be chosen specific to the type of data we have. In some cases wavelet pooling may perform the best, and in other cases, other methods may perform better, if the data is more suited for those types of pooling.
MNIST
Figure 6 shows the network and Table 1 shows the accuracy. It can be seen that wavelet pooling achieves the best accuracy from all pooling methods compared.
CIFAR-10
Figure 7 shows the network and Tables 2 and 3 shows the accuracy without and with dropout. Average pooling achieves the best accuracy but wavelet pooling is still competitive.
Computational Complexity
The authors explain that their paper is a proof of concept and is not meant to implement wavelet pooling in the most efficient way. The table below displays a comparison of the number of mathematical operations for each method according to the dataset. It can be seen that wavelet pooling is significantly worse. The authors explain that through good implementation and coding practices, the method can prove to be viable.
Criticism
Positive
- Wavelet Pooling achieves competitive performance with standard go to pooling methods
- Leads to comparison of discrete transformation techniques for pooling (DCT, DFT)
Negative
- Only 2x2 pooling window used for comparison
- Highly computationally extensive
- Not as simple as other pooling methods
- Only one wavelet used (HAAR wavelet)