Model Agnostic Learning of Semantic Features

From statwiki
Jump to: navigation, search

Presented by

Milad Sikaroudi

Introduction

Transfer learning is a line of research which focuses on storing knowledge from one domain (source domain) to solve a similar problem in another domain (target domain). In addition to regular transfer learning, one can use "transfer metric learning" by which similarity relationship between samples [1], [2] are used to learn a metric space in which robust and discriminative data representation are formed. However, both of these kinds of techniques work insofar as the domain shift between source and target domains is negligible. Domain shift is defined as the deviation in the distribution of the source domain and the target domain and it would cause the deep neural network (DNN) model to completely fail. Multi-domain learning (MDL) is the solution when the assumption of "source domain and target domain come from an almost identical distribution" may not hold. There are two variants of MDL in the literature that can be confused, i.e. domain generalization, and domain adaptation; however in domain adaptation, we have access to the target domain data somehow, while that is not the case in domain generalization. This paper introduces a technique for domain generalization based on two complementary losses that regularize the semantic structure of the feature space through an episodic training scheme originally inspired by the model-agnostic meta-learning.

Previous Work

Originated from model-agnostic meta-learning (MAML), episodic training has been widely leveraged for addressing domain generalization [3, 4, 5, 7, 8, 6, 9, 10, 11]. Meta-Learning for domain generalization (MLDG) [4] closely follows MAML in terms of back-propagating the gradients from an ordinary task loss on meta-test data, but it has its own limitation as the use of the task objective might be sub-optimal since it only uses class probabilities. Most of the works [3,7] in the literature lack notable guidance from the semantics of feature space, which contains crucial domain-independent ‘general knowledge’ that can be useful for domain generalization. The authors claim that their method is orthogonal to previous works.


Model Agnostic Meta Learning

Also known as "learning to learn", Model-agnostic meta-learning is a learning paradigm in which optimal initial weights are found incrementally (episodic training) by minimizing a loss function over some similar tasks (meta-train, meta-test sets). Imagine a 4-shot 2-class image classification task as below:

p5.png

Each of the training tasks provides an optimal initial weight for the next round of the training. By considering all of these sets of updates and meta-test sets, the updated weights are calculated using the algorithm below.

algo1.PNG

Method

In domain generalization, we assume that there are some domain-invariant patterns in the inputs (e.g. semantic features). These features can be extracted to learn a predictor that performs well across seen and unseen domains. This paper assumes that there are inter-class relationships across domains. In total, the MASF is composed of a task loss, global class alignment term and a local sample clustering term.

Task loss

[math] F_{\psi}: X \rightarrow Z[/math] where [math] Z [/math] is a feature space [math] T_{\theta}: X \rightarrow \mathbf {R}^{C}[/math] where [math] C [/math] is the number of classes in [math] Y [/math] Assume that [math]\hat{y}= softmax(T_{\theta}(F_{\psi}(x))) [/math]. The parameters [math] (\psi, \theta) [/math] are optimized with minimizing a cross-entropy loss namely [math] \mathbf{L}_{task} [/math] formulated as:

[math] l_{task}(y, \hat{y}) = - \sum_{c}1[y=C]log(\hat{y}_{c})[/math]

Although the task loss is a decent predictor, nothing prevents the model from overfitting to the source domains and suffering from degradation on unseen test domains. This issue is considered in other loss terms.

Model-Agnostic Learning with Episodic Training

The key of their learning procedure is an episodic training scheme, originated from model-agnostic meta-learning, to expose the model optimization to distribution mismatch. In line with their goal of domain generalization, the model is trained on a sequence of simulated episodes with domain shift. Specifically, at each iteration, the available domains [math]D[/math] are randomly split into sets of meta-train [math]D_{tr}[/math] rand meta-test [math]D_{te}[/math] domains. The model is trained to semantically perform well on held-out [math]D_{te}[/math] after being optimized with one or more steps of gradient descent with [math]D_{tr}[/math] domains. In our case, the feature extractor’s and task network’s parameters,ψ and θ, are first updated from the task-specific supervised loss [math]L[/math] task(e.g. cross-entropy for classification), computed on meta-train:

Global class alignment

In semantic space, we assume there are relationships between class concepts. These relationships are invariant to changes in observation domains. Capturing and preserving such class relationships can help models generalize well on unseen data. To achieve this, a global layout of extracted features are imposed such that the relative locations of extracted features reflect their semantic similarity. Since [math] L_{task} [/math] focuses only on the dominant hard label prediction, the inter-class alignment across domains is disregarded. Hence, minimizing symmetrized Kullback–Leibler (KL) divergence across domains, averaged over all [math] C [/math] classes has been used:

[math] l_{global}(D_{i}, D{j}; \psi^{'}, \theta^{'}) = 1/C \sum_{c=1}^{C} 1/2[D_{KL}(s_{c}^{(i)}||s_{c}^{(j)}) + D_{KL}(s_{c}^{(j)}||s_{c}^{(i)})], [/math]

The authors stated that symmetric divergences such as Jensen–Shannon (JS) showed no significant difference with KL over symmetry.

Local cluster sampling

[math] L_{global} [/math] captures inter-class relationships, we also want to make semantic features close to each other locally. Explicit metric learning, i.e. contrastive or triplet losses, have been used to ensure that the semantic features, locally cluster according to only class labels, regardless of the domain. Contrastive loss takes two samples as input and makes samples of the same class closer while pushing away samples of different classes. contrastive.png

Conversely, triplet loss takes three samples as input: one anchor, one positive, and one negative. Triplet loss tries to make relevant samples closer than irrelevant ones.

[math] l_{triplet}^{a,p,n} = \sum_{i=1}^{b} \sum_{k=1}^{c-1} \sum_{\ell=1}^{c-1}\! [m\!+\!\|x_{i}\!- \!x_{k}\|_2^2 \!-\! \|x_{i}\!-\!x_{\ell}\|_2^2 ]_+, [/math]

Model agnostic learning of semantic features

These losses are used in an episodic training scheme showed in the below figure:

algo2.PNG

The training architecture and three losses are also illustrated as below:

Ashraf99.png

Experiments

The usefulness of the proposed method has been demonstrated using two common benchmark datasets for domain generalization, i.e. VLCS and PACS, alongside a real-world MRI medical imaging segmentation task. In all of their experiments, the AlexNet with ImageNet pre-trained weights has been utilized.

VLCS

VLCS[12] is an aggregation of images from four other datasets: PASCAL VOC2007 (V) [13], LabelMe (L) [14], Caltech (C) [15], and SUN09 (S) [16] leave-one-domain-out validation with randomly dividing each domain into 70% training and 30% test.

Notably, MASF outperforms MLDG[4], in the table below on this dataset, indicating that semantic properties would provide superior performance with respect to purely highly-abstracted task loss on meta-test. "DeepAll" in the table is the case in which there is no domain generalization. In DeepAll case the class labels have been used only, regardless of the domain each sample would lie in.

table1 masf.PNG

PACS

The more challenging domain generalization benchmark with a significant domain shift is the PACS dataset [17]. This dataset contains art painting, cartoon, photo, sketch domains with objects from seven classes: dog, elephant, giraffe, guitar, house, horse, person.

As you can see in the table below, MASF outperforms state of the art JiGen[18], MLDG[4], MetaReg[3], significantly. In addition, the best improvement has achieved (6.20%) when the unseen domain is "sketch", which requires more general knowledge about semantic concepts since it is different from other domains significantly.

Figure2 image.png
t-SNE visualizations of extracted features.


table2 masf.PNG

Ablation study over PACS

The ablation study over the PACS dataset shows the effectiveness of each loss term.

table3 masf.PNG

Deeper Architectures

For stronger baseline results, the authors have performed additional experiments using advanced deep residual architectures like ResNet-18 and ResNet-50. The below table shows strong and consistent improvements of MASF over the DeepAll baseline in all PACS splits for both network architectures. This suggests that the proposed algorithm is also beneficial for domain generalization with deeper feature extractors.

Paper18 PacResults.PNG

Multi-site Brain MRI image segmentation

The effectiveness of the MASF has been also demonstrated using a segmentation task of MRI images gathering from four different clinical centers denoted as (Set-A, Set-B, Set-C, and Set-D). The domain shift, in this case, would occur due to differences in hardware, acquisition protocols, and many other factors, hindering translating learning-based methods to real clinical practice. The authors attempted to segment the brain images into four classes: background, grey matter, white matter, and cerebrospinal fluid. Tasks such as these have enormous impact in clinical diagnosis and aiding in treatment. For example, designing a similar net to segment between healthy brain tissue and tumorous brain tissue could aid surgeons in brain tumour resection.


The results showed the effectiveness of the MASF in comparison to not use domain generalization.

table5 masf.PNG

Conclusion

The work proposes a new domain generalization technique by taking the advantage of global and local constraints for learning semantic feature spaces, which outperforms the state-of-the-art. The power and effectiveness of this method has been demonstrated using two domain generalization benchmarks and a real clinical dataset (MRI image segmentation). The code is publicly available at [19]. As future work, it would be interesting to integrate the proposed loss functions with other methods as they are orthogonal to each other and evaluate the benefit of doing so. Also, investigating the usage of the current learning procedure in the context of generative models would be an interesting research direction.

Critiques

The purpose of this paper is to help guide learning in semantic feature space by leveraging local similarity. The authors argument may contain essential domain-independent general knowledge for domain generalization to solve this issue. In addition to adopting constructive loss and triplet loss to encourage the clustering for solving this issue. Extracting robust semantic features regardless of domains can be learned by leveraging from the across-domain class similarity information, which is important information during learning. The learner would suffer from indistinct decision boundaries if it could not separate the samples from different source domains with separation on the domain invariant feature space and in-dependent class-specific cohesion. The major problem that will be revealed with large datasets is that these indistinct decision boundaries might still be sensitive to the unseen target domain.

References

[1]: Koch, Gregory, Richard Zemel, and Ruslan Salakhutdinov. "Siamese neural networks for one-shot image recognition." ICML deep learning workshop. Vol. 2. 2015.

[2]: Hoffer, Elad, and Nir Ailon. "Deep metric learning using triplet network." International Workshop on Similarity-Based Pattern Recognition. Springer, Cham, 2015.

[3]: Balaji, Yogesh, Swami Sankaranarayanan, and Rama Chellappa. "Metareg: Towards domain generalization using meta-regularization." Advances in Neural Information Processing Systems. 2018.

[4]: Li, Da, et al. "Learning to generalize: Meta-learning for domain generalization." arXiv preprint arXiv:1710.03463 (2017).

[5]: Li, Da, et al. "Episodic training for domain generalization." Proceedings of the IEEE International Conference on Computer Vision. 2019.

[6]: Li, Haoliang, et al. "Domain generalization with adversarial feature learning." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018.

[7]: Li, Yiying, et al. "Feature-critic networks for heterogeneous domain generalization." arXiv preprint arXiv:1901.11448 (2019).

[8]: Ghifary, Muhammad, et al. "Domain generalization for object recognition with multi-task autoencoders." Proceedings of the IEEE international conference on computer vision. 2015.

[9]: Li, Ya, et al. "Deep domain generalization via conditional invariant adversarial networks." Proceedings of the European Conference on Computer Vision (ECCV). 2018

[10]: Motiian, Saeid, et al. "Unified deep supervised domain adaptation and generalization." Proceedings of the IEEE International Conference on Computer Vision. 2017.

[11]: Muandet, Krikamol, David Balduzzi, and Bernhard Schölkopf. "Domain generalization via invariant feature representation." International Conference on Machine Learning. 2013.

[12]: Fang, Chen, Ye Xu, and Daniel N. Rockmore. "Unbiased metric learning: On the utilization of multiple datasets and web images for softening bias." Proceedings of the IEEE International Conference on Computer Vision. 2013.

[13]: Everingham, Mark, et al. "The pascal visual object classes (voc) challenge." International journal of computer vision 88.2 (2010): 303-338.

[14]: Russell, Bryan C., et al. "LabelMe: a database and web-based tool for image annotation." International journal of computer vision 77.1-3 (2008): 157-173.

[15]: Fei-Fei, Li. "Learning generative visual models from few training examples." Workshop on Generative-Model Based Vision, IEEE Proc. CVPR, 2004. 2004.

[16]: Chopra, Sumit, Raia Hadsell, and Yann LeCun. "Learning a similarity metric discriminatively, with application to face verification." 2005 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR'05). Vol. 1. IEEE, 2005.

[17]: Da Li, Yongxin Yang, Yi-Zhe Song, and Timothy M Hospedales. "Deeper, broader and artier domain generalization". IEEE International Conference on Computer Vision (ICCV), 2017.

[18]: Carlucci, Fabio M., et al. "Domain generalization by solving jigsaw puzzles." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.

[19]: https://github.com/biomedia-mira/masf