When Does Self-Supervision Improve Few-Shot Learning?
Presented by
Arash Moayyedi
Introduction
This paper proposes a technique utilizing self-supervised learning (SSL) to improve the generalization of few-shot learned representations on small labelled data sets .
Few-shot learning refers to training a classifier on minimalist datasets, contrary to the normal practice of using massive data, in hope of successfully classifying previously unseen, but related classes.
Self-supervised learning aims at teaching the agent the internal structures of the images by providing it with tasks such as predicting the degree of rotation in an image. This method can help aid against generalization issues where the agent cannot distinguish the difference between newly introduced objects.
Previous Work
This work leverages few-shot learning, where we aim to learn general representations, so that when facing novel classes, the agent can differentiate between them with training on just a few samples. Many few-shot learning methods currently exist, among which is this paper which focuses on Prototypical Networks or ProtoNets[1] for short. There is also a section of this paper that compares this model with model-agnostic meta-learner (MAML)[2].
The other machine learning technique that this paper is based on is self-supervised learning. In this technique unlabelled data is utilized which can avoid incurring the computational expenses of labelling and maintaining a massive data set . Images already contains structural information that can be utilized. There exist many SSL tasks, such as removing a part of the data in order for the agent to reconstruct the lost part. Other methods include tasks prediction rotations, relative patch location, etc.
The work in this paper is also related to multi-task learning. In multi-task learning training proceeds on multiple tasks concurrently to improve each other. Training on multiple tasks is known to decline the performance on individual tasks[3] and this seems to work only for very specific combinations and architectures. This paper shows that the combination of self-supervised tasks and few-shot learning are mutually beneficial to each other and this has significant practical implications since self-supervised tasks do not require any annotations.
Method
The authors of this paper suggest a framework, as seen in Fig. 1, that combines few-shot learning with self-supervised learning.
In this a feed-forward convolutional network [math]\displaystyle{ f(x) }[/math] maps either a labelled image or an augmented unlabelled image to an embedding space. Depending on the input type the embedding is then mapped to one of two label spaces by either a classifier [math]\displaystyle{ g }[/math] or a function [math]\displaystyle{ h }[/math]. The labelled training data consists of a set of base classes in pairs of images and labels, and its domain is denoted by [math]\displaystyle{ \mathcal{D}_s }[/math]. Similarly, the domain of the unlabelled images used for the self-supervised tasks is shown by [math]\displaystyle{ \mathcal{D}_{ss} }[/math]. Within this domain augmentations will have be applied to the images. The authors consider the augmentation types of jigsaw puzzle and rotation.They also compare the effects on accuracy of having the unlabelled image be an augmentation of the inputted labelled image (i.e [math]\displaystyle{ \mathcal{D}_s = \mathcal{D}_{ss} }[/math]) versus having the unlabelled image be an augmentation of a different image (i.e [math]\displaystyle{ \mathcal{D}_s \neq \mathcal{D}_{ss} }[/math]).
The training procedure consists of mapping a labelled image and unlabelled augmented image to separate embeddings using the shared feature backbone of the feed-forward convolutional network [math]\displaystyle{ f(x) }[/math]. It is then trained using an loss function [math]\displaystyle{ \mathcal{L} }[/math] which combines a classification loss term [math]\displaystyle{ \mathcal{L}_s }[/math] involving the labelled image embedding and a self-supervised losses term [math]\displaystyle{ \mathcal{L}_{ss} }[/math] involving the unlabelled augmented image embedding.
The classification loss [math]\displaystyle{ \mathcal{L}_s }[/math] is defined as:
[math]\displaystyle{ \mathcal{L}_s := \sum_{(x_i,y_i)\in \mathcal{D}_s} \ell(g \circ f(x_i), y_i) + \mathcal{R}(f,g), }[/math]
Where it is common to use cross-entropy loss for the loss function, [math]\displaystyle{ \ell }[/math], and [math]\displaystyle{ \ell_2 }[/math] norm for the regularization, [math]\displaystyle{ \mathcal{R} }[/math].
The task prediction loss [math]\displaystyle{ \mathcal{L}_{ss} }[/math] utilizes a separate function [math]\displaystyle{ h }[/math] which maps the embeddings of unlabelled images to a separate label space. Here a target label [math]\displaystyle{ \hat{y} }[/math] will be related to the augmentation that was applied to the unlabelled image [math]\displaystyle{ \hat{x} }[/math] . In the case of jigsaw the label will be the indexes of the permutations applied to the original image. In the case of a rotation the label will be the angle of rotation applied to the original image. If we define a set of labelled pairs for the previously unlabelled augmented imaged as, [math]\displaystyle{ \forall x \in \mathcal{D}_{ss}, x \rightarrow (\hat{x_i}, \hat{y_i}) }[/math], then the task prediction loss can then be defined as:
[math]\displaystyle{ \mathcal{L}_{ss} := \sum_{x_i\in \mathcal{D}_{ss}} \ell(h \circ f(\hat{x_i}), \hat{y_i}). }[/math]
The final loss is [math]\displaystyle{ \mathcal{L} := \mathcal{L}_s + \mathcal{L}_{ss} }[/math], and thus the self-supervised losses act as a data-dependent regularizer for representation learning. The gradient updates are therefore performed based on this combined loss. It should be noted that for the case [math]\displaystyle{ \mathcal{D}_s \neq \mathcal{D}_{ss} }[/math], a forward pass is done on a batch per each dataset, and the two losses are combined.
Experiments
The authors of this paper have experimented on the following datasets: Caltech-UCSD birds, Stanford cars, FGVC aircraft, Stanford dogs, Oxford flowers, mini-ImageNet, and tiered-Imagenet. Each dataset is divided into three disjoint sets: base set for training the parameters, val set for validation, and the novel set for testing with a few examples per each class as shown in Figure 2. Data augmentation has been used with all these datasets to improve the results.
The authors used a meta-learning method based on prototypical networks where training and testing are done in stages called meta-training and meta-testing. These networks are similar to distance-based learners and metric-based learners that train on label similarity. Two tasks have been used for the self-supervised learning part, rotation and the Jigsaw puzzle[4]. In the rotation task the image is rotated by an angle [math]\displaystyle{ \theta \in \{0^{\circ}, 90^{\circ}, 180^{\circ}, 270^{\circ}\} }[/math], which results in the input, and the target label is the index of the rotation in the list. In the Jigsaw puzzle task, the image is tiled into [math]\displaystyle{ 3\times3 }[/math] tiles and then these tiles are shuffled to produce the input image. The target is a number in range of 35 based on the hamming distance.
Results
The results on 5-way 5-shot classification accuracy can be seen in Fig. 3. ProtoNet has been used as a baseline and is compared with the Jigsaw task, the rotation task, and both of them combined. The result is that the Jigsaw task always improves the result. However, the rotation task seems to not provide much improvement on the flowers and the aircraft datasets. The authors speculate that this might be because flowers are mostly symmetrical, making the task too hard, and that the planes are usually horizontal, making the task too simple.
In another attempt, it is also proven that the improvements self-supervised learning provides are much higher in more difficult few-shot learning problems. As it can be observed from Fig. 4, SSL is found to be more beneficial with greyscale or low-resolution images, which make the classification harder for natural and man-made objects, respectively.
Self-supervision has also been combined with two other meta-learners in this work, MAML and a standard feature extractor trained with cross-entropy loss (softmax). Fig. 5 summarizes these results, and even though there is an accuracy gain in all scenarios (except for two), the ProtoNet + Jigsaw combination seems to work best.
In Fig. 6 you can see the effects of size and domain of SSL on 5-way 5-shot classification accuracy. First, only 20 percent of the data is used for meta-learning. Fig. 6(a) shows the changes in the accuracy based on increasing the percentage of the images, from the whole dataset, used for SSL. It is observed that increasing the size of the SSL dataset domain has a positive effect, with diminishing ends. Fig. 6(b) shows the effects of shifting the domain of the SSL dataset, by changing a percentage of the images with pictures from other datasets. This has a negative result and moreover, training with SSL on the 20 percent of the images used for meta-learning is often better than increasing the size, but shifting the domain. This is shown as crosses on the chart.
Figure 7 shows the accuracy of the meta-learner with SSL on different domains as function of distance between the supervised domain Ds and the self-supervised domain Dss. Once again we see that the effectiveness of SSL decreases with the distance from the supervised domain across all datasets.
The improvements obtained here generalize to other meta-learners as well. For instance, 5-way 5-shot accuracies across five fine-grained datasets for softmax, MAML, and ProtoNet improve when combined with the jigsaw puzzle task.
Results also show that Self-supervision alone is not enough. A ResNet18 trained with SSL alone achieved 32.9% (w/ jigsaw) and 33.7% (w/ rotation) 5-way 5-shot accuracy averaged across five fine-grained datasets. While this is better than a random initialization (29.5%), it is dramatically worse than one trained with a simple cross-entropy loss (85.5%) on the labels.
Conclusion
The authors of this paper provide us with a great insight on the effects of using SSL as a regulizer for few-shot learning methods. It is proven that SSL is beneficial in almost every case, however, these improvements are much higher in more difficult tasks. It also showed that the dataset used for SSL should not necessarily be large. Increasing the size of the mentioned dataset can possibly help, but only if the added images are from the same or a similar domain.
Critiques
The authors of this paper could have analyzed other SSL tasks in addition to the Jigsaw puzzle and the rotation task, e.g. number of objects and removed patch prediction. Additionally, while analyzing the effects of the data used for SSL, they did not experiment with adding data from other domains, while fully utilizing the base dataset. Moreover, comparing their work with previous works (Fig. 6), we can see they have used mini-ImageNet with a picture size of [math]\displaystyle{ 244\times224 }[/math] in contrast to other methods that have used a [math]\displaystyle{ 84\times84 }[/math] image size. This gives them a huge advantage, however, we still notice that other methods with smaller images have achieved higher accuracy.
Moreover, in fig. 8 the authors considered the same domain learning for different examples and they indicated that adding more unlabeled data of the base classes will increase the accuracy. I would be really curious to apply their approach using cross-domain learning where the base and novel classes come from very different domains. I believe it might add some robustness and take accuracy to a different level. Also, comparing the cross-domain with the same-domain learning might add value to their point when they clued that there is no much improvement in the rotation task especially in the flowers example as it is mostly symmetrical.
References
[1]: Snell, J., Swersky, K., Zemel, R.: Prototypical networks for few-shot learning. In: NeurIPS (2017)
[2]: Finn, C., Abbeel, P., Levine, S.: Model-agnostic meta-learning for fast adaptation of deep networks. In: ICML (2017)
[3]: Kokkinos, I.: Ubernet: Training a universal convolutional neural network for low-, mid-, and high-level vision using diverse datasets and limited memory. In: CVPR (2017)
[4]: Noroozi, M., Favaro, P.: Unsupervised learning of visual representations by solving jigsaw puzzles. In: ECCV (2016)