Model Agnostic Learning of Semantic Features
Presented by
Milad Sikaroudi
Introduction
Transfer learning is a line of research in machine learning which focuses on storing knowledge from one domain (source domain) to solve a similar problem in another domain (target domain). In addition to regular transfer learning, one can use "transfer metric learning" in which through utilizing a similarity relationship between samples [1], [2] a more robust and discriminative data representation is formed. However, both of these kinds of techniques work insofar as the domain shift, between source and target domains, is negligible. Domain shift is defined as the deviation in the distribution of the source domain and the target domain and it would cause the DNN model to completely fail. The multi-domain learning is the solution when the assumption of "source domain and target domain comes from an almost same distribution" may not hold. There are two variants of MDL in the literature that can be confused, i.e. domain generalization, and domain adaptation; however in domain adaptation, we have access to the target domain data somehow, while that is not the case in domain generalization. This paper introduces a technique for domain generalization based on two complementary losses that regularize the semantic structure of the feature space through an episodic training scheme originally inspired by the model-agnostic meta-learning.
Previous Work
Originated from model-agnostic meta-learning (MAML), episodic training has been vastly leveraged for addressing domain generalization [3, 4, 5, 7, 8, 6, 9, 10, 11]. The method of MLDG [4] closely follows MAML in terms of back-propagating the gradients from an ordinary task loss on meta-test data, but it has its own limitation as the use of the task objective might be sub-optimal since it only uses class probabilities. Most of the works [3,7] in the literature lack notable guidance from the semantics of feature space, which contains crucial domain-independent ‘general knowledge’ that can be useful for domain generalization. The authors claim that their method is orthogonal to previous works.
Model Agnostic Meta Learning
Also known as "learning to learn", Model-agnostic Meta Learning is a learning paradigm in which optimal initial weights are found incrementally (episodic training) by minimizing a loss function over some similar tasks (meta-train, meta-test sets). Imagine a 4-shot 2-class image classification task as below:
Each of the training tasks provides an optimal initial weight for the next round of the training. By considering all of these sets of updates and meta-test set, the updated weights are calculated using algorithm below.
Method
In domain generalization, we assume that there are some domain-invariant patterns in the inputs (e.g. semantic features). These features can be extracted to learn a predictor that performs well across seen and unseen domains. This paper assumes that there are inter-class relationships across domains. In total, the MASF is composed of a task loss, global class alignment term and a local sample clustering term.
Task loss
[math]\displaystyle{ F_{\psi}: X \rightarrow Z }[/math] where [math]\displaystyle{ Z }[/math] is a feature space [math]\displaystyle{ T_{\theta}: X \rightarrow \mathbf {R}^{C} }[/math] where [math]\displaystyle{ C }[/math] is the number of classes in [math]\displaystyle{ Y }[/math] Assume that [math]\displaystyle{ \hat{y}= softmax(T_{\theta}(F_{\psi}(x))) }[/math]. The parameters [math]\displaystyle{ (\psi, \theta) }[/math] are optimized with minimizing a cross-entropy loss namely [math]\displaystyle{ \mathbf{L}_{task} }[/math] formulated as:
[math]\displaystyle{ l_{task}(y, \hat{y}) = - \sum_{c}1[y=C]log(\hat{y}_{c}) }[/math]
Although the task loss is a decent predictor, nothing prevents the model from overfitting to the source domains and suffering from degradation on unseen test domains. This issue is considered in other loss terms.
Model-Agnostic Learning with Episodic Training
The key of their learning procedure is an episodic training scheme, originated from model-agnostic meta-learning, to expose the model optimization to distribution mismatch. In line with their goal of domain generalization, the model is trained on a sequence of simulated episodes with domain shift. Specifically, at each iteration, the available domains [math]\displaystyle{ D }[/math] are randomly split into sets of meta-train [math]\displaystyle{ D_{tr} }[/math] rand meta-test [math]\displaystyle{ D_{te} }[/math] domains. The model is trained to semantically perform well on held-out [math]\displaystyle{ D_{te} }[/math] after being optimized with one or more steps of gradient descent with [math]\displaystyle{ D_{tr} }[/math] domains. In our case, the feature extractor’s and task network’s parameters,ψ and θ, are first updated from the task-specific supervised loss [math]\displaystyle{ L }[/math] task(e.g. cross-entropy for classification), computed on meta-train:
Global class alignment
In semantic space, we assume there are relationships between class concepts. These relationships are invariant to changes in observation domains. Capturing and preserving such class relationships can help models generalize well on unseen data. To achieve this, a global layout of extracted features are imposed such that the relative locations of extracted features reflect their semantic similarity. Since [math]\displaystyle{ L_{task} }[/math] focuses only on the dominant hard label prediction, the inter-class alignment across domains is disregarded. Hence, minimizing symmetrized Kullback–Leibler (KL) divergence across domains, averaged over all [math]\displaystyle{ C }[/math] classes has been used:
[math]\displaystyle{ l_{global}(D_{i}, D{j}; \psi^{'}, \theta^{'}) = 1/C \sum_{c=1}^{C} 1/2[D_{KL}(s_{c}^{(i)}||s_{c}^{(j)}) + D_{KL}(s_{c}^{(j)}||s_{c}^{(i)})], }[/math]
The authors stated that symmetric divergences such as Jensen–Shannon (JS) showed no significant difference with KL over symmetry.
Local cluster sampling
[math]\displaystyle{ L_{global} }[/math] captures inter-class relationships, we also want to make semantic features close to each other locally. Explicit metric learning, i.e. contrastive or triplet losses, have been used to ensure that the semantic features, locally cluster according to only class labels, regardless of the domain. Contrastive loss takes two samples as input and makes samples of the same class closer while pushing away samples of different classes.
Conversely, triplet loss takes three samples as input: one anchor, one positive, and one negative. Triplet loss tries to make relevant samples closer than irrelevant ones.
[math]\displaystyle{ l_{triplet}^{a,p,n} = \sum_{i=1}^{b} \sum_{k=1}^{c-1} \sum_{\ell=1}^{c-1}\! [m\!+\!\|x_{i}\!- \!x_{k}\|_2^2 \!-\! \|x_{i}\!-\!x_{\ell}\|_2^2 ]_+, }[/math]
Model agnostic learning of semantic features
These losses are used in an episodic training scheme showed in the below figure:
The training architecture and three losses are also illustrated as below:
Experiments
The usefulness of the proposed method has been demonstrated using two common benchmark datasets for domain generalization, i.e. VLCS and PACS, alongside a real-world MRI medical imaging segmentation task. In all of their experiments, the AlexNet with ImageNet pre-trained weights has been utilized.
VLCS
VLCS[12] is an aggregation of images from four other datasets: PASCAL VOC2007 (V) [13], LabelMe (L) [14], Caltech (C) [15], and SUN09 (S) [16] leave-one-domain-out validation with randomly dividing each domain into 70% training and 30% test.
-
VLCS dataset
Notably, MASF outperforms MLDG[4], in the table below on this dataset, indicating that semantic properties would provide superior performance with respect to purely highly-abstracted task loss on meta-test. "DeepAll" in the table is the case in which there is no domain generalization. In DeepAll case the class labels have been used only, regardless of the domain each sample would lie in.
PACS
The more challenging domain generalization benchmark with a significant domain shift is the PACS dataset [17]. This dataset contains art painting, cartoon, photo, sketch domains with objects from seven classes: dog, elephant, giraffe, guitar, house, horse, person.
-
PACS dataset sample
As you can see in the table below, MASF outperforms state of the art JiGen[18], MLDG[4], MetaReg[3], significantly. In addition, the best improvement has achieved (6.20%) when the unseen domain is "sketch", which requires more general knowledge about semantic concepts since it is different from other domains significantly.
Ablation study over PACS
The ablation study over the PACS dataset shows the effectiveness of each loss term.
Deeper Architectures
For stronger baseline results, the authors have performed additional experiments using advanced deep residual architectures like ResNet-18 and ResNet-50. The below table shows strong and consistent improvements of MASF over the DeepAll baseline in all PACS splits for both network architectures. This suggests that the proposed algorithm is also beneficial for domain generalization with deeper feature extractors.
Multi-site Brain MRI image segmentation
The effectiveness of the MASF has been also demonstrated using a segmentation task of MRI images gathering from four different clinical centers denoted as (Set-A, Set-B, Set-C, and Set-D). The domain shift, in this case, would occur due to differences in hardware, acquisition protocols, and many other factors, hindering translating learning-based methods to real clinical practice. The authors attempted to segment the brain images into four classes: background, grey matter, white matter, and cerebrospinal fluid. Tasks such as these have enormous impact in clinical diagnosis and aiding in treatment. For example, designing a similar net to segment between healthy brain tissue and tumorous brain tissue could aid surgeons in brain tumour resection.
-
MRI dataset
The results showed the effectiveness of the MASF in comparison to not use domain generalization.
Conclusion
A new domain generalization technique by taking the advantage of incorporating global and local constraints for learning semantic feature spaces presented which outperforms the state-of-the-art. The power and effectiveness of this method have been demonstrated using two domain generalization benchmarks, and a real clinical dataset (MRI image segmentation). The code is publicly available at [19]. As future work, it would be interesting to integrate the proposed loss functions with other methods as they are orthogonal to each other and evaluate the benefit of doing so. Also, investigating the usage of the current learning procedure in the context of generative models would be an interesting research direction.
Critiques
The purpose of this paper is to help guide learning in semantic feature space by leveraging local similarity. The authors argument may contain essential domain-independent general knowledge for domain generalization to solve this issue. In addition to adopting constructive loss and triplet loss to encourage the clustering for solving this issue. Extracting robust semantic features regardless of domains can be learned by leveraging from the across-domain class similarity information, which is important information during learning. The learner would suffer from indistinct decision boundaries if it could not separate the samples from different source domains with separation on the domain invariant feature space and in-dependent class-specific cohesion. The major problem that will be revealed with large datasets is that these indistinct decision boundaries might still be sensitive to the unseen target domain.
References
[1]: Koch, Gregory, Richard Zemel, and Ruslan Salakhutdinov. "Siamese neural networks for one-shot image recognition." ICML deep learning workshop. Vol. 2. 2015.
[2]: Hoffer, Elad, and Nir Ailon. "Deep metric learning using triplet network." International Workshop on Similarity-Based Pattern Recognition. Springer, Cham, 2015.
[3]: Balaji, Yogesh, Swami Sankaranarayanan, and Rama Chellappa. "Metareg: Towards domain generalization using meta-regularization." Advances in Neural Information Processing Systems. 2018.
[4]: Li, Da, et al. "Learning to generalize: Meta-learning for domain generalization." arXiv preprint arXiv:1710.03463 (2017).
[5]: Li, Da, et al. "Episodic training for domain generalization." Proceedings of the IEEE International Conference on Computer Vision. 2019.
[6]: Li, Haoliang, et al. "Domain generalization with adversarial feature learning." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018.
[7]: Li, Yiying, et al. "Feature-critic networks for heterogeneous domain generalization." arXiv preprint arXiv:1901.11448 (2019).
[8]: Ghifary, Muhammad, et al. "Domain generalization for object recognition with multi-task autoencoders." Proceedings of the IEEE international conference on computer vision. 2015.
[9]: Li, Ya, et al. "Deep domain generalization via conditional invariant adversarial networks." Proceedings of the European Conference on Computer Vision (ECCV). 2018
[10]: Motiian, Saeid, et al. "Unified deep supervised domain adaptation and generalization." Proceedings of the IEEE International Conference on Computer Vision. 2017.
[11]: Muandet, Krikamol, David Balduzzi, and Bernhard Schölkopf. "Domain generalization via invariant feature representation." International Conference on Machine Learning. 2013.
[12]: Fang, Chen, Ye Xu, and Daniel N. Rockmore. "Unbiased metric learning: On the utilization of multiple datasets and web images for softening bias." Proceedings of the IEEE International Conference on Computer Vision. 2013.
[13]: Everingham, Mark, et al. "The pascal visual object classes (voc) challenge." International journal of computer vision 88.2 (2010): 303-338.
[14]: Russell, Bryan C., et al. "LabelMe: a database and web-based tool for image annotation." International journal of computer vision 77.1-3 (2008): 157-173.
[15]: Fei-Fei, Li. "Learning generative visual models from few training examples." Workshop on Generative-Model Based Vision, IEEE Proc. CVPR, 2004. 2004.
[16]: Chopra, Sumit, Raia Hadsell, and Yann LeCun. "Learning a similarity metric discriminatively, with application to face verification." 2005 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR'05). Vol. 1. IEEE, 2005.
[17]: Da Li, Yongxin Yang, Yi-Zhe Song, and Timothy M Hospedales. "Deeper, broader and artier domain generalization". IEEE International Conference on Computer Vision (ICCV), 2017.
[18]: Carlucci, Fabio M., et al. "Domain generalization by solving jigsaw puzzles." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.