Wavelet Pooling CNN

From statwiki
Jump to navigation Jump to search

Introduction

It is generally the case that Convolution Neural Networks (CNNs) out perform vector-based deep learning techniques. As such, the fundamentals of CNNs are good candidates to be innovated in order to improve said performance. The pooling layer is one of these fundamentals, and although various methods exist ranging from deterministic and simple: max pooling and average pooling, to probabilistic: mixed pooling and stochastic pooling, all these methods employ a neighborhood approach to the sub-sampling which, albeit fast and simple, can produce artifacts such as blurring, aliasing, and edge halos (Parker et al., 1983).

This paper introduces a novel pooling method based on the discrete wavelet transform. Specifically, it uses a second-level wavelet decomposition for the sub-sampling. This method, instead of nearest neighbor interpolation, uses a sub-band method that the authors claim produces less artifacts and represents the underlying features more accurately. Therefore, if pooling is viewed as a lossy process, the reason for employing a wavelet approach is to try to minimize this loss.

Pooling Background

Pooling essentially means sub-sampling. After the pooling layer, the spatial dimensions of the data is reduced to some degree, with the goal being to compress the data rather than discard some of it. Typical approaches to pooling reduce the dimensionality by using some method to combine a region of values into one value. For max pooling, this can be represented by the equation [math]\displaystyle{ a_{kij} = max_{(p,q) \epsilon R_{ij}} (a_{kpq}) }[/math] where [math]\displaystyle{ a_{kij} }[/math] is the output activation of the [math]\displaystyle{ k^th }[/math] feature map at [math]\displaystyle{ (i,j) }[/math], [math]\displaystyle{ a_{kpq} }[/math] is input activation at [math]\displaystyle{ (p,q) }[/math] within [math]\displaystyle{ R_{ij} }[/math], and [math]\displaystyle{ |R_{ij}| }[/math] is the size of the pooling region. Mean pooling can be represented by the equation [math]\displaystyle{ a_{kij} = \frac{1}{|R_{ij}|} \sum_{(p,q) \epsilon R_{ij}} (a_{kpq}) }[/math] with everything defined as before. Figure 1 provides a numerical example that can be followed.

The paper mentions that these pooling methods, although simple and effective, have shortcomings. Max pooling can omit details from an image if the important features have less intensity than the insignificant ones, and also commonly overfits. On the other hand, average pooling can dilute important features if the data is averaged with values of significantly lower intensities. Figure 2 displays an image of this.

Wavelet Background

Data or signals tend to be composed of slowly changing trends (low frequency) as well as fast changing transients (high frequency). Similarly, images have smooth regions of intensity which are perturbed by edges or abrupt changes. We know that these abrupt changes can represent features that are of great importance to us when we perform deep learning. Wavelets are a class of functions that are well localized in time and frequency. Compare this to the Fourier transform which represents signals as the sum of sine waves which oscillate forever (not localized in time and space). The ability of wavelets to be localized in time and space is what makes it suitable for detecting the abrupt changes in an image well.

Essentially, a wavelet is a fast decaying, oscillating signal with zero mean that only exists for a fixed duration and can be scaled and shifted in time. There are some well defined types of wavelets as shown in Figure 3. The key characteristic of wavelets for us is that they have a band-pass characteristic, and the band can be adjusted based on the scaling and shifting.

The paper uses discrete wavelet transform and more specifically a faster variation called Fast Wavelet Transform (FWT) using the Haar wavelet. There also exists a continuous wavelet transform. The main difference in these is how the scale and shift parameters are selected.

Discrete Wavelet Transform General

The discrete wavelet transform for images is essentially applying a low pass and high pass filter to your image where the transfer functions of the filters are related and defined by the type of wavelet used (Haar in this paper). This is shown in the figures below, which also show the recursive nature of the transform. For an image, the per row transform is taken first. This results in a new image where the first half is a low frequency sub-band and the second half is the high frequency sub-band. Then this new image is transformed again per column, resulting in four sub-bands. Generally, the low frequency content approximates the image and the high frequency content represents abrupt changes. Therefore, one can simply take the LL band and perform the transformation again to sub-sample even more.

DWT example using Haar Wavelet

Suppose we have an image represented by the following pixels: [math]\displaystyle{ \begin{bmatrix} 100 & 50 & 60 & 150 \\ 20 & 60 & 40 & 30 \\ 50 & 90 & 70 & 82 \\ 74 & 66 & 90 & 58 \\ \end{bmatrix} }[/math]

For each level of the DWT using the Haar wavelet, we will perform the transform on the rows first and then the columns. For the row pass, we transform each row as follows:

  • Take row i = [ i1, i2, i3, i4], and let i_t = [a1, a2, d1, d2] represent the transformed row
  • a1 = (i1 + i2)/2
  • a2 = (i3 + i4)/2
  • d1 = (i1 - i2)/2
  • d2 = (i3 - i4)/2

After the row transforms, the images looks as follows: [math]\displaystyle{ \begin{bmatrix} 75 & 105 & 25 & -45 \\ 40 & 35 & -20 & 5 \\ 70 & 76 & -20 & -6 \\ 70 & 74 & 4 & 16 \\ \end{bmatrix} }[/math]

Now we apply the same method to the columns in the exact same way.

Proposed Method

The proposed method uses subbands from the second level FWT and discards the first level subbands. The authors postulate that this method is more 'organic' in capturing the data compression and will create less artifacts that may affect the image classification.

Forward Propagation

FWT can be expressed by [Insert equation] and [Insert Equation] where [] is the approximation function, [] is the detail function, W, W, are approximation and detail coefficients, h and h are time reversed scaling and wavelet vectors, (n) represents the sample in the vector, and j denotes the resolution level. To apply to images, FWT is first applied on the rows and then the columns. If a low (L) and high(H) sub-band is extracted from the rows and similarly for the columns than at each level there is 4 sub-bands (LH, HL, HH, and LL) where LL will further be decomposed into the level 2 decomposition.

Using the level 2 decomposition sub-bands, the Inverse Fast Wavelet Transform (IFWT) is used to obtain the resulting sub-sampled image, which is sub-sampled by a factor of two. The Equation for IFWT is [] where the parameters are the same as previously explained. Figure 4 displays the algorithm for the forward propagation.

Back Propagation

This is simply the reverse of the forward propagation. The FWT of the image is upsampled to be used as the level 2 decomposition. Then IFWT is performed to obtain the original image which is upsampled by a factor of two using wavelet methods. Figure 5 displays the algorithm.

Results

The authors tested on MNIST, CIFAR-10, SHVN, and KDEF and the paper provides comprehensive results for each. Stochastic gradient descent was used and the Haar wavelet is used due to its even, square subbands. The network for all datasets except MNIST is loosedly based on (Zeiler & Fergus, 2013). The authors keep the network consistent, but change the pooling method for each dataset. They also experiment with dropout and Batch Normalization to examine the effects of regularization on their method. All pooling methods compared use a 2x2 window. The overall results teach us that the pooling method should be chosen specific to the type of data we have. In some cases wavelet pooling may perform the best, and in other cases, other methods may perform better, if the data is more suited for those types of pooling.

MNIST

Figure 6 shows the network and Table 1 shows the accuracy. It can be seen that wavelet pooling achieves the best accuracy from all pooling methods compared.

CIFAR-10

Figure 7 shows the network and Tables 2 and 3 shows the accuracy without and with dropout. Average pooling achieves the best accuracy but wavelet pooling is still competitive.


Computational Complexity

The authors explain that their paper is a proof of concept and is not meant to implement wavelet pooling in the most efficient way. The table below displays a comparison of the number of mathematical operations for each method according to the dataset. It can be seen that wavelet pooling is significantly worse. The authors explain that through good implementation and coding practices, the method can prove to be viable.

Criticism

Positive

  • Wavelet Pooling achieves competitive performance with standard go to pooling methods
  • Leads to comparison of discrete transformation techniques for pooling (DCT, DFT)

Negative

  • Only 2x2 pooling window used for comparison
  • Highly computationally extensive
  • Not as simple as other pooling methods
  • Only one wavelet used (HAAR wavelet)

References