When Does Self-Supervision Improve Few-Shot Learning?

From statwiki
Jump to navigation Jump to search

Presented by

Arash Moayyedi

Introduction

This paper proposes a technique utilizing self-supervised learning (SSL) to improve the generalization of few-shot learned representations on small labeled data sets.

Few-shot learning refers to training a classifier on minimalist datasets, contrary to the normal practice of using massive data, in hope of successfully classifying previously unseen, but related classes.

Self-supervised learning aims at teaching the agent the internal structures of the images by providing it with tasks such as predicting the degree of rotation in an image. The following image indicates the rotation prediction as proxy task in self-supervision. The proposed method can help aid against generalization issues where the agent cannot distinguish the difference between newly introduced objects. Self-supervision is an inevitable and powerful method for taking advantage of the vast amount of unlabeled data.

Previous Work

This work leverages few-shot learning, where we aim to learn general representations, so that when facing novel classes, the agent can differentiate between them with training on just a few samples. Many few-shot learning methods currently exist, among which is this paper which focuses on Prototypical Networks or ProtoNets[1] for short. There is also a section of this paper that compares this model with model-agnostic meta-learner (MAML)[2]. [note 1]


The other machine learning technique that this paper is based on is self-supervised learning. In this technique, unlabelled data is utilized which can avoid incurring the computational expenses of labeling and maintaining a massive data set. Images already contain structural information that can be utilized. There exist many SSL tasks, such as removing a part of the data in order for the agent to reconstruct the lost part. Other methods include task prediction rotations, relative patch location, etc.

The work in this paper is also related to multi-task learning. In multi-task learning training proceeds on multiple tasks concurrently to improve each other. Training on multiple tasks is known to decline the performance on individual tasks[3] and this seems to work only for very specific combinations and architectures. This paper shows that the combination of self-supervised tasks and few-shot learning are mutually beneficial to each other and this has significant practical implications since self-supervised tasks do not require any annotations.

Method

The authors of this paper suggest a framework, as seen in Fig. 1, that combines few-shot learning with self-supervised learning.

In this a feed-forward convolutional network [math]\displaystyle{ f(x) }[/math] maps either a labeled image or an augmented unlabelled image to an embedding space. Depending on the input type the embedding is then mapped to one of two label spaces by either a classifier [math]\displaystyle{ g }[/math] or a function [math]\displaystyle{ h }[/math]. When evaluating the accuracy of the model only the mappings of labelled images by the classifier[math]\displaystyle{ g }[/math] will be considered. Whereas when training the model both mappings of labelled and unlabelled images by [math]\displaystyle{ g }[/math] and [math]\displaystyle{ h }[/math] respectively will be utilized. The labelled training data consists of a set of base classes in pairs of images and labels, and its domain is denoted by [math]\displaystyle{ \mathcal{D}_s }[/math]. Similarly, the domain of the unlabelled images used for the self-supervised tasks is shown by [math]\displaystyle{ \mathcal{D}_{ss} }[/math]. Within this domain augmentations will have be applied to the images. The authors consider the augmentation types of jigsaw puzzle and rotation.They also compare the effects on accuracy of having the unlabelled image be an augmentation of the inputted labelled image (i.e [math]\displaystyle{ \mathcal{D}_s = \mathcal{D}_{ss} }[/math]) versus having the unlabelled image be an augmentation of a different image (i.e [math]\displaystyle{ \mathcal{D}_s \neq \mathcal{D}_{ss} }[/math]).

Figure 1: Combining supervised and self-supervised losses for few-shot learning. . This paper investigates how the performance on the supervised learning task is influenced by the the choice of the self-supervision task.

The training procedure consists of mapping a labelled image and an unlabelled augmented image to separate embeddings using the shared feature backbone of the feed-forward convolutional network [math]\displaystyle{ f }[/math]. It is then trained using an loss function [math]\displaystyle{ \mathcal{L} }[/math] which combines a classification loss term [math]\displaystyle{ \mathcal{L}_s }[/math] involving the labelled image embedding and a self-supervised losses term [math]\displaystyle{ \mathcal{L}_{ss} }[/math] involving the unlabelled augmented image embedding.

The classification loss [math]\displaystyle{ \mathcal{L}_s }[/math] is defined as:

[math]\displaystyle{ \mathcal{L}_s := \sum_{(x_i,y_i)\in \mathcal{D}_s} \ell(g \circ f(x_i), y_i) + \mathcal{R}(f,g), }[/math]

Where it is common to use cross-entropy loss for the loss function, [math]\displaystyle{ \ell }[/math], and [math]\displaystyle{ \ell_2 }[/math] norm for the regularization, [math]\displaystyle{ \mathcal{R} }[/math].

The task prediction loss [math]\displaystyle{ \mathcal{L}_{ss} }[/math] utilizes a separate function [math]\displaystyle{ h }[/math] which maps the embeddings of unlabelled images to a separate label space. Here a target label [math]\displaystyle{ \hat{y} }[/math] will be related to the augmentation that was applied to the unlabelled image. In the case of jigsaw the label will be the indexes of the permutations applied to the original image. In the case of a rotation the label will be the angle of rotation applied to the original image. If we define a set of labelled pairs for the previously unlabelled augmented imaged as, [math]\displaystyle{ \forall x \in \mathcal{D}_{ss}, x \rightarrow (\hat{x}, \hat{y}) }[/math], where [math]\displaystyle{ \hat{x} }[/math] is the identity mapping of [math]\displaystyle{ x }[/math], then the task prediction loss can then be defined as:

[math]\displaystyle{ \mathcal{L}_{ss} := \sum_{x_i\in \mathcal{D}_{ss}} \ell(h \circ f(\hat{x_i}), \hat{y_i}). }[/math]


The final loss is [math]\displaystyle{ \mathcal{L} := \mathcal{L}_s + \mathcal{L}_{ss} }[/math], and thus the self-supervised losses act as a data-dependent regularizer for representation learning. The gradient updates are therefore performed based on this combined loss. It should be noted that for the case [math]\displaystyle{ \mathcal{D}_s \neq \mathcal{D}_{ss} }[/math], a forward pass is done on a batch per each dataset, and the two losses are combined.

Experiments

To assess the proposed method, several datasets, e.g., Caltech-UCSD birds, Stanford cars, FGVC aircraft, Stanford dogs, Oxford flowers, mini-ImageNet, and tiered-Imagenet, have been employed. Each dataset is divided into three disjoint sets: base set for training the parameters, val set for validation, and the novel set for testing with a few examples per each class as shown in Figure 2. Data augmentation has been used with all these datasets to improve the results.

Figure 2: Used datasets and their base, validation and test splits.

The authors used a meta-learning method based on prototypical networks where training and testing are done in stages called meta-training and meta-testing. These networks are similar to distance-based learners and metric-based learners that train on label similarity. Two tasks have been used for the self-supervised learning part, rotation and the Jigsaw puzzle[4]. In the rotation task the image is rotated by an angle [math]\displaystyle{ \theta \in \{0^{\circ}, 90^{\circ}, 180^{\circ}, 270^{\circ}\} }[/math], which results in the input, and the target label is the index of the rotation in the list. In the Jigsaw puzzle task, the image is tiled into [math]\displaystyle{ 3\times3 }[/math] tiles and then these tiles are shuffled to produce the input image. The target is a number in range of 35 based on the hamming distance, which calculates the number of permutations needed to convert the tiled and shuffled image back to its original form.

Results

An N-way k-shot classification task contains N unique classes with k labeled images per class. The results on 5-way 5-shot classification accuracy can be seen in Fig. 3. ProtoNet has been used as a baseline and is compared with the Jigsaw task, the rotation task, and both of them combined. The result is that the Jigsaw task always improves the result. However, the rotation task seems to not provide much improvement on the flowers and the aircraft datasets. The authors speculate that this might be because flowers are mostly symmetrical, making the task too hard, and that the planes are usually horizontal, making the task too simple.

Figure 3: Benefits of SSL for few-shot learning tasks.

In another attempt, it is also proven that the improvements self-supervised learning provides are much higher in more difficult few-shot learning problems. As it can be observed from Fig. 4, SSL is found to be more beneficial with greyscale or low-resolution images, which make the classification harder for natural and man-made objects, respectively.

Figure 4: Benefits of SSL for harder few-shot learning tasks.

Self-supervision has also been combined with two other meta-learners in this work, MAML and a standard feature extractor trained with cross-entropy loss (softmax). Fig. 5 summarizes these results, and even though there is an accuracy gain in all scenarios (except for two), the ProtoNet + Jigsaw combination seems to work best.

Figure 5: Performance on few-shot learning using different meta-learners.

In Fig. 6 you can see the effects of size and domain of SSL on 5-way 5-shot classification accuracy. First, only 20 percent of the data is used for meta-learning. Fig. 6(a) shows the changes in the accuracy based on increasing the percentage of the images, from the whole dataset, used for SSL. It is observed that increasing the size of the SSL dataset domain has a positive effect, with diminishing ends. Fig. 6(b) shows the effects of shifting the domain of the SSL dataset, by changing a percentage of the images with pictures from other datasets. This has a negative result and moreover, training with SSL on the 20 percent of the images used for meta-learning is often better than increasing the size, but shifting the domain. This is shown as crosses on the chart.

Figure 6: (a) Effect of number of images on SSL. (b) Effect of domain shift on SSL.


Figure 7 shows the accuracy of the meta-learner with SSL on different domains as a function of the distance between the supervised domain Ds and the self-supervised domain Dss. Once again we see that the effectiveness of SSL decreases with the distance from the supervised domain across all datasets.

Figure 7: Effectiveness of SSL as a function of domain distance between Ds and Dss (shown on top).

The improvements obtained here generalize to other meta-learners as well. For instance, 5-way 5-shot accuracies across five fine-grained datasets for softmax, MAML, and ProtoNet improve when combined with the jigsaw puzzle task.

Results also show that Self-supervision alone is not enough. A ResNet18 trained with SSL alone achieved 32.9% (w/ jigsaw) and 33.7% (w/ rotation) 5-way 5-shot accuracy averaged across five fine-grained datasets. While this is better than a random initialization (29.5%), it is dramatically worse than one trained with a simple cross-entropy loss (85.5%) on the labels.

Source Codes

The source code can be found here: https://github.com/cvl-umass/fsl_ssl .

Conclusion

The authors of this paper provide us with great insight into the effects of using SSL as a regularizer for few-shot learning methods. It is proven that SSL is beneficial in almost every case, however, these improvements are much higher in more difficult tasks. It also showed that the dataset used for SSL should not necessarily be large. Increasing the size of the mentioned dataset can possibly help, but only if the added images are from the same or a similar domain.

Critiques

The authors of this paper could have analyzed other SSL tasks in addition to the Jigsaw puzzle and the rotation task, e.g. number of objects and removed patch prediction. Additionally, while analyzing the effects of the data used for SSL, they did not experiment with adding data from other domains, while fully utilizing the base dataset. Moreover, comparing their work with previous works (Fig. 6), we can see they have used mini-ImageNet with a picture size of [math]\displaystyle{ 244\times224 }[/math] in contrast to other methods that have used a [math]\displaystyle{ 84\times84 }[/math] image size. This gives them a huge advantage, however, we still notice that other methods with smaller images have achieved higher accuracy.

Moreover, in fig. 8 the authors considered the same domain learning for different examples, and they indicated that adding more unlabeled data of the base classes will increase the accuracy. I would be really curious to apply their approach using cross-domain learning where the base and novel classes come from very different domains. I believe it might add some robustness and take accuracy to a different level. Also, comparing the cross-domain with the same-domain learning might add value to their point when they clued that there is no much improvement in the rotation task especially in the flowers example as it is mostly symmetrical.

Figure 8: Comparison with prior works on mini_ImageNet.

I believe that both strength and weakness of this paper is in its experiments. Different experiments compare a variety self-supervised learning algorithms which is a good point. However, as the reviewers also pointed out, there are some concerns including the level of novelty in the work, the way of creating unlabeled pool, and finally employing pre-trained ResNet-101 on ImageNet and mini-ImageNet in their experiments.

Notes

1. Model-Agnostic Meta-learning (MAML): Neural networks are performing very well at many tasks, but they often require large datasets. On the contrary, humans are able to learn new skills with little examples. MAML is trained with different tasks, which have the role of training sets, and is used to learn new tasks that are like test sets. Therefore, MAML is able to perform well on tasks with small training sets without overfitting to the data.[5]

References

[1]: Snell, J., Swersky, K., Zemel, R.: Prototypical networks for few-shot learning. In: NeurIPS (2017)

[2]: Finn, C., Abbeel, P., Levine, S.: Model-agnostic meta-learning for fast adaptation of deep networks. In: ICML (2017)

[3]: Kokkinos, I.: Ubernet: Training a universal convolutional neural network for low-, mid-, and high-level vision using diverse datasets and limited memory. In: CVPR (2017)

[4]: Noroozi, M., Favaro, P.: Unsupervised learning of visual representations by solving jigsaw puzzles. In: ECCV (2016)

[5]: Chelsea Finn, Pieter Abbeel, and Sergey Levine. Model-agnostic meta-learning for fast adaptation of deep networks. arXiv preprint arXiv:1703.03400, 2017.