http://wiki.math.uwaterloo.ca/statwiki/api.php?action=feedcontributions&user=Gsahu&feedformat=atomstatwiki - User contributions [US]2022-01-23T06:33:09ZUser contributionsMediaWiki 1.28.3http://wiki.math.uwaterloo.ca/statwiki/index.php?title=F18-STAT946-Proposal&diff=42452F18-STAT946-Proposal2018-12-14T19:50:33Z<p>Gsahu: </p>
<hr />
<div><br />
'''Project # 0'''<br />
Group members:<br />
<br />
Last name, First name<br />
<br />
Last name, First name<br />
<br />
Last name, First name<br />
<br />
Last name, First name<br />
<br />
'''Title:''' Making a String Telephone<br />
<br />
'''Description:''' We use paper cups to make a string phone and talk with friends while learning about sound waves with this science project. (Explain your project in one or two paragraphs).<br />
<br />
--------------------------------------------------------------------<br />
<br />
'''Project # 1'''<br />
Group members:<br />
<br />
Zhang, Xinyue<br />
<br />
Zhang, Junyi<br />
<br />
Chen, Shala<br />
<br />
'''Title:''' Airbus Ship Detection Challenge<br />
<br />
'''Description:''' The idea and data for this project is taken from https://www.kaggle.com/c/airbus-ship-detection#description. The goal for this project is to build a model that detects all ships in satellite images and put an aligned bounding box segment around the ships we locate. We are going to extract the segmentation map for the ships first, augment the images and train a simple CNN model to detect them.<br />
<br />
<br />
--------------------------------------------------------------------<br />
<br />
'''Project # 2'''<br />
Group members:<br />
<br />
Nekoei, Hadi<br />
<br />
Afify, Ahmed<br />
<br />
Carrillo, Juan<br />
<br />
Ganapathi Subramanian, Sriram<br />
<br />
'''Title:''' Algorithmic Analysis and Improvements in Multi-Agent Reinforcement Learning in Partially Observable Settings<br />
<br />
'''Description:''' Reinforcement learning (RL) is a branch of Machine Learning in which an agent learns to act optimally in an environment using weak reward signals, which is different from strong labels in supervised learning. Multi-Agent Reinforcement Learning (MARL) is composed of multiple agents that can be competing against each other or cooperating together to achieve a common goal. <br />
<br />
Our project aims to investigate the performance of several state of the art Multi-Agent Reinforcement Learning (MARL) algorithms in playing the game of Pommerman. This game will be used as a benchmark during a competition that will be held in NIPS 2018 (https://www.pommerman.com). We plan to participate and compare the performance of our agents against agents created by other researchers. Our project also aims to make algorithmic improvements to the state of the art MARL algorithms and come up with a new algorithm that renders best performance in this partially observable multi-agent setting of Pommerman. <br />
<br />
In Pommerman, we have two competing teams, each has two agents who work together to defeat the opponent team. The agents move inside the board leaving bombs that can eliminate other agents when exploding in their horizontal or vertical vicinity. The agents can obtain bonuses such as extra bombs, increased bomb range, or ability to kick installed bombs. Our two agents can choose one of the following actions: stop, move up, move left, move down, move right, or lay a bomb. Each agent will receive an 11x11 grid of integer values representing the board state. Additional information will be provided to the agents such as its own position, positions of his teammate and enemies, available bombs, blast strength, kicking ability, and surrounding walls and bombs. <br />
<br />
The algorithms that we are considering are: <br />
<br />
- Monte Carlo Tree Search and Reinforcement Learning: Combining MCTS with deep neural networks.<br />
<br />
- Multi-Agent Deep Deterministic Policy Gradient (DDPG): A technique developed by OpenAI, based on the Deep Deterministic Policy Gradient technique that outperforms traditional Reinforcement Learning algorithms (DQN/DDPG/TRPO) in several environments.<br />
<br />
- Opponent Modelling in Deep Reinforcement Learning: based on DQN to model opponents through a Deep Reinforcement Opponent Network (DRON).<br />
<br />
We will use Convolutional Neural Networks for data pre-processing, where we extract features from inputs. We will also be using Feed Forward Deep Networks along with Reinforcement learning frameworks in all the algorithms we implement (Deep Reinforcement learning).<br />
<br />
--------------------------------------------------------------------<br />
<br />
'''Project # 3'''<br />
Group members:<br />
<br />
Fisher, Wesley<br />
<br />
Pafla, Marvin<br />
<br />
Rajendran, Vidyasagar<br />
<br />
'''Title:''' Deep Reinforcement Learning for Angry Birds<br />
<br />
'''Description:''' According to Artificial Intelligence (AI) researchers, AI’s performance in the game Angry Birds will exceed human performance in the next 3-4 years [1]. We propose a final project that will hopefully bring us closer to this goal by developing an AI model based on deep reinforcement learning to play the game Angry Birds. While AI has been applied to Angry Birds in the past, there are only a few approaches that utilize deep learning such as in [3]. We plan to implement Yuan et al.’s recommendations by creating an Angry Birds reinforcement learning model with more learning dimensions [3]. To further add novelty to our research, we want to explore the potential of extending our model with evolutionary algorithms [2]. To realize this project, we plan to use an existing implementation of Angry Birds (either https://github.com/estevaofon/angry-birds-python or the one provided for the Angry Birds AI competition which can be found at https://aibirds.org).<br />
<br />
References:<br />
<br />
[1] Grace, K., Salvatier, J., Dafoe, A., Zhang, B., & Evans, O. (2017). When will AI exceed human performance? Evidence from AI experts. arXiv preprint arXiv:1705.08807.<br />
<br />
[2] Risi, S., & Togelius, J. (2017). Neuroevolution in games: State of the art and open challenges. IEEE Transactions on Computational Intelligence and AI in Games, 9(1), 25-41.<br />
<br />
[3] Yuan, Y., Chen, Z., Wu, P., & Chang, L. Enhancing Deep Reinforcement Learning Agent for Angry Birds. https://aibirds.org/2017/aibirds_BNU.pdf<br />
<br />
--------------------------------------------------------------------<br />
'''Project # 4'''<br />
Group members:<br />
<br />
Heydari, Nargess<br />
<br />
Manuel, Jacob<br />
<br />
Ravi, Aravind<br />
<br />
'''Title:''' Deep Learning for Detection of Steady State Visually Evoked Potentials<br />
<br />
'''Description:''' Brain Computer Interfaces (BCIs) enable users to control an external device by modulating their neuronal activity. Steady state visual evoked potential (SSVEP) based BCIs are of particular interest due to their high information transfer rate (ITR) and relatively low amount of training required for use. SSVEP responses are elicited when a user focuses on a flickering light source and are observed prominently in the occipitoparietal area of the cortex. These responses manifest as an increase in amplitude of the frequency components of the EEG signal at the stimulus frequency and harmonic frequencies. Therefore, by analyzing the frequency component dominant in the EEG signals recorded from occipitoparietal area, the stimulus with user’s visual engagement can be identified.The goal of this project is to identify and compare deep learning architectures for classifying SSVEP responses to use in BCIs. Different architectures will be compared with state of the art classification methods (e.g. Canonical Correlation Analysis) through a sensitivity analysis of their accuracy across multiple BCI variables (e.g. analysis window size, subject variability, size of training data, etc.). The goal of this comparison is to establish a new system design to support application of Deep Neural Networks in SSVEP-based BCI. The proposed study will be performed on the SSVEP dataset collected by the eBionics Lab at the University of Waterloo<br />
<br />
References<br />
<br />
N. S. Kwak, K. R. M ̈uller, and S. W. Lee, “A convolutional neural network for steady state visual evoked potential classification under ambulatory environment,” PLoS One, 2017.<br />
<br />
-------------------------------------------------------------------<br />
<br />
--------------------------------------------------------------------<br />
'''Project # 5'''<br />
Group members:<br />
<br />
Khan, Salman<br />
<br />
Naik, Abdul<br />
<br />
Koundinya, Shubham<br />
<br />
'''Title:''' Deep Learning for Image Captioning <br />
<br />
'''Description:''' Image captioning is the automatic generation of textual descriptions from images. It involves identifying the contents of an image, understanding relationships between what has been detected and generating textual descriptions.<br />
<br />
It is a challenging task as it includes both Computer Vision and Natural Language Processing components. Furthermore, an image can be described by multiple text statements. We will explore various state-of-the-art translation models focussing primarily on different ways of describing an image. <br />
<br />
<br />
References<br />
<br />
StyleNet: Generating Attractive Visual Captions with Styles.-https://ieeexplore.ieee.org/document/8099591.<br />
<br />
-------------------------------------------------------------------<br />
<br />
'''Project # 6'''<br />
Group members:<br />
<br />
Amirpasha Ghabussi<br />
<br />
Kumar, Dhruv<br />
<br />
Sahu, Gaurav<br />
<br />
Khan, Kashif<br />
<br />
'''Title:''' Bi-Directional Attention Flow for Question Answering<br />
<br />
'''Description:''' Question answering is a computer science discipline within the fields of information retrieval and natural language processing, which is concerned with building systems that automatically answer questions posed by humans in a natural language.<br />
<br />
We will try to improve the accuracy of the models that have shown promising results in most of the highly active datasets such as SQuAD or MS-Marco.<br />
<br />
References<br />
<br />
[1] Bi-Directional Attention Flow For Machine Comprehension - https://arxiv.org/abs/1611.01603<br />
<br />
[2] QANet : Combining Local Convolution With Global Self - Attention For Reading Comprehension - https://arxiv.org/abs/1804.09541<br />
<br />
<br />
'''Project # 7'''<br />
Group members:<br />
<br />
Minhas Manpreet Singh<br />
<br />
Budnarain Neil <br />
<br />
Ameli Soroush <br />
<br />
Rezapour Zahra <br />
<br />
'''Title:''' SELECT VIA PROXY: EFFICIENT DATA SELECTION FOR TRAINING DEEP NETWORKS<br />
<br />
'''Description:''' We shall be participating in the ICLR Reproducibility Challenge 2019. Abstract: At internet scale, applications collect a tremendous amount of data by logging user events, analyzing text, and collecting images. This data powers a variety of machine learning models for tasks such as image classification, language modeling, content recommendation, and advertising. However, training large models over all available data can be computationally expensive, creating a bottleneck in the development of new machine learning models. In this work, we develop a novel approach to efficiently select a subset of training data to achieve faster training with no loss in model predictive performance. In our approach, we first train a small proxy model quickly, which we then use to estimate the utility of individual training data points, and then select the most informative ones for training the large target model. Extensive experiments show that our approach leads to a 1.6× and 1.8× speed-up on CIFAR10 and SVHN by selecting 60% and 50% subsets of the data, while maintaining the predictive performance of the model trained on the entire dataset. Further, our method is robust to design choices. <br />
<br />
--------------------------------------------------------------------<br />
<br />
'''Project # 8'''<br />
Group members:<br />
<br />
Bhatt, Neel<br />
<br />
Chen, Henry<br />
<br />
Moosa, Johra Muhammad<br />
<br />
'''Title:''' Fast and Robust Pedestrian Detection: The Successor Of Fused-DNN+Semantic Segmentation<br />
<br />
'''Description:''' Object Detection in computer vision and image processing deals with identifying semantic objects such as buildings, cars, or humans in digital images and videos. Particularly, pedestrian detection has attracted much research interest in recent years due to its significance in robotics and autonomous driving applications. Consequently, the accuracy of pedestrian detection algorithms has improved significantly, and much of this progress seems to be driven by breakthroughs in Deep Neural Networks (DNNs) and the availability of open source pedestrian datasets. The current state-of-art model being the Fused-DNN+Semantic Segmentation mask, which achieves the lowest log-average miss rate (L-AMR) of 8.2, on the CALTECH pedestrian dataset [1]. While these advancements are impressive, many improvements can be made. For example, existing deep pedestrian detection models tend to rely on hand-crafted features and are generally hard to train. In addition, they seem to perform poorly when image quality is reduced or background interference is high. For this reason, we are proposing to survey deeper into the state-of-art pedestrian detection algorithms and ultimately propose an improved DNN model to address some of the limitations.<br />
<br />
Reference:<br />
<br />
[1] Du, Xianzhi, et al. "Fused DNN: A deep neural network fusion approach to fast and robust pedestrian detection." Applications of Computer Vision (WACV), 2017 IEEE Winter Conference on. IEEE, 2017.<br />
--------------------------------------------------------------------<br />
<br />
'''Project # 9'''<br />
Group members:<br />
<br />
Sigeng Chen<br />
<br />
'''Title:''' Humpback Whale Identification<br />
<br />
<br />
'''Description:''' It is an active Kaggle Challenge https://www.kaggle.com/c/humpback-whale-identification/submit<br />
--------------------------------------------------------------------<br />
<br />
'''Project # 10'''<br />
Group Members: Glen Chalatov, Ronnie Feng, Ki Beom Lee, Patrick Li<br />
<br />
'''Title:''' Approximation of Lift-and-Project Methods using Large Hidden Layers; A Comparison to Kernel Methods for Manifold Learning<br />
<br />
Kernel methods aim to transform features into a higher dimensional space with reasonable computational cost. Using the kernel trick, we can induce nonlinear patterns into our data through linear transformations. We wish to replicate the behaviour of such methods using neural networks. In contrast to autoencoders which perform dimensionality reduction on data, we use the opposite neural network structure to perform dimensionality lift (increase the dimensionality of our data) as follows:<br />
<br />
-Feed d-dimensional data into a neural network<br />
<br />
-Using hidden layer(s) of dimension p >> d, we attempt to project our data into higher dimensions<br />
<br />
-Using a d-dimensional output layer and a loss function that penalizes difference between the output and input, we tune our network and output the p-dimensional hidden layer to arrive at our lifted feature space.<br />
<br />
Our project will analyze and contrast the performance of our lifted feature space under a variety of conditions and applications. Our current hypothesis is that these methods will allow for greater flexibility in pattern recognition, but will be more prone to overfitting. If time allows, we will compare our method against traditional lift-and-project ideas from semidefinite optimization.<br />
<br />
--------------------------------------------------------------------<br />
<br />
'''Project # 11'''<br />
Group Members: Zheng Ma, Jiazhen Chen, Ruijie Zhang, Charupriya Sharma<br />
<br />
'''Title:''' Deep Learning Based Automatic Theorem Prover<br />
<br />
'''Description:''' <br />
<br />
“Formal logic is the science of deduction. It aims to provide systematic means for telling whether or not given conclusions follow from given premises, i.e., whether arguments are valid or invalid” [JEFFREY] <br />
<br />
Automatic theorem provers (ATPs) for first order logic have been an active research area in mathematics and computer science. In recent years, several efforts are made to incorporate machine learning into ATPs, boosting their performance. Rocktäschel and Riedel [1] gave an implementation of an Neural Theorem Prover (NTP), which is an end-to-end differentiable version of an automated theorem prover. NTPs are differentiable with respect to symbol representations in a knowledge base, which enables us to learn representations of symbols in ground atoms and parameters of first-order rules of predefined structure using backpropagation. S. Loos, et al. [2] tried to incorporate convolutional neural network into a refutation-based ATP, "E". In their work convolutional neural network is used to provide heuristics for the ATP, in place of human engineered heuristics. C. Kaliszyk, et al. [3] modified a tableaux based ATP (leanCoP), implementing reinforcement learning and Monte-Carlo search as guidance method.<br />
<br />
In our project we aim to investigate ATPs with deep learning as its guidence. In particular, we plan to adapt C. Kaliszyk's reinforcement learning implementation, and build a deep reinforcement learning based ATP, and compare its performance with other contemporary ATPs. We would like to investigate if and how deep learning can be used to help automatic reasoning.<br />
<br />
'''References'''<br />
<br />
[1] T. Rocktäschel, and S. Riedel. "Learning knowledge base inference with neural theorem provers." Proceedings of the 5th Workshop on Automated Knowledge Base Construction. 2016.<br />
<br />
[2] S. Loosm et al. "Deep Network Guided Proof Search." Proceedings of the 21st International onference on Logic for Programming, Artificial Intelligence and Reasoning. 2017.<br />
<br />
[3] C. Kaliszyk, et al. "Reinforcement Learning of Theorem Proving." NIPS 2018.<br />
--------------------------------------------------------------------<br />
<br />
'''Project # 12'''<br />
Group Members: Travis Bender, Ivan Li, Aileen Li, Xudong Peng<br />
<br />
'''Title:''' Quick, Draw! Doodle Recognition Challenge<br />
<br />
'''Description:''' <br />
<br />
Our project is based on the Kaggle Competition https://www.kaggle.com/c/quickdraw-doodle-recognition/leaderboard. The competition involves classifying doodle images that are generated by users in the game called Quick, Draw! made by Google. Convolution Neural Networks are an existing class of model that can effectively perform this task. In our project, we plan to explore different types of network architectures to build a quick and efficient classifier.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Fix_your_classifier:_the_marginal_value_of_training_the_last_weight_layer&diff=41999Fix your classifier: the marginal value of training the last weight layer2018-11-30T01:39:16Z<p>Gsahu: Move github link to the start</p>
<hr />
<div><br />
The code for the proposed model is available at https://github.com/eladhoffer/fix_your_classifier.<br />
<br />
=Introduction=<br />
<br />
Deep neural networks have become a widely used model for machine learning, achieving state-of-the-art results on many tasks. The most common task these models are used for is to perform classification, as in the case of convolutional neural networks (CNNs) being used to classify images to a semantic category. Typically, a learned affine transformation is placed at the end of such models, yielding a per-class value used for classification. This classifier can have a vast number of parameters, which grows linearly with the number of possible classes, thus requiring increasingly more computational resources.<br />
<br />
=Brief Overview=<br />
<br />
In order to alleviate the aforementioned problem, the authors propose that the final layer of the classifier be fixed (upto a global scale constant). They argue that with little or no loss of accuracy for most classification tasks, the method provides significant memory and computational benefits. In addition, they show that by initializing the classifier with a Hadamard matrix the inference could be made faster as well.<br />
<br />
=Previous Work=<br />
<br />
Training NN models and using them for inference requires large amounts of memory and computational resources; thus, extensive amount of research has been done lately to reduce the size of networks which are as follows:<br />
<br />
* Weight sharing and specification (Han et al., 2015)<br />
<br />
* Mixed precision to reduce the size of the neural networks by half (Micikevicius et al., 2017)<br />
<br />
* Low-rank approximations to speed up CNN (Tai et al., 2015)<br />
<br />
* Quantization of weights, activations and gradients to further reduce computation during training (Hubara et al., 2016b; Li et al., 2016 and Zhou et al., 2016)<br />
<br />
Some of the past works have also put forward the fact that predefined (Park & Sandberg, 1991) and random (Huang et al., 2006) projections can be used together with a learned affine transformation to achieve competitive results on many of the classification tasks. However, the authors' proposal in the current paper is quite reversed.<br />
<br />
=Background=<br />
<br />
Convolutional neural networks (CNNs) are commonly used to solve a variety of spatial and temporal tasks. CNNs are usually composed of a stack of convolutional parameterized layers, spatial pooling layers and fully connected layers, separated by non-linear activation functions. Earlier architectures of CNNs (LeCun et al., 1998; Krizhevsky et al., 2012) used a set of fully-connected layers at later stage of the network, presumably to allow classification based on global features of an image.<br />
<br />
== Shortcomings of the Final Classification Layer and its Solution ==<br />
<br />
Despite the enormous number of trainable parameters these layers added to the model, they are known to have a rather marginal impact on the final performance of the network (Zeiler & Fergus, 2014).<br />
<br />
It has been shown previously that these layers could be easily compressed and reduced after a model was trained by simple means such as matrix decomposition and sparsification (Han et al., 2015). Modern architecture choices are characterized with the removal of most of the fully connected layers (Lin et al., 2013; Szegedy et al., 2015; He et al., 2016), that lead to better generalization and overall accuracy, together with a huge decrease in the number of trainable parameters. Additionally, numerous works showed that CNNs can be trained in a metric learning regime (Bromley et al., 1994; Schroff et al., 2015; Hoffer & Ailon, 2015), where no explicit classification layer was introduced and the objective regarded only distance measures between intermediate representations. Hardt & Ma (2017) suggested an all-convolutional network variant, where they kept the original initialization of the classification layer fixed with no negative impact on performance on the CIFAR-10 dataset.<br />
<br />
=Proposed Method=<br />
<br />
The aforementioned works provide evidence that fully-connected layers are in fact redundant and play a small role in learning and generalization. In this work, the authors have suggested that parameters used for the final classification transform are completely redundant, and can be replaced with a predetermined linear transform. This holds for even in large-scale models and classification tasks, such as recent architectures trained on the ImageNet benchmark (Deng et al., 2009).<br />
<br />
==Using a Fixed Classifier==<br />
<br />
Suppose the final representation obtained by the network (the last hidden layer) is represented as <math>x = F(z;\theta)</math> where <math>F</math> is assumed to be a deep neural network with input z and parameters θ, e.g., a convolutional network, trained by backpropagation.<br />
<br />
In common NN models, this representation is followed by an additional affine transformation, <math>y = W^T x + b</math> ,where <math>W</math> and <math>b</math> are also trained by back-propagation.<br />
<br />
For input <math>x</math> of <math>N</math> length, and <math>C</math> different possible outputs, <math>W</math> is required to be a matrix of <math>N ×<br />
C</math>. Training is done using cross-entropy loss, by feeding the network outputs through a softmax activation<br />
<br />
<math><br />
v_i = \frac{e^{y_i}}{\sum_{j}^{C}{e^{y_j}}}, i &isin; </math> { <math> {1, . . . , C} </math> }<br />
<br />
and reducing the expected negative log likelihood with respect to ground-truth target <math> t &isin; </math> { <math> {1, . . . , C} </math> },<br />
by minimizing the loss function:<br />
<br />
<math><br />
L(x, t) = −\text{log}\ {v_t} = −{w_t}·{x} − b_t + \text{log} ({\sum_{j}^{C}e^{w_j . x + b_j}})<br />
</math><br />
<br />
where <math>w_i</math> is the <math>i</math>-th column of <math>W</math>.<br />
<br />
==Choosing the Projection Matrix==<br />
<br />
To evaluate the conjecture regarding the importance of the final classification transformation, the trainable parameter matrix <math>W</math> is replaced with a fixed orthonormal projection <math> Q &isin; R^{N×C} </math>, such that <math> &forall; i &ne; j : q_i · q_j = 0 </math> and <math> || q_i ||_{2} = 1 </math>, where <math>q_i</math> is the <math>i</math>th column of <math>Q</math>. This is ensured by a simple random sampling and singular-value decomposition<br />
<br />
As the rows of classifier weight matrix are fixed with an equally valued <math>L_{2}</math> norm, we find it beneficial<br />
to also restrict the representation of <math>x</math> by normalizing it to reside on the <math>n</math>-dimensional sphere:<br />
<br />
<center><math><br />
\hat{x} = \frac{x}{||x||_{2}}<br />
</math></center><br />
<br />
This allows faster training and convergence, as the network does not need to account for changes in the scale of its weights. However, it has now an issue that <math>q_i · \hat{x} </math> is bounded between −1 and 1. This causes convergence issues, as the softmax function is scale sensitive, and the network is affected by the inability to re-scale its input. This issue is amended with a fixed scale <math>T</math> applied to softmax inputs <math>f(y) = softmax(\frac{1}{T}y)</math>, also known as the ''softmax temperature''. However, this introduces an additional hyper-parameter which may differ between networks and datasets. So, the authors propose to introduce a single scalar parameter <math>\alpha</math> to learn the softmax scale, effectively functioning as an inverse of the softmax temperature <math>\frac{1}{T}</math>. The normalized weights and an additional scale coefficient are also used, specially using a single scale for all entries in the weight matrix. The additional vector of bias parameters <math>b &isin; \mathbb{R}^{C}</math> is kept the same and the model is trained using the traditional negative-log-likelihood criterion. Explicitly, the classifier output is now:<br />
<br />
<center><br />
<math><br />
v_i=\frac{e^{\alpha q_i &middot; \hat{x} + b_i}}{\sum_{j}^{C} e^{\alpha q_j &middot; \hat{x} + b_j}}, i &isin; </math> { <math> {1,...,C} </math>}<br />
</center><br />
<br />
and the loss to be minimized is:<br />
<br />
<center><math><br />
L(x, t) = -\alpha q_t &middot; \frac{x}{||x||_{2}} + b_t + \text{log} (\sum_{i=1}^{C} \text{exp}((\alpha q_i &middot; \frac{x}{||x||_{2}} + b_i)))<br />
</math></center><br />
<br />
where <math>x</math> is the final representation obtained by the network for a specific sample, and <math> t &isin; </math> { <math> {1, . . . , C} </math> } is the ground-truth label for that sample. The behaviour of the parameter <math> \alpha </math> over time, which is logarithmic in nature and has the same behavior exhibited by the norm of a learned classifier, is shown in<br />
[[Media: figure1_log_behave.png| Figure 1]].<br />
<br />
<center>[[File:figure1_log_behave.png]]</center><br />
<br />
When <math> -1 \le q_i · \hat{x} \le 1 </math>, a possible cosine angle loss is <br />
<br />
<center>[[File:caloss.png]]</center><br />
<br />
But its final validation accuracy has slight decrease, compared to original models.<br />
<br />
==Using a Hadmard Matrix==<br />
<br />
To recall, Hadmard matrix (Hedayat et al., 1978) <math> H </math> is an <math> n × n </math> matrix, where all of its entries are either +1 or −1.<br />
Furthermore, <math> H </math> is orthogonal, such that <math> HH^{T} = nI_n </math> where <math>I_n</math> is the identity matrix. Instead of using the entire Hadmard matrix <math>H</math>, a truncated version, <math> \hat{H} &isin; </math> {<math> {-1, 1}</math>}<math>^{C \times N}</math> where all <math>C</math> rows are orthogonal as the final classification layer is such that:<br />
<br />
<center><math><br />
y = \hat{H} \hat{x} + b<br />
</math></center><br />
<br />
This usage allows two main benefits:<br />
* A deterministic, low-memory and easily generated matrix that can be used for classification.<br />
* Removal of the need to perform a full matrix-matrix multiplication - as multiplying by a Hadamard matrix can be done by simple sign manipulation and addition.<br />
<br />
Here, <math>n</math> must be a multiple of 4, but it can be easily truncated to fit normally defined networks. Also, as the classifier weights are fixed to need only 1-bit precision, it is now possible to focus our attention on the features preceding it.<br />
<br />
=Experimental Results=<br />
<br />
The authors have evaluated their proposed model on the following datasets:<br />
<br />
==CIFAR-10/100==<br />
<br />
===About the Dataset===<br />
<br />
CIFAR-10 is an image classification benchmark dataset containing 50,000 training images and 10,000 test images. The images are in color and contain 32×32 pixels. There are 10 possible classes of various animals and vehicles. CIFAR-100 holds the same number of images of same size, but contains 100 different classes.<br />
<br />
===Training Details===<br />
<br />
The authors trained a residual network ( He et al., 2016) on the CIFAR-10 dataset. The network depth was 56 and the same hyper-parameters as in the original work were used. A comparison of the two variants, i.e., the learned classifier and the proposed classifier with a fixed transformation is shown in [[Media: figure1_resnet_cifar10.png | Figure 2]].<br />
<br />
<center>[[File: figure1_resnet_cifar10.png]]</center><br />
<br />
These results demonstrate that although the training error is considerably lower for the network with learned classifier, both models achieve the same classification accuracy on the validation set. The authors conjecture is that with the new fixed parameterization, the network can no longer increase the<br />
norm of a given sample’s representation - thus learning its label requires more effort. As this may happen for specific seen samples - it affects only training error.<br />
<br />
The authors also compared using a fixed scale variable <math>\alpha </math> at different values vs. the learned parameter. Results for <math> \alpha = </math> {0.1, 1, 10} are depicted in [[Media: figure3_alpha_resnet_cifar.png| Figure 3]] for both training and validation error and as can be seen, similar validation accuracy can be obtained using a fixed scale value (in this case <math>\alpha </math>= 1 or 10 will suffice) at the expense of another hyper-parameter to seek. In all the further experiments the scaling parameter <math> \alpha </math> was regularized with the same weight decay coefficient used on original classifier. Although learning the scale is not necessary, but it will help convergence during training.<br />
<br />
<center>[[File: figure3_alpha_resnet_cifar.png]]</center><br />
<br />
The authors then train the model on CIFAR-100 dataset. They used the DenseNet-BC model from Huang et al. (2017) with depth of 100 layers and k = 12. The higher number of classes caused the number of parameters to grow and encompassed about 4% of the whole model. However, validation accuracy for the fixed-classifier model remained equally good as the original model, and the same training curve was observed as earlier.<br />
<br />
==IMAGENET==<br />
<br />
===About the Dataset===<br />
<br />
The Imagenet dataset introduced by Deng et al. (2009) spans over 1000 visual classes, and over 1.2 million samples. This is supposedly a more challenging dataset to work on as compared to CIFAR-10/100.<br />
<br />
===Experiment Details===<br />
<br />
The authors evaluated their fixed classifier method on Imagenet using Resnet50 by He et al. (2016) and Densenet169 model (Huang et al., 2017) as described in the original work. Using a fixed classifier removed approximately 2-million parameters were from the model, accounting for about 8% and 12 % of the model parameters respectively. The experiments revealed similar trends as observed on CIFAR-10.<br />
<br />
For a more stricter evaluation, the authors also trained a Shufflenet architecture (Zhang et al., 2017b), which was designed to be used in low memory and limited computing platforms and has parameters making up the majority of the model. They were able to reduce the parameters to 0.86 million as compared to 0.96 million parameters in the final layer of the original model. Again, the proposed modification in the original model gave similar convergence results on validation accuracy.<br />
<br />
The overall results of the fixed-classifier are summarized in [[Media: table1_fixed_results.png | Table 1]].<br />
<br />
<center>[[File: table1_fixed_results.png]]</center><br />
<br />
==Language Modelling==<br />
<br />
Recent works have empirically found that using the same weights for both word embedding and classifier can yield equal or better results than using a separate pair of weights. So the authors experimented with fix-classifiers on language modelling as it also requires classification of all possible tokens available in the task vocabulary. They trained a recurrent model with 2-layers of LSTM (Hochreiter & Schmidhuber, 1997) and embedding + hidden size of 512 on the WikiText2 dataset (Merity et al., 2016), using same settings as in Merity et al. (2017). WikiText2 dataset contains about 33K different words, so the number of parameters expected in the embedding and classifier layer was about 34-million. This number is about 89% of the total number of parameters used for the whole model which is 38-million. However, using a random orthogonal transform yielded poor results compared to learned embedding. This was suspected to be due to semantic relationships captured in the embedding layer of language models, which is not the case in image classification task. The intuition was further confirmed by the much better results when pre-trained embeddings using word2vec algorithm by Mikolov et al. (2013) or PMI factorization as suggested by Levy & Goldberg (2014), were used.<br />
<br />
=Discussion=<br />
<br />
==Implications and Use Cases==<br />
<br />
With the increasing number of classes in the benchmark datasets, computational demands for the final classifier will increase as well. In order to understand the problem better, the authors observe the work by Sun et al. (2017), which introduced JFT-300M - an internal Google dataset with over 18K different classes. Using a Resnet50 (He et al., 2016), with a 2048 sized representation led to a model with over 36M parameters meaning that over 60% of the model parameters resided in the final classification layer. Sun et al. (2017) also describe the difficulty in distributing so many parameters over the training servers involving a non-trivial overhead during synchronization of the model for update. The authors claim that the fixed-classifier would help considerably in this kind of scenario - where using a fixed classifier removes the need to do any gradient synchronization for the final layer. Furthermore, introduction of Hadamard matrix removes the need to save the transformation altogether, thereby, making it more efficient and allowing considerable memory and computational savings.<br />
<br />
==Possible Caveats==<br />
<br />
The good performance of fixed-classifiers relies on the ability of the preceding layers to learn separable representations. This could be affected when when the ratio between learned features and number of classes is small – that is, when <math> C > N</math>. However, they tested their method in such cases and their model performed well and provided good results.<br />
Another factor that can affect the performance of their model using a fixed classifier is when the classes are highly correlated. In that case, the fixed classifier actually cannot support correlated classes and thus, the network could have some difficulty to learn. For a language model, word classes tend to have highly correlated instances, which also lead to difficult learning process.<br />
<br />
Also, this proposed approach will only eliminate the computation of the classifier weights, so when the classes are fewer, the computation saving effect will not be readily apparent.<br />
<br />
==Future Work==<br />
<br />
<br />
The use of fixed classifiers might be further simplified in Binarized Neural Networks (Hubara et al., 2016a), where the activations and weights are restricted to ±1 during propagations. In that case the norm of the last hidden layer would be constant for all samples (equal to the square root of the hidden layer width). The constant could then be absorbed into the scale constant <math>\alpha</math>, and there is no need in a per-sample normalization.<br />
<br />
Additionally, more efficient ways to learn a word embedding should also be explored where similar redundancy in classifier weights may suggest simpler forms of token representations - such as low-rank or sparse versions.<br />
<br />
A related paper was published that claims that fixing most of the parameters of the neural network achieves comparable results with learning all of them [A. Rosenfeld and J. K. Tsotsos]<br />
<br />
=Conclusion=<br />
<br />
In this work, the authors argue that the final classification layer in deep neural networks is redundant and suggest removing the parameters from the classification layer. The empirical results from experiments on the CIFAR and IMAGENET datasets suggest that such a change lead to little or almost no decline in the performance of the architecture. Furthermore, using a Hadmard matrix as classifier might lead to some computational benefits when properly implemented, and save memory otherwise spent on large amount of transformation coefficients.<br />
<br />
Another possible scope of research that could be pointed out for future could be to find new efficient methods to create pre-defined word embeddings, which require huge amount of parameters that can possibly be avoided when learning a new task. Therefore, more emphasis should be given to the representations learned by the non-linear parts of the neural networks - upto the final classifier, as it seems highly redundant.<br />
<br />
=Critique=<br />
<br />
The paper proposes an interesting idea that has a potential use case when designing memory-efficient neural networks. The experiments shown in the paper are quite rigorous and provide support to the authors' claim. However, it would have been more helpful if the authors had described a bit more about efficient implementation of the Hadamard matrix and how to scale this method for larger datasets (cases with <math> C >N</math>).<br />
<br />
=References=<br />
<br />
Madhu S Advani and Andrew M Saxe. High-dimensional dynamics of generalization error in neural networks. arXiv preprint arXiv:1710.03667, 2017.<br />
<br />
Peter Bartlett, Dylan J Foster, and Matus Telgarsky. Spectrally-normalized margin bounds for neural networks. arXiv preprint arXiv:1706.08498, 2017.<br />
<br />
Jane Bromley, Isabelle Guyon, Yann LeCun, Eduard Sackinger, and Roopak Shah. Signature verification using a” siamese” time delay neural network. In Advances in Neural Information Processing Systems, pp. 737–744, 1994.<br />
<br />
Matthieu Courbariaux, Yoshua Bengio, and Jean-Pierre David. Binaryconnect: Training deep neural networks with binary weights during propagations. In Advances in Neural Information Processing Systems, pp. 3123–3131, 2015.<br />
<br />
Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-scale hierarchical image database. In Computer Vision and Pattern Recognition, 2009. CVPR 2009. IEEE Conference on, pp. 248–255. IEEE, 2009.<br />
<br />
Suriya Gunasekar, Blake Woodworth, Srinadh Bhojanapalli, Behnam Neyshabur, and Nathan Srebro. Implicit regularization in matrix factorization. arXiv preprint arXiv:1705.09280, 2017.<br />
<br />
Song Han, Huizi Mao, and William J Dally. Deep compression: Compressing deep neural networks with pruning, trained quantization and huffman coding. arXiv preprint arXiv:1510.00149, 2015.<br />
<br />
Moritz Hardt and Tengyu Ma. Identity matters in deep learning. 2017.<br />
<br />
Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770–778, 2016.<br />
<br />
A Hedayat, WD Wallis, et al. Hadamard matrices and their applications. The Annals of Statistics, 6<br />
(6):1184–1238, 1978.<br />
<br />
Sepp Hochreiter and Jurgen Schmidhuber. Long short-term memory. ¨ Neural computation, 9(8): 1735–1780, 1997.<br />
<br />
Elad Hoffer and Nir Ailon. Deep metric learning using triplet network. In International Workshop on Similarity-Based Pattern Recognition, pp. 84–92. Springer, 2015.<br />
<br />
Elad Hoffer, Itay Hubara, and Daniel Soudry. Train longer, generalize better: closing the generalization gap in large batch training of neural networks. 2017.<br />
<br />
Andrew G Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, Tobias Weyand, Marco Andreetto, and Hartwig Adam. Mobilenets: Efficient convolutional neural networks for mobile vision applications. arXiv preprint arXiv:1704.04861, 2017.<br />
<br />
Gao Huang, Zhuang Liu, Laurens van der Maaten, and Kilian Q Weinberger. Densely connected convolutional networks. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 2017.<br />
<br />
Guang-Bin Huang, Qin-Yu Zhu, and Chee-Kheong Siew. Extreme learning machine: theory and applications. Neurocomputing, 70(1):489–501, 2006.<br />
<br />
Itay Hubara, Matthieu Courbariaux, Daniel Soudry, Ran El-Yaniv, and Yoshua Bengio. Binarized neural networks. In Advances in Neural Information Processing Systems 29 (NIPS’16), 2016a.<br />
<br />
Itay Hubara, Matthieu Courbariaux, Daniel Soudry, Ran El-Yaniv, and Yoshua Bengio. Quantized neural networks: Training neural networks with low precision weights and activations. arXiv preprint arXiv:1609.07061, 2016b.<br />
<br />
Hakan Inan, Khashayar Khosravi, and Richard Socher. Tying word vectors and word classifiers: A loss framework for language modeling. arXiv preprint arXiv:1611.01462, 2016.<br />
<br />
Max Jaderberg, Andrea Vedaldi, and Andrew Zisserman. Speeding up convolutional neural networks with low rank expansions. arXiv preprint arXiv:1405.3866, 2014.<br />
<br />
Alex Krizhevsky. Learning multiple layers of features from tiny images. 2009.<br />
<br />
Alex Krizhevsky, Ilya Sutskever, and Geoffrey E Hinton. Imagenet classification with deep convolutional neural networks. In Advances in neural information processing systems, pp. 1097–1105, 2012.<br />
<br />
Yann LeCun, Leon Bottou, Yoshua Bengio, and Patrick Haffner. Gradient-based learning applied to ´ document recognition. Proceedings of the IEEE, 86(11):2278 2324, 1998.<br />
<br />
Omer Levy and Yoav Goldberg. Neural word embedding as implicit matrix factorization. In Advances in neural information processing systems, pp. 2177–2185, 2014.<br />
<br />
Fengfu Li, Bo Zhang, and Bin Liu. Ternary weight networks. arXiv preprint arXiv:1605.04711, 2016.<br />
<br />
Min Lin, Qiang Chen, and Shuicheng Yan. Network in network. arXiv preprint arXiv:1312.4400, 2013.<br />
<br />
Stephen Merity, Caiming Xiong, James Bradbury, and Richard Socher. Pointer sentinel mixture models. arXiv preprint arXiv:1609.07843, 2016.<br />
<br />
Stephen Merity, Nitish Shirish Keskar, and Richard Socher. Regularizing and Optimizing LSTM Language Models. arXiv preprint arXiv:1708.02182, 2017.<br />
<br />
Paulius Micikevicius, Sharan Narang, Jonah Alben, Gregory Diamos, Erich Elsen, David Garcia, Boris Ginsburg, Michael Houston, Oleksii Kuchaev, Ganesh Venkatesh, et al. Mixed precision training. arXiv preprint arXiv:1710.03740, 2017.<br />
<br />
Tomas Mikolov, Ilya Sutskever, Kai Chen, Greg S Corrado, and Jeff Dean. Distributed tations of words and phrases and their compositionality. In Advances in neural information processing systems, pp. 3111–3119, 2013.<br />
<br />
Behnam Neyshabur, Srinadh Bhojanapalli, David McAllester, and Nathan Srebro. Exploring generalization in deep learning. arXiv preprint arXiv:1706.08947, 2017.<br />
Jooyoung Park and Irwin W Sandberg. Universal approximation using radial-basis-function networks. Neural computation, 3(2):246–257, 1991.<br />
<br />
Ofir Press and Lior Wolf. Using the output embedding to improve language models. EACL 2017,<br />
pp. 157, 2017.<br />
<br />
Itay Safran and Ohad Shamir. On the quality of the initial basin in overspecified neural networks. In International Conference on Machine Learning, pp. 774–782, 2016.<br />
<br />
Tim Salimans and Diederik P Kingma. Weight normalization: A simple reparameterization to accelerate training of deep neural networks. In Advances in Neural Information Processing Systems, pp. 901–909, 2016.<br />
<br />
Florian Schroff, Dmitry Kalenichenko, and James Philbin. Facenet: A unified embedding for face recognition and clustering. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 815–823, 2015.<br />
<br />
Mahdi Soltanolkotabi, Adel Javanmard, and Jason D Lee. Theoretical insights into the optimization landscape of over-parameterized shallow neural networks. arXiv preprint arXiv:1707.04926, 2017.<br />
<br />
Daniel Soudry and Yair Carmon. No bad local minima: Data independent training error guarantees for multilayer neural networks. arXiv preprint arXiv:1605.08361, 2016.<br />
<br />
Daniel Soudry and Elad Hoffer. Exponentially vanishing sub-optimal local minima in multilayer neural networks. arXiv preprint arXiv:1702.05777, 2017.<br />
<br />
Daniel Soudry, Elad Hoffer, and Nathan Srebro. The implicit bias of gradient descent on separable data. 2018.<br />
<br />
Jost Tobias Springenberg, Alexey Dosovitskiy, Thomas Brox, and Martin Riedmiller. Striving for simplicity: The all convolutional net. arXiv preprint arXiv:1412.6806, 2014.<br />
<br />
Chen Sun, Abhinav Shrivastava, Saurabh Singh, and Abhinav Gupta. Revisiting unreasonable effectiveness of data in deep learning era. arXiv preprint arXiv:1707.02968, 2017.<br />
<br />
Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, and Andrew Rabinovich. Going deeper with convolutions. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 1–9, 2015.<br />
<br />
Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jon Shlens, and Zbigniew Wojna. Rethinking the inception architecture for computer vision. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2818–2826, 2016.<br />
<br />
Cheng Tai, Tong Xiao, Yi Zhang, Xiaogang Wang, et al. Convolutional neural networks with lowrank regularization. arXiv preprint arXiv:1511.06067, 2015.<br />
<br />
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. 2017.<br />
Ashia C Wilson, Rebecca Roelofs, Mitchell Stern, Nathan Srebro, and Benjamin Recht. The marginal value of adaptive gradient methods in machine learning. arXiv preprint arXiv:1705.08292, 2017.<br />
<br />
Bo Xie, Yingyu Liang, and Le Song. Diversity leads to generalization in neural networks. arXiv preprint arXiv:1611.03131, 2016.<br />
<br />
Matthew D Zeiler and Rob Fergus. Visualizing and understanding convolutional networks. In European conference on computer vision, pp. 818–833. Springer, 2014. Chiyuan Zhang, Samy Bengio, Moritz Hardt, Benjamin Recht, and Oriol Vinyals. Understanding deep learning requires rethinking generalization. In ICLR, 2017a. URL https://arxiv.org/abs/1611.03530.<br />
<br />
Xiangyu Zhang, Xinyu Zhou, Mengxiao Lin, and Jian Sun. Shufflenet: An extremely efficient convolutional neural network for mobile devices. arXiv preprint arXiv:1707.01083, 2017b.<br />
<br />
Shuchang Zhou, Zekun Ni, Xinyu Zhou, He Wen, Yuxin Wu, and Yuheng Zou. Dorefa-net: Training low bitwidth convolutional neural networks with low bitwidth gradients. arXiv preprint arXiv:1606.06160, 2016.<br />
<br />
A. Rosenfeld and J. K. Tsotsos, “Intriguing properties of randomly weighted networks: Generalizing while learning next to nothing,” arXiv preprint arXiv:1802.00844, 2018.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Fix_your_classifier:_the_marginal_value_of_training_the_last_weight_layer&diff=41998Fix your classifier: the marginal value of training the last weight layer2018-11-30T01:38:04Z<p>Gsahu: Remove repitition</p>
<hr />
<div>=Introduction=<br />
<br />
Deep neural networks have become a widely used model for machine learning, achieving state-of-the-art results on many tasks. The most common task these models are used for is to perform classification, as in the case of convolutional neural networks (CNNs) being used to classify images to a semantic category. Typically, a learned affine transformation is placed at the end of such models, yielding a per-class value used for classification. This classifier can have a vast number of parameters, which grows linearly with the number of possible classes, thus requiring increasingly more computational resources.<br />
<br />
=Brief Overview=<br />
<br />
In order to alleviate the aforementioned problem, the authors propose that the final layer of the classifier be fixed (upto a global scale constant). They argue that with little or no loss of accuracy for most classification tasks, the method provides significant memory and computational benefits. In addition, they show that by initializing the classifier with a Hadamard matrix the inference could be made faster as well.<br />
<br />
=Previous Work=<br />
<br />
Training NN models and using them for inference requires large amounts of memory and computational resources; thus, extensive amount of research has been done lately to reduce the size of networks which are as follows:<br />
<br />
* Weight sharing and specification (Han et al., 2015)<br />
<br />
* Mixed precision to reduce the size of the neural networks by half (Micikevicius et al., 2017)<br />
<br />
* Low-rank approximations to speed up CNN (Tai et al., 2015)<br />
<br />
* Quantization of weights, activations and gradients to further reduce computation during training (Hubara et al., 2016b; Li et al., 2016 and Zhou et al., 2016)<br />
<br />
Some of the past works have also put forward the fact that predefined (Park & Sandberg, 1991) and random (Huang et al., 2006) projections can be used together with a learned affine transformation to achieve competitive results on many of the classification tasks. However, the authors' proposal in the current paper is quite reversed.<br />
<br />
=Background=<br />
<br />
Convolutional neural networks (CNNs) are commonly used to solve a variety of spatial and temporal tasks. CNNs are usually composed of a stack of convolutional parameterized layers, spatial pooling layers and fully connected layers, separated by non-linear activation functions. Earlier architectures of CNNs (LeCun et al., 1998; Krizhevsky et al., 2012) used a set of fully-connected layers at later stage of the network, presumably to allow classification based on global features of an image.<br />
<br />
== Shortcomings of the Final Classification Layer and its Solution ==<br />
<br />
Despite the enormous number of trainable parameters these layers added to the model, they are known to have a rather marginal impact on the final performance of the network (Zeiler & Fergus, 2014).<br />
<br />
It has been shown previously that these layers could be easily compressed and reduced after a model was trained by simple means such as matrix decomposition and sparsification (Han et al., 2015). Modern architecture choices are characterized with the removal of most of the fully connected layers (Lin et al., 2013; Szegedy et al., 2015; He et al., 2016), that lead to better generalization and overall accuracy, together with a huge decrease in the number of trainable parameters. Additionally, numerous works showed that CNNs can be trained in a metric learning regime (Bromley et al., 1994; Schroff et al., 2015; Hoffer & Ailon, 2015), where no explicit classification layer was introduced and the objective regarded only distance measures between intermediate representations. Hardt & Ma (2017) suggested an all-convolutional network variant, where they kept the original initialization of the classification layer fixed with no negative impact on performance on the CIFAR-10 dataset.<br />
<br />
=Proposed Method=<br />
<br />
The aforementioned works provide evidence that fully-connected layers are in fact redundant and play a small role in learning and generalization. In this work, the authors have suggested that parameters used for the final classification transform are completely redundant, and can be replaced with a predetermined linear transform. This holds for even in large-scale models and classification tasks, such as recent architectures trained on the ImageNet benchmark (Deng et al., 2009).<br />
<br />
==Using a Fixed Classifier==<br />
<br />
Suppose the final representation obtained by the network (the last hidden layer) is represented as <math>x = F(z;\theta)</math> where <math>F</math> is assumed to be a deep neural network with input z and parameters θ, e.g., a convolutional network, trained by backpropagation.<br />
<br />
In common NN models, this representation is followed by an additional affine transformation, <math>y = W^T x + b</math> ,where <math>W</math> and <math>b</math> are also trained by back-propagation.<br />
<br />
For input <math>x</math> of <math>N</math> length, and <math>C</math> different possible outputs, <math>W</math> is required to be a matrix of <math>N ×<br />
C</math>. Training is done using cross-entropy loss, by feeding the network outputs through a softmax activation<br />
<br />
<math><br />
v_i = \frac{e^{y_i}}{\sum_{j}^{C}{e^{y_j}}}, i &isin; </math> { <math> {1, . . . , C} </math> }<br />
<br />
and reducing the expected negative log likelihood with respect to ground-truth target <math> t &isin; </math> { <math> {1, . . . , C} </math> },<br />
by minimizing the loss function:<br />
<br />
<math><br />
L(x, t) = −\text{log}\ {v_t} = −{w_t}·{x} − b_t + \text{log} ({\sum_{j}^{C}e^{w_j . x + b_j}})<br />
</math><br />
<br />
where <math>w_i</math> is the <math>i</math>-th column of <math>W</math>.<br />
<br />
==Choosing the Projection Matrix==<br />
<br />
To evaluate the conjecture regarding the importance of the final classification transformation, the trainable parameter matrix <math>W</math> is replaced with a fixed orthonormal projection <math> Q &isin; R^{N×C} </math>, such that <math> &forall; i &ne; j : q_i · q_j = 0 </math> and <math> || q_i ||_{2} = 1 </math>, where <math>q_i</math> is the <math>i</math>th column of <math>Q</math>. This is ensured by a simple random sampling and singular-value decomposition<br />
<br />
As the rows of classifier weight matrix are fixed with an equally valued <math>L_{2}</math> norm, we find it beneficial<br />
to also restrict the representation of <math>x</math> by normalizing it to reside on the <math>n</math>-dimensional sphere:<br />
<br />
<center><math><br />
\hat{x} = \frac{x}{||x||_{2}}<br />
</math></center><br />
<br />
This allows faster training and convergence, as the network does not need to account for changes in the scale of its weights. However, it has now an issue that <math>q_i · \hat{x} </math> is bounded between −1 and 1. This causes convergence issues, as the softmax function is scale sensitive, and the network is affected by the inability to re-scale its input. This issue is amended with a fixed scale <math>T</math> applied to softmax inputs <math>f(y) = softmax(\frac{1}{T}y)</math>, also known as the ''softmax temperature''. However, this introduces an additional hyper-parameter which may differ between networks and datasets. So, the authors propose to introduce a single scalar parameter <math>\alpha</math> to learn the softmax scale, effectively functioning as an inverse of the softmax temperature <math>\frac{1}{T}</math>. The normalized weights and an additional scale coefficient are also used, specially using a single scale for all entries in the weight matrix. The additional vector of bias parameters <math>b &isin; \mathbb{R}^{C}</math> is kept the same and the model is trained using the traditional negative-log-likelihood criterion. Explicitly, the classifier output is now:<br />
<br />
<center><br />
<math><br />
v_i=\frac{e^{\alpha q_i &middot; \hat{x} + b_i}}{\sum_{j}^{C} e^{\alpha q_j &middot; \hat{x} + b_j}}, i &isin; </math> { <math> {1,...,C} </math>}<br />
</center><br />
<br />
and the loss to be minimized is:<br />
<br />
<center><math><br />
L(x, t) = -\alpha q_t &middot; \frac{x}{||x||_{2}} + b_t + \text{log} (\sum_{i=1}^{C} \text{exp}((\alpha q_i &middot; \frac{x}{||x||_{2}} + b_i)))<br />
</math></center><br />
<br />
where <math>x</math> is the final representation obtained by the network for a specific sample, and <math> t &isin; </math> { <math> {1, . . . , C} </math> } is the ground-truth label for that sample. The behaviour of the parameter <math> \alpha </math> over time, which is logarithmic in nature and has the same behavior exhibited by the norm of a learned classifier, is shown in<br />
[[Media: figure1_log_behave.png| Figure 1]].<br />
<br />
<center>[[File:figure1_log_behave.png]]</center><br />
<br />
When <math> -1 \le q_i · \hat{x} \le 1 </math>, a possible cosine angle loss is <br />
<br />
<center>[[File:caloss.png]]</center><br />
<br />
But its final validation accuracy has slight decrease, compared to original models.<br />
<br />
==Using a Hadmard Matrix==<br />
<br />
To recall, Hadmard matrix (Hedayat et al., 1978) <math> H </math> is an <math> n × n </math> matrix, where all of its entries are either +1 or −1.<br />
Furthermore, <math> H </math> is orthogonal, such that <math> HH^{T} = nI_n </math> where <math>I_n</math> is the identity matrix. Instead of using the entire Hadmard matrix <math>H</math>, a truncated version, <math> \hat{H} &isin; </math> {<math> {-1, 1}</math>}<math>^{C \times N}</math> where all <math>C</math> rows are orthogonal as the final classification layer is such that:<br />
<br />
<center><math><br />
y = \hat{H} \hat{x} + b<br />
</math></center><br />
<br />
This usage allows two main benefits:<br />
* A deterministic, low-memory and easily generated matrix that can be used for classification.<br />
* Removal of the need to perform a full matrix-matrix multiplication - as multiplying by a Hadamard matrix can be done by simple sign manipulation and addition.<br />
<br />
Here, <math>n</math> must be a multiple of 4, but it can be easily truncated to fit normally defined networks. Also, as the classifier weights are fixed to need only 1-bit precision, it is now possible to focus our attention on the features preceding it.<br />
<br />
=Experimental Results=<br />
<br />
The authors have evaluated their proposed model on the following datasets:<br />
<br />
==CIFAR-10/100==<br />
<br />
===About the Dataset===<br />
<br />
CIFAR-10 is an image classification benchmark dataset containing 50,000 training images and 10,000 test images. The images are in color and contain 32×32 pixels. There are 10 possible classes of various animals and vehicles. CIFAR-100 holds the same number of images of same size, but contains 100 different classes.<br />
<br />
===Training Details===<br />
<br />
The authors trained a residual network ( He et al., 2016) on the CIFAR-10 dataset. The network depth was 56 and the same hyper-parameters as in the original work were used. A comparison of the two variants, i.e., the learned classifier and the proposed classifier with a fixed transformation is shown in [[Media: figure1_resnet_cifar10.png | Figure 2]].<br />
<br />
<center>[[File: figure1_resnet_cifar10.png]]</center><br />
<br />
These results demonstrate that although the training error is considerably lower for the network with learned classifier, both models achieve the same classification accuracy on the validation set. The authors conjecture is that with the new fixed parameterization, the network can no longer increase the<br />
norm of a given sample’s representation - thus learning its label requires more effort. As this may happen for specific seen samples - it affects only training error.<br />
<br />
The authors also compared using a fixed scale variable <math>\alpha </math> at different values vs. the learned parameter. Results for <math> \alpha = </math> {0.1, 1, 10} are depicted in [[Media: figure3_alpha_resnet_cifar.png| Figure 3]] for both training and validation error and as can be seen, similar validation accuracy can be obtained using a fixed scale value (in this case <math>\alpha </math>= 1 or 10 will suffice) at the expense of another hyper-parameter to seek. In all the further experiments the scaling parameter <math> \alpha </math> was regularized with the same weight decay coefficient used on original classifier. Although learning the scale is not necessary, but it will help convergence during training.<br />
<br />
<center>[[File: figure3_alpha_resnet_cifar.png]]</center><br />
<br />
The authors then train the model on CIFAR-100 dataset. They used the DenseNet-BC model from Huang et al. (2017) with depth of 100 layers and k = 12. The higher number of classes caused the number of parameters to grow and encompassed about 4% of the whole model. However, validation accuracy for the fixed-classifier model remained equally good as the original model, and the same training curve was observed as earlier.<br />
<br />
==IMAGENET==<br />
<br />
===About the Dataset===<br />
<br />
The Imagenet dataset introduced by Deng et al. (2009) spans over 1000 visual classes, and over 1.2 million samples. This is supposedly a more challenging dataset to work on as compared to CIFAR-10/100.<br />
<br />
===Experiment Details===<br />
<br />
The authors evaluated their fixed classifier method on Imagenet using Resnet50 by He et al. (2016) and Densenet169 model (Huang et al., 2017) as described in the original work. Using a fixed classifier removed approximately 2-million parameters were from the model, accounting for about 8% and 12 % of the model parameters respectively. The experiments revealed similar trends as observed on CIFAR-10.<br />
<br />
For a more stricter evaluation, the authors also trained a Shufflenet architecture (Zhang et al., 2017b), which was designed to be used in low memory and limited computing platforms and has parameters making up the majority of the model. They were able to reduce the parameters to 0.86 million as compared to 0.96 million parameters in the final layer of the original model. Again, the proposed modification in the original model gave similar convergence results on validation accuracy.<br />
<br />
The overall results of the fixed-classifier are summarized in [[Media: table1_fixed_results.png | Table 1]].<br />
<br />
<center>[[File: table1_fixed_results.png]]</center><br />
<br />
==Language Modelling==<br />
<br />
Recent works have empirically found that using the same weights for both word embedding and classifier can yield equal or better results than using a separate pair of weights. So the authors experimented with fix-classifiers on language modelling as it also requires classification of all possible tokens available in the task vocabulary. They trained a recurrent model with 2-layers of LSTM (Hochreiter & Schmidhuber, 1997) and embedding + hidden size of 512 on the WikiText2 dataset (Merity et al., 2016), using same settings as in Merity et al. (2017). WikiText2 dataset contains about 33K different words, so the number of parameters expected in the embedding and classifier layer was about 34-million. This number is about 89% of the total number of parameters used for the whole model which is 38-million. However, using a random orthogonal transform yielded poor results compared to learned embedding. This was suspected to be due to semantic relationships captured in the embedding layer of language models, which is not the case in image classification task. The intuition was further confirmed by the much better results when pre-trained embeddings using word2vec algorithm by Mikolov et al. (2013) or PMI factorization as suggested by Levy & Goldberg (2014), were used.<br />
<br />
=Discussion=<br />
<br />
==Implications and Use Cases==<br />
<br />
With the increasing number of classes in the benchmark datasets, computational demands for the final classifier will increase as well. In order to understand the problem better, the authors observe the work by Sun et al. (2017), which introduced JFT-300M - an internal Google dataset with over 18K different classes. Using a Resnet50 (He et al., 2016), with a 2048 sized representation led to a model with over 36M parameters meaning that over 60% of the model parameters resided in the final classification layer. Sun et al. (2017) also describe the difficulty in distributing so many parameters over the training servers involving a non-trivial overhead during synchronization of the model for update. The authors claim that the fixed-classifier would help considerably in this kind of scenario - where using a fixed classifier removes the need to do any gradient synchronization for the final layer. Furthermore, introduction of Hadamard matrix removes the need to save the transformation altogether, thereby, making it more efficient and allowing considerable memory and computational savings.<br />
<br />
==Possible Caveats==<br />
<br />
The good performance of fixed-classifiers relies on the ability of the preceding layers to learn separable representations. This could be affected when when the ratio between learned features and number of classes is small – that is, when <math> C > N</math>. However, they tested their method in such cases and their model performed well and provided good results.<br />
Another factor that can affect the performance of their model using a fixed classifier is when the classes are highly correlated. In that case, the fixed classifier actually cannot support correlated classes and thus, the network could have some difficulty to learn. For a language model, word classes tend to have highly correlated instances, which also lead to difficult learning process.<br />
<br />
Also, this proposed approach will only eliminate the computation of the classifier weights, so when the classes are fewer, the computation saving effect will not be readily apparent.<br />
<br />
==Future Work==<br />
<br />
<br />
The use of fixed classifiers might be further simplified in Binarized Neural Networks (Hubara et al., 2016a), where the activations and weights are restricted to ±1 during propagations. In that case the norm of the last hidden layer would be constant for all samples (equal to the square root of the hidden layer width). The constant could then be absorbed into the scale constant <math>\alpha</math>, and there is no need in a per-sample normalization.<br />
<br />
Additionally, more efficient ways to learn a word embedding should also be explored where similar redundancy in classifier weights may suggest simpler forms of token representations - such as low-rank or sparse versions.<br />
<br />
A related paper was published that claims that fixing most of the parameters of the neural network achieves comparable results with learning all of them [A. Rosenfeld and J. K. Tsotsos]<br />
<br />
=Conclusion=<br />
<br />
In this work, the authors argue that the final classification layer in deep neural networks is redundant and suggest removing the parameters from the classification layer. The empirical results from experiments on the CIFAR and IMAGENET datasets suggest that such a change lead to little or almost no decline in the performance of the architecture. Furthermore, using a Hadmard matrix as classifier might lead to some computational benefits when properly implemented, and save memory otherwise spent on large amount of transformation coefficients.<br />
<br />
Another possible scope of research that could be pointed out for future could be to find new efficient methods to create pre-defined word embeddings, which require huge amount of parameters that can possibly be avoided when learning a new task. Therefore, more emphasis should be given to the representations learned by the non-linear parts of the neural networks - upto the final classifier, as it seems highly redundant.<br />
<br />
=Critique=<br />
<br />
The paper proposes an interesting idea that has a potential use case when designing memory-efficient neural networks. The experiments shown in the paper are quite rigorous and provide support to the authors' claim. However, it would have been more helpful if the authors had described a bit more about efficient implementation of the Hadamard matrix and how to scale this method for larger datasets (cases with <math> C >N</math>).<br />
<br />
=References=<br />
<br />
The code for the proposed model is available at https://github.com/eladhoffer/fix_your_classifier.<br />
<br />
Madhu S Advani and Andrew M Saxe. High-dimensional dynamics of generalization error in neural networks. arXiv preprint arXiv:1710.03667, 2017.<br />
<br />
Peter Bartlett, Dylan J Foster, and Matus Telgarsky. Spectrally-normalized margin bounds for neural networks. arXiv preprint arXiv:1706.08498, 2017.<br />
<br />
Jane Bromley, Isabelle Guyon, Yann LeCun, Eduard Sackinger, and Roopak Shah. Signature verification using a” siamese” time delay neural network. In Advances in Neural Information Processing Systems, pp. 737–744, 1994.<br />
<br />
Matthieu Courbariaux, Yoshua Bengio, and Jean-Pierre David. Binaryconnect: Training deep neural networks with binary weights during propagations. In Advances in Neural Information Processing Systems, pp. 3123–3131, 2015.<br />
<br />
Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-scale hierarchical image database. In Computer Vision and Pattern Recognition, 2009. CVPR 2009. IEEE Conference on, pp. 248–255. IEEE, 2009.<br />
<br />
Suriya Gunasekar, Blake Woodworth, Srinadh Bhojanapalli, Behnam Neyshabur, and Nathan Srebro. Implicit regularization in matrix factorization. arXiv preprint arXiv:1705.09280, 2017.<br />
<br />
Song Han, Huizi Mao, and William J Dally. Deep compression: Compressing deep neural networks with pruning, trained quantization and huffman coding. arXiv preprint arXiv:1510.00149, 2015.<br />
<br />
Moritz Hardt and Tengyu Ma. Identity matters in deep learning. 2017.<br />
<br />
Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770–778, 2016.<br />
<br />
A Hedayat, WD Wallis, et al. Hadamard matrices and their applications. The Annals of Statistics, 6<br />
(6):1184–1238, 1978.<br />
<br />
Sepp Hochreiter and Jurgen Schmidhuber. Long short-term memory. ¨ Neural computation, 9(8): 1735–1780, 1997.<br />
<br />
Elad Hoffer and Nir Ailon. Deep metric learning using triplet network. In International Workshop on Similarity-Based Pattern Recognition, pp. 84–92. Springer, 2015.<br />
<br />
Elad Hoffer, Itay Hubara, and Daniel Soudry. Train longer, generalize better: closing the generalization gap in large batch training of neural networks. 2017.<br />
<br />
Andrew G Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, Tobias Weyand, Marco Andreetto, and Hartwig Adam. Mobilenets: Efficient convolutional neural networks for mobile vision applications. arXiv preprint arXiv:1704.04861, 2017.<br />
<br />
Gao Huang, Zhuang Liu, Laurens van der Maaten, and Kilian Q Weinberger. Densely connected convolutional networks. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 2017.<br />
<br />
Guang-Bin Huang, Qin-Yu Zhu, and Chee-Kheong Siew. Extreme learning machine: theory and applications. Neurocomputing, 70(1):489–501, 2006.<br />
<br />
Itay Hubara, Matthieu Courbariaux, Daniel Soudry, Ran El-Yaniv, and Yoshua Bengio. Binarized neural networks. In Advances in Neural Information Processing Systems 29 (NIPS’16), 2016a.<br />
<br />
Itay Hubara, Matthieu Courbariaux, Daniel Soudry, Ran El-Yaniv, and Yoshua Bengio. Quantized neural networks: Training neural networks with low precision weights and activations. arXiv preprint arXiv:1609.07061, 2016b.<br />
<br />
Hakan Inan, Khashayar Khosravi, and Richard Socher. Tying word vectors and word classifiers: A loss framework for language modeling. arXiv preprint arXiv:1611.01462, 2016.<br />
<br />
Max Jaderberg, Andrea Vedaldi, and Andrew Zisserman. Speeding up convolutional neural networks with low rank expansions. arXiv preprint arXiv:1405.3866, 2014.<br />
<br />
Alex Krizhevsky. Learning multiple layers of features from tiny images. 2009.<br />
<br />
Alex Krizhevsky, Ilya Sutskever, and Geoffrey E Hinton. Imagenet classification with deep convolutional neural networks. In Advances in neural information processing systems, pp. 1097–1105, 2012.<br />
<br />
Yann LeCun, Leon Bottou, Yoshua Bengio, and Patrick Haffner. Gradient-based learning applied to ´ document recognition. Proceedings of the IEEE, 86(11):2278 2324, 1998.<br />
<br />
Omer Levy and Yoav Goldberg. Neural word embedding as implicit matrix factorization. In Advances in neural information processing systems, pp. 2177–2185, 2014.<br />
<br />
Fengfu Li, Bo Zhang, and Bin Liu. Ternary weight networks. arXiv preprint arXiv:1605.04711, 2016.<br />
<br />
Min Lin, Qiang Chen, and Shuicheng Yan. Network in network. arXiv preprint arXiv:1312.4400, 2013.<br />
<br />
Stephen Merity, Caiming Xiong, James Bradbury, and Richard Socher. Pointer sentinel mixture models. arXiv preprint arXiv:1609.07843, 2016.<br />
<br />
Stephen Merity, Nitish Shirish Keskar, and Richard Socher. Regularizing and Optimizing LSTM Language Models. arXiv preprint arXiv:1708.02182, 2017.<br />
<br />
Paulius Micikevicius, Sharan Narang, Jonah Alben, Gregory Diamos, Erich Elsen, David Garcia, Boris Ginsburg, Michael Houston, Oleksii Kuchaev, Ganesh Venkatesh, et al. Mixed precision training. arXiv preprint arXiv:1710.03740, 2017.<br />
<br />
Tomas Mikolov, Ilya Sutskever, Kai Chen, Greg S Corrado, and Jeff Dean. Distributed tations of words and phrases and their compositionality. In Advances in neural information processing systems, pp. 3111–3119, 2013.<br />
<br />
Behnam Neyshabur, Srinadh Bhojanapalli, David McAllester, and Nathan Srebro. Exploring generalization in deep learning. arXiv preprint arXiv:1706.08947, 2017.<br />
Jooyoung Park and Irwin W Sandberg. Universal approximation using radial-basis-function networks. Neural computation, 3(2):246–257, 1991.<br />
<br />
Ofir Press and Lior Wolf. Using the output embedding to improve language models. EACL 2017,<br />
pp. 157, 2017.<br />
<br />
Itay Safran and Ohad Shamir. On the quality of the initial basin in overspecified neural networks. In International Conference on Machine Learning, pp. 774–782, 2016.<br />
<br />
Tim Salimans and Diederik P Kingma. Weight normalization: A simple reparameterization to accelerate training of deep neural networks. In Advances in Neural Information Processing Systems, pp. 901–909, 2016.<br />
<br />
Florian Schroff, Dmitry Kalenichenko, and James Philbin. Facenet: A unified embedding for face recognition and clustering. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 815–823, 2015.<br />
<br />
Mahdi Soltanolkotabi, Adel Javanmard, and Jason D Lee. Theoretical insights into the optimization landscape of over-parameterized shallow neural networks. arXiv preprint arXiv:1707.04926, 2017.<br />
<br />
Daniel Soudry and Yair Carmon. No bad local minima: Data independent training error guarantees for multilayer neural networks. arXiv preprint arXiv:1605.08361, 2016.<br />
<br />
Daniel Soudry and Elad Hoffer. Exponentially vanishing sub-optimal local minima in multilayer neural networks. arXiv preprint arXiv:1702.05777, 2017.<br />
<br />
Daniel Soudry, Elad Hoffer, and Nathan Srebro. The implicit bias of gradient descent on separable data. 2018.<br />
<br />
Jost Tobias Springenberg, Alexey Dosovitskiy, Thomas Brox, and Martin Riedmiller. Striving for simplicity: The all convolutional net. arXiv preprint arXiv:1412.6806, 2014.<br />
<br />
Chen Sun, Abhinav Shrivastava, Saurabh Singh, and Abhinav Gupta. Revisiting unreasonable effectiveness of data in deep learning era. arXiv preprint arXiv:1707.02968, 2017.<br />
<br />
Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, and Andrew Rabinovich. Going deeper with convolutions. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 1–9, 2015.<br />
<br />
Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jon Shlens, and Zbigniew Wojna. Rethinking the inception architecture for computer vision. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2818–2826, 2016.<br />
<br />
Cheng Tai, Tong Xiao, Yi Zhang, Xiaogang Wang, et al. Convolutional neural networks with lowrank regularization. arXiv preprint arXiv:1511.06067, 2015.<br />
<br />
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. 2017.<br />
Ashia C Wilson, Rebecca Roelofs, Mitchell Stern, Nathan Srebro, and Benjamin Recht. The marginal value of adaptive gradient methods in machine learning. arXiv preprint arXiv:1705.08292, 2017.<br />
<br />
Bo Xie, Yingyu Liang, and Le Song. Diversity leads to generalization in neural networks. arXiv preprint arXiv:1611.03131, 2016.<br />
<br />
Matthew D Zeiler and Rob Fergus. Visualizing and understanding convolutional networks. In European conference on computer vision, pp. 818–833. Springer, 2014. Chiyuan Zhang, Samy Bengio, Moritz Hardt, Benjamin Recht, and Oriol Vinyals. Understanding deep learning requires rethinking generalization. In ICLR, 2017a. URL https://arxiv.org/abs/1611.03530.<br />
<br />
Xiangyu Zhang, Xinyu Zhou, Mengxiao Lin, and Jian Sun. Shufflenet: An extremely efficient convolutional neural network for mobile devices. arXiv preprint arXiv:1707.01083, 2017b.<br />
<br />
Shuchang Zhou, Zekun Ni, Xinyu Zhou, He Wen, Yuxin Wu, and Yuheng Zou. Dorefa-net: Training low bitwidth convolutional neural networks with low bitwidth gradients. arXiv preprint arXiv:1606.06160, 2016.<br />
<br />
A. Rosenfeld and J. K. Tsotsos, “Intriguing properties of randomly weighted networks: Generalizing while learning next to nothing,” arXiv preprint arXiv:1802.00844, 2018.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=conditional_neural_process&diff=41997conditional neural process2018-11-30T01:08:47Z<p>Gsahu: /* Conditional Neural Process */</p>
<hr />
<div>== Motivation ==<br />
<br />
Deep neural networks are good at function approximations, yet they are typically trained from scratch for each new function. While Bayesian methods, such as Gaussian Processes (GPs), exploit prior knowledge to quickly infer the shape of a new function at test time. Yet GPs<br />
are computationally expensive, and it can be hard to design appropriate priors. Hence the authors propose a propose a family of neural models called, Conditional Neural Processes (CNPs), that combine the benefits of both. <br />
<br />
== Introduction ==<br />
<br />
To train a model effectively, deep neural networks typically require large datasets. To mitigate this data efficiency problem, learning in two phases is one approach: the first phase learns the statistics of a generic domain without committing to a specific learning task; the second phase learns a function for a specific task but does so using only a small number of data points by exploiting the domain-wide statistics already learned. Taking a probabilistic stance and specifying a distribution over functions (stochastic processes) is another approach -- Gaussian Processes being a commonly used example of this. Such Bayesian methods can be computationally expensive. <br />
<br />
The authors of the paper propose a family of models that represent solutions to the supervised problem, and an end-to-end training approach to learning them that combines neural networks with features reminiscent of Gaussian Processes. They call this family of models Conditional Neural Processes (CNPs). CNPs can be trained on very few data points to make accurate predictions, while they also have the capacity to scale to complex functions and large datasets.<br />
<br />
== Model ==<br />
Consider a data set <math display="inline"> \{x_i, y_i\} </math> with evaluations <math display="inline">y_i = f(x_i) </math> for some unknown function <math display="inline">f</math>. Assume <math display="inline">g</math> is an approximating function of f. The aim is to minimize the loss between <math display="inline">f</math> and <math display="inline">g</math> on the entire space <math display="inline">X</math>. In practice, the routine is evaluated on a finite set of observations.<br />
<br />
<br />
Let training set be <math display="inline"> O = \{x_i, y_i\}_{i = 0} ^{n-1}</math>, and test set be <math display="inline"> T = \{x_i, y_i\}_{i = n} ^ {n + m - 1} \subset X</math> of unlabelled points.<br />
<br />
P be a probability distribution over functions <math display="inline"> F : X \to Y</math>, formally known as a stochastic process. Thus, P defines a joint distribution over the random variables <math display="inline"> {f(x_i)}_{i = 0} ^{n + m - 1}</math>. Therefore, for <math display="inline"> P(f(x)|O, T)</math>, our task is to predict the output values <math display="inline">f(x_i)</math> for <math display="inline"> x_i \in T</math>, given <math display="inline"> O</math>. <br />
<br />
A common assumption made on P is that all function evaluations of <math display="inline"> f </math> is Gaussian distributed. The random functions class is called Gaussian Processes (GPs). This framework of the stochastic process allows a model to be data efficient, however, it's hard to get appropriate priors and stochastic processes are expensive in computation, scaling poorly with <math>n</math> and <math>m</math>. One of the examples is GPs, which has running time <math>O(n+3)^3</math>.<br />
<br />
[[File:001.jpg|300px|center]]<br />
<br />
== Conditional Neural Process ==<br />
<br />
Conditional Neural Process models directly parametrize conditional stochastic processes without imposing consistency with respect to some prior process. CNP parametrize distributions over <math display="inline">f(T)</math> given a distributed representation of <math display="inline">O</math> of fixed dimensionality. Thus, the mathematical guarantees associated with stochastic processes is traded off for functional flexibility and scalability.<br />
<br />
CNP is a conditional stochastic process <math display="inline">Q_\theta</math> defines distributions over <math display="inline">f(x_i)</math> for <math display="inline">x_i \in T</math>, given a set of observations <math display="inline">O</math>. For stochastic processs, the authors assume that <math display="inline">Q_{\theta}</math> is invariant to permutations, and <math display="inline">Q_\theta(f(T) | O, T)= Q_\theta(f(T') | O, T')=Q_\theta(f(T) | O', T) </math> when <math> O', T'</math> are permutations of <math display="inline">O</math> and <math display="inline">T </math>. In this work, we generally enforce permutation invariance with respect to <math display="inline">T</math> be assuming a factored structure, which is the easiest way to ensure a valid stochastic process. That is, <math display="inline">Q_\theta(f(T) | O, T) = \prod _{x \in T} Q_\theta(f(x) | O, x)</math>. Moreover, this framework can be extended to non-factored distributions.<br />
<br />
In detail, the following architecture is used<br />
<br />
<math display="inline">r_i = h_\theta(x_i, y_i)</math> &forall; <math display="inline">(x_i, y_i) \in O</math>, where <math display="inline">h_\theta : X \times Y \to \mathbb{R} ^ d</math><br />
<br />
<math display="inline">r = r_i * r_2 * ... * r_n</math>, where <math display="inline">*</math> is a commutative operation that takes elements in <math display="inline">\mathbb{R}^d</math> and maps them into a single element of <math display="inline">\mathbb{R} ^ d</math><br />
<br />
<math display="inline">\Phi_i = g_\theta</math> &forall; <math display="inline">x_i \in T</math>, where <math display="inline">g_\theta : X \times \mathbb{R} ^ d \to \mathbb{R} ^ e</math> and <math display="inline">\Phi_i</math> are parameters for <math display="inline">Q_\theta</math><br />
<br />
Note that this architecture ensures permutation invariance and <math display="inline">O(n + m)</math> scaling for conditional prediction. Also, <math display="inline">r = r_i * r_2 * ... * r_n</math> can be computed in <math display="inline">O(n)</math>, this architecture supports streaming observation with minimal overhead.<br />
<br />
We train <math display="inline">Q_\theta</math> by asking it to predict <math display="inline">O</math> conditioned on a randomly<br />
chosen subset of <math display="inline">O</math>. This gives the model a signal of the uncertainty over the space X inherent in the distribution<br />
P given a set of observations. The authors let <math display="inline"> f \sim P</math>, <math display="inline"> O = \{(x_i, y_i)\}_{i = 0} ^{n-1}</math>, and N ~ uniform[0, 1, ..... ,n-1]. Subset <math display="inline"> O = \{(x_i, y_i)\}_{i = 0} ^{N}</math> that is first N elements of <math display="inline">O</math> is regarded as condition. The negative conditional log probability is given by<br />
\[\mathcal{L}(\theta)=-\mathbb{E}_{f \sim p}[\mathbb{E}_{N}[\log Q_\theta(\{y_i\}_{i = 0} ^{n-1}|O_{N}, \{x_i\}_{i = 0} ^{n-1})]]\]<br />
Thus, the targets it scores <math display="inline">Q_\theta</math> on include both the observed <br />
and unobserved values. In practice, Monte Carlo estimates of the gradient of this loss is taken by sampling <math display="inline">f</math> and <math display="inline">N</math>. <br />
<br />
This approach shifts the burden of imposing prior knowledge from an analytic prior to empirical data. This has the advantage of liberating a practitioner from having to specify an analytic form for the prior, which is ultimately<br />
intended to summarize their empirical experience. Still, we emphasize that the <math display="inline">Q_\theta</math> are not necessarily a consistent set of conditionals for all observation sets, and the training routine does not guarantee that.<br />
<br />
In summary,<br />
<br />
1. A CNP is a conditional distribution over functions<br />
trained to model the empirical conditional distributions<br />
of functions <math display="inline">f \sim P</math>.<br />
<br />
2. A CNP is permutation invariant in <math display="inline">O</math> and <math display="inline">T</math>.<br />
<br />
3. A CNP is scalable, achieving a running time complexity<br />
of <math display="inline">O(n + m)</math> for making <math display="inline">m</math> predictions with <math display="inline">n</math><br />
observations.<br />
<br />
== Related Work ==<br />
<br />
===Gaussian Process Framework===<br />
<br />
A Gaussian Process (GP) is a non-parametric method for regression, used extensively for regression and classification problems in the machine learning community. A GP is defined as a collection of random variables, any finite number of which have a joint Gaussian distribution.<br />
A standard approach is to model data as <math>y = m(X, φ) + \epsilon</math><br />
where m is the mean function with parameter vector <math>φ</math>, and <math>\epsilon</math> represents independent and identically distributed (i.i.d.) Gaussian noise: <math>N\sim (0,\sigma^2)</math><br />
<br />
For more info on Gaussian Process Framework:<br />
[https://arxiv.org/abs/1506.07304 A Gaussian process framework for modeling instrumental systematics: application to transmission spectroscopy]<br />
<br />
Several papers attempt to address various issues with GPs. These include:<br />
* Using sparse GPs to aid in scaling (Snelson & Ghahramani, 2006)<br />
* Using Deep GPs to achieve more expressiveness (Damianou & Lawrence, 2013; Salimbeni & Deisenroth, 2017)<br />
* Using neural networks to learn more expressive kernels (Wilson et al., 2016)<br />
<br />
A Python resource for Gaussian Process Framework implementation: [https://github.com/SheffieldML/GPyimplementation Gaussian Process Framework in Python]<br />
<br />
<br />
The goal of this paper is to incorporate ideas from standard neural networks with Gaussian processes in order to overcome drawbacks of both. Bayesian techniques work better with less data, but complex Bayesian networks become intractable on even moderate sized data sizes. NNs on the other hand, cannot make use of prior knowledge and often have to be retrained from scratch. Without sufficient data, they also perform poorly. Combining both frameworks, we get Conditional Neural Processes serves to learn the kernels of the Gaussian Process through neural networks and uses these learned kernels on a framework similar to GPs for prediction.<br />
<br />
===Meta Learning===<br />
<br />
Meta-Learning attempts to allow neural networks to learn more generalizable functions, as opposed to only approximating one function. This can be done by learning deep generative models which can do few-shot estimations of data. This can be implemented with attention mechanisms or additional memory.<br />
<br />
Classification is another common task in meta-learning, few-shot classification algorithms usually rely on some distance metric in feature space to compare target images and the observations. Matching networks(Vinyals et al., 2016; Bartunov & Vetrov, 2016) are closely related to CNPs.<br />
<br />
Finally, the latest variant of Conditional Neural Process can also be seen as an approximated amortized version of Bayesian DL(Gal & Ghahramani, 2016; Blundell et al., 2015; Louizos et al., 2017; Louizos & Welling, 2017). For example, Gal & Ghahramani 2016 develop a new theoretical framework casting dropout training in deep neural networks as approximate Bayesian inference in deep Gaussian processes. Their theory extracts information from existing models and gives us tools to model uncertainty.<br />
<br />
== Experimental Result I: Function Regression ==<br />
<br />
Classical 1D regression task that used as a common baseline for GP is the first example. <br />
They generated two different datasets that consisted of functions<br />
generated from a GP with an exponential kernel. In the first dataset they used a kernel with fixed parameters, and in the second dataset, the function switched at some random point. on the real line between two functions, each sampled with<br />
different kernel parameters. At every training step, they sampled a curve from the GP, select<br />
a subset of n points as observations, and a subset of t points as target points. Using the model, the observed points are encoded using a three-layer MLP encoder h with a 128-dimensional output representation. The representations are aggregated into a single representation<br />
<math display="inline">r = \frac{1}{n} \sum r_i</math><br />
, which is concatenated to <math display="inline">x_t</math> and passed to a decoder g consisting of a five layer<br />
MLP. The function outputs a Gaussian mean and variance for the target outputs. The model is trained to maximize the log-likelihood of the target points using the Adam optimizer. <br />
<br />
Two examples of the regression results obtained for each<br />
of the datasets are shown in the following figure.<br />
<br />
[[File:007.jpg|300px|center]]<br />
<br />
They compared the model to the predictions generated by a GP with the correct<br />
hyperparameters, which constitutes an upper bound on our<br />
performance. Although the prediction generated by the GP<br />
is smoother than the CNP's prediction both for the mean<br />
and variance, the model is able to learn to regress from a few<br />
context points for both the fixed kernels and switching kernels.<br />
As the number of context points grows, the accuracy<br />
of the model improves and the approximated uncertainty<br />
of the model decreases. Crucially, we see the model learns<br />
to estimate its own uncertainty given the observations very<br />
accurately. Nonetheless, it provides a good approximation<br />
that increases in accuracy as the number of context points<br />
increases.<br />
Furthermore, the model achieves similarly good performance<br />
on the switching kernel task. This type of regression task<br />
is not trivial for GPs whereas in our case we only have to<br />
change the dataset used for training<br />
<br />
== Experimental Result II: Image Completion for Digits ==<br />
<br />
[[File:002.jpg|600px|center]]<br />
<br />
They also tested CNP on the MNIST dataset and use the test<br />
set to evaluate its performance. As shown in the above figure the<br />
model learns to make good predictions of the underlying<br />
digit even for a small number of context points. Crucially,<br />
when conditioned only on one non-informative context point the model’s prediction corresponds<br />
to the average overall MNIST digits. As the number<br />
of context points increases the predictions become more<br />
similar to the underlying ground truth. This demonstrates<br />
the model’s capacity to extract dataset specific prior knowledge.<br />
It is worth mentioning that even with a complete set<br />
of observations, the model does not achieve pixel-perfect<br />
reconstruction, as we have a bottleneck at the representation<br />
level.<br />
Since this implementation of CNP returns factored outputs,<br />
the best prediction it can produce given limited context<br />
information is to average over all possible predictions that<br />
agree with the context. An alternative to this is to add<br />
latent variables in the model such that they can be sampled<br />
conditioned on the context to produce predictions with high<br />
probability in the data distribution. <br />
<br />
<br />
An important aspect of the model is its ability to estimate<br />
the uncertainty of the prediction. As shown in the bottom<br />
row of the above figure, as they added more observations, the variance<br />
shifts from being almost uniformly spread over the digit<br />
positions to being localized around areas that are specific<br />
to the underlying digit, specifically its edges. Being able to<br />
model the uncertainty given some context can be helpful for<br />
many tasks. One example is active exploration, where the<br />
model has a choice over where to observe.<br />
They tested this by<br />
comparing the predictions of CNP when the observations<br />
are chosen according to uncertainty, versus random pixels. This method is a very simple way of doing active<br />
exploration, but it already produces better prediction results<br />
then selecting the conditioning points at random.<br />
<br />
== Experimental Result III: Image Completion for Faces ==<br />
<br />
<br />
[[File:003.jpg|400px|center]]<br />
<br />
<br />
They also applied CNP to CelebA, a dataset of images of<br />
celebrity faces and reported performance obtained on the<br />
test set.<br />
<br />
As shown in the above figure our model is able to capture<br />
the complex shapes and colors of this dataset with predictions<br />
conditioned on less than 10% of the pixels being<br />
already close to the ground truth. As before, given a few contexts<br />
points the model averages over all possible faces, but as<br />
the number of context pairs increases the predictions capture<br />
image-specific details like face orientation and facial<br />
expression. Furthermore, as the number of context points<br />
increases the variance is shifted towards the edges in the<br />
image.<br />
<br />
[[File:004.jpg|400px|center]]<br />
<br />
An important aspect of CNPs demonstrated in the above figure is<br />
it's flexibility not only in the number of observations and<br />
targets it receives but also with regards to their input values.<br />
It is interesting to compare this property to GPs on one hand,<br />
and to trained generative models (van den Oord et al., 2016;<br />
Gregor et al., 2015) on the other hand.<br />
The first type of flexibility can be seen when conditioning on<br />
subsets that the model has not encountered during training.<br />
Consider conditioning the model on one half of the image,<br />
fox example. This forces the model to not only predict the pixel<br />
values according to some stationary smoothness property of<br />
the images, but also according to global spatial properties,<br />
e.g. symmetry and the relative location of different parts of<br />
faces. As seen in the first row of the figure, CNPs are able to<br />
capture those properties. A GP with a stationary kernel cannot<br />
capture this, and in the absence of observations would<br />
revert to its mean (the mean itself can be non-stationary but<br />
usually, this would not be enough to capture the interesting<br />
properties).<br />
<br />
In addition, the model is flexible with regards to the target<br />
input values. This means, e.g., we can query the model<br />
at resolutions it has not seen during training. We take a<br />
model that has only been trained using pixel coordinates of<br />
a specific resolution and predict at test time subpixel values<br />
for targets between the original coordinates. As shown in<br />
Figure 5, with one forward pass we can query the model at<br />
different resolutions. While GPs also exhibit this type of<br />
flexibility, it is not the case for trained generative models,<br />
which can only predict values for the pixel coordinates on<br />
which they were trained. In this sense, CNPs capture the best<br />
of both worlds – it is flexible in regards to the conditioning<br />
and prediction task and has the capacity to extract domain<br />
knowledge from a training set.<br />
<br />
[[File:010.jpg|400px|center]]<br />
<br />
<br />
They compared CNPs quantitatively to two related models:<br />
kNNs and GPs. As shown in the above table CNPs outperform<br />
the latter when a number of context points are small (empirically<br />
when half of the image or less is provided as context).<br />
When the majority of the image is given as context exact<br />
methods like GPs and kNN will perform better. From the table<br />
we can also see that the order in which the context points<br />
are provided is less important for CNPs, since providing the<br />
context points in order from top to bottom still results in<br />
good performance. Both insights point to the fact that CNPs<br />
learn a data-specific ‘prior’ that will generate good samples<br />
even when the number of context points is very small.<br />
<br />
== Experimental Result IV: Classification ==<br />
Finally, they applied the model to one-shot classification using the Omniglot dataset. This dataset consists of 1,623 classes of characters from 50 different alphabets. Each class has only 20 examples and as such this dataset is particularly suitable for few-shot learning algorithms. The authors used 1,200 randomly selected classes as their training set and the remainder as the testing data set.<br />
<br />
Additionally, to apply data augmentation the authors cropped the image from 32 × 32 to 28 × 28, applied small random<br />
translations and rotations to the inputs, and also increased<br />
the number of classes by rotating every character by 90<br />
degrees and defining that to be a new class. They generated<br />
the labels for an N-way classification task by choosing N<br />
random classes at each training step and arbitrarily assigning<br />
the labels 0, ..., N − 1 to each.<br />
<br />
<br />
[[File:008.jpg|400px|center]]<br />
<br />
Given that the input points are images, they modified the architecture<br />
of the encoder h to include convolution layers as<br />
mentioned in section 2. In addition, they only aggregated over<br />
inputs of the same class by using the information provided<br />
by the input label. The aggregated class-specific representations<br />
are then concatenated to form the final representation.<br />
Given that both the size of the class-specific representations<br />
and the number of classes is constant, the size of the final<br />
representation is still constant and thus the O(n + m)<br />
runtime still holds.<br />
The results of the classification are summarized in the following table<br />
CNPs achieve higher accuracy than models that are significantly<br />
more complex (like MANN). While CNPs do not<br />
beat state of the art for one-shot classification our accuracy<br />
values are comparable. Crucially, they reached those values<br />
using a significantly simpler architecture (three convolutional<br />
layers for the encoder and a three-layer MLP for the<br />
decoder) and with a lower runtime of O(n + m) at test time<br />
as opposed to O(nm)<br />
<br />
== Conclusion ==<br />
<br />
The paper introduced Conditional Neural Processes,<br />
a model that is both flexible at test time and has the<br />
capacity to extract prior knowledge from training data.<br />
<br />
The authors had demonstrated its ability to perform a variety of tasks<br />
including regression, classification and image completion.<br />
The paper compared CNP's to Gaussian Processes on one hand, and<br />
deep learning methods on the other, and also discussed the<br />
relation to meta-learning and few-shot learning.<br />
It is important to note that the specific CNP implementations<br />
described here are just simple proofs-of-concept and can<br />
be substantially extended, e.g. by including more elaborate<br />
architectures in line with modern deep learning advances.<br />
To summarize, this work can be seen as a step towards learning<br />
high-level abstractions, one of the grand challenges of<br />
contemporary machine learning. Functions learned by most<br />
Conditional Neural Processes<br />
conventional deep learning models are tied to a specific, constrained<br />
statistical context at any stage of training. A trained<br />
CNP is more general, in that it encapsulates the high-level<br />
statistics of a family of functions. As such it constitutes a<br />
high-level abstraction that can be reused for multiple tasks.<br />
In future work, they are going to explore how far these models can<br />
help in tackling the many key machine learning problems<br />
that seem to hinge on abstraction, such as transfer learning,<br />
meta-learning, and data efficiency.<br />
<br />
== Critiques ==<br />
<br />
This paper introduces a method, for reducing the computational complexity of the more famous Gaussian Processes model, but they have mentioned a complexity of O(n + m) which is almost the same order of RBF kernel GP. With respect to performances in a sequence of tasks, the authors have not made metric comparisons to GP methods to prove the superiority of their approach.<br />
<br />
It appears that the proposed model is effective in making accurate predictions using lower quality inputs. For example, a dataset with fewer data points or an image with fewer pixels. However, it is not clear whether the proposed algorithm can be trained with a smaller amount of input data.<br />
<br />
== Other Sources ==<br />
# Code for this model and a simpler explanation can be found at [https://github.com/deepmind/conditional-neural-process]<br />
# A newer version of the model is described in this paper [https://arxiv.org/pdf/1807.01622.pdf]<br />
# A good blog post on neural processes [https://kasparmartens.rbind.io/post/np/]<br />
<br />
== Reference ==<br />
Bartunov, S. and Vetrov, D. P. Fast adaptation in generative<br />
models with generative matching networks. arXiv<br />
preprint arXiv:1612.02192, 2016.<br />
<br />
Blundell, C., Cornebise, J., Kavukcuoglu, K., and Wierstra,<br />
D. Weight uncertainty in neural networks. arXiv preprint<br />
arXiv:1505.05424, 2015.<br />
<br />
Bornschein, J., Mnih, A., Zoran, D., and J. Rezende, D.<br />
Variational memory addressing in generative models. In<br />
Advances in Neural Information Processing Systems, pp.<br />
3923–3932, 2017.<br />
<br />
Damianou, A. and Lawrence, N. Deep gaussian processes.<br />
In Artificial Intelligence and Statistics, pp. 207–215,<br />
2013.<br />
<br />
Devlin, J., Bunel, R. R., Singh, R., Hausknecht, M., and<br />
Kohli, P. Neural program meta-induction. In Advances in<br />
Neural Information Processing Systems, pp. 2077–2085,<br />
2017.<br />
<br />
Edwards, H. and Storkey, A. Towards a neural statistician.<br />
2016.<br />
<br />
Finn, C., Abbeel, P., and Levine, S. Model-agnostic metalearning<br />
for fast adaptation of deep networks. arXiv<br />
preprint arXiv:1703.03400, 2017.<br />
<br />
Gal, Y. and Ghahramani, Z. Dropout as a bayesian approximation:<br />
Representing model uncertainty in deep learning.<br />
In international conference on machine learning, pp.<br />
1050–1059, 2016.<br />
<br />
Garnelo, M., Arulkumaran, K., and Shanahan, M. Towards<br />
deep symbolic reinforcement learning. arXiv preprint<br />
arXiv:1609.05518, 2016.<br />
<br />
Gregor, K., Danihelka, I., Graves, A., Rezende, D. J., and<br />
Wierstra, D. Draw: A recurrent neural network for image<br />
generation. arXiv preprint arXiv:1502.04623, 2015.<br />
<br />
Hewitt, L., Gane, A., Jaakkola, T., and Tenenbaum, J. B. The<br />
variational homoencoder: Learning to infer high-capacity<br />
generative models from few examples. 2018.<br />
<br />
J. Rezende, D., Danihelka, I., Gregor, K., Wierstra, D.,<br />
et al. One-shot generalization in deep generative models.<br />
In International Conference on Machine Learning, pp.<br />
1521–1529, 2016.<br />
<br />
Kingma, D. P. and Ba, J. Adam: A method for stochastic<br />
optimization. arXiv preprint arXiv:1412.6980, 2014.<br />
<br />
Kingma, D. P. and Welling, M. Auto-encoding variational<br />
bayes. arXiv preprint arXiv:1312.6114, 2013.<br />
<br />
Koch, G., Zemel, R., and Salakhutdinov, R. Siamese neural<br />
networks for one-shot image recognition. In ICML Deep<br />
Learning Workshop, volume 2, 2015.<br />
<br />
Lake, B. M., Salakhutdinov, R., and Tenenbaum, J. B.<br />
Human-level concept learning through probabilistic program<br />
induction. Science, 350(6266):1332–1338, 2015.<br />
<br />
Lake, B. M., Ullman, T. D., Tenenbaum, J. B., and Gershman,<br />
S. J. Building machines that learn and think like<br />
people. Behavioral and Brain Sciences, 40, 2017.<br />
<br />
LeCun, Y., Bottou, L., Bengio, Y., and Haffner, P. Gradientbased<br />
learning applied to document recognition. Proceedings<br />
of the IEEE, 86(11):2278–2324, 1998.<br />
<br />
Liu, Z., Luo, P., Wang, X., and Tang, X. Deep learning face<br />
attributes in the wild. In Proceedings of International<br />
Conference on Computer Vision (ICCV), December 2015.<br />
<br />
Louizos, C. and Welling, M. Multiplicative normalizing<br />
flows for variational bayesian neural networks. arXiv<br />
preprint arXiv:1703.01961, 2017.<br />
<br />
Louizos, C., Ullrich, K., and Welling, M. Bayesian compression<br />
for deep learning. In Advances in Neural Information<br />
Processing Systems, pp. 3290–3300, 2017.<br />
<br />
Rasmussen, C. E. and Williams, C. K. Gaussian processes<br />
in machine learning. In Advanced lectures on machine<br />
learning, pp. 63–71. Springer, 2004.<br />
<br />
Reed, S., Chen, Y., Paine, T., Oord, A. v. d., Eslami, S.,<br />
J. Rezende, D., Vinyals, O., and de Freitas, N. Few-shot<br />
autoregressive density estimation: Towards learning to<br />
learn distributions. 2017.<br />
<br />
Rezende, D. J., Mohamed, S., and Wierstra, D. Stochastic<br />
backpropagation and approximate inference in deep generative<br />
models. arXiv preprint arXiv:1401.4082, 2014.<br />
<br />
Salimbeni, H. and Deisenroth, M. Doubly stochastic variational<br />
inference for deep gaussian processes. In Advances<br />
in Neural Information Processing Systems, pp.<br />
4591–4602, 2017.<br />
<br />
Santoro, A., Bartunov, S., Botvinick, M., Wierstra, D., and<br />
Lillicrap, T. One-shot learning with memory-augmented<br />
neural networks. arXiv preprint arXiv:1605.06065, 2016.<br />
<br />
Snell, J., Swersky, K., and Zemel, R. Prototypical networks<br />
for few-shot learning. In Advances in Neural Information<br />
Processing Systems, pp. 4080–4090, 2017.<br />
<br />
Snelson, E. and Ghahramani, Z. Sparse gaussian processes<br />
using pseudo-inputs. In Advances in neural information<br />
processing systems, pp. 1257–1264, 2006.<br />
<br />
van den Oord, A., Kalchbrenner, N., Espeholt, L., Vinyals,<br />
O., Graves, A., et al. Conditional image generation with<br />
pixelcnn decoders. In Advances in Neural Information<br />
Processing Systems, pp. 4790–4798, 2016.<br />
<br />
Vinyals, O., Blundell, C., Lillicrap, T., Wierstra, D., et al.<br />
Matching networks for one shot learning. In Advances in<br />
Neural Information Processing Systems, pp. 3630–3638,<br />
2016.<br />
<br />
Wang, J. X., Kurth-Nelson, Z., Tirumala, D., Soyer, H.,<br />
Leibo, J. Z., Munos, R., Blundell, C., Kumaran, D., and<br />
Botvinick, M. Learning to reinforcement learn. arXiv<br />
preprint arXiv:1611.05763, 2016.<br />
<br />
Wilson, A. G., Hu, Z., Salakhutdinov, R., and Xing, E. P.<br />
Deep kernel learning. In Artificial Intelligence and Statistics,<br />
pp. 370–378, 2016.<br />
<br />
Damianou, A. and Lawrence, N. Deep gaussian processes.<br />
In Artificial Intelligence and Statistics, pp. 207–215,<br />
2013.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=conditional_neural_process&diff=41995conditional neural process2018-11-30T01:02:57Z<p>Gsahu: /* Conclusion */</p>
<hr />
<div>== Motivation ==<br />
<br />
Deep neural networks are good at function approximations, yet they are typically trained from scratch for each new function. While Bayesian methods, such as Gaussian Processes (GPs), exploit prior knowledge to quickly infer the shape of a new function at test time. Yet GPs<br />
are computationally expensive, and it can be hard to design appropriate priors. Hence the authors propose a propose a family of neural models called, Conditional Neural Processes (CNPs), that combine the benefits of both. <br />
<br />
== Introduction ==<br />
<br />
To train a model effectively, deep neural networks typically require large datasets. To mitigate this data efficiency problem, learning in two phases is one approach: the first phase learns the statistics of a generic domain without committing to a specific learning task; the second phase learns a function for a specific task but does so using only a small number of data points by exploiting the domain-wide statistics already learned. Taking a probabilistic stance and specifying a distribution over functions (stochastic processes) is another approach -- Gaussian Processes being a commonly used example of this. Such Bayesian methods can be computationally expensive. <br />
<br />
The authors of the paper propose a family of models that represent solutions to the supervised problem, and an end-to-end training approach to learning them that combines neural networks with features reminiscent of Gaussian Processes. They call this family of models Conditional Neural Processes (CNPs). CNPs can be trained on very few data points to make accurate predictions, while they also have the capacity to scale to complex functions and large datasets.<br />
<br />
== Model ==<br />
Consider a data set <math display="inline"> \{x_i, y_i\} </math> with evaluations <math display="inline">y_i = f(x_i) </math> for some unknown function <math display="inline">f</math>. Assume <math display="inline">g</math> is an approximating function of f. The aim is to minimize the loss between <math display="inline">f</math> and <math display="inline">g</math> on the entire space <math display="inline">X</math>. In practice, the routine is evaluated on a finite set of observations.<br />
<br />
<br />
Let training set be <math display="inline"> O = \{x_i, y_i\}_{i = 0} ^{n-1}</math>, and test set be <math display="inline"> T = \{x_i, y_i\}_{i = n} ^ {n + m - 1} \subset X</math> of unlabelled points.<br />
<br />
P be a probability distribution over functions <math display="inline"> F : X \to Y</math>, formally known as a stochastic process. Thus, P defines a joint distribution over the random variables <math display="inline"> {f(x_i)}_{i = 0} ^{n + m - 1}</math>. Therefore, for <math display="inline"> P(f(x)|O, T)</math>, our task is to predict the output values <math display="inline">f(x_i)</math> for <math display="inline"> x_i \in T</math>, given <math display="inline"> O</math>. <br />
<br />
A common assumption made on P is that all function evaluations of <math display="inline"> f </math> is Gaussian distributed. The random functions class is called Gaussian Processes (GPs). This framework of the stochastic process allows a model to be data efficient, however, it's hard to get appropriate priors and stochastic processes are expensive in computation, scaling poorly with <math>n</math> and <math>m</math>. One of the examples is GPs, which has running time <math>O(n+3)^3</math>.<br />
<br />
[[File:001.jpg|300px|center]]<br />
<br />
== Conditional Neural Process ==<br />
<br />
Conditional Neural Process models directly parametrize conditional stochastic processes without imposing consistency with respect to some prior process. CNP parametrize distributions over <math display="inline">f(T)</math> given a distributed representation of <math display="inline">O</math> of fixed dimensionality. Thus, the mathematical guarantees associated with stochastic processes is traded off for functional flexibility and scalability.<br />
<br />
CNP is a conditional stochastic process <math display="inline">Q_\theta</math> defines distributions over <math display="inline">f(x_i)</math> for <math display="inline">x_i \in T</math>, given a set of observations <math display="inline">O</math>. For stochastic processs, the authors assume that <math display="inline">Q_{\theta}</math> is invariant to permutations, and <math display="inline">Q_\theta(f(T) | O, T)= Q_\theta(f(T') | O, T')=Q_\theta(f(T) | O', T) </math> when <math> O', T'</math> are permutations of <math display="inline">O</math> and <math display="inline">T </math>. In this work, we generally enforce permutation invariance with respect to <math display="inline">T</math> be assuming a factored structure, which is the easiest way to ensure a valid stochastic process. That is, <math display="inline">Q_\theta(f(T) | O, T) = \prod _{x \in T} Q_\theta(f(x) | O, x)</math>. Moreover, this framework can be extended to non-factored distributions.<br />
<br />
In detail, the following architecture is used<br />
<br />
<math display="inline">r_i = h_\theta(x_i, y_i)</math> for any <math display="inline">(x_i, y_i) \in O</math>, where <math display="inline">h_\theta : X \times Y \to \mathbb{R} ^ d</math><br />
<br />
<math display="inline">r = r_i * r_2 * ... * r_n</math>, where <math display="inline">*</math> is a commutative operation that takes elements in <math display="inline">\mathbb{R}^d</math> and maps them into a single element of <math display="inline">\mathbb{R} ^ d</math><br />
<br />
<math display="inline">\Phi_i = g_\theta</math> for any <math display="inline">x_i \in T</math>, where <math display="inline">g_\theta : X \times \mathbb{R} ^ d \to \mathbb{R} ^ e</math> and <math display="inline">\Phi_i</math> are parameters for <math display="inline">Q_\theta</math><br />
<br />
Note that this architecture ensures permutation invariance and <math display="inline">O(n + m)</math> scaling for conditional prediction. Also, <math display="inline">r = r_i * r_2 * ... * r_n</math> can be computed in <math display="inline">O(n)</math>, this architecture supports streaming observation with minimal overhead.<br />
<br />
We train <math display="inline">Q_\theta</math> by asking it to predict <math display="inline">O</math> conditioned on a randomly<br />
chosen subset of <math display="inline">O</math>. This gives the model a signal of the uncertainty over the space X inherent in the distribution<br />
P given a set of observations. The authors let <math display="inline"> f \sim P</math>, <math display="inline"> O = \{(x_i, y_i)\}_{i = 0} ^{n-1}</math>, and N ~ uniform[0, 1, ..... ,n-1]. Subset <math display="inline"> O = \{(x_i, y_i)\}_{i = 0} ^{N}</math> that is first N elements of <math display="inline">O</math> is regarded as condition. The negative conditional log probability is given by<br />
\[\mathcal{L}(\theta)=-\mathbb{E}_{f \sim p}[\mathbb{E}_{N}[\log Q_\theta(\{y_i\}_{i = 0} ^{n-1}|O_{N}, \{x_i\}_{i = 0} ^{n-1})]]\]<br />
Thus, the targets it scores <math display="inline">Q_\theta</math> on include both the observed <br />
and unobserved values. In practice, Monte Carlo estimates of the gradient of this loss is taken by sampling <math display="inline">f</math> and <math display="inline">N</math>. <br />
<br />
This approach shifts the burden of imposing prior knowledge from an analytic prior to empirical data. This has the advantage of liberating a practitioner from having to specify an analytic form for the prior, which is ultimately<br />
intended to summarize their empirical experience. Still, we emphasize that the <math display="inline">Q_\theta</math> are not necessarily a consistent set of conditionals for all observation sets, and the training routine does not guarantee that.<br />
<br />
In summary,<br />
<br />
1. A CNP is a conditional distribution over functions<br />
trained to model the empirical conditional distributions<br />
of functions <math display="inline">f \sim P</math>.<br />
<br />
2. A CNP is permutation invariant in <math display="inline">O</math> and <math display="inline">T</math>.<br />
<br />
3. A CNP is scalable, achieving a running time complexity<br />
of <math display="inline">O(n + m)</math> for making <math display="inline">m</math> predictions with <math display="inline">n</math><br />
observations.<br />
<br />
== Related Work ==<br />
<br />
===Gaussian Process Framework===<br />
<br />
A Gaussian Process (GP) is a non-parametric method for regression, used extensively for regression and classification problems in the machine learning community. A GP is defined as a collection of random variables, any finite number of which have a joint Gaussian distribution.<br />
A standard approach is to model data as <math>y = m(X, φ) + \epsilon</math><br />
where m is the mean function with parameter vector <math>φ</math>, and <math>\epsilon</math> represents independent and identically distributed (i.i.d.) Gaussian noise: <math>N\sim (0,\sigma^2)</math><br />
<br />
For more info on Gaussian Process Framework:<br />
[https://arxiv.org/abs/1506.07304 A Gaussian process framework for modeling instrumental systematics: application to transmission spectroscopy]<br />
<br />
Several papers attempt to address various issues with GPs. These include:<br />
* Using sparse GPs to aid in scaling (Snelson & Ghahramani, 2006)<br />
* Using Deep GPs to achieve more expressiveness (Damianou & Lawrence, 2013; Salimbeni & Deisenroth, 2017)<br />
* Using neural networks to learn more expressive kernels (Wilson et al., 2016)<br />
<br />
A Python resource for Gaussian Process Framework implementation: [https://github.com/SheffieldML/GPyimplementation Gaussian Process Framework in Python]<br />
<br />
<br />
The goal of this paper is to incorporate ideas from standard neural networks with Gaussian processes in order to overcome drawbacks of both. Bayesian techniques work better with less data, but complex Bayesian networks become intractable on even moderate sized data sizes. NNs on the other hand, cannot make use of prior knowledge and often have to be retrained from scratch. Without sufficient data, they also perform poorly. Combining both frameworks, we get Conditional Neural Processes serves to learn the kernels of the Gaussian Process through neural networks and uses these learned kernels on a framework similar to GPs for prediction.<br />
<br />
===Meta Learning===<br />
<br />
Meta-Learning attempts to allow neural networks to learn more generalizable functions, as opposed to only approximating one function. This can be done by learning deep generative models which can do few-shot estimations of data. This can be implemented with attention mechanisms or additional memory.<br />
<br />
Classification is another common task in meta-learning, few-shot classification algorithms usually rely on some distance metric in feature space to compare target images and the observations. Matching networks(Vinyals et al., 2016; Bartunov & Vetrov, 2016) are closely related to CNPs.<br />
<br />
Finally, the latest variant of Conditional Neural Process can also be seen as an approximated amortized version of Bayesian DL(Gal & Ghahramani, 2016; Blundell et al., 2015; Louizos et al., 2017; Louizos & Welling, 2017). For example, Gal & Ghahramani 2016 develop a new theoretical framework casting dropout training in deep neural networks as approximate Bayesian inference in deep Gaussian processes. Their theory extracts information from existing models and gives us tools to model uncertainty.<br />
<br />
== Experimental Result I: Function Regression ==<br />
<br />
Classical 1D regression task that used as a common baseline for GP is the first example. <br />
They generated two different datasets that consisted of functions<br />
generated from a GP with an exponential kernel. In the first dataset they used a kernel with fixed parameters, and in the second dataset, the function switched at some random point. on the real line between two functions, each sampled with<br />
different kernel parameters. At every training step, they sampled a curve from the GP, select<br />
a subset of n points as observations, and a subset of t points as target points. Using the model, the observed points are encoded using a three-layer MLP encoder h with a 128-dimensional output representation. The representations are aggregated into a single representation<br />
<math display="inline">r = \frac{1}{n} \sum r_i</math><br />
, which is concatenated to <math display="inline">x_t</math> and passed to a decoder g consisting of a five layer<br />
MLP. The function outputs a Gaussian mean and variance for the target outputs. The model is trained to maximize the log-likelihood of the target points using the Adam optimizer. <br />
<br />
Two examples of the regression results obtained for each<br />
of the datasets are shown in the following figure.<br />
<br />
[[File:007.jpg|300px|center]]<br />
<br />
They compared the model to the predictions generated by a GP with the correct<br />
hyperparameters, which constitutes an upper bound on our<br />
performance. Although the prediction generated by the GP<br />
is smoother than the CNP's prediction both for the mean<br />
and variance, the model is able to learn to regress from a few<br />
context points for both the fixed kernels and switching kernels.<br />
As the number of context points grows, the accuracy<br />
of the model improves and the approximated uncertainty<br />
of the model decreases. Crucially, we see the model learns<br />
to estimate its own uncertainty given the observations very<br />
accurately. Nonetheless, it provides a good approximation<br />
that increases in accuracy as the number of context points<br />
increases.<br />
Furthermore, the model achieves similarly good performance<br />
on the switching kernel task. This type of regression task<br />
is not trivial for GPs whereas in our case we only have to<br />
change the dataset used for training<br />
<br />
== Experimental Result II: Image Completion for Digits ==<br />
<br />
[[File:002.jpg|600px|center]]<br />
<br />
They also tested CNP on the MNIST dataset and use the test<br />
set to evaluate its performance. As shown in the above figure the<br />
model learns to make good predictions of the underlying<br />
digit even for a small number of context points. Crucially,<br />
when conditioned only on one non-informative context point the model’s prediction corresponds<br />
to the average overall MNIST digits. As the number<br />
of context points increases the predictions become more<br />
similar to the underlying ground truth. This demonstrates<br />
the model’s capacity to extract dataset specific prior knowledge.<br />
It is worth mentioning that even with a complete set<br />
of observations, the model does not achieve pixel-perfect<br />
reconstruction, as we have a bottleneck at the representation<br />
level.<br />
Since this implementation of CNP returns factored outputs,<br />
the best prediction it can produce given limited context<br />
information is to average over all possible predictions that<br />
agree with the context. An alternative to this is to add<br />
latent variables in the model such that they can be sampled<br />
conditioned on the context to produce predictions with high<br />
probability in the data distribution. <br />
<br />
<br />
An important aspect of the model is its ability to estimate<br />
the uncertainty of the prediction. As shown in the bottom<br />
row of the above figure, as they added more observations, the variance<br />
shifts from being almost uniformly spread over the digit<br />
positions to being localized around areas that are specific<br />
to the underlying digit, specifically its edges. Being able to<br />
model the uncertainty given some context can be helpful for<br />
many tasks. One example is active exploration, where the<br />
model has a choice over where to observe.<br />
They tested this by<br />
comparing the predictions of CNP when the observations<br />
are chosen according to uncertainty, versus random pixels. This method is a very simple way of doing active<br />
exploration, but it already produces better prediction results<br />
then selecting the conditioning points at random.<br />
<br />
== Experimental Result III: Image Completion for Faces ==<br />
<br />
<br />
[[File:003.jpg|400px|center]]<br />
<br />
<br />
They also applied CNP to CelebA, a dataset of images of<br />
celebrity faces and reported performance obtained on the<br />
test set.<br />
<br />
As shown in the above figure our model is able to capture<br />
the complex shapes and colors of this dataset with predictions<br />
conditioned on less than 10% of the pixels being<br />
already close to the ground truth. As before, given a few contexts<br />
points the model averages over all possible faces, but as<br />
the number of context pairs increases the predictions capture<br />
image-specific details like face orientation and facial<br />
expression. Furthermore, as the number of context points<br />
increases the variance is shifted towards the edges in the<br />
image.<br />
<br />
[[File:004.jpg|400px|center]]<br />
<br />
An important aspect of CNPs demonstrated in the above figure is<br />
it's flexibility not only in the number of observations and<br />
targets it receives but also with regards to their input values.<br />
It is interesting to compare this property to GPs on one hand,<br />
and to trained generative models (van den Oord et al., 2016;<br />
Gregor et al., 2015) on the other hand.<br />
The first type of flexibility can be seen when conditioning on<br />
subsets that the model has not encountered during training.<br />
Consider conditioning the model on one half of the image,<br />
fox example. This forces the model to not only predict the pixel<br />
values according to some stationary smoothness property of<br />
the images, but also according to global spatial properties,<br />
e.g. symmetry and the relative location of different parts of<br />
faces. As seen in the first row of the figure, CNPs are able to<br />
capture those properties. A GP with a stationary kernel cannot<br />
capture this, and in the absence of observations would<br />
revert to its mean (the mean itself can be non-stationary but<br />
usually, this would not be enough to capture the interesting<br />
properties).<br />
<br />
In addition, the model is flexible with regards to the target<br />
input values. This means, e.g., we can query the model<br />
at resolutions it has not seen during training. We take a<br />
model that has only been trained using pixel coordinates of<br />
a specific resolution and predict at test time subpixel values<br />
for targets between the original coordinates. As shown in<br />
Figure 5, with one forward pass we can query the model at<br />
different resolutions. While GPs also exhibit this type of<br />
flexibility, it is not the case for trained generative models,<br />
which can only predict values for the pixel coordinates on<br />
which they were trained. In this sense, CNPs capture the best<br />
of both worlds – it is flexible in regards to the conditioning<br />
and prediction task and has the capacity to extract domain<br />
knowledge from a training set.<br />
<br />
[[File:010.jpg|400px|center]]<br />
<br />
<br />
They compared CNPs quantitatively to two related models:<br />
kNNs and GPs. As shown in the above table CNPs outperform<br />
the latter when a number of context points are small (empirically<br />
when half of the image or less is provided as context).<br />
When the majority of the image is given as context exact<br />
methods like GPs and kNN will perform better. From the table<br />
we can also see that the order in which the context points<br />
are provided is less important for CNPs, since providing the<br />
context points in order from top to bottom still results in<br />
good performance. Both insights point to the fact that CNPs<br />
learn a data-specific ‘prior’ that will generate good samples<br />
even when the number of context points is very small.<br />
<br />
== Experimental Result IV: Classification ==<br />
Finally, they applied the model to one-shot classification using the Omniglot dataset. This dataset consists of 1,623 classes of characters from 50 different alphabets. Each class has only 20 examples and as such this dataset is particularly suitable for few-shot learning algorithms. The authors used 1,200 randomly selected classes as their training set and the remainder as the testing data set.<br />
<br />
Additionally, to apply data augmentation the authors cropped the image from 32 × 32 to 28 × 28, applied small random<br />
translations and rotations to the inputs, and also increased<br />
the number of classes by rotating every character by 90<br />
degrees and defining that to be a new class. They generated<br />
the labels for an N-way classification task by choosing N<br />
random classes at each training step and arbitrarily assigning<br />
the labels 0, ..., N − 1 to each.<br />
<br />
<br />
[[File:008.jpg|400px|center]]<br />
<br />
Given that the input points are images, they modified the architecture<br />
of the encoder h to include convolution layers as<br />
mentioned in section 2. In addition, they only aggregated over<br />
inputs of the same class by using the information provided<br />
by the input label. The aggregated class-specific representations<br />
are then concatenated to form the final representation.<br />
Given that both the size of the class-specific representations<br />
and the number of classes is constant, the size of the final<br />
representation is still constant and thus the O(n + m)<br />
runtime still holds.<br />
The results of the classification are summarized in the following table<br />
CNPs achieve higher accuracy than models that are significantly<br />
more complex (like MANN). While CNPs do not<br />
beat state of the art for one-shot classification our accuracy<br />
values are comparable. Crucially, they reached those values<br />
using a significantly simpler architecture (three convolutional<br />
layers for the encoder and a three-layer MLP for the<br />
decoder) and with a lower runtime of O(n + m) at test time<br />
as opposed to O(nm)<br />
<br />
== Conclusion ==<br />
<br />
The paper introduced Conditional Neural Processes,<br />
a model that is both flexible at test time and has the<br />
capacity to extract prior knowledge from training data.<br />
<br />
The authors had demonstrated its ability to perform a variety of tasks<br />
including regression, classification and image completion.<br />
The paper compared CNP's to Gaussian Processes on one hand, and<br />
deep learning methods on the other, and also discussed the<br />
relation to meta-learning and few-shot learning.<br />
It is important to note that the specific CNP implementations<br />
described here are just simple proofs-of-concept and can<br />
be substantially extended, e.g. by including more elaborate<br />
architectures in line with modern deep learning advances.<br />
To summarize, this work can be seen as a step towards learning<br />
high-level abstractions, one of the grand challenges of<br />
contemporary machine learning. Functions learned by most<br />
Conditional Neural Processes<br />
conventional deep learning models are tied to a specific, constrained<br />
statistical context at any stage of training. A trained<br />
CNP is more general, in that it encapsulates the high-level<br />
statistics of a family of functions. As such it constitutes a<br />
high-level abstraction that can be reused for multiple tasks.<br />
In future work, they are going to explore how far these models can<br />
help in tackling the many key machine learning problems<br />
that seem to hinge on abstraction, such as transfer learning,<br />
meta-learning, and data efficiency.<br />
<br />
== Critiques ==<br />
<br />
This paper introduces a method, for reducing the computational complexity of the more famous Gaussian Processes model, but they have mentioned a complexity of O(n + m) which is almost the same order of RBF kernel GP. With respect to performances in a sequence of tasks, the authors have not made metric comparisons to GP methods to prove the superiority of their approach.<br />
<br />
It appears that the proposed model is effective in making accurate predictions using lower quality inputs. For example, a dataset with fewer data points or an image with fewer pixels. However, it is not clear whether the proposed algorithm can be trained with a smaller amount of input data.<br />
<br />
== Other Sources ==<br />
# Code for this model and a simpler explanation can be found at [https://github.com/deepmind/conditional-neural-process]<br />
# A newer version of the model is described in this paper [https://arxiv.org/pdf/1807.01622.pdf]<br />
# A good blog post on neural processes [https://kasparmartens.rbind.io/post/np/]<br />
<br />
== Reference ==<br />
Bartunov, S. and Vetrov, D. P. Fast adaptation in generative<br />
models with generative matching networks. arXiv<br />
preprint arXiv:1612.02192, 2016.<br />
<br />
Blundell, C., Cornebise, J., Kavukcuoglu, K., and Wierstra,<br />
D. Weight uncertainty in neural networks. arXiv preprint<br />
arXiv:1505.05424, 2015.<br />
<br />
Bornschein, J., Mnih, A., Zoran, D., and J. Rezende, D.<br />
Variational memory addressing in generative models. In<br />
Advances in Neural Information Processing Systems, pp.<br />
3923–3932, 2017.<br />
<br />
Damianou, A. and Lawrence, N. Deep gaussian processes.<br />
In Artificial Intelligence and Statistics, pp. 207–215,<br />
2013.<br />
<br />
Devlin, J., Bunel, R. R., Singh, R., Hausknecht, M., and<br />
Kohli, P. Neural program meta-induction. In Advances in<br />
Neural Information Processing Systems, pp. 2077–2085,<br />
2017.<br />
<br />
Edwards, H. and Storkey, A. Towards a neural statistician.<br />
2016.<br />
<br />
Finn, C., Abbeel, P., and Levine, S. Model-agnostic metalearning<br />
for fast adaptation of deep networks. arXiv<br />
preprint arXiv:1703.03400, 2017.<br />
<br />
Gal, Y. and Ghahramani, Z. Dropout as a bayesian approximation:<br />
Representing model uncertainty in deep learning.<br />
In international conference on machine learning, pp.<br />
1050–1059, 2016.<br />
<br />
Garnelo, M., Arulkumaran, K., and Shanahan, M. Towards<br />
deep symbolic reinforcement learning. arXiv preprint<br />
arXiv:1609.05518, 2016.<br />
<br />
Gregor, K., Danihelka, I., Graves, A., Rezende, D. J., and<br />
Wierstra, D. Draw: A recurrent neural network for image<br />
generation. arXiv preprint arXiv:1502.04623, 2015.<br />
<br />
Hewitt, L., Gane, A., Jaakkola, T., and Tenenbaum, J. B. The<br />
variational homoencoder: Learning to infer high-capacity<br />
generative models from few examples. 2018.<br />
<br />
J. Rezende, D., Danihelka, I., Gregor, K., Wierstra, D.,<br />
et al. One-shot generalization in deep generative models.<br />
In International Conference on Machine Learning, pp.<br />
1521–1529, 2016.<br />
<br />
Kingma, D. P. and Ba, J. Adam: A method for stochastic<br />
optimization. arXiv preprint arXiv:1412.6980, 2014.<br />
<br />
Kingma, D. P. and Welling, M. Auto-encoding variational<br />
bayes. arXiv preprint arXiv:1312.6114, 2013.<br />
<br />
Koch, G., Zemel, R., and Salakhutdinov, R. Siamese neural<br />
networks for one-shot image recognition. In ICML Deep<br />
Learning Workshop, volume 2, 2015.<br />
<br />
Lake, B. M., Salakhutdinov, R., and Tenenbaum, J. B.<br />
Human-level concept learning through probabilistic program<br />
induction. Science, 350(6266):1332–1338, 2015.<br />
<br />
Lake, B. M., Ullman, T. D., Tenenbaum, J. B., and Gershman,<br />
S. J. Building machines that learn and think like<br />
people. Behavioral and Brain Sciences, 40, 2017.<br />
<br />
LeCun, Y., Bottou, L., Bengio, Y., and Haffner, P. Gradientbased<br />
learning applied to document recognition. Proceedings<br />
of the IEEE, 86(11):2278–2324, 1998.<br />
<br />
Liu, Z., Luo, P., Wang, X., and Tang, X. Deep learning face<br />
attributes in the wild. In Proceedings of International<br />
Conference on Computer Vision (ICCV), December 2015.<br />
<br />
Louizos, C. and Welling, M. Multiplicative normalizing<br />
flows for variational bayesian neural networks. arXiv<br />
preprint arXiv:1703.01961, 2017.<br />
<br />
Louizos, C., Ullrich, K., and Welling, M. Bayesian compression<br />
for deep learning. In Advances in Neural Information<br />
Processing Systems, pp. 3290–3300, 2017.<br />
<br />
Rasmussen, C. E. and Williams, C. K. Gaussian processes<br />
in machine learning. In Advanced lectures on machine<br />
learning, pp. 63–71. Springer, 2004.<br />
<br />
Reed, S., Chen, Y., Paine, T., Oord, A. v. d., Eslami, S.,<br />
J. Rezende, D., Vinyals, O., and de Freitas, N. Few-shot<br />
autoregressive density estimation: Towards learning to<br />
learn distributions. 2017.<br />
<br />
Rezende, D. J., Mohamed, S., and Wierstra, D. Stochastic<br />
backpropagation and approximate inference in deep generative<br />
models. arXiv preprint arXiv:1401.4082, 2014.<br />
<br />
Salimbeni, H. and Deisenroth, M. Doubly stochastic variational<br />
inference for deep gaussian processes. In Advances<br />
in Neural Information Processing Systems, pp.<br />
4591–4602, 2017.<br />
<br />
Santoro, A., Bartunov, S., Botvinick, M., Wierstra, D., and<br />
Lillicrap, T. One-shot learning with memory-augmented<br />
neural networks. arXiv preprint arXiv:1605.06065, 2016.<br />
<br />
Snell, J., Swersky, K., and Zemel, R. Prototypical networks<br />
for few-shot learning. In Advances in Neural Information<br />
Processing Systems, pp. 4080–4090, 2017.<br />
<br />
Snelson, E. and Ghahramani, Z. Sparse gaussian processes<br />
using pseudo-inputs. In Advances in neural information<br />
processing systems, pp. 1257–1264, 2006.<br />
<br />
van den Oord, A., Kalchbrenner, N., Espeholt, L., Vinyals,<br />
O., Graves, A., et al. Conditional image generation with<br />
pixelcnn decoders. In Advances in Neural Information<br />
Processing Systems, pp. 4790–4798, 2016.<br />
<br />
Vinyals, O., Blundell, C., Lillicrap, T., Wierstra, D., et al.<br />
Matching networks for one shot learning. In Advances in<br />
Neural Information Processing Systems, pp. 3630–3638,<br />
2016.<br />
<br />
Wang, J. X., Kurth-Nelson, Z., Tirumala, D., Soyer, H.,<br />
Leibo, J. Z., Munos, R., Blundell, C., Kumaran, D., and<br />
Botvinick, M. Learning to reinforcement learn. arXiv<br />
preprint arXiv:1611.05763, 2016.<br />
<br />
Wilson, A. G., Hu, Z., Salakhutdinov, R., and Xing, E. P.<br />
Deep kernel learning. In Artificial Intelligence and Statistics,<br />
pp. 370–378, 2016.<br />
<br />
Damianou, A. and Lawrence, N. Deep gaussian processes.<br />
In Artificial Intelligence and Statistics, pp. 207–215,<br />
2013.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Mapping_Images_to_Scene_Graphs_with_Permutation-Invariant_Structured_Prediction&diff=41994Mapping Images to Scene Graphs with Permutation-Invariant Structured Prediction2018-11-30T01:01:43Z<p>Gsahu: /* References */</p>
<hr />
<div>The paper ''Mapping Images to Scene Graphs with Permutation-Invariant Structured Prediction'' was written by Roei Herzig* from Tel Aviv University, Moshiko Raboh* from Tel Aviv University, Gal Chechik from Google Brain, Bar-Ilan University, Jonathan Berant from Tel Aviv University, and Amir Globerson from Tel Aviv University. This paper is part of the NIPS 2018 conference to be hosted in December 2018 at Montréal, Canada. This paper summary is based on version 3 of the pre-print (as of May 2018) obtained from [https://arxiv.org/pdf/1802.05451v3.pdf arXiv] <br />
<br />
(*) Equal contribution<br />
<br />
=Motivation=<br />
In the field of artificial intelligence, a major goal is to enable machines to understand complex images, such as the underlying relationships between objects that exist in each scene, and the global context to interpret this scene. A natural modelling framework for capturing such effects is structured prediction, which optimizes over complex labels, while modelling within-label interactions. Although these models capture both complex labels and interactions between labels, there is a disconnect for what guidelines should be used when leveraging deep learning. This paper introduces a design principle for such models that stem from the concept of permutation invariance and proves state of the art performance on models that follow this principle.<br />
<br />
The primary contributions that this paper makes include:<br />
# Deriving sufficient and necessary conditions for respecting graph-permutation invariance in deep structured prediction architectures<br />
# Empirically proving the benefit of graph-permutation invariance<br />
# Developing a state-of-the-art model for scene graph predictions over a large set of complex visual scenes<br />
<br />
=Introduction=<br />
In order to make a machine to interpret complex visual scenes, it must recognize and understand both objects and relationships between the objects in the scene. A '''scene graph''' is a representation of the set of objects and relations that exist in the scene, where objects are represented as nodes, relations are represented as edges connecting the different nodes. Hence, the prediction of the scene graph is analogous to inferring the joint set of objects and relations of a visual scene.<br />
<br />
[[File:scene_graph_example.png|600px|center]]<br />
<br />
Given that objects in scenes are interdependent on each other, joint prediction of the objects and relations is necessary. The field of structured prediction, which involves the general problem of inferring multiple inter-dependent labels, is of interest for this problem. Structured prediction has attracted considerable attention because it applies to many learning problems and poses unique theoretical and applied challenges (e.g., see Belanger et al., 2017; Chen et al., 2015; Taskar et al., 2004).<br />
<br />
In structured prediction models, a score function <math>s(x, y)</math> is defined to evaluate the compatibility between label <math>y</math> and input <math>x</math>. For instance, when interpreting the scene of an image, <math>x</math> refers to the image itself, and <math>y</math> refers to a complex label, which contains both the objects and the relations between objects. As with most other inference methods, the goal is to find the label <math>y^*</math> such that <math>s(x,y)</math> is maximized, <math> y^*=argmax_y s(x,y)</math>. However, the major concern is that the space for possible label assignments grows exponentially with respect to input size. For example, although an image may seem very simple, the corpus containing possible labels for objects may be very large, rendering it difficult to optimize the scoring function. <br />
<br />
The paper presents an alternative approach, for which input <math>x</math> is mapped to structured output <math>y</math> using a "black box" neural network, omitting the definition of a score function. The main concern for this approach is the determination of the network architecture.<br />
<br />
The model is evaluated by firstly demonstrating the importance of permutation invariance on a synthetic data set. The approach laid out by the authors is then shown to respect permutation invariance, and results are compared to a competitive benchmark. This method achieves state-of-the-art results.<br />
<br />
=Structured prediction=<br />
This paper further considers structured predictions using score-based methods. For structured predictions that follow a score-based approach, a score function <math>s(x, y)</math> is used to measure how compatible label <math>y</math> is for input <math>x</math> and is also used to infer a label by maximizing <math>s(x, y)</math>. To optimize the score function, previous works have decomposed <math>s(x,y) = \sum_i f_i(x,y)</math> in order to facilitate efficient optimization which is done by optimizing the local score function, <math>\max_y f_i(x,y)</math>, with a small subset of the <math>y</math> variables.<br />
<br />
Recently, modeling the <math>f_i </math> functions as deep networks is a new interest. In such area of structured predictions, the most commonly-used score functions include the singleton score function <math>f_i(y_i, x)</math> and pairwise score function <math>f_{ij} (y_i, y_j, x)</math>. Previous works explored a two-stage architectures (learn local scores independently of the structured prediction goal), end-to-end architectures (to include the inference algorithm within the computation graph), and modelling global factors. <br />
<br />
==Advantages of using score-based methods==<br />
# Allow for intuitive specification of local dependencies between labels, and how they map to global dependencies<br />
# Linear score functions offer natural convex surrogates<br />
# Inference in large label space is sometimes possible via exact algorithms or empirically accurate approximations<br />
<br />
The concern for modeling score functions using deep networks is that learning may no longer be convex. Hence, the paper presents properties for how deep networks can be used for structured predictions by considering architectures that do not require explicit maximization of a score function.<br />
<br />
=Background, Notations, and Definitions=<br />
We denote <math>y</math> as a structured label where <math>y = [y_1, \dots, y_n]</math><br />
<br />
'''Score functions:''' for score-based methods, the score is defined as either the sum of a set of singleton scores <math>f_i = f_i(y_i, x)</math> or the sum of pairwise scores <math>f_{ij} = f_{ij}(y_i, y_j, x)</math>.<br />
<br />
Let <math>s(x,y)</math> be the score of a score-based method. Then:<br />
<br />
<div align="center"><br />
<math>s(x,y) = \begin{cases}<br />
\sum_i f_i ~ \text{if we have a set of singleton scores}\\<br />
\sum_{ij} f_{ij} ~ \text{if we have a set of pairwise scores } \\<br />
\end{cases}</math><br />
</div><br />
<br />
'''Inference algorithm:''' an inference algorithm takes input set of local scores (either <math>f_i</math> or <math>f_{ij}</math>) and outputs an assignment of labels <math>y_1, \dots, y_n</math> that maximizes score function <math>s(x,y)</math><br />
<br />
'''Graph labeling function:''' a graph labeling function <math>\mathcal{F} : (V,E) \rightarrow Y</math> is a function that takes input of: an ordered set of node features <math>V = [z_1, \dots, z_n]</math> and an ordered set of edge features <math>E = [z_{1,2},\dots,z_{i,j},\dots,z_{n,n-1}]</math> to output set of node labels <math>\mathbf{y} = [y_1, \dots, y_n]</math>. For instance, <math>z_i</math> can be set equal to <math>f_i</math> and <math>z_{ij}</math> can be set equal to <math>f_{ij}</math>.<br />
<br />
For convenience, the joint set of nodes and edges will be denoted as <math>\mathbf{z}</math> to be a size <math>n^2</math> vector (<math>n</math> nodes and <math>n(n-1)</math> edges).<br />
<br />
'''Permutation:''' Let <math>z</math> be a set of node and edge features. Given a permutation <math>\sigma</math> of <math>\{1,\dots,n\}</math>, let <math>\sigma(z)</math> be a new set of node and edge features given by [<math>\sigma(z)]_i = z_{\sigma(i)}</math> and <math>[\sigma(z)]_{i,j} = z_{\sigma(i), \sigma(j)}</math><br />
<br />
'''One-hot representation:''' <math>\mathbf{1}[j]</math> be a one-hot vector with 1 in the <math>j^{th}</math> coordinate<br />
<br />
=Permutation-Invariant Structured prediction=<br />
<br />
With permutation-invariant structured prediction, we would expect the algorithm to produce the same result given the same score function. For instance, consider the case where we have label space for 3 variables <math>y_1, y_2, y_3</math> with input <math>\mathbf{z} = (f_1, f_2, f_3, f_{12}, f_{13}, f_{23})</math> that outputs label <math>\mathbf{y} = (y_1^*, y_2^*, y_3^*)</math>. Then if the algorithm is run on a permuted version input <math>z' = (f_2, f_1, f_3, f_{21}, f_{23}, f_{13})</math>, we would expect <math>\mathbf{y} = (y_2^*, y_1^*, y_3^*)</math> given the same score function.<br />
<br />
'''Graph permutation invariance (GPI):''' a graph labeling function <math>\mathcal{F}</math> is graph-permutation invariant, if for all permutations <math>\sigma</math> of <math>\{1, \dots, n\}</math> and for all nodes <math>z</math>, <math>\mathcal{F}(\sigma(\mathbf{z})) = \sigma(\mathcal{F}(\mathbf{z}))</math>. Practically speaking, graph permutation means that the same graph is constructed, no matter the order in which elements are predicted. In scene graph generation approaches, Region Proposal Networks are often used as an initial pre-processing step. The results from these (cropped images representing bounding boxes) are then sequentially fed through a respective vertex (or edge) detection network. The idea behind Permutation Invariance is that no matter the order these are passed in, the final scene graph is identical. In effect, this means not connecting vertices that should not be connected simply because a more promising vertex has not yet been identified. <br />
<br />
The paper presents a theorem on the necessary and sufficient conditions for a function <math>\mathcal{F}</math> to be graph permutation invariant. Intuitively, because <math>\mathcal{F}</math> is a function that takes an ordered set <math>z</math> as input, the output on <math>\mathbf{z}</math> could very well be different from <math>\sigma(\mathbf{z})</math>, which means <math>\mathcal{F}</math> needs to have some sort of symmetry in order to sustain <math>[\mathcal{F}(\sigma(\mathbf{z}))]]_k = [\mathcal{F}(\mathbf{z})]_{\sigma(k)}</math>.<br />
<br />
[[File:graph_permutation_invariance.jpg|400px|center]]<br />
<br />
==Theorem 1==<br />
Let <math>\mathcal{F}</math> be a graph labeling function. Then <math>\mathcal{F}</math> is graph-permutation invariant if and only if there exist functions <math>\alpha, \rho, \phi</math> such that for all <math>k=1, .., n</math>:<br />
\begin{align}<br />
[\mathcal{F}(\mathbf{z})]_k = \rho(\mathbf{z}_k, \sum_{i=1}^n \alpha(\mathbf{z}_i, \sum_{i\neq j} \phi(\mathbf{z}_i, \mathbf{z}_{i,j}, \mathbf{z}_j)))<br />
\end{align}<br />
where <math>\phi: \mathbb{R}^{2d+e} \rightarrow \mathbb{R}^L, \alpha: \mathbb{R}^{d + L} \rightarrow \mathbb{R}^{W}, p: \mathbb{R}^{W+d} \rightarrow \mathbb{R}</math>.<br />
<br />
Notice that for the dimensions of inputs and outputs, <math>d</math> refers to the number of singleton features in <math>z</math> and <math>e</math> refers to the number of edges. <br />
<br />
[[File:GPI_architecture.jpg|thumb|A schematic representation of the GPI architecture. Singleton features <math>z_i</math> are omitted for simplicity. First, the features <math>z_{i,j}</math> are processed element-wise by <math>\phi</math>. Next, they are summed to create a vector <math>s_i</math>, which is concatenated with <math>z_i</math>. Third, a representation of the entire graph is created by applying <math>\alpha\ n</math> times and summing the created vector. The graph representation is then finally processed by <math>\rho</math> together with <math>z_k</math>.|600px|center]]<br />
<br />
==Proof Sketch for Theorem 1==<br />
The proof of this theorem can be found in the paper. A proof sketch is provided below:<br />
<br />
'''For the forward direction''' (function that follows the form set out in equation (1) is GPI):<br />
# Using definition of permutation <math>\sigma</math>, and rewriting <math>[F(z)]_{\sigma(k)}</math> in the form from equation (1)<br />
# Second argument of <math>\rho</math> is invariant under <math>\sigma</math>, since it takes the sum of all indices <math>i</math> and all other indices <math>j \neq i </math>.<br />
<br />
'''For the backward direction''' (any black-box GPI function can be expressed in the form of equation 1):<br />
# Construct <math>\phi, \alpha</math> such that second argument of <math>\rho</math> contains all information about graph features of <math>z</math>, including edges that the features originate from<br />
# Assume each <math>z_k</math> uniquely identifies the node and <math>\mathcal{F}</math> is a function only of pairwise features <math>z_{i,j}</math><br />
# Construct <math>H</math> be a perfect hash function with <math>L</math> buckets, and <math>\phi</math> which maps '''pairwise features''' to a vector of size <math>L</math><br />
# <math>*</math>Construct <math>\phi(z_i, z_{i,j}, z_j) = \mathbf{1}[H(z_j)] z_{i,j}</math>, which intuitively means that <math>\phi</math> stores <math>z_{i,j}</math> in the unique bucket for node <math>j</math><br />
# Construct function <math>\alpha</math> to output a matrix <math>\mathbb{R}^{L \times L}</math> that maps each pairwise feature into unique positions (<math>\alpha(z_i, s_i) = \mathbf{1}[H(z_i)]s_i^T</math>)<br />
# Construct matrix <math>M = \sum_i \alpha(z_i,s_i)</math> by discarding rows/columns in <math>M</math> that do not correspond to original nodes (which reduces dimension to <math>n\times n</math>; set <math>\rho</math> to have same outcome as <math>\mathcal{F}</math>, and set the output of <math>\mathcal{F}</math> on <math>M</math> to be the labels <math>\mathbf{y} = y_1, \dots, y_n</math><br />
<br />
<math>*</math>The paper presents the proof for the edge features <math>z_{ij}</math> being scalar (<math>e = 1</math>) for simplicity, which can be extended easily to vectors with additional indexing.<br />
<br />
Although the results discussed previously apply to complete graphs (edges apply to all feature pairs), it can be easily extended to incomplete graphs. For incomplete graphs, the input to F only contains the features corresponding to valid edges of the graph. The authors are only interested in invariances that preserve the graph structure. Thus, in place of permutation-invariance, it is now an automorphism-invariance.<br />
<br />
==Implications and Applications of Theorem 1==<br />
===Key Implications of Theorem 1===<br />
# Architecture "collects" information from the different edges of the graph, and does so in an invariant fashion using <math>\alpha</math> and <math>\phi</math><br />
# Architecture is parallelizable, since all <math>\phi</math> functions can be applied simultaneously. In contrast, recurrent models (Zellers et al. 2017) are harder to parallelize and are thus practically slower.<br />
<br />
===Some applications of Theorem 1===<br />
# '''Attention:''' the concept of attention can be implemented in the GPI characterization, with slight alterations to the functions <math>\alpha</math> and <math>\phi</math>. In attention each node aggregates features of neighbors through a function of neighbor's relevance. Which means the label of an entity could depend strongly on its close entity. The complete details can be found in the supplementary materials of the paper.<br />
<br />
# '''RNN:''' recurrent architectures can maintain GPI property, since all GPI function <math>\mathcal{F}</math> are closed under composition. The output of one step after running <math>\mathcal{F}</math> will act as input for the next step, but maintain the GPI property throughout.<br />
<br />
=Related Work=<br />
# '''Architectural invariance:''' suggested recently in a 2017 paper called Deep Sets by Zaheer et al., which considers the case of invariance that is more restrictive.<br />
# '''Deep structured prediction:''' previous work applied deep learning to structured prediction, for instance, semantic segmentation. Some algorithms include message passing algorithms, gradient descent for maximizing score functions, greedy decoding (inference of labels based on time of previous labels). For example, Xu et al. 2017 propose a novel end-to-end model that generates structured scene representation, and their model solves the scene graph inference problem using standard RNNs and learns to iteratively improves its predictions via message passing. Apart from those algorithms, deep learning has been applied to other graph-based problems such as the Travelling Salesman Problem (Bello et al., 2016; Gilmer et al., 2017; Khalil et al., 2017). However, none of the previous work specifically address the notion of invariance in the general architecture, but rather focus on message passing architectures that can be generalized by this paper.<br />
# '''Scene graph prediction:''' scene graph extraction allows for reasoning, question answering, and image retrieval (Johnson et al., 2015; Lu et al., 2016; Raposo et al., 2017). Some other works in this area include object detection, action recognition, and even the detection of human-object interactions (Liao et al., 2016; Plummer et al., 2017). Additional work has been done with the use of message passing algorithms (Xu et al., 2017), word embeddings (Lu et al., 2016), and end-to-end prediction directly from pixels (Newell & Deng, 2017). A notable mention is NeuralMotif (Zellers et al., 2017), which the authors describe as the current state-of-the-art model for scene graph predictions on Visual Genome dataset. It uses an RNN that supplies global context by reading the independent predictions sequentially for each entity and relation and then conducts further refinement on the predictions. The NeuralMotif model has a fixed order in which the RNN reads its inputs and thereby maintains GPI. However, this fixed order is not guaranteed to be optimal.<br />
# '''Burst Image Deblurring Using Permutation Invariant Convolutional Neural Networks:''' similar ideas were applied, where Permutation Invariant CNN, are used to restore sharp and noise-free images from bursts of photographs affected by hand tremor and noise. This presented good quality images with lots of details for challenging datasets.<br />
<br />
=Experimental Results=<br />
<br />
The authors evaluated the advantage of GPI architectures empirically. They first utilized synthetic graph labeling and then used scene-graph classification for mapping images.<br />
<br />
==Synthetic Graph Labeling==<br />
The authors created a synthetic problem to study GPI. This involved using an input graph <math>G = (V,E)</math> where each node <math>i</math> belongs to the set <math>\Gamma(i) \in \{1, \dots, K\}</math> where <math>K</math> is the number of samples. The task is to compute for each node, the number of neighbours that belong to the same set (i.e. finding the label of the node <math>i</math> if <math>y_i = \sum_{j \in N(i)} \mathbf{1}[\Gamma(i) = \Gamma(j)]</math>) . Then, random graphs (each with 10 nodes) were generated by sampling edges, and the set <math>\Gamma(i) \in \{1, \dots, K\}</math>for each node independently and uniformly.<br />
The node features of the graph <math>z_i \in \{0,1\}^K</math> are one-hot vectors of <math>\Gamma(i)</math>, and each pairwise edge feature <math>z_{ij} \in \{0, 1\}</math> denote whether the edge <math>ij</math> is in the edge set <math>E</math>. <br />
3 architectures were studied in this paper:<br />
# '''GPI-architecture for graph prediction''' (without attention and RNN)<br />
# '''LSTM''': replacing <math>\sum \phi(\cdot)</math> and <math>\sum \alpha(\cdot)</math> in the form of Theorem 1 using two LSTMs with state size 200, reading their input in random order<br />
# '''Fully connected feed-forward network''': with 2 hidden layers, each layer containing 1,000 nodes; the input is a concatenation of all nodes and pairwise features, and the output is all node predictions<br />
<br />
The results show that the GPI architecture requires far fewer samples to converge to the correct solution.<br />
[[File:GPI_synthetic_example.jpg|450px|center]]<br />
<br />
This experimental result is meant to demonstrate sample complexity. For fairness, all three models were constructed with a similar number of trainable parameters. The results tie back in with the author's comment that a black-box model which violates permutation invariant structure wastes capacity on learning it at training time. This illustrates the advantage of an architecture with a proper inductive bias.<br />
<br />
==Scene-Graph Classification==<br />
Applying the concept of GPI to Scene-Graph Prediction (SGP) is the main task of this paper. The input to this problem is an image, along with a set of annotated bounding boxes for the entities in the image. The goal is to correctly label each entity within the bounding boxes and the relationship between every pair of entities, resulting in a coherent scene graph.<br />
<br />
The authors describe two different types of variables to predict. The first type is entity variables <math>[y_1, \dots, y_n]</math> for all bounding boxes, where each <math>y_i</math> can take one of L values and refers to objects such as "dog" or "man". The second type is relation variables <math>[y_{n+1}, \cdots, y_{n^2}]</math>, where each <math>y_i</math> represents the relation (e.g. "on", "below") between a pair of bounding boxes (entities).<br />
<br />
The scene graph and contain two types of edges:<br />
# '''Entity-entity edge''': connecting two entities <math>y_i</math> and <math>y_j</math> for <math>1 \leq i \neq j \leq n</math><br />
# '''Entity-relation edges''': connecting every relation variable <math>y_k</math> for <math>k > n</math> to two entities<br />
<br />
The feature set <math>\mathbf{z}</math> is based on the baseline model from Zellers et al. (2017). For entity variables <math>y_i</math>, the vector <math>\mathbf{z}_i \in \mathbb{R}^L</math> models the probability of the entity appearing in <math>y_i</math>. <math>\mathbf{z}_i</math> is augmented by the coordinates of the bounding box. Similarly for relation variables <math>y_j</math>, the vector <math>\mathbf{z}_j \in \mathbb{R}^R</math>, models the probability of the relations between the two entities in <math>j</math>. For entity-entity pairwise features <math>\mathbf{z}_{i,j}</math>, there is a similar representation of the probabilities for the pair. The SGP outputs probability distributions over all entities and relations, which will then be used as input recurrently to maintain GPI. Finally, word embeddings are used and concatenated for the most probable entity-relation labels.<br />
<br />
'''Components of the GPI architecture''' (ent for entity, rel for relation)<br />
# <math>\phi_{ent}</math>: network that integrates two entity variables <math>y_i</math> and <math>y_j</math>, with input <math>z_i, z_j, z_{i,j}</math> and output vector of <math>\mathbb{R}^{n_1}</math> <br />
# <math>\alpha_{ent}</math>: network with inputs from <math>\phi_{ent}</math> for all neighbours of an entity, and uses attention mechanism to output vector <math>\mathbb{R}^{n_2}</math> <br />
# <math>\rho_{ent}</math>: network with inputs from the various <math>\mathbb{R}^{n_2}</math> vectors, and outputs <math>L</math> logits to predict entity value<br />
# <math>\rho_{rel}</math>: network with inputs <math>\alpha_{ent}</math> of two entities and <math>z_{i,j}</math>, and output into <math>R</math> logits<br />
<br />
==Set-up and Results==<br />
'''Dataset''': based on Visual Genome (VG) by (Krishna et al., 2017), which contains a total of 108,077 images annotated with bounding boxes, entities, and relations. An average of 12 entities and 7 relations exist per image. For a fair comparison with previous works, data from (Xu et al., 2017) for train and test splits were used. The authors used the same 150 entities and 50 relations as in (Xu et al., 2017; Newell & Deng, 2017; Zellers et al., 2017). Hyperparameters were tuned using a 70K/5K/32K split for training, validation, and testing respectively.<br />
<br />
'''Training''': all networks were trained using the Adam optimizer, with a batch size of 20. The loss function was the sum of cross-entropy losses over all of entities and relations. Penalties for misclassified entities were 4 times stronger than that of relations. Penalties for misclassified negative relations were 10 times weaker than that of positive relations.<br />
<br />
'''Evaluation''': there are three major tasks when inferring from the scene graph. The authors focus on the following:<br />
# '''SGCIs''': given ground-truth entity bounding boxes, predict all entity and relations categories<br />
# '''PredCIs''': given annotated bounding boxes with entity labels, predict all relations<br />
<br />
The evaluation metric Recall@K (shortened to R@K) is drawn from (Lu et al., 2016). This metric is the fraction of correct ground-truth triplets that appear within the <math>K</math> most confident triplets predicted by the model. Graph-constrained protocol requires the top-<math>K</math> triplets to assign one consistent class per entity and relation. The unconstrained protocol does not enforce such constraint.<br />
<br />
'''Models and baselines''': The authors compared variants of the GPI approach against four baselines, state-of-the-art models on completing scene graph sub-tasks. To maintain consistency, all models used the same training/testing data split, in addition to the preprocessing as per (Xu et al., 2017).<br />
<br />
'''Baselines from existing state-of-the-art models'''<br />
# (Lu et al., 2016): use of word embeddings to fine-tune the likelihood of predicted relations<br />
# (Xu et al., 2017): message passing algorithm between entities and relations to iteratively improve feature map for prediction<br />
# (Newell & Deng, 2017): Pixel2Graph, uses associative embeddings to produce a full graph from image<br />
# (Zellers et al., 2017): NeuralMotif method, encodes global context to capture higher-order motif in scene graphs; Baseline outputs entities and relations distributions without using global context<br />
<br />
'''GPI models'''<br />
# '''GPI with no attention mechanism''': simply following Theorem 1's functional form, with summation over features<br />
# '''GPI NeighborAttention''': same GPI model, but considers attention over neighbours features<br />
# '''GPI Linguistic''': similar to NeighborAttention model, but concatenates word embedding vectors<br />
<br />
'''Key Results''': The GPI Linguistic approach outperforms all baseline for SGCIs, and has similar performance to the state of the art NeuralMotifs method. The authors argue that PredCI is an easier task with less structure, yielding high performance for the existing state of the art models.<br />
<br />
[[File:GPI_table_results.png|700px|center]]<br />
<br />
=Conclusion=<br />
<br />
A deep learning approach was presented in this paper to structured prediction, which constrains the architecture to be invariant to structurally identical inputs. This approach relies on pairwise features which are capable of describing inter-label correlations and inherits the intuitive aspect of score-based approaches. The output produced is invariant to equivalent representation of the pairwise terms. <br />
<br />
As future work, the axiomatic approach can be extended; for example in image labeling, geometric variances such as shifts or rotations may be desired (or in other cases invariance to feature permutations may be desired). Additionally, exploring algorithms that discover symmetries for deep structured prediction when the invariant structure is unknown and should be discovered from data is also an interesting extension of this work.<br />
<br />
=Critique=<br />
The paper's contribution comes from the novelty of the permutation invariance as a design guideline for structured prediction. Although not explicitly considered in many of the previous works, the idea of invariance in architecture has already been considered in Deep Sets by (Zaheer et al., 2017). This paper characterizes relaxes the condition on the invariance as compared to that of previous works. In the evaluation of the benefit of GPI models, the paper used a synthetic problem to illustrate the fact that far fewer samples are required for the GPI model to converge to 100% accuracy. However, when comparing the true task of scene graph prediction against the state-of-the-art baselines, the GPI variants had only marginal higher Recall@K scores. The true benefit of this paper's discovery is the avoidance of maximizing a score function (leading computationally difficult problem), and instead directly producing output invariant to how we represent the pairwise terms.<br />
<br />
=References=<br />
<br />
Lu, Cewu, Krishna, Ranjay, Bernstein, Michael S., and Li, Fei-Fei. Visual relationship detection with<br />
language priors. In European Conf. Comput. Vision, pp. 852–869, 2016.<br />
<br />
Roei Herzig, Moshiko Raboh, Gal Chechik, Jonathan Berant, Amir Globerson, Mapping Images to Scene Graphs with Permutation-Invariant Structured Prediction, 2018.<br />
<br />
Belanger, David, Yang, Bishan, and McCallum, Andrew. End-to-end learning for structured prediction energy networks. In Precup, Doina and Teh, Yee Whye (eds.), Proceedings of the 34th International Conference on Machine Learning, volume 70, pp. 429–439. PMLR, 2017.<br />
<br />
Chen, Liang Chieh, Schwing, Alexander G, Yuille, Alan L, and Urtasun, Raquel. Learning deep structured models. In Proc. ICML, 2015.<br />
<br />
Taskar, B., Guestrin, C., and Koller, D. Max margin Markov networks. In Thrun, S., Saul, L., and Schölkopf, B. (eds.), Advances in Neural Information Processing Systems 16, pp. 25–32. MIT Press, Cambridge, MA, 2004.<br />
<br />
Zheng, Shuai, Jayasumana, Sadeep, Romera-Paredes, Bernardino, Vineet, Vibhav, Su, Zhizhong, Du, Dalong, Huang, Chang, and Torr, Philip HS. Conditional random fields as recurrent neural networks. In Proceedings of the IEEE International Conference on Computer Vision, pp. 1529–1537, 2015.<br />
<br />
Additional resources from Moshiko Raboh's [https://github.com/shikorab/SceneGraph GitHub]</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Mapping_Images_to_Scene_Graphs_with_Permutation-Invariant_Structured_Prediction&diff=41993Mapping Images to Scene Graphs with Permutation-Invariant Structured Prediction2018-11-30T00:59:26Z<p>Gsahu: /* References */</p>
<hr />
<div>The paper ''Mapping Images to Scene Graphs with Permutation-Invariant Structured Prediction'' was written by Roei Herzig* from Tel Aviv University, Moshiko Raboh* from Tel Aviv University, Gal Chechik from Google Brain, Bar-Ilan University, Jonathan Berant from Tel Aviv University, and Amir Globerson from Tel Aviv University. This paper is part of the NIPS 2018 conference to be hosted in December 2018 at Montréal, Canada. This paper summary is based on version 3 of the pre-print (as of May 2018) obtained from [https://arxiv.org/pdf/1802.05451v3.pdf arXiv] <br />
<br />
(*) Equal contribution<br />
<br />
=Motivation=<br />
In the field of artificial intelligence, a major goal is to enable machines to understand complex images, such as the underlying relationships between objects that exist in each scene, and the global context to interpret this scene. A natural modelling framework for capturing such effects is structured prediction, which optimizes over complex labels, while modelling within-label interactions. Although these models capture both complex labels and interactions between labels, there is a disconnect for what guidelines should be used when leveraging deep learning. This paper introduces a design principle for such models that stem from the concept of permutation invariance and proves state of the art performance on models that follow this principle.<br />
<br />
The primary contributions that this paper makes include:<br />
# Deriving sufficient and necessary conditions for respecting graph-permutation invariance in deep structured prediction architectures<br />
# Empirically proving the benefit of graph-permutation invariance<br />
# Developing a state-of-the-art model for scene graph predictions over a large set of complex visual scenes<br />
<br />
=Introduction=<br />
In order to make a machine to interpret complex visual scenes, it must recognize and understand both objects and relationships between the objects in the scene. A '''scene graph''' is a representation of the set of objects and relations that exist in the scene, where objects are represented as nodes, relations are represented as edges connecting the different nodes. Hence, the prediction of the scene graph is analogous to inferring the joint set of objects and relations of a visual scene.<br />
<br />
[[File:scene_graph_example.png|600px|center]]<br />
<br />
Given that objects in scenes are interdependent on each other, joint prediction of the objects and relations is necessary. The field of structured prediction, which involves the general problem of inferring multiple inter-dependent labels, is of interest for this problem. Structured prediction has attracted considerable attention because it applies to many learning problems and poses unique theoretical and applied challenges (e.g., see Belanger et al., 2017; Chen et al., 2015; Taskar et al., 2004).<br />
<br />
In structured prediction models, a score function <math>s(x, y)</math> is defined to evaluate the compatibility between label <math>y</math> and input <math>x</math>. For instance, when interpreting the scene of an image, <math>x</math> refers to the image itself, and <math>y</math> refers to a complex label, which contains both the objects and the relations between objects. As with most other inference methods, the goal is to find the label <math>y^*</math> such that <math>s(x,y)</math> is maximized, <math> y^*=argmax_y s(x,y)</math>. However, the major concern is that the space for possible label assignments grows exponentially with respect to input size. For example, although an image may seem very simple, the corpus containing possible labels for objects may be very large, rendering it difficult to optimize the scoring function. <br />
<br />
The paper presents an alternative approach, for which input <math>x</math> is mapped to structured output <math>y</math> using a "black box" neural network, omitting the definition of a score function. The main concern for this approach is the determination of the network architecture.<br />
<br />
The model is evaluated by firstly demonstrating the importance of permutation invariance on a synthetic data set. The approach laid out by the authors is then shown to respect permutation invariance, and results are compared to a competitive benchmark. This method achieves state-of-the-art results.<br />
<br />
=Structured prediction=<br />
This paper further considers structured predictions using score-based methods. For structured predictions that follow a score-based approach, a score function <math>s(x, y)</math> is used to measure how compatible label <math>y</math> is for input <math>x</math> and is also used to infer a label by maximizing <math>s(x, y)</math>. To optimize the score function, previous works have decomposed <math>s(x,y) = \sum_i f_i(x,y)</math> in order to facilitate efficient optimization which is done by optimizing the local score function, <math>\max_y f_i(x,y)</math>, with a small subset of the <math>y</math> variables.<br />
<br />
Recently, modeling the <math>f_i </math> functions as deep networks is a new interest. In such area of structured predictions, the most commonly-used score functions include the singleton score function <math>f_i(y_i, x)</math> and pairwise score function <math>f_{ij} (y_i, y_j, x)</math>. Previous works explored a two-stage architectures (learn local scores independently of the structured prediction goal), end-to-end architectures (to include the inference algorithm within the computation graph), and modelling global factors. <br />
<br />
==Advantages of using score-based methods==<br />
# Allow for intuitive specification of local dependencies between labels, and how they map to global dependencies<br />
# Linear score functions offer natural convex surrogates<br />
# Inference in large label space is sometimes possible via exact algorithms or empirically accurate approximations<br />
<br />
The concern for modeling score functions using deep networks is that learning may no longer be convex. Hence, the paper presents properties for how deep networks can be used for structured predictions by considering architectures that do not require explicit maximization of a score function.<br />
<br />
=Background, Notations, and Definitions=<br />
We denote <math>y</math> as a structured label where <math>y = [y_1, \dots, y_n]</math><br />
<br />
'''Score functions:''' for score-based methods, the score is defined as either the sum of a set of singleton scores <math>f_i = f_i(y_i, x)</math> or the sum of pairwise scores <math>f_{ij} = f_{ij}(y_i, y_j, x)</math>.<br />
<br />
Let <math>s(x,y)</math> be the score of a score-based method. Then:<br />
<br />
<div align="center"><br />
<math>s(x,y) = \begin{cases}<br />
\sum_i f_i ~ \text{if we have a set of singleton scores}\\<br />
\sum_{ij} f_{ij} ~ \text{if we have a set of pairwise scores } \\<br />
\end{cases}</math><br />
</div><br />
<br />
'''Inference algorithm:''' an inference algorithm takes input set of local scores (either <math>f_i</math> or <math>f_{ij}</math>) and outputs an assignment of labels <math>y_1, \dots, y_n</math> that maximizes score function <math>s(x,y)</math><br />
<br />
'''Graph labeling function:''' a graph labeling function <math>\mathcal{F} : (V,E) \rightarrow Y</math> is a function that takes input of: an ordered set of node features <math>V = [z_1, \dots, z_n]</math> and an ordered set of edge features <math>E = [z_{1,2},\dots,z_{i,j},\dots,z_{n,n-1}]</math> to output set of node labels <math>\mathbf{y} = [y_1, \dots, y_n]</math>. For instance, <math>z_i</math> can be set equal to <math>f_i</math> and <math>z_{ij}</math> can be set equal to <math>f_{ij}</math>.<br />
<br />
For convenience, the joint set of nodes and edges will be denoted as <math>\mathbf{z}</math> to be a size <math>n^2</math> vector (<math>n</math> nodes and <math>n(n-1)</math> edges).<br />
<br />
'''Permutation:''' Let <math>z</math> be a set of node and edge features. Given a permutation <math>\sigma</math> of <math>\{1,\dots,n\}</math>, let <math>\sigma(z)</math> be a new set of node and edge features given by [<math>\sigma(z)]_i = z_{\sigma(i)}</math> and <math>[\sigma(z)]_{i,j} = z_{\sigma(i), \sigma(j)}</math><br />
<br />
'''One-hot representation:''' <math>\mathbf{1}[j]</math> be a one-hot vector with 1 in the <math>j^{th}</math> coordinate<br />
<br />
=Permutation-Invariant Structured prediction=<br />
<br />
With permutation-invariant structured prediction, we would expect the algorithm to produce the same result given the same score function. For instance, consider the case where we have label space for 3 variables <math>y_1, y_2, y_3</math> with input <math>\mathbf{z} = (f_1, f_2, f_3, f_{12}, f_{13}, f_{23})</math> that outputs label <math>\mathbf{y} = (y_1^*, y_2^*, y_3^*)</math>. Then if the algorithm is run on a permuted version input <math>z' = (f_2, f_1, f_3, f_{21}, f_{23}, f_{13})</math>, we would expect <math>\mathbf{y} = (y_2^*, y_1^*, y_3^*)</math> given the same score function.<br />
<br />
'''Graph permutation invariance (GPI):''' a graph labeling function <math>\mathcal{F}</math> is graph-permutation invariant, if for all permutations <math>\sigma</math> of <math>\{1, \dots, n\}</math> and for all nodes <math>z</math>, <math>\mathcal{F}(\sigma(\mathbf{z})) = \sigma(\mathcal{F}(\mathbf{z}))</math>. Practically speaking, graph permutation means that the same graph is constructed, no matter the order in which elements are predicted. In scene graph generation approaches, Region Proposal Networks are often used as an initial pre-processing step. The results from these (cropped images representing bounding boxes) are then sequentially fed through a respective vertex (or edge) detection network. The idea behind Permutation Invariance is that no matter the order these are passed in, the final scene graph is identical. In effect, this means not connecting vertices that should not be connected simply because a more promising vertex has not yet been identified. <br />
<br />
The paper presents a theorem on the necessary and sufficient conditions for a function <math>\mathcal{F}</math> to be graph permutation invariant. Intuitively, because <math>\mathcal{F}</math> is a function that takes an ordered set <math>z</math> as input, the output on <math>\mathbf{z}</math> could very well be different from <math>\sigma(\mathbf{z})</math>, which means <math>\mathcal{F}</math> needs to have some sort of symmetry in order to sustain <math>[\mathcal{F}(\sigma(\mathbf{z}))]]_k = [\mathcal{F}(\mathbf{z})]_{\sigma(k)}</math>.<br />
<br />
[[File:graph_permutation_invariance.jpg|400px|center]]<br />
<br />
==Theorem 1==<br />
Let <math>\mathcal{F}</math> be a graph labeling function. Then <math>\mathcal{F}</math> is graph-permutation invariant if and only if there exist functions <math>\alpha, \rho, \phi</math> such that for all <math>k=1, .., n</math>:<br />
\begin{align}<br />
[\mathcal{F}(\mathbf{z})]_k = \rho(\mathbf{z}_k, \sum_{i=1}^n \alpha(\mathbf{z}_i, \sum_{i\neq j} \phi(\mathbf{z}_i, \mathbf{z}_{i,j}, \mathbf{z}_j)))<br />
\end{align}<br />
where <math>\phi: \mathbb{R}^{2d+e} \rightarrow \mathbb{R}^L, \alpha: \mathbb{R}^{d + L} \rightarrow \mathbb{R}^{W}, p: \mathbb{R}^{W+d} \rightarrow \mathbb{R}</math>.<br />
<br />
Notice that for the dimensions of inputs and outputs, <math>d</math> refers to the number of singleton features in <math>z</math> and <math>e</math> refers to the number of edges. <br />
<br />
[[File:GPI_architecture.jpg|thumb|A schematic representation of the GPI architecture. Singleton features <math>z_i</math> are omitted for simplicity. First, the features <math>z_{i,j}</math> are processed element-wise by <math>\phi</math>. Next, they are summed to create a vector <math>s_i</math>, which is concatenated with <math>z_i</math>. Third, a representation of the entire graph is created by applying <math>\alpha\ n</math> times and summing the created vector. The graph representation is then finally processed by <math>\rho</math> together with <math>z_k</math>.|600px|center]]<br />
<br />
==Proof Sketch for Theorem 1==<br />
The proof of this theorem can be found in the paper. A proof sketch is provided below:<br />
<br />
'''For the forward direction''' (function that follows the form set out in equation (1) is GPI):<br />
# Using definition of permutation <math>\sigma</math>, and rewriting <math>[F(z)]_{\sigma(k)}</math> in the form from equation (1)<br />
# Second argument of <math>\rho</math> is invariant under <math>\sigma</math>, since it takes the sum of all indices <math>i</math> and all other indices <math>j \neq i </math>.<br />
<br />
'''For the backward direction''' (any black-box GPI function can be expressed in the form of equation 1):<br />
# Construct <math>\phi, \alpha</math> such that second argument of <math>\rho</math> contains all information about graph features of <math>z</math>, including edges that the features originate from<br />
# Assume each <math>z_k</math> uniquely identifies the node and <math>\mathcal{F}</math> is a function only of pairwise features <math>z_{i,j}</math><br />
# Construct <math>H</math> be a perfect hash function with <math>L</math> buckets, and <math>\phi</math> which maps '''pairwise features''' to a vector of size <math>L</math><br />
# <math>*</math>Construct <math>\phi(z_i, z_{i,j}, z_j) = \mathbf{1}[H(z_j)] z_{i,j}</math>, which intuitively means that <math>\phi</math> stores <math>z_{i,j}</math> in the unique bucket for node <math>j</math><br />
# Construct function <math>\alpha</math> to output a matrix <math>\mathbb{R}^{L \times L}</math> that maps each pairwise feature into unique positions (<math>\alpha(z_i, s_i) = \mathbf{1}[H(z_i)]s_i^T</math>)<br />
# Construct matrix <math>M = \sum_i \alpha(z_i,s_i)</math> by discarding rows/columns in <math>M</math> that do not correspond to original nodes (which reduces dimension to <math>n\times n</math>; set <math>\rho</math> to have same outcome as <math>\mathcal{F}</math>, and set the output of <math>\mathcal{F}</math> on <math>M</math> to be the labels <math>\mathbf{y} = y_1, \dots, y_n</math><br />
<br />
<math>*</math>The paper presents the proof for the edge features <math>z_{ij}</math> being scalar (<math>e = 1</math>) for simplicity, which can be extended easily to vectors with additional indexing.<br />
<br />
Although the results discussed previously apply to complete graphs (edges apply to all feature pairs), it can be easily extended to incomplete graphs. For incomplete graphs, the input to F only contains the features corresponding to valid edges of the graph. The authors are only interested in invariances that preserve the graph structure. Thus, in place of permutation-invariance, it is now an automorphism-invariance.<br />
<br />
==Implications and Applications of Theorem 1==<br />
===Key Implications of Theorem 1===<br />
# Architecture "collects" information from the different edges of the graph, and does so in an invariant fashion using <math>\alpha</math> and <math>\phi</math><br />
# Architecture is parallelizable, since all <math>\phi</math> functions can be applied simultaneously. In contrast, recurrent models (Zellers et al. 2017) are harder to parallelize and are thus practically slower.<br />
<br />
===Some applications of Theorem 1===<br />
# '''Attention:''' the concept of attention can be implemented in the GPI characterization, with slight alterations to the functions <math>\alpha</math> and <math>\phi</math>. In attention each node aggregates features of neighbors through a function of neighbor's relevance. Which means the label of an entity could depend strongly on its close entity. The complete details can be found in the supplementary materials of the paper.<br />
<br />
# '''RNN:''' recurrent architectures can maintain GPI property, since all GPI function <math>\mathcal{F}</math> are closed under composition. The output of one step after running <math>\mathcal{F}</math> will act as input for the next step, but maintain the GPI property throughout.<br />
<br />
=Related Work=<br />
# '''Architectural invariance:''' suggested recently in a 2017 paper called Deep Sets by Zaheer et al., which considers the case of invariance that is more restrictive.<br />
# '''Deep structured prediction:''' previous work applied deep learning to structured prediction, for instance, semantic segmentation. Some algorithms include message passing algorithms, gradient descent for maximizing score functions, greedy decoding (inference of labels based on time of previous labels). For example, Xu et al. 2017 propose a novel end-to-end model that generates structured scene representation, and their model solves the scene graph inference problem using standard RNNs and learns to iteratively improves its predictions via message passing. Apart from those algorithms, deep learning has been applied to other graph-based problems such as the Travelling Salesman Problem (Bello et al., 2016; Gilmer et al., 2017; Khalil et al., 2017). However, none of the previous work specifically address the notion of invariance in the general architecture, but rather focus on message passing architectures that can be generalized by this paper.<br />
# '''Scene graph prediction:''' scene graph extraction allows for reasoning, question answering, and image retrieval (Johnson et al., 2015; Lu et al., 2016; Raposo et al., 2017). Some other works in this area include object detection, action recognition, and even the detection of human-object interactions (Liao et al., 2016; Plummer et al., 2017). Additional work has been done with the use of message passing algorithms (Xu et al., 2017), word embeddings (Lu et al., 2016), and end-to-end prediction directly from pixels (Newell & Deng, 2017). A notable mention is NeuralMotif (Zellers et al., 2017), which the authors describe as the current state-of-the-art model for scene graph predictions on Visual Genome dataset. It uses an RNN that supplies global context by reading the independent predictions sequentially for each entity and relation and then conducts further refinement on the predictions. The NeuralMotif model has a fixed order in which the RNN reads its inputs and thereby maintains GPI. However, this fixed order is not guaranteed to be optimal.<br />
# '''Burst Image Deblurring Using Permutation Invariant Convolutional Neural Networks:''' similar ideas were applied, where Permutation Invariant CNN, are used to restore sharp and noise-free images from bursts of photographs affected by hand tremor and noise. This presented good quality images with lots of details for challenging datasets.<br />
<br />
=Experimental Results=<br />
<br />
The authors evaluated the advantage of GPI architectures empirically. They first utilized synthetic graph labeling and then used scene-graph classification for mapping images.<br />
<br />
==Synthetic Graph Labeling==<br />
The authors created a synthetic problem to study GPI. This involved using an input graph <math>G = (V,E)</math> where each node <math>i</math> belongs to the set <math>\Gamma(i) \in \{1, \dots, K\}</math> where <math>K</math> is the number of samples. The task is to compute for each node, the number of neighbours that belong to the same set (i.e. finding the label of the node <math>i</math> if <math>y_i = \sum_{j \in N(i)} \mathbf{1}[\Gamma(i) = \Gamma(j)]</math>) . Then, random graphs (each with 10 nodes) were generated by sampling edges, and the set <math>\Gamma(i) \in \{1, \dots, K\}</math>for each node independently and uniformly.<br />
The node features of the graph <math>z_i \in \{0,1\}^K</math> are one-hot vectors of <math>\Gamma(i)</math>, and each pairwise edge feature <math>z_{ij} \in \{0, 1\}</math> denote whether the edge <math>ij</math> is in the edge set <math>E</math>. <br />
3 architectures were studied in this paper:<br />
# '''GPI-architecture for graph prediction''' (without attention and RNN)<br />
# '''LSTM''': replacing <math>\sum \phi(\cdot)</math> and <math>\sum \alpha(\cdot)</math> in the form of Theorem 1 using two LSTMs with state size 200, reading their input in random order<br />
# '''Fully connected feed-forward network''': with 2 hidden layers, each layer containing 1,000 nodes; the input is a concatenation of all nodes and pairwise features, and the output is all node predictions<br />
<br />
The results show that the GPI architecture requires far fewer samples to converge to the correct solution.<br />
[[File:GPI_synthetic_example.jpg|450px|center]]<br />
<br />
This experimental result is meant to demonstrate sample complexity. For fairness, all three models were constructed with a similar number of trainable parameters. The results tie back in with the author's comment that a black-box model which violates permutation invariant structure wastes capacity on learning it at training time. This illustrates the advantage of an architecture with a proper inductive bias.<br />
<br />
==Scene-Graph Classification==<br />
Applying the concept of GPI to Scene-Graph Prediction (SGP) is the main task of this paper. The input to this problem is an image, along with a set of annotated bounding boxes for the entities in the image. The goal is to correctly label each entity within the bounding boxes and the relationship between every pair of entities, resulting in a coherent scene graph.<br />
<br />
The authors describe two different types of variables to predict. The first type is entity variables <math>[y_1, \dots, y_n]</math> for all bounding boxes, where each <math>y_i</math> can take one of L values and refers to objects such as "dog" or "man". The second type is relation variables <math>[y_{n+1}, \cdots, y_{n^2}]</math>, where each <math>y_i</math> represents the relation (e.g. "on", "below") between a pair of bounding boxes (entities).<br />
<br />
The scene graph and contain two types of edges:<br />
# '''Entity-entity edge''': connecting two entities <math>y_i</math> and <math>y_j</math> for <math>1 \leq i \neq j \leq n</math><br />
# '''Entity-relation edges''': connecting every relation variable <math>y_k</math> for <math>k > n</math> to two entities<br />
<br />
The feature set <math>\mathbf{z}</math> is based on the baseline model from Zellers et al. (2017). For entity variables <math>y_i</math>, the vector <math>\mathbf{z}_i \in \mathbb{R}^L</math> models the probability of the entity appearing in <math>y_i</math>. <math>\mathbf{z}_i</math> is augmented by the coordinates of the bounding box. Similarly for relation variables <math>y_j</math>, the vector <math>\mathbf{z}_j \in \mathbb{R}^R</math>, models the probability of the relations between the two entities in <math>j</math>. For entity-entity pairwise features <math>\mathbf{z}_{i,j}</math>, there is a similar representation of the probabilities for the pair. The SGP outputs probability distributions over all entities and relations, which will then be used as input recurrently to maintain GPI. Finally, word embeddings are used and concatenated for the most probable entity-relation labels.<br />
<br />
'''Components of the GPI architecture''' (ent for entity, rel for relation)<br />
# <math>\phi_{ent}</math>: network that integrates two entity variables <math>y_i</math> and <math>y_j</math>, with input <math>z_i, z_j, z_{i,j}</math> and output vector of <math>\mathbb{R}^{n_1}</math> <br />
# <math>\alpha_{ent}</math>: network with inputs from <math>\phi_{ent}</math> for all neighbours of an entity, and uses attention mechanism to output vector <math>\mathbb{R}^{n_2}</math> <br />
# <math>\rho_{ent}</math>: network with inputs from the various <math>\mathbb{R}^{n_2}</math> vectors, and outputs <math>L</math> logits to predict entity value<br />
# <math>\rho_{rel}</math>: network with inputs <math>\alpha_{ent}</math> of two entities and <math>z_{i,j}</math>, and output into <math>R</math> logits<br />
<br />
==Set-up and Results==<br />
'''Dataset''': based on Visual Genome (VG) by (Krishna et al., 2017), which contains a total of 108,077 images annotated with bounding boxes, entities, and relations. An average of 12 entities and 7 relations exist per image. For a fair comparison with previous works, data from (Xu et al., 2017) for train and test splits were used. The authors used the same 150 entities and 50 relations as in (Xu et al., 2017; Newell & Deng, 2017; Zellers et al., 2017). Hyperparameters were tuned using a 70K/5K/32K split for training, validation, and testing respectively.<br />
<br />
'''Training''': all networks were trained using the Adam optimizer, with a batch size of 20. The loss function was the sum of cross-entropy losses over all of entities and relations. Penalties for misclassified entities were 4 times stronger than that of relations. Penalties for misclassified negative relations were 10 times weaker than that of positive relations.<br />
<br />
'''Evaluation''': there are three major tasks when inferring from the scene graph. The authors focus on the following:<br />
# '''SGCIs''': given ground-truth entity bounding boxes, predict all entity and relations categories<br />
# '''PredCIs''': given annotated bounding boxes with entity labels, predict all relations<br />
<br />
The evaluation metric Recall@K (shortened to R@K) is drawn from (Lu et al., 2016). This metric is the fraction of correct ground-truth triplets that appear within the <math>K</math> most confident triplets predicted by the model. Graph-constrained protocol requires the top-<math>K</math> triplets to assign one consistent class per entity and relation. The unconstrained protocol does not enforce such constraint.<br />
<br />
'''Models and baselines''': The authors compared variants of the GPI approach against four baselines, state-of-the-art models on completing scene graph sub-tasks. To maintain consistency, all models used the same training/testing data split, in addition to the preprocessing as per (Xu et al., 2017).<br />
<br />
'''Baselines from existing state-of-the-art models'''<br />
# (Lu et al., 2016): use of word embeddings to fine-tune the likelihood of predicted relations<br />
# (Xu et al., 2017): message passing algorithm between entities and relations to iteratively improve feature map for prediction<br />
# (Newell & Deng, 2017): Pixel2Graph, uses associative embeddings to produce a full graph from image<br />
# (Zellers et al., 2017): NeuralMotif method, encodes global context to capture higher-order motif in scene graphs; Baseline outputs entities and relations distributions without using global context<br />
<br />
'''GPI models'''<br />
# '''GPI with no attention mechanism''': simply following Theorem 1's functional form, with summation over features<br />
# '''GPI NeighborAttention''': same GPI model, but considers attention over neighbours features<br />
# '''GPI Linguistic''': similar to NeighborAttention model, but concatenates word embedding vectors<br />
<br />
'''Key Results''': The GPI Linguistic approach outperforms all baseline for SGCIs, and has similar performance to the state of the art NeuralMotifs method. The authors argue that PredCI is an easier task with less structure, yielding high performance for the existing state of the art models.<br />
<br />
[[File:GPI_table_results.png|700px|center]]<br />
<br />
=Conclusion=<br />
<br />
A deep learning approach was presented in this paper to structured prediction, which constrains the architecture to be invariant to structurally identical inputs. This approach relies on pairwise features which are capable of describing inter-label correlations and inherits the intuitive aspect of score-based approaches. The output produced is invariant to equivalent representation of the pairwise terms. <br />
<br />
As future work, the axiomatic approach can be extended; for example in image labeling, geometric variances such as shifts or rotations may be desired (or in other cases invariance to feature permutations may be desired). Additionally, exploring algorithms that discover symmetries for deep structured prediction when the invariant structure is unknown and should be discovered from data is also an interesting extension of this work.<br />
<br />
=Critique=<br />
The paper's contribution comes from the novelty of the permutation invariance as a design guideline for structured prediction. Although not explicitly considered in many of the previous works, the idea of invariance in architecture has already been considered in Deep Sets by (Zaheer et al., 2017). This paper characterizes relaxes the condition on the invariance as compared to that of previous works. In the evaluation of the benefit of GPI models, the paper used a synthetic problem to illustrate the fact that far fewer samples are required for the GPI model to converge to 100% accuracy. However, when comparing the true task of scene graph prediction against the state-of-the-art baselines, the GPI variants had only marginal higher Recall@K scores. The true benefit of this paper's discovery is the avoidance of maximizing a score function (leading computationally difficult problem), and instead directly producing output invariant to how we represent the pairwise terms.<br />
<br />
=References=<br />
<br />
Lu, Cewu, Krishna, Ranjay, Bernstein, Michael S., and Li, Fei-Fei. Visual relationship detection with<br />
language priors. In European Conf. Comput. Vision, pp. 852–869, 2016.<br />
<br />
Roei Herzig, Moshiko Raboh, Gal Chechik, Jonathan Berant, Amir Globerson, Mapping Images to Scene Graphs with Permutation-Invariant Structured Prediction, 2018.<br />
<br />
Belanger, David, Yang, Bishan, and McCallum, Andrew. End-to-end learning for structured prediction energy networks. In Precup, Doina and Teh, Yee Whye (eds.), Proceedings of the 34th International Conference on Machine Learning, volume 70, pp. 429–439. PMLR, 2017.<br />
<br />
Chen, Liang Chieh, Schwing, Alexander G, Yuille, Alan L, and Urtasun, Raquel. Learning deep structured models. In Proc. ICML, 2015.<br />
<br />
Taskar, B., Guestrin, C., and Koller, D. Max margin Markov networks. In Thrun, S., Saul, L., and Schölkopf, B. (eds.), Advances in Neural Information Processing Systems 16, pp. 25–32. MIT Press, Cambridge, MA, 2004.<br />
<br />
Additional resources from Moshiko Raboh's [https://github.com/shikorab/SceneGraph GitHub]</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Learning_to_Navigate_in_Cities_Without_a_Map&diff=41991Learning to Navigate in Cities Without a Map2018-11-30T00:58:31Z<p>Gsahu: /* Future Works */</p>
<hr />
<div>Paper: <br />
[https://arxiv.org/pdf/1804.00168.pdf Learning to Navigate in Cities Without a Map]<br />
A video of the paper is available [https://sites.google.com/view/streetlearn here].<br />
<br />
== Introduction ==<br />
Navigation is an attractive topic in many research disciplines and technology related domains such as neuroscience and robotics. The majority of algorithms are based on the following steps.<br />
<br />
1. Building an explicit map<br />
<br />
2. Planning and acting using that map. <br />
<br />
In this article, based on this fact that human can learn to navigate through cities without using any special tool such as maps or GPS, authors propose new methods to show that a neural network agent can do the same thing by using visual observations. To do so, an interactive environment using Google StreetView Images and a dual pathway agent architecture is designed. As shown in figure 1, some parts of the environment are built using Google StreetView images of New York City (Times Square, Central Park) and London (St. Paul’s Cathedral). The green cone represents the agent’s location and orientation. Although learning to navigate using visual aids is shown to be successful in some domains such as games and simulated environments using deep reinforcement learning (RL), it suffers from data inefficiency and sensitivity to changes in the environment. Thus, it is unclear whether this method could be used for large-scale navigation. That’s why it became the subject of investigation in this paper.<br />
[[File:figure1-soroush.png|600px|thumb|center|Figure 1. Our environment is built of real-world places from StreetView. The figure shows diverse views and corresponding local maps (neither map nor current position have not been used by the agent) in New York City (Times Square, Central Park) and London (St. Paul’s Cathedral). The green cone represents the agent’s location and orientation.]]<br />
<br />
==Contribution==<br />
This paper has made the following contributions:<br />
<br />
1. Designing a dual pathway agent architecture. This agent can navigate through a real city and is trained with end-to-end reinforcement learning to handle real-world navigations.<br />
<br />
2. Using Goal-dependent learning. This means that the policy and value functions must adapt themselves to a sequence of goals that are provided as input.<br />
<br />
3. Leveraging a recurrent neural architecture. Using that, not only could navigation through a city be possible, but also the model is scalable for navigation in new cities. This architecture supports both locale-specific learnings and general transferable navigations. The authors achieved these by separating a recurrent neural pathway. This pathway receives and interprets the current goal as well as encapsulates and memorizes features of a single region.<br />
<br />
4. Using a new environment which is built on top of Google StreetView images. This provides real-world images for agent’s observation. Using this environment, the agent can navigate from an arbitrary starting point to a goal and then to another goal etc. Also, London, Paris, and New York City are chosen for navigation.<br />
<br />
The authors demonstrate that their proposed method can provide a mechanism for transferring knowledge to new cities. As with humans, when the agent visits a new city, the expectation is it to have it learn a new set of landmarks, but not to have to re-learn its visual representations or its behaviours (e.g., zooming forward along streets or turning at intersections). Therefore, using the MultiCity architecture, the paper trains first on a number of cities, then freezes both the policy network and the visual convolutional network and only a new locale-specific pathway on a new city. This approach enables the agent to acquire new knowledge without forgetting what it has already learned, similarly to the progressive neural networks architecture.<br />
<br />
==Related Work==<br />
<br />
1. Localization from real-world imagery. For example, (Weyand et al., 2016), a CNN was able to achieve excellent results on geolocation task. This paper provides novel work by not including supervised training with ground-truth labels, and by including planning as a goal. Some other works also improve by exploiting spatiotemporal continuity or estimating camera pose or depth estimation from pixels. These methods rely on supervised training with ground truth labels, which is not possible in every environment. <br />
<br />
2. Deep RL methods for navigation. For instance, (Mirowski et al., 2016; Jaderberg et al., 2016) used self-supervised auxiliary tasks to produce visual navigation in several created mazes. Some other researches used text descriptions to incorporate goal instructions. Researchers developed realistic, higher-fidelity environment simulations to make the experiment more realistic, but that still came with lack of diversities. This paper makes use of real-world data, in contrast to many related papers in this area. It's diverse and visually realistic but still, it does not contain dynamic elements, and the street topology cannot be regenerated or altered.<br />
<br />
3. Deep RL for path planning and mapping. For example, (Zhang et al., 2017) created an agent that represented a global map via an RL agent with external memory; some other work uses a hierarchical control strategy to propose a structured memory and Memory Augmented Control Maps. Explicit neural mapper and navigation planner with joint training was also used. Among all these works, the target-driven visual navigation with a goal-conditional policy approach was most related to our method.<br />
<br />
4. To make simulations resemble reality, researchers have developed higher-fidelity simulated environments (Dosovitskiy et al., 2017; Kolve et al., 2017; Shah et al., 2018; Wu et al., 2018). However, in spite of the photo-realism, the inherent problems of simulated environments pertain to the limited diversity of the environments and the idealistic cleanliness of the observations.<br />
<br />
==Environment==<br />
Google StreetView consists of both high-resolution 360-degree imagery and graph connectivity. Also, it provides a public API. These features make it a valuable resource. In this work, large areas of New York, Paris, and London that contain between 7,000 and 65,500 nodes<br />
(and between 7,200 and 128,600 edges, respectively), have a mean node spacing of 10m and cover a range of up to<br />
5km chosen (Figure 2), without simplifying the underlying connections. This means that there are many areas 'congested' with nodes, occlusions, available footpaths, etc. The agent only sees RGB images that are visible in StreetView images (Figure 1) and is not aware of the underlying graph.<br />
<br />
[[File:figure2-soroush.png|700px|thumb|center|Figure 2. Map of the 5 environments in New York City; our experiments focus on the NYU area as well as on transfer learning from the other areas to Wall Street (see Section 5.3). In the zoomed in area, each green dot corresponds to a unique panorama, the goal is marked in blue, and landmark locations are marked with red pins.]]<br />
<br />
==Agent Interface and the Courier Task==<br />
In an RL environment, we need to define observations and actions in addition to tasks. The inputs to the agent are the image <math>x_t</math> and the goal <math>g_t</math>. Also, a first-person view of the 3D environment is simulated by cropping <math>x_t</math> to a 60-degree square RGB image that is scaled to 84*84 pixels. Furthermore, the action space consists of 5 movements: “slow” rotate left or right (±22:5), “fast” rotate left or right (±67.5), or move forward (implemented as a ''noop'' in the case where this is not a viable action). The most central edge is chosen if there are multiple edges in the agents viewing cone.<br />
<br />
There are lots of ways to specify the goal to the agent. In this paper, the current goal is chosen to be represented in terms of its proximity to a set L of fixed landmarks <math> L={(Lat_k, Long_k)}</math> which are specified using Latitude and Longitude coordinate system. For distance to the <math> k_{th}</math> landmark <math>{(d_{(t,k)}^g})_k</math> the goal vector contains <math> g_{(t,i)}=\tfrac{exp(-αd_{(t,i)}^g)}{∑_k exp(-αd_{(t,k)}^g)} </math>for <math>i_{th}</math> landmark with <math>α=0.002</math> (Figure 3).<br />
<br />
[[File:figure3-soroush.PNG|400px|thumb|center|Figure 3. We illustrate the goal description by showing a goal and a set of 5 landmarks that are nearby, plus 4 that are more distant. The code <math>g_i</math> is a vector with a softmax-normalised distance to each landmark.]]<br />
<br />
This form of representation has several advantages: <br />
<br />
1. It could easily be extended to new environments.<br />
<br />
2. It is intuitive. Even humans and animals use landmarks to be able to move from one place to another.<br />
<br />
3. It does not rely on arbitrary map coordinates, and provides an absolute (as opposed to relative) goal.<br />
<br />
In this work, 644 landmarks for New York, Paris, and London are manually defined. The courier task is the problem of navigating to a list of random locations within a city. In each episode, which consists of 1000 steps, the agent starts from a random place with random orientation. when an agent gets within 100 meters of goal, the next goal is randomly chosen. An episode ends after 1000 agent steps. Finally, the reward is proportional to the shortest path between agent and goal when the goal is first assigned (providing more reward for longer journeys). Thus the agent needs to learn the mapping between the images observed at the goal location and the goal vector in order to solve the courier task problem. Furthermore, the agent must learn the association between the images observed at its current location and the policy to reach the goal destination.<br />
<br />
==Methods==<br />
<br />
===Goal-dependent Actor-Critic Reinforcement Learning===<br />
In this paper, the learning problem is based on Markov Decision Process, with state space <math>\mathcal{S}</math>, action space <math>\mathcal{A}</math>, environment <math>\mathcal{E}</math>, and a set of possible goals <math>\mathcal{G}</math>. The reward function depends on the current goal and state: <math>\mathcal{R}: \mathcal{S} \times \mathcal{G} \times \mathcal{A} &rarr; \mathbb{R}</math>. Typically, in reinforcement learning the main goal is to find the policy which maximizes the expected return. Expected return is defined as the sum of<br />
discounted rewards starting from state <math>s_0</math> with discount <math>\gamma</math>. Also, the expected return from a state <math>s_t</math> depends on the goals that are sampled. The policy is defined as a distribution over the actions, given the current state <math>s_t</math> and the goal <math>g_t</math>: <br />
<br />
\begin{align}<br />
\pi(\alpha|s,g)=Pr(\alpha_t=\alpha|s_t=s, g_t=g)<br />
\end{align}<br />
<br />
Value function is defined as the expected return obtained by sampling actions from policy <math>\pi</math> from state <math>s_t</math> with goal <math>g_t</math>:<br />
<br />
\begin{align}<br />
V^{\pi}(s,g)=E[R_t]=E[Σ_{k=0}^{\infty}\gamma^kr_{t+k}|s_t=s, g_t=g]<br />
\end{align}<br />
<br />
Also, an architecture with multiple pathways is designed to support two types of learning that is required for this problem. First, an agent needs an internal representation which is general and gives an understanding of a scene. Second, to better understand a scene the agent needs to remember unique features of the scene which then help the agent to organize and remember the scenes.<br />
<br />
===Architectures===<br />
<br />
[[File:figure4-soroush.png|400px|thumb|center|Figure 4. Comparison of architectures. Left: GoalNav is a convolutional encoder plus policy LSTM with goal description input. Middle: CityNav is a single-city navigation architecture with a separate goal LSTM and optional auxiliary heading (θ). Right: MultiCityNav is a multi-city architecture with individual goal LSTM pathways for each city.]]<br />
<br />
The authors use neural networks to parameterize policy and value functions. These neural networks share weights in all layers except the final linear layer. The agent takes image pixels as input. These pixels are passed through a convolutional network. The output of the Convolution network is fed to a Long Short-Term Memory (LSTM) as well as the past reward <math>r_{t-1}</math> and previous action <math>\alpha_{t-1}</math>.<br />
<br />
Three different architectures are described below.<br />
<br />
The '''GoalNav''' architecture (Fig. 4a) which consists of a convolutional architecture and policy LSTM. Goal description <math>g_t</math>, previous action, and reward are the inputs of this LSTM.<br />
<br />
The '''CityNav''' architecture (Fig. 4b) consists of the previous architecture alongside an additional LSTM, called the goal LSTM. Inputs of this LSTM are visual features and the goal description. The CityNav agent also adds an auxiliary heading (θ) prediction task which is defined as an angle between the north direction and the agent’s pose. This auxiliary task can speed up learning and provides relevant information. <br />
<br />
The '''MultiCityNav''' architecture (Fig. 4c) is an extension of CityNav for learning in different cities. This is done using the parallel connection of goal LSTMs for encapsulating locale-specific features, for each city. Moreover, the convolutional architecture and the policy LSTM become general after training on a number of cities. So, new goal LSTMs are required to be trained in new cities.<br />
<br />
In this paper, the authors use IMPALA [1] to train the agents because IMPALA can get similar performance to A3C [2].<br />
<br />
===Prior on agent training: IMPALA and A3C===<br />
<br />
IMPALA (Importance Weighted Actor-Learner Architecture) is an actor-critic implementation of deep reinforcement learning that decouples actions from learning. IMPALA results in a comparable performance to A3C (Google DeepMind's previous algorithm: Asynchronous Actor-Critic Agents) on a single city task, but it has been shown to handle better multi-task learning than A3C. The authors use 256 actors for CityNav and 512 actors for MultiCityNav, with batch sizes of 256 or 512 respectively, and sequences are unrolled to length 50.<br />
<br />
===Curriculum Learning===<br />
In curriculum learning, the model is trained using simple examples in first steps. As soon as the model learns those examples, more complex and difficult examples would be fed to the model. In this paper, this approach is used to teach agent to navigate to further destinations. This courier task suffers from a common problem of RL tasks which is sparse rewards (similar to Montezuma’s Revenge) . To overcome this problem, a natural curriculum scheme is defined, in which sampling each new goal would be within 500m of the agent’s position. This is called phase 1. In phase 2, the maximum range is gradually increased to cover the full graph (3.5km in the smaller New York areas, or 5km for central London or Downtown Manhattan)<br />
<br />
Curriculum learning was first introduced by Bengio et. al in 2009. It serves as a continuation method for non-convex optimization, and improves training time by injecting noisy data. One example outside this paper for curriculum learning is outlined below:<br />
<br />
1. We aim to classify shapes within the following three classes: triangles, ellipses, and rectangles. We can create a curriculum by first starting with a simplified dataset that consists of only special cases of these three classes: equilateral triangles, circles, and squares. By first training on these special cases, and then introducing the full model, we can allow the algorithm to converge more quickly towards a local minima before providing "harder" examples. Feeding only these specialized examples also serves as a method to make the classes fall on more distinct manifold locations; with less overlap, these networks will perform better when noise is later added as well.<br />
<br />
==Results==<br />
In this section, the performance of the proposed architectures on the courier task is shown.<br />
<br />
[[File:figure5-2.png|600px|thumb|center|Figure 5. Average per-episode goal rewards (y-axis) are plotted vs. learning steps (x-axis) for the courier task in the NYU (New York City) environment (top), and in central London (bottom). We compare the GoalNav agent, the CityNav agent, and the CityNav agent without skip connection on the NYU environment, and the CityNav agent in London. We also compare the Oracle performance and a Heuristic agent, described below. The London agents were trained with a 2-phase curriculum– we indicate the end of phase 1 (500m only) and the end of phase 2 (500m to 5000m). Results on the Rive Gauche part of Paris (trained in the same way<br />
as in London) are comparable and the agent achieved mean goal reward 426.]]<br />
<br />
It is first shown that the CityNav agent, trained with curriculum learning, succeeds in learning the courier task in New York, London and Paris. Figure 5 compares the following agents:<br />
<br />
1. Goal Navigation agent.<br />
<br />
2. City Navigation Agent.<br />
<br />
3. A City Navigation agent without the skip connection from the vision layers to the policy LSTM. This is needed to regularise the interface between the goal LSTM and the policy LSTM in multi-city transfer scenario.<br />
<br />
Also, a lower bound (Heuristic) and an upper bound(Oracle) on the performance is considered. As it is said in the paper: "Heuristic is a random walk on the street graph, where the agent turns in a random direction if it cannot move forward; if at an intersection it will turn with a probability <math>P=0.95</math>. Oracle uses the full graph to compute the optimal path using breadth-first search.". As it is clear in Figure 5, CityNav architecture with the previously mentioned architecture attains a higher performance and is more stable than the simpler GoalNav agent.<br />
<br />
The trajectories of the trained agent over two 1000 step episodes and the value function of the agent during navigation to a destination is shown in Figure 6.<br />
<br />
[[File:figure6-soroush.png|400px|thumb|center|Figure 6. Trained CityNav agent’s performance in two environments: Central London (left panes), and NYU (right panes). Top: examples of the agent’s trajectory during one 1000-step episode, showing successful consecutive goal acquisitions. The arrows show the direction of travel of the agent. Bottom: We visualize the value function of the agent during 100 trajectories with random starting points and the same goal (respectively St Paul’s Cathedral and Washington Square). Thicker and warmer color lines correspond to higher value functions.]]<br />
<br />
Figure 7 shows that navigation policy is learned by agent successfully in St Paul’s Cathedral in London and Washington Square in New York.<br />
[[File:figure7-soroush.png|400px|thumb|center|Figure 7. Number of steps required for the CityNav agent to reach<br />
a goal (Washington Square in New York or St Paul’s Cathedral in<br />
London) from 100 start locations vs. the straight-line distance to<br />
the goal in meters. One agent step corresponds to a forward movement<br />
of about 10m or a left/right turn by 22.5 or 67.5 degrees.]]<br />
<br />
The authors mask 25% of the possible goals and train on the remaining ones in order to investigate the generalisation capability of a trained agent. Figure 8 Showa that the agent is still able to traverse through these areas, it just never samples a goal there. <br />
[[File:fff8.png|600px|center]]<br />
<br />
A critical test for this article is to transfer model to new cities by learning a new set of landmarks, but without re-learning visual representation, behaviors, etc. Therefore, the MultiCityNav agent is trained on a number of cities besides freezing both the policy LSTM and the convolutional encoder. Then a new locale-specific goal LSTM is trained. The performance is compared using three different training regimes, illustrated in Fig. 9: Training on only the target city (single training); training on multiple cities, including the target city, together (joint training); and joint training on all but the target city, followed by training on the target city with the rest of the architecture frozen (pre-train and transfer). Figure 10 shows that transferring to other cities is possible. Also, training the model on more cities would increase its effectiveness. According to the paper: "Remarkably, the agent that is pre-trained on 4 regions and then transferred to Wall Street achieves comparable performance to an agent trained jointly on all the regions, and only slightly worse than single-city training on Wall Street alone". Training the model in a single city using skip connection is useful. However, it is not useful in multi-city transferring.<br />
[[File:figure9-soroush.png|400px|thumb|center|Figure 9. Illustration of training regimes: (a) training on a single city (equivalent to CityNav); (b) joint training over multiple cities with a dedicated per-city pathway and shared convolutional net and policy LSTM; (c) joint pre-training on a number of cities followed by training on a target city with convolutional net and policy LSTM frozen (only the target city pathway is optimized).]]<br />
[[File:figure10-soroush.png|400px|thumb|center|Figure 10. Joint multi-city training and transfer learning performance of variants of the MultiCityNav agent evaluated only on the target city (Wall Street). We compare single-city training on the target environment alone vs. joint training on multiple cities (3, 4, or 5-way joint training including Wall Street), vs. pre-training on multiple cities and then transferring to Wall Street while freezing the entire agent except for the new pathway (see Fig. 10). One variant has skip connections between the convolutional encoder and the policy LSTM, the other does not (no-skip).]]<br />
<br />
Giving early rewards before agent reaches the goal or adding random rewards (coins) to encourage exploration is investigated in this article. Figure 11a suggests that coins by themselves are ineffective as our task does not benefit from wide explorations. Also, as it is clear from Figure 11b, reducing the density of the landmarks does not seem to reduce the performance. Based on the results, authors chose to start sampling the goal within a radius of 500m from the agent’s location, and then progressively extend it to the maximum distance an agent could travel within the environment. In addition, to asses the importance of the goal-conditioned agents, a Goal-less CityNav agent is trained by removing inputs gt. The poor performance of this agent is clear in Figure 11b. Furthermore, reducing the density of the landmarks by the ratio of 50%, 25%, and 12:5% does not reduce the performance that much. Finally, some alternative for goal representation is investigated:<br />
<br />
a) Latitude and longitude scalar coordinates normalized to be between 0 and 1.<br />
<br />
b) Binned representation. <br />
<br />
The latitude and longitude scalar goal representations perform the best. However, since the all landmarks representation performs well while remaining independent of the coordinate system, we use this representation as the canonical one.<br />
<br />
[[File:figure11-soroush.PNG|300px|thumb|center|Figure 11. Top: Learning curves of the CityNav agent on NYU, comparing reward shaping with different radii of early rewards (ER) vs. ER with random coins vs. curriculum learning with ER 200m and no coins (ER 200m, Curr.). Bottom: Learning curves for CityNav agents with different goal representations: landmark-based, as well as latitude and longitude classification-based and regression-based.]]<br />
<br />
==Conclusion==<br />
In this paper, a deep reinforcement learning approach that enables navigation in cities is presented through the use of Google StreetView for its photographic content and worldwide coverage. Furthermore, the authors discussed a new courier task and a multi-city neural network agent architecture that is transferable to new cities. A successful navigation architecture is presented which relies on integration of general policies with locale-specific knowledge.<br />
<br />
==Future Works==<br />
The paper uses staic Google Street View images. However, this means that there are some more information that we can get from the images beyond the route. Even though it is not the central focus of the paper, it would be extremely useful if we can incorporate such information for effective route-building or planning.<br />
<br />
==Critique==<br />
1. It is not clear how this model is applicable to the real world. A real-world navigation problem needs to detect objects, people, and cars. However, it is not clear whether they are modeling them or not. From what I understood, they did not care about the collision, which is against their claim that it is a real-world problem.<br />
<br />
2. This paper is only using static Google Street View images as its primary source of data. But the authors must at least complement this with other dynamic data like traffic and road blockage information for a realistic model of navigation in the world. Also, this is quite understandable not to use maps but is not clear why have they not used GPS to know their position and maybe even made up with a map. This can be something useful in an emergency or even for investigating places that are not known or there is no access to them. The resulting map could be easily compared with the real one and could also be used in training to achieve higher performance. The availability should not be a serious problem because if they are simulating a real city and the google images are available, why should not GPS be? What is the intuition? At least, a complementary description on this could be helpful.<br />
<br />
3. The 'Transfer in Multi-City Experiments' results could be strengthened significantly via cross-validation (only Wall Street, which covers the smallest area of the four regions, is used as the test case). Additionally, the results do not show true 'multi-city' transfer learning, since all regions are within New York City. It is stated in the paper that not having to re-learn visual representations when transferring between cities is one of the outcomes, but the tests do not actually check for this. There are likely significant differences in the features that would be learned in NYC vs. Waterloo, for example, and this type of transfer has not been evaluated.<br />
<br />
4. The proposed navigation model could be limited by its reliance on pre-defined landmarks, which appears to be strategically placed evenly spreading across each city. This could limit the agent's deployability to new cities.<br />
<br />
==Reference==<br />
[1] Espeholt, Lasse, Soyer, Hubert, Munos, Remi, Simonyan, Karen, Mnih, Volodymir, Ward, Tom, Doron, Yotam, Firoiu, Vlad, Harley, Tim, Dunning, Iain, Legg, Shane, and Kavukcuoglu, Koray. Impala: Scalable distributed deep-rl with importance weighted actor-learner architec- tures. arXiv preprint arXiv:1802.01561, 2018.<br />
<br />
[2] Mnih, Volodymyr, Badia, Adria Puigdomenech, Mirza, Mehdi, Graves, Alex, Lillicrap, Timothy, Harley, Tim, Silver, David, and Kavukcuoglu, Koray. Asynchronous methods for deep reinforcement learning. In Interna- tional Conference on Machine Learning, pp. 1928–1937, 2016.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Learning_to_Navigate_in_Cities_Without_a_Map&diff=41990Learning to Navigate in Cities Without a Map2018-11-30T00:57:42Z<p>Gsahu: /* Conclusion */</p>
<hr />
<div>Paper: <br />
[https://arxiv.org/pdf/1804.00168.pdf Learning to Navigate in Cities Without a Map]<br />
A video of the paper is available [https://sites.google.com/view/streetlearn here].<br />
<br />
== Introduction ==<br />
Navigation is an attractive topic in many research disciplines and technology related domains such as neuroscience and robotics. The majority of algorithms are based on the following steps.<br />
<br />
1. Building an explicit map<br />
<br />
2. Planning and acting using that map. <br />
<br />
In this article, based on this fact that human can learn to navigate through cities without using any special tool such as maps or GPS, authors propose new methods to show that a neural network agent can do the same thing by using visual observations. To do so, an interactive environment using Google StreetView Images and a dual pathway agent architecture is designed. As shown in figure 1, some parts of the environment are built using Google StreetView images of New York City (Times Square, Central Park) and London (St. Paul’s Cathedral). The green cone represents the agent’s location and orientation. Although learning to navigate using visual aids is shown to be successful in some domains such as games and simulated environments using deep reinforcement learning (RL), it suffers from data inefficiency and sensitivity to changes in the environment. Thus, it is unclear whether this method could be used for large-scale navigation. That’s why it became the subject of investigation in this paper.<br />
[[File:figure1-soroush.png|600px|thumb|center|Figure 1. Our environment is built of real-world places from StreetView. The figure shows diverse views and corresponding local maps (neither map nor current position have not been used by the agent) in New York City (Times Square, Central Park) and London (St. Paul’s Cathedral). The green cone represents the agent’s location and orientation.]]<br />
<br />
==Contribution==<br />
This paper has made the following contributions:<br />
<br />
1. Designing a dual pathway agent architecture. This agent can navigate through a real city and is trained with end-to-end reinforcement learning to handle real-world navigations.<br />
<br />
2. Using Goal-dependent learning. This means that the policy and value functions must adapt themselves to a sequence of goals that are provided as input.<br />
<br />
3. Leveraging a recurrent neural architecture. Using that, not only could navigation through a city be possible, but also the model is scalable for navigation in new cities. This architecture supports both locale-specific learnings and general transferable navigations. The authors achieved these by separating a recurrent neural pathway. This pathway receives and interprets the current goal as well as encapsulates and memorizes features of a single region.<br />
<br />
4. Using a new environment which is built on top of Google StreetView images. This provides real-world images for agent’s observation. Using this environment, the agent can navigate from an arbitrary starting point to a goal and then to another goal etc. Also, London, Paris, and New York City are chosen for navigation.<br />
<br />
The authors demonstrate that their proposed method can provide a mechanism for transferring knowledge to new cities. As with humans, when the agent visits a new city, the expectation is it to have it learn a new set of landmarks, but not to have to re-learn its visual representations or its behaviours (e.g., zooming forward along streets or turning at intersections). Therefore, using the MultiCity architecture, the paper trains first on a number of cities, then freezes both the policy network and the visual convolutional network and only a new locale-specific pathway on a new city. This approach enables the agent to acquire new knowledge without forgetting what it has already learned, similarly to the progressive neural networks architecture.<br />
<br />
==Related Work==<br />
<br />
1. Localization from real-world imagery. For example, (Weyand et al., 2016), a CNN was able to achieve excellent results on geolocation task. This paper provides novel work by not including supervised training with ground-truth labels, and by including planning as a goal. Some other works also improve by exploiting spatiotemporal continuity or estimating camera pose or depth estimation from pixels. These methods rely on supervised training with ground truth labels, which is not possible in every environment. <br />
<br />
2. Deep RL methods for navigation. For instance, (Mirowski et al., 2016; Jaderberg et al., 2016) used self-supervised auxiliary tasks to produce visual navigation in several created mazes. Some other researches used text descriptions to incorporate goal instructions. Researchers developed realistic, higher-fidelity environment simulations to make the experiment more realistic, but that still came with lack of diversities. This paper makes use of real-world data, in contrast to many related papers in this area. It's diverse and visually realistic but still, it does not contain dynamic elements, and the street topology cannot be regenerated or altered.<br />
<br />
3. Deep RL for path planning and mapping. For example, (Zhang et al., 2017) created an agent that represented a global map via an RL agent with external memory; some other work uses a hierarchical control strategy to propose a structured memory and Memory Augmented Control Maps. Explicit neural mapper and navigation planner with joint training was also used. Among all these works, the target-driven visual navigation with a goal-conditional policy approach was most related to our method.<br />
<br />
4. To make simulations resemble reality, researchers have developed higher-fidelity simulated environments (Dosovitskiy et al., 2017; Kolve et al., 2017; Shah et al., 2018; Wu et al., 2018). However, in spite of the photo-realism, the inherent problems of simulated environments pertain to the limited diversity of the environments and the idealistic cleanliness of the observations.<br />
<br />
==Environment==<br />
Google StreetView consists of both high-resolution 360-degree imagery and graph connectivity. Also, it provides a public API. These features make it a valuable resource. In this work, large areas of New York, Paris, and London that contain between 7,000 and 65,500 nodes<br />
(and between 7,200 and 128,600 edges, respectively), have a mean node spacing of 10m and cover a range of up to<br />
5km chosen (Figure 2), without simplifying the underlying connections. This means that there are many areas 'congested' with nodes, occlusions, available footpaths, etc. The agent only sees RGB images that are visible in StreetView images (Figure 1) and is not aware of the underlying graph.<br />
<br />
[[File:figure2-soroush.png|700px|thumb|center|Figure 2. Map of the 5 environments in New York City; our experiments focus on the NYU area as well as on transfer learning from the other areas to Wall Street (see Section 5.3). In the zoomed in area, each green dot corresponds to a unique panorama, the goal is marked in blue, and landmark locations are marked with red pins.]]<br />
<br />
==Agent Interface and the Courier Task==<br />
In an RL environment, we need to define observations and actions in addition to tasks. The inputs to the agent are the image <math>x_t</math> and the goal <math>g_t</math>. Also, a first-person view of the 3D environment is simulated by cropping <math>x_t</math> to a 60-degree square RGB image that is scaled to 84*84 pixels. Furthermore, the action space consists of 5 movements: “slow” rotate left or right (±22:5), “fast” rotate left or right (±67.5), or move forward (implemented as a ''noop'' in the case where this is not a viable action). The most central edge is chosen if there are multiple edges in the agents viewing cone.<br />
<br />
There are lots of ways to specify the goal to the agent. In this paper, the current goal is chosen to be represented in terms of its proximity to a set L of fixed landmarks <math> L={(Lat_k, Long_k)}</math> which are specified using Latitude and Longitude coordinate system. For distance to the <math> k_{th}</math> landmark <math>{(d_{(t,k)}^g})_k</math> the goal vector contains <math> g_{(t,i)}=\tfrac{exp(-αd_{(t,i)}^g)}{∑_k exp(-αd_{(t,k)}^g)} </math>for <math>i_{th}</math> landmark with <math>α=0.002</math> (Figure 3).<br />
<br />
[[File:figure3-soroush.PNG|400px|thumb|center|Figure 3. We illustrate the goal description by showing a goal and a set of 5 landmarks that are nearby, plus 4 that are more distant. The code <math>g_i</math> is a vector with a softmax-normalised distance to each landmark.]]<br />
<br />
This form of representation has several advantages: <br />
<br />
1. It could easily be extended to new environments.<br />
<br />
2. It is intuitive. Even humans and animals use landmarks to be able to move from one place to another.<br />
<br />
3. It does not rely on arbitrary map coordinates, and provides an absolute (as opposed to relative) goal.<br />
<br />
In this work, 644 landmarks for New York, Paris, and London are manually defined. The courier task is the problem of navigating to a list of random locations within a city. In each episode, which consists of 1000 steps, the agent starts from a random place with random orientation. when an agent gets within 100 meters of goal, the next goal is randomly chosen. An episode ends after 1000 agent steps. Finally, the reward is proportional to the shortest path between agent and goal when the goal is first assigned (providing more reward for longer journeys). Thus the agent needs to learn the mapping between the images observed at the goal location and the goal vector in order to solve the courier task problem. Furthermore, the agent must learn the association between the images observed at its current location and the policy to reach the goal destination.<br />
<br />
==Methods==<br />
<br />
===Goal-dependent Actor-Critic Reinforcement Learning===<br />
In this paper, the learning problem is based on Markov Decision Process, with state space <math>\mathcal{S}</math>, action space <math>\mathcal{A}</math>, environment <math>\mathcal{E}</math>, and a set of possible goals <math>\mathcal{G}</math>. The reward function depends on the current goal and state: <math>\mathcal{R}: \mathcal{S} \times \mathcal{G} \times \mathcal{A} &rarr; \mathbb{R}</math>. Typically, in reinforcement learning the main goal is to find the policy which maximizes the expected return. Expected return is defined as the sum of<br />
discounted rewards starting from state <math>s_0</math> with discount <math>\gamma</math>. Also, the expected return from a state <math>s_t</math> depends on the goals that are sampled. The policy is defined as a distribution over the actions, given the current state <math>s_t</math> and the goal <math>g_t</math>: <br />
<br />
\begin{align}<br />
\pi(\alpha|s,g)=Pr(\alpha_t=\alpha|s_t=s, g_t=g)<br />
\end{align}<br />
<br />
Value function is defined as the expected return obtained by sampling actions from policy <math>\pi</math> from state <math>s_t</math> with goal <math>g_t</math>:<br />
<br />
\begin{align}<br />
V^{\pi}(s,g)=E[R_t]=E[Σ_{k=0}^{\infty}\gamma^kr_{t+k}|s_t=s, g_t=g]<br />
\end{align}<br />
<br />
Also, an architecture with multiple pathways is designed to support two types of learning that is required for this problem. First, an agent needs an internal representation which is general and gives an understanding of a scene. Second, to better understand a scene the agent needs to remember unique features of the scene which then help the agent to organize and remember the scenes.<br />
<br />
===Architectures===<br />
<br />
[[File:figure4-soroush.png|400px|thumb|center|Figure 4. Comparison of architectures. Left: GoalNav is a convolutional encoder plus policy LSTM with goal description input. Middle: CityNav is a single-city navigation architecture with a separate goal LSTM and optional auxiliary heading (θ). Right: MultiCityNav is a multi-city architecture with individual goal LSTM pathways for each city.]]<br />
<br />
The authors use neural networks to parameterize policy and value functions. These neural networks share weights in all layers except the final linear layer. The agent takes image pixels as input. These pixels are passed through a convolutional network. The output of the Convolution network is fed to a Long Short-Term Memory (LSTM) as well as the past reward <math>r_{t-1}</math> and previous action <math>\alpha_{t-1}</math>.<br />
<br />
Three different architectures are described below.<br />
<br />
The '''GoalNav''' architecture (Fig. 4a) which consists of a convolutional architecture and policy LSTM. Goal description <math>g_t</math>, previous action, and reward are the inputs of this LSTM.<br />
<br />
The '''CityNav''' architecture (Fig. 4b) consists of the previous architecture alongside an additional LSTM, called the goal LSTM. Inputs of this LSTM are visual features and the goal description. The CityNav agent also adds an auxiliary heading (θ) prediction task which is defined as an angle between the north direction and the agent’s pose. This auxiliary task can speed up learning and provides relevant information. <br />
<br />
The '''MultiCityNav''' architecture (Fig. 4c) is an extension of CityNav for learning in different cities. This is done using the parallel connection of goal LSTMs for encapsulating locale-specific features, for each city. Moreover, the convolutional architecture and the policy LSTM become general after training on a number of cities. So, new goal LSTMs are required to be trained in new cities.<br />
<br />
In this paper, the authors use IMPALA [1] to train the agents because IMPALA can get similar performance to A3C [2].<br />
<br />
===Prior on agent training: IMPALA and A3C===<br />
<br />
IMPALA (Importance Weighted Actor-Learner Architecture) is an actor-critic implementation of deep reinforcement learning that decouples actions from learning. IMPALA results in a comparable performance to A3C (Google DeepMind's previous algorithm: Asynchronous Actor-Critic Agents) on a single city task, but it has been shown to handle better multi-task learning than A3C. The authors use 256 actors for CityNav and 512 actors for MultiCityNav, with batch sizes of 256 or 512 respectively, and sequences are unrolled to length 50.<br />
<br />
===Curriculum Learning===<br />
In curriculum learning, the model is trained using simple examples in first steps. As soon as the model learns those examples, more complex and difficult examples would be fed to the model. In this paper, this approach is used to teach agent to navigate to further destinations. This courier task suffers from a common problem of RL tasks which is sparse rewards (similar to Montezuma’s Revenge) . To overcome this problem, a natural curriculum scheme is defined, in which sampling each new goal would be within 500m of the agent’s position. This is called phase 1. In phase 2, the maximum range is gradually increased to cover the full graph (3.5km in the smaller New York areas, or 5km for central London or Downtown Manhattan)<br />
<br />
Curriculum learning was first introduced by Bengio et. al in 2009. It serves as a continuation method for non-convex optimization, and improves training time by injecting noisy data. One example outside this paper for curriculum learning is outlined below:<br />
<br />
1. We aim to classify shapes within the following three classes: triangles, ellipses, and rectangles. We can create a curriculum by first starting with a simplified dataset that consists of only special cases of these three classes: equilateral triangles, circles, and squares. By first training on these special cases, and then introducing the full model, we can allow the algorithm to converge more quickly towards a local minima before providing "harder" examples. Feeding only these specialized examples also serves as a method to make the classes fall on more distinct manifold locations; with less overlap, these networks will perform better when noise is later added as well.<br />
<br />
==Results==<br />
In this section, the performance of the proposed architectures on the courier task is shown.<br />
<br />
[[File:figure5-2.png|600px|thumb|center|Figure 5. Average per-episode goal rewards (y-axis) are plotted vs. learning steps (x-axis) for the courier task in the NYU (New York City) environment (top), and in central London (bottom). We compare the GoalNav agent, the CityNav agent, and the CityNav agent without skip connection on the NYU environment, and the CityNav agent in London. We also compare the Oracle performance and a Heuristic agent, described below. The London agents were trained with a 2-phase curriculum– we indicate the end of phase 1 (500m only) and the end of phase 2 (500m to 5000m). Results on the Rive Gauche part of Paris (trained in the same way<br />
as in London) are comparable and the agent achieved mean goal reward 426.]]<br />
<br />
It is first shown that the CityNav agent, trained with curriculum learning, succeeds in learning the courier task in New York, London and Paris. Figure 5 compares the following agents:<br />
<br />
1. Goal Navigation agent.<br />
<br />
2. City Navigation Agent.<br />
<br />
3. A City Navigation agent without the skip connection from the vision layers to the policy LSTM. This is needed to regularise the interface between the goal LSTM and the policy LSTM in multi-city transfer scenario.<br />
<br />
Also, a lower bound (Heuristic) and an upper bound(Oracle) on the performance is considered. As it is said in the paper: "Heuristic is a random walk on the street graph, where the agent turns in a random direction if it cannot move forward; if at an intersection it will turn with a probability <math>P=0.95</math>. Oracle uses the full graph to compute the optimal path using breadth-first search.". As it is clear in Figure 5, CityNav architecture with the previously mentioned architecture attains a higher performance and is more stable than the simpler GoalNav agent.<br />
<br />
The trajectories of the trained agent over two 1000 step episodes and the value function of the agent during navigation to a destination is shown in Figure 6.<br />
<br />
[[File:figure6-soroush.png|400px|thumb|center|Figure 6. Trained CityNav agent’s performance in two environments: Central London (left panes), and NYU (right panes). Top: examples of the agent’s trajectory during one 1000-step episode, showing successful consecutive goal acquisitions. The arrows show the direction of travel of the agent. Bottom: We visualize the value function of the agent during 100 trajectories with random starting points and the same goal (respectively St Paul’s Cathedral and Washington Square). Thicker and warmer color lines correspond to higher value functions.]]<br />
<br />
Figure 7 shows that navigation policy is learned by agent successfully in St Paul’s Cathedral in London and Washington Square in New York.<br />
[[File:figure7-soroush.png|400px|thumb|center|Figure 7. Number of steps required for the CityNav agent to reach<br />
a goal (Washington Square in New York or St Paul’s Cathedral in<br />
London) from 100 start locations vs. the straight-line distance to<br />
the goal in meters. One agent step corresponds to a forward movement<br />
of about 10m or a left/right turn by 22.5 or 67.5 degrees.]]<br />
<br />
The authors mask 25% of the possible goals and train on the remaining ones in order to investigate the generalisation capability of a trained agent. Figure 8 Showa that the agent is still able to traverse through these areas, it just never samples a goal there. <br />
[[File:fff8.png|600px|center]]<br />
<br />
A critical test for this article is to transfer model to new cities by learning a new set of landmarks, but without re-learning visual representation, behaviors, etc. Therefore, the MultiCityNav agent is trained on a number of cities besides freezing both the policy LSTM and the convolutional encoder. Then a new locale-specific goal LSTM is trained. The performance is compared using three different training regimes, illustrated in Fig. 9: Training on only the target city (single training); training on multiple cities, including the target city, together (joint training); and joint training on all but the target city, followed by training on the target city with the rest of the architecture frozen (pre-train and transfer). Figure 10 shows that transferring to other cities is possible. Also, training the model on more cities would increase its effectiveness. According to the paper: "Remarkably, the agent that is pre-trained on 4 regions and then transferred to Wall Street achieves comparable performance to an agent trained jointly on all the regions, and only slightly worse than single-city training on Wall Street alone". Training the model in a single city using skip connection is useful. However, it is not useful in multi-city transferring.<br />
[[File:figure9-soroush.png|400px|thumb|center|Figure 9. Illustration of training regimes: (a) training on a single city (equivalent to CityNav); (b) joint training over multiple cities with a dedicated per-city pathway and shared convolutional net and policy LSTM; (c) joint pre-training on a number of cities followed by training on a target city with convolutional net and policy LSTM frozen (only the target city pathway is optimized).]]<br />
[[File:figure10-soroush.png|400px|thumb|center|Figure 10. Joint multi-city training and transfer learning performance of variants of the MultiCityNav agent evaluated only on the target city (Wall Street). We compare single-city training on the target environment alone vs. joint training on multiple cities (3, 4, or 5-way joint training including Wall Street), vs. pre-training on multiple cities and then transferring to Wall Street while freezing the entire agent except for the new pathway (see Fig. 10). One variant has skip connections between the convolutional encoder and the policy LSTM, the other does not (no-skip).]]<br />
<br />
Giving early rewards before agent reaches the goal or adding random rewards (coins) to encourage exploration is investigated in this article. Figure 11a suggests that coins by themselves are ineffective as our task does not benefit from wide explorations. Also, as it is clear from Figure 11b, reducing the density of the landmarks does not seem to reduce the performance. Based on the results, authors chose to start sampling the goal within a radius of 500m from the agent’s location, and then progressively extend it to the maximum distance an agent could travel within the environment. In addition, to asses the importance of the goal-conditioned agents, a Goal-less CityNav agent is trained by removing inputs gt. The poor performance of this agent is clear in Figure 11b. Furthermore, reducing the density of the landmarks by the ratio of 50%, 25%, and 12:5% does not reduce the performance that much. Finally, some alternative for goal representation is investigated:<br />
<br />
a) Latitude and longitude scalar coordinates normalized to be between 0 and 1.<br />
<br />
b) Binned representation. <br />
<br />
The latitude and longitude scalar goal representations perform the best. However, since the all landmarks representation performs well while remaining independent of the coordinate system, we use this representation as the canonical one.<br />
<br />
[[File:figure11-soroush.PNG|300px|thumb|center|Figure 11. Top: Learning curves of the CityNav agent on NYU, comparing reward shaping with different radii of early rewards (ER) vs. ER with random coins vs. curriculum learning with ER 200m and no coins (ER 200m, Curr.). Bottom: Learning curves for CityNav agents with different goal representations: landmark-based, as well as latitude and longitude classification-based and regression-based.]]<br />
<br />
==Conclusion==<br />
In this paper, a deep reinforcement learning approach that enables navigation in cities is presented through the use of Google StreetView for its photographic content and worldwide coverage. Furthermore, the authors discussed a new courier task and a multi-city neural network agent architecture that is transferable to new cities. A successful navigation architecture is presented which relies on integration of general policies with locale-specific knowledge.<br />
<br />
==Future Works==<br />
The paper uses staic Google Street View images. However, this means that there are some more information that we can get from the images beyond the route. Even though it is not the central focus of the paper, it would be extremely useful if we can incorporate such information for effective route-builing or planning.<br />
<br />
==Critique==<br />
1. It is not clear how this model is applicable to the real world. A real-world navigation problem needs to detect objects, people, and cars. However, it is not clear whether they are modeling them or not. From what I understood, they did not care about the collision, which is against their claim that it is a real-world problem.<br />
<br />
2. This paper is only using static Google Street View images as its primary source of data. But the authors must at least complement this with other dynamic data like traffic and road blockage information for a realistic model of navigation in the world. Also, this is quite understandable not to use maps but is not clear why have they not used GPS to know their position and maybe even made up with a map. This can be something useful in an emergency or even for investigating places that are not known or there is no access to them. The resulting map could be easily compared with the real one and could also be used in training to achieve higher performance. The availability should not be a serious problem because if they are simulating a real city and the google images are available, why should not GPS be? What is the intuition? At least, a complementary description on this could be helpful.<br />
<br />
3. The 'Transfer in Multi-City Experiments' results could be strengthened significantly via cross-validation (only Wall Street, which covers the smallest area of the four regions, is used as the test case). Additionally, the results do not show true 'multi-city' transfer learning, since all regions are within New York City. It is stated in the paper that not having to re-learn visual representations when transferring between cities is one of the outcomes, but the tests do not actually check for this. There are likely significant differences in the features that would be learned in NYC vs. Waterloo, for example, and this type of transfer has not been evaluated.<br />
<br />
4. The proposed navigation model could be limited by its reliance on pre-defined landmarks, which appears to be strategically placed evenly spreading across each city. This could limit the agent's deployability to new cities.<br />
<br />
==Reference==<br />
[1] Espeholt, Lasse, Soyer, Hubert, Munos, Remi, Simonyan, Karen, Mnih, Volodymir, Ward, Tom, Doron, Yotam, Firoiu, Vlad, Harley, Tim, Dunning, Iain, Legg, Shane, and Kavukcuoglu, Koray. Impala: Scalable distributed deep-rl with importance weighted actor-learner architec- tures. arXiv preprint arXiv:1802.01561, 2018.<br />
<br />
[2] Mnih, Volodymyr, Badia, Adria Puigdomenech, Mirza, Mehdi, Graves, Alex, Lillicrap, Timothy, Harley, Tim, Silver, David, and Kavukcuoglu, Koray. Asynchronous methods for deep reinforcement learning. In Interna- tional Conference on Machine Learning, pp. 1928–1937, 2016.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Countering_Adversarial_Images_Using_Input_Transformations&diff=41985Countering Adversarial Images Using Input Transformations2018-11-30T00:46:34Z<p>Gsahu: /* Problem Definition */</p>
<hr />
<div>The code for this paper is available here[https://github.com/facebookresearch/adversarial_image_defenses]<br />
<br />
==Motivation ==<br />
As the use of machine intelligence has increased, robustness has become a critical feature to guarantee the reliability of deployed machine-learning systems. However, recent research has shown that existing models are not robust to small, adversarially designed perturbations to the input. Adversarial examples are inputs to Machine Learning models so that an attacker has intentionally designed to cause the model to make a mistake. Adversarially perturbed examples have been deployed to attack image classification services (Liu et al., 2016)[11], speech recognition systems (Cisse et al., 2017a)[12], and robot vision (Melis et al., 2017)[13]. The existence of these adversarial examples has motivated proposals for approaches that increase the robustness of learning systems to such examples. In the example below (Goodfellow et. al) [17], a small perturbation is applied to the original image of a panda, changing the prediction to a gibbon.<br />
<br />
[[File:Panda.png|center]]<br />
<br />
==Introduction==<br />
The paper studies strategies that defend against adversarial example attacks on image classification systems by transforming the images before feeding them to a Convolutional Network Classifier. <br />
Generally, defenses against adversarial examples fall into two main categories:<br />
<br />
# Model-Specific – They enforce model properties such as smoothness and invariance via the learning algorithm. <br />
# Model-Agnostic – They try to remove adversarial perturbations from the input. <br />
<br />
Model-specific defense strategies make strong assumptions about expected adversarial attacks. As a result, they violate the Kerckhoffs principle, which states that adversaries can circumvent model-specific defenses by simply changing how an attack is executed. This paper focuses on increasing the effectiveness of model-agnostic defense strategies. Specifically, they investigated the following image transformations as a means for protecting against adversarial images:<br />
<br />
# Image Cropping and Re-scaling (Graese et al, 2016). <br />
# Bit Depth Reduction (Xu et. al, 2017) <br />
# JPEG Compression (Dziugaite et al, 2016) <br />
# Total Variance Minimization (Rudin et al, 1992) <br />
# Image Quilting (Efros & Freeman, 2001). <br />
<br />
These image transformations have been studied against Adversarial attacks such as the fast gradient sign method (Goodfelow et. al., 2015), its iterative extension (Kurakin et al., 2016a), Deepfool (Moosavi-Dezfooli et al., 2016), and the Carlini & Wagner (2017) <math>L_2</math>attack. <br />
<br />
The authors in this paper try to focus on increasing the effectiveness of model-agnostic defense strategies through approaches that:<br />
# remove the adversarial perturbations from input images,<br />
# maintain sufficient information in input images to correctly classify them,<br />
# and are still effective in situations where the adversary has information about the defense strategy being used.<br />
<br />
From their experiments, the strongest defenses are based on Total Variance Minimization and Image Quilting. These defenses are non-differentiable and inherently random which makes it difficult for an adversary to get around them.<br />
<br />
==Previous Work==<br />
Recently, a lot of research has focused on countering adversarial threats. Wang et al [4], proposed a new adversary resistant technique that obstructs attackers from constructing impactful adversarial images. This is done by randomly nullifying features within images. Tramer et al [2], showed the state-of-the-art Ensemble Adversarial Training Method, which augments the training process but not only included adversarial images constructed from their model but also including adversarial images generated from an ensemble of other models. Their method implemented on an Inception V2 classifier finished 1st among 70 submissions of NIPS 2017 competition on Defenses against Adversarial Attacks. Graese, et al. [3], showed how input transformation such as shifting, blurring and noise can render the majority of the adversarial examples as non-adversarial. Xu et al.[5] demonstrated, how feature squeezing methods, such as reducing the color bit depth of each pixel and spatial smoothing, defends against attacks. Dziugaite et al [6], studied the effect of JPG compression on adversarial images. Chen et al. [7] introduce an advanced denoising algorithm with GAN based noise modeling in order to improve the blind denoising performance in low level vision processing. The GAN is trained to estimate the noise distribution over the input noisy images and to generate noise samples. Although meant for image processing, this method can be generalized to target adversarial examples where the unknown noise generating algorithm can be leveraged.<br />
<br />
==Terminology==<br />
<br />
'''Gray Box Attack''' : Model Architecture and parameters are Public<br />
<br />
'''Black Box Attack''': Adversary does not have access to the model.<br />
<br />
An interesting and important observation of adversarial examples is that they generally are not model or architecture specific. Adversarial examples generated for one neural network architecture will transfer very well to another architecture. In other words, if you wanted to trick a model you could create your own model and adversarial examples based off of it. Then these same adversarial examples will most probably trick the other model as well. This has huge implications as it means that it is possible to create adversarial examples for a completely black box model where we have no prior knowledge of the internal mechanics. [https://ml.berkeley.edu/blog/2018/01/10/adversarial-examples/ reference]<br />
<br />
'''Non Targeted Adversarial Attack''': The goal of the attack is to modify a source image in a way such that the image will be classified incorrectly by the network.<br />
<br />
This is an example on non-targeted adversarial attacks to be more clear [https://ml.berkeley.edu/blog/2018/01/10/adversarial-examples/ reference]:<br />
[[File:non-targeted O.JPG| 600px|center]]<br />
<br />
'''Targeted Adversarial Attack''': The goal of the attack is to modify a source image in way such that image will be classified as a ''target'' class by the network.<br />
<br />
This is an example on targeted adversarial attacks to be more clear [https://ml.berkeley.edu/blog/2018/01/10/adversarial-examples/ reference]:<br />
[[File:Targeted O.JPG| 600px|center]]<br />
<br />
'''Defense''': A defense is a strategy that aims make the prediction on an adversarial example h(x') equal to the prediction on the corresponding clean example h(x).<br />
<br />
== Problem Definition ==<br />
The paper discusses non-targeted adversarial attacks for image recognition systems. Given image space <math>\mathcal{X} = [0,1]^{H \times W \times C}</math>, a source image <math>x \in \mathcal{X}</math>, and a classifier <math>h(.)</math>, a non-targeted adversarial example of <math>x</math> is a perturbed image <math>x'</math>, such that <math>h(x) \neq h(x')</math> and <math>d(x, x') \leq \rho</math> for some dissimilarity function <math>d(·, ·)</math> and <math>\rho \geq 0</math>. In the best case scenario, <math>d(·, ·)</math> measures the perceptual difference between the original image <math>x</math> and the perturbed image <math>x'</math>, but usually, Euclidean distance (<math>||x - x'||_2</math>) or the Chebyshov distance (<math>||x - x'||_{\infty}</math>) are used.<br />
<br />
From a set of N clean images <math>[{x_{1}, …, x_{N}}]</math>, an adversarial attack aims to generate <math>[{x'_{1}, …, x'_{N}}]</math> images, such that (<math>x'_{n}</math>) is an adversary of (<math>x_{n}</math>).<br />
<br />
The success rate of an attack is given as: <br />
<br />
<center><math><br />
\frac{1}{N}\sum_{n=1}^{N}I[h(x_n) &ne; h({x_n}^\prime)],<br />
</math></center><br />
<br />
which is the proportions of predictions that were altered by an attack.<br />
<br />
The success rate is generally measured as a function of the magnitude of perturbations performed by the attack. In this paper, L2 perturbations are used and are quantified using the normalized L2-dissimilarity metric:<br />
<math> \frac{1}{N} \sum_{n=1}^N{\frac{\vert \vert x_n - x'_n \vert \vert_2}{\vert \vert x_n \vert \vert_2}} </math><br />
<br />
A strong adversarial attack has a high rate, while its normalized L2-dissimilarity given by the above equation is less.<br />
<br />
==Adversarial Attacks==<br />
<br />
Although the exact effect that adversarial examples have on network is unknown, Ian Goodfellow et. al's Deep Learning book states that adversarial examples exploit the linearity of neural networks to perturb the cost function to force incorrect classifications. Images are often high resolution, and thus have thousands of pixels (millions for HD images). An epsilon ball perturbation when dimensionality is in the magnitude of thousands/millions greatly effects the cost function (especially if it increases loss at every pixel). Hence, although the following methods such as FGSM and Iterative FGSM are very straightforward, they greatly influence the network under a white box attack. <br />
<br />
For the experimental purposes, below 4 attacks have been studied in the paper:<br />
<br />
1. '''Fast Gradient Sign Method (FGSM; Goodfellow et al. (2015)) [17]''': Given a source input <math>x</math>, and true label <math>y</math>, and let <math>l(.,.)</math> be the differentiable loss function used to train the classifier <math>h(.)</math>. Then the corresponding adversarial example is given by:<br />
<br />
<math>x' = x + \epsilon \cdot sign(\nabla_x l(x, y))</math><br />
<br />
for some <math>\epsilon \gt 0</math> which controls the perturbation magnitude.<br />
<br />
2. '''Iterative FGSM ((I-FGSM; Kurakin et al. (2016b)) [14]''': iteratively applies the FGSM update, where M is the number of iterations. It is given as:<br />
<br />
<math>x^{(m)} = x^{(m-1)} + \epsilon \cdot sign(\nabla_{x^{m-1}} l(x^{m-1}, y))</math><br />
<br />
where <math>m = 1,...,M; x^{(0)} = x;</math> and <math>x' = x^{(M)}</math>. M is set such that <math>h(x) \neq h(x')</math>.<br />
<br />
Both FGSM and I-FGSM work by minimizing the Chebyshev distance between the inputs and the generated adversarial examples.<br />
<br />
3. '''DeepFool ((Moosavi-Dezfooliet al., 2016) [15]''': projects x onto a linearization of the decision boundary defined by binary classifier h(.) for M iterations. This can be particularly effictive when a network uses ReLU activation functions. It is given as:<br />
<br />
[[File:DeepFool.PNG|400px |]]<br />
<br />
4. '''Carlini-Wagner's L2 attack (CW-L2; Carlini & Wagner (2017)) [16]''': propose an optimization-based attack that combines a differentiable surrogate for the model’s classification accuracy with an L2-penalty term which encourages the adversary image to be close to the original image. Let <math>Z(x)</math> be the operation that computes the logit vector (i.e., the output before the softmax layer) for an input <math>x</math>, and <math>Z(x)_k</math> be the logit value corresponding to class <math>k</math>. The untargeted variant<br />
of CW-L2 finds a solution to the unconstrained optimization problem. It is given as:<br />
<br />
[[File:Carlini.PNG|500px |]]<br />
<br />
As mentioned earlier, the first two attacks minimize the Chebyshev distance whereas the last two attacks minimize the Euclidean distance between the inputs and the adversarial examples.<br />
<br />
All the methods described above maintain <math>x' \in \mathcal{X}</math> by performing value clipping. <br />
<br />
Below figure shows adversarial images and corresponding perturbations at five levels of normalized L2-dissimilarity for all four attacks, mentioned above.<br />
<br />
[[File:Strength.PNG|thumb|center| 600px |Figure 1: Adversarial images and corresponding perturbations at five levels of normalized L2- dissimilarity for all four attacks.]]<br />
<br />
==Defenses==<br />
Defense is a strategy that aims to make the prediction on an adversarial example equal to the prediction on the corresponding clean example, and the particular structure of adversarial perturbations <math> x-x' </math> have been shown in Figure 1.<br />
Five image transformations that alter the structure of these perturbations have been studied:<br />
# Image Cropping and Re-scaling, <br />
# Bit Depth Reduction, <br />
# JPEG Compression, <br />
# Total Variance Minimization, <br />
# Image Quilting.<br />
<br />
'''Image cropping and Rescaling''' has the effect of altering the spatial positioning of the adversarial perturbation. In this study, images are cropped and re-scaled during training time as part of data-augmentation. At test time, the predictions of randomly cropped are averaged.<br />
<br />
'''Bit Depth Reduction (Xu et. al) [5]''' performs a simple type of quantization that can remove small (adversarial) variations in pixel values from an image. Images are reduced to 3 bits in the experiment.<br />
<br />
'''JPEG Compression and Decompression (Dziugaite etal., 2016)''' removes small perturbations by performing simple quantization. The authors use a quality level of 75/100 in their experiments<br />
<br />
'''Total Variance Minimization (Rudin et. al) [9]''' :<br />
This combines pixel dropout with total variance minimization. This approach randomly selects a small set of pixels, and reconstructs the “simplest” image that is consistent with the selected pixels. The reconstructed image does not contain the adversarial perturbations because these perturbations tend to be small and localized.Specifically, we first select a random set of pixels by sampling a Bernoulli random variable <math>X(i; j; k)</math> for each pixel location <math>(i; j; k)</math>;we maintain a pixel when <math>(i; j; k)</math>= 1. Next, we use total variation, minimization to constructs an image z that is similar to the (perturbed) input image x for the selected<br />
set of pixels, whilst also being “simple” in terms of total variation by solving:<br />
<br />
[[File:TV!.png|300px|]] , <br />
<br />
where <math>TV_{p}(z)</math> represents <math>L_{p}</math> total variation of '''z''' :<br />
<br />
[[File:TV2.png|500px|]]<br />
<br />
The total variation (TV) measures the amount of fine-scale variation in the image z, as a result of which TV minimization encourages removal of small (adversarial) perturbations in the image.<br />
<br />
'''Image Quilting (Efros & Freeman, 2001) [8]'''<br />
Image Quilting is a non-parametric technique that synthesizes images by piecing together small patches that are taken from a database of image patches. The algorithm places appropriate patches in the database for a predefined set of grid points and computes minimum graph cuts in all overlapping boundary regions to remove edge artifacts. Image Quilting can be used to remove adversarial perturbations by constructing a patch database that only contains patches from "clean" images ( without adversarial perturbations); the patches used to create the synthesized image are selected by finding the K nearest neighbors ( in pixel space) of the corresponding patch from the adversarial image in the patch database, and picking one of these neighbors uniformly at random. The motivation for this defense is that resulting image only contains pixels that were not modified by the adversary - the database of real patches is unlikely to contain the structures that appear in adversarial images.<br />
<br />
=Experiments=<br />
<br />
Five experiments were performed to test the efficacy of defences. The first four experiments consider gray and black box attacks. The gray-box attack applies defenses on input adversarial images for the convolutional networks. The adversary is able to read model architecture and parameters but not the defence strategy. The black-box attack replaces convolutional network by a trained network with image-transformations. The final experiment compares the authors' defenses with prior work. <br />
<br />
'''Set up:'''<br />
Experiments are performed on the ImageNet image classification dataset. The dataset comprises 1.2 million training images and 50,000 test images that correspond to one of 1000 classes. The adversarial images are produced by attacking a ResNet-50 model, with different kinds of attacks mentioned in Section5. The strength of an adversary is measured in terms of its normalized L2-dissimilarity. To produce the adversarial images, L2 dissimilarity for each of the attack was set as below:<br />
<br />
- FGSM. Increasing the step size <math>\epsilon</math>, increases the normalized L2-dissimilarity.<br />
<br />
- I-FGSM. We fix M=10, and increase <math>\epsilon</math> to increase the normalized L2-dissimilarity.<br />
<br />
- DeepFool. We fix M=5, and increase <math>\epsilon</math> to increase the normalized L2-dissimilarity.<br />
<br />
- CW-L2. We fix <math>k</math>=0 and <math>\lambda_{f}</math> =10, and multiply the resulting perturbation <br />
<br />
The hyperparameters of the defenses have been fixed in all the experiments. Specifically the pixel dropout probability was set to <math>p</math>=0.5 and regularization parameter of total variation minimizer <math>\lambda_{TV}</math>=0.03.<br />
<br />
Below figure shows the difference between the set up in different experiments below. The network is either trained on a) regular images or b) transformed images. The different settings are marked by 8.1, 8.2 and 8.3 <br />
[[File:models3.png |center]] <br />
<br />
==GrayBox - Image Transformation at Test Time== <br />
This experiment applies a transformation on adversarial images at test time before feeding them to a ResNet -50 which was trained to classify clean images. Below figure shows the results for five different transformations applied and their corresponding Top-1 accuracy. Few of the interesting observations from the plot are: All of the image transformations partly eliminate the effects of the attack, Crop ensemble gives the best accuracy around 40-60 percent, with an ensemble size of 30. The accuracy of Image Quilting Defense hardly deteriorates as the strength of the adversary increases. However, it does impact accuracy on non-adversarial examples.<br />
<br />
[[File:sFig4.png|center|600px |]]<br />
<br />
==BlackBox - Image Transformation at Training and Test Time==<br />
ResNet-50 model was trained on transformed ImageNet Training images. Before feeding the images to the network for training, standard data augmentation (from He et al) along with bit depth reduction, JPEG Compression, TV Minimization, or Image Quilting were applied on the images. The classification accuracy on the same adversarial images as in the previous case is shown Figure below. (Adversary cannot get this trained model to generate new images - Hence this is assumed as a Black Box setting!). Below figure concludes that training Convolutional Neural Networks on images that are transformed in the same way at test time, dramatically improves the effectiveness of all transformation defenses. Nearly 80 -90 % of the attacks are defended successfully, even when the L2- dissimilarity is high.<br />
<br />
<br />
[[File:sFig5.png|center|600px |]]<br />
<br />
<br />
==Blackbox - Ensembling==<br />
Four networks ResNet-50, ResNet-10, DenseNet-169, and Inception-v4 along with an ensemble of defenses were studied, as shown in Table 1. The adversarial images are produced by attacking a ResNet-50 model. The results in the table conclude that Inception-v4 performs best. This could be due to that network having a higher accuracy even in non-adversarial settings. The best ensemble of defenses achieves an accuracy of about 71% against all the other attacks. The attacks deteriorate the accuracy of the best defenses (a combination of cropping, TVM, image quilting, and model transfer) by at most 6%. Gains of 1-2% in classification accuracy could be found from ensembling different defenses, while gains of 2-3% were found from transferring attacks to different network architectures.<br />
<br />
<br />
[[File:sTab1.png|600px|thumb|center|Table 1. Top-1 classification accuracy of ensemble and model transfer defenses (columns) against four black-box attacks (rows). The four networks we use to classify images are ResNet-50 (RN50), ResNet-101 (RN101), DenseNet-169 (DN169), and Inception-v4 (Iv4). Adversarial images are generated by running attacks against the ResNet-50 model, aiming for an average normalized <math>L_2</math>-dissimilarity of 0.06. Higher is better. The best defense against each attack is typeset in boldface.]]<br />
<br />
==GrayBox - Image Transformation at Training and Test Time ==<br />
In this experiment, the adversary has access to the network and the related parameters (but does not have access to the input transformations applied at test time). From the network trained in-(BlackBox: Image Transformation at Training and Test Time), novel adversarial images were generated by the four attack methods. The results show that Bit-Depth Reduction and JPEG Compression are weak defenses in such a gray box setting. In contrast, image cropping, rescaling, variation minimization, and image quilting are more robust against adversarial images in this setting.<br />
The results for this experiment are shown in below figure. Networks using these defenses classify up to 50 % of images correctly.<br />
<br />
[[File:sFig6.png|center| 600px |]]<br />
<br />
==Comparison With Ensemble Adversarial Training==<br />
The results of the experiment are compared with the state of the art ensemble adversarial training approach proposed by Tramer et al. [2]. Ensemble Training fits the parameters of a Convolutional Neural Network on adversarial examples that were generated to attack an ensemble of pre-trained models. The model release by Tramer et al [2]: an Inception-Resnet-v2, trained on adversarial examples generated by FGSM against Inception-Resnet-v2 and Inception-v3 models. The authors compared their ResNet-50 models with image cropping, total variance minimization and image quilting defenses. Two assumption differences need to be noticed. Their defenses assume the input transformation is unknown to the adversary and no prior knowledge of the attacks is being used. The results of ensemble training and the pre-processing techniques mentioned in this paper are shown in Table 2. The results show that ensemble adversarial training works better on FGSM attacks (which it uses at training time), but is outperformed by each of the transformation-based defenses all other attacks.<br />
<br />
<br />
<br />
[[File:sTab2.png|600px|thumb|center|Table 2. Top-1 classification accuracy on images perturbed using attacks against ResNet-50 models trained on input-transformed images and an Inception-v4 model trained using ensemble adversarial. Adversarial images are generated by running attacks against the models, aiming for an average normalized <math>L_2</math>-dissimilarity of 0.06. The best defense against each attack is typeset in boldface.]]<br />
<br />
=Discussion/Conclusions=<br />
The paper proposed reasonable approaches to countering adversarial images. The authors evaluated Total Variance Minimization and Image Quilting and compared it with already proposed ideas like Image Cropping - Rescaling, Bit Depth Reduction, JPEG Compression, and Decompression on the challenging ImageNet dataset.<br />
Previous work by Wang et al. [10] shows that a strong input defense should be nondifferentiable and randomized. Two of the defenses - namely Total Variation Minimization and Image Quilting, both possess this property.<br />
<br />
Image quilting involves a discrete variable that conducts selection of a patch from the database, which is a non-differentiable operation.<br />
Additionally, total variation minimization randomly conducts pixels selection from the pixels it uses to measure reconstruction<br />
error during creation of the de-noised image. Image quilting conducts random selection of a particular K<br />
nearest neighbor uniformly, but in a random manner. This inherent randomness makes it difficult to attack the model. <br />
<br />
Future work suggests applying the same techniques to other domains such as speech recognition and image segmentation. For example, in speech recognition, total variance minimization can be used to remove perturbations from waveforms and "spectrogram quilting" techniques that reconstruct a spectrogram could be developed. The proposed input-transformation defenses can also be combined with ensemble adversarial training by Tramèr et al.[2] to study new attack methods.<br />
<br />
=Critiques=<br />
1. The terminology of Black Box, White Box, and Grey Box attack is not exactly given and clear.<br />
<br />
2. White Box attacks could have been considered where the adversary has a full access to the model as well as the pre-processing techniques.<br />
<br />
3. Though the authors did a considerable work in showing the effect of four attacks on ImageNet database, much stronger attacks (Madry et al) [7], could have been evaluated.<br />
<br />
4. Authors claim that the success rate is generally measured as a function of the magnitude of perturbations, performed by the attack using the L2- dissimilarity, but the claim is not supported by any references. None of the previous work has used these metrics.<br />
<br />
5. ([https://openreview.net/forum?id=SyJ7ClWCb])In the new draft of the paper, the authors add the sentence "our defenses assume that part of the defense strategy (viz., the input transformation) is unknown to the adversary".<br />
<br />
This is a completely unreasonable assumption. Any algorithm which hopes to be secure must allow the adversary to, at the very least, understand what the defense is that's being used. Consider a world where the defense here is implemented in practice: any attacker in the world could just go look up the paper, read the description of the algorithm, and know how it works.<br />
<br />
=References=<br />
<br />
1. Chuan Guo , Mayank Rana & Moustapha Ciss´e & Laurens van der Maaten , Countering Adversarial Images Using Input Transformations<br />
<br />
2. Florian Tramèr, Alexey Kurakin, Nicolas Papernot, Ian Goodfellow, Dan Boneh, Patrick McDaniel, Ensemble Adversarial Training: Attacks and defenses.<br />
<br />
3. Abigail Graese, Andras Rozsa, and Terrance E. Boult. Assessing threat of adversarial examples of deep neural networks. CoRR, abs/1610.04256, 2016. <br />
<br />
4. Qinglong Wang, Wenbo Guo, Kaixuan Zhang, Alexander G. Ororbia II, Xinyu Xing, C. Lee Giles, and Xue Liu. Adversary resistant deep neural networks with an application to malware detection. CoRR, abs/1610.01239, 2016a.<br />
<br />
5. Weilin Xu, David Evans, and Yanjun Qi. Feature squeezing: Detecting adversarial examples in deep neural networks. CoRR, abs/1704.01155, 2017. <br />
<br />
6. Gintare Karolina Dziugaite, Zoubin Ghahramani, and Daniel Roy. A study of the effect of JPG compression on adversarial images. CoRR, abs/1608.00853, 2016.<br />
<br />
7. Aleksander Madry, Aleksandar Makelov, Ludwig Schmidt, Dimitris Tsipras, Adrian Vladu .Towards Deep Learning Models Resistant to Adversarial Attacks, arXiv:1706.06083v3<br />
<br />
8. Alexei Efros and William Freeman. Image quilting for texture synthesis and transfer. In Proc. SIGGRAPH, pp. 341–346, 2001.<br />
<br />
9. Leonid Rudin, Stanley Osher, and Emad Fatemi. Nonlinear total variation based noise removal algorithms. Physica D, 60:259–268, 1992.<br />
<br />
10. Qinglong Wang, Wenbo Guo, Kaixuan Zhang, Alexander G. Ororbia II, Xinyu Xing, C. Lee Giles, and Xue Liu. Learning adversary-resistant deep neural networks. CoRR, abs/1612.01401, 2016b.<br />
<br />
11. Yanpei Liu, Xinyun Chen, Chang Liu, and Dawn Song. Delving into transferable adversarial examples and black-box attacks. CoRR, abs/1611.02770, 2016.<br />
<br />
12. Moustapha Cisse, Yossi Adi, Natalia Neverova, and Joseph Keshet. Houdini: Fooling deep structured prediction models. CoRR, abs/1707.05373, 2017 <br />
<br />
13. Marco Melis, Ambra Demontis, Battista Biggio, Gavin Brown, Giorgio Fumera, and Fabio Roli. Is deep learning safe for robot vision? adversarial examples against the icub humanoid. CoRR,abs/1708.06939, 2017.<br />
<br />
14. Alexey Kurakin, Ian J. Goodfellow, and Samy Bengio. Adversarial examples in the physical world. CoRR, abs/1607.02533, 2016b.<br />
<br />
15. Seyed-Mohsen Moosavi-Dezfooli, Alhussein Fawzi, and Pascal Frossard. Deepfool: A simple and accurate method to fool deep neural networks. In Proc. CVPR, pp. 2574–2582, 2016.<br />
<br />
16. Nicholas Carlini and David A. Wagner. Towards evaluating the robustness of neural networks. In IEEE Symposium on Security and Privacy, pp. 39–57, 2017.<br />
<br />
17. Ian Goodfellow, Jonathon Shlens, and Christian Szegedy. Explaining and harnessing adversarial examples. In Proc. ICLR, 2015.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Robot_Learning_in_Homes:_Improving_Generalization_and_Reducing_Dataset_Bias&diff=41979Robot Learning in Homes: Improving Generalization and Reducing Dataset Bias2018-11-30T00:37:26Z<p>Gsahu: /* References */</p>
<hr />
<div>==Introduction==<br />
<br />
<br />
The use of data-driven approaches in robotics has increased in the last decade. Instead of using hand-designed models, these data-driven approaches work on large-scale datasets and learn appropriate policies that map from high-dimensional observations to actions. Since collecting data using an actual robot in real-time is very expensive, most of the data-driven approaches in robotics use simulators in order to collect simulated data. The concern here is whether these approaches have the capability to be robust enough to domain shift and to be used for real-world data. It is an undeniable fact that there is a wide reality gap between simulators and the real world.<br />
<br />
This has motivated the robotics community to increase their efforts in collecting real-world physical interaction data for a variety of tasks. This effort has been accelerated by the declining costs of hardware. This approach has been quite successful at tasks such as grasping, pushing, poking and imitation learning. However, the major problem is that the performance of these learning models are not good enough and tend to plateau fast. Furthermore, robotic action data did not lead to similar gains in other areas such as computer vision and natural language processing. As the paper claimed, the solution for all of these obstacles is using “real data”. Current robotic datasets lack diversity of environment. Learning-based approaches need to move out of simulators in the labs and go to real environments such as real homes so that they can learn from real datasets. <br />
<br />
Like every other process, the process of collecting real-world data is made difficult by a number of problems. First, there is a need for cheap and compact robots to collect data in homes but current industrial robots (i.e. Sawyer and Baxter) are too expensive. Secondly, cheap robots are not accurate enough to collect reliable data. Also, there is a lack of constant supervision for data collection in homes. Finally, there is also a circular dependency problem in home-robotics: there is a lack of real-world data which are needed to improve current robots, but current robots are not good enough to collect reliable data in homes. These challenges in addition to some other external factors will likely result in noisy data collection. In this paper, a first systematic effort has been presented for collecting a dataset inside homes. In accomplishing this goal, the authors: <br />
<br />
1. Build a cheap robot costing less than USD 3K which is appropriate for use in homes<br />
<br />
2. Collect training data in 6 different homes and testing data in 3 homes<br />
<br />
3. Propose a method for modelling the noise in the labelled data<br />
<br />
4. Demonstrate that the diversity in the collected data provides superior performance and requires little-to-no domain adaptation<br />
<br />
[[File:aa1.PNG|600px|thumb|center|]]<br />
<br />
==Overview==<br />
<br />
This paper emphasizes the importance of diversifying the data for robotic learning in order to have a greater generalization, by focusing on the task of grasping. A diverse dataset also allows for removing biases in the data. By considering these facts, the paper argues that even for simple tasks like grasping, datasets which are collected in labs suffer from strong biases such as simple backgrounds and same environment dynamics. Hence, the learning approaches cannot generalize the models and work well on real datasets.<br />
<br />
As a future possibility, there would be a need for having a low-cost robot to collect large-scale data inside a huge number of homes. For this reason, they introduced a customized mobile manipulator. They used a Dobot Magician which is a robotic arm mounted on a Kobuki which is a low-cost mobile robot base equipped with sensors such as bumper contact sensors and wheel encoders. The resulting robot arm has five degrees of freedom (DOF) (x, y, z, roll, pitch). The gripper is a two-fingered electric gripper with a 0.3kg payload. They also add an Intel R200 RGBD camera to their robot which is at a height of 1m above the ground. An Intel Core i5 processor is also used as an onboard laptop to perform all the processing. The whole system can run for 1.5 hours with a single charge.<br />
<br />
As there is always a trade-off, when we gain a low-cost robot, we are actually losing accuracy for controlling it. So, the low-cost robot which is built from cheaper components than the expensive setups such as Baxter and Sawyer suffers from higher calibration errors and execution errors. This means that the dataset collected with this approach is diverse and huge but it has noisy labels. To illustrate, consider when the robot wants to grasp at location <math> {(x, y)}</math>. Since there is a noise in the execution, the robot may perform this action in the location <math> {(x + \delta_{x}, y+ \delta_{y})}</math> which would assign the success or failure label of this action to a wrong place. Therefore, to solve the problem, they used an approach to learn from noisy data. They modeled noise as a latent variable and used two networks, one for predicting the noise and one for predicting the action to execute.<br />
<br />
==Learning on low-cost robot data==<br />
<br />
This paper uses a patch grasping framework in its proposed architecture. Also, as mentioned before, there is a high tendency for noisy labels in the datasets which are collected by inaccurate and cheap robots. The cause of the noise in the labels could be due to the hardware execution error, inaccurate kinematics, camera calibration, proprioception, wear, and tear, etc. Here are more explanations about different parts of the architecture in order to disentangle the noise of the low-cost robot’s actual and commanded executions.<br />
<br />
===Grasping Formulation===<br />
<br />
Planar grasping is the object of interest in this architecture. It means that all the objects are grasped at the same height and vertical to the ground (ie: a fixed end-effector pitch). The final goal is to find <math>{(x, y, \theta)}</math> given an observation <math> {I}</math> of the object, where <math> {x}</math> and <math> {y}</math> are the translational degrees of freedom and <math> {\theta}</math> is the rotational degrees of freedom (roll of the end-effector). For the purpose of comparison, they used a model which does not predict the <math>{(x, y, \theta)}</math> directly from the image <math> {I}</math>, but samples several smaller patches <math> {I_{P}}</math> at different locations <math>{(x, y)}</math>. Thus, the angle of grasp <math> {\theta}</math> is predicted from these patches. Also, in order to have multi-modal predictions, discrete steps of the angle <math> {\theta}</math>, <math> {\theta_{D}}</math> is used. <br />
<br />
Hence, each datapoint consists of an image <math> {I}</math>, the executed grasp <math>{(x, y, \theta)}</math> and the grasp success/failure label g. Then, the image <math> {I}</math> and the angle <math> {\theta}</math> are converted to image patch <math> {I_{P}}</math> and angle <math> {\theta_{D}}</math>. Then, to minimize the classification error, a binary cross entropy loss is used which minimizes the error between the predicted and ground truth label <math> g </math>. A convolutional neural network with weight initialization from pre-training on Imagenet is used for this formulation.<br />
<br />
(Note: On Cross Entropy:<br />
<br />
If we think of a distribution as the tool we use to encode symbols, then entropy measures the number of bits we'll need if we use the correct tool. This is optimal, in that we can't encode the symbols using fewer bits on average.<br />
In contrast, cross entropy is the number of bits we'll need if we encode symbols from y using the wrong tool <math> {\hat h}</math> . This consists of encoding the <math> {i_{th}}</math> symbol using <math> {\log(\frac{1}{{\hat h_i}})}</math> bits instead of <math> {\log(\frac{1}{{ h_i}})}</math> bits. We of course still take the expected value to the true distribution y , since it's the distribution that truly generates the symbols:<br />
<br />
\begin{align}<br />
H(y,\hat y) = \sum_i{y_i\log{\frac{1}{\hat y_i}}}<br />
\end{align}<br />
<br />
Cross entropy is always larger than entropy; encoding symbols according to the wrong distribution <math> {\hat y}</math> will always make us use more bits. The only exception is the trivial case where y and <math> {\hat y}</math> are equal, and in this case entropy and cross entropy are equal.)<br />
<br />
===Modeling noise as latent variable===<br />
<br />
In order to tackle the problem of inaccurate position control and calibration due to cheap robot, they found a structure in the noise which is dependent on the robot and the design. They modeled this structure of noise as a latent variable and decoupled during training. The approach is shown in figure 2: <br />
<br />
<br />
[[File:aa2.PNG|600px|thumb|center|]]<br />
<br />
The conventional approach models the grasp success probability for a given image patch at a given angle where the variables of the environment which can introduce noise in the system is generally insignificant, due to the high accuracy of expensive, commercial robots. However, in the low cost setting with multiple robots collecting data in parallel, it becomes an important consideration for learning. <br />
<br />
The grasp success probability for image patch <math> {I_{P}}</math> at angle <math> {\theta_{D}}</math> is represented as <math> {P(g|I_{P},\theta_{D}; \mathcal{R} )}</math> where <math> \mathcal{R}</math> represents environment variables that can add noise to the system.<br />
<br />
The conditional probability of grasping at a noisy image patch <math>I_P</math> for this model is computed by:<br />
<br />
<br />
\[ { P(g|I_{P},\theta_{D}, \mathcal{R} ) = ∑_{( \widehat{I_P} \in \mathcal{P})} P(g│z=\widehat{I_P},\theta_{D},\mathcal{R}) \cdot P(z=\widehat{I_P} | \theta_{D},I_P,\mathcal{R})} \]<br />
<br />
<br />
Here, <math> {z}</math> models the latent variable of the actual patch executed, and <math>\widehat{I_P}</math> belongs to a set of possible neighboring patches <math> \mathcal{P}</math>.<math> P(z=\widehat{I_P}|\theta_D,I_P,\mathcal{R})</math> shows the noise which can be caused by <math>\mathcal{R}</math> variables and is implemented as the Noise Modelling Network (NMN). <math> {P(g│z=\widehat{I_P},\theta_{D}, \mathcal{R} )}</math> shows the grasp prediction probability given the true patch and is implemented as the Grasp Prediction Network (GPN). The overall Robust-Grasp model is computed by marginalizing GPN and NMN.<br />
<br />
===Learning the latent noise model===<br />
<br />
This section concerns what be the inputs to the NMN network should be and how should the inputs can be trained. The authors assume that <math> {z}</math> is conditionally independent of the local patch-specific variables <math> {(I_{P}, \theta_{D})}</math>. To estimate the latent variable <math> {z}</math> given the global information <math>\mathcal{R}</math>, i.e <math> P(z=\widehat{I_P}|\theta_D,I_P,\mathcal{R}) \equiv P(z=\widehat{I_P}|\mathcal{R})</math>. Apart from the patch <math> I_{P} </math> and grasp information (x, y, θ), they use information like image of the entire scene, ID of the robot and the location of the raw pixel. They argue that the image of the full scene could contain some essential information about the system such as the relative location of camera to the ground which may change over the lifetime of the robot. They used direct optimization to learn both NMN and GPN with noisy labels. However, explicit labels are not available to train NMN but the latent variable <math>z</math> can be estimated using a technique such as Expectation-Maximization. The entire image of the scene and the environment information are the inputs of the NMN, as well as robot ID and raw-pixel grasp location. The output of the NMN is the probability distribution of the actual patches where the grasps are executed. Finally, a binary cross entropy loss is applied to the marginalized output of these two networks and the true grasp label g.<br />
<br />
===Training details===<br />
<br />
They implemented their model in PyTorch using a pretrained ResNet-18 model. They concatenated 512 dimensional ResNet feature with a 1-hot vector of robot ID and the raw pixel location of the grasp for their NMN. Also, the inputs of the GPN are the original noisy patch plus 8 other equidistant patches from the original one.<br />
Their training process starts with training only GPN over 5 epochs of the data. Then, the NMN and the marginalization operator are added to the model. So, they train NMN and GPN simultaneously for the other 25 epochs.<br />
<br />
==Results==<br />
<br />
In the results part of the paper, they show that collecting dataset in homes is essential for generalizing learning from unseen environments. They also show that modelling the noise in their Low-Cost Arm (LCA) can improve grasping performance.<br />
They collected data in parallel using multiple robots in 6 different homes, as shown in Figure 3. They used an object detector (tiny-YOLO) as the input data were unstructured due to LCA limited memory and computational capabilities. With an object location detected, class information was discarded, and a grasp was attempted. The grasp location in 3D was computed using PointCloud data. They scattered different objects in homes within 2m area to prevent collision of the robot with obstacles and let the robot move randomly and grasp objects. Finally, they collected a dataset with 28K grasp results.<br />
<br />
[[File:aa3.PNG|600px|thumb|center|]]<br />
<br />
To evaluate their approach in a more quantitative way, they used three test settings:<br />
<br />
- The first one is a binary classification or held-out data. The test set is collected by performing random grasps on objects. They measure the performance of binary classification by predicting the success or failure of grasping, given a location and the angle. Using binary classification allows for testing a lot of models without running them on real robots. They collected two held-out datasets using LCA in lab and homes and the dataset for Baxter robot.<br />
<br />
- The second one is Real Low-Cost Arm (Real-LCA). Here, they evaluate their model by running it in three unseen homes. They put 20 new objects in these three homes in different orientations. Since the objects and the environments are completely new, this tests could measure the generalization of the model.<br />
<br />
- The third one is Real Sawyer (Real-Sawyer). They evaluate the performance of their model by running the model on the Sawyer robot which is more accurate than the LCA. They tested their model in the lab environment to show that training models with the datasets collected from homes can improve the performance of models even in lab environments.<br />
<br />
They used baselines for both their data which is collected in homes and their model which is Robust-Grasp. They used two datasets for the baseline. The dataset collected by (Lab-Baxter) and the dataset collected by their LCA in the lab (Lab-LCA).<br />
They compared their Robust-Grasp model with the noise independent patch grasping model (Patch-Grasp) [4]. They also compared their data and model with DexNet-3.0 (DexNet) for a strong real-world grasping baseline.<br />
<br />
===Experiment 1: Performance on held-out data===<br />
<br />
Table 1 shows that the models trained on lab data cannot generalize to the Home-LCA environment (i.e. they overfit to their respective environments and attain a lower binary classification score). However, the model trained on Home-LCA has a good performance on both lab data and home environment.<br />
<br />
[[File:aa4.PNG|600px|thumb|center|]]<br />
<br />
===Experiment 2: Performance on Real LCA Robot===<br />
<br />
In table 2, the performance of the Home-LCA is compared against a pre-trained DexNet and the model trained on the Lab-Baxter. Training on the Home-LCA dataset performs 43.7% better than training on the Lab-Baxter dataset and 33% better than DexNet. The low performance of DexNet can be described by the possible noise in the depth images that are caused by the natural light. DexNet, which requires high-quality depth sensing, cannot perform well in these scenarios. By using cheap commodity RGBD cameras in LCA, the noise in the depth images is not a matter of concern, as the model has no expectation of high-quality sensing.<br />
<br />
[[File:aa5.PNG|600px|thumb|center|]]<br />
<br />
===Performance on Real Sawyer===<br />
<br />
To compare the performance of the Robust-Grasp model against the Patch-Grasp model without collecting noise-free data, they used Lab-Baxter for benchmarking, which is an accurate and better calibrated robot. The Sawyer robot is used for testing to ensure that the testing robot is different from both training robots. As shown in Table 3, the Robust-Grasp model trained on Home-LCA outperforms the Patch-Grasp model and achieves 77.5% accuracy. This accuracy is similar to several recent papers, however, this model was trained and tested in a different environment. The Robust-Grasp model also outperforms the Patch-Grasp by about 4% on binary classification. Furthermore, the visualizations of predicted noise corrections in Figure 4 shows that the corrections depend on both the pixel locations of the noisy grasp and the robot.<br />
<br />
[[File:aa6.PNG|600px|thumb|center|]]<br />
<br />
[[File:aa7.PNG|600px|thumb|center|]]<br />
<br />
==Related work==<br />
<br />
Over the last few years, the interest of scaling up robot learning with large-scale datasets has been increased. Hence, many papers were published in this area. A hand annotated grasping dataset, a self-supervised grasping dataset, and grasping using reinforcement learning are some examples of using large-scale datasets for grasping. The work mentioned above used high-cost hardware and data labeling mechanisms. There were also many papers that worked on other robotic tasks like material recognition, pushing objects and manipulating a rope. However, none of these papers worked on real data in real environments like homes, they all used lab data.<br />
<br />
Furthermore, since grasping is one of the basic problems in robotics, there were some efforts to improve grasping. Classical approaches focused on physics-based issues of grasping and required 3D models of the objects. However, recent works focused on data-driven approaches which learn from visual observations to grasp objects. Simulation and real-world robots are both required for large-scale data collection. A versatile grasping model was proposed to achieve a 90% performance for a bin-picking task. The point here is that they usually require high-quality depth as input which seems to be a barrier for practical use of robots in real environments. High-quality depth sensing means a high cost to implement in hardware and thus is a barrier for practical use.<br />
<br />
Most labs use industrial robots or standard collaborative hardware for their experiments. Therefore, there is few research that used low-cost robots. One of the examples is learning using a cheap inaccurate robot for stack multiple blocks. Although mobile robots like iRobot’s Roomba have been in the home consumer electronics market for a decade, it is not clear whether learning approaches are used in it alongside mapping and planning.<br />
<br />
Learning from noisy inputs is another challenge specifically in computer vision. A controversial question which is often raised in this area is whether learning from noise can improve the performance. Some works show it could have bad effects on the performance; however, some other works find it valuable when the noise is independent or statistically dependent on the environment. In this paper, they used a model that can exploit the noise and learn a better grasping model.<br />
<br />
==Conclusion==<br />
<br />
All in all, the paper presents an approach for collecting large-scale robot data in real home environments. They implemented their approach by using a mobile manipulator which is a lot cheaper than the existing industrial robots. They collected a dataset of 28K grasps in six different homes. In order to solve the problem of noisy labels which were caused by their inaccurate robots, they presented a framework to factor out the noise in the data. They tested their model by physically grasping 20 new objects in three new homes and in the lab. The model trained with home dataset showed 43.7% improvement over the models trained with lab data. Their results also showed that their model can improve the grasping performance even in lab environments. They also demonstrated that their architecture for modeling the noise improved the performance by about 10%.<br />
<br />
==Critiques==<br />
<br />
This paper does not contain a significant algorithmic contribution. They are just combining a large number of data engineering techniques for the robot learning problem. The authors claim that they have obtained 43.7% more accuracy than baseline models, but it does not seem to be a fair comparison as the data collection happened in simulated settings in the lab for other methods, whereas the authors use the home dataset. The authors must have also discussed safety issues when training robots in real environments as against simulated environments like labs. The authors are encouraging other researchers to look outside the labs, but are not discussing the critical safety issues in this approach.<br />
<br />
Another strange finding is that the paper mentions that they "follow a model architecture similar to [Pinto and Gupta [4]]," however, the proposed model is, in fact, a fine-tuned resnet-18 architecture. Pinto and Gupta, implement a version similar to AlexNet as shown below in Figure 5.<br />
<br />
[[File:Figure_5_PandG.JPG | 450px|thumb|center|Figure 5: AlexNet architecture implemented in Pinto and Gupta [4].]]<br />
<br />
<br />
The paper argues that the dataset collected by the LCA is noisy, since the robot is cheap and inaccurate. It further asserts that in order to handle the noise in the dataset, they can model the noise as a latent variable and their model can improve the performance of grasping. Although learning from noisy data and achieving a good performance is valuable, it is better that they test their noise modeling network for other robots as well. Since their noise modelling network takes robot information as an input, it would be a good idea to generalize it by testing it using different inaccurate robots to ensure that it would perform well.<br />
<br />
They did not mention other aspects of their comparison, for example they could mention their training time compared to other models or the size of other datasets.<br />
<br />
==References==<br />
<br />
#Josh Tobin, Rachel Fong, Alex Ray, Jonas Schneider, Wojciech Zaremba, and Pieter Abbeel. "Domain randomization for transferring deep neural networks from simulation to the real world." 2017. URL https://arxiv.org/abs/1703.06907.<br />
#Xue Bin Peng, Marcin Andrychowicz, Wojciech Zaremba, and Pieter Abbeel. "Sim-to-real transfer of robotic control with dynamics randomization." arXiv preprint arXiv:1710.06537,2017.<br />
#Lerrel Pinto, Marcin Andrychowicz, Peter Welinder, Wojciech Zaremba, and Pieter Abbeel. "Asymmetric actor-critic for image-based robot learning." Robotics Science and Systems, 2018.<br />
#Lerrel Pinto and Abhinav Gupta. "Supersizing self-supervision: Learning to grasp from 50k tries and 700 robot hours." CoRR, abs/1509.06825, 2015. URL http://arxiv.org/abs/1509. 06825.<br />
#Adithyavairavan Murali, Lerrel Pinto, Dhiraj Gandhi, and Abhinav Gupta. "CASSL: Curriculum accelerated self-supervised learning." International Conference on Robotics and Automation, 2018.<br />
# Sergey Levine, Chelsea Finn, Trevor Darrell, and Pieter Abbeel. "End-to-end training of deep visuomotor policies." The Journal of Machine Learning Research, 17(1):1334–1373, 2016.<br />
#Sergey Levine, Peter Pastor, Alex Krizhevsky, and Deirdre Quillen. "Learning hand-eye coordination for robotic grasping with deep learning and large-scale data collection." CoRR, abs/1603.02199, 2016. URL http://arxiv.org/abs/1603.02199.<br />
#Pulkit Agarwal, Ashwin Nair, Pieter Abbeel, Jitendra Malik, and Sergey Levine. "Learning to poke by poking: Experiential learning of intuitive physics." 2016. URL http://arxiv.org/ abs/1606.07419<br />
#Chelsea Finn, Ian Goodfellow, and Sergey Levine. "Unsupervised learning for physical interaction through video prediction." In Advances in neural information processing systems, 2016.<br />
#Ashvin Nair, Dian Chen, Pulkit Agrawal, Phillip Isola, Pieter Abbeel, Jitendra Malik, and Sergey Levine. "Combining self-supervised learning and imitation for vision-based rope manipulation." International Conference on Robotics and Automation, 2017.<br />
#Chen Sun, Abhinav Shrivastava, Saurabh Singh, and Abhinav Gupta. "Revisiting unreasonable effectiveness of data in deep learning era." ICCV, 2017.<br />
#Marc Peter Deisenroth, Carl Edward Rasmussen, and Dieter Fox. Learning to control a low-cost manipulator using data-efficient reinforcement learning. RSS, 2011.<br />
#David F Nettleton, Albert Orriols-Puig, and Albert Fornells. A study of the effect of different types of noise on the precision of supervised learning techniques. Artificial intelligence review, 33(4):275–306, 2010.<br />
#Benoît Frénay and Michel Verleysen. Classification in the presence of label noise: a survey. IEEE transactions on neural networks and learning systems, 25(5):845–869, 2014.<br />
#Tong Xiao, Tian Xia, Yi Yang, Chang Huang, and Xiaogang Wang. Learning from massive noisy labeled data for image classification. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 2691–2699, 2015.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Attend_and_Predict:_Understanding_Gene_Regulation_by_Selective_Attention_on_Chromatin&diff=41974Attend and Predict: Understanding Gene Regulation by Selective Attention on Chromatin2018-11-30T00:25:40Z<p>Gsahu: /* Conclusion */</p>
<hr />
<div>This page contains a summary of the paper [https://arxiv.org/abs/1708.00339 "Attend and Predict: Understanding Gene Regulation by Selective Attention on Chromatin."] by Singh, Ritambhara, et al. It was published at the Advances in Neural Information Processing Systems (NIPS) in 2017. The code for this paper is shared here[https://qdata.github.io/deep4biomed-web/].<br />
<br />
<br />
= Background =<br />
<br />
Gene regulation is the process of controlling which genes in a cell's DNA are turned 'on' (expressed) or 'off' (not expressed). By this process, a functional product such as a protein is created. Even though all the cells of a multicellular organism (e.g., humans) contain the same DNA, different types of cells in that organism may express very different sets of genes. As a result, each cell types have distinct functionality. In other words how a cell operates depends upon the genes expressed in that cell. Many factors including ‘Chromatin modification marks’ influence which genes are abundant in that cell.<br />
<br />
The function of chromatin is to efficiently wraps DNA around bead-like structures of histones into a condensed volume to fit into the nucleus of a cell, and protect the DNA structure and sequence during cell division and replication. Different chemical modifications in the histones of the chromatin, known as histone marks, change spatial arrangement of the condensed DNA structure. Which in turn affects the gene’s expression of the histone mark’s neighboring region. Histone marks can promote (obstruct) the gene to be turned on by making the gene region accessible (restricted). This section of the DNA, where histone marks can potentially have an impact, is known as DNA flanking region or ‘gene region’ which is considered to cover 10k base pair centered at the transcription start site (TSS) (i.e., a 5k base pair in each direction). Unlike genetic mutations, histone modifications are reversible [1]. Therefore, understanding the influence of histone marks in determining gene regulation can assist in developing drugs for genetic diseases.<br />
<br />
= Introduction = <br />
<br />
Revolution in genomic technologies now enables us to profile genome-wide chromatin mark signals. Therefore, biologists can now measure gene expressions and chromatin signals of the ‘gene region’ for different cell types covering whole human genome. The Roadmap Epigenome Project (REMC, publicly available) [2] recently released 2,804 genome-wide datasets of 100 separate “normal” (not diseased) human cells/tissues, among which 166 datasets are gene expression reads and the rest are signal reads of various histone marks. The goal is to understand which histone marks are the most important and how they interact together in gene regulation for each cell type.<br />
<br />
Signal reads for histone marks are high-dimensional and spatially structured. Influence of a histone modification mark can be anywhere in the gene region (covering 10k base pairs centered around the Transcription Start Site of each gene). It is important to understand how the impact of the mark on gene expression varies over the gene region. In other words, how histone signals over the gene region impacts the gene expression. There are different types of histone marks in human chromatin that can have an influence on gene regulation. Researchers have found five standard histone proteins. These five histone proteins can be altered in different combinations with different chemical modifications resulting in a large number of distinct histone modification marks. Different histone modification marks can act as a module to interact with each other and influence the gene expression.<br />
<br />
<br />
This paper proposes an attention-based deep learning model to find how this chromatin factors/ histone modification marks contributes to the gene expression of a particular cell. AttentiveChrome[3] utilizes a hierarchy of multiple LSTM to discover interactions between signals of each histone marks, and learn dependencies among the marks on expressing a gene. The authors included two levels of soft attention mechanism, (1) to attend to the most relevant signals of a histone mark, and (2) to attend to the important marks and their interactions. In this context, ''attention'' refers to weighting the importance of different items differently.<br />
<br />
== Main Contributions ==<br />
The contributions of this work can be summarized as follows:<br />
<br />
* More accurate predictions than the state-of-the-art baselines. This is measured using datasets from REMC on 56 different cell types.<br />
* Better interpretation than the state-of-the-art methods for visualizing deep learning model. They compute the correlation of the attention scores of the model with the mark signal from REMC. <br />
* Like the application of attention models previously in indirectly hinting the parts of the input that the model deemed important, AttentiveChrome can too explain it's decisions by hinting at “what” and “where” it has focused.<br />
* This is the first time that the attention based deep learning approach is applied to a problem in molecular biology.<br />
* Ability to deal with highly modular inputs<br />
<br />
= Previous Works = <br />
<br />
Machine learning algorithms to classify gene expression from histone modification signals have been surveyed by [15]. These algorithms vary from linear regression, support vector machine, and random forests to rule-based learning, and CNNs. To accommodate the spatially structured, high dimensional input data (histone modification signals) these studies applied different feature selection strategies. The preceding research study, DeepChrome [4], by the authors incorporated the best position selection strategy. The positions that are highly correlated to the gene expression are considered as the best positions. This model can learn the relationship between the histone marks. This CNN based DeepChrome model outperforms all the previous works. However, these approaches either (1) failed to model the spatial dependencies among the marks, or (2) required additional feature analysis. Only AttentiveChrome is reported to satisfy all of the eight desirable metrics of a model.<br />
<br />
= AttentiveChrome: Model Formulation =<br />
<br />
The authors proposed an end-to-end architecture which has the ability to simultaneously attend and predict. This method incorporates recurrent neural networks (RNN) composed of LSTM units to model the sequential spatial dependencies of the gene regions and predict gene expression level from The embedding vector, <math> h_t </math>, output of an LSTM module encodes the learned representation of the feature dependencies from the time step 0 to <math> t </math>. For this task, each bin position of the gene region is considered as a time step.<br />
<br />
The proposed AttentiveChrome framework contains following 5 important modules:<br />
<br />
* Bin-level LSTM encoder encoding the bin positions of the gene region (one for each HM mark)<br />
* Bin-level <math> \alpha </math>-Attention across all bin positions (one for each HM mark)<br />
* HM-level LSTM encoder (one encoder encoding all HM marks)<br />
* HM-level <math> \beta </math>-Attention among all HM marks (one)<br />
* The final classification module<br />
<br />
Figure 1 (Supplementary Figure 2) presents an overview of the proposed AttentiveChrome framework.<br />
<br />
<br />
[[File:supplemntary_figure_2.png|thumb|center| 800px |Figure 1: Overview of the all five modules of the proposed AttentiveChrome framework]]<br />
<br />
<br />
<br />
== Input and Output ==<br />
<br />
Each dataset contains the gene expression labels and the histone signal reads for one specific cell type. The authors evaluated AttentiveChrome on 56 different cell types. For each mark, we have a feature/input vector containing the signals reads surrounding the gene’s TSS position (gene region) for the histone mark. The label of this input vector denotes the gene expression of the specific gene. This study considers binary labeling where <math> +1 </math> denotes gene is expressed (on) and <math> -1 </math> denotes that the gene is not expressed (off). Each histone marks will have one feature vector for each gene. The authors integrates the feature inputs and outputs of their previous work DeepChrome [4] into this research. The input feature is represented by a matrix <math> \textbf{X} </math> of size <math> M \times T </math>, where <math> M </math> is the number of HM marks considered in the input, and <math> T </math> is the number of bin positions taken into account to represent the gene region. The <math> j^{th} </math> row of the vector <math> \textbf{X} </math>, <math> x_j</math>, represents sequentially structured signals from the <math> j^{th} </math> HM mark, where <math> j\in \{1, \cdots, M\} </math>. Therefore, <math> x_j^t</math>, in the matrix <math> \textbf{X} </math> represents the value from the <math> t^{th}</math> bin belonging to the <math> j^{th} </math> HM mark, where <math> t\in \{1, \cdots, T\} </math>. If the training set contains <math>N_{tr} </math> labeled pairs, the <math> n^{th} </math> is specified as <math>( X^n, y^n)</math>, where <math> X^n </math> is a matrix of size <math> M \times T </math> and <math> y^n \in \{ -1, +1 \} </math> is the binary label, and <math> n \in \{ 1, \cdots, N_{tr} \} </math>.<br />
<br />
Figure 2 (also refer to Figure 1 (a), and 1(b) for better understanding) exhibits the input feature, and the output of AttentiveChrome for a particular gene (one sample).<br />
<br />
[[File:input-output-attentivechrome.png|center|thumb| 700px | Figure 2: Input and Output of the AttentiveChrome model]]<br />
<br />
== Bin-Level Encoder (one LSTM for each HM) ==<br />
The sequentially ordered elements (each element actually is a bin position) of the gene region of <math> n^{th} </math> gene is represented by the <math> j_{th} </math> row vector <math> x^j </math>. The authors considered each bin position as a time step for LSTM. This study incorporates bidirectional LSTM to model the overall dependencies among a total of <math> T </math> bin positions in the gene region. The bidirectional LSTM contains two LSTMs<br />
* A forward LSTM, <math> \overrightarrow{LSTM_j} </math>, to model <math> x^j </math> from <math> x_1^j </math> to <math> x_T^j </math>, which outputs the embedding vector <math> \overrightarrow{h^t_j} </math>, of size <math> d </math> for each bin <math> t </math><br />
* A reverse LSTM, <math> \overleftarrow{LSTM_j} </math>, to model <math> x^j </math> from <math> x_T^j </math> to <math> x_1^j </math>, which outputs the embedding vector <math> \overleftarrow{h^j_t} </math>, of size <math> d </math> for each bin <math> t </math><br />
<br />
The final output of this layer, embedding vector at <math> t^{th} </math> bin for the <math> j^{th} </math> HM, <math> h^j_t </math>, of size <math> d </math>, is obtained by concatenating the two vectors from the both directions. Therefore, <math> h^j_t = [ \overrightarrow{h^j_t}, \overleftarrow{h^j_t}]</math>. By pairing these LSTM-based HM encoders with the final classification, embedding each HM mark by drawing out the dependencies among bins can be learned by these pairs.Figure 1 (c) illustrates the module for <math> j=2 </math>.<br />
<br />
== Bin-Level <math> \alpha</math>-attention ==<br />
<br />
Each bin contributes differently in the encoding of the entire <math> j^{th} </math> mark. To automatically and adaptively highlight the most important bins for prediction, a soft attention weight vector <math> \alpha^j </math> of size <math> T </math> is learned for each <math> j </math>. To calculated the soft weight <math> \alpha^j_t </math>, for each <math> t </math>, the embedding vectors <math> \{h^j_1, \cdots, h^j_t \} </math> of all the bins are utilized. The following equation is used:<br />
<br />
<center><math> \alpha^j_t = \frac{exp(\textbf{W}_b h^j_t)}{\sum_{i=1}^T{exp(\textbf{W}_b h^j_i)}} </math></center><br />
<br />
<br />
<math> \alpha^j_t</math> is a scalar and is computed by all bins’ embedding vectors <math>h^j</math>. The parameter <math> W_b </math> is initialized randomly, and learned alongside during the process with the other model parameters. Therefore, once we have importance weight of each bin position, the <math> j^{th} </math> HM mark can be represented by <math> m^j = \sum_{t=1}^T{\alpha^j_t \times h^j_t}</math>. Here, <math> h^j_t</math> is the embedding vector and <math> \alpha^t_j </math> is the importance weight of the <math> t^{th} </math> bin in the representation of the <math> j^{th} </math> HM mark. Intuitively <math> \textbf{W}_b </math> will learn the cell type. Figure 1(d) shows this module for <math> HM_2 </math>.<br />
<br />
== HM-level Encoder (one LSTM) ==<br />
<br />
Studies observed that HMs work cooperatively to provoke or subdue gene expression [5]. The HM-level encoder (not in the fFgure 1) utilizes one bidirectional LSTM to capture this relationship between the HMs. To formulate the sequential dependency a random sequence is imagined as the authors did not find influence of any specific ordering of the HMs. The representation <math> m_j </math>of the <math> j^{th} </math> HM, <math> HM_j </math>, which is calculated from the bin-level attention layer, is the input of this step. This set based encoder outputs an embedding vector <math> s^j </math> of size <math> d’ </math>, which is the encoding for the <math> j^{th} </math> HM.<br />
<br />
<math> s^j = [ \overrightarrow{LSTM_s}(m_j), \overleftarrow{LSTM_s}(m_j) ] </math><br />
<br />
The dependencies between <math> j^{th} </math> HM and the other HM marks are encoded in <math> s^j </math>, whereas <math> m^j </math> from the previous step encodes the bin dependencies of the <math> j^{th} </math> HM.<br />
<br />
<br />
== HM-Level <math> \beta</math>-attention ==<br />
This second soft attention level (Figure 1(e)) finds the important HM marks for classifying a gene’s expression by learning the importance weights, <math> \beta_j </math>, for each <math> HM_j </math>, where <math> j \in \{ 1, \cdots, M \} </math>. The equation is <br />
<br />
<math> \beta^j = \frac{exp(\textbf{W}_s s^j)}{\sum_{i=1}^M{exp(\textbf{W}_s s^j)}} </math><br />
<br />
The HM-level context parameter <math> \textbf{W}_s </math> is trained jointly in the process. Intuitively <math> \textbf{W}_s </math> learns how the HMs are significant for a cell type. Finally the entire gene region is encoded in a hidden representation <math> \textbf{v} </math>, using the weighted sum of the embedding of all HM marks. <br />
<br />
<br />
<math> \textbf{v} = \sum_{j=1}^MT{\beta^j \times s^j}</math><br />
<br />
== End-to-end training ==<br />
<br />
The embedding vector <math> \textbf{v} </math> is fed to a simple classification module, <math> f(\textbf{v}) = </math>softmax<math> (\textbf{W}_c\textbf{v}+b_c) </math>, where <math> \textbf{W}_c </math>, and <math> b_c </math> are learnable parameters. The output is the probability of gene expression being high (expressed) or low (suppressed).<br />
The whole model including the attention modules is differentiable. Thus backpropagation can perform end-to-end learning trivially. The negative log-likelihood loss function is minimized in the learning.<br />
<br />
= Experimental Settings =<br />
<br />
This work makes use of the REMC dataset. AttentiveChrome is evaluated on 56 different cell types. Similar to DeepChrome, this study considered the following five core HM marks (<math> M=5 </math>). Because these selected marks are uniformly profiled across all 56 cell types in the REMC study.<br />
<br />
[[File:HM.png|center|thumb| 700px | Table 1: Five core HM marks and their attributes considered in this paper]]<br />
<br />
<br />
<br />
For a gene region 10k base pairs centred at the TSS site (5k bp in each direction) are taken into account. These 10k base pairs are divided into 100 bins, each bin consisting of <math> T=100 </math> continuous bp). Therefore, for each gene in a particular cell type, the input matrix will be of size <math> 5 \times 100 </math>. The gene expression labels are normalized and discretized to represent binary labelling. The sample dataset is divided into three equal sized folds for training, validation, and testing.<br />
<br />
== Model Variations and Two Baselines ==<br />
To evaluate the performance of the proposed model the authors considered RNN method (direct LSTM without any attention), and their prior work DeepChrome as baselines. The results obtained from multiple variations of the AttentiveChrome model are compared with the baselines. The authors considered five variant of AttentiveChrome during performance evaluation. The variants are:<br />
<br />
* LSTM-Attn: one LSTM with attention on the input matrix (does not consider the modular nature of HM marks)<br />
* CNN-Attn: DeepChrome [4] with one attention mechanism incorporated. <br />
* LSTM-<math>\alpha , \beta</math>: the proposed architecture.<br />
* CNN-<math>\alpha , \beta</math>: LSTM module of the proposed architecture replaced with CNN. This variation includes two attention mechanisms. First attention mechanism contains one <math>\alpha</math>-attention on top of a CNN module per HM mark. And, the second -<math>\beta</math>- attention mechanism is used to combine HMs.<br />
* LSTM-<math>\alpha</math>: one LSTM and <math>\alpha</math>-attention per HM mark.<br />
<br />
== Hyperparameters ==<br />
<br />
For all the variants of AttentiveChrome the bin-level LSTM embedding size <math> d</math> is set to 32, and the HM-level LSTM embedding size <math>d’</math> is set to 16. Because of bidirectional LSTM, the size of the embedding vector <math> h_t</math>, and <math>m_j</math> will be 64, and 32 respectively. Size of the context vectors are set accordingly.<br />
<br />
= Performance Evaluation =<br />
<br />
== AUC Scores ==<br />
<br />
This study summarizes AUC scores across all 56 cell types on the test set to compare the methods.<br />
<br />
[[File:AUC.JPG|center|thumb| 700px | Table 2: AUC score performances for different variations of AttentiveChrome and baselines]]<br />
<br />
Overall the LSTM-attention models perform better than the DeepChrome (CNN-based) and LSTM baselines. The authors argue that the proposed AttentiveChrome model is a good choice because of its interpretability, even though the performance improvement from DeepChrome is insignificant.<br />
<br />
== Evaluation of Attention Scores for Interpretation ==<br />
<br />
To understand if the model is focusing on the right regions, the authors make use of additional study results from REMC database. To validate the bin attention,signal data of a new histone mark, H3K27ac, referred to as <math>H_{active}</math> in this article, from REMC database is utilized. This particular histone mark is known to mark active region when the gene is expressed (ON). Genome-wide read of this HM mark is available for three important cell types: stem cell (H1-hESC), blood cell (GM12878), and leukemia cell (K562). This particular HM mark is used to analyze the visualization results only and not applied in the learning phase. The authors discussed performance of both the attention mechanisms in this section. <br />
<br />
=== Correlation of Importance Weight of <math>H_{prom}</math> with <math>H_{active}</math> ===<br />
<br />
Average read count of <math>H_{active}</math> across all 100 bins for all the active genes (ON or labeled as <math>+1</math>) in the three selected cell types is calculated. The proposed AttentiveChrome and LSTM-<math>\alpha</math> methods are compared with two widely used visualization techniques, (1) class based, and (2) saliency map applied on the baseline DeepChrome model (CNN-based prior work). Using these visualization methods, the authors calculate the importance weights for <math>H_{prom}</math> (promoter HM mark used in training) across the 100 bins. The Pearson Correlation score between these importance weights and the read count of the <math>H_{active}</math> (HM mark for validation) across the same 100 bins is computed. The <math>H_{active}</math> read counts indicates the actual active regions of those cells. <br />
<br />
[[File: pc.JPG|center|thumb| 700px | Figure 4: Pearson Correlation between a known active HM mark]]<br />
<br />
<br />
The results indicate that the proposed models consistently gained highest correlation with <math>H_{active}</math> for all three cell types. Thus, the proposed method is successful to capture the important signals.<br />
<br />
=== Visualization of Attention Weight of bins for each HM of a specific cell type GM12878===<br />
<br />
To visualize bin level attention weights, the authors plotted the average bin-level attention weights for each HM for a specific cell type GM12878 (blood cell) for expressed (ON) genes and suppressed (OFF) genes separately. <br />
<br />
[[File: figure2.png|center|thumb| 700px |]]<br />
<br />
For the “ON” genes, the attention profiles are well defined for the HM marks, <math>H_{prom}</math>, <math>H_{enhc}</math>, <math>H_{struct}</math>. On the other hand, the weights are low for <math>H_{reprA}</math> and <math>H_{reprB}</math>. The average trend reverses for the “OFF” genes, where the repressor HM marks have more influence than the <math>H_{prom}</math>, <math>H_{enhc}</math>, <math>H_{struct}</math>. This observation agrees with the biologist finding that <math>H_{prom}</math>, <math>H_{enhc}</math>, <math>H_{struct}</math> marks stimulates gene activation and, <math>H_{reprA}</math> and <math>H_{reprB}</math> mark restrains the genes.<br />
<br />
=== Attention Weight of bins with <math>H_{active}</math>===<br />
<br />
The average read counts of <math>H_{active}</math> for the same 100 bins across all the active (ON) genes for the cell type GM12878 is plotted (FIGURE 2(b)). Besides, for AttentiveChrome the plot of bin-level attention weights of averaged over all the genes that are PREDICTED ON for GM12878 is also provided. The plots exhibit that the <math>H_{prom}</math> profile is similar to <math>H_{active}</math>.<br />
<br />
=== Visualization of HM-level Attention Weight for Gene PAX5 ===<br />
<br />
To visualize HM-level attention weight the authors produces a heatmap for a differentially regulated gene, PAX5, for the three aforementioned cell types. The heatmap is presented in FIGURE 2(c). PAX5 plays significant role in gene regulation when stem cells convert to blood cells. This gene is OFF in stem cells (H1-hESC), however it becomes activated when the stem cell is transformed into blood cell (GM12878). The <math>\beta_j</math> weight for <math>H_{repr}</math> is high when the gene is OFF in H1-hESC, and the weight decreases when the gene is ON in GM12878. On the contrary, for <math>H_{prom}</math> mark the <math>\beta_j</math> weight increases from H1-hESC to GM12878 as the gene becomes activated. This information extracted by the deep learning model is also supported by biological literature [16].<br />
<br />
= Related Works/Studies =<br />
<br />
In the last few years, deep learning models obtained models obtained unprecedented success in diverse research fields. Though as not rapidly as other fields, deep learning based algorithms are gaining popularity among bioinformaticians.<br />
<br />
== Attention-based Deep Models ==<br />
<br />
The idea of attention technique in deep learning is adapted from the human visual perception system. Humans tend to focus over some parts more than the others while perceiving a scene. This mechanism augmented with deep neural networks achieved an excellent outcome in several research topics, such as machine translation. Various types of attention models e.g., soft [6], or location-aware [7], or hard [8, 9] attentions have been proposed in the literature. In the soft attention model, a soft weight vector is calculated for the overall feature vectors. The extent of the weight is correlated with the degree of importance of the feature in the prediction. In practice, RNN is often used to help implement such models.<br />
<br />
== Visualization and Apprehension of Deep Models ==<br />
<br />
Prior studies mostly focused on interpreting convolutional neural networks (CNN) for image classification. Deconvulation approaches [10] attempt to map hidden layer representations back to an input space. Saliency maps [11, 12], attempt to use taylor expansion to approximate the network, and identify the most relevant input features. Class optimization [12] based visualization techniques attempt to find the best example member of each class. Some recent research works [13, 14] tried to understand recurrent neural networks (RNN) for text-based problems. By looking into the features the model attends to, we can interpret the output of a deep model.<br />
<br />
== Deep Learning in Bioinformatics ==<br />
Deep learning is also getting popular in bioinformatics fields because it is able to extract meaningful representations from datasets. Scholars use deep learning to model protein sequences and DNA sequences and predicting gene expressions.<br />
<br />
== Previous model for gene expression predictions ==<br />
There were multiple machine learning models had been used to predict gene expressions, such as linear regression and support vector machines. The strategies included using signal averaging across all relevant positions and selecting input signals at positions where was highly correlated to target gene expression and then use CNN to learn combinatorial interactions among histone modification marks.<br />
<br />
= Conclusion = <br />
<br />
The paper has introduced an attention-based approach called "AttentiveChrome" that deals with both understanding and prediction with several advantages on previous architectures including higher accuracy from state-of-the-art baselines, clearer interpretation than saliency map, which allows them to view what the model ‘sees’ during prediction, and class optimization. Another advantage of this approach is that it can model modular feature inputs which are sequentially structured. Finally, according to the authors, this is the first implementation of deep attention to understand gene regulation. AttentiveChrome is claimed to be the first attention based model applied on a molecular biology dataset. The authors expect that through this deep attention mechanism, the biologists can have a better understanding of epigenomic data. This model can handle understanding and prediction of hard to interpret biological data as it grants insights<br />
to the predictions by locating ‘what’ and ‘where’ AttentiveChrome has focused.<br />
<br />
= Critiques =<br />
<br />
This paper does not give a considerable algorithmic contribution. They have only used existing methods for this application. This deep learning based method is shown to perform better than simple machine learning models like linear regression and SVMs but this is considerably harder to implement and has many more hyperparameters to tune. The training time is considerably higher, especially because all the parameters are learned together. The dataset considered in the application here also seems to have only a limited number of samples for a study of high complexity. Model hyperparameters have been chosen randomly without any explanation of intuition for them. The authors have also not cited any relevant literature to understand where these numbers came from. <br />
<br />
Discussion about attention scores for interpretation does not provide any clear definition or mention previous literature using them. Reference of literature about H3K27ac, and how its read counts represent active region of a cell should be included. No reasoning given for why only one specific cell type is used to visualize bin level attention weights. Example of some other real world problems where this model can be useful should be provided.<br />
<br />
Moreover, this paper relies heavily on the intuition. Due to complicated structures, it must be challenging to provide algorithmic/theoretical justifications. This means that there is no proper guidence of how hyperparameters should be chosen or any kinds of treatment that the author performs on other data sets.<br />
<br />
= Additional Resources =<br />
<br />
# [https://qdata.github.io/deep4biomed-web/ Official DeepChrome Website]<br />
# [http://papers.nips.cc/paper/7255-attend-and-predict-understanding-gene-regulation-by-selective-attention-on-chromatin-supplemental.zip Supplemental Resources]<br />
# [https://github.com/QData/AttentiveChrome/blob/master/NIPS%20poster.pdf Poster]<br />
# [https://www.youtube.com/watch?v=tfgmXvSgsQE&feature=youtu.be Video Presentation]<br />
<br />
= Reference =<br />
<br />
[1] Andrew J Bannister and Tony Kouzarides. Regulation of chromatin by histone modifications. Cell Research, 21(3):381–395, 2011.<br />
<br />
[2] Anshul Kundaje, Wouter Meuleman, Jason Ernst, Misha Bilenky, Angela Yen, Alireza Heravi-Moussavi, Pouya Kheradpour, Zhizhuo Zhang, Jianrong Wang, Michael J Ziller, et al. Integrative analysis of 111 reference human epigenomes. Nature, 518(7539):317–330, 2015.<br />
<br />
[3] Singh, Ritambhara, et al. "Attend and Predict: Understanding Gene Regulation by Selective Attention on Chromatin." Advances in Neural Information Processing Systems. 2017.<br />
<br />
[4] Ritambhara Singh, Jack Lanchantin, Gabriel Robins, and Yanjun Qi. Deepchrome: deep-learning for predicting gene expression from histone modifications. Bioinformatics, 32(17):i639–i648, 2016.<br />
<br />
[5] Joanna Boros, Nausica Arnoult, Vincent Stroobant, Jean-François Collet, and Anabelle Decottignies. Polycomb repressive complex 2 and h3k27me3 cooperate with h3k9 methylation to maintain heterochromatin protein 1α at chromatin. Molecular and cellular biology, 34(19):3662–3674, 2014.<br />
<br />
[6] Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473, 2014.<br />
<br />
[7] Jan K Chorowski, Dzmitry Bahdanau, Dmitriy Serdyuk, Kyunghyun Cho, and Yoshua Bengio. Attention-based models for speech recognition. In C. Cortes, N. D. Lawrence, D. D. Lee, M. Sugiyama, and R. Garnett, editors, Advances in Neural Information Processing Systems 28, pages 577–585. Curran Associates, Inc., 2015.<br />
<br />
[8] Minh-Thang Luong, Hieu Pham, and Christopher D. Manning. Effective approaches to attention-based neural machine translation. In Empirical Methods in Natural Language Processing (EMNLP), pages 1412–1421, Lisbon, Portugal, September 2015. Association for Computational Linguistics.<br />
<br />
[9] Huijuan Xu and Kate Saenko. Ask, attend and answer: Exploring question-guided spatial attention for visual question answering. In ECCV, 2016.<br />
<br />
[10] Matthew D Zeiler and Rob Fergus. Visualizing and understanding convolutional networks. In Computer Vision–ECCV 2014, pages 818–833. Springer, 2014.<br />
<br />
[11] David Baehrens, Timon Schroeter, Stefan Harmeling, Motoaki Kawanabe, Katja Hansen, and Klaus-Robert MÃžller. How to explain individual classification decisions. volume 11, pages 1803–1831, 2010.<br />
<br />
[12] Karen Simonyan, Andrea Vedaldi, and Andrew Zisserman. Deep inside convolutional networks: Visualising image classification models and saliency maps. 2013.<br />
<br />
[13] Andrej Karpathy, Justin Johnson, and Fei-Fei Li. Visualizing and understanding recurrent networks. 2015.<br />
<br />
[14] Jiwei Li, Xinlei Chen, Eduard Hovy, and Dan Jurafsky. Visualizing and understanding neural models in nlp. 2015.<br />
<br />
[15] Xianjun Dong and Zhiping Weng. The correlation between histone modifications and gene expression. Epigenomics, 5(2):113–116, 2013.<br />
<br />
[16] Shane McManus, Anja Ebert, Giorgia Salvagiotto, Jasna Medvedovic, Qiong Sun, Ido Tamir, Markus Jaritz, Hiromi Tagoh, and Meinrad Busslinger. The transcription factor pax5 regulates its target genes by recruiting chromatin-modifying proteins in committed b cells. The EMBO journal, 30(12):2388–2404, 2011.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Attend_and_Predict:_Understanding_Gene_Regulation_by_Selective_Attention_on_Chromatin&diff=41973Attend and Predict: Understanding Gene Regulation by Selective Attention on Chromatin2018-11-30T00:24:18Z<p>Gsahu: /* Conclusion */</p>
<hr />
<div>This page contains a summary of the paper [https://arxiv.org/abs/1708.00339 "Attend and Predict: Understanding Gene Regulation by Selective Attention on Chromatin."] by Singh, Ritambhara, et al. It was published at the Advances in Neural Information Processing Systems (NIPS) in 2017. The code for this paper is shared here[https://qdata.github.io/deep4biomed-web/].<br />
<br />
<br />
= Background =<br />
<br />
Gene regulation is the process of controlling which genes in a cell's DNA are turned 'on' (expressed) or 'off' (not expressed). By this process, a functional product such as a protein is created. Even though all the cells of a multicellular organism (e.g., humans) contain the same DNA, different types of cells in that organism may express very different sets of genes. As a result, each cell types have distinct functionality. In other words how a cell operates depends upon the genes expressed in that cell. Many factors including ‘Chromatin modification marks’ influence which genes are abundant in that cell.<br />
<br />
The function of chromatin is to efficiently wraps DNA around bead-like structures of histones into a condensed volume to fit into the nucleus of a cell, and protect the DNA structure and sequence during cell division and replication. Different chemical modifications in the histones of the chromatin, known as histone marks, change spatial arrangement of the condensed DNA structure. Which in turn affects the gene’s expression of the histone mark’s neighboring region. Histone marks can promote (obstruct) the gene to be turned on by making the gene region accessible (restricted). This section of the DNA, where histone marks can potentially have an impact, is known as DNA flanking region or ‘gene region’ which is considered to cover 10k base pair centered at the transcription start site (TSS) (i.e., a 5k base pair in each direction). Unlike genetic mutations, histone modifications are reversible [1]. Therefore, understanding the influence of histone marks in determining gene regulation can assist in developing drugs for genetic diseases.<br />
<br />
= Introduction = <br />
<br />
Revolution in genomic technologies now enables us to profile genome-wide chromatin mark signals. Therefore, biologists can now measure gene expressions and chromatin signals of the ‘gene region’ for different cell types covering whole human genome. The Roadmap Epigenome Project (REMC, publicly available) [2] recently released 2,804 genome-wide datasets of 100 separate “normal” (not diseased) human cells/tissues, among which 166 datasets are gene expression reads and the rest are signal reads of various histone marks. The goal is to understand which histone marks are the most important and how they interact together in gene regulation for each cell type.<br />
<br />
Signal reads for histone marks are high-dimensional and spatially structured. Influence of a histone modification mark can be anywhere in the gene region (covering 10k base pairs centered around the Transcription Start Site of each gene). It is important to understand how the impact of the mark on gene expression varies over the gene region. In other words, how histone signals over the gene region impacts the gene expression. There are different types of histone marks in human chromatin that can have an influence on gene regulation. Researchers have found five standard histone proteins. These five histone proteins can be altered in different combinations with different chemical modifications resulting in a large number of distinct histone modification marks. Different histone modification marks can act as a module to interact with each other and influence the gene expression.<br />
<br />
<br />
This paper proposes an attention-based deep learning model to find how this chromatin factors/ histone modification marks contributes to the gene expression of a particular cell. AttentiveChrome[3] utilizes a hierarchy of multiple LSTM to discover interactions between signals of each histone marks, and learn dependencies among the marks on expressing a gene. The authors included two levels of soft attention mechanism, (1) to attend to the most relevant signals of a histone mark, and (2) to attend to the important marks and their interactions. In this context, ''attention'' refers to weighting the importance of different items differently.<br />
<br />
== Main Contributions ==<br />
The contributions of this work can be summarized as follows:<br />
<br />
* More accurate predictions than the state-of-the-art baselines. This is measured using datasets from REMC on 56 different cell types.<br />
* Better interpretation than the state-of-the-art methods for visualizing deep learning model. They compute the correlation of the attention scores of the model with the mark signal from REMC. <br />
* Like the application of attention models previously in indirectly hinting the parts of the input that the model deemed important, AttentiveChrome can too explain it's decisions by hinting at “what” and “where” it has focused.<br />
* This is the first time that the attention based deep learning approach is applied to a problem in molecular biology.<br />
* Ability to deal with highly modular inputs<br />
<br />
= Previous Works = <br />
<br />
Machine learning algorithms to classify gene expression from histone modification signals have been surveyed by [15]. These algorithms vary from linear regression, support vector machine, and random forests to rule-based learning, and CNNs. To accommodate the spatially structured, high dimensional input data (histone modification signals) these studies applied different feature selection strategies. The preceding research study, DeepChrome [4], by the authors incorporated the best position selection strategy. The positions that are highly correlated to the gene expression are considered as the best positions. This model can learn the relationship between the histone marks. This CNN based DeepChrome model outperforms all the previous works. However, these approaches either (1) failed to model the spatial dependencies among the marks, or (2) required additional feature analysis. Only AttentiveChrome is reported to satisfy all of the eight desirable metrics of a model.<br />
<br />
= AttentiveChrome: Model Formulation =<br />
<br />
The authors proposed an end-to-end architecture which has the ability to simultaneously attend and predict. This method incorporates recurrent neural networks (RNN) composed of LSTM units to model the sequential spatial dependencies of the gene regions and predict gene expression level from The embedding vector, <math> h_t </math>, output of an LSTM module encodes the learned representation of the feature dependencies from the time step 0 to <math> t </math>. For this task, each bin position of the gene region is considered as a time step.<br />
<br />
The proposed AttentiveChrome framework contains following 5 important modules:<br />
<br />
* Bin-level LSTM encoder encoding the bin positions of the gene region (one for each HM mark)<br />
* Bin-level <math> \alpha </math>-Attention across all bin positions (one for each HM mark)<br />
* HM-level LSTM encoder (one encoder encoding all HM marks)<br />
* HM-level <math> \beta </math>-Attention among all HM marks (one)<br />
* The final classification module<br />
<br />
Figure 1 (Supplementary Figure 2) presents an overview of the proposed AttentiveChrome framework.<br />
<br />
<br />
[[File:supplemntary_figure_2.png|thumb|center| 800px |Figure 1: Overview of the all five modules of the proposed AttentiveChrome framework]]<br />
<br />
<br />
<br />
== Input and Output ==<br />
<br />
Each dataset contains the gene expression labels and the histone signal reads for one specific cell type. The authors evaluated AttentiveChrome on 56 different cell types. For each mark, we have a feature/input vector containing the signals reads surrounding the gene’s TSS position (gene region) for the histone mark. The label of this input vector denotes the gene expression of the specific gene. This study considers binary labeling where <math> +1 </math> denotes gene is expressed (on) and <math> -1 </math> denotes that the gene is not expressed (off). Each histone marks will have one feature vector for each gene. The authors integrates the feature inputs and outputs of their previous work DeepChrome [4] into this research. The input feature is represented by a matrix <math> \textbf{X} </math> of size <math> M \times T </math>, where <math> M </math> is the number of HM marks considered in the input, and <math> T </math> is the number of bin positions taken into account to represent the gene region. The <math> j^{th} </math> row of the vector <math> \textbf{X} </math>, <math> x_j</math>, represents sequentially structured signals from the <math> j^{th} </math> HM mark, where <math> j\in \{1, \cdots, M\} </math>. Therefore, <math> x_j^t</math>, in the matrix <math> \textbf{X} </math> represents the value from the <math> t^{th}</math> bin belonging to the <math> j^{th} </math> HM mark, where <math> t\in \{1, \cdots, T\} </math>. If the training set contains <math>N_{tr} </math> labeled pairs, the <math> n^{th} </math> is specified as <math>( X^n, y^n)</math>, where <math> X^n </math> is a matrix of size <math> M \times T </math> and <math> y^n \in \{ -1, +1 \} </math> is the binary label, and <math> n \in \{ 1, \cdots, N_{tr} \} </math>.<br />
<br />
Figure 2 (also refer to Figure 1 (a), and 1(b) for better understanding) exhibits the input feature, and the output of AttentiveChrome for a particular gene (one sample).<br />
<br />
[[File:input-output-attentivechrome.png|center|thumb| 700px | Figure 2: Input and Output of the AttentiveChrome model]]<br />
<br />
== Bin-Level Encoder (one LSTM for each HM) ==<br />
The sequentially ordered elements (each element actually is a bin position) of the gene region of <math> n^{th} </math> gene is represented by the <math> j_{th} </math> row vector <math> x^j </math>. The authors considered each bin position as a time step for LSTM. This study incorporates bidirectional LSTM to model the overall dependencies among a total of <math> T </math> bin positions in the gene region. The bidirectional LSTM contains two LSTMs<br />
* A forward LSTM, <math> \overrightarrow{LSTM_j} </math>, to model <math> x^j </math> from <math> x_1^j </math> to <math> x_T^j </math>, which outputs the embedding vector <math> \overrightarrow{h^t_j} </math>, of size <math> d </math> for each bin <math> t </math><br />
* A reverse LSTM, <math> \overleftarrow{LSTM_j} </math>, to model <math> x^j </math> from <math> x_T^j </math> to <math> x_1^j </math>, which outputs the embedding vector <math> \overleftarrow{h^j_t} </math>, of size <math> d </math> for each bin <math> t </math><br />
<br />
The final output of this layer, embedding vector at <math> t^{th} </math> bin for the <math> j^{th} </math> HM, <math> h^j_t </math>, of size <math> d </math>, is obtained by concatenating the two vectors from the both directions. Therefore, <math> h^j_t = [ \overrightarrow{h^j_t}, \overleftarrow{h^j_t}]</math>. By pairing these LSTM-based HM encoders with the final classification, embedding each HM mark by drawing out the dependencies among bins can be learned by these pairs.Figure 1 (c) illustrates the module for <math> j=2 </math>.<br />
<br />
== Bin-Level <math> \alpha</math>-attention ==<br />
<br />
Each bin contributes differently in the encoding of the entire <math> j^{th} </math> mark. To automatically and adaptively highlight the most important bins for prediction, a soft attention weight vector <math> \alpha^j </math> of size <math> T </math> is learned for each <math> j </math>. To calculated the soft weight <math> \alpha^j_t </math>, for each <math> t </math>, the embedding vectors <math> \{h^j_1, \cdots, h^j_t \} </math> of all the bins are utilized. The following equation is used:<br />
<br />
<center><math> \alpha^j_t = \frac{exp(\textbf{W}_b h^j_t)}{\sum_{i=1}^T{exp(\textbf{W}_b h^j_i)}} </math></center><br />
<br />
<br />
<math> \alpha^j_t</math> is a scalar and is computed by all bins’ embedding vectors <math>h^j</math>. The parameter <math> W_b </math> is initialized randomly, and learned alongside during the process with the other model parameters. Therefore, once we have importance weight of each bin position, the <math> j^{th} </math> HM mark can be represented by <math> m^j = \sum_{t=1}^T{\alpha^j_t \times h^j_t}</math>. Here, <math> h^j_t</math> is the embedding vector and <math> \alpha^t_j </math> is the importance weight of the <math> t^{th} </math> bin in the representation of the <math> j^{th} </math> HM mark. Intuitively <math> \textbf{W}_b </math> will learn the cell type. Figure 1(d) shows this module for <math> HM_2 </math>.<br />
<br />
== HM-level Encoder (one LSTM) ==<br />
<br />
Studies observed that HMs work cooperatively to provoke or subdue gene expression [5]. The HM-level encoder (not in the fFgure 1) utilizes one bidirectional LSTM to capture this relationship between the HMs. To formulate the sequential dependency a random sequence is imagined as the authors did not find influence of any specific ordering of the HMs. The representation <math> m_j </math>of the <math> j^{th} </math> HM, <math> HM_j </math>, which is calculated from the bin-level attention layer, is the input of this step. This set based encoder outputs an embedding vector <math> s^j </math> of size <math> d’ </math>, which is the encoding for the <math> j^{th} </math> HM.<br />
<br />
<math> s^j = [ \overrightarrow{LSTM_s}(m_j), \overleftarrow{LSTM_s}(m_j) ] </math><br />
<br />
The dependencies between <math> j^{th} </math> HM and the other HM marks are encoded in <math> s^j </math>, whereas <math> m^j </math> from the previous step encodes the bin dependencies of the <math> j^{th} </math> HM.<br />
<br />
<br />
== HM-Level <math> \beta</math>-attention ==<br />
This second soft attention level (Figure 1(e)) finds the important HM marks for classifying a gene’s expression by learning the importance weights, <math> \beta_j </math>, for each <math> HM_j </math>, where <math> j \in \{ 1, \cdots, M \} </math>. The equation is <br />
<br />
<math> \beta^j = \frac{exp(\textbf{W}_s s^j)}{\sum_{i=1}^M{exp(\textbf{W}_s s^j)}} </math><br />
<br />
The HM-level context parameter <math> \textbf{W}_s </math> is trained jointly in the process. Intuitively <math> \textbf{W}_s </math> learns how the HMs are significant for a cell type. Finally the entire gene region is encoded in a hidden representation <math> \textbf{v} </math>, using the weighted sum of the embedding of all HM marks. <br />
<br />
<br />
<math> \textbf{v} = \sum_{j=1}^MT{\beta^j \times s^j}</math><br />
<br />
== End-to-end training ==<br />
<br />
The embedding vector <math> \textbf{v} </math> is fed to a simple classification module, <math> f(\textbf{v}) = </math>softmax<math> (\textbf{W}_c\textbf{v}+b_c) </math>, where <math> \textbf{W}_c </math>, and <math> b_c </math> are learnable parameters. The output is the probability of gene expression being high (expressed) or low (suppressed).<br />
The whole model including the attention modules is differentiable. Thus backpropagation can perform end-to-end learning trivially. The negative log-likelihood loss function is minimized in the learning.<br />
<br />
= Experimental Settings =<br />
<br />
This work makes use of the REMC dataset. AttentiveChrome is evaluated on 56 different cell types. Similar to DeepChrome, this study considered the following five core HM marks (<math> M=5 </math>). Because these selected marks are uniformly profiled across all 56 cell types in the REMC study.<br />
<br />
[[File:HM.png|center|thumb| 700px | Table 1: Five core HM marks and their attributes considered in this paper]]<br />
<br />
<br />
<br />
For a gene region 10k base pairs centred at the TSS site (5k bp in each direction) are taken into account. These 10k base pairs are divided into 100 bins, each bin consisting of <math> T=100 </math> continuous bp). Therefore, for each gene in a particular cell type, the input matrix will be of size <math> 5 \times 100 </math>. The gene expression labels are normalized and discretized to represent binary labelling. The sample dataset is divided into three equal sized folds for training, validation, and testing.<br />
<br />
== Model Variations and Two Baselines ==<br />
To evaluate the performance of the proposed model the authors considered RNN method (direct LSTM without any attention), and their prior work DeepChrome as baselines. The results obtained from multiple variations of the AttentiveChrome model are compared with the baselines. The authors considered five variant of AttentiveChrome during performance evaluation. The variants are:<br />
<br />
* LSTM-Attn: one LSTM with attention on the input matrix (does not consider the modular nature of HM marks)<br />
* CNN-Attn: DeepChrome [4] with one attention mechanism incorporated. <br />
* LSTM-<math>\alpha , \beta</math>: the proposed architecture.<br />
* CNN-<math>\alpha , \beta</math>: LSTM module of the proposed architecture replaced with CNN. This variation includes two attention mechanisms. First attention mechanism contains one <math>\alpha</math>-attention on top of a CNN module per HM mark. And, the second -<math>\beta</math>- attention mechanism is used to combine HMs.<br />
* LSTM-<math>\alpha</math>: one LSTM and <math>\alpha</math>-attention per HM mark.<br />
<br />
== Hyperparameters ==<br />
<br />
For all the variants of AttentiveChrome the bin-level LSTM embedding size <math> d</math> is set to 32, and the HM-level LSTM embedding size <math>d’</math> is set to 16. Because of bidirectional LSTM, the size of the embedding vector <math> h_t</math>, and <math>m_j</math> will be 64, and 32 respectively. Size of the context vectors are set accordingly.<br />
<br />
= Performance Evaluation =<br />
<br />
== AUC Scores ==<br />
<br />
This study summarizes AUC scores across all 56 cell types on the test set to compare the methods.<br />
<br />
[[File:AUC.JPG|center|thumb| 700px | Table 2: AUC score performances for different variations of AttentiveChrome and baselines]]<br />
<br />
Overall the LSTM-attention models perform better than the DeepChrome (CNN-based) and LSTM baselines. The authors argue that the proposed AttentiveChrome model is a good choice because of its interpretability, even though the performance improvement from DeepChrome is insignificant.<br />
<br />
== Evaluation of Attention Scores for Interpretation ==<br />
<br />
To understand if the model is focusing on the right regions, the authors make use of additional study results from REMC database. To validate the bin attention,signal data of a new histone mark, H3K27ac, referred to as <math>H_{active}</math> in this article, from REMC database is utilized. This particular histone mark is known to mark active region when the gene is expressed (ON). Genome-wide read of this HM mark is available for three important cell types: stem cell (H1-hESC), blood cell (GM12878), and leukemia cell (K562). This particular HM mark is used to analyze the visualization results only and not applied in the learning phase. The authors discussed performance of both the attention mechanisms in this section. <br />
<br />
=== Correlation of Importance Weight of <math>H_{prom}</math> with <math>H_{active}</math> ===<br />
<br />
Average read count of <math>H_{active}</math> across all 100 bins for all the active genes (ON or labeled as <math>+1</math>) in the three selected cell types is calculated. The proposed AttentiveChrome and LSTM-<math>\alpha</math> methods are compared with two widely used visualization techniques, (1) class based, and (2) saliency map applied on the baseline DeepChrome model (CNN-based prior work). Using these visualization methods, the authors calculate the importance weights for <math>H_{prom}</math> (promoter HM mark used in training) across the 100 bins. The Pearson Correlation score between these importance weights and the read count of the <math>H_{active}</math> (HM mark for validation) across the same 100 bins is computed. The <math>H_{active}</math> read counts indicates the actual active regions of those cells. <br />
<br />
[[File: pc.JPG|center|thumb| 700px | Figure 4: Pearson Correlation between a known active HM mark]]<br />
<br />
<br />
The results indicate that the proposed models consistently gained highest correlation with <math>H_{active}</math> for all three cell types. Thus, the proposed method is successful to capture the important signals.<br />
<br />
=== Visualization of Attention Weight of bins for each HM of a specific cell type GM12878===<br />
<br />
To visualize bin level attention weights, the authors plotted the average bin-level attention weights for each HM for a specific cell type GM12878 (blood cell) for expressed (ON) genes and suppressed (OFF) genes separately. <br />
<br />
[[File: figure2.png|center|thumb| 700px |]]<br />
<br />
For the “ON” genes, the attention profiles are well defined for the HM marks, <math>H_{prom}</math>, <math>H_{enhc}</math>, <math>H_{struct}</math>. On the other hand, the weights are low for <math>H_{reprA}</math> and <math>H_{reprB}</math>. The average trend reverses for the “OFF” genes, where the repressor HM marks have more influence than the <math>H_{prom}</math>, <math>H_{enhc}</math>, <math>H_{struct}</math>. This observation agrees with the biologist finding that <math>H_{prom}</math>, <math>H_{enhc}</math>, <math>H_{struct}</math> marks stimulates gene activation and, <math>H_{reprA}</math> and <math>H_{reprB}</math> mark restrains the genes.<br />
<br />
=== Attention Weight of bins with <math>H_{active}</math>===<br />
<br />
The average read counts of <math>H_{active}</math> for the same 100 bins across all the active (ON) genes for the cell type GM12878 is plotted (FIGURE 2(b)). Besides, for AttentiveChrome the plot of bin-level attention weights of averaged over all the genes that are PREDICTED ON for GM12878 is also provided. The plots exhibit that the <math>H_{prom}</math> profile is similar to <math>H_{active}</math>.<br />
<br />
=== Visualization of HM-level Attention Weight for Gene PAX5 ===<br />
<br />
To visualize HM-level attention weight the authors produces a heatmap for a differentially regulated gene, PAX5, for the three aforementioned cell types. The heatmap is presented in FIGURE 2(c). PAX5 plays significant role in gene regulation when stem cells convert to blood cells. This gene is OFF in stem cells (H1-hESC), however it becomes activated when the stem cell is transformed into blood cell (GM12878). The <math>\beta_j</math> weight for <math>H_{repr}</math> is high when the gene is OFF in H1-hESC, and the weight decreases when the gene is ON in GM12878. On the contrary, for <math>H_{prom}</math> mark the <math>\beta_j</math> weight increases from H1-hESC to GM12878 as the gene becomes activated. This information extracted by the deep learning model is also supported by biological literature [16].<br />
<br />
= Related Works/Studies =<br />
<br />
In the last few years, deep learning models obtained models obtained unprecedented success in diverse research fields. Though as not rapidly as other fields, deep learning based algorithms are gaining popularity among bioinformaticians.<br />
<br />
== Attention-based Deep Models ==<br />
<br />
The idea of attention technique in deep learning is adapted from the human visual perception system. Humans tend to focus over some parts more than the others while perceiving a scene. This mechanism augmented with deep neural networks achieved an excellent outcome in several research topics, such as machine translation. Various types of attention models e.g., soft [6], or location-aware [7], or hard [8, 9] attentions have been proposed in the literature. In the soft attention model, a soft weight vector is calculated for the overall feature vectors. The extent of the weight is correlated with the degree of importance of the feature in the prediction. In practice, RNN is often used to help implement such models.<br />
<br />
== Visualization and Apprehension of Deep Models ==<br />
<br />
Prior studies mostly focused on interpreting convolutional neural networks (CNN) for image classification. Deconvulation approaches [10] attempt to map hidden layer representations back to an input space. Saliency maps [11, 12], attempt to use taylor expansion to approximate the network, and identify the most relevant input features. Class optimization [12] based visualization techniques attempt to find the best example member of each class. Some recent research works [13, 14] tried to understand recurrent neural networks (RNN) for text-based problems. By looking into the features the model attends to, we can interpret the output of a deep model.<br />
<br />
== Deep Learning in Bioinformatics ==<br />
Deep learning is also getting popular in bioinformatics fields because it is able to extract meaningful representations from datasets. Scholars use deep learning to model protein sequences and DNA sequences and predicting gene expressions.<br />
<br />
== Previous model for gene expression predictions ==<br />
There were multiple machine learning models had been used to predict gene expressions, such as linear regression and support vector machines. The strategies included using signal averaging across all relevant positions and selecting input signals at positions where was highly correlated to target gene expression and then use CNN to learn combinatorial interactions among histone modification marks.<br />
<br />
= Conclusion = <br />
<br />
The paper has introduced an attention-based approach called "AttentiveChrome" that deals with both understanding and prediction with several advantages on previous architectures including higher accuracy from state-of-the-art baselines, clearer interpretation than saliency map, which allows them to view what the model ‘sees’ during prediction, and class optimization. Another advantage of this approach is that it can model modular feature inputs which are sequentially structured. Finally, according to the authors, this is the first implementation of deep attention to understand gene regulation. AttentiveChrome is claimed to be the first attention based model applied on a molecular biology dataset. The authors expect that through this deep attention mechanism, the biologists can have a better understanding of epigenomic data. This model can handle understanding and prediction of hard to interpret biological data.<br />
<br />
= Critiques =<br />
<br />
This paper does not give a considerable algorithmic contribution. They have only used existing methods for this application. This deep learning based method is shown to perform better than simple machine learning models like linear regression and SVMs but this is considerably harder to implement and has many more hyperparameters to tune. The training time is considerably higher, especially because all the parameters are learned together. The dataset considered in the application here also seems to have only a limited number of samples for a study of high complexity. Model hyperparameters have been chosen randomly without any explanation of intuition for them. The authors have also not cited any relevant literature to understand where these numbers came from. <br />
<br />
Discussion about attention scores for interpretation does not provide any clear definition or mention previous literature using them. Reference of literature about H3K27ac, and how its read counts represent active region of a cell should be included. No reasoning given for why only one specific cell type is used to visualize bin level attention weights. Example of some other real world problems where this model can be useful should be provided.<br />
<br />
Moreover, this paper relies heavily on the intuition. Due to complicated structures, it must be challenging to provide algorithmic/theoretical justifications. This means that there is no proper guidence of how hyperparameters should be chosen or any kinds of treatment that the author performs on other data sets.<br />
<br />
= Additional Resources =<br />
<br />
# [https://qdata.github.io/deep4biomed-web/ Official DeepChrome Website]<br />
# [http://papers.nips.cc/paper/7255-attend-and-predict-understanding-gene-regulation-by-selective-attention-on-chromatin-supplemental.zip Supplemental Resources]<br />
# [https://github.com/QData/AttentiveChrome/blob/master/NIPS%20poster.pdf Poster]<br />
# [https://www.youtube.com/watch?v=tfgmXvSgsQE&feature=youtu.be Video Presentation]<br />
<br />
= Reference =<br />
<br />
[1] Andrew J Bannister and Tony Kouzarides. Regulation of chromatin by histone modifications. Cell Research, 21(3):381–395, 2011.<br />
<br />
[2] Anshul Kundaje, Wouter Meuleman, Jason Ernst, Misha Bilenky, Angela Yen, Alireza Heravi-Moussavi, Pouya Kheradpour, Zhizhuo Zhang, Jianrong Wang, Michael J Ziller, et al. Integrative analysis of 111 reference human epigenomes. Nature, 518(7539):317–330, 2015.<br />
<br />
[3] Singh, Ritambhara, et al. "Attend and Predict: Understanding Gene Regulation by Selective Attention on Chromatin." Advances in Neural Information Processing Systems. 2017.<br />
<br />
[4] Ritambhara Singh, Jack Lanchantin, Gabriel Robins, and Yanjun Qi. Deepchrome: deep-learning for predicting gene expression from histone modifications. Bioinformatics, 32(17):i639–i648, 2016.<br />
<br />
[5] Joanna Boros, Nausica Arnoult, Vincent Stroobant, Jean-François Collet, and Anabelle Decottignies. Polycomb repressive complex 2 and h3k27me3 cooperate with h3k9 methylation to maintain heterochromatin protein 1α at chromatin. Molecular and cellular biology, 34(19):3662–3674, 2014.<br />
<br />
[6] Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473, 2014.<br />
<br />
[7] Jan K Chorowski, Dzmitry Bahdanau, Dmitriy Serdyuk, Kyunghyun Cho, and Yoshua Bengio. Attention-based models for speech recognition. In C. Cortes, N. D. Lawrence, D. D. Lee, M. Sugiyama, and R. Garnett, editors, Advances in Neural Information Processing Systems 28, pages 577–585. Curran Associates, Inc., 2015.<br />
<br />
[8] Minh-Thang Luong, Hieu Pham, and Christopher D. Manning. Effective approaches to attention-based neural machine translation. In Empirical Methods in Natural Language Processing (EMNLP), pages 1412–1421, Lisbon, Portugal, September 2015. Association for Computational Linguistics.<br />
<br />
[9] Huijuan Xu and Kate Saenko. Ask, attend and answer: Exploring question-guided spatial attention for visual question answering. In ECCV, 2016.<br />
<br />
[10] Matthew D Zeiler and Rob Fergus. Visualizing and understanding convolutional networks. In Computer Vision–ECCV 2014, pages 818–833. Springer, 2014.<br />
<br />
[11] David Baehrens, Timon Schroeter, Stefan Harmeling, Motoaki Kawanabe, Katja Hansen, and Klaus-Robert MÃžller. How to explain individual classification decisions. volume 11, pages 1803–1831, 2010.<br />
<br />
[12] Karen Simonyan, Andrea Vedaldi, and Andrew Zisserman. Deep inside convolutional networks: Visualising image classification models and saliency maps. 2013.<br />
<br />
[13] Andrej Karpathy, Justin Johnson, and Fei-Fei Li. Visualizing and understanding recurrent networks. 2015.<br />
<br />
[14] Jiwei Li, Xinlei Chen, Eduard Hovy, and Dan Jurafsky. Visualizing and understanding neural models in nlp. 2015.<br />
<br />
[15] Xianjun Dong and Zhiping Weng. The correlation between histone modifications and gene expression. Epigenomics, 5(2):113–116, 2013.<br />
<br />
[16] Shane McManus, Anja Ebert, Giorgia Salvagiotto, Jasna Medvedovic, Qiong Sun, Ido Tamir, Markus Jaritz, Hiromi Tagoh, and Meinrad Busslinger. The transcription factor pax5 regulates its target genes by recruiting chromatin-modifying proteins in committed b cells. The EMBO journal, 30(12):2388–2404, 2011.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Attend_and_Predict:_Understanding_Gene_Regulation_by_Selective_Attention_on_Chromatin&diff=41971Attend and Predict: Understanding Gene Regulation by Selective Attention on Chromatin2018-11-30T00:23:20Z<p>Gsahu: /* Conclusion */</p>
<hr />
<div>This page contains a summary of the paper [https://arxiv.org/abs/1708.00339 "Attend and Predict: Understanding Gene Regulation by Selective Attention on Chromatin."] by Singh, Ritambhara, et al. It was published at the Advances in Neural Information Processing Systems (NIPS) in 2017. The code for this paper is shared here[https://qdata.github.io/deep4biomed-web/].<br />
<br />
<br />
= Background =<br />
<br />
Gene regulation is the process of controlling which genes in a cell's DNA are turned 'on' (expressed) or 'off' (not expressed). By this process, a functional product such as a protein is created. Even though all the cells of a multicellular organism (e.g., humans) contain the same DNA, different types of cells in that organism may express very different sets of genes. As a result, each cell types have distinct functionality. In other words how a cell operates depends upon the genes expressed in that cell. Many factors including ‘Chromatin modification marks’ influence which genes are abundant in that cell.<br />
<br />
The function of chromatin is to efficiently wraps DNA around bead-like structures of histones into a condensed volume to fit into the nucleus of a cell, and protect the DNA structure and sequence during cell division and replication. Different chemical modifications in the histones of the chromatin, known as histone marks, change spatial arrangement of the condensed DNA structure. Which in turn affects the gene’s expression of the histone mark’s neighboring region. Histone marks can promote (obstruct) the gene to be turned on by making the gene region accessible (restricted). This section of the DNA, where histone marks can potentially have an impact, is known as DNA flanking region or ‘gene region’ which is considered to cover 10k base pair centered at the transcription start site (TSS) (i.e., a 5k base pair in each direction). Unlike genetic mutations, histone modifications are reversible [1]. Therefore, understanding the influence of histone marks in determining gene regulation can assist in developing drugs for genetic diseases.<br />
<br />
= Introduction = <br />
<br />
Revolution in genomic technologies now enables us to profile genome-wide chromatin mark signals. Therefore, biologists can now measure gene expressions and chromatin signals of the ‘gene region’ for different cell types covering whole human genome. The Roadmap Epigenome Project (REMC, publicly available) [2] recently released 2,804 genome-wide datasets of 100 separate “normal” (not diseased) human cells/tissues, among which 166 datasets are gene expression reads and the rest are signal reads of various histone marks. The goal is to understand which histone marks are the most important and how they interact together in gene regulation for each cell type.<br />
<br />
Signal reads for histone marks are high-dimensional and spatially structured. Influence of a histone modification mark can be anywhere in the gene region (covering 10k base pairs centered around the Transcription Start Site of each gene). It is important to understand how the impact of the mark on gene expression varies over the gene region. In other words, how histone signals over the gene region impacts the gene expression. There are different types of histone marks in human chromatin that can have an influence on gene regulation. Researchers have found five standard histone proteins. These five histone proteins can be altered in different combinations with different chemical modifications resulting in a large number of distinct histone modification marks. Different histone modification marks can act as a module to interact with each other and influence the gene expression.<br />
<br />
<br />
This paper proposes an attention-based deep learning model to find how this chromatin factors/ histone modification marks contributes to the gene expression of a particular cell. AttentiveChrome[3] utilizes a hierarchy of multiple LSTM to discover interactions between signals of each histone marks, and learn dependencies among the marks on expressing a gene. The authors included two levels of soft attention mechanism, (1) to attend to the most relevant signals of a histone mark, and (2) to attend to the important marks and their interactions. In this context, ''attention'' refers to weighting the importance of different items differently.<br />
<br />
== Main Contributions ==<br />
The contributions of this work can be summarized as follows:<br />
<br />
* More accurate predictions than the state-of-the-art baselines. This is measured using datasets from REMC on 56 different cell types.<br />
* Better interpretation than the state-of-the-art methods for visualizing deep learning model. They compute the correlation of the attention scores of the model with the mark signal from REMC. <br />
* Like the application of attention models previously in indirectly hinting the parts of the input that the model deemed important, AttentiveChrome can too explain it's decisions by hinting at “what” and “where” it has focused.<br />
* This is the first time that the attention based deep learning approach is applied to a problem in molecular biology.<br />
* Ability to deal with highly modular inputs<br />
<br />
= Previous Works = <br />
<br />
Machine learning algorithms to classify gene expression from histone modification signals have been surveyed by [15]. These algorithms vary from linear regression, support vector machine, and random forests to rule-based learning, and CNNs. To accommodate the spatially structured, high dimensional input data (histone modification signals) these studies applied different feature selection strategies. The preceding research study, DeepChrome [4], by the authors incorporated the best position selection strategy. The positions that are highly correlated to the gene expression are considered as the best positions. This model can learn the relationship between the histone marks. This CNN based DeepChrome model outperforms all the previous works. However, these approaches either (1) failed to model the spatial dependencies among the marks, or (2) required additional feature analysis. Only AttentiveChrome is reported to satisfy all of the eight desirable metrics of a model.<br />
<br />
= AttentiveChrome: Model Formulation =<br />
<br />
The authors proposed an end-to-end architecture which has the ability to simultaneously attend and predict. This method incorporates recurrent neural networks (RNN) composed of LSTM units to model the sequential spatial dependencies of the gene regions and predict gene expression level from The embedding vector, <math> h_t </math>, output of an LSTM module encodes the learned representation of the feature dependencies from the time step 0 to <math> t </math>. For this task, each bin position of the gene region is considered as a time step.<br />
<br />
The proposed AttentiveChrome framework contains following 5 important modules:<br />
<br />
* Bin-level LSTM encoder encoding the bin positions of the gene region (one for each HM mark)<br />
* Bin-level <math> \alpha </math>-Attention across all bin positions (one for each HM mark)<br />
* HM-level LSTM encoder (one encoder encoding all HM marks)<br />
* HM-level <math> \beta </math>-Attention among all HM marks (one)<br />
* The final classification module<br />
<br />
Figure 1 (Supplementary Figure 2) presents an overview of the proposed AttentiveChrome framework.<br />
<br />
<br />
[[File:supplemntary_figure_2.png|thumb|center| 800px |Figure 1: Overview of the all five modules of the proposed AttentiveChrome framework]]<br />
<br />
<br />
<br />
== Input and Output ==<br />
<br />
Each dataset contains the gene expression labels and the histone signal reads for one specific cell type. The authors evaluated AttentiveChrome on 56 different cell types. For each mark, we have a feature/input vector containing the signals reads surrounding the gene’s TSS position (gene region) for the histone mark. The label of this input vector denotes the gene expression of the specific gene. This study considers binary labeling where <math> +1 </math> denotes gene is expressed (on) and <math> -1 </math> denotes that the gene is not expressed (off). Each histone marks will have one feature vector for each gene. The authors integrates the feature inputs and outputs of their previous work DeepChrome [4] into this research. The input feature is represented by a matrix <math> \textbf{X} </math> of size <math> M \times T </math>, where <math> M </math> is the number of HM marks considered in the input, and <math> T </math> is the number of bin positions taken into account to represent the gene region. The <math> j^{th} </math> row of the vector <math> \textbf{X} </math>, <math> x_j</math>, represents sequentially structured signals from the <math> j^{th} </math> HM mark, where <math> j\in \{1, \cdots, M\} </math>. Therefore, <math> x_j^t</math>, in the matrix <math> \textbf{X} </math> represents the value from the <math> t^{th}</math> bin belonging to the <math> j^{th} </math> HM mark, where <math> t\in \{1, \cdots, T\} </math>. If the training set contains <math>N_{tr} </math> labeled pairs, the <math> n^{th} </math> is specified as <math>( X^n, y^n)</math>, where <math> X^n </math> is a matrix of size <math> M \times T </math> and <math> y^n \in \{ -1, +1 \} </math> is the binary label, and <math> n \in \{ 1, \cdots, N_{tr} \} </math>.<br />
<br />
Figure 2 (also refer to Figure 1 (a), and 1(b) for better understanding) exhibits the input feature, and the output of AttentiveChrome for a particular gene (one sample).<br />
<br />
[[File:input-output-attentivechrome.png|center|thumb| 700px | Figure 2: Input and Output of the AttentiveChrome model]]<br />
<br />
== Bin-Level Encoder (one LSTM for each HM) ==<br />
The sequentially ordered elements (each element actually is a bin position) of the gene region of <math> n^{th} </math> gene is represented by the <math> j_{th} </math> row vector <math> x^j </math>. The authors considered each bin position as a time step for LSTM. This study incorporates bidirectional LSTM to model the overall dependencies among a total of <math> T </math> bin positions in the gene region. The bidirectional LSTM contains two LSTMs<br />
* A forward LSTM, <math> \overrightarrow{LSTM_j} </math>, to model <math> x^j </math> from <math> x_1^j </math> to <math> x_T^j </math>, which outputs the embedding vector <math> \overrightarrow{h^t_j} </math>, of size <math> d </math> for each bin <math> t </math><br />
* A reverse LSTM, <math> \overleftarrow{LSTM_j} </math>, to model <math> x^j </math> from <math> x_T^j </math> to <math> x_1^j </math>, which outputs the embedding vector <math> \overleftarrow{h^j_t} </math>, of size <math> d </math> for each bin <math> t </math><br />
<br />
The final output of this layer, embedding vector at <math> t^{th} </math> bin for the <math> j^{th} </math> HM, <math> h^j_t </math>, of size <math> d </math>, is obtained by concatenating the two vectors from the both directions. Therefore, <math> h^j_t = [ \overrightarrow{h^j_t}, \overleftarrow{h^j_t}]</math>. By pairing these LSTM-based HM encoders with the final classification, embedding each HM mark by drawing out the dependencies among bins can be learned by these pairs.Figure 1 (c) illustrates the module for <math> j=2 </math>.<br />
<br />
== Bin-Level <math> \alpha</math>-attention ==<br />
<br />
Each bin contributes differently in the encoding of the entire <math> j^{th} </math> mark. To automatically and adaptively highlight the most important bins for prediction, a soft attention weight vector <math> \alpha^j </math> of size <math> T </math> is learned for each <math> j </math>. To calculated the soft weight <math> \alpha^j_t </math>, for each <math> t </math>, the embedding vectors <math> \{h^j_1, \cdots, h^j_t \} </math> of all the bins are utilized. The following equation is used:<br />
<br />
<center><math> \alpha^j_t = \frac{exp(\textbf{W}_b h^j_t)}{\sum_{i=1}^T{exp(\textbf{W}_b h^j_i)}} </math></center><br />
<br />
<br />
<math> \alpha^j_t</math> is a scalar and is computed by all bins’ embedding vectors <math>h^j</math>. The parameter <math> W_b </math> is initialized randomly, and learned alongside during the process with the other model parameters. Therefore, once we have importance weight of each bin position, the <math> j^{th} </math> HM mark can be represented by <math> m^j = \sum_{t=1}^T{\alpha^j_t \times h^j_t}</math>. Here, <math> h^j_t</math> is the embedding vector and <math> \alpha^t_j </math> is the importance weight of the <math> t^{th} </math> bin in the representation of the <math> j^{th} </math> HM mark. Intuitively <math> \textbf{W}_b </math> will learn the cell type. Figure 1(d) shows this module for <math> HM_2 </math>.<br />
<br />
== HM-level Encoder (one LSTM) ==<br />
<br />
Studies observed that HMs work cooperatively to provoke or subdue gene expression [5]. The HM-level encoder (not in the fFgure 1) utilizes one bidirectional LSTM to capture this relationship between the HMs. To formulate the sequential dependency a random sequence is imagined as the authors did not find influence of any specific ordering of the HMs. The representation <math> m_j </math>of the <math> j^{th} </math> HM, <math> HM_j </math>, which is calculated from the bin-level attention layer, is the input of this step. This set based encoder outputs an embedding vector <math> s^j </math> of size <math> d’ </math>, which is the encoding for the <math> j^{th} </math> HM.<br />
<br />
<math> s^j = [ \overrightarrow{LSTM_s}(m_j), \overleftarrow{LSTM_s}(m_j) ] </math><br />
<br />
The dependencies between <math> j^{th} </math> HM and the other HM marks are encoded in <math> s^j </math>, whereas <math> m^j </math> from the previous step encodes the bin dependencies of the <math> j^{th} </math> HM.<br />
<br />
<br />
== HM-Level <math> \beta</math>-attention ==<br />
This second soft attention level (Figure 1(e)) finds the important HM marks for classifying a gene’s expression by learning the importance weights, <math> \beta_j </math>, for each <math> HM_j </math>, where <math> j \in \{ 1, \cdots, M \} </math>. The equation is <br />
<br />
<math> \beta^j = \frac{exp(\textbf{W}_s s^j)}{\sum_{i=1}^M{exp(\textbf{W}_s s^j)}} </math><br />
<br />
The HM-level context parameter <math> \textbf{W}_s </math> is trained jointly in the process. Intuitively <math> \textbf{W}_s </math> learns how the HMs are significant for a cell type. Finally the entire gene region is encoded in a hidden representation <math> \textbf{v} </math>, using the weighted sum of the embedding of all HM marks. <br />
<br />
<br />
<math> \textbf{v} = \sum_{j=1}^MT{\beta^j \times s^j}</math><br />
<br />
== End-to-end training ==<br />
<br />
The embedding vector <math> \textbf{v} </math> is fed to a simple classification module, <math> f(\textbf{v}) = </math>softmax<math> (\textbf{W}_c\textbf{v}+b_c) </math>, where <math> \textbf{W}_c </math>, and <math> b_c </math> are learnable parameters. The output is the probability of gene expression being high (expressed) or low (suppressed).<br />
The whole model including the attention modules is differentiable. Thus backpropagation can perform end-to-end learning trivially. The negative log-likelihood loss function is minimized in the learning.<br />
<br />
= Experimental Settings =<br />
<br />
This work makes use of the REMC dataset. AttentiveChrome is evaluated on 56 different cell types. Similar to DeepChrome, this study considered the following five core HM marks (<math> M=5 </math>). Because these selected marks are uniformly profiled across all 56 cell types in the REMC study.<br />
<br />
[[File:HM.png|center|thumb| 700px | Table 1: Five core HM marks and their attributes considered in this paper]]<br />
<br />
<br />
<br />
For a gene region 10k base pairs centred at the TSS site (5k bp in each direction) are taken into account. These 10k base pairs are divided into 100 bins, each bin consisting of <math> T=100 </math> continuous bp). Therefore, for each gene in a particular cell type, the input matrix will be of size <math> 5 \times 100 </math>. The gene expression labels are normalized and discretized to represent binary labelling. The sample dataset is divided into three equal sized folds for training, validation, and testing.<br />
<br />
== Model Variations and Two Baselines ==<br />
To evaluate the performance of the proposed model the authors considered RNN method (direct LSTM without any attention), and their prior work DeepChrome as baselines. The results obtained from multiple variations of the AttentiveChrome model are compared with the baselines. The authors considered five variant of AttentiveChrome during performance evaluation. The variants are:<br />
<br />
* LSTM-Attn: one LSTM with attention on the input matrix (does not consider the modular nature of HM marks)<br />
* CNN-Attn: DeepChrome [4] with one attention mechanism incorporated. <br />
* LSTM-<math>\alpha , \beta</math>: the proposed architecture.<br />
* CNN-<math>\alpha , \beta</math>: LSTM module of the proposed architecture replaced with CNN. This variation includes two attention mechanisms. First attention mechanism contains one <math>\alpha</math>-attention on top of a CNN module per HM mark. And, the second -<math>\beta</math>- attention mechanism is used to combine HMs.<br />
* LSTM-<math>\alpha</math>: one LSTM and <math>\alpha</math>-attention per HM mark.<br />
<br />
== Hyperparameters ==<br />
<br />
For all the variants of AttentiveChrome the bin-level LSTM embedding size <math> d</math> is set to 32, and the HM-level LSTM embedding size <math>d’</math> is set to 16. Because of bidirectional LSTM, the size of the embedding vector <math> h_t</math>, and <math>m_j</math> will be 64, and 32 respectively. Size of the context vectors are set accordingly.<br />
<br />
= Performance Evaluation =<br />
<br />
== AUC Scores ==<br />
<br />
This study summarizes AUC scores across all 56 cell types on the test set to compare the methods.<br />
<br />
[[File:AUC.JPG|center|thumb| 700px | Table 2: AUC score performances for different variations of AttentiveChrome and baselines]]<br />
<br />
Overall the LSTM-attention models perform better than the DeepChrome (CNN-based) and LSTM baselines. The authors argue that the proposed AttentiveChrome model is a good choice because of its interpretability, even though the performance improvement from DeepChrome is insignificant.<br />
<br />
== Evaluation of Attention Scores for Interpretation ==<br />
<br />
To understand if the model is focusing on the right regions, the authors make use of additional study results from REMC database. To validate the bin attention,signal data of a new histone mark, H3K27ac, referred to as <math>H_{active}</math> in this article, from REMC database is utilized. This particular histone mark is known to mark active region when the gene is expressed (ON). Genome-wide read of this HM mark is available for three important cell types: stem cell (H1-hESC), blood cell (GM12878), and leukemia cell (K562). This particular HM mark is used to analyze the visualization results only and not applied in the learning phase. The authors discussed performance of both the attention mechanisms in this section. <br />
<br />
=== Correlation of Importance Weight of <math>H_{prom}</math> with <math>H_{active}</math> ===<br />
<br />
Average read count of <math>H_{active}</math> across all 100 bins for all the active genes (ON or labeled as <math>+1</math>) in the three selected cell types is calculated. The proposed AttentiveChrome and LSTM-<math>\alpha</math> methods are compared with two widely used visualization techniques, (1) class based, and (2) saliency map applied on the baseline DeepChrome model (CNN-based prior work). Using these visualization methods, the authors calculate the importance weights for <math>H_{prom}</math> (promoter HM mark used in training) across the 100 bins. The Pearson Correlation score between these importance weights and the read count of the <math>H_{active}</math> (HM mark for validation) across the same 100 bins is computed. The <math>H_{active}</math> read counts indicates the actual active regions of those cells. <br />
<br />
[[File: pc.JPG|center|thumb| 700px | Figure 4: Pearson Correlation between a known active HM mark]]<br />
<br />
<br />
The results indicate that the proposed models consistently gained highest correlation with <math>H_{active}</math> for all three cell types. Thus, the proposed method is successful to capture the important signals.<br />
<br />
=== Visualization of Attention Weight of bins for each HM of a specific cell type GM12878===<br />
<br />
To visualize bin level attention weights, the authors plotted the average bin-level attention weights for each HM for a specific cell type GM12878 (blood cell) for expressed (ON) genes and suppressed (OFF) genes separately. <br />
<br />
[[File: figure2.png|center|thumb| 700px |]]<br />
<br />
For the “ON” genes, the attention profiles are well defined for the HM marks, <math>H_{prom}</math>, <math>H_{enhc}</math>, <math>H_{struct}</math>. On the other hand, the weights are low for <math>H_{reprA}</math> and <math>H_{reprB}</math>. The average trend reverses for the “OFF” genes, where the repressor HM marks have more influence than the <math>H_{prom}</math>, <math>H_{enhc}</math>, <math>H_{struct}</math>. This observation agrees with the biologist finding that <math>H_{prom}</math>, <math>H_{enhc}</math>, <math>H_{struct}</math> marks stimulates gene activation and, <math>H_{reprA}</math> and <math>H_{reprB}</math> mark restrains the genes.<br />
<br />
=== Attention Weight of bins with <math>H_{active}</math>===<br />
<br />
The average read counts of <math>H_{active}</math> for the same 100 bins across all the active (ON) genes for the cell type GM12878 is plotted (FIGURE 2(b)). Besides, for AttentiveChrome the plot of bin-level attention weights of averaged over all the genes that are PREDICTED ON for GM12878 is also provided. The plots exhibit that the <math>H_{prom}</math> profile is similar to <math>H_{active}</math>.<br />
<br />
=== Visualization of HM-level Attention Weight for Gene PAX5 ===<br />
<br />
To visualize HM-level attention weight the authors produces a heatmap for a differentially regulated gene, PAX5, for the three aforementioned cell types. The heatmap is presented in FIGURE 2(c). PAX5 plays significant role in gene regulation when stem cells convert to blood cells. This gene is OFF in stem cells (H1-hESC), however it becomes activated when the stem cell is transformed into blood cell (GM12878). The <math>\beta_j</math> weight for <math>H_{repr}</math> is high when the gene is OFF in H1-hESC, and the weight decreases when the gene is ON in GM12878. On the contrary, for <math>H_{prom}</math> mark the <math>\beta_j</math> weight increases from H1-hESC to GM12878 as the gene becomes activated. This information extracted by the deep learning model is also supported by biological literature [16].<br />
<br />
= Related Works/Studies =<br />
<br />
In the last few years, deep learning models obtained models obtained unprecedented success in diverse research fields. Though as not rapidly as other fields, deep learning based algorithms are gaining popularity among bioinformaticians.<br />
<br />
== Attention-based Deep Models ==<br />
<br />
The idea of attention technique in deep learning is adapted from the human visual perception system. Humans tend to focus over some parts more than the others while perceiving a scene. This mechanism augmented with deep neural networks achieved an excellent outcome in several research topics, such as machine translation. Various types of attention models e.g., soft [6], or location-aware [7], or hard [8, 9] attentions have been proposed in the literature. In the soft attention model, a soft weight vector is calculated for the overall feature vectors. The extent of the weight is correlated with the degree of importance of the feature in the prediction. In practice, RNN is often used to help implement such models.<br />
<br />
== Visualization and Apprehension of Deep Models ==<br />
<br />
Prior studies mostly focused on interpreting convolutional neural networks (CNN) for image classification. Deconvulation approaches [10] attempt to map hidden layer representations back to an input space. Saliency maps [11, 12], attempt to use taylor expansion to approximate the network, and identify the most relevant input features. Class optimization [12] based visualization techniques attempt to find the best example member of each class. Some recent research works [13, 14] tried to understand recurrent neural networks (RNN) for text-based problems. By looking into the features the model attends to, we can interpret the output of a deep model.<br />
<br />
== Deep Learning in Bioinformatics ==<br />
Deep learning is also getting popular in bioinformatics fields because it is able to extract meaningful representations from datasets. Scholars use deep learning to model protein sequences and DNA sequences and predicting gene expressions.<br />
<br />
== Previous model for gene expression predictions ==<br />
There were multiple machine learning models had been used to predict gene expressions, such as linear regression and support vector machines. The strategies included using signal averaging across all relevant positions and selecting input signals at positions where was highly correlated to target gene expression and then use CNN to learn combinatorial interactions among histone modification marks.<br />
<br />
= Conclusion = <br />
<br />
The paper has introduced an attention-based approach called "AttentiveChrome" that deals with both understanding and prediction with several advantages on previous architectures including higher accuracy from state-of-the-art baselines, clearer interpretation than saliency map, which allows them to view what the model ‘sees’ during prediction, and class optimization. Another advantage of this approach is that it can model modular feature inputs which are sequentially structured. Finally, according to the authors, this is the first implementation of deep attention to understand gene regulation. AttentiveChrome is claimed to be the first attention based model applied on a molecular biology dataset. The authors expect that through this deep attention mechanism the biologists can have a better understanding of epigenomic data. This model can handle understanding and prediction of hard to interpret biological data.<br />
<br />
= Critiques =<br />
<br />
This paper does not give a considerable algorithmic contribution. They have only used existing methods for this application. This deep learning based method is shown to perform better than simple machine learning models like linear regression and SVMs but this is considerably harder to implement and has many more hyperparameters to tune. The training time is considerably higher, especially because all the parameters are learned together. The dataset considered in the application here also seems to have only a limited number of samples for a study of high complexity. Model hyperparameters have been chosen randomly without any explanation of intuition for them. The authors have also not cited any relevant literature to understand where these numbers came from. <br />
<br />
Discussion about attention scores for interpretation does not provide any clear definition or mention previous literature using them. Reference of literature about H3K27ac, and how its read counts represent active region of a cell should be included. No reasoning given for why only one specific cell type is used to visualize bin level attention weights. Example of some other real world problems where this model can be useful should be provided.<br />
<br />
Moreover, this paper relies heavily on the intuition. Due to complicated structures, it must be challenging to provide algorithmic/theoretical justifications. This means that there is no proper guidence of how hyperparameters should be chosen or any kinds of treatment that the author performs on other data sets.<br />
<br />
= Additional Resources =<br />
<br />
# [https://qdata.github.io/deep4biomed-web/ Official DeepChrome Website]<br />
# [http://papers.nips.cc/paper/7255-attend-and-predict-understanding-gene-regulation-by-selective-attention-on-chromatin-supplemental.zip Supplemental Resources]<br />
# [https://github.com/QData/AttentiveChrome/blob/master/NIPS%20poster.pdf Poster]<br />
# [https://www.youtube.com/watch?v=tfgmXvSgsQE&feature=youtu.be Video Presentation]<br />
<br />
= Reference =<br />
<br />
[1] Andrew J Bannister and Tony Kouzarides. Regulation of chromatin by histone modifications. Cell Research, 21(3):381–395, 2011.<br />
<br />
[2] Anshul Kundaje, Wouter Meuleman, Jason Ernst, Misha Bilenky, Angela Yen, Alireza Heravi-Moussavi, Pouya Kheradpour, Zhizhuo Zhang, Jianrong Wang, Michael J Ziller, et al. Integrative analysis of 111 reference human epigenomes. Nature, 518(7539):317–330, 2015.<br />
<br />
[3] Singh, Ritambhara, et al. "Attend and Predict: Understanding Gene Regulation by Selective Attention on Chromatin." Advances in Neural Information Processing Systems. 2017.<br />
<br />
[4] Ritambhara Singh, Jack Lanchantin, Gabriel Robins, and Yanjun Qi. Deepchrome: deep-learning for predicting gene expression from histone modifications. Bioinformatics, 32(17):i639–i648, 2016.<br />
<br />
[5] Joanna Boros, Nausica Arnoult, Vincent Stroobant, Jean-François Collet, and Anabelle Decottignies. Polycomb repressive complex 2 and h3k27me3 cooperate with h3k9 methylation to maintain heterochromatin protein 1α at chromatin. Molecular and cellular biology, 34(19):3662–3674, 2014.<br />
<br />
[6] Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473, 2014.<br />
<br />
[7] Jan K Chorowski, Dzmitry Bahdanau, Dmitriy Serdyuk, Kyunghyun Cho, and Yoshua Bengio. Attention-based models for speech recognition. In C. Cortes, N. D. Lawrence, D. D. Lee, M. Sugiyama, and R. Garnett, editors, Advances in Neural Information Processing Systems 28, pages 577–585. Curran Associates, Inc., 2015.<br />
<br />
[8] Minh-Thang Luong, Hieu Pham, and Christopher D. Manning. Effective approaches to attention-based neural machine translation. In Empirical Methods in Natural Language Processing (EMNLP), pages 1412–1421, Lisbon, Portugal, September 2015. Association for Computational Linguistics.<br />
<br />
[9] Huijuan Xu and Kate Saenko. Ask, attend and answer: Exploring question-guided spatial attention for visual question answering. In ECCV, 2016.<br />
<br />
[10] Matthew D Zeiler and Rob Fergus. Visualizing and understanding convolutional networks. In Computer Vision–ECCV 2014, pages 818–833. Springer, 2014.<br />
<br />
[11] David Baehrens, Timon Schroeter, Stefan Harmeling, Motoaki Kawanabe, Katja Hansen, and Klaus-Robert MÃžller. How to explain individual classification decisions. volume 11, pages 1803–1831, 2010.<br />
<br />
[12] Karen Simonyan, Andrea Vedaldi, and Andrew Zisserman. Deep inside convolutional networks: Visualising image classification models and saliency maps. 2013.<br />
<br />
[13] Andrej Karpathy, Justin Johnson, and Fei-Fei Li. Visualizing and understanding recurrent networks. 2015.<br />
<br />
[14] Jiwei Li, Xinlei Chen, Eduard Hovy, and Dan Jurafsky. Visualizing and understanding neural models in nlp. 2015.<br />
<br />
[15] Xianjun Dong and Zhiping Weng. The correlation between histone modifications and gene expression. Epigenomics, 5(2):113–116, 2013.<br />
<br />
[16] Shane McManus, Anja Ebert, Giorgia Salvagiotto, Jasna Medvedovic, Qiong Sun, Ido Tamir, Markus Jaritz, Hiromi Tagoh, and Meinrad Busslinger. The transcription factor pax5 regulates its target genes by recruiting chromatin-modifying proteins in committed b cells. The EMBO journal, 30(12):2388–2404, 2011.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Attend_and_Predict:_Understanding_Gene_Regulation_by_Selective_Attention_on_Chromatin&diff=41970Attend and Predict: Understanding Gene Regulation by Selective Attention on Chromatin2018-11-30T00:22:19Z<p>Gsahu: /* Bin-Level \alpha-attention */</p>
<hr />
<div>This page contains a summary of the paper [https://arxiv.org/abs/1708.00339 "Attend and Predict: Understanding Gene Regulation by Selective Attention on Chromatin."] by Singh, Ritambhara, et al. It was published at the Advances in Neural Information Processing Systems (NIPS) in 2017. The code for this paper is shared here[https://qdata.github.io/deep4biomed-web/].<br />
<br />
<br />
= Background =<br />
<br />
Gene regulation is the process of controlling which genes in a cell's DNA are turned 'on' (expressed) or 'off' (not expressed). By this process, a functional product such as a protein is created. Even though all the cells of a multicellular organism (e.g., humans) contain the same DNA, different types of cells in that organism may express very different sets of genes. As a result, each cell types have distinct functionality. In other words how a cell operates depends upon the genes expressed in that cell. Many factors including ‘Chromatin modification marks’ influence which genes are abundant in that cell.<br />
<br />
The function of chromatin is to efficiently wraps DNA around bead-like structures of histones into a condensed volume to fit into the nucleus of a cell, and protect the DNA structure and sequence during cell division and replication. Different chemical modifications in the histones of the chromatin, known as histone marks, change spatial arrangement of the condensed DNA structure. Which in turn affects the gene’s expression of the histone mark’s neighboring region. Histone marks can promote (obstruct) the gene to be turned on by making the gene region accessible (restricted). This section of the DNA, where histone marks can potentially have an impact, is known as DNA flanking region or ‘gene region’ which is considered to cover 10k base pair centered at the transcription start site (TSS) (i.e., a 5k base pair in each direction). Unlike genetic mutations, histone modifications are reversible [1]. Therefore, understanding the influence of histone marks in determining gene regulation can assist in developing drugs for genetic diseases.<br />
<br />
= Introduction = <br />
<br />
Revolution in genomic technologies now enables us to profile genome-wide chromatin mark signals. Therefore, biologists can now measure gene expressions and chromatin signals of the ‘gene region’ for different cell types covering whole human genome. The Roadmap Epigenome Project (REMC, publicly available) [2] recently released 2,804 genome-wide datasets of 100 separate “normal” (not diseased) human cells/tissues, among which 166 datasets are gene expression reads and the rest are signal reads of various histone marks. The goal is to understand which histone marks are the most important and how they interact together in gene regulation for each cell type.<br />
<br />
Signal reads for histone marks are high-dimensional and spatially structured. Influence of a histone modification mark can be anywhere in the gene region (covering 10k base pairs centered around the Transcription Start Site of each gene). It is important to understand how the impact of the mark on gene expression varies over the gene region. In other words, how histone signals over the gene region impacts the gene expression. There are different types of histone marks in human chromatin that can have an influence on gene regulation. Researchers have found five standard histone proteins. These five histone proteins can be altered in different combinations with different chemical modifications resulting in a large number of distinct histone modification marks. Different histone modification marks can act as a module to interact with each other and influence the gene expression.<br />
<br />
<br />
This paper proposes an attention-based deep learning model to find how this chromatin factors/ histone modification marks contributes to the gene expression of a particular cell. AttentiveChrome[3] utilizes a hierarchy of multiple LSTM to discover interactions between signals of each histone marks, and learn dependencies among the marks on expressing a gene. The authors included two levels of soft attention mechanism, (1) to attend to the most relevant signals of a histone mark, and (2) to attend to the important marks and their interactions. In this context, ''attention'' refers to weighting the importance of different items differently.<br />
<br />
== Main Contributions ==<br />
The contributions of this work can be summarized as follows:<br />
<br />
* More accurate predictions than the state-of-the-art baselines. This is measured using datasets from REMC on 56 different cell types.<br />
* Better interpretation than the state-of-the-art methods for visualizing deep learning model. They compute the correlation of the attention scores of the model with the mark signal from REMC. <br />
* Like the application of attention models previously in indirectly hinting the parts of the input that the model deemed important, AttentiveChrome can too explain it's decisions by hinting at “what” and “where” it has focused.<br />
* This is the first time that the attention based deep learning approach is applied to a problem in molecular biology.<br />
* Ability to deal with highly modular inputs<br />
<br />
= Previous Works = <br />
<br />
Machine learning algorithms to classify gene expression from histone modification signals have been surveyed by [15]. These algorithms vary from linear regression, support vector machine, and random forests to rule-based learning, and CNNs. To accommodate the spatially structured, high dimensional input data (histone modification signals) these studies applied different feature selection strategies. The preceding research study, DeepChrome [4], by the authors incorporated the best position selection strategy. The positions that are highly correlated to the gene expression are considered as the best positions. This model can learn the relationship between the histone marks. This CNN based DeepChrome model outperforms all the previous works. However, these approaches either (1) failed to model the spatial dependencies among the marks, or (2) required additional feature analysis. Only AttentiveChrome is reported to satisfy all of the eight desirable metrics of a model.<br />
<br />
= AttentiveChrome: Model Formulation =<br />
<br />
The authors proposed an end-to-end architecture which has the ability to simultaneously attend and predict. This method incorporates recurrent neural networks (RNN) composed of LSTM units to model the sequential spatial dependencies of the gene regions and predict gene expression level from The embedding vector, <math> h_t </math>, output of an LSTM module encodes the learned representation of the feature dependencies from the time step 0 to <math> t </math>. For this task, each bin position of the gene region is considered as a time step.<br />
<br />
The proposed AttentiveChrome framework contains following 5 important modules:<br />
<br />
* Bin-level LSTM encoder encoding the bin positions of the gene region (one for each HM mark)<br />
* Bin-level <math> \alpha </math>-Attention across all bin positions (one for each HM mark)<br />
* HM-level LSTM encoder (one encoder encoding all HM marks)<br />
* HM-level <math> \beta </math>-Attention among all HM marks (one)<br />
* The final classification module<br />
<br />
Figure 1 (Supplementary Figure 2) presents an overview of the proposed AttentiveChrome framework.<br />
<br />
<br />
[[File:supplemntary_figure_2.png|thumb|center| 800px |Figure 1: Overview of the all five modules of the proposed AttentiveChrome framework]]<br />
<br />
<br />
<br />
== Input and Output ==<br />
<br />
Each dataset contains the gene expression labels and the histone signal reads for one specific cell type. The authors evaluated AttentiveChrome on 56 different cell types. For each mark, we have a feature/input vector containing the signals reads surrounding the gene’s TSS position (gene region) for the histone mark. The label of this input vector denotes the gene expression of the specific gene. This study considers binary labeling where <math> +1 </math> denotes gene is expressed (on) and <math> -1 </math> denotes that the gene is not expressed (off). Each histone marks will have one feature vector for each gene. The authors integrates the feature inputs and outputs of their previous work DeepChrome [4] into this research. The input feature is represented by a matrix <math> \textbf{X} </math> of size <math> M \times T </math>, where <math> M </math> is the number of HM marks considered in the input, and <math> T </math> is the number of bin positions taken into account to represent the gene region. The <math> j^{th} </math> row of the vector <math> \textbf{X} </math>, <math> x_j</math>, represents sequentially structured signals from the <math> j^{th} </math> HM mark, where <math> j\in \{1, \cdots, M\} </math>. Therefore, <math> x_j^t</math>, in the matrix <math> \textbf{X} </math> represents the value from the <math> t^{th}</math> bin belonging to the <math> j^{th} </math> HM mark, where <math> t\in \{1, \cdots, T\} </math>. If the training set contains <math>N_{tr} </math> labeled pairs, the <math> n^{th} </math> is specified as <math>( X^n, y^n)</math>, where <math> X^n </math> is a matrix of size <math> M \times T </math> and <math> y^n \in \{ -1, +1 \} </math> is the binary label, and <math> n \in \{ 1, \cdots, N_{tr} \} </math>.<br />
<br />
Figure 2 (also refer to Figure 1 (a), and 1(b) for better understanding) exhibits the input feature, and the output of AttentiveChrome for a particular gene (one sample).<br />
<br />
[[File:input-output-attentivechrome.png|center|thumb| 700px | Figure 2: Input and Output of the AttentiveChrome model]]<br />
<br />
== Bin-Level Encoder (one LSTM for each HM) ==<br />
The sequentially ordered elements (each element actually is a bin position) of the gene region of <math> n^{th} </math> gene is represented by the <math> j_{th} </math> row vector <math> x^j </math>. The authors considered each bin position as a time step for LSTM. This study incorporates bidirectional LSTM to model the overall dependencies among a total of <math> T </math> bin positions in the gene region. The bidirectional LSTM contains two LSTMs<br />
* A forward LSTM, <math> \overrightarrow{LSTM_j} </math>, to model <math> x^j </math> from <math> x_1^j </math> to <math> x_T^j </math>, which outputs the embedding vector <math> \overrightarrow{h^t_j} </math>, of size <math> d </math> for each bin <math> t </math><br />
* A reverse LSTM, <math> \overleftarrow{LSTM_j} </math>, to model <math> x^j </math> from <math> x_T^j </math> to <math> x_1^j </math>, which outputs the embedding vector <math> \overleftarrow{h^j_t} </math>, of size <math> d </math> for each bin <math> t </math><br />
<br />
The final output of this layer, embedding vector at <math> t^{th} </math> bin for the <math> j^{th} </math> HM, <math> h^j_t </math>, of size <math> d </math>, is obtained by concatenating the two vectors from the both directions. Therefore, <math> h^j_t = [ \overrightarrow{h^j_t}, \overleftarrow{h^j_t}]</math>. By pairing these LSTM-based HM encoders with the final classification, embedding each HM mark by drawing out the dependencies among bins can be learned by these pairs.Figure 1 (c) illustrates the module for <math> j=2 </math>.<br />
<br />
== Bin-Level <math> \alpha</math>-attention ==<br />
<br />
Each bin contributes differently in the encoding of the entire <math> j^{th} </math> mark. To automatically and adaptively highlight the most important bins for prediction, a soft attention weight vector <math> \alpha^j </math> of size <math> T </math> is learned for each <math> j </math>. To calculated the soft weight <math> \alpha^j_t </math>, for each <math> t </math>, the embedding vectors <math> \{h^j_1, \cdots, h^j_t \} </math> of all the bins are utilized. The following equation is used:<br />
<br />
<center><math> \alpha^j_t = \frac{exp(\textbf{W}_b h^j_t)}{\sum_{i=1}^T{exp(\textbf{W}_b h^j_i)}} </math></center><br />
<br />
<br />
<math> \alpha^j_t</math> is a scalar and is computed by all bins’ embedding vectors <math>h^j</math>. The parameter <math> W_b </math> is initialized randomly, and learned alongside during the process with the other model parameters. Therefore, once we have importance weight of each bin position, the <math> j^{th} </math> HM mark can be represented by <math> m^j = \sum_{t=1}^T{\alpha^j_t \times h^j_t}</math>. Here, <math> h^j_t</math> is the embedding vector and <math> \alpha^t_j </math> is the importance weight of the <math> t^{th} </math> bin in the representation of the <math> j^{th} </math> HM mark. Intuitively <math> \textbf{W}_b </math> will learn the cell type. Figure 1(d) shows this module for <math> HM_2 </math>.<br />
<br />
== HM-level Encoder (one LSTM) ==<br />
<br />
Studies observed that HMs work cooperatively to provoke or subdue gene expression [5]. The HM-level encoder (not in the fFgure 1) utilizes one bidirectional LSTM to capture this relationship between the HMs. To formulate the sequential dependency a random sequence is imagined as the authors did not find influence of any specific ordering of the HMs. The representation <math> m_j </math>of the <math> j^{th} </math> HM, <math> HM_j </math>, which is calculated from the bin-level attention layer, is the input of this step. This set based encoder outputs an embedding vector <math> s^j </math> of size <math> d’ </math>, which is the encoding for the <math> j^{th} </math> HM.<br />
<br />
<math> s^j = [ \overrightarrow{LSTM_s}(m_j), \overleftarrow{LSTM_s}(m_j) ] </math><br />
<br />
The dependencies between <math> j^{th} </math> HM and the other HM marks are encoded in <math> s^j </math>, whereas <math> m^j </math> from the previous step encodes the bin dependencies of the <math> j^{th} </math> HM.<br />
<br />
<br />
== HM-Level <math> \beta</math>-attention ==<br />
This second soft attention level (Figure 1(e)) finds the important HM marks for classifying a gene’s expression by learning the importance weights, <math> \beta_j </math>, for each <math> HM_j </math>, where <math> j \in \{ 1, \cdots, M \} </math>. The equation is <br />
<br />
<math> \beta^j = \frac{exp(\textbf{W}_s s^j)}{\sum_{i=1}^M{exp(\textbf{W}_s s^j)}} </math><br />
<br />
The HM-level context parameter <math> \textbf{W}_s </math> is trained jointly in the process. Intuitively <math> \textbf{W}_s </math> learns how the HMs are significant for a cell type. Finally the entire gene region is encoded in a hidden representation <math> \textbf{v} </math>, using the weighted sum of the embedding of all HM marks. <br />
<br />
<br />
<math> \textbf{v} = \sum_{j=1}^MT{\beta^j \times s^j}</math><br />
<br />
== End-to-end training ==<br />
<br />
The embedding vector <math> \textbf{v} </math> is fed to a simple classification module, <math> f(\textbf{v}) = </math>softmax<math> (\textbf{W}_c\textbf{v}+b_c) </math>, where <math> \textbf{W}_c </math>, and <math> b_c </math> are learnable parameters. The output is the probability of gene expression being high (expressed) or low (suppressed).<br />
The whole model including the attention modules is differentiable. Thus backpropagation can perform end-to-end learning trivially. The negative log-likelihood loss function is minimized in the learning.<br />
<br />
= Experimental Settings =<br />
<br />
This work makes use of the REMC dataset. AttentiveChrome is evaluated on 56 different cell types. Similar to DeepChrome, this study considered the following five core HM marks (<math> M=5 </math>). Because these selected marks are uniformly profiled across all 56 cell types in the REMC study.<br />
<br />
[[File:HM.png|center|thumb| 700px | Table 1: Five core HM marks and their attributes considered in this paper]]<br />
<br />
<br />
<br />
For a gene region 10k base pairs centred at the TSS site (5k bp in each direction) are taken into account. These 10k base pairs are divided into 100 bins, each bin consisting of <math> T=100 </math> continuous bp). Therefore, for each gene in a particular cell type, the input matrix will be of size <math> 5 \times 100 </math>. The gene expression labels are normalized and discretized to represent binary labelling. The sample dataset is divided into three equal sized folds for training, validation, and testing.<br />
<br />
== Model Variations and Two Baselines ==<br />
To evaluate the performance of the proposed model the authors considered RNN method (direct LSTM without any attention), and their prior work DeepChrome as baselines. The results obtained from multiple variations of the AttentiveChrome model are compared with the baselines. The authors considered five variant of AttentiveChrome during performance evaluation. The variants are:<br />
<br />
* LSTM-Attn: one LSTM with attention on the input matrix (does not consider the modular nature of HM marks)<br />
* CNN-Attn: DeepChrome [4] with one attention mechanism incorporated. <br />
* LSTM-<math>\alpha , \beta</math>: the proposed architecture.<br />
* CNN-<math>\alpha , \beta</math>: LSTM module of the proposed architecture replaced with CNN. This variation includes two attention mechanisms. First attention mechanism contains one <math>\alpha</math>-attention on top of a CNN module per HM mark. And, the second -<math>\beta</math>- attention mechanism is used to combine HMs.<br />
* LSTM-<math>\alpha</math>: one LSTM and <math>\alpha</math>-attention per HM mark.<br />
<br />
== Hyperparameters ==<br />
<br />
For all the variants of AttentiveChrome the bin-level LSTM embedding size <math> d</math> is set to 32, and the HM-level LSTM embedding size <math>d’</math> is set to 16. Because of bidirectional LSTM, the size of the embedding vector <math> h_t</math>, and <math>m_j</math> will be 64, and 32 respectively. Size of the context vectors are set accordingly.<br />
<br />
= Performance Evaluation =<br />
<br />
== AUC Scores ==<br />
<br />
This study summarizes AUC scores across all 56 cell types on the test set to compare the methods.<br />
<br />
[[File:AUC.JPG|center|thumb| 700px | Table 2: AUC score performances for different variations of AttentiveChrome and baselines]]<br />
<br />
Overall the LSTM-attention models perform better than the DeepChrome (CNN-based) and LSTM baselines. The authors argue that the proposed AttentiveChrome model is a good choice because of its interpretability, even though the performance improvement from DeepChrome is insignificant.<br />
<br />
== Evaluation of Attention Scores for Interpretation ==<br />
<br />
To understand if the model is focusing on the right regions, the authors make use of additional study results from REMC database. To validate the bin attention,signal data of a new histone mark, H3K27ac, referred to as <math>H_{active}</math> in this article, from REMC database is utilized. This particular histone mark is known to mark active region when the gene is expressed (ON). Genome-wide read of this HM mark is available for three important cell types: stem cell (H1-hESC), blood cell (GM12878), and leukemia cell (K562). This particular HM mark is used to analyze the visualization results only and not applied in the learning phase. The authors discussed performance of both the attention mechanisms in this section. <br />
<br />
=== Correlation of Importance Weight of <math>H_{prom}</math> with <math>H_{active}</math> ===<br />
<br />
Average read count of <math>H_{active}</math> across all 100 bins for all the active genes (ON or labeled as <math>+1</math>) in the three selected cell types is calculated. The proposed AttentiveChrome and LSTM-<math>\alpha</math> methods are compared with two widely used visualization techniques, (1) class based, and (2) saliency map applied on the baseline DeepChrome model (CNN-based prior work). Using these visualization methods, the authors calculate the importance weights for <math>H_{prom}</math> (promoter HM mark used in training) across the 100 bins. The Pearson Correlation score between these importance weights and the read count of the <math>H_{active}</math> (HM mark for validation) across the same 100 bins is computed. The <math>H_{active}</math> read counts indicates the actual active regions of those cells. <br />
<br />
[[File: pc.JPG|center|thumb| 700px | Figure 4: Pearson Correlation between a known active HM mark]]<br />
<br />
<br />
The results indicate that the proposed models consistently gained highest correlation with <math>H_{active}</math> for all three cell types. Thus, the proposed method is successful to capture the important signals.<br />
<br />
=== Visualization of Attention Weight of bins for each HM of a specific cell type GM12878===<br />
<br />
To visualize bin level attention weights, the authors plotted the average bin-level attention weights for each HM for a specific cell type GM12878 (blood cell) for expressed (ON) genes and suppressed (OFF) genes separately. <br />
<br />
[[File: figure2.png|center|thumb| 700px |]]<br />
<br />
For the “ON” genes, the attention profiles are well defined for the HM marks, <math>H_{prom}</math>, <math>H_{enhc}</math>, <math>H_{struct}</math>. On the other hand, the weights are low for <math>H_{reprA}</math> and <math>H_{reprB}</math>. The average trend reverses for the “OFF” genes, where the repressor HM marks have more influence than the <math>H_{prom}</math>, <math>H_{enhc}</math>, <math>H_{struct}</math>. This observation agrees with the biologist finding that <math>H_{prom}</math>, <math>H_{enhc}</math>, <math>H_{struct}</math> marks stimulates gene activation and, <math>H_{reprA}</math> and <math>H_{reprB}</math> mark restrains the genes.<br />
<br />
=== Attention Weight of bins with <math>H_{active}</math>===<br />
<br />
The average read counts of <math>H_{active}</math> for the same 100 bins across all the active (ON) genes for the cell type GM12878 is plotted (FIGURE 2(b)). Besides, for AttentiveChrome the plot of bin-level attention weights of averaged over all the genes that are PREDICTED ON for GM12878 is also provided. The plots exhibit that the <math>H_{prom}</math> profile is similar to <math>H_{active}</math>.<br />
<br />
=== Visualization of HM-level Attention Weight for Gene PAX5 ===<br />
<br />
To visualize HM-level attention weight the authors produces a heatmap for a differentially regulated gene, PAX5, for the three aforementioned cell types. The heatmap is presented in FIGURE 2(c). PAX5 plays significant role in gene regulation when stem cells convert to blood cells. This gene is OFF in stem cells (H1-hESC), however it becomes activated when the stem cell is transformed into blood cell (GM12878). The <math>\beta_j</math> weight for <math>H_{repr}</math> is high when the gene is OFF in H1-hESC, and the weight decreases when the gene is ON in GM12878. On the contrary, for <math>H_{prom}</math> mark the <math>\beta_j</math> weight increases from H1-hESC to GM12878 as the gene becomes activated. This information extracted by the deep learning model is also supported by biological literature [16].<br />
<br />
= Related Works/Studies =<br />
<br />
In the last few years, deep learning models obtained models obtained unprecedented success in diverse research fields. Though as not rapidly as other fields, deep learning based algorithms are gaining popularity among bioinformaticians.<br />
<br />
== Attention-based Deep Models ==<br />
<br />
The idea of attention technique in deep learning is adapted from the human visual perception system. Humans tend to focus over some parts more than the others while perceiving a scene. This mechanism augmented with deep neural networks achieved an excellent outcome in several research topics, such as machine translation. Various types of attention models e.g., soft [6], or location-aware [7], or hard [8, 9] attentions have been proposed in the literature. In the soft attention model, a soft weight vector is calculated for the overall feature vectors. The extent of the weight is correlated with the degree of importance of the feature in the prediction. In practice, RNN is often used to help implement such models.<br />
<br />
== Visualization and Apprehension of Deep Models ==<br />
<br />
Prior studies mostly focused on interpreting convolutional neural networks (CNN) for image classification. Deconvulation approaches [10] attempt to map hidden layer representations back to an input space. Saliency maps [11, 12], attempt to use taylor expansion to approximate the network, and identify the most relevant input features. Class optimization [12] based visualization techniques attempt to find the best example member of each class. Some recent research works [13, 14] tried to understand recurrent neural networks (RNN) for text-based problems. By looking into the features the model attends to, we can interpret the output of a deep model.<br />
<br />
== Deep Learning in Bioinformatics ==<br />
Deep learning is also getting popular in bioinformatics fields because it is able to extract meaningful representations from datasets. Scholars use deep learning to model protein sequences and DNA sequences and predicting gene expressions.<br />
<br />
== Previous model for gene expression predictions ==<br />
There were multiple machine learning models had been used to predict gene expressions, such as linear regression and support vector machines. The strategies included using signal averaging across all relevant positions and selecting input signals at positions where was highly correlated to target gene expression and then use CNN to learn combinatorial interactions among histone modification marks.<br />
<br />
= Conclusion = <br />
<br />
The paper has introduced an attention-based approach called "AttentiveChrome" that deals with both understanding and prediction with several advantages on previous architectures including higher accuracy from state-of-the-art baselines, clearer interpretation than saliency map, which allows them to view what the model ‘sees’ during prediction prediction, and class optimization. Another advantage of this approach is that it can model modular feature inputs which are sequentially structured. Finally, according to the authors, this is the first implementation of deep attention to understand gene regulation. AttentiveChrome is claimed to be the first attention based model applied on a molecular biology dataset. The authors expect that through this deep attention mechanism the biologists can have a better understanding of epigenomic data. This model can handle understanding and prediction of hard to interpret biological data.<br />
<br />
= Critiques =<br />
<br />
This paper does not give a considerable algorithmic contribution. They have only used existing methods for this application. This deep learning based method is shown to perform better than simple machine learning models like linear regression and SVMs but this is considerably harder to implement and has many more hyperparameters to tune. The training time is considerably higher, especially because all the parameters are learned together. The dataset considered in the application here also seems to have only a limited number of samples for a study of high complexity. Model hyperparameters have been chosen randomly without any explanation of intuition for them. The authors have also not cited any relevant literature to understand where these numbers came from. <br />
<br />
Discussion about attention scores for interpretation does not provide any clear definition or mention previous literature using them. Reference of literature about H3K27ac, and how its read counts represent active region of a cell should be included. No reasoning given for why only one specific cell type is used to visualize bin level attention weights. Example of some other real world problems where this model can be useful should be provided.<br />
<br />
Moreover, this paper relies heavily on the intuition. Due to complicated structures, it must be challenging to provide algorithmic/theoretical justifications. This means that there is no proper guidence of how hyperparameters should be chosen or any kinds of treatment that the author performs on other data sets.<br />
<br />
= Additional Resources =<br />
<br />
# [https://qdata.github.io/deep4biomed-web/ Official DeepChrome Website]<br />
# [http://papers.nips.cc/paper/7255-attend-and-predict-understanding-gene-regulation-by-selective-attention-on-chromatin-supplemental.zip Supplemental Resources]<br />
# [https://github.com/QData/AttentiveChrome/blob/master/NIPS%20poster.pdf Poster]<br />
# [https://www.youtube.com/watch?v=tfgmXvSgsQE&feature=youtu.be Video Presentation]<br />
<br />
= Reference =<br />
<br />
[1] Andrew J Bannister and Tony Kouzarides. Regulation of chromatin by histone modifications. Cell Research, 21(3):381–395, 2011.<br />
<br />
[2] Anshul Kundaje, Wouter Meuleman, Jason Ernst, Misha Bilenky, Angela Yen, Alireza Heravi-Moussavi, Pouya Kheradpour, Zhizhuo Zhang, Jianrong Wang, Michael J Ziller, et al. Integrative analysis of 111 reference human epigenomes. Nature, 518(7539):317–330, 2015.<br />
<br />
[3] Singh, Ritambhara, et al. "Attend and Predict: Understanding Gene Regulation by Selective Attention on Chromatin." Advances in Neural Information Processing Systems. 2017.<br />
<br />
[4] Ritambhara Singh, Jack Lanchantin, Gabriel Robins, and Yanjun Qi. Deepchrome: deep-learning for predicting gene expression from histone modifications. Bioinformatics, 32(17):i639–i648, 2016.<br />
<br />
[5] Joanna Boros, Nausica Arnoult, Vincent Stroobant, Jean-François Collet, and Anabelle Decottignies. Polycomb repressive complex 2 and h3k27me3 cooperate with h3k9 methylation to maintain heterochromatin protein 1α at chromatin. Molecular and cellular biology, 34(19):3662–3674, 2014.<br />
<br />
[6] Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473, 2014.<br />
<br />
[7] Jan K Chorowski, Dzmitry Bahdanau, Dmitriy Serdyuk, Kyunghyun Cho, and Yoshua Bengio. Attention-based models for speech recognition. In C. Cortes, N. D. Lawrence, D. D. Lee, M. Sugiyama, and R. Garnett, editors, Advances in Neural Information Processing Systems 28, pages 577–585. Curran Associates, Inc., 2015.<br />
<br />
[8] Minh-Thang Luong, Hieu Pham, and Christopher D. Manning. Effective approaches to attention-based neural machine translation. In Empirical Methods in Natural Language Processing (EMNLP), pages 1412–1421, Lisbon, Portugal, September 2015. Association for Computational Linguistics.<br />
<br />
[9] Huijuan Xu and Kate Saenko. Ask, attend and answer: Exploring question-guided spatial attention for visual question answering. In ECCV, 2016.<br />
<br />
[10] Matthew D Zeiler and Rob Fergus. Visualizing and understanding convolutional networks. In Computer Vision–ECCV 2014, pages 818–833. Springer, 2014.<br />
<br />
[11] David Baehrens, Timon Schroeter, Stefan Harmeling, Motoaki Kawanabe, Katja Hansen, and Klaus-Robert MÃžller. How to explain individual classification decisions. volume 11, pages 1803–1831, 2010.<br />
<br />
[12] Karen Simonyan, Andrea Vedaldi, and Andrew Zisserman. Deep inside convolutional networks: Visualising image classification models and saliency maps. 2013.<br />
<br />
[13] Andrej Karpathy, Justin Johnson, and Fei-Fei Li. Visualizing and understanding recurrent networks. 2015.<br />
<br />
[14] Jiwei Li, Xinlei Chen, Eduard Hovy, and Dan Jurafsky. Visualizing and understanding neural models in nlp. 2015.<br />
<br />
[15] Xianjun Dong and Zhiping Weng. The correlation between histone modifications and gene expression. Epigenomics, 5(2):113–116, 2013.<br />
<br />
[16] Shane McManus, Anja Ebert, Giorgia Salvagiotto, Jasna Medvedovic, Qiong Sun, Ido Tamir, Markus Jaritz, Hiromi Tagoh, and Meinrad Busslinger. The transcription factor pax5 regulates its target genes by recruiting chromatin-modifying proteins in committed b cells. The EMBO journal, 30(12):2388–2404, 2011.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=DON%27T_DECAY_THE_LEARNING_RATE_,_INCREASE_THE_BATCH_SIZE&diff=41969DON'T DECAY THE LEARNING RATE , INCREASE THE BATCH SIZE2018-11-30T00:19:02Z<p>Gsahu: /* THE EFFECTIVE LEARNING RATE AND THE ACCUMULATION VARIABLE */</p>
<hr />
<div>Summary of the ICLR 2018 paper: '''Don't Decay the learning Rate, Increase the Batch Size ''' <br />
<br />
Link: [https://arxiv.org/pdf/1711.00489.pdf]<br />
<br />
Summarized by: Afify, Ahmed [ID: 20700841]<br />
<br />
==INTUITION==<br />
Nowadays, it is a common practice to not have a singular steady learning rate for the learning phase of the neural network models. Instead, we use adaptive learning rates with the standard gradient descent method. The intuition behind this is that when we are far away from the minima it is beneficial for us to take large steps towards it as it would require a lesser number of steps to reach but as we approach it our step size should decrease otherwise we may just keep oscillating around the minima. In practice, this is generally achieved by methods like SGD with momentum, Nesterov momentum, and Adam. However, the core claim of this paper is that the same effect can be achieved by increasing the batch size during the gradient descent process while keeping the learning rate constant throughout. In addition, the paper argues that such an approach also reduces the parameter updates required to reach the minima, thus leading to greater parallelism and shorter training times.<br />
<br />
== INTRODUCTION ==<br />
Although stochastic gradient descent (SGD) is widely used in deep learning training process due to finding minima that generalizes well(Zhang et al., 2016; Wilson et al., 2017), the optimization process is slow and takes lots of time. According to (Goyal et al., 2017; Hoffer et al., 2017; You et al., 2017a), this has motivated researchers to try to speed up this optimization process by taking bigger steps, and hence reduce the number of parameter updates in training a model by using large batch training, which can be divided across many machines. <br />
<br />
However, increasing the batch size leads to decreasing the test set accuracy (Keskar et al., 2016; Goyal et al., 2017). Smith and Le (2017) believed that SGD has a scale of random fluctuations <math> g = \epsilon (\frac{N}{B}-1) </math>, where <math> \epsilon </math> is the learning rate, N number of training samples, and B batch size. They concluded that there is an optimal batch size proportional to the learning rate when <math> B \ll N </math>, and optimum fluctuation scale g for a maximum test set accuracy.<br />
<br />
In this paper, the authors' main goal is to provide evidence that increasing the batch size is quantitatively equivalent to decreasing the learning rate with the same number of training epochs in decreasing the scale of random fluctuations, but with remarkably less number of parameter updates. Moreover, an additional reduction in the number of parameter updates can be attained by increasing the learning rate and scaling <math> B \propto \epsilon </math> or even more reduction by increasing the momentum coefficient and scaling <math> B \propto \frac{1}{1-m} </math> although the later decreases the test accuracy. This has been demonstrated by several experiments on the ImageNet and CIFAR-10 datasets using ResNet-50 and Inception-ResNet-V2 architectures respectively.<br />
<br />
== STOCHASTIC GRADIENT DESCENT AND CONVEX OPTIMIZATION ==<br />
As mentioned in the previous section, the drawback of SGD when compared to full-batch training is the noise that it introduces that hinders optimization. According to (Robbins & Monro, 1951), there are two equations that govern how to reach the minimum of a convex function: (<math> \epsilon_i </math> denotes the learning rate at the <math> i^{th} </math> gradient update)<br />
<br />
<math> \sum_{i=1}^{\infty} \epsilon_i = \infty </math>. This equation guarantees that we will reach the minimum <br />
<br />
<math> \sum_{i=1}^{\infty} \epsilon^2_i < \infty </math>. This equation, which is valid only for a fixed batch size, guarantees that learning rate decays fast enough allowing us to reach the minimum rather than bouncing due to noise.<br />
<br />
These equations indicate that the learning rate must decay during training, and second equation is only available when the batch size is constant. To change the batch size, Smith and Le (2017) proposed to interpret SGD as integrating this stochastic differential equation <math> \frac{dw}{dt} = -\frac{dC}{dw} + \eta(t) </math>, where C represents cost function, w represents the parameters, and η represents the Gaussian random noise. Furthermore, they proved that noise scale g controls the magnitude of random fluctuations in the training dynamics by this formula: <math> g = \epsilon (\frac{N}{B}-1) </math>, where <math> \epsilon </math> is the learning rate, N is the training set size and B is the batch size. As we usually have <math> B \ll N </math>, we can define <math> g \approx \epsilon \frac{N}{B} </math>. This explains why when the learning rate decreases, noise g decreases, enabling us to converge to the minimum of the cost function. However, increasing the batch size has the same effect and makes g decays with constant learning rate. In this work, the batch size is increased until <math> B \approx \frac{N}{10} </math>, then the conventional way of decaying the learning rate is followed.<br />
<br />
== SIMULATED ANNEALING AND THE GENERALIZATION GAP ==<br />
'''Simulated Annealing:''' Introducing random noise or fluctuations whose scale falls during training.<br />
<br />
'''Generalization Gap:''' Small batch data generalizes better to the test set than large batch data.<br />
<br />
Smith and Le (2017) found that there is an optimal batch size which corresponds to optimal noise scale g <math> (g \approx \epsilon \frac{N}{B}) </math> and concluded that <math> B_{opt} \propto \epsilon N </math> that corresponds to maximum test set accuracy. This means that gradient noise is helpful as it makes SGD escape sharp minima, which does not generalize well. <br />
<br />
Simulated Annealing is a famous technique in non-convex optimization. Starting with noise in the training process helps us to discover a wide range of parameters then once we are near the optimum value, noise is reduced to fine tune our final parameters. However, more and more researches like to use the sharper decay schedules like cosine decay or step-function drops. In physical sciences, slowly annealing (or decaying) the temperature (which is the noise scale in this situation) helps to converge to the global minimum, which is sharp. But decaying the temperature in discrete steps can make the system stuck in a local minimum, which lead to higher cost and lower curvature. The authors think that deep learning has the same intuition.<br />
.<br />
<br />
== THE EFFECTIVE LEARNING RATE AND THE ACCUMULATION VARIABLE ==<br />
'''The Effective Learning Rate''' : <math> \epsilon_eff = \frac{\epsilon}{1-m} </math><br />
<br />
Smith and Le (2017) included momentum to the equation of the vanilla SGD noise scale that was defined above to be: <math> g = \frac{\epsilon}{1-m}(\frac{N}{B}-1)\approx \frac{\epsilon N}{B(1-m)} </math>, which is the same as the previous equation when m goes to 0. They found that increasing the learning rate and momentum coefficient and scaling <math> B \propto \frac{\epsilon }{1-m} </math> reduces the number of parameter updates, but the test accuracy decreases when the momentum coefficient is increased. <br />
<br />
To understand the reasons behind this, we need to analyze momentum update equations below:<br />
<br />
<center><math><br />
\Delta A = -(1-m)A + \frac{d\widehat{C}}{dw} <br />
</math><br />
<br />
<math><br />
\Delta w = -A\epsilon<br />
</math><br />
</center><br />
<br />
We can see that the Accumulation variable A, which is initially set to 0, then increases exponentially to reach its steady state value during <math> \frac{B}{N(1-m)} </math> training epochs while <math> \Delta w </math> is suppressed that can reduce the rate of convergence. Moreover, at high momentum, we have three challenges:<br />
<br />
1- Additional epochs are needed to catch up with the accumulation.<br />
<br />
2- Accumulation needs more time <math> \frac{B}{N(1-m)} </math> to forget old gradients. <br />
<br />
3- After this time, however, the accumulation cannot adapt to changes in the loss landscape.<br />
<br />
4- In the early stage, large batch size will lead to the instabilities.<br />
<br />
== EXPERIMENTS ==<br />
=== SIMULATED ANNEALING IN A WIDE RESNET ===<br />
<br />
'''Dataset:''' CIFAR-10 (50,000 training images)<br />
<br />
'''Network Architecture:''' “16-4” wide ResNet<br />
<br />
'''Training Schedules used as in the below figure:''' <br />
<br />
- Decaying learning rate: learning rate decays by a factor of 5 at a sequence of “steps”, and the batch size is constant<br />
<br />
- Increasing batch size: learning rate is constant, and the batch size is increased by a factor of 5 at every step.<br />
<br />
- Hybrid: At the beginning, the learning rate is constant and batch size is increased by a factor of 5. Then, the learning rate decays by a factor of 5 at each subsequent step, and the batch size is constant. This is the schedule that will be used if there is a hardware limit affecting a maximum batch size limit.<br />
<br />
[[File:Paper_40_Fig_1.png | 800px|center]]<br />
<br />
As shown in the below figure: in the left figure (2a), we can observe that for the training set, the three learning curves are exactly the same while in figure 2b, increasing the batch size has a huge advantage of reducing the number of parameter updates.<br />
This concludes that noise scale is the one that needs to be decayed and not the learning rate itself<br />
[[File:Paper_40_Fig_2.png | 800px|center]] <br />
<br />
To make sure that these results are the same for the test set as well, in figure 3, we can see that the three learning curves are exactly the same for SGD with momentum, and Nesterov momentum<br />
[[File:Paper_40_Fig_3.png | 800px|center]]<br />
<br />
To check for other optimizers as well. the below figure shows the same experiment as in figure 3, which is the three learning curves for test set, but for vanilla SGD and Adam, and showing <br />
[[File:Paper_40_Fig_4.png | 800px|center]]<br />
<br />
'''Conclusion:''' Decreasing the learning rate and increasing the batch size during training are equivalent<br />
<br />
=== INCREASING THE EFFECTIVE LEARNING RATE===<br />
<br />
'''Dataset:''' CIFAR-10 (50,000 training images)<br />
<br />
'''Network Architecture:''' “16-4” wide ResNet<br />
<br />
'''Training Parameters:''' Optimization Algorithm: SGD with momentum / Maximum batch size = 5120<br />
<br />
'''Training Schedules:''' <br />
<br />
Four training schedules, all of which decay the noise scale by a factor of five in a series of three steps with the same number of epochs.<br />
<br />
Original training schedule: initial learning rate of 0.1 which decays by a factor of 5 at each step, a momentum coefficient of 0.9, and a batch size of 128. <br />
<br />
Increasing batch size: learning rate of 0.1, momentum coefficient of 0.9, initial batch size of 128 that increases by a factor of 5 at each step. <br />
<br />
Increased initial learning rate: initial learning rate of 0.5, initial batch size of 640 that increase during training.<br />
<br />
Increased momentum coefficient: increased initial learning rate of 0.5, initial batch size of 3200 that increase during training, and an increased momentum coefficient of 0.98.<br />
<br />
The results of all training schedules, which are presented in the below figure, are documented in the following table:<br />
<br />
[[File:Paper_40_Table_1.png | 800px|center]]<br />
<br />
[[File:Paper_40_Fig_5.png | 800px|center]]<br />
<br />
'''Conclusion:''' Increasing the effective learning rate and scaling the batch size results in further reduction in the number of parameter updates<br />
<br />
=== TRAINING IMAGENET IN 2500 PARAMETER UPDATES===<br />
<br />
'''A) Experiment Goal:''' Control Batch Size<br />
<br />
'''Dataset:''' ImageNet (1.28 million training images)<br />
<br />
The paper modified the setup of Goyal et al. (2017), and used the following configuration:<br />
<br />
'''Network Architecture:''' Inception-ResNet-V2 <br />
<br />
'''Training Parameters:''' <br />
<br />
90 epochs / noise decayed at epoch 30, 60, and 80 by a factor of 10 / Initial ghost batch size = 32 / Learning rate = 3 / momentum coefficient = 0.9 / Initial batch size = 8192<br />
<br />
Two training schedules were used:<br />
<br />
“Decaying learning rate”, where batch size is fixed and the learning rate is decayed<br />
<br />
“Increasing batch size”, where batch size is increased to 81920 then the learning rate is decayed at two steps.<br />
<br />
[[File:Paper_40_Table_2.png | 800px|center]]<br />
<br />
[[File:Paper_40_Fig_6.png | 800px|center]]<br />
<br />
'''Conclusion:''' Increasing the batch size resulted in reducing the number of parameter updates from 14,000 to 6,000.<br />
<br />
'''B) Experiment Goal:''' Control Batch Size and Momentum Coefficient<br />
<br />
'''Training Parameters:''' Ghost batch size = 64 / noise decayed at epoch 30, 60, and 80 by a factor of 10. <br />
<br />
The below table shows the number of parameter updates and accuracy for different set of training parameters:<br />
<br />
[[File:Paper_40_Table_3.png | 800px|center]]<br />
<br />
[[File:Paper_40_Fig_7.png | 800px|center]]<br />
<br />
'''Conclusion:''' Increasing the momentum reduces the number of parameter updates, but leads to a drop in the test accuracy.<br />
<br />
=== TRAINING IMAGENET IN 30 MINUTES===<br />
<br />
'''Dataset:''' ImageNet (Already introduced in the previous section)<br />
<br />
'''Network Architecture:''' ResNet-50<br />
<br />
The paper replicated the setup of Goyal et al. (2017) while modifying the number of TPU devices, batch size, learning rate, and then calculating the time to complete 90 epochs, and measuring the accuracy, and performed the following experiments below:<br />
<br />
[[File:Paper_40_Table_4.png | 800px|center]]<br />
<br />
'''Conclusion:''' Model training times can be reduced by increasing the batch size during training.<br />
<br />
== RELATED WORK ==<br />
Main related work mentioned in the paper is as follows:<br />
<br />
- Smith & Le (2017) interpreted Stochastic gradient descent as stochastic differential equation, which the paper built on this idea to include decaying learning rate.<br />
<br />
- Mandt et al. (2017) analyzed how SGD perform in Bayesian posterior sampling.<br />
<br />
- Keskar et al. (2016) focused on the analysis of noise once the training is started.<br />
<br />
- Moreover, the proportional relationship between batch size and learning rate was first discovered by Goyal et al. (2017) and successfully trained ResNet-50 on ImageNet in one hour after discovering the proportionality relationship between batch size and learning rate.<br />
<br />
- Furthermore, You et al. (2017a) presented Layer-wise Adaptive Rate Scaling (LARS), which is appling different learning rates to train ImageNet in 14 minutes and 74.9% accuracy. <br />
<br />
- Finally, another strategy called Asynchronous-SGD that allowed (Recht et al., 2011; Dean et al., 2012) to use multiple GPUs even with small batch sizes.<br />
<br />
== CONCLUSIONS ==<br />
Increasing batch size during training has the same benefits of decaying the learning rate in addition to reducing the number of parameter updates, which corresponds to faster training time. Experiments were performed on different image datasets and various optimizers with different training schedules to prove this result. The paper proposed to increase increase the learning rate and momentum parameter m, while scaling <math> B \propto \frac{\epsilon}{1-m} </math>, which achieves fewer parameter updates, but slightly less test set accuracy as mentioned in details in the experiments’ section. In summary, on ImageNet dataset, Inception-ResNet-V2 achieved 77% validation accuracy in under 2500 parameter updates, and ResNet-50 achieved 76.1% validation set accuracy on TPU in less than 30 minutes. One of the great findings of this paper is that literature parameters were used, and no hyper parameter tuning was needed.<br />
<br />
== CRITIQUE ==<br />
'''Pros:'''<br />
<br />
- The paper showed empirically that increasing batch size and decaying learning rate are equivalent.<br />
<br />
- Several experiments were performed on different optimizers such as SGD and Adam.<br />
<br />
- Had several comparisons with previous experimental setups.<br />
<br />
'''Cons:'''<br />
<br />
- All datasets used are image datasets. Other experiments should have been done on datasets from different domains to ensure generalization. <br />
<br />
- The number of parameter updates was used as a comparison criterion, but wall-clock times could have provided additional measurable judgment although they depend on the hardware used.<br />
<br />
- Special hardware is needed for large batch training, which is not always feasible.<br />
<br />
- In section 5.2 (Increasing the Effective Learning rate), the authors did not test a range of learning rate values and used only (0.1 and 0.5). Additional results from varying the initial learning rate values from 0.1 to 3.2 are provided in the appendix, which indicates that the test accuracy begins to fall for initial learning rates greater than ~0.4. The appended results do not show validation set accuracy curves like in Figure 6, however. It would be beneficial to see if they were similar to the original 0.1 and 0.5 initial learning rate baselines.<br />
<br />
- Although the main idea of the paper is interesting, its results does not seem to be too surprising in comparison with other recent papers in the subject.<br />
<br />
- The paper could benefit from using some other models to demonstrate its claim and generalize its idea by adding some comparisons with other models as well as other recent methods to increase batch size.<br />
<br />
- The paper presents interesting ideas. However, it lacks of mathematical and theoretical analysis beyond the idea. Since the experiment is primary on image dataset and it does not provide sufficient theories, the paper itself presents limited applicability to other types. <br />
<br />
== REFERENCES ==<br />
- Takuya Akiba, Shuji Suzuki, and Keisuke Fukuda. Extremely large minibatch sgd: Training resnet-50 on imagenet in 15 minutes. arXiv preprint arXiv:1711.04325, 2017.<br />
<br />
- Lukas Balles, Javier Romero, and Philipp Hennig. Coupling adaptive batch sizes with learning rates.arXiv preprint arXiv:1612.05086, 2016.<br />
<br />
- L´eon Bottou, Frank E Curtis, and Jorge Nocedal. Optimization methods for large-scale machine learning.arXiv preprint arXiv:1606.04838, 2016.<br />
<br />
- Richard H Byrd, Gillian M Chin, Jorge Nocedal, and Yuchen Wu. Sample size selection in optimization methods for machine learning. Mathematical programming, 134(1):127–155, 2012.<br />
<br />
- Pratik Chaudhari, Anna Choromanska, Stefano Soatto, and Yann LeCun. Entropy-SGD: Biasing gradient descent into wide valleys. arXiv preprint arXiv:1611.01838, 2016.<br />
<br />
- Soham De, Abhay Yadav, David Jacobs, and Tom Goldstein. Automated inference with adaptive batches. In Artificial Intelligence and Statistics, pp. 1504–1513, 2017.<br />
<br />
- Jeffrey Dean, Greg Corrado, Rajat Monga, Kai Chen, Matthieu Devin, Mark Mao, Andrew Senior, Paul Tucker, Ke Yang, Quoc V Le, et al. Large scale distributed deep networks. In Advances in neural information processing systems, pp. 1223–1231, 2012.<br />
<br />
- Michael P Friedlander and Mark Schmidt. Hybrid deterministic-stochastic methods for data fitting.SIAM Journal on Scientific Computing, 34(3):A1380–A1405, 2012.<br />
<br />
- Priya Goyal, Piotr Doll´ar, Ross Girshick, Pieter Noordhuis, Lukasz Wesolowski, Aapo Kyrola, Andrew Tulloch, Yangqing Jia, and Kaiming He. Accurate, large minibatch SGD: Training imagenet in 1 hour. arXiv preprint arXiv:1706.02677, 2017.<br />
<br />
- Sepp Hochreiter and J¨urgen Schmidhuber. Flat minima. Neural Computation, 9(1):1–42, 1997.<br />
<br />
- Elad Hoffer, Itay Hubara, and Daniel Soudry. Train longer, generalize better: closing the generalization gap in large batch training of neural networks. arXiv preprint arXiv:1705.08741, 2017.<br />
<br />
- Norman P Jouppi, Cliff Young, Nishant Patil, David Patterson, Gaurav Agrawal, Raminder Bajwa, Sarah Bates, Suresh Bhatia, Nan Boden, Al Borchers, et al. In-datacenter performance analysis of a tensor processing unit. In Proceedings of the 44th Annual International Symposium on Computer Architecture, pp. 1–12. ACM, 2017.<br />
<br />
- Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy, and Ping Tak Peter Tang. On large-batch training for deep learning: Generalization gap and sharp minima. arXiv preprint arXiv:1609.04836, 2016.<br />
<br />
- Diederik Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.<br />
<br />
- Alex Krizhevsky. One weird trick for parallelizing convolutional neural networks. arXiv preprint arXiv:1404.5997, 2014.<br />
<br />
- Qianxiao Li, Cheng Tai, and E Weinan. Stochastic modified equations and adaptive stochastic gradient algorithms. arXiv preprint arXiv:1511.06251, 2017.<br />
<br />
- Ilya Loshchilov and Frank Hutter. SGDR: stochastic gradient descent with restarts. arXiv preprint arXiv:1608.03983, 2016.<br />
<br />
- Stephan Mandt, Matthew D Hoffman, and DavidMBlei. Stochastic gradient descent as approximate bayesian inference. arXiv preprint arXiv:1704.04289, 2017.<br />
<br />
- James Martens and Roger Grosse. Optimizing neural networks with kronecker-factored approximate curvature. In International Conference on Machine Learning, pp. 2408–2417, 2015.<br />
<br />
- Yurii Nesterov. A method of solving a convex programming problem with convergence rate o (1/k2). In Soviet Mathematics Doklady, volume 27, pp. 372–376, 1983.<br />
<br />
- Lutz Prechelt. Early stopping-but when? Neural Networks: Tricks of the trade, pp. 553–553, 1998.<br />
<br />
- Benjamin Recht, Christopher Re, Stephen Wright, and Feng Niu. Hogwild: A lock-free approach to parallelizing stochastic gradient descent. In Advances in neural information processing systems, pp. 693–701, 2011.<br />
<br />
- Herbert Robbins and Sutton Monro. A stochastic approximation method. The annals of mathematical statistics, pp. 400–407, 1951.<br />
<br />
- Samuel L. Smith and Quoc V. Le. A bayesian perspective on generalization and stochastic gradient descent. arXiv preprint arXiv:1710.06451, 2017.<br />
<br />
- Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, and Alexander A Alemi. Inception-v4, Inception-ResNet and the impact of residual connections on learning. In AAAI, pp. 4278–4284, 2017.<br />
<br />
- Max Welling and Yee W Teh. Bayesian learning via stochastic gradient langevin dynamics. In Proceedings of the 28th International Conference on Machine Learning (ICML-11), pp. 681–688, 2011.<br />
<br />
- Ashia C Wilson, Rebecca Roelofs, Mitchell Stern, Nathan Srebro, and Benjamin Recht. The marginal value of adaptive gradient methods in machine learning. arXiv preprint arXiv:1705.08292, 2017.<br />
<br />
- Yang You, Igor Gitman, and Boris Ginsburg. Scaling SGD batch size to 32k for imagenet training. arXiv preprint arXiv:1708.03888, 2017a.<br />
<br />
- Yang You, Zhao Zhang, C Hsieh, James Demmel, and Kurt Keutzer. Imagenet training in minutes. CoRR, abs/1709.05011, 2017b.<br />
<br />
- Sergey Zagoruyko and Nikos Komodakis. Wide residual networks. arXiv preprint arXiv:1605.07146, 2016.<br />
<br />
- Chiyuan Zhang, Samy Bengio, Moritz Hardt, Benjamin Recht, and Oriol Vinyals. Understanding deep learning requires rethinking generalization. arXiv preprint arXiv:1611.03530, 2016.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=DON%27T_DECAY_THE_LEARNING_RATE_,_INCREASE_THE_BATCH_SIZE&diff=41968DON'T DECAY THE LEARNING RATE , INCREASE THE BATCH SIZE2018-11-30T00:17:54Z<p>Gsahu: </p>
<hr />
<div>Summary of the ICLR 2018 paper: '''Don't Decay the learning Rate, Increase the Batch Size ''' <br />
<br />
Link: [https://arxiv.org/pdf/1711.00489.pdf]<br />
<br />
Summarized by: Afify, Ahmed [ID: 20700841]<br />
<br />
==INTUITION==<br />
Nowadays, it is a common practice to not have a singular steady learning rate for the learning phase of the neural network models. Instead, we use adaptive learning rates with the standard gradient descent method. The intuition behind this is that when we are far away from the minima it is beneficial for us to take large steps towards it as it would require a lesser number of steps to reach but as we approach it our step size should decrease otherwise we may just keep oscillating around the minima. In practice, this is generally achieved by methods like SGD with momentum, Nesterov momentum, and Adam. However, the core claim of this paper is that the same effect can be achieved by increasing the batch size during the gradient descent process while keeping the learning rate constant throughout. In addition, the paper argues that such an approach also reduces the parameter updates required to reach the minima, thus leading to greater parallelism and shorter training times.<br />
<br />
== INTRODUCTION ==<br />
Although stochastic gradient descent (SGD) is widely used in deep learning training process due to finding minima that generalizes well(Zhang et al., 2016; Wilson et al., 2017), the optimization process is slow and takes lots of time. According to (Goyal et al., 2017; Hoffer et al., 2017; You et al., 2017a), this has motivated researchers to try to speed up this optimization process by taking bigger steps, and hence reduce the number of parameter updates in training a model by using large batch training, which can be divided across many machines. <br />
<br />
However, increasing the batch size leads to decreasing the test set accuracy (Keskar et al., 2016; Goyal et al., 2017). Smith and Le (2017) believed that SGD has a scale of random fluctuations <math> g = \epsilon (\frac{N}{B}-1) </math>, where <math> \epsilon </math> is the learning rate, N number of training samples, and B batch size. They concluded that there is an optimal batch size proportional to the learning rate when <math> B \ll N </math>, and optimum fluctuation scale g for a maximum test set accuracy.<br />
<br />
In this paper, the authors' main goal is to provide evidence that increasing the batch size is quantitatively equivalent to decreasing the learning rate with the same number of training epochs in decreasing the scale of random fluctuations, but with remarkably less number of parameter updates. Moreover, an additional reduction in the number of parameter updates can be attained by increasing the learning rate and scaling <math> B \propto \epsilon </math> or even more reduction by increasing the momentum coefficient and scaling <math> B \propto \frac{1}{1-m} </math> although the later decreases the test accuracy. This has been demonstrated by several experiments on the ImageNet and CIFAR-10 datasets using ResNet-50 and Inception-ResNet-V2 architectures respectively.<br />
<br />
== STOCHASTIC GRADIENT DESCENT AND CONVEX OPTIMIZATION ==<br />
As mentioned in the previous section, the drawback of SGD when compared to full-batch training is the noise that it introduces that hinders optimization. According to (Robbins & Monro, 1951), there are two equations that govern how to reach the minimum of a convex function: (<math> \epsilon_i </math> denotes the learning rate at the <math> i^{th} </math> gradient update)<br />
<br />
<math> \sum_{i=1}^{\infty} \epsilon_i = \infty </math>. This equation guarantees that we will reach the minimum <br />
<br />
<math> \sum_{i=1}^{\infty} \epsilon^2_i < \infty </math>. This equation, which is valid only for a fixed batch size, guarantees that learning rate decays fast enough allowing us to reach the minimum rather than bouncing due to noise.<br />
<br />
These equations indicate that the learning rate must decay during training, and second equation is only available when the batch size is constant. To change the batch size, Smith and Le (2017) proposed to interpret SGD as integrating this stochastic differential equation <math> \frac{dw}{dt} = -\frac{dC}{dw} + \eta(t) </math>, where C represents cost function, w represents the parameters, and η represents the Gaussian random noise. Furthermore, they proved that noise scale g controls the magnitude of random fluctuations in the training dynamics by this formula: <math> g = \epsilon (\frac{N}{B}-1) </math>, where <math> \epsilon </math> is the learning rate, N is the training set size and B is the batch size. As we usually have <math> B \ll N </math>, we can define <math> g \approx \epsilon \frac{N}{B} </math>. This explains why when the learning rate decreases, noise g decreases, enabling us to converge to the minimum of the cost function. However, increasing the batch size has the same effect and makes g decays with constant learning rate. In this work, the batch size is increased until <math> B \approx \frac{N}{10} </math>, then the conventional way of decaying the learning rate is followed.<br />
<br />
== SIMULATED ANNEALING AND THE GENERALIZATION GAP ==<br />
'''Simulated Annealing:''' Introducing random noise or fluctuations whose scale falls during training.<br />
<br />
'''Generalization Gap:''' Small batch data generalizes better to the test set than large batch data.<br />
<br />
Smith and Le (2017) found that there is an optimal batch size which corresponds to optimal noise scale g <math> (g \approx \epsilon \frac{N}{B}) </math> and concluded that <math> B_{opt} \propto \epsilon N </math> that corresponds to maximum test set accuracy. This means that gradient noise is helpful as it makes SGD escape sharp minima, which does not generalize well. <br />
<br />
Simulated Annealing is a famous technique in non-convex optimization. Starting with noise in the training process helps us to discover a wide range of parameters then once we are near the optimum value, noise is reduced to fine tune our final parameters. However, more and more researches like to use the sharper decay schedules like cosine decay or step-function drops. In physical sciences, slowly annealing (or decaying) the temperature (which is the noise scale in this situation) helps to converge to the global minimum, which is sharp. But decaying the temperature in discrete steps can make the system stuck in a local minimum, which lead to higher cost and lower curvature. The authors think that deep learning has the same intuition.<br />
.<br />
<br />
== THE EFFECTIVE LEARNING RATE AND THE ACCUMULATION VARIABLE ==<br />
'''The Effective Learning Rate''' <math> \epsilon_eff = \frac{\epsilon}{1-m} </math><br />
<br />
Smith and Le (2017) included momentum to the equation of the vanilla SGD noise scale that was defined above to be: <math> g = \frac{\epsilon}{1-m}(\frac{N}{B}-1)\approx \frac{\epsilon N}{B(1-m)} </math>, which is the same as the previous equation when m goes to 0. They found that increasing the learning rate and momentum coefficient and scaling <math> B \propto \frac{\epsilon }{1-m} </math> reduces the number of parameter updates, but the test accuracy decreases when the momentum coefficient is increased. <br />
<br />
To understand the reasons behind this, we need to analyze momentum update equations below:<br />
<br />
<math><br />
\Delta A = -(1-m)A + \frac{d\widehat{C}}{dw} <br />
</math><br />
<br />
<math><br />
\Delta w = -A\epsilon<br />
</math><br />
<br />
We can see that the Accumulation variable A, which is initially set to 0, then increases exponentially to reach its steady state value during <math> \frac{B}{N(1-m)} </math> training epochs while <math> \Delta w </math> is suppressed that can reduce the rate of convergence. Moreover, at high momentum, we have three challenges:<br />
<br />
1- Additional epochs are needed to catch up with the accumulation.<br />
<br />
2- Accumulation needs more time <math> \frac{B}{N(1-m)} </math> to forget old gradients. <br />
<br />
3- After this time, however, the accumulation cannot adapt to changes in the loss landscape.<br />
<br />
4- In the early stage, large batch size will lead to the instabilities.<br />
<br />
== EXPERIMENTS ==<br />
=== SIMULATED ANNEALING IN A WIDE RESNET ===<br />
<br />
'''Dataset:''' CIFAR-10 (50,000 training images)<br />
<br />
'''Network Architecture:''' “16-4” wide ResNet<br />
<br />
'''Training Schedules used as in the below figure:''' <br />
<br />
- Decaying learning rate: learning rate decays by a factor of 5 at a sequence of “steps”, and the batch size is constant<br />
<br />
- Increasing batch size: learning rate is constant, and the batch size is increased by a factor of 5 at every step.<br />
<br />
- Hybrid: At the beginning, the learning rate is constant and batch size is increased by a factor of 5. Then, the learning rate decays by a factor of 5 at each subsequent step, and the batch size is constant. This is the schedule that will be used if there is a hardware limit affecting a maximum batch size limit.<br />
<br />
[[File:Paper_40_Fig_1.png | 800px|center]]<br />
<br />
As shown in the below figure: in the left figure (2a), we can observe that for the training set, the three learning curves are exactly the same while in figure 2b, increasing the batch size has a huge advantage of reducing the number of parameter updates.<br />
This concludes that noise scale is the one that needs to be decayed and not the learning rate itself<br />
[[File:Paper_40_Fig_2.png | 800px|center]] <br />
<br />
To make sure that these results are the same for the test set as well, in figure 3, we can see that the three learning curves are exactly the same for SGD with momentum, and Nesterov momentum<br />
[[File:Paper_40_Fig_3.png | 800px|center]]<br />
<br />
To check for other optimizers as well. the below figure shows the same experiment as in figure 3, which is the three learning curves for test set, but for vanilla SGD and Adam, and showing <br />
[[File:Paper_40_Fig_4.png | 800px|center]]<br />
<br />
'''Conclusion:''' Decreasing the learning rate and increasing the batch size during training are equivalent<br />
<br />
=== INCREASING THE EFFECTIVE LEARNING RATE===<br />
<br />
'''Dataset:''' CIFAR-10 (50,000 training images)<br />
<br />
'''Network Architecture:''' “16-4” wide ResNet<br />
<br />
'''Training Parameters:''' Optimization Algorithm: SGD with momentum / Maximum batch size = 5120<br />
<br />
'''Training Schedules:''' <br />
<br />
Four training schedules, all of which decay the noise scale by a factor of five in a series of three steps with the same number of epochs.<br />
<br />
Original training schedule: initial learning rate of 0.1 which decays by a factor of 5 at each step, a momentum coefficient of 0.9, and a batch size of 128. <br />
<br />
Increasing batch size: learning rate of 0.1, momentum coefficient of 0.9, initial batch size of 128 that increases by a factor of 5 at each step. <br />
<br />
Increased initial learning rate: initial learning rate of 0.5, initial batch size of 640 that increase during training.<br />
<br />
Increased momentum coefficient: increased initial learning rate of 0.5, initial batch size of 3200 that increase during training, and an increased momentum coefficient of 0.98.<br />
<br />
The results of all training schedules, which are presented in the below figure, are documented in the following table:<br />
<br />
[[File:Paper_40_Table_1.png | 800px|center]]<br />
<br />
[[File:Paper_40_Fig_5.png | 800px|center]]<br />
<br />
'''Conclusion:''' Increasing the effective learning rate and scaling the batch size results in further reduction in the number of parameter updates<br />
<br />
=== TRAINING IMAGENET IN 2500 PARAMETER UPDATES===<br />
<br />
'''A) Experiment Goal:''' Control Batch Size<br />
<br />
'''Dataset:''' ImageNet (1.28 million training images)<br />
<br />
The paper modified the setup of Goyal et al. (2017), and used the following configuration:<br />
<br />
'''Network Architecture:''' Inception-ResNet-V2 <br />
<br />
'''Training Parameters:''' <br />
<br />
90 epochs / noise decayed at epoch 30, 60, and 80 by a factor of 10 / Initial ghost batch size = 32 / Learning rate = 3 / momentum coefficient = 0.9 / Initial batch size = 8192<br />
<br />
Two training schedules were used:<br />
<br />
“Decaying learning rate”, where batch size is fixed and the learning rate is decayed<br />
<br />
“Increasing batch size”, where batch size is increased to 81920 then the learning rate is decayed at two steps.<br />
<br />
[[File:Paper_40_Table_2.png | 800px|center]]<br />
<br />
[[File:Paper_40_Fig_6.png | 800px|center]]<br />
<br />
'''Conclusion:''' Increasing the batch size resulted in reducing the number of parameter updates from 14,000 to 6,000.<br />
<br />
'''B) Experiment Goal:''' Control Batch Size and Momentum Coefficient<br />
<br />
'''Training Parameters:''' Ghost batch size = 64 / noise decayed at epoch 30, 60, and 80 by a factor of 10. <br />
<br />
The below table shows the number of parameter updates and accuracy for different set of training parameters:<br />
<br />
[[File:Paper_40_Table_3.png | 800px|center]]<br />
<br />
[[File:Paper_40_Fig_7.png | 800px|center]]<br />
<br />
'''Conclusion:''' Increasing the momentum reduces the number of parameter updates, but leads to a drop in the test accuracy.<br />
<br />
=== TRAINING IMAGENET IN 30 MINUTES===<br />
<br />
'''Dataset:''' ImageNet (Already introduced in the previous section)<br />
<br />
'''Network Architecture:''' ResNet-50<br />
<br />
The paper replicated the setup of Goyal et al. (2017) while modifying the number of TPU devices, batch size, learning rate, and then calculating the time to complete 90 epochs, and measuring the accuracy, and performed the following experiments below:<br />
<br />
[[File:Paper_40_Table_4.png | 800px|center]]<br />
<br />
'''Conclusion:''' Model training times can be reduced by increasing the batch size during training.<br />
<br />
== RELATED WORK ==<br />
Main related work mentioned in the paper is as follows:<br />
<br />
- Smith & Le (2017) interpreted Stochastic gradient descent as stochastic differential equation, which the paper built on this idea to include decaying learning rate.<br />
<br />
- Mandt et al. (2017) analyzed how SGD perform in Bayesian posterior sampling.<br />
<br />
- Keskar et al. (2016) focused on the analysis of noise once the training is started.<br />
<br />
- Moreover, the proportional relationship between batch size and learning rate was first discovered by Goyal et al. (2017) and successfully trained ResNet-50 on ImageNet in one hour after discovering the proportionality relationship between batch size and learning rate.<br />
<br />
- Furthermore, You et al. (2017a) presented Layer-wise Adaptive Rate Scaling (LARS), which is appling different learning rates to train ImageNet in 14 minutes and 74.9% accuracy. <br />
<br />
- Finally, another strategy called Asynchronous-SGD that allowed (Recht et al., 2011; Dean et al., 2012) to use multiple GPUs even with small batch sizes.<br />
<br />
== CONCLUSIONS ==<br />
Increasing batch size during training has the same benefits of decaying the learning rate in addition to reducing the number of parameter updates, which corresponds to faster training time. Experiments were performed on different image datasets and various optimizers with different training schedules to prove this result. The paper proposed to increase increase the learning rate and momentum parameter m, while scaling <math> B \propto \frac{\epsilon}{1-m} </math>, which achieves fewer parameter updates, but slightly less test set accuracy as mentioned in details in the experiments’ section. In summary, on ImageNet dataset, Inception-ResNet-V2 achieved 77% validation accuracy in under 2500 parameter updates, and ResNet-50 achieved 76.1% validation set accuracy on TPU in less than 30 minutes. One of the great findings of this paper is that literature parameters were used, and no hyper parameter tuning was needed.<br />
<br />
== CRITIQUE ==<br />
'''Pros:'''<br />
<br />
- The paper showed empirically that increasing batch size and decaying learning rate are equivalent.<br />
<br />
- Several experiments were performed on different optimizers such as SGD and Adam.<br />
<br />
- Had several comparisons with previous experimental setups.<br />
<br />
'''Cons:'''<br />
<br />
- All datasets used are image datasets. Other experiments should have been done on datasets from different domains to ensure generalization. <br />
<br />
- The number of parameter updates was used as a comparison criterion, but wall-clock times could have provided additional measurable judgment although they depend on the hardware used.<br />
<br />
- Special hardware is needed for large batch training, which is not always feasible.<br />
<br />
- In section 5.2 (Increasing the Effective Learning rate), the authors did not test a range of learning rate values and used only (0.1 and 0.5). Additional results from varying the initial learning rate values from 0.1 to 3.2 are provided in the appendix, which indicates that the test accuracy begins to fall for initial learning rates greater than ~0.4. The appended results do not show validation set accuracy curves like in Figure 6, however. It would be beneficial to see if they were similar to the original 0.1 and 0.5 initial learning rate baselines.<br />
<br />
- Although the main idea of the paper is interesting, its results does not seem to be too surprising in comparison with other recent papers in the subject.<br />
<br />
- The paper could benefit from using some other models to demonstrate its claim and generalize its idea by adding some comparisons with other models as well as other recent methods to increase batch size.<br />
<br />
- The paper presents interesting ideas. However, it lacks of mathematical and theoretical analysis beyond the idea. Since the experiment is primary on image dataset and it does not provide sufficient theories, the paper itself presents limited applicability to other types. <br />
<br />
== REFERENCES ==<br />
- Takuya Akiba, Shuji Suzuki, and Keisuke Fukuda. Extremely large minibatch sgd: Training resnet-50 on imagenet in 15 minutes. arXiv preprint arXiv:1711.04325, 2017.<br />
<br />
- Lukas Balles, Javier Romero, and Philipp Hennig. Coupling adaptive batch sizes with learning rates.arXiv preprint arXiv:1612.05086, 2016.<br />
<br />
- L´eon Bottou, Frank E Curtis, and Jorge Nocedal. Optimization methods for large-scale machine learning.arXiv preprint arXiv:1606.04838, 2016.<br />
<br />
- Richard H Byrd, Gillian M Chin, Jorge Nocedal, and Yuchen Wu. Sample size selection in optimization methods for machine learning. Mathematical programming, 134(1):127–155, 2012.<br />
<br />
- Pratik Chaudhari, Anna Choromanska, Stefano Soatto, and Yann LeCun. Entropy-SGD: Biasing gradient descent into wide valleys. arXiv preprint arXiv:1611.01838, 2016.<br />
<br />
- Soham De, Abhay Yadav, David Jacobs, and Tom Goldstein. Automated inference with adaptive batches. In Artificial Intelligence and Statistics, pp. 1504–1513, 2017.<br />
<br />
- Jeffrey Dean, Greg Corrado, Rajat Monga, Kai Chen, Matthieu Devin, Mark Mao, Andrew Senior, Paul Tucker, Ke Yang, Quoc V Le, et al. Large scale distributed deep networks. In Advances in neural information processing systems, pp. 1223–1231, 2012.<br />
<br />
- Michael P Friedlander and Mark Schmidt. Hybrid deterministic-stochastic methods for data fitting.SIAM Journal on Scientific Computing, 34(3):A1380–A1405, 2012.<br />
<br />
- Priya Goyal, Piotr Doll´ar, Ross Girshick, Pieter Noordhuis, Lukasz Wesolowski, Aapo Kyrola, Andrew Tulloch, Yangqing Jia, and Kaiming He. Accurate, large minibatch SGD: Training imagenet in 1 hour. arXiv preprint arXiv:1706.02677, 2017.<br />
<br />
- Sepp Hochreiter and J¨urgen Schmidhuber. Flat minima. Neural Computation, 9(1):1–42, 1997.<br />
<br />
- Elad Hoffer, Itay Hubara, and Daniel Soudry. Train longer, generalize better: closing the generalization gap in large batch training of neural networks. arXiv preprint arXiv:1705.08741, 2017.<br />
<br />
- Norman P Jouppi, Cliff Young, Nishant Patil, David Patterson, Gaurav Agrawal, Raminder Bajwa, Sarah Bates, Suresh Bhatia, Nan Boden, Al Borchers, et al. In-datacenter performance analysis of a tensor processing unit. In Proceedings of the 44th Annual International Symposium on Computer Architecture, pp. 1–12. ACM, 2017.<br />
<br />
- Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy, and Ping Tak Peter Tang. On large-batch training for deep learning: Generalization gap and sharp minima. arXiv preprint arXiv:1609.04836, 2016.<br />
<br />
- Diederik Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.<br />
<br />
- Alex Krizhevsky. One weird trick for parallelizing convolutional neural networks. arXiv preprint arXiv:1404.5997, 2014.<br />
<br />
- Qianxiao Li, Cheng Tai, and E Weinan. Stochastic modified equations and adaptive stochastic gradient algorithms. arXiv preprint arXiv:1511.06251, 2017.<br />
<br />
- Ilya Loshchilov and Frank Hutter. SGDR: stochastic gradient descent with restarts. arXiv preprint arXiv:1608.03983, 2016.<br />
<br />
- Stephan Mandt, Matthew D Hoffman, and DavidMBlei. Stochastic gradient descent as approximate bayesian inference. arXiv preprint arXiv:1704.04289, 2017.<br />
<br />
- James Martens and Roger Grosse. Optimizing neural networks with kronecker-factored approximate curvature. In International Conference on Machine Learning, pp. 2408–2417, 2015.<br />
<br />
- Yurii Nesterov. A method of solving a convex programming problem with convergence rate o (1/k2). In Soviet Mathematics Doklady, volume 27, pp. 372–376, 1983.<br />
<br />
- Lutz Prechelt. Early stopping-but when? Neural Networks: Tricks of the trade, pp. 553–553, 1998.<br />
<br />
- Benjamin Recht, Christopher Re, Stephen Wright, and Feng Niu. Hogwild: A lock-free approach to parallelizing stochastic gradient descent. In Advances in neural information processing systems, pp. 693–701, 2011.<br />
<br />
- Herbert Robbins and Sutton Monro. A stochastic approximation method. The annals of mathematical statistics, pp. 400–407, 1951.<br />
<br />
- Samuel L. Smith and Quoc V. Le. A bayesian perspective on generalization and stochastic gradient descent. arXiv preprint arXiv:1710.06451, 2017.<br />
<br />
- Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, and Alexander A Alemi. Inception-v4, Inception-ResNet and the impact of residual connections on learning. In AAAI, pp. 4278–4284, 2017.<br />
<br />
- Max Welling and Yee W Teh. Bayesian learning via stochastic gradient langevin dynamics. In Proceedings of the 28th International Conference on Machine Learning (ICML-11), pp. 681–688, 2011.<br />
<br />
- Ashia C Wilson, Rebecca Roelofs, Mitchell Stern, Nathan Srebro, and Benjamin Recht. The marginal value of adaptive gradient methods in machine learning. arXiv preprint arXiv:1705.08292, 2017.<br />
<br />
- Yang You, Igor Gitman, and Boris Ginsburg. Scaling SGD batch size to 32k for imagenet training. arXiv preprint arXiv:1708.03888, 2017a.<br />
<br />
- Yang You, Zhao Zhang, C Hsieh, James Demmel, and Kurt Keutzer. Imagenet training in minutes. CoRR, abs/1709.05011, 2017b.<br />
<br />
- Sergey Zagoruyko and Nikos Komodakis. Wide residual networks. arXiv preprint arXiv:1605.07146, 2016.<br />
<br />
- Chiyuan Zhang, Samy Bengio, Moritz Hardt, Benjamin Recht, and Oriol Vinyals. Understanding deep learning requires rethinking generalization. arXiv preprint arXiv:1611.03530, 2016.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Deep_Reinforcement_Learning_in_Continuous_Action_Spaces_a_Case_Study_in_the_Game_of_Simulated_Curling&diff=41967Deep Reinforcement Learning in Continuous Action Spaces a Case Study in the Game of Simulated Curling2018-11-30T00:17:12Z<p>Gsahu: /* Conclusion & Critique */</p>
<hr />
<div>This page provides a summary and critique of the paper '''Deep Reinforcement Learning in Continuous Action Spaces: a Case Study in the Game of Simulated Curling''' [[http://proceedings.mlr.press/v80/lee18b/lee18b.pdf Online Source]], published in ICML 2018. The source code for this paper is available [https://github.com/leekwoon/KR-DL-UCT here]<br />
<br />
= Introduction and Motivation =<br />
<br />
In recent years, Reinforcement Learning methods have been applied to many different games, such as chess and checkers. More recently, the use of CNN's has allowed neural networks to out-perform humans in many difficult games, such as Go. However, many of these cases involve a discrete state or action space; the number of actions a player can take and/or the number of possible game states are finite. <br />
<br />
Interacting with the real world (e.g.; a scenario that involves moving physical objects) typically involves working with a continuous action space. It is thus important to develop strategies for dealing with continuous action spaces. Deep neural networks that are designed to succeed in finite action spaces are not necessarily suitable for continuous action space problems. This is due to the fact that deterministic discretization of a continuous action space causes strong biases in policy evaluation and improvement. <br />
<br />
This paper introduces a method to allow learning with continuous action spaces. A CNN is used to perform learning on a discretion state and action spaces, and then a continuous action search is performed on these discrete results.<br />
<br />
Curling is chosen as a domain to test the network on. Curling was chosen due to its large action space, potential for complicated strategies, and need for precise interactions.<br />
<br />
== Curling ==<br />
<br />
Curling is a sport played by two teams on a long sheet of ice. Roughly, the goal is for each time to slide rocks closer to the target on the other end of the sheet than the other team. The next sections will provide a background on the game play, and potential challenges/concerns for learning algorithms. A terminology section follows.<br />
<br />
=== Game play ===<br />
<br />
A game of curling is divided into ends. In each end, players from both teams alternate throwing (sliding) eight rocks to the other end of the ice sheet, known as the house. Rocks must land in a certain area in order to stay in play, and must touch or be inside concentric rings (12ft diameter and smaller) in order to score points. At the end of each end, the team with rocks closest to the center of the house scores points.<br />
<br />
When throwing a rock, the curling can spin the rock. This allows the rock to 'curl' its path towards the house and can allow rocks to travel around other rocks. Team members are also able to sweep the ice in front of a moving rock in order to decrease friction, which allows for fine-tuning of distance (though the physics of sweeping are not implemented in the simulation used).<br />
<br />
Curling offers many possible high-level actions, which are directed by a team member to the throwing member. An example set of these includes:<br />
<br />
* Draw: Throw a rock to a target location<br />
* Freeze: Draw a rock up against another rock<br />
* Takeout: Knock another rock out of the house. Can be combined with different ricochet directions<br />
* Guard: Place a rock in front of another, to block other rocks (ex: takeouts)<br />
<br />
=== Challenges for AI ===<br />
<br />
Curling offers many challenges for curling based on its physics and rules. This section lists a few concerns.<br />
<br />
The effect of changing actions can be highly nonlinear and discontinuous. This can be seen when considering that a 1-cm deviation in a path can make the difference between a high-speed collision, or lack of collision.<br />
<br />
Curling will require both offensive and defensive strategies. For example, consider the fact that the last team to throw a rock each end only needs to place that rock closer than the opposing team's rocks to score a point and invalidate any opposing rocks in the house. The opposing team should thus be considering how to prevent this from happening, in addition to scoring points themselves.<br />
<br />
Curling also has a concept known as 'the hammer'. The hammer belongs to the team which throws the last rock each end, providing an advantage, and is given to the team that does not score points each end. It could very well be a good strategy to try not to win a single point in an end (if already ahead in points, etc), as this would give the advantage to the opposing team.<br />
<br />
Finally, curling has a rule known as the 'Free Guard Zone'. This applies to the first 4 rocks thrown (2 from each team). If they land short of the house, but still in play, then the rocks are not allowed to be removed (via collisions) until all of the first 4 rocks have been thrown.<br />
<br />
=== Terminology ===<br />
<br />
* End: A round of the game<br />
* House: The end of the sheet of ice, which contains<br />
* Hammer: The team that throws the last rock of an end 'has the hammer'<br />
* Hog Line: thick line that is drawn in front of the house, orthogonal to the length of the ice sheet. Rocks must pass this line to remain in play.<br />
* Back Line: think line drawn just behind the house. Rocks that pass this line are removed from play.<br />
<br />
<br />
== Related Work ==<br />
<br />
=== AlphaGo Lee ===<br />
<br />
AlphaGo Lee (Silver et al., 2016, [5]) refers to an algorithm used to play the game Go, which was able to defeat international champion Lee Sedol. <br />
<br />
<br />
Go game:<br />
* Start with 19x19 empty board<br />
* One player take black stones and the other take white stones<br />
* Two players take turns to put stones on the board<br />
* Rules:<br />
1. If one connected part is completely surrounded by the opponents stones, remove it from the board<br />
<br />
2. Ko rule: Forbids a board play to repeat a board position<br />
* End when there is no valuable moves on the board.<br />
* Count the territory of both players.<br />
* Add 7.5 points to whites points (called Komi).<br />
[[File:go.JPG|700px|center]]<br />
<br />
Two neural networks were trained on the moves of human experts, to act as both a policy network and a value network. A Monte Carlo Tree Search algorithm was used for policy improvement.<br />
<br />
The AlphaGo Lee policy network predicts the best move given a board configuration. It has a CNN architecture with 13 hidden layers, and it is trained using expert game play data and improved through self-play.<br />
<br />
The value network evaluates the probability of winning given a board configuration. It consists of a CNN with 14 hidden layers, and it is trained using self-play data from the policy network. <br />
<br />
Finally, the two networks are combined using Monte-Carlo Tree Search, which performs look ahead search to select the actions for game play.<br />
<br />
The use of both policy and value networks are reflected in this paper's work.<br />
<br />
=== AlphaGo Zero ===<br />
<br />
AlphaGo Zero (Silver et al., 2017, [6]) is an improvement on the AlphaGo Lee algorithm. AlphaGo Zero uses a unified neural network in place of the separate policy and value networks and is trained on self-play, without the need of expert training.<br />
<br />
The unification of networks and self-play are also reflected in this paper.<br />
<br />
=== Curling Algorithms ===<br />
<br />
Some past algorithms have been proposed to deal with continuous action spaces. For example, (Yammamoto et al, 2015, [7]) use game tree search methods in a discretized space. The value of an action is taken as the average of nearby values, with respect to some knowledge of execution uncertainty.<br />
<br />
=== Monte Carlo Tree Search ===<br />
<br />
Monte Carlo Tree Search algorithms have been applied to continuous action spaces. These algorithms, to be discussed in further detail, balance exploration of different states, with knowledge of paths of execution through past games. An MCTS called <math>KR-UCT</math> which is able to find effective selections and use kernel regression (KR) and kernel density estimation(KDE) to estimate rewards using neighborhood information has been applied to continuous action space by researchers. <br />
<br />
With bandit problem, scholars used hierarchical optimistic optimization(HOO) to create a cover tree and divide the action space into small ranges at different depths, where the most promising node will create fine granularity estimates.<br />
<br />
=== Curling Physics and Simulation ===<br />
<br />
Several references in the paper refer to the study and simulation of curling physics. Scholars have analyzed friction coefficients between curling stones and ice. While modelling the changes in friction on ice is not possible, a fixed friction coefficient was predefined in the simulation. The behavior of the stones was also modeled. Important parameters are trained from professional players. The authors used the same parameters in this paper.<br />
<br />
== General Background of Algorithms ==<br />
<br />
=== Policy and Value Functions ===<br />
<br />
A policy function is trained to provide the best action to take, given a current state. Policy iteration is an algorithm used to improve a policy over time. This is done by alternating between policy evaluation and policy improvement.<br />
<br />
POLICY IMPROVEMENT: LEARNING ACTION POLICY<br />
<br />
Action policy <math> p_{\sigma}(a|s) </math> outputs a probability distribution over all eligible moves <math> a </math>. Here <math> \sigma </math> denotes the weights of a neural network that approximates the policy. <math>s</math> denotes the set of states and <math>a</math> denotes the set of actions taken in the environment. The policy is a function that returns a action given the state at which the agent is present. The policy gradient reinforcement learning can be used to train action policy. It is updated by stochastic gradient ascent in the direction that maximizes the expected outcome at each time step t,<br />
\[ \Delta \rho \propto \frac{\partial p_{\rho}(a_t|s_t)}{\partial \rho} r(s_t) \]<br />
where <math> r(s_t) </math> is the return.<br />
<br />
POLICY EVALUATION: LEARNING VALUE FUNCTIONS<br />
<br />
A value function is trained to estimate the value of a value of being in a certain state with parameter <math> \theta </math>. It is trained based on records of state-action-reward sets <math> (s, r(s)) </math> by using stochastic gradient de- scent to minimize the mean squared error (MSE) between the predicted regression value and the corresponding outcome,<br />
\[ \Delta \theta \propto \frac{\partial v_{\theta}(s)}{\partial \theta}(r(s)-v_{\theta}(s)) \]<br />
<br />
=== Monte Carlo Tree Search ===<br />
<br />
Monte Carlo Tree Search (MCTS) is a search algorithm used for finite-horizon tasks (ex: in curling, only 16 moves, or throw stones, are taken each end).<br />
<br />
MCTS is a tree search algorithm similar to minimax. However, MCTS is probabilistic and does not need to explore a full game tree or even a tree reduced with alpha-beta pruning. This makes it tractable for games such as GO, and curling.<br />
<br />
Nodes of the tree are game states, and branches represent actions. Each node stores statistics on how many times it has been visited by the MCTS, as well as the number of wins encountered by playouts from that position. A node has been considered 'visited' if a full playout has started from that node. A node is considered 'expanded' if all its children have been visited.<br />
<br />
MCTS begins with the '''selection''' phase, which involves traversing known states/actions. This involves expanding the tree by beginning at the root node, and selecting the child/score with the highest 'score'. From each successive node, a path down to a root node is explored in a similar fashion.<br />
<br />
The next phase, '''expansion''', begins when the algorithm reaches a node where not all children have been visited (ie: the node has not been fully expanded). In the expansion phase, children of the node are visited, and '''simulations''' run from their states.<br />
<br />
Once the new child is expanded, '''simulation''' takes place. This refers to a full playout of the game from the point of the current node, and can involve many strategies, such as randomly taken moves, the use of heuristics, etc.<br />
<br />
The final phase is '''update''' or '''back-propagation''' (unrelated to the neural network algorithm). In this phase, the result of the '''simulation''' (ie: win/lose) is update in the statistics of all parent nodes.<br />
<br />
A selection function known as Upper Confidence Bound (UCT) can be used for selecting which node to select. The formula for this equation is shown below [[https://www.baeldung.com/java-monte-carlo-tree-search source]]. Note that the first term essentially acts as an average score of games played from a certain node. The second term, meanwhile, will grow when sibling nodes are expanded. This means that unexplored nodes will gradually increase their UCT score, and be selected in the future.<br />
<br />
<math> \frac{w_i}{n_i} + c \sqrt{\frac{\ln t}{n_i}} </math><br />
<br />
In which<br />
<br />
* <math> w_i = </math> number of wins after <math> i</math>th move<br />
* <math> n_i = </math> number of simulations after <math> i</math>th move<br />
* <math> c = </math> exploration parameter (theoritically eqal to <math> \sqrt{2}</math>)<br />
* <math> t = </math> total number of simulations for the parent node<br />
<br />
<br />
Sources: 2,3,4<br />
<br />
[[File:MCTS_Diagram.jpg | 500px|center]]<br />
<br />
=== Kernel Regression ===<br />
<br />
Kernel regression is a form of weighted averaging which uses a kernel function as a weight to estimate the conditional expectation of a random variable. Given two items of data, '''x''', each of which has a value '''y''' associated with them, and a choice of Kernel '''K''', the kernel functions outputs a weighting factor. An estimate of the value of a new, unseen point, is then calculated as the weighted average of values of surrounding points.<br />
<br />
A typical kernel is a Gaussian kernel, shown below. The formula for calculating estimated value is shown below as well (sources: Lee et al.).<br />
<br />
[[File:gaussian_kernel.png | 400 px]]<br />
<br />
[[File:kernel_regression.png | 250 px]]<br />
<br />
The denominator of the conditional expectation is related to kernel density estimation, which is defined as <math display="inline">W(x)=\sum_{i=0}^n K(x,x_i)</math>.<br />
<br />
In this case, the combination of the two-act to weigh scores of samples closest to '''x''' more strongly.<br />
<br />
= Methods =<br />
<br />
== Variable Definitions ==<br />
<br />
The following variables are used often in the paper:<br />
<br />
* <math>s</math>: A state in the game, as described below as the input to the network.<br />
* <math>s_t</math>: The state at a certain time-step of the game. Time-steps refer to full turns in the game<br />
* <math>a_t</math>: The action taken in state <math>s_t</math><br />
* <math>A_t</math>: The actions taken for sibling nodes related to <math>a_t</math> in MCTS<br />
* <math>n_{a_t}</math>: The number of visits to node a in MCTS<br />
* <math>v_{a_t}</math>: The MCTS value estimate of a node<br />
<br />
== Network Design ==<br />
<br />
The authors design a CNN called the 'policy-value' network. The network consists of a common network structure, which is then split into 'policy' and 'value' outputs. This network is trained to learn a probability distribution of actions to take, and expected rewards, given an input state.<br />
<br />
=== Shared Structure ===<br />
<br />
The network consists of 1 convolutional layer followed by 9 residual blocks, each block consisting of 2 convolutional layers with 32 3x3 filters. The structure of this network is shown below:<br />
<br />
<br />
[[File:curling_network_layers.png|600px|thumb|center|Figure 2. A detail description of our policy-value network. The shared network is composed of one convolutional layer and nine residual blocks. Each residual block (explained in b) has two convolutional layer with batch normalization (Ioffe & Szegedy, 2015[11]) followed by the addition of the input and the residual block. Each layer in the shared network uses 3x3 filters. The policy head<br />
has two more convolutional layers, while the value head has two fully connected layers on top of a convolutional layer. For the activation function of each convolutional layer, ReLU (Nair & Hinton[12]) is used.]]<br />
<br />
<br />
<br />
the input to this network is the following:<br />
* Location of stones<br />
* Order to tee (the center of the sheet)<br />
* A 32x32 grid of representation of the ice sheet, representing which stones are present in each grid cell.<br />
<br />
The authors do not describe how the stone-based information is added to the 32x32 grid as input to the network.<br />
<br />
=== Policy Network ===<br />
<br />
The policy head is created by adding 2 convolutional layers with 2 (two) 3x3 filters to the main body of the network. The output of the policy head is a distribution of probabilities of the actions to select the best shot out of a 32x32x2 set of actions. The actions represent target locations in the grid and spin direction of the stone.<br />
<br />
[[File:policy-value-net.PNG | 700px]]<br />
<br />
=== Value Network ===<br />
<br />
The valve head is created by adding a convolution layer with 1 3x3 filter, and dense layers of 256 and 17 units, to the shared network. The 17 output units represent a probability of scores in the range of [-8,8], which are the possible scores at each end of a curling game.<br />
<br />
== Continuous Action Search ==<br />
<br />
The policy head of the network only outputs actions from a discretized action space. For real-life interactions, and especially in curling, this will not suffice, as very fine adjustments to actions can make significant differences in outcomes.<br />
<br />
Actions in the continuous space are generated using an MCTS algorithm, with the following steps:<br />
<br />
=== Selection ===<br />
<br />
From a given state, the list of already-visited actions is denoted as A<sub>t</sub>. Scores and the number of visits to each node are estimated using the equations below (the first equation shows the expectation of the end value for one-end games). These are likely estimated rather than simply taken from the MCTS statistics to help account for the differences in a continuous action space.<br />
<br />
[[File:curling_kernel_equations.png | 400px]]<br />
<br />
The UCB formula is then used to select an action to expand.<br />
<br />
The actions that are taken in the simulator appear to be drawn from a Gaussian centered around <math>a_t</math>. This allows exploration in the continuous action space.<br />
<br />
=== Expansion ===<br />
<br />
The authors use a variant of regular UCT for expansion. In this case, they expand a new node only when existing nodes have been visited a certain number of times. The authors utilize a widening approach to overcome problems with standard UCT performing a shallow search when there is a large action space.<br />
<br />
=== Simulation ===<br />
<br />
Instead of simulating with a random game playout, the authors use the value network to estimate the likely score associated with a state. This speeds up simulation (assuming the network is well trained), as the game does not actually need to be simulated.<br />
<br />
=== Backpropogation ===<br />
<br />
Standard backpropagation is used, updating both the values and number of visits stored in the path of parent nodes.<br />
<br />
<br />
== Supervised Learning ==<br />
<br />
During supervised training, data is gathered from the program AyumuGAT'16 ([8]). This program is also based on both an MCTS algorithm, and a high-performance AI curling program. 400 000 state-action pairs were generated during this training.<br />
<br />
=== Policy Network ===<br />
<br />
The policy network was trained to learn the action taken in each state. Here, the likelihood of the taken action was set to be 1, and the likelihood of other actions to be 0.<br />
<br />
=== Value Network ===<br />
<br />
The value network was trained by 'd-depth simulations and bootstrapping of the prediction to handle the high variance in rewards resulting from a sequence of stochastic moves' (quote taken from paper). In this case, ''m'' state-action pairs were sampled from the training data. For each pair, <math>(s_t, a_t)</math>, a state d' steps ahead was generated, <math>s_{t+d}</math>. This process dealt with uncertainty by considering all actions in this rollout to have no uncertainty, and allowing uncertainty in the last action, ''a<sub>t+d-1</sub>''. The value network is used to predict the value for this state, <math>z_t</math>, and the value is used for learning the value at ''s<sub>t</sub>''.<br />
<br />
=== Policy-Value Network ===<br />
<br />
The policy-value network was trained to maximize the similarity of the predicted policy and value, and the actual policy and value from a state. The learning algorithm parameters are:<br />
<br />
* Algorithm: stochastic gradient descent<br />
* Batch size: 256<br />
* Momentum: 0.9<br />
* L2 regularization: 0.0001<br />
* Training time: ~100 epochs<br />
* Learning rate: initialized at 0.01, reduced twice<br />
<br />
A multi-task loss function was used. This takes the summation of the cross-entropy losses of each prediction:<br />
<br />
[[File:curling_loss_function.png | 300px]]<br />
<br />
== Self-Play Reinforcement Learning ==<br />
<br />
After initialization by supervised learning, the algorithm uses self-play to further train itself. During this training, the policy network learns probabilities from the MCTS process, while the value network learns from game outcomes.<br />
<br />
At a game state ''s<sub>t</sub>'':<br />
<br />
1) the algorithm outputs a prediction ''z<sub>t</sub>''. This is en estimate of game score probabilities. It is based on similar past actions, and computed using kernel regression.<br />
<br />
2) the algorithm outputs a prediction <math>\pi_t</math>, representing a probability distribution of actions. These are proportional to estimated visit counts from MCTS, based on kernel density estimation.<br />
<br />
It is not clear how these predictions are created. It would seem likely that the policy-value network generates these, but the wording of the paper suggests they are generated from MCTS statistics.<br />
<br />
The policy-value network is updated by sampling data <math>(s, \pi, z)</math> from recent history of self-play. The same loss function is used as before.<br />
<br />
It is not clear how the improved network is used, as MCTS seems to be the driving process at this point.<br />
<br />
== Long-Term Strategy Learning ==<br />
<br />
Finally, the authors implement a new strategy to augment their algorithm for long-term play. In this context, this refers to playing a game over many ends, where the strategy to win a single end may not be a good strategy to win a full game. For example, scoring one point in an end, while being one point ahead, gives the advantage to the other team in the next round (as they will throw the last stone). The other team could then use the advantage to score two points, taking the lead.<br />
<br />
The authors build a 'winning percentage' table. This table stores the percentage of games won, based on the number of ends left, and the difference in score (current team - opposing team). This can be computed iteratively and using the probability distribution estimation of one-end scores.<br />
<br />
== Final Algorithms ==<br />
<br />
The authors make use of the following versions of their algorithm:<br />
<br />
=== KR-DL ===<br />
<br />
''Kernel regression-deep learning'': This algorithm is trained only by supervised learning.<br />
<br />
=== KR-DRL ===<br />
<br />
''Kernel regression-deep reinforcement learning'': This algorithm is trained by supervised learning (ie: initialized as the KR-DL algorithm), and again on self-play. During self-play, each shot is selected after 400 MCTS simulations of k=20 randomly selected actions. Data for self-play was collected over a week on 5 GPUS and generated 5 million game positions. The policy-value network was continually updated using samples from the latest 1 million game positions.<br />
<br />
=== KR-DRL-MES ===<br />
<br />
''Kernel regression-deep reinforcement learning-multi-ends-strategy'': This algorithm makes use of the winning percentage table generated from self-play.<br />
<br />
= Testing and Results =<br />
The authors use data from the public program AyumuGAT’16 to test. Testing is done with a simulated curling program [9]. This simulator does not deal with changing ice conditions, or sweeping, but does deal with stone trajectories and collisions.<br />
<br />
== Comparison of KR-DL-UCT and DL-UCT ==<br />
<br />
The first test compares an algorithm trained with kernel regression with an algorithm trained without kernel regression, to show the contribution that kernel regression adds to the performance. Both algorithms have networks initialised with the supervised learning, and then trained with two different algorithms for self-play. KR-DL-UCT uses the algorithm described above. The authors do not go into detail on how DL-UCT selects shots, but state that a constant is set to allow exploration.<br />
<br />
As an evaluation, both algorithms play 2000 games against the DL-UCT algorithm, which is frozen after supervised training. 1000 games are played with the algorithm taking the first, and 100 taking the 2nd, shots. The games were two-end games. The figure below shows each algorithm's winning percentage given different amounts of training data. While the DL-UCT outperforms the supervised-training-only-DL-UCT algorithm, the KR-DL-UCT algorithm performs much better.<br />
<br />
<center>[[File:curling_KR_test.png | 400px]]</center><br />
<br />
== Matches ==<br />
<br />
Finally, to test the performance of their multiple algorithms, the authors run matches between their algorithms and other existing programs. Each algorithm plays 200 matches against each other program, 100 of which are played as the first-playing team, and 100 as the second-playing team. Only 1 program was able to out-perform the KR-DRL algorithm. The authors state that this program, ''JiritsukunGAT'17'' also uses a deep network and hand-crafted features. However, the KR-DRL-MES algorithm was still able to out-perform this. Figure 4 shows the Elo ratings of the different programs. Note that the programs in blue are those created by the authors. They also played some games between their KR-DRL-MES and notable<br />
programs. Table 1, shows the details of the match results. ''JiritsukunGAT'17'' shows a similar level of performance but KR-DRL-MES is still the winner.<br />
<br />
<br />
<br />
[[File:curling_ratings.png|600px|thumb|center|Figure 4. Elo rating and winning percentages of our models and GAT rankers. Each match has 200 games (each program plays 100 pre-ordered games), because the player which has the last shot (the hammer shot) in each end would have an advantage.]]<br />
<br />
<br />
[[File:ttt.png|600px|thumb|center|Table 1. The 8-end game results for KR-DRL-MES against other programs alternating the opening player each game. The matches are held by following the rules of the latest GAT competition.]]<br />
<br />
= Conclusion & Critique =<br />
<br />
The authors have presented a new framework which incorporates a deep neural network for learning game strategy with a kernel-based Monte Carlo tree search from a continuous space. Without the use of any hand-crafted feature, their policy-value network is successfully trained using supervised learning followed by reinforcement learning with a high-fidelity simulator for the Olympic sport of curling. Following are my critiques on the paper:<br />
<br />
== Strengths ==<br />
<br />
This algorithm out-performs other high-performance algorithms (including past competition champions).<br />
<br />
I think the paper does a decent job of comparing the performance of their algorithm to others. They are able to clearly show the benefits of many of their additions.<br />
<br />
The authors do seem to be able to adopt strategies similar to those used in Go and other games to the continuous action-space domain. In addition, the final strategy needs no hand-crafted features for learning.<br />
<br />
== Weaknesses ==<br />
<br />
Somtimes, I found this paper difficult to follow. One problem was that the algorithms were introduced first, and then how they were used was described. So when the paper stated that self-play shots were taken after 400 simulations, it seemed unclear what simulations were being run and at what stage of the algorithm (ex: MCTS simulations, simulations sped up by using the value network, full simulations on the curling simulator). In particular, both the MCTS statistics and the policy-value network could be used to estimate both action probabilities and state values, so it is difficult to tell which is used in which case. There was also no clear distinction between discrete-space actions and continuous-space actions.<br />
<br />
While I think the comparison of different algorithms was done well, I believe it still lacked significant details. There were one-off mentioned in the paper which would have been nice to see as results. These include the statement that having a policy-value network in place of two networks lead to better performance.<br />
<br />
At this point, the algorithms used still rely on initialization by a pre-made program.<br />
<br />
There was little theoretical development or justification done in this paper.<br />
<br />
While curling is an interesting choice for demonstrating the algorithm, the fact that the simulations used did not support many of the key points of curling (ice conditions, sweeping) seems very limited. Another game, such as pool, would likely have offered some of the same challenges but offered more high-fidelity simulations/training.<br />
<br />
While the spatial placements of stones were discretized in a grid, the curl of thrown stones was discretized to only +/-1. This seems like it may limit learning high- and low-spin moves. It should be noted that having zero spins is not commonly used, to the best of my knowledge.<br />
<br />
=References=<br />
# Lee, K., Kim, S., Choi, J. & Lee, S. "Deep Reinforcement Learning in Continuous Action Spaces: a Case Study in the Game of Simulated Curling." Proceedings of the 35th International Conference on Machine Learning, in PMLR 80:2937-2946 (2018)<br />
# https://www.baeldung.com/java-monte-carlo-tree-search<br />
# https://jeffbradberry.com/posts/2015/09/intro-to-monte-carlo-tree-search/<br />
# https://int8.io/monte-carlo-tree-search-beginners-guide/<br />
# https://en.wikipedia.org/wiki/Monte_Carlo_tree_search<br />
# Silver, D., Huang, A., Maddison, C., Guez, A., Sifre, L.,Van Den Driessche, G., Schrittwieser, J., Antonoglou, I.,Panneershelvam, V., Lanctot, M., Dieleman, S., Grewe,D., Nham, J., Kalchbrenner, N.,Sutskever, I., Lillicrap, T.,Leach, M., Kavukcuoglu, K., Graepel, T., and Hassabis,D. Mastering the game of go with deep neural networksand tree search. Nature, pp. 484–489, 2016.<br />
# Silver, D., Schrittwieser, J., Simonyan, K., Antonoglou,I., Huang, A., Guez, A., Hubert, T., Baker, L., Lai, M., Bolton, A., Chen, Y., Lillicrap, T., Hui, F., Sifre, L.,van den Driessche, G., Graepel, T., and Hassabis, D.Mastering the game of go without human knowledge.Nature, pp. 354–359, 2017.<br />
# Yamamoto, M., Kato, S., and Iizuka, H. Digital curling strategy based on game tree search. In Proceedings of the IEEE Conference on Computational Intelligence and Games, CIG, pp. 474–480, 2015.<br />
# Ohto, K. and Tanaka, T. A curling agent based on the montecarlo tree search considering the similarity of the best action among similar states. In Proceedings of Advances in Computer Games, ACG, pp. 151–164, 2017.<br />
# Ito, T. and Kitasei, Y. Proposal and implementation of digital curling. In Proceedings of the IEEE Conference on Computational Intelligence and Games, CIG, pp. 469–473, 2015.<br />
# Ioffe, S. and Szegedy, C. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In Proceedings of the International Conference on Machine Learning, ICML, pp. 448–456, 2015.<br />
# Nair, V. and Hinton, G. Rectified linear units improve restricted boltzmann machines.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Deep_Reinforcement_Learning_in_Continuous_Action_Spaces_a_Case_Study_in_the_Game_of_Simulated_Curling&diff=41966Deep Reinforcement Learning in Continuous Action Spaces a Case Study in the Game of Simulated Curling2018-11-30T00:16:50Z<p>Gsahu: /* Critique */</p>
<hr />
<div>This page provides a summary and critique of the paper '''Deep Reinforcement Learning in Continuous Action Spaces: a Case Study in the Game of Simulated Curling''' [[http://proceedings.mlr.press/v80/lee18b/lee18b.pdf Online Source]], published in ICML 2018. The source code for this paper is available [https://github.com/leekwoon/KR-DL-UCT here]<br />
<br />
= Introduction and Motivation =<br />
<br />
In recent years, Reinforcement Learning methods have been applied to many different games, such as chess and checkers. More recently, the use of CNN's has allowed neural networks to out-perform humans in many difficult games, such as Go. However, many of these cases involve a discrete state or action space; the number of actions a player can take and/or the number of possible game states are finite. <br />
<br />
Interacting with the real world (e.g.; a scenario that involves moving physical objects) typically involves working with a continuous action space. It is thus important to develop strategies for dealing with continuous action spaces. Deep neural networks that are designed to succeed in finite action spaces are not necessarily suitable for continuous action space problems. This is due to the fact that deterministic discretization of a continuous action space causes strong biases in policy evaluation and improvement. <br />
<br />
This paper introduces a method to allow learning with continuous action spaces. A CNN is used to perform learning on a discretion state and action spaces, and then a continuous action search is performed on these discrete results.<br />
<br />
Curling is chosen as a domain to test the network on. Curling was chosen due to its large action space, potential for complicated strategies, and need for precise interactions.<br />
<br />
== Curling ==<br />
<br />
Curling is a sport played by two teams on a long sheet of ice. Roughly, the goal is for each time to slide rocks closer to the target on the other end of the sheet than the other team. The next sections will provide a background on the game play, and potential challenges/concerns for learning algorithms. A terminology section follows.<br />
<br />
=== Game play ===<br />
<br />
A game of curling is divided into ends. In each end, players from both teams alternate throwing (sliding) eight rocks to the other end of the ice sheet, known as the house. Rocks must land in a certain area in order to stay in play, and must touch or be inside concentric rings (12ft diameter and smaller) in order to score points. At the end of each end, the team with rocks closest to the center of the house scores points.<br />
<br />
When throwing a rock, the curling can spin the rock. This allows the rock to 'curl' its path towards the house and can allow rocks to travel around other rocks. Team members are also able to sweep the ice in front of a moving rock in order to decrease friction, which allows for fine-tuning of distance (though the physics of sweeping are not implemented in the simulation used).<br />
<br />
Curling offers many possible high-level actions, which are directed by a team member to the throwing member. An example set of these includes:<br />
<br />
* Draw: Throw a rock to a target location<br />
* Freeze: Draw a rock up against another rock<br />
* Takeout: Knock another rock out of the house. Can be combined with different ricochet directions<br />
* Guard: Place a rock in front of another, to block other rocks (ex: takeouts)<br />
<br />
=== Challenges for AI ===<br />
<br />
Curling offers many challenges for curling based on its physics and rules. This section lists a few concerns.<br />
<br />
The effect of changing actions can be highly nonlinear and discontinuous. This can be seen when considering that a 1-cm deviation in a path can make the difference between a high-speed collision, or lack of collision.<br />
<br />
Curling will require both offensive and defensive strategies. For example, consider the fact that the last team to throw a rock each end only needs to place that rock closer than the opposing team's rocks to score a point and invalidate any opposing rocks in the house. The opposing team should thus be considering how to prevent this from happening, in addition to scoring points themselves.<br />
<br />
Curling also has a concept known as 'the hammer'. The hammer belongs to the team which throws the last rock each end, providing an advantage, and is given to the team that does not score points each end. It could very well be a good strategy to try not to win a single point in an end (if already ahead in points, etc), as this would give the advantage to the opposing team.<br />
<br />
Finally, curling has a rule known as the 'Free Guard Zone'. This applies to the first 4 rocks thrown (2 from each team). If they land short of the house, but still in play, then the rocks are not allowed to be removed (via collisions) until all of the first 4 rocks have been thrown.<br />
<br />
=== Terminology ===<br />
<br />
* End: A round of the game<br />
* House: The end of the sheet of ice, which contains<br />
* Hammer: The team that throws the last rock of an end 'has the hammer'<br />
* Hog Line: thick line that is drawn in front of the house, orthogonal to the length of the ice sheet. Rocks must pass this line to remain in play.<br />
* Back Line: think line drawn just behind the house. Rocks that pass this line are removed from play.<br />
<br />
<br />
== Related Work ==<br />
<br />
=== AlphaGo Lee ===<br />
<br />
AlphaGo Lee (Silver et al., 2016, [5]) refers to an algorithm used to play the game Go, which was able to defeat international champion Lee Sedol. <br />
<br />
<br />
Go game:<br />
* Start with 19x19 empty board<br />
* One player take black stones and the other take white stones<br />
* Two players take turns to put stones on the board<br />
* Rules:<br />
1. If one connected part is completely surrounded by the opponents stones, remove it from the board<br />
<br />
2. Ko rule: Forbids a board play to repeat a board position<br />
* End when there is no valuable moves on the board.<br />
* Count the territory of both players.<br />
* Add 7.5 points to whites points (called Komi).<br />
[[File:go.JPG|700px|center]]<br />
<br />
Two neural networks were trained on the moves of human experts, to act as both a policy network and a value network. A Monte Carlo Tree Search algorithm was used for policy improvement.<br />
<br />
The AlphaGo Lee policy network predicts the best move given a board configuration. It has a CNN architecture with 13 hidden layers, and it is trained using expert game play data and improved through self-play.<br />
<br />
The value network evaluates the probability of winning given a board configuration. It consists of a CNN with 14 hidden layers, and it is trained using self-play data from the policy network. <br />
<br />
Finally, the two networks are combined using Monte-Carlo Tree Search, which performs look ahead search to select the actions for game play.<br />
<br />
The use of both policy and value networks are reflected in this paper's work.<br />
<br />
=== AlphaGo Zero ===<br />
<br />
AlphaGo Zero (Silver et al., 2017, [6]) is an improvement on the AlphaGo Lee algorithm. AlphaGo Zero uses a unified neural network in place of the separate policy and value networks and is trained on self-play, without the need of expert training.<br />
<br />
The unification of networks and self-play are also reflected in this paper.<br />
<br />
=== Curling Algorithms ===<br />
<br />
Some past algorithms have been proposed to deal with continuous action spaces. For example, (Yammamoto et al, 2015, [7]) use game tree search methods in a discretized space. The value of an action is taken as the average of nearby values, with respect to some knowledge of execution uncertainty.<br />
<br />
=== Monte Carlo Tree Search ===<br />
<br />
Monte Carlo Tree Search algorithms have been applied to continuous action spaces. These algorithms, to be discussed in further detail, balance exploration of different states, with knowledge of paths of execution through past games. An MCTS called <math>KR-UCT</math> which is able to find effective selections and use kernel regression (KR) and kernel density estimation(KDE) to estimate rewards using neighborhood information has been applied to continuous action space by researchers. <br />
<br />
With bandit problem, scholars used hierarchical optimistic optimization(HOO) to create a cover tree and divide the action space into small ranges at different depths, where the most promising node will create fine granularity estimates.<br />
<br />
=== Curling Physics and Simulation ===<br />
<br />
Several references in the paper refer to the study and simulation of curling physics. Scholars have analyzed friction coefficients between curling stones and ice. While modelling the changes in friction on ice is not possible, a fixed friction coefficient was predefined in the simulation. The behavior of the stones was also modeled. Important parameters are trained from professional players. The authors used the same parameters in this paper.<br />
<br />
== General Background of Algorithms ==<br />
<br />
=== Policy and Value Functions ===<br />
<br />
A policy function is trained to provide the best action to take, given a current state. Policy iteration is an algorithm used to improve a policy over time. This is done by alternating between policy evaluation and policy improvement.<br />
<br />
POLICY IMPROVEMENT: LEARNING ACTION POLICY<br />
<br />
Action policy <math> p_{\sigma}(a|s) </math> outputs a probability distribution over all eligible moves <math> a </math>. Here <math> \sigma </math> denotes the weights of a neural network that approximates the policy. <math>s</math> denotes the set of states and <math>a</math> denotes the set of actions taken in the environment. The policy is a function that returns a action given the state at which the agent is present. The policy gradient reinforcement learning can be used to train action policy. It is updated by stochastic gradient ascent in the direction that maximizes the expected outcome at each time step t,<br />
\[ \Delta \rho \propto \frac{\partial p_{\rho}(a_t|s_t)}{\partial \rho} r(s_t) \]<br />
where <math> r(s_t) </math> is the return.<br />
<br />
POLICY EVALUATION: LEARNING VALUE FUNCTIONS<br />
<br />
A value function is trained to estimate the value of a value of being in a certain state with parameter <math> \theta </math>. It is trained based on records of state-action-reward sets <math> (s, r(s)) </math> by using stochastic gradient de- scent to minimize the mean squared error (MSE) between the predicted regression value and the corresponding outcome,<br />
\[ \Delta \theta \propto \frac{\partial v_{\theta}(s)}{\partial \theta}(r(s)-v_{\theta}(s)) \]<br />
<br />
=== Monte Carlo Tree Search ===<br />
<br />
Monte Carlo Tree Search (MCTS) is a search algorithm used for finite-horizon tasks (ex: in curling, only 16 moves, or throw stones, are taken each end).<br />
<br />
MCTS is a tree search algorithm similar to minimax. However, MCTS is probabilistic and does not need to explore a full game tree or even a tree reduced with alpha-beta pruning. This makes it tractable for games such as GO, and curling.<br />
<br />
Nodes of the tree are game states, and branches represent actions. Each node stores statistics on how many times it has been visited by the MCTS, as well as the number of wins encountered by playouts from that position. A node has been considered 'visited' if a full playout has started from that node. A node is considered 'expanded' if all its children have been visited.<br />
<br />
MCTS begins with the '''selection''' phase, which involves traversing known states/actions. This involves expanding the tree by beginning at the root node, and selecting the child/score with the highest 'score'. From each successive node, a path down to a root node is explored in a similar fashion.<br />
<br />
The next phase, '''expansion''', begins when the algorithm reaches a node where not all children have been visited (ie: the node has not been fully expanded). In the expansion phase, children of the node are visited, and '''simulations''' run from their states.<br />
<br />
Once the new child is expanded, '''simulation''' takes place. This refers to a full playout of the game from the point of the current node, and can involve many strategies, such as randomly taken moves, the use of heuristics, etc.<br />
<br />
The final phase is '''update''' or '''back-propagation''' (unrelated to the neural network algorithm). In this phase, the result of the '''simulation''' (ie: win/lose) is update in the statistics of all parent nodes.<br />
<br />
A selection function known as Upper Confidence Bound (UCT) can be used for selecting which node to select. The formula for this equation is shown below [[https://www.baeldung.com/java-monte-carlo-tree-search source]]. Note that the first term essentially acts as an average score of games played from a certain node. The second term, meanwhile, will grow when sibling nodes are expanded. This means that unexplored nodes will gradually increase their UCT score, and be selected in the future.<br />
<br />
<math> \frac{w_i}{n_i} + c \sqrt{\frac{\ln t}{n_i}} </math><br />
<br />
In which<br />
<br />
* <math> w_i = </math> number of wins after <math> i</math>th move<br />
* <math> n_i = </math> number of simulations after <math> i</math>th move<br />
* <math> c = </math> exploration parameter (theoritically eqal to <math> \sqrt{2}</math>)<br />
* <math> t = </math> total number of simulations for the parent node<br />
<br />
<br />
Sources: 2,3,4<br />
<br />
[[File:MCTS_Diagram.jpg | 500px|center]]<br />
<br />
=== Kernel Regression ===<br />
<br />
Kernel regression is a form of weighted averaging which uses a kernel function as a weight to estimate the conditional expectation of a random variable. Given two items of data, '''x''', each of which has a value '''y''' associated with them, and a choice of Kernel '''K''', the kernel functions outputs a weighting factor. An estimate of the value of a new, unseen point, is then calculated as the weighted average of values of surrounding points.<br />
<br />
A typical kernel is a Gaussian kernel, shown below. The formula for calculating estimated value is shown below as well (sources: Lee et al.).<br />
<br />
[[File:gaussian_kernel.png | 400 px]]<br />
<br />
[[File:kernel_regression.png | 250 px]]<br />
<br />
The denominator of the conditional expectation is related to kernel density estimation, which is defined as <math display="inline">W(x)=\sum_{i=0}^n K(x,x_i)</math>.<br />
<br />
In this case, the combination of the two-act to weigh scores of samples closest to '''x''' more strongly.<br />
<br />
= Methods =<br />
<br />
== Variable Definitions ==<br />
<br />
The following variables are used often in the paper:<br />
<br />
* <math>s</math>: A state in the game, as described below as the input to the network.<br />
* <math>s_t</math>: The state at a certain time-step of the game. Time-steps refer to full turns in the game<br />
* <math>a_t</math>: The action taken in state <math>s_t</math><br />
* <math>A_t</math>: The actions taken for sibling nodes related to <math>a_t</math> in MCTS<br />
* <math>n_{a_t}</math>: The number of visits to node a in MCTS<br />
* <math>v_{a_t}</math>: The MCTS value estimate of a node<br />
<br />
== Network Design ==<br />
<br />
The authors design a CNN called the 'policy-value' network. The network consists of a common network structure, which is then split into 'policy' and 'value' outputs. This network is trained to learn a probability distribution of actions to take, and expected rewards, given an input state.<br />
<br />
=== Shared Structure ===<br />
<br />
The network consists of 1 convolutional layer followed by 9 residual blocks, each block consisting of 2 convolutional layers with 32 3x3 filters. The structure of this network is shown below:<br />
<br />
<br />
[[File:curling_network_layers.png|600px|thumb|center|Figure 2. A detail description of our policy-value network. The shared network is composed of one convolutional layer and nine residual blocks. Each residual block (explained in b) has two convolutional layer with batch normalization (Ioffe & Szegedy, 2015[11]) followed by the addition of the input and the residual block. Each layer in the shared network uses 3x3 filters. The policy head<br />
has two more convolutional layers, while the value head has two fully connected layers on top of a convolutional layer. For the activation function of each convolutional layer, ReLU (Nair & Hinton[12]) is used.]]<br />
<br />
<br />
<br />
the input to this network is the following:<br />
* Location of stones<br />
* Order to tee (the center of the sheet)<br />
* A 32x32 grid of representation of the ice sheet, representing which stones are present in each grid cell.<br />
<br />
The authors do not describe how the stone-based information is added to the 32x32 grid as input to the network.<br />
<br />
=== Policy Network ===<br />
<br />
The policy head is created by adding 2 convolutional layers with 2 (two) 3x3 filters to the main body of the network. The output of the policy head is a distribution of probabilities of the actions to select the best shot out of a 32x32x2 set of actions. The actions represent target locations in the grid and spin direction of the stone.<br />
<br />
[[File:policy-value-net.PNG | 700px]]<br />
<br />
=== Value Network ===<br />
<br />
The valve head is created by adding a convolution layer with 1 3x3 filter, and dense layers of 256 and 17 units, to the shared network. The 17 output units represent a probability of scores in the range of [-8,8], which are the possible scores at each end of a curling game.<br />
<br />
== Continuous Action Search ==<br />
<br />
The policy head of the network only outputs actions from a discretized action space. For real-life interactions, and especially in curling, this will not suffice, as very fine adjustments to actions can make significant differences in outcomes.<br />
<br />
Actions in the continuous space are generated using an MCTS algorithm, with the following steps:<br />
<br />
=== Selection ===<br />
<br />
From a given state, the list of already-visited actions is denoted as A<sub>t</sub>. Scores and the number of visits to each node are estimated using the equations below (the first equation shows the expectation of the end value for one-end games). These are likely estimated rather than simply taken from the MCTS statistics to help account for the differences in a continuous action space.<br />
<br />
[[File:curling_kernel_equations.png | 400px]]<br />
<br />
The UCB formula is then used to select an action to expand.<br />
<br />
The actions that are taken in the simulator appear to be drawn from a Gaussian centered around <math>a_t</math>. This allows exploration in the continuous action space.<br />
<br />
=== Expansion ===<br />
<br />
The authors use a variant of regular UCT for expansion. In this case, they expand a new node only when existing nodes have been visited a certain number of times. The authors utilize a widening approach to overcome problems with standard UCT performing a shallow search when there is a large action space.<br />
<br />
=== Simulation ===<br />
<br />
Instead of simulating with a random game playout, the authors use the value network to estimate the likely score associated with a state. This speeds up simulation (assuming the network is well trained), as the game does not actually need to be simulated.<br />
<br />
=== Backpropogation ===<br />
<br />
Standard backpropagation is used, updating both the values and number of visits stored in the path of parent nodes.<br />
<br />
<br />
== Supervised Learning ==<br />
<br />
During supervised training, data is gathered from the program AyumuGAT'16 ([8]). This program is also based on both an MCTS algorithm, and a high-performance AI curling program. 400 000 state-action pairs were generated during this training.<br />
<br />
=== Policy Network ===<br />
<br />
The policy network was trained to learn the action taken in each state. Here, the likelihood of the taken action was set to be 1, and the likelihood of other actions to be 0.<br />
<br />
=== Value Network ===<br />
<br />
The value network was trained by 'd-depth simulations and bootstrapping of the prediction to handle the high variance in rewards resulting from a sequence of stochastic moves' (quote taken from paper). In this case, ''m'' state-action pairs were sampled from the training data. For each pair, <math>(s_t, a_t)</math>, a state d' steps ahead was generated, <math>s_{t+d}</math>. This process dealt with uncertainty by considering all actions in this rollout to have no uncertainty, and allowing uncertainty in the last action, ''a<sub>t+d-1</sub>''. The value network is used to predict the value for this state, <math>z_t</math>, and the value is used for learning the value at ''s<sub>t</sub>''.<br />
<br />
=== Policy-Value Network ===<br />
<br />
The policy-value network was trained to maximize the similarity of the predicted policy and value, and the actual policy and value from a state. The learning algorithm parameters are:<br />
<br />
* Algorithm: stochastic gradient descent<br />
* Batch size: 256<br />
* Momentum: 0.9<br />
* L2 regularization: 0.0001<br />
* Training time: ~100 epochs<br />
* Learning rate: initialized at 0.01, reduced twice<br />
<br />
A multi-task loss function was used. This takes the summation of the cross-entropy losses of each prediction:<br />
<br />
[[File:curling_loss_function.png | 300px]]<br />
<br />
== Self-Play Reinforcement Learning ==<br />
<br />
After initialization by supervised learning, the algorithm uses self-play to further train itself. During this training, the policy network learns probabilities from the MCTS process, while the value network learns from game outcomes.<br />
<br />
At a game state ''s<sub>t</sub>'':<br />
<br />
1) the algorithm outputs a prediction ''z<sub>t</sub>''. This is en estimate of game score probabilities. It is based on similar past actions, and computed using kernel regression.<br />
<br />
2) the algorithm outputs a prediction <math>\pi_t</math>, representing a probability distribution of actions. These are proportional to estimated visit counts from MCTS, based on kernel density estimation.<br />
<br />
It is not clear how these predictions are created. It would seem likely that the policy-value network generates these, but the wording of the paper suggests they are generated from MCTS statistics.<br />
<br />
The policy-value network is updated by sampling data <math>(s, \pi, z)</math> from recent history of self-play. The same loss function is used as before.<br />
<br />
It is not clear how the improved network is used, as MCTS seems to be the driving process at this point.<br />
<br />
== Long-Term Strategy Learning ==<br />
<br />
Finally, the authors implement a new strategy to augment their algorithm for long-term play. In this context, this refers to playing a game over many ends, where the strategy to win a single end may not be a good strategy to win a full game. For example, scoring one point in an end, while being one point ahead, gives the advantage to the other team in the next round (as they will throw the last stone). The other team could then use the advantage to score two points, taking the lead.<br />
<br />
The authors build a 'winning percentage' table. This table stores the percentage of games won, based on the number of ends left, and the difference in score (current team - opposing team). This can be computed iteratively and using the probability distribution estimation of one-end scores.<br />
<br />
== Final Algorithms ==<br />
<br />
The authors make use of the following versions of their algorithm:<br />
<br />
=== KR-DL ===<br />
<br />
''Kernel regression-deep learning'': This algorithm is trained only by supervised learning.<br />
<br />
=== KR-DRL ===<br />
<br />
''Kernel regression-deep reinforcement learning'': This algorithm is trained by supervised learning (ie: initialized as the KR-DL algorithm), and again on self-play. During self-play, each shot is selected after 400 MCTS simulations of k=20 randomly selected actions. Data for self-play was collected over a week on 5 GPUS and generated 5 million game positions. The policy-value network was continually updated using samples from the latest 1 million game positions.<br />
<br />
=== KR-DRL-MES ===<br />
<br />
''Kernel regression-deep reinforcement learning-multi-ends-strategy'': This algorithm makes use of the winning percentage table generated from self-play.<br />
<br />
= Testing and Results =<br />
The authors use data from the public program AyumuGAT’16 to test. Testing is done with a simulated curling program [9]. This simulator does not deal with changing ice conditions, or sweeping, but does deal with stone trajectories and collisions.<br />
<br />
== Comparison of KR-DL-UCT and DL-UCT ==<br />
<br />
The first test compares an algorithm trained with kernel regression with an algorithm trained without kernel regression, to show the contribution that kernel regression adds to the performance. Both algorithms have networks initialised with the supervised learning, and then trained with two different algorithms for self-play. KR-DL-UCT uses the algorithm described above. The authors do not go into detail on how DL-UCT selects shots, but state that a constant is set to allow exploration.<br />
<br />
As an evaluation, both algorithms play 2000 games against the DL-UCT algorithm, which is frozen after supervised training. 1000 games are played with the algorithm taking the first, and 100 taking the 2nd, shots. The games were two-end games. The figure below shows each algorithm's winning percentage given different amounts of training data. While the DL-UCT outperforms the supervised-training-only-DL-UCT algorithm, the KR-DL-UCT algorithm performs much better.<br />
<br />
<center>[[File:curling_KR_test.png | 400px]]</center><br />
<br />
== Matches ==<br />
<br />
Finally, to test the performance of their multiple algorithms, the authors run matches between their algorithms and other existing programs. Each algorithm plays 200 matches against each other program, 100 of which are played as the first-playing team, and 100 as the second-playing team. Only 1 program was able to out-perform the KR-DRL algorithm. The authors state that this program, ''JiritsukunGAT'17'' also uses a deep network and hand-crafted features. However, the KR-DRL-MES algorithm was still able to out-perform this. Figure 4 shows the Elo ratings of the different programs. Note that the programs in blue are those created by the authors. They also played some games between their KR-DRL-MES and notable<br />
programs. Table 1, shows the details of the match results. ''JiritsukunGAT'17'' shows a similar level of performance but KR-DRL-MES is still the winner.<br />
<br />
<br />
<br />
[[File:curling_ratings.png|600px|thumb|center|Figure 4. Elo rating and winning percentages of our models and GAT rankers. Each match has 200 games (each program plays 100 pre-ordered games), because the player which has the last shot (the hammer shot) in each end would have an advantage.]]<br />
<br />
<br />
[[File:ttt.png|600px|thumb|center|Table 1. The 8-end game results for KR-DRL-MES against other programs alternating the opening player each game. The matches are held by following the rules of the latest GAT competition.]]<br />
<br />
= Conclusion & Critique =<br />
<br />
The authors have presented a new framework which incorporates a deep neural network for learning game strategy with a kernel-based Monte Carlo tree search from a continuous space. Without the use of any hand-crafted feature, their policy-value network is successfully trained using supervised learning followed by reinforcement learning with a high-fidelity simulator for the Olympic sport of curling.<br />
<br />
== Strengths ==<br />
<br />
This algorithm out-performs other high-performance algorithms (including past competition champions).<br />
<br />
I think the paper does a decent job of comparing the performance of their algorithm to others. They are able to clearly show the benefits of many of their additions.<br />
<br />
The authors do seem to be able to adopt strategies similar to those used in Go and other games to the continuous action-space domain. In addition, the final strategy needs no hand-crafted features for learning.<br />
<br />
== Weaknesses ==<br />
<br />
Somtimes, I found this paper difficult to follow. One problem was that the algorithms were introduced first, and then how they were used was described. So when the paper stated that self-play shots were taken after 400 simulations, it seemed unclear what simulations were being run and at what stage of the algorithm (ex: MCTS simulations, simulations sped up by using the value network, full simulations on the curling simulator). In particular, both the MCTS statistics and the policy-value network could be used to estimate both action probabilities and state values, so it is difficult to tell which is used in which case. There was also no clear distinction between discrete-space actions and continuous-space actions.<br />
<br />
While I think the comparison of different algorithms was done well, I believe it still lacked significant details. There were one-off mentioned in the paper which would have been nice to see as results. These include the statement that having a policy-value network in place of two networks lead to better performance.<br />
<br />
At this point, the algorithms used still rely on initialization by a pre-made program.<br />
<br />
There was little theoretical development or justification done in this paper.<br />
<br />
While curling is an interesting choice for demonstrating the algorithm, the fact that the simulations used did not support many of the key points of curling (ice conditions, sweeping) seems very limited. Another game, such as pool, would likely have offered some of the same challenges but offered more high-fidelity simulations/training.<br />
<br />
While the spatial placements of stones were discretized in a grid, the curl of thrown stones was discretized to only +/-1. This seems like it may limit learning high- and low-spin moves. It should be noted that having zero spins is not commonly used, to the best of my knowledge.<br />
<br />
=References=<br />
# Lee, K., Kim, S., Choi, J. & Lee, S. "Deep Reinforcement Learning in Continuous Action Spaces: a Case Study in the Game of Simulated Curling." Proceedings of the 35th International Conference on Machine Learning, in PMLR 80:2937-2946 (2018)<br />
# https://www.baeldung.com/java-monte-carlo-tree-search<br />
# https://jeffbradberry.com/posts/2015/09/intro-to-monte-carlo-tree-search/<br />
# https://int8.io/monte-carlo-tree-search-beginners-guide/<br />
# https://en.wikipedia.org/wiki/Monte_Carlo_tree_search<br />
# Silver, D., Huang, A., Maddison, C., Guez, A., Sifre, L.,Van Den Driessche, G., Schrittwieser, J., Antonoglou, I.,Panneershelvam, V., Lanctot, M., Dieleman, S., Grewe,D., Nham, J., Kalchbrenner, N.,Sutskever, I., Lillicrap, T.,Leach, M., Kavukcuoglu, K., Graepel, T., and Hassabis,D. Mastering the game of go with deep neural networksand tree search. Nature, pp. 484–489, 2016.<br />
# Silver, D., Schrittwieser, J., Simonyan, K., Antonoglou,I., Huang, A., Guez, A., Hubert, T., Baker, L., Lai, M., Bolton, A., Chen, Y., Lillicrap, T., Hui, F., Sifre, L.,van den Driessche, G., Graepel, T., and Hassabis, D.Mastering the game of go without human knowledge.Nature, pp. 354–359, 2017.<br />
# Yamamoto, M., Kato, S., and Iizuka, H. Digital curling strategy based on game tree search. In Proceedings of the IEEE Conference on Computational Intelligence and Games, CIG, pp. 474–480, 2015.<br />
# Ohto, K. and Tanaka, T. A curling agent based on the montecarlo tree search considering the similarity of the best action among similar states. In Proceedings of Advances in Computer Games, ACG, pp. 151–164, 2017.<br />
# Ito, T. and Kitasei, Y. Proposal and implementation of digital curling. In Proceedings of the IEEE Conference on Computational Intelligence and Games, CIG, pp. 469–473, 2015.<br />
# Ioffe, S. and Szegedy, C. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In Proceedings of the International Conference on Machine Learning, ICML, pp. 448–456, 2015.<br />
# Nair, V. and Hinton, G. Rectified linear units improve restricted boltzmann machines.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Deep_Reinforcement_Learning_in_Continuous_Action_Spaces_a_Case_Study_in_the_Game_of_Simulated_Curling&diff=41964Deep Reinforcement Learning in Continuous Action Spaces a Case Study in the Game of Simulated Curling2018-11-30T00:14:01Z<p>Gsahu: /* Weaknesses */</p>
<hr />
<div>This page provides a summary and critique of the paper '''Deep Reinforcement Learning in Continuous Action Spaces: a Case Study in the Game of Simulated Curling''' [[http://proceedings.mlr.press/v80/lee18b/lee18b.pdf Online Source]], published in ICML 2018. The source code for this paper is available [https://github.com/leekwoon/KR-DL-UCT here]<br />
<br />
= Introduction and Motivation =<br />
<br />
In recent years, Reinforcement Learning methods have been applied to many different games, such as chess and checkers. More recently, the use of CNN's has allowed neural networks to out-perform humans in many difficult games, such as Go. However, many of these cases involve a discrete state or action space; the number of actions a player can take and/or the number of possible game states are finite. <br />
<br />
Interacting with the real world (e.g.; a scenario that involves moving physical objects) typically involves working with a continuous action space. It is thus important to develop strategies for dealing with continuous action spaces. Deep neural networks that are designed to succeed in finite action spaces are not necessarily suitable for continuous action space problems. This is due to the fact that deterministic discretization of a continuous action space causes strong biases in policy evaluation and improvement. <br />
<br />
This paper introduces a method to allow learning with continuous action spaces. A CNN is used to perform learning on a discretion state and action spaces, and then a continuous action search is performed on these discrete results.<br />
<br />
Curling is chosen as a domain to test the network on. Curling was chosen due to its large action space, potential for complicated strategies, and need for precise interactions.<br />
<br />
== Curling ==<br />
<br />
Curling is a sport played by two teams on a long sheet of ice. Roughly, the goal is for each time to slide rocks closer to the target on the other end of the sheet than the other team. The next sections will provide a background on the game play, and potential challenges/concerns for learning algorithms. A terminology section follows.<br />
<br />
=== Game play ===<br />
<br />
A game of curling is divided into ends. In each end, players from both teams alternate throwing (sliding) eight rocks to the other end of the ice sheet, known as the house. Rocks must land in a certain area in order to stay in play, and must touch or be inside concentric rings (12ft diameter and smaller) in order to score points. At the end of each end, the team with rocks closest to the center of the house scores points.<br />
<br />
When throwing a rock, the curling can spin the rock. This allows the rock to 'curl' its path towards the house and can allow rocks to travel around other rocks. Team members are also able to sweep the ice in front of a moving rock in order to decrease friction, which allows for fine-tuning of distance (though the physics of sweeping are not implemented in the simulation used).<br />
<br />
Curling offers many possible high-level actions, which are directed by a team member to the throwing member. An example set of these includes:<br />
<br />
* Draw: Throw a rock to a target location<br />
* Freeze: Draw a rock up against another rock<br />
* Takeout: Knock another rock out of the house. Can be combined with different ricochet directions<br />
* Guard: Place a rock in front of another, to block other rocks (ex: takeouts)<br />
<br />
=== Challenges for AI ===<br />
<br />
Curling offers many challenges for curling based on its physics and rules. This section lists a few concerns.<br />
<br />
The effect of changing actions can be highly nonlinear and discontinuous. This can be seen when considering that a 1-cm deviation in a path can make the difference between a high-speed collision, or lack of collision.<br />
<br />
Curling will require both offensive and defensive strategies. For example, consider the fact that the last team to throw a rock each end only needs to place that rock closer than the opposing team's rocks to score a point and invalidate any opposing rocks in the house. The opposing team should thus be considering how to prevent this from happening, in addition to scoring points themselves.<br />
<br />
Curling also has a concept known as 'the hammer'. The hammer belongs to the team which throws the last rock each end, providing an advantage, and is given to the team that does not score points each end. It could very well be a good strategy to try not to win a single point in an end (if already ahead in points, etc), as this would give the advantage to the opposing team.<br />
<br />
Finally, curling has a rule known as the 'Free Guard Zone'. This applies to the first 4 rocks thrown (2 from each team). If they land short of the house, but still in play, then the rocks are not allowed to be removed (via collisions) until all of the first 4 rocks have been thrown.<br />
<br />
=== Terminology ===<br />
<br />
* End: A round of the game<br />
* House: The end of the sheet of ice, which contains<br />
* Hammer: The team that throws the last rock of an end 'has the hammer'<br />
* Hog Line: thick line that is drawn in front of the house, orthogonal to the length of the ice sheet. Rocks must pass this line to remain in play.<br />
* Back Line: think line drawn just behind the house. Rocks that pass this line are removed from play.<br />
<br />
<br />
== Related Work ==<br />
<br />
=== AlphaGo Lee ===<br />
<br />
AlphaGo Lee (Silver et al., 2016, [5]) refers to an algorithm used to play the game Go, which was able to defeat international champion Lee Sedol. <br />
<br />
<br />
Go game:<br />
* Start with 19x19 empty board<br />
* One player take black stones and the other take white stones<br />
* Two players take turns to put stones on the board<br />
* Rules:<br />
1. If one connected part is completely surrounded by the opponents stones, remove it from the board<br />
<br />
2. Ko rule: Forbids a board play to repeat a board position<br />
* End when there is no valuable moves on the board.<br />
* Count the territory of both players.<br />
* Add 7.5 points to whites points (called Komi).<br />
[[File:go.JPG|700px|center]]<br />
<br />
Two neural networks were trained on the moves of human experts, to act as both a policy network and a value network. A Monte Carlo Tree Search algorithm was used for policy improvement.<br />
<br />
The AlphaGo Lee policy network predicts the best move given a board configuration. It has a CNN architecture with 13 hidden layers, and it is trained using expert game play data and improved through self-play.<br />
<br />
The value network evaluates the probability of winning given a board configuration. It consists of a CNN with 14 hidden layers, and it is trained using self-play data from the policy network. <br />
<br />
Finally, the two networks are combined using Monte-Carlo Tree Search, which performs look ahead search to select the actions for game play.<br />
<br />
The use of both policy and value networks are reflected in this paper's work.<br />
<br />
=== AlphaGo Zero ===<br />
<br />
AlphaGo Zero (Silver et al., 2017, [6]) is an improvement on the AlphaGo Lee algorithm. AlphaGo Zero uses a unified neural network in place of the separate policy and value networks and is trained on self-play, without the need of expert training.<br />
<br />
The unification of networks and self-play are also reflected in this paper.<br />
<br />
=== Curling Algorithms ===<br />
<br />
Some past algorithms have been proposed to deal with continuous action spaces. For example, (Yammamoto et al, 2015, [7]) use game tree search methods in a discretized space. The value of an action is taken as the average of nearby values, with respect to some knowledge of execution uncertainty.<br />
<br />
=== Monte Carlo Tree Search ===<br />
<br />
Monte Carlo Tree Search algorithms have been applied to continuous action spaces. These algorithms, to be discussed in further detail, balance exploration of different states, with knowledge of paths of execution through past games. An MCTS called <math>KR-UCT</math> which is able to find effective selections and use kernel regression (KR) and kernel density estimation(KDE) to estimate rewards using neighborhood information has been applied to continuous action space by researchers. <br />
<br />
With bandit problem, scholars used hierarchical optimistic optimization(HOO) to create a cover tree and divide the action space into small ranges at different depths, where the most promising node will create fine granularity estimates.<br />
<br />
=== Curling Physics and Simulation ===<br />
<br />
Several references in the paper refer to the study and simulation of curling physics. Scholars have analyzed friction coefficients between curling stones and ice. While modelling the changes in friction on ice is not possible, a fixed friction coefficient was predefined in the simulation. The behavior of the stones was also modeled. Important parameters are trained from professional players. The authors used the same parameters in this paper.<br />
<br />
== General Background of Algorithms ==<br />
<br />
=== Policy and Value Functions ===<br />
<br />
A policy function is trained to provide the best action to take, given a current state. Policy iteration is an algorithm used to improve a policy over time. This is done by alternating between policy evaluation and policy improvement.<br />
<br />
POLICY IMPROVEMENT: LEARNING ACTION POLICY<br />
<br />
Action policy <math> p_{\sigma}(a|s) </math> outputs a probability distribution over all eligible moves <math> a </math>. Here <math> \sigma </math> denotes the weights of a neural network that approximates the policy. <math>s</math> denotes the set of states and <math>a</math> denotes the set of actions taken in the environment. The policy is a function that returns a action given the state at which the agent is present. The policy gradient reinforcement learning can be used to train action policy. It is updated by stochastic gradient ascent in the direction that maximizes the expected outcome at each time step t,<br />
\[ \Delta \rho \propto \frac{\partial p_{\rho}(a_t|s_t)}{\partial \rho} r(s_t) \]<br />
where <math> r(s_t) </math> is the return.<br />
<br />
POLICY EVALUATION: LEARNING VALUE FUNCTIONS<br />
<br />
A value function is trained to estimate the value of a value of being in a certain state with parameter <math> \theta </math>. It is trained based on records of state-action-reward sets <math> (s, r(s)) </math> by using stochastic gradient de- scent to minimize the mean squared error (MSE) between the predicted regression value and the corresponding outcome,<br />
\[ \Delta \theta \propto \frac{\partial v_{\theta}(s)}{\partial \theta}(r(s)-v_{\theta}(s)) \]<br />
<br />
=== Monte Carlo Tree Search ===<br />
<br />
Monte Carlo Tree Search (MCTS) is a search algorithm used for finite-horizon tasks (ex: in curling, only 16 moves, or throw stones, are taken each end).<br />
<br />
MCTS is a tree search algorithm similar to minimax. However, MCTS is probabilistic and does not need to explore a full game tree or even a tree reduced with alpha-beta pruning. This makes it tractable for games such as GO, and curling.<br />
<br />
Nodes of the tree are game states, and branches represent actions. Each node stores statistics on how many times it has been visited by the MCTS, as well as the number of wins encountered by playouts from that position. A node has been considered 'visited' if a full playout has started from that node. A node is considered 'expanded' if all its children have been visited.<br />
<br />
MCTS begins with the '''selection''' phase, which involves traversing known states/actions. This involves expanding the tree by beginning at the root node, and selecting the child/score with the highest 'score'. From each successive node, a path down to a root node is explored in a similar fashion.<br />
<br />
The next phase, '''expansion''', begins when the algorithm reaches a node where not all children have been visited (ie: the node has not been fully expanded). In the expansion phase, children of the node are visited, and '''simulations''' run from their states.<br />
<br />
Once the new child is expanded, '''simulation''' takes place. This refers to a full playout of the game from the point of the current node, and can involve many strategies, such as randomly taken moves, the use of heuristics, etc.<br />
<br />
The final phase is '''update''' or '''back-propagation''' (unrelated to the neural network algorithm). In this phase, the result of the '''simulation''' (ie: win/lose) is update in the statistics of all parent nodes.<br />
<br />
A selection function known as Upper Confidence Bound (UCT) can be used for selecting which node to select. The formula for this equation is shown below [[https://www.baeldung.com/java-monte-carlo-tree-search source]]. Note that the first term essentially acts as an average score of games played from a certain node. The second term, meanwhile, will grow when sibling nodes are expanded. This means that unexplored nodes will gradually increase their UCT score, and be selected in the future.<br />
<br />
<math> \frac{w_i}{n_i} + c \sqrt{\frac{\ln t}{n_i}} </math><br />
<br />
In which<br />
<br />
* <math> w_i = </math> number of wins after <math> i</math>th move<br />
* <math> n_i = </math> number of simulations after <math> i</math>th move<br />
* <math> c = </math> exploration parameter (theoritically eqal to <math> \sqrt{2}</math>)<br />
* <math> t = </math> total number of simulations for the parent node<br />
<br />
<br />
Sources: 2,3,4<br />
<br />
[[File:MCTS_Diagram.jpg | 500px|center]]<br />
<br />
=== Kernel Regression ===<br />
<br />
Kernel regression is a form of weighted averaging which uses a kernel function as a weight to estimate the conditional expectation of a random variable. Given two items of data, '''x''', each of which has a value '''y''' associated with them, and a choice of Kernel '''K''', the kernel functions outputs a weighting factor. An estimate of the value of a new, unseen point, is then calculated as the weighted average of values of surrounding points.<br />
<br />
A typical kernel is a Gaussian kernel, shown below. The formula for calculating estimated value is shown below as well (sources: Lee et al.).<br />
<br />
[[File:gaussian_kernel.png | 400 px]]<br />
<br />
[[File:kernel_regression.png | 250 px]]<br />
<br />
The denominator of the conditional expectation is related to kernel density estimation, which is defined as <math display="inline">W(x)=\sum_{i=0}^n K(x,x_i)</math>.<br />
<br />
In this case, the combination of the two-act to weigh scores of samples closest to '''x''' more strongly.<br />
<br />
= Methods =<br />
<br />
== Variable Definitions ==<br />
<br />
The following variables are used often in the paper:<br />
<br />
* <math>s</math>: A state in the game, as described below as the input to the network.<br />
* <math>s_t</math>: The state at a certain time-step of the game. Time-steps refer to full turns in the game<br />
* <math>a_t</math>: The action taken in state <math>s_t</math><br />
* <math>A_t</math>: The actions taken for sibling nodes related to <math>a_t</math> in MCTS<br />
* <math>n_{a_t}</math>: The number of visits to node a in MCTS<br />
* <math>v_{a_t}</math>: The MCTS value estimate of a node<br />
<br />
== Network Design ==<br />
<br />
The authors design a CNN called the 'policy-value' network. The network consists of a common network structure, which is then split into 'policy' and 'value' outputs. This network is trained to learn a probability distribution of actions to take, and expected rewards, given an input state.<br />
<br />
=== Shared Structure ===<br />
<br />
The network consists of 1 convolutional layer followed by 9 residual blocks, each block consisting of 2 convolutional layers with 32 3x3 filters. The structure of this network is shown below:<br />
<br />
<br />
[[File:curling_network_layers.png|600px|thumb|center|Figure 2. A detail description of our policy-value network. The shared network is composed of one convolutional layer and nine residual blocks. Each residual block (explained in b) has two convolutional layer with batch normalization (Ioffe & Szegedy, 2015[11]) followed by the addition of the input and the residual block. Each layer in the shared network uses 3x3 filters. The policy head<br />
has two more convolutional layers, while the value head has two fully connected layers on top of a convolutional layer. For the activation function of each convolutional layer, ReLU (Nair & Hinton[12]) is used.]]<br />
<br />
<br />
<br />
the input to this network is the following:<br />
* Location of stones<br />
* Order to tee (the center of the sheet)<br />
* A 32x32 grid of representation of the ice sheet, representing which stones are present in each grid cell.<br />
<br />
The authors do not describe how the stone-based information is added to the 32x32 grid as input to the network.<br />
<br />
=== Policy Network ===<br />
<br />
The policy head is created by adding 2 convolutional layers with 2 (two) 3x3 filters to the main body of the network. The output of the policy head is a distribution of probabilities of the actions to select the best shot out of a 32x32x2 set of actions. The actions represent target locations in the grid and spin direction of the stone.<br />
<br />
[[File:policy-value-net.PNG | 700px]]<br />
<br />
=== Value Network ===<br />
<br />
The valve head is created by adding a convolution layer with 1 3x3 filter, and dense layers of 256 and 17 units, to the shared network. The 17 output units represent a probability of scores in the range of [-8,8], which are the possible scores at each end of a curling game.<br />
<br />
== Continuous Action Search ==<br />
<br />
The policy head of the network only outputs actions from a discretized action space. For real-life interactions, and especially in curling, this will not suffice, as very fine adjustments to actions can make significant differences in outcomes.<br />
<br />
Actions in the continuous space are generated using an MCTS algorithm, with the following steps:<br />
<br />
=== Selection ===<br />
<br />
From a given state, the list of already-visited actions is denoted as A<sub>t</sub>. Scores and the number of visits to each node are estimated using the equations below (the first equation shows the expectation of the end value for one-end games). These are likely estimated rather than simply taken from the MCTS statistics to help account for the differences in a continuous action space.<br />
<br />
[[File:curling_kernel_equations.png | 400px]]<br />
<br />
The UCB formula is then used to select an action to expand.<br />
<br />
The actions that are taken in the simulator appear to be drawn from a Gaussian centered around <math>a_t</math>. This allows exploration in the continuous action space.<br />
<br />
=== Expansion ===<br />
<br />
The authors use a variant of regular UCT for expansion. In this case, they expand a new node only when existing nodes have been visited a certain number of times. The authors utilize a widening approach to overcome problems with standard UCT performing a shallow search when there is a large action space.<br />
<br />
=== Simulation ===<br />
<br />
Instead of simulating with a random game playout, the authors use the value network to estimate the likely score associated with a state. This speeds up simulation (assuming the network is well trained), as the game does not actually need to be simulated.<br />
<br />
=== Backpropogation ===<br />
<br />
Standard backpropagation is used, updating both the values and number of visits stored in the path of parent nodes.<br />
<br />
<br />
== Supervised Learning ==<br />
<br />
During supervised training, data is gathered from the program AyumuGAT'16 ([8]). This program is also based on both an MCTS algorithm, and a high-performance AI curling program. 400 000 state-action pairs were generated during this training.<br />
<br />
=== Policy Network ===<br />
<br />
The policy network was trained to learn the action taken in each state. Here, the likelihood of the taken action was set to be 1, and the likelihood of other actions to be 0.<br />
<br />
=== Value Network ===<br />
<br />
The value network was trained by 'd-depth simulations and bootstrapping of the prediction to handle the high variance in rewards resulting from a sequence of stochastic moves' (quote taken from paper). In this case, ''m'' state-action pairs were sampled from the training data. For each pair, <math>(s_t, a_t)</math>, a state d' steps ahead was generated, <math>s_{t+d}</math>. This process dealt with uncertainty by considering all actions in this rollout to have no uncertainty, and allowing uncertainty in the last action, ''a<sub>t+d-1</sub>''. The value network is used to predict the value for this state, <math>z_t</math>, and the value is used for learning the value at ''s<sub>t</sub>''.<br />
<br />
=== Policy-Value Network ===<br />
<br />
The policy-value network was trained to maximize the similarity of the predicted policy and value, and the actual policy and value from a state. The learning algorithm parameters are:<br />
<br />
* Algorithm: stochastic gradient descent<br />
* Batch size: 256<br />
* Momentum: 0.9<br />
* L2 regularization: 0.0001<br />
* Training time: ~100 epochs<br />
* Learning rate: initialized at 0.01, reduced twice<br />
<br />
A multi-task loss function was used. This takes the summation of the cross-entropy losses of each prediction:<br />
<br />
[[File:curling_loss_function.png | 300px]]<br />
<br />
== Self-Play Reinforcement Learning ==<br />
<br />
After initialization by supervised learning, the algorithm uses self-play to further train itself. During this training, the policy network learns probabilities from the MCTS process, while the value network learns from game outcomes.<br />
<br />
At a game state ''s<sub>t</sub>'':<br />
<br />
1) the algorithm outputs a prediction ''z<sub>t</sub>''. This is en estimate of game score probabilities. It is based on similar past actions, and computed using kernel regression.<br />
<br />
2) the algorithm outputs a prediction <math>\pi_t</math>, representing a probability distribution of actions. These are proportional to estimated visit counts from MCTS, based on kernel density estimation.<br />
<br />
It is not clear how these predictions are created. It would seem likely that the policy-value network generates these, but the wording of the paper suggests they are generated from MCTS statistics.<br />
<br />
The policy-value network is updated by sampling data <math>(s, \pi, z)</math> from recent history of self-play. The same loss function is used as before.<br />
<br />
It is not clear how the improved network is used, as MCTS seems to be the driving process at this point.<br />
<br />
== Long-Term Strategy Learning ==<br />
<br />
Finally, the authors implement a new strategy to augment their algorithm for long-term play. In this context, this refers to playing a game over many ends, where the strategy to win a single end may not be a good strategy to win a full game. For example, scoring one point in an end, while being one point ahead, gives the advantage to the other team in the next round (as they will throw the last stone). The other team could then use the advantage to score two points, taking the lead.<br />
<br />
The authors build a 'winning percentage' table. This table stores the percentage of games won, based on the number of ends left, and the difference in score (current team - opposing team). This can be computed iteratively and using the probability distribution estimation of one-end scores.<br />
<br />
== Final Algorithms ==<br />
<br />
The authors make use of the following versions of their algorithm:<br />
<br />
=== KR-DL ===<br />
<br />
''Kernel regression-deep learning'': This algorithm is trained only by supervised learning.<br />
<br />
=== KR-DRL ===<br />
<br />
''Kernel regression-deep reinforcement learning'': This algorithm is trained by supervised learning (ie: initialized as the KR-DL algorithm), and again on self-play. During self-play, each shot is selected after 400 MCTS simulations of k=20 randomly selected actions. Data for self-play was collected over a week on 5 GPUS and generated 5 million game positions. The policy-value network was continually updated using samples from the latest 1 million game positions.<br />
<br />
=== KR-DRL-MES ===<br />
<br />
''Kernel regression-deep reinforcement learning-multi-ends-strategy'': This algorithm makes use of the winning percentage table generated from self-play.<br />
<br />
= Testing and Results =<br />
The authors use data from the public program AyumuGAT’16 to test. Testing is done with a simulated curling program [9]. This simulator does not deal with changing ice conditions, or sweeping, but does deal with stone trajectories and collisions.<br />
<br />
== Comparison of KR-DL-UCT and DL-UCT ==<br />
<br />
The first test compares an algorithm trained with kernel regression with an algorithm trained without kernel regression, to show the contribution that kernel regression adds to the performance. Both algorithms have networks initialised with the supervised learning, and then trained with two different algorithms for self-play. KR-DL-UCT uses the algorithm described above. The authors do not go into detail on how DL-UCT selects shots, but state that a constant is set to allow exploration.<br />
<br />
As an evaluation, both algorithms play 2000 games against the DL-UCT algorithm, which is frozen after supervised training. 1000 games are played with the algorithm taking the first, and 100 taking the 2nd, shots. The games were two-end games. The figure below shows each algorithm's winning percentage given different amounts of training data. While the DL-UCT outperforms the supervised-training-only-DL-UCT algorithm, the KR-DL-UCT algorithm performs much better.<br />
<br />
<center>[[File:curling_KR_test.png | 400px]]</center><br />
<br />
== Matches ==<br />
<br />
Finally, to test the performance of their multiple algorithms, the authors run matches between their algorithms and other existing programs. Each algorithm plays 200 matches against each other program, 100 of which are played as the first-playing team, and 100 as the second-playing team. Only 1 program was able to out-perform the KR-DRL algorithm. The authors state that this program, ''JiritsukunGAT'17'' also uses a deep network and hand-crafted features. However, the KR-DRL-MES algorithm was still able to out-perform this. Figure 4 shows the Elo ratings of the different programs. Note that the programs in blue are those created by the authors. They also played some games between their KR-DRL-MES and notable<br />
programs. Table 1, shows the details of the match results. ''JiritsukunGAT'17'' shows a similar level of performance but KR-DRL-MES is still the winner.<br />
<br />
<br />
<br />
[[File:curling_ratings.png|600px|thumb|center|Figure 4. Elo rating and winning percentages of our models and GAT rankers. Each match has 200 games (each program plays 100 pre-ordered games), because the player which has the last shot (the hammer shot) in each end would have an advantage.]]<br />
<br />
<br />
[[File:ttt.png|600px|thumb|center|Table 1. The 8-end game results for KR-DRL-MES against other programs alternating the opening player each game. The matches are held by following the rules of the latest GAT competition.]]<br />
<br />
= Critique =<br />
<br />
== Strengths ==<br />
<br />
This algorithm out-performs other high-performance algorithms (including past competition champions).<br />
<br />
I think the paper does a decent job of comparing the performance of their algorithm to others. They are able to clearly show the benefits of many of their additions.<br />
<br />
The authors do seem to be able to adopt strategies similar to those used in Go and other games to the continuous action-space domain. In addition, the final strategy needs no hand-crafted features for learning.<br />
<br />
== Weaknesses ==<br />
<br />
Somtimes, I found this paper difficult to follow. One problem was that the algorithms were introduced first, and then how they were used was described. So when the paper stated that self-play shots were taken after 400 simulations, it seemed unclear what simulations were being run and at what stage of the algorithm (ex: MCTS simulations, simulations sped up by using the value network, full simulations on the curling simulator). In particular, both the MCTS statistics and the policy-value network could be used to estimate both action probabilities and state values, so it is difficult to tell which is used in which case. There was also no clear distinction between discrete-space actions and continuous-space actions.<br />
<br />
While I think the comparison of different algorithms was done well, I believe it still lacked significant details. There were one-off mentioned in the paper which would have been nice to see as results. These include the statement that having a policy-value network in place of two networks lead to better performance.<br />
<br />
At this point, the algorithms used still rely on initialization by a pre-made program.<br />
<br />
There was little theoretical development or justification done in this paper.<br />
<br />
While curling is an interesting choice for demonstrating the algorithm, the fact that the simulations used did not support many of the key points of curling (ice conditions, sweeping) seems very limited. Another game, such as pool, would likely have offered some of the same challenges but offered more high-fidelity simulations/training.<br />
<br />
While the spatial placements of stones were discretized in a grid, the curl of thrown stones was discretized to only +/-1. This seems like it may limit learning high- and low-spin moves. It should be noted that having zero spins is not commonly used, to the best of my knowledge.<br />
<br />
=References=<br />
# Lee, K., Kim, S., Choi, J. & Lee, S. "Deep Reinforcement Learning in Continuous Action Spaces: a Case Study in the Game of Simulated Curling." Proceedings of the 35th International Conference on Machine Learning, in PMLR 80:2937-2946 (2018)<br />
# https://www.baeldung.com/java-monte-carlo-tree-search<br />
# https://jeffbradberry.com/posts/2015/09/intro-to-monte-carlo-tree-search/<br />
# https://int8.io/monte-carlo-tree-search-beginners-guide/<br />
# https://en.wikipedia.org/wiki/Monte_Carlo_tree_search<br />
# Silver, D., Huang, A., Maddison, C., Guez, A., Sifre, L.,Van Den Driessche, G., Schrittwieser, J., Antonoglou, I.,Panneershelvam, V., Lanctot, M., Dieleman, S., Grewe,D., Nham, J., Kalchbrenner, N.,Sutskever, I., Lillicrap, T.,Leach, M., Kavukcuoglu, K., Graepel, T., and Hassabis,D. Mastering the game of go with deep neural networksand tree search. Nature, pp. 484–489, 2016.<br />
# Silver, D., Schrittwieser, J., Simonyan, K., Antonoglou,I., Huang, A., Guez, A., Hubert, T., Baker, L., Lai, M., Bolton, A., Chen, Y., Lillicrap, T., Hui, F., Sifre, L.,van den Driessche, G., Graepel, T., and Hassabis, D.Mastering the game of go without human knowledge.Nature, pp. 354–359, 2017.<br />
# Yamamoto, M., Kato, S., and Iizuka, H. Digital curling strategy based on game tree search. In Proceedings of the IEEE Conference on Computational Intelligence and Games, CIG, pp. 474–480, 2015.<br />
# Ohto, K. and Tanaka, T. A curling agent based on the montecarlo tree search considering the similarity of the best action among similar states. In Proceedings of Advances in Computer Games, ACG, pp. 151–164, 2017.<br />
# Ito, T. and Kitasei, Y. Proposal and implementation of digital curling. In Proceedings of the IEEE Conference on Computational Intelligence and Games, CIG, pp. 469–473, 2015.<br />
# Ioffe, S. and Szegedy, C. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In Proceedings of the International Conference on Machine Learning, ICML, pp. 448–456, 2015.<br />
# Nair, V. and Hinton, G. Rectified linear units improve restricted boltzmann machines.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Deep_Reinforcement_Learning_in_Continuous_Action_Spaces_a_Case_Study_in_the_Game_of_Simulated_Curling&diff=41963Deep Reinforcement Learning in Continuous Action Spaces a Case Study in the Game of Simulated Curling2018-11-30T00:12:48Z<p>Gsahu: /* Comparison of KR-DL-UCT and DL-UCT */</p>
<hr />
<div>This page provides a summary and critique of the paper '''Deep Reinforcement Learning in Continuous Action Spaces: a Case Study in the Game of Simulated Curling''' [[http://proceedings.mlr.press/v80/lee18b/lee18b.pdf Online Source]], published in ICML 2018. The source code for this paper is available [https://github.com/leekwoon/KR-DL-UCT here]<br />
<br />
= Introduction and Motivation =<br />
<br />
In recent years, Reinforcement Learning methods have been applied to many different games, such as chess and checkers. More recently, the use of CNN's has allowed neural networks to out-perform humans in many difficult games, such as Go. However, many of these cases involve a discrete state or action space; the number of actions a player can take and/or the number of possible game states are finite. <br />
<br />
Interacting with the real world (e.g.; a scenario that involves moving physical objects) typically involves working with a continuous action space. It is thus important to develop strategies for dealing with continuous action spaces. Deep neural networks that are designed to succeed in finite action spaces are not necessarily suitable for continuous action space problems. This is due to the fact that deterministic discretization of a continuous action space causes strong biases in policy evaluation and improvement. <br />
<br />
This paper introduces a method to allow learning with continuous action spaces. A CNN is used to perform learning on a discretion state and action spaces, and then a continuous action search is performed on these discrete results.<br />
<br />
Curling is chosen as a domain to test the network on. Curling was chosen due to its large action space, potential for complicated strategies, and need for precise interactions.<br />
<br />
== Curling ==<br />
<br />
Curling is a sport played by two teams on a long sheet of ice. Roughly, the goal is for each time to slide rocks closer to the target on the other end of the sheet than the other team. The next sections will provide a background on the game play, and potential challenges/concerns for learning algorithms. A terminology section follows.<br />
<br />
=== Game play ===<br />
<br />
A game of curling is divided into ends. In each end, players from both teams alternate throwing (sliding) eight rocks to the other end of the ice sheet, known as the house. Rocks must land in a certain area in order to stay in play, and must touch or be inside concentric rings (12ft diameter and smaller) in order to score points. At the end of each end, the team with rocks closest to the center of the house scores points.<br />
<br />
When throwing a rock, the curling can spin the rock. This allows the rock to 'curl' its path towards the house and can allow rocks to travel around other rocks. Team members are also able to sweep the ice in front of a moving rock in order to decrease friction, which allows for fine-tuning of distance (though the physics of sweeping are not implemented in the simulation used).<br />
<br />
Curling offers many possible high-level actions, which are directed by a team member to the throwing member. An example set of these includes:<br />
<br />
* Draw: Throw a rock to a target location<br />
* Freeze: Draw a rock up against another rock<br />
* Takeout: Knock another rock out of the house. Can be combined with different ricochet directions<br />
* Guard: Place a rock in front of another, to block other rocks (ex: takeouts)<br />
<br />
=== Challenges for AI ===<br />
<br />
Curling offers many challenges for curling based on its physics and rules. This section lists a few concerns.<br />
<br />
The effect of changing actions can be highly nonlinear and discontinuous. This can be seen when considering that a 1-cm deviation in a path can make the difference between a high-speed collision, or lack of collision.<br />
<br />
Curling will require both offensive and defensive strategies. For example, consider the fact that the last team to throw a rock each end only needs to place that rock closer than the opposing team's rocks to score a point and invalidate any opposing rocks in the house. The opposing team should thus be considering how to prevent this from happening, in addition to scoring points themselves.<br />
<br />
Curling also has a concept known as 'the hammer'. The hammer belongs to the team which throws the last rock each end, providing an advantage, and is given to the team that does not score points each end. It could very well be a good strategy to try not to win a single point in an end (if already ahead in points, etc), as this would give the advantage to the opposing team.<br />
<br />
Finally, curling has a rule known as the 'Free Guard Zone'. This applies to the first 4 rocks thrown (2 from each team). If they land short of the house, but still in play, then the rocks are not allowed to be removed (via collisions) until all of the first 4 rocks have been thrown.<br />
<br />
=== Terminology ===<br />
<br />
* End: A round of the game<br />
* House: The end of the sheet of ice, which contains<br />
* Hammer: The team that throws the last rock of an end 'has the hammer'<br />
* Hog Line: thick line that is drawn in front of the house, orthogonal to the length of the ice sheet. Rocks must pass this line to remain in play.<br />
* Back Line: think line drawn just behind the house. Rocks that pass this line are removed from play.<br />
<br />
<br />
== Related Work ==<br />
<br />
=== AlphaGo Lee ===<br />
<br />
AlphaGo Lee (Silver et al., 2016, [5]) refers to an algorithm used to play the game Go, which was able to defeat international champion Lee Sedol. <br />
<br />
<br />
Go game:<br />
* Start with 19x19 empty board<br />
* One player take black stones and the other take white stones<br />
* Two players take turns to put stones on the board<br />
* Rules:<br />
1. If one connected part is completely surrounded by the opponents stones, remove it from the board<br />
<br />
2. Ko rule: Forbids a board play to repeat a board position<br />
* End when there is no valuable moves on the board.<br />
* Count the territory of both players.<br />
* Add 7.5 points to whites points (called Komi).<br />
[[File:go.JPG|700px|center]]<br />
<br />
Two neural networks were trained on the moves of human experts, to act as both a policy network and a value network. A Monte Carlo Tree Search algorithm was used for policy improvement.<br />
<br />
The AlphaGo Lee policy network predicts the best move given a board configuration. It has a CNN architecture with 13 hidden layers, and it is trained using expert game play data and improved through self-play.<br />
<br />
The value network evaluates the probability of winning given a board configuration. It consists of a CNN with 14 hidden layers, and it is trained using self-play data from the policy network. <br />
<br />
Finally, the two networks are combined using Monte-Carlo Tree Search, which performs look ahead search to select the actions for game play.<br />
<br />
The use of both policy and value networks are reflected in this paper's work.<br />
<br />
=== AlphaGo Zero ===<br />
<br />
AlphaGo Zero (Silver et al., 2017, [6]) is an improvement on the AlphaGo Lee algorithm. AlphaGo Zero uses a unified neural network in place of the separate policy and value networks and is trained on self-play, without the need of expert training.<br />
<br />
The unification of networks and self-play are also reflected in this paper.<br />
<br />
=== Curling Algorithms ===<br />
<br />
Some past algorithms have been proposed to deal with continuous action spaces. For example, (Yammamoto et al, 2015, [7]) use game tree search methods in a discretized space. The value of an action is taken as the average of nearby values, with respect to some knowledge of execution uncertainty.<br />
<br />
=== Monte Carlo Tree Search ===<br />
<br />
Monte Carlo Tree Search algorithms have been applied to continuous action spaces. These algorithms, to be discussed in further detail, balance exploration of different states, with knowledge of paths of execution through past games. An MCTS called <math>KR-UCT</math> which is able to find effective selections and use kernel regression (KR) and kernel density estimation(KDE) to estimate rewards using neighborhood information has been applied to continuous action space by researchers. <br />
<br />
With bandit problem, scholars used hierarchical optimistic optimization(HOO) to create a cover tree and divide the action space into small ranges at different depths, where the most promising node will create fine granularity estimates.<br />
<br />
=== Curling Physics and Simulation ===<br />
<br />
Several references in the paper refer to the study and simulation of curling physics. Scholars have analyzed friction coefficients between curling stones and ice. While modelling the changes in friction on ice is not possible, a fixed friction coefficient was predefined in the simulation. The behavior of the stones was also modeled. Important parameters are trained from professional players. The authors used the same parameters in this paper.<br />
<br />
== General Background of Algorithms ==<br />
<br />
=== Policy and Value Functions ===<br />
<br />
A policy function is trained to provide the best action to take, given a current state. Policy iteration is an algorithm used to improve a policy over time. This is done by alternating between policy evaluation and policy improvement.<br />
<br />
POLICY IMPROVEMENT: LEARNING ACTION POLICY<br />
<br />
Action policy <math> p_{\sigma}(a|s) </math> outputs a probability distribution over all eligible moves <math> a </math>. Here <math> \sigma </math> denotes the weights of a neural network that approximates the policy. <math>s</math> denotes the set of states and <math>a</math> denotes the set of actions taken in the environment. The policy is a function that returns a action given the state at which the agent is present. The policy gradient reinforcement learning can be used to train action policy. It is updated by stochastic gradient ascent in the direction that maximizes the expected outcome at each time step t,<br />
\[ \Delta \rho \propto \frac{\partial p_{\rho}(a_t|s_t)}{\partial \rho} r(s_t) \]<br />
where <math> r(s_t) </math> is the return.<br />
<br />
POLICY EVALUATION: LEARNING VALUE FUNCTIONS<br />
<br />
A value function is trained to estimate the value of a value of being in a certain state with parameter <math> \theta </math>. It is trained based on records of state-action-reward sets <math> (s, r(s)) </math> by using stochastic gradient de- scent to minimize the mean squared error (MSE) between the predicted regression value and the corresponding outcome,<br />
\[ \Delta \theta \propto \frac{\partial v_{\theta}(s)}{\partial \theta}(r(s)-v_{\theta}(s)) \]<br />
<br />
=== Monte Carlo Tree Search ===<br />
<br />
Monte Carlo Tree Search (MCTS) is a search algorithm used for finite-horizon tasks (ex: in curling, only 16 moves, or throw stones, are taken each end).<br />
<br />
MCTS is a tree search algorithm similar to minimax. However, MCTS is probabilistic and does not need to explore a full game tree or even a tree reduced with alpha-beta pruning. This makes it tractable for games such as GO, and curling.<br />
<br />
Nodes of the tree are game states, and branches represent actions. Each node stores statistics on how many times it has been visited by the MCTS, as well as the number of wins encountered by playouts from that position. A node has been considered 'visited' if a full playout has started from that node. A node is considered 'expanded' if all its children have been visited.<br />
<br />
MCTS begins with the '''selection''' phase, which involves traversing known states/actions. This involves expanding the tree by beginning at the root node, and selecting the child/score with the highest 'score'. From each successive node, a path down to a root node is explored in a similar fashion.<br />
<br />
The next phase, '''expansion''', begins when the algorithm reaches a node where not all children have been visited (ie: the node has not been fully expanded). In the expansion phase, children of the node are visited, and '''simulations''' run from their states.<br />
<br />
Once the new child is expanded, '''simulation''' takes place. This refers to a full playout of the game from the point of the current node, and can involve many strategies, such as randomly taken moves, the use of heuristics, etc.<br />
<br />
The final phase is '''update''' or '''back-propagation''' (unrelated to the neural network algorithm). In this phase, the result of the '''simulation''' (ie: win/lose) is update in the statistics of all parent nodes.<br />
<br />
A selection function known as Upper Confidence Bound (UCT) can be used for selecting which node to select. The formula for this equation is shown below [[https://www.baeldung.com/java-monte-carlo-tree-search source]]. Note that the first term essentially acts as an average score of games played from a certain node. The second term, meanwhile, will grow when sibling nodes are expanded. This means that unexplored nodes will gradually increase their UCT score, and be selected in the future.<br />
<br />
<math> \frac{w_i}{n_i} + c \sqrt{\frac{\ln t}{n_i}} </math><br />
<br />
In which<br />
<br />
* <math> w_i = </math> number of wins after <math> i</math>th move<br />
* <math> n_i = </math> number of simulations after <math> i</math>th move<br />
* <math> c = </math> exploration parameter (theoritically eqal to <math> \sqrt{2}</math>)<br />
* <math> t = </math> total number of simulations for the parent node<br />
<br />
<br />
Sources: 2,3,4<br />
<br />
[[File:MCTS_Diagram.jpg | 500px|center]]<br />
<br />
=== Kernel Regression ===<br />
<br />
Kernel regression is a form of weighted averaging which uses a kernel function as a weight to estimate the conditional expectation of a random variable. Given two items of data, '''x''', each of which has a value '''y''' associated with them, and a choice of Kernel '''K''', the kernel functions outputs a weighting factor. An estimate of the value of a new, unseen point, is then calculated as the weighted average of values of surrounding points.<br />
<br />
A typical kernel is a Gaussian kernel, shown below. The formula for calculating estimated value is shown below as well (sources: Lee et al.).<br />
<br />
[[File:gaussian_kernel.png | 400 px]]<br />
<br />
[[File:kernel_regression.png | 250 px]]<br />
<br />
The denominator of the conditional expectation is related to kernel density estimation, which is defined as <math display="inline">W(x)=\sum_{i=0}^n K(x,x_i)</math>.<br />
<br />
In this case, the combination of the two-act to weigh scores of samples closest to '''x''' more strongly.<br />
<br />
= Methods =<br />
<br />
== Variable Definitions ==<br />
<br />
The following variables are used often in the paper:<br />
<br />
* <math>s</math>: A state in the game, as described below as the input to the network.<br />
* <math>s_t</math>: The state at a certain time-step of the game. Time-steps refer to full turns in the game<br />
* <math>a_t</math>: The action taken in state <math>s_t</math><br />
* <math>A_t</math>: The actions taken for sibling nodes related to <math>a_t</math> in MCTS<br />
* <math>n_{a_t}</math>: The number of visits to node a in MCTS<br />
* <math>v_{a_t}</math>: The MCTS value estimate of a node<br />
<br />
== Network Design ==<br />
<br />
The authors design a CNN called the 'policy-value' network. The network consists of a common network structure, which is then split into 'policy' and 'value' outputs. This network is trained to learn a probability distribution of actions to take, and expected rewards, given an input state.<br />
<br />
=== Shared Structure ===<br />
<br />
The network consists of 1 convolutional layer followed by 9 residual blocks, each block consisting of 2 convolutional layers with 32 3x3 filters. The structure of this network is shown below:<br />
<br />
<br />
[[File:curling_network_layers.png|600px|thumb|center|Figure 2. A detail description of our policy-value network. The shared network is composed of one convolutional layer and nine residual blocks. Each residual block (explained in b) has two convolutional layer with batch normalization (Ioffe & Szegedy, 2015[11]) followed by the addition of the input and the residual block. Each layer in the shared network uses 3x3 filters. The policy head<br />
has two more convolutional layers, while the value head has two fully connected layers on top of a convolutional layer. For the activation function of each convolutional layer, ReLU (Nair & Hinton[12]) is used.]]<br />
<br />
<br />
<br />
the input to this network is the following:<br />
* Location of stones<br />
* Order to tee (the center of the sheet)<br />
* A 32x32 grid of representation of the ice sheet, representing which stones are present in each grid cell.<br />
<br />
The authors do not describe how the stone-based information is added to the 32x32 grid as input to the network.<br />
<br />
=== Policy Network ===<br />
<br />
The policy head is created by adding 2 convolutional layers with 2 (two) 3x3 filters to the main body of the network. The output of the policy head is a distribution of probabilities of the actions to select the best shot out of a 32x32x2 set of actions. The actions represent target locations in the grid and spin direction of the stone.<br />
<br />
[[File:policy-value-net.PNG | 700px]]<br />
<br />
=== Value Network ===<br />
<br />
The valve head is created by adding a convolution layer with 1 3x3 filter, and dense layers of 256 and 17 units, to the shared network. The 17 output units represent a probability of scores in the range of [-8,8], which are the possible scores at each end of a curling game.<br />
<br />
== Continuous Action Search ==<br />
<br />
The policy head of the network only outputs actions from a discretized action space. For real-life interactions, and especially in curling, this will not suffice, as very fine adjustments to actions can make significant differences in outcomes.<br />
<br />
Actions in the continuous space are generated using an MCTS algorithm, with the following steps:<br />
<br />
=== Selection ===<br />
<br />
From a given state, the list of already-visited actions is denoted as A<sub>t</sub>. Scores and the number of visits to each node are estimated using the equations below (the first equation shows the expectation of the end value for one-end games). These are likely estimated rather than simply taken from the MCTS statistics to help account for the differences in a continuous action space.<br />
<br />
[[File:curling_kernel_equations.png | 400px]]<br />
<br />
The UCB formula is then used to select an action to expand.<br />
<br />
The actions that are taken in the simulator appear to be drawn from a Gaussian centered around <math>a_t</math>. This allows exploration in the continuous action space.<br />
<br />
=== Expansion ===<br />
<br />
The authors use a variant of regular UCT for expansion. In this case, they expand a new node only when existing nodes have been visited a certain number of times. The authors utilize a widening approach to overcome problems with standard UCT performing a shallow search when there is a large action space.<br />
<br />
=== Simulation ===<br />
<br />
Instead of simulating with a random game playout, the authors use the value network to estimate the likely score associated with a state. This speeds up simulation (assuming the network is well trained), as the game does not actually need to be simulated.<br />
<br />
=== Backpropogation ===<br />
<br />
Standard backpropagation is used, updating both the values and number of visits stored in the path of parent nodes.<br />
<br />
<br />
== Supervised Learning ==<br />
<br />
During supervised training, data is gathered from the program AyumuGAT'16 ([8]). This program is also based on both an MCTS algorithm, and a high-performance AI curling program. 400 000 state-action pairs were generated during this training.<br />
<br />
=== Policy Network ===<br />
<br />
The policy network was trained to learn the action taken in each state. Here, the likelihood of the taken action was set to be 1, and the likelihood of other actions to be 0.<br />
<br />
=== Value Network ===<br />
<br />
The value network was trained by 'd-depth simulations and bootstrapping of the prediction to handle the high variance in rewards resulting from a sequence of stochastic moves' (quote taken from paper). In this case, ''m'' state-action pairs were sampled from the training data. For each pair, <math>(s_t, a_t)</math>, a state d' steps ahead was generated, <math>s_{t+d}</math>. This process dealt with uncertainty by considering all actions in this rollout to have no uncertainty, and allowing uncertainty in the last action, ''a<sub>t+d-1</sub>''. The value network is used to predict the value for this state, <math>z_t</math>, and the value is used for learning the value at ''s<sub>t</sub>''.<br />
<br />
=== Policy-Value Network ===<br />
<br />
The policy-value network was trained to maximize the similarity of the predicted policy and value, and the actual policy and value from a state. The learning algorithm parameters are:<br />
<br />
* Algorithm: stochastic gradient descent<br />
* Batch size: 256<br />
* Momentum: 0.9<br />
* L2 regularization: 0.0001<br />
* Training time: ~100 epochs<br />
* Learning rate: initialized at 0.01, reduced twice<br />
<br />
A multi-task loss function was used. This takes the summation of the cross-entropy losses of each prediction:<br />
<br />
[[File:curling_loss_function.png | 300px]]<br />
<br />
== Self-Play Reinforcement Learning ==<br />
<br />
After initialization by supervised learning, the algorithm uses self-play to further train itself. During this training, the policy network learns probabilities from the MCTS process, while the value network learns from game outcomes.<br />
<br />
At a game state ''s<sub>t</sub>'':<br />
<br />
1) the algorithm outputs a prediction ''z<sub>t</sub>''. This is en estimate of game score probabilities. It is based on similar past actions, and computed using kernel regression.<br />
<br />
2) the algorithm outputs a prediction <math>\pi_t</math>, representing a probability distribution of actions. These are proportional to estimated visit counts from MCTS, based on kernel density estimation.<br />
<br />
It is not clear how these predictions are created. It would seem likely that the policy-value network generates these, but the wording of the paper suggests they are generated from MCTS statistics.<br />
<br />
The policy-value network is updated by sampling data <math>(s, \pi, z)</math> from recent history of self-play. The same loss function is used as before.<br />
<br />
It is not clear how the improved network is used, as MCTS seems to be the driving process at this point.<br />
<br />
== Long-Term Strategy Learning ==<br />
<br />
Finally, the authors implement a new strategy to augment their algorithm for long-term play. In this context, this refers to playing a game over many ends, where the strategy to win a single end may not be a good strategy to win a full game. For example, scoring one point in an end, while being one point ahead, gives the advantage to the other team in the next round (as they will throw the last stone). The other team could then use the advantage to score two points, taking the lead.<br />
<br />
The authors build a 'winning percentage' table. This table stores the percentage of games won, based on the number of ends left, and the difference in score (current team - opposing team). This can be computed iteratively and using the probability distribution estimation of one-end scores.<br />
<br />
== Final Algorithms ==<br />
<br />
The authors make use of the following versions of their algorithm:<br />
<br />
=== KR-DL ===<br />
<br />
''Kernel regression-deep learning'': This algorithm is trained only by supervised learning.<br />
<br />
=== KR-DRL ===<br />
<br />
''Kernel regression-deep reinforcement learning'': This algorithm is trained by supervised learning (ie: initialized as the KR-DL algorithm), and again on self-play. During self-play, each shot is selected after 400 MCTS simulations of k=20 randomly selected actions. Data for self-play was collected over a week on 5 GPUS and generated 5 million game positions. The policy-value network was continually updated using samples from the latest 1 million game positions.<br />
<br />
=== KR-DRL-MES ===<br />
<br />
''Kernel regression-deep reinforcement learning-multi-ends-strategy'': This algorithm makes use of the winning percentage table generated from self-play.<br />
<br />
= Testing and Results =<br />
The authors use data from the public program AyumuGAT’16 to test. Testing is done with a simulated curling program [9]. This simulator does not deal with changing ice conditions, or sweeping, but does deal with stone trajectories and collisions.<br />
<br />
== Comparison of KR-DL-UCT and DL-UCT ==<br />
<br />
The first test compares an algorithm trained with kernel regression with an algorithm trained without kernel regression, to show the contribution that kernel regression adds to the performance. Both algorithms have networks initialised with the supervised learning, and then trained with two different algorithms for self-play. KR-DL-UCT uses the algorithm described above. The authors do not go into detail on how DL-UCT selects shots, but state that a constant is set to allow exploration.<br />
<br />
As an evaluation, both algorithms play 2000 games against the DL-UCT algorithm, which is frozen after supervised training. 1000 games are played with the algorithm taking the first, and 100 taking the 2nd, shots. The games were two-end games. The figure below shows each algorithm's winning percentage given different amounts of training data. While the DL-UCT outperforms the supervised-training-only-DL-UCT algorithm, the KR-DL-UCT algorithm performs much better.<br />
<br />
<center>[[File:curling_KR_test.png | 400px]]</center><br />
<br />
== Matches ==<br />
<br />
Finally, to test the performance of their multiple algorithms, the authors run matches between their algorithms and other existing programs. Each algorithm plays 200 matches against each other program, 100 of which are played as the first-playing team, and 100 as the second-playing team. Only 1 program was able to out-perform the KR-DRL algorithm. The authors state that this program, ''JiritsukunGAT'17'' also uses a deep network and hand-crafted features. However, the KR-DRL-MES algorithm was still able to out-perform this. Figure 4 shows the Elo ratings of the different programs. Note that the programs in blue are those created by the authors. They also played some games between their KR-DRL-MES and notable<br />
programs. Table 1, shows the details of the match results. ''JiritsukunGAT'17'' shows a similar level of performance but KR-DRL-MES is still the winner.<br />
<br />
<br />
<br />
[[File:curling_ratings.png|600px|thumb|center|Figure 4. Elo rating and winning percentages of our models and GAT rankers. Each match has 200 games (each program plays 100 pre-ordered games), because the player which has the last shot (the hammer shot) in each end would have an advantage.]]<br />
<br />
<br />
[[File:ttt.png|600px|thumb|center|Table 1. The 8-end game results for KR-DRL-MES against other programs alternating the opening player each game. The matches are held by following the rules of the latest GAT competition.]]<br />
<br />
= Critique =<br />
<br />
== Strengths ==<br />
<br />
This algorithm out-performs other high-performance algorithms (including past competition champions).<br />
<br />
I think the paper does a decent job of comparing the performance of their algorithm to others. They are able to clearly show the benefits of many of their additions.<br />
<br />
The authors do seem to be able to adopt strategies similar to those used in Go and other games to the continuous action-space domain. In addition, the final strategy needs no hand-crafted features for learning.<br />
<br />
== Weaknesses ==<br />
<br />
Somtimes, I found this paper difficult to follow. One problem was that the algorithms were introduced first, and then how they were used was described. So when the paper stated that self-play shots were taken after 400 simulations, it seemed unclear what simulations were being run and at what stage of the algorithm (ex: MCTS simulations, simulations sped up by using the value network, full simulations on the curling simulator). In particular, both the MCTS statistics and the policy-value network could be used to estimate both action probabilities and state values, so it is difficult to tell which is used in which case. There was also no clear distinction between discrete-space actions and continuous-space actions.<br />
<br />
While I think the comparison of different algorithms was done well, I believe it still lacked some good detail. There were one-off mentions in the paper which would have been nice to see as results. These include the statement that having a policy-value network in place of two networks lead to better performance.<br />
<br />
At this point, the algorithms used still rely on initialization by a pre-made program.<br />
<br />
There was little theoretical development or justification done in this paper.<br />
<br />
While curling is an interesting choice for demonstrating the algorithm, the fact that the simulations used did not support many of the key points of curling (ice conditions, sweeping) seems very limited. Another game, such as pool, would likely have offered some of the same challenges but offered more high-fidelity simulations/training.<br />
<br />
While the spatial placements of stones were discretized in a grid, the curl of thrown stones was discretized to only +/-1. This seems like it may limit learning high- and low-spin moves. It should be noted that having zero spins is not commonly used, to the best of my knowledge.<br />
<br />
=References=<br />
# Lee, K., Kim, S., Choi, J. & Lee, S. "Deep Reinforcement Learning in Continuous Action Spaces: a Case Study in the Game of Simulated Curling." Proceedings of the 35th International Conference on Machine Learning, in PMLR 80:2937-2946 (2018)<br />
# https://www.baeldung.com/java-monte-carlo-tree-search<br />
# https://jeffbradberry.com/posts/2015/09/intro-to-monte-carlo-tree-search/<br />
# https://int8.io/monte-carlo-tree-search-beginners-guide/<br />
# https://en.wikipedia.org/wiki/Monte_Carlo_tree_search<br />
# Silver, D., Huang, A., Maddison, C., Guez, A., Sifre, L.,Van Den Driessche, G., Schrittwieser, J., Antonoglou, I.,Panneershelvam, V., Lanctot, M., Dieleman, S., Grewe,D., Nham, J., Kalchbrenner, N.,Sutskever, I., Lillicrap, T.,Leach, M., Kavukcuoglu, K., Graepel, T., and Hassabis,D. Mastering the game of go with deep neural networksand tree search. Nature, pp. 484–489, 2016.<br />
# Silver, D., Schrittwieser, J., Simonyan, K., Antonoglou,I., Huang, A., Guez, A., Hubert, T., Baker, L., Lai, M., Bolton, A., Chen, Y., Lillicrap, T., Hui, F., Sifre, L.,van den Driessche, G., Graepel, T., and Hassabis, D.Mastering the game of go without human knowledge.Nature, pp. 354–359, 2017.<br />
# Yamamoto, M., Kato, S., and Iizuka, H. Digital curling strategy based on game tree search. In Proceedings of the IEEE Conference on Computational Intelligence and Games, CIG, pp. 474–480, 2015.<br />
# Ohto, K. and Tanaka, T. A curling agent based on the montecarlo tree search considering the similarity of the best action among similar states. In Proceedings of Advances in Computer Games, ACG, pp. 151–164, 2017.<br />
# Ito, T. and Kitasei, Y. Proposal and implementation of digital curling. In Proceedings of the IEEE Conference on Computational Intelligence and Games, CIG, pp. 469–473, 2015.<br />
# Ioffe, S. and Szegedy, C. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In Proceedings of the International Conference on Machine Learning, ICML, pp. 448–456, 2015.<br />
# Nair, V. and Hinton, G. Rectified linear units improve restricted boltzmann machines.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Deep_Reinforcement_Learning_in_Continuous_Action_Spaces_a_Case_Study_in_the_Game_of_Simulated_Curling&diff=41961Deep Reinforcement Learning in Continuous Action Spaces a Case Study in the Game of Simulated Curling2018-11-30T00:10:52Z<p>Gsahu: /* Policy and Value Functions */</p>
<hr />
<div>This page provides a summary and critique of the paper '''Deep Reinforcement Learning in Continuous Action Spaces: a Case Study in the Game of Simulated Curling''' [[http://proceedings.mlr.press/v80/lee18b/lee18b.pdf Online Source]], published in ICML 2018. The source code for this paper is available [https://github.com/leekwoon/KR-DL-UCT here]<br />
<br />
= Introduction and Motivation =<br />
<br />
In recent years, Reinforcement Learning methods have been applied to many different games, such as chess and checkers. More recently, the use of CNN's has allowed neural networks to out-perform humans in many difficult games, such as Go. However, many of these cases involve a discrete state or action space; the number of actions a player can take and/or the number of possible game states are finite. <br />
<br />
Interacting with the real world (e.g.; a scenario that involves moving physical objects) typically involves working with a continuous action space. It is thus important to develop strategies for dealing with continuous action spaces. Deep neural networks that are designed to succeed in finite action spaces are not necessarily suitable for continuous action space problems. This is due to the fact that deterministic discretization of a continuous action space causes strong biases in policy evaluation and improvement. <br />
<br />
This paper introduces a method to allow learning with continuous action spaces. A CNN is used to perform learning on a discretion state and action spaces, and then a continuous action search is performed on these discrete results.<br />
<br />
Curling is chosen as a domain to test the network on. Curling was chosen due to its large action space, potential for complicated strategies, and need for precise interactions.<br />
<br />
== Curling ==<br />
<br />
Curling is a sport played by two teams on a long sheet of ice. Roughly, the goal is for each time to slide rocks closer to the target on the other end of the sheet than the other team. The next sections will provide a background on the game play, and potential challenges/concerns for learning algorithms. A terminology section follows.<br />
<br />
=== Game play ===<br />
<br />
A game of curling is divided into ends. In each end, players from both teams alternate throwing (sliding) eight rocks to the other end of the ice sheet, known as the house. Rocks must land in a certain area in order to stay in play, and must touch or be inside concentric rings (12ft diameter and smaller) in order to score points. At the end of each end, the team with rocks closest to the center of the house scores points.<br />
<br />
When throwing a rock, the curling can spin the rock. This allows the rock to 'curl' its path towards the house and can allow rocks to travel around other rocks. Team members are also able to sweep the ice in front of a moving rock in order to decrease friction, which allows for fine-tuning of distance (though the physics of sweeping are not implemented in the simulation used).<br />
<br />
Curling offers many possible high-level actions, which are directed by a team member to the throwing member. An example set of these includes:<br />
<br />
* Draw: Throw a rock to a target location<br />
* Freeze: Draw a rock up against another rock<br />
* Takeout: Knock another rock out of the house. Can be combined with different ricochet directions<br />
* Guard: Place a rock in front of another, to block other rocks (ex: takeouts)<br />
<br />
=== Challenges for AI ===<br />
<br />
Curling offers many challenges for curling based on its physics and rules. This section lists a few concerns.<br />
<br />
The effect of changing actions can be highly nonlinear and discontinuous. This can be seen when considering that a 1-cm deviation in a path can make the difference between a high-speed collision, or lack of collision.<br />
<br />
Curling will require both offensive and defensive strategies. For example, consider the fact that the last team to throw a rock each end only needs to place that rock closer than the opposing team's rocks to score a point and invalidate any opposing rocks in the house. The opposing team should thus be considering how to prevent this from happening, in addition to scoring points themselves.<br />
<br />
Curling also has a concept known as 'the hammer'. The hammer belongs to the team which throws the last rock each end, providing an advantage, and is given to the team that does not score points each end. It could very well be a good strategy to try not to win a single point in an end (if already ahead in points, etc), as this would give the advantage to the opposing team.<br />
<br />
Finally, curling has a rule known as the 'Free Guard Zone'. This applies to the first 4 rocks thrown (2 from each team). If they land short of the house, but still in play, then the rocks are not allowed to be removed (via collisions) until all of the first 4 rocks have been thrown.<br />
<br />
=== Terminology ===<br />
<br />
* End: A round of the game<br />
* House: The end of the sheet of ice, which contains<br />
* Hammer: The team that throws the last rock of an end 'has the hammer'<br />
* Hog Line: thick line that is drawn in front of the house, orthogonal to the length of the ice sheet. Rocks must pass this line to remain in play.<br />
* Back Line: think line drawn just behind the house. Rocks that pass this line are removed from play.<br />
<br />
<br />
== Related Work ==<br />
<br />
=== AlphaGo Lee ===<br />
<br />
AlphaGo Lee (Silver et al., 2016, [5]) refers to an algorithm used to play the game Go, which was able to defeat international champion Lee Sedol. <br />
<br />
<br />
Go game:<br />
* Start with 19x19 empty board<br />
* One player take black stones and the other take white stones<br />
* Two players take turns to put stones on the board<br />
* Rules:<br />
1. If one connected part is completely surrounded by the opponents stones, remove it from the board<br />
<br />
2. Ko rule: Forbids a board play to repeat a board position<br />
* End when there is no valuable moves on the board.<br />
* Count the territory of both players.<br />
* Add 7.5 points to whites points (called Komi).<br />
[[File:go.JPG|700px|center]]<br />
<br />
Two neural networks were trained on the moves of human experts, to act as both a policy network and a value network. A Monte Carlo Tree Search algorithm was used for policy improvement.<br />
<br />
The AlphaGo Lee policy network predicts the best move given a board configuration. It has a CNN architecture with 13 hidden layers, and it is trained using expert game play data and improved through self-play.<br />
<br />
The value network evaluates the probability of winning given a board configuration. It consists of a CNN with 14 hidden layers, and it is trained using self-play data from the policy network. <br />
<br />
Finally, the two networks are combined using Monte-Carlo Tree Search, which performs look ahead search to select the actions for game play.<br />
<br />
The use of both policy and value networks are reflected in this paper's work.<br />
<br />
=== AlphaGo Zero ===<br />
<br />
AlphaGo Zero (Silver et al., 2017, [6]) is an improvement on the AlphaGo Lee algorithm. AlphaGo Zero uses a unified neural network in place of the separate policy and value networks and is trained on self-play, without the need of expert training.<br />
<br />
The unification of networks and self-play are also reflected in this paper.<br />
<br />
=== Curling Algorithms ===<br />
<br />
Some past algorithms have been proposed to deal with continuous action spaces. For example, (Yammamoto et al, 2015, [7]) use game tree search methods in a discretized space. The value of an action is taken as the average of nearby values, with respect to some knowledge of execution uncertainty.<br />
<br />
=== Monte Carlo Tree Search ===<br />
<br />
Monte Carlo Tree Search algorithms have been applied to continuous action spaces. These algorithms, to be discussed in further detail, balance exploration of different states, with knowledge of paths of execution through past games. An MCTS called <math>KR-UCT</math> which is able to find effective selections and use kernel regression (KR) and kernel density estimation(KDE) to estimate rewards using neighborhood information has been applied to continuous action space by researchers. <br />
<br />
With bandit problem, scholars used hierarchical optimistic optimization(HOO) to create a cover tree and divide the action space into small ranges at different depths, where the most promising node will create fine granularity estimates.<br />
<br />
=== Curling Physics and Simulation ===<br />
<br />
Several references in the paper refer to the study and simulation of curling physics. Scholars have analyzed friction coefficients between curling stones and ice. While modelling the changes in friction on ice is not possible, a fixed friction coefficient was predefined in the simulation. The behavior of the stones was also modeled. Important parameters are trained from professional players. The authors used the same parameters in this paper.<br />
<br />
== General Background of Algorithms ==<br />
<br />
=== Policy and Value Functions ===<br />
<br />
A policy function is trained to provide the best action to take, given a current state. Policy iteration is an algorithm used to improve a policy over time. This is done by alternating between policy evaluation and policy improvement.<br />
<br />
POLICY IMPROVEMENT: LEARNING ACTION POLICY<br />
<br />
Action policy <math> p_{\sigma}(a|s) </math> outputs a probability distribution over all eligible moves <math> a </math>. Here <math> \sigma </math> denotes the weights of a neural network that approximates the policy. <math>s</math> denotes the set of states and <math>a</math> denotes the set of actions taken in the environment. The policy is a function that returns a action given the state at which the agent is present. The policy gradient reinforcement learning can be used to train action policy. It is updated by stochastic gradient ascent in the direction that maximizes the expected outcome at each time step t,<br />
\[ \Delta \rho \propto \frac{\partial p_{\rho}(a_t|s_t)}{\partial \rho} r(s_t) \]<br />
where <math> r(s_t) </math> is the return.<br />
<br />
POLICY EVALUATION: LEARNING VALUE FUNCTIONS<br />
<br />
A value function is trained to estimate the value of a value of being in a certain state with parameter <math> \theta </math>. It is trained based on records of state-action-reward sets <math> (s, r(s)) </math> by using stochastic gradient de- scent to minimize the mean squared error (MSE) between the predicted regression value and the corresponding outcome,<br />
\[ \Delta \theta \propto \frac{\partial v_{\theta}(s)}{\partial \theta}(r(s)-v_{\theta}(s)) \]<br />
<br />
=== Monte Carlo Tree Search ===<br />
<br />
Monte Carlo Tree Search (MCTS) is a search algorithm used for finite-horizon tasks (ex: in curling, only 16 moves, or throw stones, are taken each end).<br />
<br />
MCTS is a tree search algorithm similar to minimax. However, MCTS is probabilistic and does not need to explore a full game tree or even a tree reduced with alpha-beta pruning. This makes it tractable for games such as GO, and curling.<br />
<br />
Nodes of the tree are game states, and branches represent actions. Each node stores statistics on how many times it has been visited by the MCTS, as well as the number of wins encountered by playouts from that position. A node has been considered 'visited' if a full playout has started from that node. A node is considered 'expanded' if all its children have been visited.<br />
<br />
MCTS begins with the '''selection''' phase, which involves traversing known states/actions. This involves expanding the tree by beginning at the root node, and selecting the child/score with the highest 'score'. From each successive node, a path down to a root node is explored in a similar fashion.<br />
<br />
The next phase, '''expansion''', begins when the algorithm reaches a node where not all children have been visited (ie: the node has not been fully expanded). In the expansion phase, children of the node are visited, and '''simulations''' run from their states.<br />
<br />
Once the new child is expanded, '''simulation''' takes place. This refers to a full playout of the game from the point of the current node, and can involve many strategies, such as randomly taken moves, the use of heuristics, etc.<br />
<br />
The final phase is '''update''' or '''back-propagation''' (unrelated to the neural network algorithm). In this phase, the result of the '''simulation''' (ie: win/lose) is update in the statistics of all parent nodes.<br />
<br />
A selection function known as Upper Confidence Bound (UCT) can be used for selecting which node to select. The formula for this equation is shown below [[https://www.baeldung.com/java-monte-carlo-tree-search source]]. Note that the first term essentially acts as an average score of games played from a certain node. The second term, meanwhile, will grow when sibling nodes are expanded. This means that unexplored nodes will gradually increase their UCT score, and be selected in the future.<br />
<br />
<math> \frac{w_i}{n_i} + c \sqrt{\frac{\ln t}{n_i}} </math><br />
<br />
In which<br />
<br />
* <math> w_i = </math> number of wins after <math> i</math>th move<br />
* <math> n_i = </math> number of simulations after <math> i</math>th move<br />
* <math> c = </math> exploration parameter (theoritically eqal to <math> \sqrt{2}</math>)<br />
* <math> t = </math> total number of simulations for the parent node<br />
<br />
<br />
Sources: 2,3,4<br />
<br />
[[File:MCTS_Diagram.jpg | 500px|center]]<br />
<br />
=== Kernel Regression ===<br />
<br />
Kernel regression is a form of weighted averaging which uses a kernel function as a weight to estimate the conditional expectation of a random variable. Given two items of data, '''x''', each of which has a value '''y''' associated with them, and a choice of Kernel '''K''', the kernel functions outputs a weighting factor. An estimate of the value of a new, unseen point, is then calculated as the weighted average of values of surrounding points.<br />
<br />
A typical kernel is a Gaussian kernel, shown below. The formula for calculating estimated value is shown below as well (sources: Lee et al.).<br />
<br />
[[File:gaussian_kernel.png | 400 px]]<br />
<br />
[[File:kernel_regression.png | 250 px]]<br />
<br />
The denominator of the conditional expectation is related to kernel density estimation, which is defined as <math display="inline">W(x)=\sum_{i=0}^n K(x,x_i)</math>.<br />
<br />
In this case, the combination of the two-act to weigh scores of samples closest to '''x''' more strongly.<br />
<br />
= Methods =<br />
<br />
== Variable Definitions ==<br />
<br />
The following variables are used often in the paper:<br />
<br />
* <math>s</math>: A state in the game, as described below as the input to the network.<br />
* <math>s_t</math>: The state at a certain time-step of the game. Time-steps refer to full turns in the game<br />
* <math>a_t</math>: The action taken in state <math>s_t</math><br />
* <math>A_t</math>: The actions taken for sibling nodes related to <math>a_t</math> in MCTS<br />
* <math>n_{a_t}</math>: The number of visits to node a in MCTS<br />
* <math>v_{a_t}</math>: The MCTS value estimate of a node<br />
<br />
== Network Design ==<br />
<br />
The authors design a CNN called the 'policy-value' network. The network consists of a common network structure, which is then split into 'policy' and 'value' outputs. This network is trained to learn a probability distribution of actions to take, and expected rewards, given an input state.<br />
<br />
=== Shared Structure ===<br />
<br />
The network consists of 1 convolutional layer followed by 9 residual blocks, each block consisting of 2 convolutional layers with 32 3x3 filters. The structure of this network is shown below:<br />
<br />
<br />
[[File:curling_network_layers.png|600px|thumb|center|Figure 2. A detail description of our policy-value network. The shared network is composed of one convolutional layer and nine residual blocks. Each residual block (explained in b) has two convolutional layer with batch normalization (Ioffe & Szegedy, 2015[11]) followed by the addition of the input and the residual block. Each layer in the shared network uses 3x3 filters. The policy head<br />
has two more convolutional layers, while the value head has two fully connected layers on top of a convolutional layer. For the activation function of each convolutional layer, ReLU (Nair & Hinton[12]) is used.]]<br />
<br />
<br />
<br />
the input to this network is the following:<br />
* Location of stones<br />
* Order to tee (the center of the sheet)<br />
* A 32x32 grid of representation of the ice sheet, representing which stones are present in each grid cell.<br />
<br />
The authors do not describe how the stone-based information is added to the 32x32 grid as input to the network.<br />
<br />
=== Policy Network ===<br />
<br />
The policy head is created by adding 2 convolutional layers with 2 (two) 3x3 filters to the main body of the network. The output of the policy head is a distribution of probabilities of the actions to select the best shot out of a 32x32x2 set of actions. The actions represent target locations in the grid and spin direction of the stone.<br />
<br />
[[File:policy-value-net.PNG | 700px]]<br />
<br />
=== Value Network ===<br />
<br />
The valve head is created by adding a convolution layer with 1 3x3 filter, and dense layers of 256 and 17 units, to the shared network. The 17 output units represent a probability of scores in the range of [-8,8], which are the possible scores at each end of a curling game.<br />
<br />
== Continuous Action Search ==<br />
<br />
The policy head of the network only outputs actions from a discretized action space. For real-life interactions, and especially in curling, this will not suffice, as very fine adjustments to actions can make significant differences in outcomes.<br />
<br />
Actions in the continuous space are generated using an MCTS algorithm, with the following steps:<br />
<br />
=== Selection ===<br />
<br />
From a given state, the list of already-visited actions is denoted as A<sub>t</sub>. Scores and the number of visits to each node are estimated using the equations below (the first equation shows the expectation of the end value for one-end games). These are likely estimated rather than simply taken from the MCTS statistics to help account for the differences in a continuous action space.<br />
<br />
[[File:curling_kernel_equations.png | 400px]]<br />
<br />
The UCB formula is then used to select an action to expand.<br />
<br />
The actions that are taken in the simulator appear to be drawn from a Gaussian centered around <math>a_t</math>. This allows exploration in the continuous action space.<br />
<br />
=== Expansion ===<br />
<br />
The authors use a variant of regular UCT for expansion. In this case, they expand a new node only when existing nodes have been visited a certain number of times. The authors utilize a widening approach to overcome problems with standard UCT performing a shallow search when there is a large action space.<br />
<br />
=== Simulation ===<br />
<br />
Instead of simulating with a random game playout, the authors use the value network to estimate the likely score associated with a state. This speeds up simulation (assuming the network is well trained), as the game does not actually need to be simulated.<br />
<br />
=== Backpropogation ===<br />
<br />
Standard backpropagation is used, updating both the values and number of visits stored in the path of parent nodes.<br />
<br />
<br />
== Supervised Learning ==<br />
<br />
During supervised training, data is gathered from the program AyumuGAT'16 ([8]). This program is also based on both an MCTS algorithm, and a high-performance AI curling program. 400 000 state-action pairs were generated during this training.<br />
<br />
=== Policy Network ===<br />
<br />
The policy network was trained to learn the action taken in each state. Here, the likelihood of the taken action was set to be 1, and the likelihood of other actions to be 0.<br />
<br />
=== Value Network ===<br />
<br />
The value network was trained by 'd-depth simulations and bootstrapping of the prediction to handle the high variance in rewards resulting from a sequence of stochastic moves' (quote taken from paper). In this case, ''m'' state-action pairs were sampled from the training data. For each pair, <math>(s_t, a_t)</math>, a state d' steps ahead was generated, <math>s_{t+d}</math>. This process dealt with uncertainty by considering all actions in this rollout to have no uncertainty, and allowing uncertainty in the last action, ''a<sub>t+d-1</sub>''. The value network is used to predict the value for this state, <math>z_t</math>, and the value is used for learning the value at ''s<sub>t</sub>''.<br />
<br />
=== Policy-Value Network ===<br />
<br />
The policy-value network was trained to maximize the similarity of the predicted policy and value, and the actual policy and value from a state. The learning algorithm parameters are:<br />
<br />
* Algorithm: stochastic gradient descent<br />
* Batch size: 256<br />
* Momentum: 0.9<br />
* L2 regularization: 0.0001<br />
* Training time: ~100 epochs<br />
* Learning rate: initialized at 0.01, reduced twice<br />
<br />
A multi-task loss function was used. This takes the summation of the cross-entropy losses of each prediction:<br />
<br />
[[File:curling_loss_function.png | 300px]]<br />
<br />
== Self-Play Reinforcement Learning ==<br />
<br />
After initialization by supervised learning, the algorithm uses self-play to further train itself. During this training, the policy network learns probabilities from the MCTS process, while the value network learns from game outcomes.<br />
<br />
At a game state ''s<sub>t</sub>'':<br />
<br />
1) the algorithm outputs a prediction ''z<sub>t</sub>''. This is en estimate of game score probabilities. It is based on similar past actions, and computed using kernel regression.<br />
<br />
2) the algorithm outputs a prediction <math>\pi_t</math>, representing a probability distribution of actions. These are proportional to estimated visit counts from MCTS, based on kernel density estimation.<br />
<br />
It is not clear how these predictions are created. It would seem likely that the policy-value network generates these, but the wording of the paper suggests they are generated from MCTS statistics.<br />
<br />
The policy-value network is updated by sampling data <math>(s, \pi, z)</math> from recent history of self-play. The same loss function is used as before.<br />
<br />
It is not clear how the improved network is used, as MCTS seems to be the driving process at this point.<br />
<br />
== Long-Term Strategy Learning ==<br />
<br />
Finally, the authors implement a new strategy to augment their algorithm for long-term play. In this context, this refers to playing a game over many ends, where the strategy to win a single end may not be a good strategy to win a full game. For example, scoring one point in an end, while being one point ahead, gives the advantage to the other team in the next round (as they will throw the last stone). The other team could then use the advantage to score two points, taking the lead.<br />
<br />
The authors build a 'winning percentage' table. This table stores the percentage of games won, based on the number of ends left, and the difference in score (current team - opposing team). This can be computed iteratively and using the probability distribution estimation of one-end scores.<br />
<br />
== Final Algorithms ==<br />
<br />
The authors make use of the following versions of their algorithm:<br />
<br />
=== KR-DL ===<br />
<br />
''Kernel regression-deep learning'': This algorithm is trained only by supervised learning.<br />
<br />
=== KR-DRL ===<br />
<br />
''Kernel regression-deep reinforcement learning'': This algorithm is trained by supervised learning (ie: initialized as the KR-DL algorithm), and again on self-play. During self-play, each shot is selected after 400 MCTS simulations of k=20 randomly selected actions. Data for self-play was collected over a week on 5 GPUS and generated 5 million game positions. The policy-value network was continually updated using samples from the latest 1 million game positions.<br />
<br />
=== KR-DRL-MES ===<br />
<br />
''Kernel regression-deep reinforcement learning-multi-ends-strategy'': This algorithm makes use of the winning percentage table generated from self-play.<br />
<br />
= Testing and Results =<br />
The authors use data from the public program AyumuGAT’16 to test. Testing is done with a simulated curling program [9]. This simulator does not deal with changing ice conditions, or sweeping, but does deal with stone trajectories and collisions.<br />
<br />
== Comparison of KR-DL-UCT and DL-UCT ==<br />
<br />
The first test compares an algorithm trained with kernel regression with an algorithm trained without kernel regression, to show the contribution that kernel regression adds to the performance. Both algorithms have networks initialised with the supervised learning, and then trained with two different algorithms for self-play. KR-DL-UCT uses the algorithm described above. The authors do not go into detail on how DL-UCT selects shots, but state that a constant is set to allow exploration.<br />
<br />
As an evaluation, both algorithms play 2000 games against the DL-UCT algorithm, which is frozen after supervised training. 1000 games are played with the algorithm taking the first, and 100 taking the 2nd, shots. The games were two-end games. The figure below shows each algorithm's winning percentage given different amounts of training data. While the DL-UCT outperforms the supervised-training-only-DL-UCT algorithm, the KR-DL-UCT algorithm performs much better.<br />
<br />
[[File:curling_KR_test.png | 400px]]<br />
<br />
== Matches ==<br />
<br />
Finally, to test the performance of their multiple algorithms, the authors run matches between their algorithms and other existing programs. Each algorithm plays 200 matches against each other program, 100 of which are played as the first-playing team, and 100 as the second-playing team. Only 1 program was able to out-perform the KR-DRL algorithm. The authors state that this program, ''JiritsukunGAT'17'' also uses a deep network and hand-crafted features. However, the KR-DRL-MES algorithm was still able to out-perform this. Figure 4 shows the Elo ratings of the different programs. Note that the programs in blue are those created by the authors. They also played some games between their KR-DRL-MES and notable<br />
programs. Table 1, shows the details of the match results. ''JiritsukunGAT'17'' shows a similar level of performance but KR-DRL-MES is still the winner.<br />
<br />
<br />
<br />
[[File:curling_ratings.png|600px|thumb|center|Figure 4. Elo rating and winning percentages of our models and GAT rankers. Each match has 200 games (each program plays 100 pre-ordered games), because the player which has the last shot (the hammer shot) in each end would have an advantage.]]<br />
<br />
<br />
[[File:ttt.png|600px|thumb|center|Table 1. The 8-end game results for KR-DRL-MES against other programs alternating the opening player each game. The matches are held by following the rules of the latest GAT competition.]]<br />
<br />
= Critique =<br />
<br />
== Strengths ==<br />
<br />
This algorithm out-performs other high-performance algorithms (including past competition champions).<br />
<br />
I think the paper does a decent job of comparing the performance of their algorithm to others. They are able to clearly show the benefits of many of their additions.<br />
<br />
The authors do seem to be able to adopt strategies similar to those used in Go and other games to the continuous action-space domain. In addition, the final strategy needs no hand-crafted features for learning.<br />
<br />
== Weaknesses ==<br />
<br />
Somtimes, I found this paper difficult to follow. One problem was that the algorithms were introduced first, and then how they were used was described. So when the paper stated that self-play shots were taken after 400 simulations, it seemed unclear what simulations were being run and at what stage of the algorithm (ex: MCTS simulations, simulations sped up by using the value network, full simulations on the curling simulator). In particular, both the MCTS statistics and the policy-value network could be used to estimate both action probabilities and state values, so it is difficult to tell which is used in which case. There was also no clear distinction between discrete-space actions and continuous-space actions.<br />
<br />
While I think the comparison of different algorithms was done well, I believe it still lacked some good detail. There were one-off mentions in the paper which would have been nice to see as results. These include the statement that having a policy-value network in place of two networks lead to better performance.<br />
<br />
At this point, the algorithms used still rely on initialization by a pre-made program.<br />
<br />
There was little theoretical development or justification done in this paper.<br />
<br />
While curling is an interesting choice for demonstrating the algorithm, the fact that the simulations used did not support many of the key points of curling (ice conditions, sweeping) seems very limited. Another game, such as pool, would likely have offered some of the same challenges but offered more high-fidelity simulations/training.<br />
<br />
While the spatial placements of stones were discretized in a grid, the curl of thrown stones was discretized to only +/-1. This seems like it may limit learning high- and low-spin moves. It should be noted that having zero spins is not commonly used, to the best of my knowledge.<br />
<br />
=References=<br />
# Lee, K., Kim, S., Choi, J. & Lee, S. "Deep Reinforcement Learning in Continuous Action Spaces: a Case Study in the Game of Simulated Curling." Proceedings of the 35th International Conference on Machine Learning, in PMLR 80:2937-2946 (2018)<br />
# https://www.baeldung.com/java-monte-carlo-tree-search<br />
# https://jeffbradberry.com/posts/2015/09/intro-to-monte-carlo-tree-search/<br />
# https://int8.io/monte-carlo-tree-search-beginners-guide/<br />
# https://en.wikipedia.org/wiki/Monte_Carlo_tree_search<br />
# Silver, D., Huang, A., Maddison, C., Guez, A., Sifre, L.,Van Den Driessche, G., Schrittwieser, J., Antonoglou, I.,Panneershelvam, V., Lanctot, M., Dieleman, S., Grewe,D., Nham, J., Kalchbrenner, N.,Sutskever, I., Lillicrap, T.,Leach, M., Kavukcuoglu, K., Graepel, T., and Hassabis,D. Mastering the game of go with deep neural networksand tree search. Nature, pp. 484–489, 2016.<br />
# Silver, D., Schrittwieser, J., Simonyan, K., Antonoglou,I., Huang, A., Guez, A., Hubert, T., Baker, L., Lai, M., Bolton, A., Chen, Y., Lillicrap, T., Hui, F., Sifre, L.,van den Driessche, G., Graepel, T., and Hassabis, D.Mastering the game of go without human knowledge.Nature, pp. 354–359, 2017.<br />
# Yamamoto, M., Kato, S., and Iizuka, H. Digital curling strategy based on game tree search. In Proceedings of the IEEE Conference on Computational Intelligence and Games, CIG, pp. 474–480, 2015.<br />
# Ohto, K. and Tanaka, T. A curling agent based on the montecarlo tree search considering the similarity of the best action among similar states. In Proceedings of Advances in Computer Games, ACG, pp. 151–164, 2017.<br />
# Ito, T. and Kitasei, Y. Proposal and implementation of digital curling. In Proceedings of the IEEE Conference on Computational Intelligence and Games, CIG, pp. 469–473, 2015.<br />
# Ioffe, S. and Szegedy, C. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In Proceedings of the International Conference on Machine Learning, ICML, pp. 448–456, 2015.<br />
# Nair, V. and Hinton, G. Rectified linear units improve restricted boltzmann machines.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=DETECTING_STATISTICAL_INTERACTIONS_FROM_NEURAL_NETWORK_WEIGHTS&diff=41959DETECTING STATISTICAL INTERACTIONS FROM NEURAL NETWORK WEIGHTS2018-11-30T00:08:53Z<p>Gsahu: /* Experiment */</p>
<hr />
<div>=Introduction=<br />
<br />
It has been commonly believed that one major advantage of neural networks is their capability of modelling complex statistical interactions between features for automatic feature learning. Statistical interactions capture important information on where features often have joint effects with other features on predicting an outcome. The discovery of interactions is especially useful for scientific discoveries and hypothesis validation. For example, physicists may be interested in understanding what joint factors provide evidence for new elementary particles; doctors may want to know what interactions are accounted for in risk prediction models, to compare against known interactions from existing medical literature.<br />
<br />
With the growth in the computational power available Neural Networks have been able to solve many of the complex tasks in a wide variety of fields. This is mainly due to their ability to model complex and non-linear interactions. Neural networks have traditionally been treated as “black box” models, preventing their adoption in many application domains, such as those where explainability is desirable. It has been noted that complex machine learning models can learn unintended patterns from data, raising significant risks to stakeholders [14]. Therefore, in applications where machine learning models are intended for making critical decisions, such as healthcare or finance, it is paramount to understand how they make predictions [9]. Within several areas, like eg: computation social science, interpretability is of utmost importance. Since we do not understand how a neural network comes to its decision, practitioners in these areas tend to prefer simpler models like linear regression, decision trees, etc. which are much more interpretable. In this paper, we are going to present one way of implementing interpretability in a neural network.<br />
<br />
Existing approaches to interpreting neural networks can be summarized into two types. One type is direct interpretation, which focuses on 1) explaining individual feature importance, for example by computing input gradients [13] and decomposing predictions [8], 2) developing attention-based models, which illustrate where neural networks focus during inference [11], and 3) providing model-specific visualizations, such as feature map and gate activation visualizations [15]. The other type is indirect interpretation, for example post-hoc interpretations of feature importance [12] and knowledge distillation to simpler interpretable models [10].<br />
<br />
In this paper, the authors propose Neural Interaction Detection (NID), which can detect any order or form of statistical interaction captured by the feedforward neural network by examining its weight matrix<br />
<br />
Note that in this paper, we only consider one specific types of neural network, Feed-Forward Neural Network. Based on the methodology discussed here, the authors suggest that we can build an interpretation methodology for other types of networks also.<br />
<br />
=Related Work=<br />
<br />
1. Interaction Detection approaches: <br />
* Conduct individual tests for all features' combination such as ANOVA and Additive Groves.<br />
* Define all interaction forms of interest, then later finds the important ones.<br />
- The paper's goal is to detect interactions without compromising the functional forms. Our method accomplishes higher-order interaction detection, which has the benefit of avoiding a high false positive or false discovery rate.<br />
<br />
2. Interpretability: A lot of work has also been done in this particular area and it can be divided it the following broad categories:<br />
* Feature Importance through Decomposition: Methods like Input Gradient(Sundararajan et al., 2017) learns the importance of features through a gradient-based approach similar to backpropagation. Works like Li et al(2017), Murdoch(2017) and Murdoch(2018) study interpretability of LSTMs by looking at phrase and word level importance scores. Bach et al. 2015 and Shrikumar et al. 2016 (DeepLift) study pixel importance in CNNs.<br />
* Studying Visualizations in Models - Karpathy et al. (2015) worked with character generating LSTMs and tried to study activation and firing in certain hidden units for meaningful attributes. (Yosinski et al., 2015 studies feature map visualizations. <br />
* Attention-Based Models: Bahdanau et al. (2014) - These are a different class of models which use attention modules(different architectures) to help focus the neural network to decide the parts of the input that it should look more closely or give more importance to. Looking at the results of these type of model an indirect sense of interpretability can be gauged.<br />
<br />
The approach in this paper is to extract non-additive interactions between variables from the neural network weights.<br />
<br />
=Notations=<br />
Before we dive in to methodology, we are going to define a few notations here. Most of them will be trivial.<br />
<br />
1. Vector: Vectors are defined with bold-lowercases, '''v, w'''<br />
<br />
2. Matrix: Matrice are defined with blod-uppercases, '''V, W'''<br />
<br />
3. Interger Set: For some interger p <math>\in</math> Z, we define [p] := {1,2,3,...,p}<br />
<br />
=Interaction=<br />
First of all, in order to explain the model, we need to be able to explain the interactions and their effects to output. Therefore, we define 'interacion' between variables as below. <br />
<br />
[[File:def_interaction.PNG|900px|center]]<br />
<br />
From the definition above, for a function like, <math>x_1x_2 + sin(x_3 + x_4 + x_5)</math>, we have <math>{[x_1, x_2]}</math> and <math>{[x_3, x_4, x_5]}</math> interactions. And we say that the latter interaction to be 3-way interaction.<br />
<br />
Note that from the definition above, we can naturally deduce that d-way interaction can exist if and only if all of its (d-1) interactions exist. For example, 3-way interaction above shows that we have 2-way interactions <math>{[3,4], [4,5]}</math> and <math>{[3,5]}</math>.<br />
<br />
One thing that we need to keep in mind is that for models like neural network, most of interactions are happening within hidden layers. This means that we needa proper way of measuring interaction strength.<br />
<br />
The key observation is that for any kinds of interaction, at a some hidden unit of some hidden layer, two interacting features the ancestors. In graph-theoretical language, interaction map can be viewed as an associated directed graph and for any interaction <math>\Gamma \in [p]</math>, there exists at least one vertix that has all of features of <math>\Gamma</math> as ancestors. The statement can be rigorized as the following:<br />
<br />
<br />
[[File:prop2.PNG|900px|center]]<br />
<br />
Now, the above mathematical statement gurantees us to measure interaction strengths at ANY hidden layers. For example, if we want to study about interactions at some specific hidden layer, now we now that there exists corresponding vertices between the hidden layer and output layer. Therefore all we need to do is now to find approprite measure which can summarize the information between those two layers.<br />
}<br />
Before doing so, let's think about a single-layered neural network. For any one hidden unit, we can have possibly, <math>2^{||W_i,:||}</math>, number of interactions. This means that our search space might be too huge for multi-layered networks. Therefore, we need a some descent way of approximate out search space. Moreover, the authors realized a fast interaction detection by limiting the search complexity of the task by only quantifying interactions created at the first hidden layer.<br />
[[File:network1.PNG|500px|center]]<br />
<br />
==Measuring influence in hidden layers==<br />
As we discussed above, in order to consider interaction between units in any layers, we need to think about their out-going paths. However, we soon encountered the fact that for some fully-connected multi-layer neural network, the search space might be too huge to compare. Therefore, we use information about out-going paths gredient upper bond. To represent the influence of out-going paths at <math>l</math>-hidden layer, we define cumulative impact of weights between output layer and <math>l+1</math>. We define aggregated weights as, <br />
<br />
[[File:def3.PNG|900px|center]]<br />
<br />
<br />
Note that <math>z^{(l)} \in R^{(p_l)}</math> where <math>p_l</math> is the number of hidden units in <math>l</math>-layer.<br />
Moreover, this is the lipschitz constant of gredients. Gredient has been an import variable of measuring influence of features, especially when we consider that input layer's derivative computes the direction normal to decision boundaries.<br />
<br />
==Quantifying influence==<br />
For some <math>i</math> hidden unit at the first hidden layer, which is the closet layer to the input layer, we define the influence strength of some interaction as, <br />
<br />
[[File:measure1.PNG|900px|center]]<br />
<br />
The function <math>\mu</math> will be defined later. Essentially, the formula shows that the strength of influence is defined as the product of the aggregated weight on the first hidden layer and some measure of influence between the first hidden layer and the input layer. <br />
<br />
For the function, <math>\mu</math>, any positive-real valued functions such as max, min and average can be candidates. The effects of those candidates will be tested later.<br />
<br />
Now based on the specifications above, the author suggested the algorithm for searching influential interactions between input layer units as follows:<br />
<br />
[[File:algorithm1.PNG|850px|center]]<br />
<br />
=Cut off Model=<br />
Now using the greedy algorithm defined above, we can rank the interactions by their strength. However, in order to access true interactions, we are building the cut off model which is a generalized additive model (GAM) as below,<br />
<br />
<center><math><br />
c_K('''x''') = \sum_{i=1}^{p}g_i(x_i) + \sum_{i=1}^{K}{g_i}^\prime(x_\chi)<br />
</math></center><br />
<br />
From the above model, each <math>g</math> and <math>g^*</math> are Feed-Forward neural network. We are keep adding interactions until the performance reaches plateaus.<br />
<br />
=Experiment=<br />
For the experiment, the authors have compared three neural network model with traditional statistical interaction detecting algorithms. For the nueral network models, first model will be MLP, second model will be MLP-M, which is MLP with additional univariate network at the output. The last one is the cut-off model defined above, which is denoted by MLP-cutoff. MLP-M model is graphically represented below.<br />
<br />
[[File:output11.PNG|300px|center]]<br />
<br />
For the experiment, We study our interaction detection framework on both simulated and real-world experiments. For simulated experiments, we are going to test on 10 synthetic functions as shown in table I.<br />
<br />
[[File:synthetic.PNG|900px|center]]<br />
<br />
We use four real-world datasets, of which two are regression datasets, and the other two are binary classification datasets. The datasets are a mixture of common prediction tasks in the cal housing<br />
and bike sharing datasets, a scientific discovery task in the higgs boson dataset, and an example of very-high order interaction detection in the letter dataset.<br />
<br />
And the author also reported the results of comparisons between the models. As you can see, neural network based models are performing better in average. Compare to the traditional methods liek ANOVA, MLP and MLP-M method shows 20% increases in performance.<br />
<br />
[[File:performance_mlpm.PNG|900px|center]]<br />
<br />
<br />
[[File:performance2_mlpm.PNG|900px|center]]<br />
<br />
The above result shows that MLP-M almost perfectly catch the most influential pair-wise interactions.<br />
<br />
=Limitations=<br />
Even though for the above synthetic experiment MLP methods showed superior performances, the method still have some limitations. For example, fir the function like, <math>x_1x_2 + x_2x_3 + x_1x_3</math>, neural network fails to distinguish between interlinked interactions to single higher order interaction. Moreoever, correlation between features deteriorates the ability of the network to distinguish interactions. However, correlation issues are presented most of interaction detection algorithms. <br />
<br />
Because this method relies on the neural network fitting the data well, there are some additional concerns. Notably, if the NN is unable to make an appropriate fit (under/overfitting), the resulting interactions will be flawed. This can occur if the datasets that are too small or too noisy, which often occurs in practical settings. <br />
<br />
=Conclusion=<br />
Here we presented the method of detecting interactions using MLP. Compared to other state-of-the-art methods like Additive Groves (AG), the performances are competitive yet computational powers required is far less. Therefore, it is safe to claim that the method will be extremly useful for practitioners with (comparably) less computational powers. Moreover, the NIP algorithm successfully reduced the computation sizes. After all, the most important aspect of this algorithm is that now users of nueral networks can impose interpretability in the model usage, which will change the level of usability to another level for most of practitioners outside of those working in machine learning and deep learning areas.<br />
For future work, the authors want to detect feature interactions by using the common units in the intermediate hidden layers of feedforward networks, and also want to use such interaction detection to interpret weights in other deep neural networks.<br />
<br />
=Critique=<br />
1. Authors need to do large-scale experiments, instead of just conducting experiments on some synthetic dataset with small feature dimensionality, to make their claim stronger.<br />
<br />
2. Although the method proposed in this paper is interesting, the paper would benefit from providing some more explanations to support its idea and fill the possible gaps in its experimental evaluation. In some parts there are repetitive explanations that could be replaced by other essential clarifications.<br />
<br />
=Reference=<br />
<br />
[1] Jacob Bien, Jonathan Taylor, and Robert Tibshirani. A lasso for hierarchical interactions. Annals of statistics, 41(3):1111, 2013. <br />
<br />
[2] G David Garson. Interpreting neural-network connection weights. AI Expert, 6(4):46–51, 1991.<br />
<br />
[3] Yotam Hechtlinger. Interpretation of prediction models using the input gradient. arXiv preprint arXiv:1611.07634, 2016.<br />
<br />
[4] Shiyu Liang and R Srikant. Why deep neural networks for function approximation? 2016. <br />
<br />
[5] David Rolnick and Max Tegmark. The power of deeper networks for expressing natural functions. International Conference on Learning Representations, 2018. <br />
<br />
[6] Daria Sorokina, Rich Caruana, and Mirek Riedewald. Additive groves of regression trees. Machine Learning: ECML 2007, pp. 323–334, 2007.<br />
<br />
[7] Simon Wood. Generalized additive models: an introduction with R. CRC press, 2006<br />
<br />
[8] Sebastian Bach, Alexander Binder, Gre ́goire Montavon, Frederick Klauschen, Klaus-Robert Mu ̈ller, and Wojciech Samek. On pixel-wise explanations for non-linear classifier decisions by layer-wise relevance propagation. PloS one, 10(7):e0130140, 2015.<br />
<br />
[9] Rich Caruana, Yin Lou, Johannes Gehrke, Paul Koch, Marc Sturm, and Noemie Elhadad. Intel- ligible models for healthcare: Predicting pneumonia risk and hospital 30-day readmission. In Proceedings of the 21th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pp. 1721–1730. ACM, 2015.<br />
<br />
[10] Zhengping Che, Sanjay Purushotham, Robinder Khemani, and Yan Liu. Interpretable deep models for icu outcome prediction. In AMIA Annual Symposium Proceedings, volume 2016, pp. 371. American Medical Informatics Association, 2016.<br />
<br />
[11] Laurent Itti, Christof Koch, and Ernst Niebur. A model of saliency-based visual attention for rapid scene analysis. IEEE Transactions on pattern analysis and machine intelligence, 20(11):1254– 1259, 1998.<br />
<br />
[12] Marco Tulio Ribeiro, Sameer Singh, and Carlos Guestrin. Why should i trust you?: Explaining the predictions of any classifier. In Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pp. 1135–1144. ACM, 2016.<br />
<br />
[13]Karen Simonyan, Andrea Vedaldi, and Andrew Zisserman. Deep inside convolutional networks: Vi- sualising image classification models and saliency maps. arXiv preprint arXiv:1312.6034, 2013.<br />
<br />
[14] Kush R Varshney and Homa Alemzadeh. On the safety of machine learning: Cyber-physical sys- tems, decision sciences, and data products. arXiv preprint arXiv:1610.01256, 2016.<br />
<br />
[15] Jason Yosinski, Jeff Clune, Anh Nguyen, Thomas Fuchs, and Hod Lipson. Understanding neural networks through deep visualization. arXiv preprint arXiv:1506.06579, 2015.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=DETECTING_STATISTICAL_INTERACTIONS_FROM_NEURAL_NETWORK_WEIGHTS&diff=41958DETECTING STATISTICAL INTERACTIONS FROM NEURAL NETWORK WEIGHTS2018-11-30T00:08:28Z<p>Gsahu: /* Cut off Model */</p>
<hr />
<div>=Introduction=<br />
<br />
It has been commonly believed that one major advantage of neural networks is their capability of modelling complex statistical interactions between features for automatic feature learning. Statistical interactions capture important information on where features often have joint effects with other features on predicting an outcome. The discovery of interactions is especially useful for scientific discoveries and hypothesis validation. For example, physicists may be interested in understanding what joint factors provide evidence for new elementary particles; doctors may want to know what interactions are accounted for in risk prediction models, to compare against known interactions from existing medical literature.<br />
<br />
With the growth in the computational power available Neural Networks have been able to solve many of the complex tasks in a wide variety of fields. This is mainly due to their ability to model complex and non-linear interactions. Neural networks have traditionally been treated as “black box” models, preventing their adoption in many application domains, such as those where explainability is desirable. It has been noted that complex machine learning models can learn unintended patterns from data, raising significant risks to stakeholders [14]. Therefore, in applications where machine learning models are intended for making critical decisions, such as healthcare or finance, it is paramount to understand how they make predictions [9]. Within several areas, like eg: computation social science, interpretability is of utmost importance. Since we do not understand how a neural network comes to its decision, practitioners in these areas tend to prefer simpler models like linear regression, decision trees, etc. which are much more interpretable. In this paper, we are going to present one way of implementing interpretability in a neural network.<br />
<br />
Existing approaches to interpreting neural networks can be summarized into two types. One type is direct interpretation, which focuses on 1) explaining individual feature importance, for example by computing input gradients [13] and decomposing predictions [8], 2) developing attention-based models, which illustrate where neural networks focus during inference [11], and 3) providing model-specific visualizations, such as feature map and gate activation visualizations [15]. The other type is indirect interpretation, for example post-hoc interpretations of feature importance [12] and knowledge distillation to simpler interpretable models [10].<br />
<br />
In this paper, the authors propose Neural Interaction Detection (NID), which can detect any order or form of statistical interaction captured by the feedforward neural network by examining its weight matrix<br />
<br />
Note that in this paper, we only consider one specific types of neural network, Feed-Forward Neural Network. Based on the methodology discussed here, the authors suggest that we can build an interpretation methodology for other types of networks also.<br />
<br />
=Related Work=<br />
<br />
1. Interaction Detection approaches: <br />
* Conduct individual tests for all features' combination such as ANOVA and Additive Groves.<br />
* Define all interaction forms of interest, then later finds the important ones.<br />
- The paper's goal is to detect interactions without compromising the functional forms. Our method accomplishes higher-order interaction detection, which has the benefit of avoiding a high false positive or false discovery rate.<br />
<br />
2. Interpretability: A lot of work has also been done in this particular area and it can be divided it the following broad categories:<br />
* Feature Importance through Decomposition: Methods like Input Gradient(Sundararajan et al., 2017) learns the importance of features through a gradient-based approach similar to backpropagation. Works like Li et al(2017), Murdoch(2017) and Murdoch(2018) study interpretability of LSTMs by looking at phrase and word level importance scores. Bach et al. 2015 and Shrikumar et al. 2016 (DeepLift) study pixel importance in CNNs.<br />
* Studying Visualizations in Models - Karpathy et al. (2015) worked with character generating LSTMs and tried to study activation and firing in certain hidden units for meaningful attributes. (Yosinski et al., 2015 studies feature map visualizations. <br />
* Attention-Based Models: Bahdanau et al. (2014) - These are a different class of models which use attention modules(different architectures) to help focus the neural network to decide the parts of the input that it should look more closely or give more importance to. Looking at the results of these type of model an indirect sense of interpretability can be gauged.<br />
<br />
The approach in this paper is to extract non-additive interactions between variables from the neural network weights.<br />
<br />
=Notations=<br />
Before we dive in to methodology, we are going to define a few notations here. Most of them will be trivial.<br />
<br />
1. Vector: Vectors are defined with bold-lowercases, '''v, w'''<br />
<br />
2. Matrix: Matrice are defined with blod-uppercases, '''V, W'''<br />
<br />
3. Interger Set: For some interger p <math>\in</math> Z, we define [p] := {1,2,3,...,p}<br />
<br />
=Interaction=<br />
First of all, in order to explain the model, we need to be able to explain the interactions and their effects to output. Therefore, we define 'interacion' between variables as below. <br />
<br />
[[File:def_interaction.PNG|900px|center]]<br />
<br />
From the definition above, for a function like, <math>x_1x_2 + sin(x_3 + x_4 + x_5)</math>, we have <math>{[x_1, x_2]}</math> and <math>{[x_3, x_4, x_5]}</math> interactions. And we say that the latter interaction to be 3-way interaction.<br />
<br />
Note that from the definition above, we can naturally deduce that d-way interaction can exist if and only if all of its (d-1) interactions exist. For example, 3-way interaction above shows that we have 2-way interactions <math>{[3,4], [4,5]}</math> and <math>{[3,5]}</math>.<br />
<br />
One thing that we need to keep in mind is that for models like neural network, most of interactions are happening within hidden layers. This means that we needa proper way of measuring interaction strength.<br />
<br />
The key observation is that for any kinds of interaction, at a some hidden unit of some hidden layer, two interacting features the ancestors. In graph-theoretical language, interaction map can be viewed as an associated directed graph and for any interaction <math>\Gamma \in [p]</math>, there exists at least one vertix that has all of features of <math>\Gamma</math> as ancestors. The statement can be rigorized as the following:<br />
<br />
<br />
[[File:prop2.PNG|900px|center]]<br />
<br />
Now, the above mathematical statement gurantees us to measure interaction strengths at ANY hidden layers. For example, if we want to study about interactions at some specific hidden layer, now we now that there exists corresponding vertices between the hidden layer and output layer. Therefore all we need to do is now to find approprite measure which can summarize the information between those two layers.<br />
}<br />
Before doing so, let's think about a single-layered neural network. For any one hidden unit, we can have possibly, <math>2^{||W_i,:||}</math>, number of interactions. This means that our search space might be too huge for multi-layered networks. Therefore, we need a some descent way of approximate out search space. Moreover, the authors realized a fast interaction detection by limiting the search complexity of the task by only quantifying interactions created at the first hidden layer.<br />
[[File:network1.PNG|500px|center]]<br />
<br />
==Measuring influence in hidden layers==<br />
As we discussed above, in order to consider interaction between units in any layers, we need to think about their out-going paths. However, we soon encountered the fact that for some fully-connected multi-layer neural network, the search space might be too huge to compare. Therefore, we use information about out-going paths gredient upper bond. To represent the influence of out-going paths at <math>l</math>-hidden layer, we define cumulative impact of weights between output layer and <math>l+1</math>. We define aggregated weights as, <br />
<br />
[[File:def3.PNG|900px|center]]<br />
<br />
<br />
Note that <math>z^{(l)} \in R^{(p_l)}</math> where <math>p_l</math> is the number of hidden units in <math>l</math>-layer.<br />
Moreover, this is the lipschitz constant of gredients. Gredient has been an import variable of measuring influence of features, especially when we consider that input layer's derivative computes the direction normal to decision boundaries.<br />
<br />
==Quantifying influence==<br />
For some <math>i</math> hidden unit at the first hidden layer, which is the closet layer to the input layer, we define the influence strength of some interaction as, <br />
<br />
[[File:measure1.PNG|900px|center]]<br />
<br />
The function <math>\mu</math> will be defined later. Essentially, the formula shows that the strength of influence is defined as the product of the aggregated weight on the first hidden layer and some measure of influence between the first hidden layer and the input layer. <br />
<br />
For the function, <math>\mu</math>, any positive-real valued functions such as max, min and average can be candidates. The effects of those candidates will be tested later.<br />
<br />
Now based on the specifications above, the author suggested the algorithm for searching influential interactions between input layer units as follows:<br />
<br />
[[File:algorithm1.PNG|850px|center]]<br />
<br />
=Cut off Model=<br />
Now using the greedy algorithm defined above, we can rank the interactions by their strength. However, in order to access true interactions, we are building the cut off model which is a generalized additive model (GAM) as below,<br />
<br />
<center><math><br />
c_K('''x''') = \sum_{i=1}^{p}g_i(x_i) + \sum_{i=1}^{K}{g_i}^\prime(x_\chi)<br />
</math></center><br />
<br />
From the above model, each <math>g</math> and <math>g^*</math> are Feed-Forward neural network. We are keep adding interactions until the performance reaches plateaus.<br />
<br />
=Experiment=<br />
For the experiment, we are going to compare three neural network model with traditional statistical interaction detecting algorithms. For the nueral network models, first model will be MLP, second model will be MLP-M, which is MLP with additional univariate network at the output. The last one is the cut-off model defined above, which is denoted by MLP-cutoff. MLP-M model is graphically represented below.<br />
<br />
[[File:output11.PNG|300px|center]]<br />
<br />
For the experiment, We study our interaction detection framework on both simulated and real-world experiments. For simulated experiments, we are going to test on 10 synthetic functions as shown in table I.<br />
<br />
[[File:synthetic.PNG|900px|center]]<br />
<br />
We use four real-world datasets, of which two are regression datasets, and the other two are binary classification datasets. The datasets are a mixture of common prediction tasks in the cal housing<br />
and bike sharing datasets, a scientific discovery task in the higgs boson dataset, and an example of very-high order interaction detection in the letter dataset.<br />
<br />
And the author also reported the results of comparisons between the models. As you can see, neural network based models are performing better in average. Compare to the traditional methods liek ANOVA, MLP and MLP-M method shows 20% increases in performance.<br />
<br />
[[File:performance_mlpm.PNG|900px|center]]<br />
<br />
<br />
[[File:performance2_mlpm.PNG|900px|center]]<br />
<br />
The above result shows that MLP-M almost perfectly catch the most influential pair-wise interactions.<br />
<br />
=Limitations=<br />
Even though for the above synthetic experiment MLP methods showed superior performances, the method still have some limitations. For example, fir the function like, <math>x_1x_2 + x_2x_3 + x_1x_3</math>, neural network fails to distinguish between interlinked interactions to single higher order interaction. Moreoever, correlation between features deteriorates the ability of the network to distinguish interactions. However, correlation issues are presented most of interaction detection algorithms. <br />
<br />
Because this method relies on the neural network fitting the data well, there are some additional concerns. Notably, if the NN is unable to make an appropriate fit (under/overfitting), the resulting interactions will be flawed. This can occur if the datasets that are too small or too noisy, which often occurs in practical settings. <br />
<br />
=Conclusion=<br />
Here we presented the method of detecting interactions using MLP. Compared to other state-of-the-art methods like Additive Groves (AG), the performances are competitive yet computational powers required is far less. Therefore, it is safe to claim that the method will be extremly useful for practitioners with (comparably) less computational powers. Moreover, the NIP algorithm successfully reduced the computation sizes. After all, the most important aspect of this algorithm is that now users of nueral networks can impose interpretability in the model usage, which will change the level of usability to another level for most of practitioners outside of those working in machine learning and deep learning areas.<br />
For future work, the authors want to detect feature interactions by using the common units in the intermediate hidden layers of feedforward networks, and also want to use such interaction detection to interpret weights in other deep neural networks.<br />
<br />
=Critique=<br />
1. Authors need to do large-scale experiments, instead of just conducting experiments on some synthetic dataset with small feature dimensionality, to make their claim stronger.<br />
<br />
2. Although the method proposed in this paper is interesting, the paper would benefit from providing some more explanations to support its idea and fill the possible gaps in its experimental evaluation. In some parts there are repetitive explanations that could be replaced by other essential clarifications.<br />
<br />
=Reference=<br />
<br />
[1] Jacob Bien, Jonathan Taylor, and Robert Tibshirani. A lasso for hierarchical interactions. Annals of statistics, 41(3):1111, 2013. <br />
<br />
[2] G David Garson. Interpreting neural-network connection weights. AI Expert, 6(4):46–51, 1991.<br />
<br />
[3] Yotam Hechtlinger. Interpretation of prediction models using the input gradient. arXiv preprint arXiv:1611.07634, 2016.<br />
<br />
[4] Shiyu Liang and R Srikant. Why deep neural networks for function approximation? 2016. <br />
<br />
[5] David Rolnick and Max Tegmark. The power of deeper networks for expressing natural functions. International Conference on Learning Representations, 2018. <br />
<br />
[6] Daria Sorokina, Rich Caruana, and Mirek Riedewald. Additive groves of regression trees. Machine Learning: ECML 2007, pp. 323–334, 2007.<br />
<br />
[7] Simon Wood. Generalized additive models: an introduction with R. CRC press, 2006<br />
<br />
[8] Sebastian Bach, Alexander Binder, Gre ́goire Montavon, Frederick Klauschen, Klaus-Robert Mu ̈ller, and Wojciech Samek. On pixel-wise explanations for non-linear classifier decisions by layer-wise relevance propagation. PloS one, 10(7):e0130140, 2015.<br />
<br />
[9] Rich Caruana, Yin Lou, Johannes Gehrke, Paul Koch, Marc Sturm, and Noemie Elhadad. Intel- ligible models for healthcare: Predicting pneumonia risk and hospital 30-day readmission. In Proceedings of the 21th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pp. 1721–1730. ACM, 2015.<br />
<br />
[10] Zhengping Che, Sanjay Purushotham, Robinder Khemani, and Yan Liu. Interpretable deep models for icu outcome prediction. In AMIA Annual Symposium Proceedings, volume 2016, pp. 371. American Medical Informatics Association, 2016.<br />
<br />
[11] Laurent Itti, Christof Koch, and Ernst Niebur. A model of saliency-based visual attention for rapid scene analysis. IEEE Transactions on pattern analysis and machine intelligence, 20(11):1254– 1259, 1998.<br />
<br />
[12] Marco Tulio Ribeiro, Sameer Singh, and Carlos Guestrin. Why should i trust you?: Explaining the predictions of any classifier. In Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pp. 1135–1144. ACM, 2016.<br />
<br />
[13]Karen Simonyan, Andrea Vedaldi, and Andrew Zisserman. Deep inside convolutional networks: Vi- sualising image classification models and saliency maps. arXiv preprint arXiv:1312.6034, 2013.<br />
<br />
[14] Kush R Varshney and Homa Alemzadeh. On the safety of machine learning: Cyber-physical sys- tems, decision sciences, and data products. arXiv preprint arXiv:1610.01256, 2016.<br />
<br />
[15] Jason Yosinski, Jeff Clune, Anh Nguyen, Thomas Fuchs, and Hod Lipson. Understanding neural networks through deep visualization. arXiv preprint arXiv:1506.06579, 2015.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=User:Gsahu&diff=41951User:Gsahu2018-11-29T23:55:01Z<p>Gsahu: Created page with "Gaurav Sahu. You can visit https://demfier.github.io/ for more information"</p>
<hr />
<div>Gaurav Sahu.<br />
<br />
You can visit https://demfier.github.io/ for more information</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=a_neural_representation_of_sketch_drawings&diff=41949a neural representation of sketch drawings2018-11-29T23:47:39Z<p>Gsahu: /* Training */</p>
<hr />
<div><br />
== Introduction ==<br />
In this paper, the authors present a recurrent neural network, sketch-rnn, that can be used to construct stroke-based drawings. Besides new robust training methods, they also outline a framework for conditional and unconditional sketch generation.<br />
<br />
Neural networks have been heavily used as image generation tools. For example, Generative Adversarial Networks, Variational Inference, and Autoregressive models have been used. Most of those models are designed to generate pixels to construct images. However, people learn to draw using sequences of strokes, beginning when they are young. The authors propose a new generative model that creates vector images so that it might generalize abstract concepts in a manner more similar to how humans do. <br />
<br />
The model is trained with hand-drawn sketches as input sequences. The model is able to produce sketches in vector format. In the conditional generation model, they also explore the latent space representation for vector images and discuss a few future applications of this model. The model and dataset are now available as an open source project ([https://magenta.tensorflow.org/sketch_rnn link]).<br />
<br />
=== Terminology ===<br />
Pixel images, also referred to as raster or bitmap images are files that encode image data as a set of pixels. These are the most common image type, with extensions such as .png, .jpg, .bmp. <br />
<br />
Vector images are files that encode image data as paths between points. SVG and EPS file types are used to store vector images. <br />
<br />
For a visual comparison of raster and vector images, see this [https://www.youtube.com/watch?v=-Fs2t6P5AjY video]. As mentioned, vector images are generally simpler and more abstract, whereas raster images generally are used to store detailed images. <br />
<br />
For this paper, the important distinction between the two is that the encoding of images in the model will be inherently more abstract because of the vector representation. The intuition is that generating abstract representations is more effective using a vector representation. <br />
<br />
== Related Work ==<br />
There are some works in the history that used a similar approach to generate images such as Portrait Drawing by Paul the Robot [26, 28] and some reinforcement learning approaches[28], Reinforcement Learning to discover a set of paint brush strokes that can best represent a given input photograph. They work more like a mimic of digitized photographs. There are also some Neural networks based approaches, but those are mostly dealing with pixel images. Little work is done on vector images generation. There are models that use Hidden Markov Models [25] or Mixture Density Networks [2] to generate human sketches, continuous data points (modelling Chinese characters as a sequence of pen stroke actions) or vectorized Kanji characters [9,29].<br />
<br />
The model also allows us to explore the latent space representation of vector images. There are previous works that achieved similar functions as well, such as combining Sequence-to-Sequence models with Variational Autoencoder to model sentences into latent space and using probabilistic program induction to model Omniglot dataset.<br />
<br />
The dataset they use contains 50 million vector sketches. Before this paper, there is a Sketch data with 20k vector sketches, a Sketchy dataset with 70k vector sketches along with pixel images, and a ShadowDraw system that used 30k raster images along with extracted vectorized features. They are all comparatively small.<br />
<br />
== Major Contributions ==<br />
This paper makes the following major contributions: Authors outline a framework for both unconditional and<br />
conditional generation of vector images composed of a sequence of lines. The recurrent neural<br />
network-based generative model is capable of producing sketches of common objects in a vector<br />
format. The paper develops a training procedure unique to vector images to make the training more robust. The paper also made available<br />
a large dataset of hand drawn vector images to encourage further development of generative modelling<br />
for vector images, and also release an implementation of our model as an open source project<br />
<br />
== Methodology ==<br />
=== Dataset ===<br />
QuickDraw is a dataset with 50 million vector drawings collected by an online game [https://quickdraw.withgoogle.com/# Quick Draw!], where the players are required to draw objects belonging to a particular object class in less than 20 seconds. It contains hundreds of classes, each class has 70k training samples, 2.5k validation samples and 2.5k test samples.<br />
<br />
The data format of each sample is a representation of a pen stroke action event. The Origin is the initial coordinate of the drawing. The sketches are points in a list. Each point consists of 5 elements <math> (\Delta x, \Delta y, p_{1}, p_{2}, p_{3})</math> where x and y are the offset distance in x and y directions from the previous point. The parameters <math>p_{1}, p_{2}, p_{3}</math> represent three possible states in binary one-hot representation where <math>p_{1}</math> indicates the pen is touching the paper, <math>p_{2}</math> indicates the pen will be lifted from here, and <math>p_{3}</math> represents the drawing has ended.<br />
<br />
=== Sketch-RNN ===<br />
[[File:sketchfig2.png|700px|center]]<br />
<br />
The model is a Sequence-to-Sequence Variational Autoencoder(VAE). <br />
<br />
==== Encoder ====<br />
The encoder is a bidirectional RNN. The input is a sketch sequence denoted by <math>S =\{S_0, S_1, ... S_{N_{s}}\}</math> and a reversed sketch sequence denoted by <math>S_{reverse} = \{S_{N_{s}},S_{N_{s}-1}, ... S_0\}</math>. The final hidden layer representations of the two encoded sequences <math>(h_{ \rightarrow}, h_{ \leftarrow})</math> are concatenated to form a latent vector, <math>h</math>, of size <math>N_{z}</math>,<br />
<br />
\begin{split}<br />
&h_{ \rightarrow} = encode_{ \rightarrow }(S), \\<br />
&h_{ \leftarrow} = encode_{ \leftarrow }(S_{reverse}), \\<br />
&h = [h_{\rightarrow}; h_{\leftarrow}].<br />
\end{split}<br />
<br />
Then the authors project <math>h</math> into two vectors <math>\mu</math> and <math>\hat{\sigma}</math> of size <math>N_{z}</math>. The projection is performed using a fully connected layer. These two vectors are the parameters of the latent space Gaussian distribution that will estimate the distribution of the input data. Because standard deviations cannot be negative, an exponential function is used to convert it to all positive values. Next, a random variable with mean <math>\mu</math> and standard deviation <math>\sigma</math> is constructed by scaling a normalized IID Gaussian, <math>\mathcal{N}(0,I)</math>, <br />
<br />
\begin{split}<br />
& \mu = W_\mu h + b_\mu, \\<br />
& \hat \sigma = W_\sigma h + b_\sigma, \\<br />
& \sigma = exp( \frac{\hat \sigma}{2}), \\<br />
& z = \mu + \sigma \odot \mathcal{N}(0,I). <br />
\end{split}<br />
<br />
<br />
Note that <math>z</math> is not deterministic but a random vector that can be conditioned on an input sketch sequence.<br />
<br />
==== Decoder ====<br />
The decoder is an autoregressive RNN. The initial hidden and cell states are generated using <math>[h_0;c_0] = \tanh(W_z z + b_z)</math>. Here, <math>c_0</math> is utilized if applicable (eg. if an LSTM decoder is used). <math>S_0</math> is defined as <math>(0,0,1,0,0)</math> (the pen is touching the paper at location 0, 0). <br />
<br />
For each step <math>i</math> in the decoder, the input <math>x_i</math> is the concatenation of the previous point <math>S_{i-1}</math> and the latent vector <math>z</math>. The outputs of the RNN decoder <math>y_i</math> are parameters for a probability distribution that will generate the next point <math>S_i</math>. <br />
<br />
The authors model <math>(\Delta x,\Delta y)</math> as a Gaussian mixture model (GMM) with <math>M</math> normal distributions and model the ground truth data <math>(p_1, p_2, p_3)</math> as a categorical distribution <math>(q_1, q_2, q_3)</math> where <math>q_1, q_2\ \text{and}\ q_3</math> sum up to 1,<br />
<br />
\begin{align*}<br />
p(\Delta x, \Delta y) = \sum_{j=1}^{M} \Pi_j \mathcal{N}(\Delta x,\Delta y | \mu_{x,j}, \mu_{y,j}, \sigma_{x,j},\sigma_{y,j}, \rho _{xy,j}), where \sum_{j=1}^{M}\Pi_j = 1<br />
\end{align*}<br />
<br />
Where <math>\mathcal{N}(\Delta x,\Delta y | \mu_{x,j}, \mu_{y,j}, \sigma_{x,j},\sigma_{y,j}, \rho _{xy,j})</math> is a bi-variate Normal Distribution, with parameters means <math>\mu_x, \mu_y</math>, standard deviations <math>\sigma_x, \sigma_y</math> and correlation parameter <math>\rho_{xy}</math>. There are <math>M</math> such distributions. <math>\Pi</math> is a categorical distribution vector of length <math>M</math>. Collectively these form the mixture weights of the Gaussian Mixture model.<br />
<br />
The output vector <math>y_i</math> is generated using a fully-connected forward propagation in the hidden state of the RNN.<br />
<br />
\begin{split}<br />
&x_i = [S_{i-1}; z], \\<br />
&[h_i; c_i] = forward(x_i,[h_{i-1}; c_{i-1}]), \\<br />
&y_i = W_y h_i + b_y, \\<br />
&y_i \in \mathbb{R}^{6M+3}. \\<br />
\end{split}<br />
<br />
The output consists the probability distribution of the next data point.<br />
<br />
\begin{align*}<br />
[(\hat\Pi_1\ \mu_x\ \mu_y\ \hat\sigma_x\ \hat\sigma_y\ \hat\rho_{xy})_1\ (\hat\Pi_1\ \mu_x\ \mu_y\ \hat\sigma_x\ \hat\sigma_y\ \hat\rho_{xy})_2\ ...\ (\hat\Pi_1\ \mu_x\ \mu_y\ \hat\sigma_x\ \hat\sigma_y\ \hat\rho_{xy})_M\ (\hat{q_1}\ \hat{q_2}\ \hat{q_3})] = y_i<br />
\end{align*}<br />
<br />
<math>\exp</math> and <math>\tanh</math> operations are applied to ensure that the standard deviations are non-negative and the correlation value is between -1 and 1.<br />
<br />
\begin{align*}<br />
\sigma_x = \exp (\hat \sigma_x),\ <br />
\sigma_y = \exp (\hat \sigma_y),\ <br />
\rho_{xy} = \tanh(\hat \rho_{xy}). <br />
\end{align*}<br />
<br />
Categorical distribution probabilities for <math>(p_1, p_2, p_3)</math> using <math>(q_1, q_2, q_3)</math> can be obtained as :<br />
<br />
\begin{align*}<br />
q_k = \frac{\exp{(\hat q_k)}}{ \sum\nolimits_{j = 1}^{3} \exp {(\hat q_j)}},<br />
k \in \left\{1,2,3\right\}, <br />
\Pi _k = \frac{\exp{(\hat \Pi_k)}}{ \sum\nolimits_{j = 1}^{M} \exp {(\hat \Pi_j)}},<br />
k \in \left\{1,...,M\right\}.<br />
\end{align*}<br />
<br />
It is hard for the model to decide when to stop drawing because the probabilities of the three events <math>(p_1, p_2, p_3)</math> are very unbalanced. Researchers in the past have used different weights for each pen event probability, but the authors found this approach lacking elegance and inadequate. They define a hyperparameter representing the max length of the longest sketch in the training set denoted by <math>N_{max}</math>, and set the <math>S_i</math> to be <math>(0, 0, 0, 0, 1)</math> for <math>i > N_s</math>.<br />
<br />
The outcome sample <math>S_i^{'}</math> can be generated in each time step during sample process and fed as input for the next time step. The process will stop when <math>p_3 = 1</math> or <math>i = N_{max}</math>. The output is not deterministic but conditioned random sequences. The level of randomness can be controlled using a temperature parameter <math>\tau</math>.<br />
<br />
\begin{align*}<br />
\hat q_k \rightarrow \frac{\hat q_k}{\tau}, <br />
\hat \Pi_k \rightarrow \frac{\hat \Pi_k}{\tau}, <br />
\sigma_x^2 \rightarrow \sigma_x^2\tau, <br />
\sigma_y^2 \rightarrow \sigma_y^2\tau. <br />
\end{align*}<br />
<br />
The <math>\tau</math> ranges from 0 to 1. When <math>\tau = 0</math> the output will be deterministic as the sample will consist of the points on the peak of the probability density function.<br />
<br />
=== Unconditional Generation ===<br />
There is a special case that only the decoder RNN module is trained. The decoder RNN could work as a standalone autoregressive model without latent variables. In this case, initial states are 0, the input <math>x_i</math> is only <math>S_{i-1}</math> or <math>S_{i-1}^{'}</math>. In the Figure 3, generating sketches unconditionally from the temperature parameter <math>\tau = 0.2</math> at the top in blue, to <math>\tau = 0.9</math> at the bottom in red.<br />
<br />
[[File:sketchfig3.png|700px|center]]<br />
<br />
=== Training ===<br />
The training process is the same as a Variational Autoencoder. The loss function is the sum of Reconstruction Loss <math>L_R</math> and the Kullback-Leibler Divergence Loss <math>L_{KL}</math>. The reconstruction loss <math>L_R</math> can be obtained with generated parameters of pdf and training data <math>S</math>. It is the sum of the <math>L_s</math> and <math>L_p</math>, which are the log loss of the offset <math>(\Delta x, \Delta y)</math> and the pen state <math>(p_1, p_2, p_3)</math>.<br />
<br />
\begin{align*}<br />
L_s = - \frac{1 }{N_{max}} \sum_{i = 1}^{N_s} \log(\sum_{i = 1}^{M} \Pi_{j,i} \mathcal{N}(\Delta x,\Delta y | \mu_{x,j,i}, \mu_{y,j,i}, \sigma_{x,j,i},\sigma_{y,j,i}, \rho _{xy,j,i})), <br />
\end{align*}<br />
\begin{align*}<br />
L_p = - \frac{1 }{N_{max}} \sum_{i = 1}^{N_{max}} \sum_{k = 1}^{3} p_{k,i} \log (q_{k,i}), <br />
L_R = L_s + L_p.<br />
\end{align*}<br />
<br />
<br />
Both terms are normalized by <math>N_{max}</math>.<br />
<br />
<math>L_{KL}</math> measures the difference between the distribution of the latent vector <math>z</math> and an IID Gaussian vector with zero mean and unit variance.<br />
<br />
\begin{align*}<br />
L_{KL} = - \frac{1}{2 N_z} (1+\hat \sigma - \mu^2 - \exp(\hat \sigma))<br />
\end{align*}<br />
<br />
The overall loss is weighted as:<br />
<br />
\begin{align*}<br />
Loss = L_R + w_{KL} L_{KL}<br />
\end{align*}<br />
<br />
When <math>w_{KL} = 0</math>, the model becomes a standalone unconditional generator. Specially, there will be no <math>L_{KL} </math> term as we only optimize for <math>L_{R} </math>. By removing the <math>L_{KL} </math> term the model approaches a pure autoencoder, meaning it sacrifices the ability to enforce a prior over the latent space and gains better reconstruction loss metrics.<br />
<br />
While the aforementioned loss function could be used, it was found that annealing the KL term (as shown below) in the loss function produces better results.<br />
<br />
<center><math><br />
\eta_{step} = 1 - (1 - \eta_{min})R^{step}<br />
</math></center><br />
<br />
<center><math><br />
Loss_{train} = L_R + w_{KL} \eta_{step} max(L_{KL}, KL_{min})<br />
</math></center><br />
<br />
As shown in Figure 4, the <math>L_{R} </math> metric for the standalone decoder model is actually an upper bound for different models using a latent vector. The reason is the unconditional model does not access to the entire sketch it needs to generate.<br />
<br />
[[File:s.png|600px|thumb|center|Figure 4. Tradeoff between <math>L_{R} </math> and <math>L_{KL} </math>, for two models trained on single class datasets (left).<br />
Validation Loss Graph for models trained on the Yoga dataset using various <math>w_{KL} </math>. (right)]]<br />
<br />
== Experiments ==<br />
The authors experiment with the sketch-rnn model using different settings and recorded both losses. They used a Long Short-Term Memory(LSTM) model as an encoder and a HyperLSTM as a decoder. HyperLSTM is a type of RNN cell that excels at sequence generation tasks. The ability for HyperLSTM to spontaneously augment its own weights enables it to adapt to many different regimes<br />
in a large diverse dataset. They also conduct multi-class datasets. The result is as follows.<br />
<br />
[[File:sketchtable1.png|700px|center]]<br />
<br />
We could see the trade-off between <math>L_R</math> and <math>L_{KL}</math> in this table clearly. Furthermore, <math>L_R</math> decreases as <math>w_{KL} </math> is halfed. <br />
<br />
=== Conditional Reconstruction ===<br />
The authors assess the reconstructed sketch with a given sketch with different <math>\tau</math> values. We could see that with high <math>\tau</math> value on the right, the reconstructed sketches are more random.<br />
<br />
[[File:sketchfig5.png|700px|center]]<br />
<br />
They also experiment on inputting a sketch from a different class. The output will still keep some features from the class that the model is trained on.<br />
<br />
=== Latent Space Interpolation ===<br />
The authors visualize the reconstruction sketches while interpolating between latent vectors using different <math>w_{KL}</math> values. With high <math>w_{KL}</math> values, the generated images are more coherently interpolated.<br />
<br />
[[File:sketchfig6.png|700px|center]]<br />
<br />
=== Sketch Drawing Analogies ===<br />
Since the latent vector <math>z</math> encode conceptual features of a sketch, those features can also be used to augment other sketches that do not have these features. This is possible when models are trained with low <math>L_{KL}</math> values. The authors are able to perform vector arithmetic on latent vectors from different sketches and explore how the model generates sketches base on these latent spaces.<br />
<br />
=== Predicting Different Endings of Incomplete Sketches === <br />
This model is able to predict an incomplete sketch by encoding the sketch into hidden state <math>h</math> using the decoder and then using <math>h</math> as an initial hidden state to generate the remaining sketch. The authors train on individual classes by using decoder-only models and set τ = 0.8 to complete samples. Figure 7 shows the results.<br />
<br />
[[File:sketchfig7.png|700px|center]]<br />
<br />
== Applications and Future Work ==<br />
The authors believe this model can assist artists by suggesting how to finish a sketch, helping them to find interesting intersections between different drawings or objects, or generating a lot of similar but different designs. In the simplest use, pattern designers can apply sketch-rnn to generate a large number of similar, but unique designs for textile or wallpaper prints. The creative designers can also come up with abstract designs which enables them to resonate more with their target audience<br />
<br />
This model may also find its place on teaching students how to draw. Even with the simple sketches in QuickDraw, the authors of this work have become much more proficient at drawing animals, insects, and various sea creatures after conducting these experiments. <br />
When the model is trained with a high <math>w_{KL}</math> and sampled with a low <math>\tau</math>, it may help to turn a poor sketch into a more aesthetical one. Latent vector augmentation could also help to create a better drawing by inputting user-rating data during training processes.<br />
<br />
The authors conclude by providing the following future directions to this work:<br />
# Investigate using user-rating data to augmenting the latent vector in the direction that maximizes the aesthetics of the drawing.<br />
# Look into combining variations of sequence-generation models with unsupervised, cross-domain pixel image generation models.<br />
<br />
It's exciting that they manage to combine this model with other unsupervised, cross-domain pixel image generation models to create photorealistic images from sketches.<br />
<br />
The authors have also mentioned the opposite direction of converting a photograph of an object into an unrealistic, but similar looking<br />
sketch of the object composed of a minimal number of lines to be a more interesting problem.<br />
<br />
Moreover, it would be interesting to see how varying loss will be represented as a drawing. Some exotic form of loss function may change the way that the network behaves, which can lead to various applications.<br />
<br />
== Conclusion ==<br />
The paper presents a methodology to model sketch drawings using recurrent neural networks. The sketch-rnn model that can encode and decode sketches, generate and complete unfinished sketches is introduced in this paper. In addition, Authors demonstrated how to both interpolate between latent spaces from a different class, and use it to augment sketches or generate similar looking sketches. Furthermore, the importance of enforcing a prior distribution on latent vector while interpolating coherent sketch generations is shown. Finally, a large sketch drawings dataset for future research work is created.<br />
<br />
== Critique ==<br />
This paper presents both a novel large dataset of sketches and a new RNN architecture to generate new sketches. It is very exciting to read but there are stil some aspect to improve.<br />
<br />
* The performance of the decoder model can hardly be evaluated. The authors present the performance of the decoder by showing the generated sketches, it is clear and straightforward, however, not very efficient. It would be great if the authors could present a way, or a metric to evaluate how well the sketches are generated rather than printing them out and evaluate with human judgment. The authors didn't present an evaluation of the algorithms either. They provided <math>L_R</math> and <math>L_{KL}</math> for reference, however, a lower loss doesn't represent a better performance. Training loss alone likely does not capture the quality of a sketch.<br />
<br />
* Algorithm lacks comparison to the prior state of the art on standard metrics, which made the novelty unclear. Using strokes as inputs is a novel and innovative move, however, the paper does not provide a baseline or any comparison with other methods or algorithms. Some other researches were mentioned in the paper, using similar and smaller datasets. It would be great if the authors could use some basic or existing methods a baseline and compare with the new algorithm.<br />
<br />
* Besides the comparison with other algorithms, it would also be great if the authors could remove or replace some component of the algorithm in the model to show if one part is necessary, or what made them decide to include a specific component in the algorithm.<br />
<br />
* The authors proposed a few future applications for the model, however, the current output seems somehow not very close to their descriptions. But I do believe that this is a very good beginning, with the release of the sketch dataset, it must attract more scholars to research and improve with it!<br />
<br />
== References == <br />
# Jimmy L. Ba, Jamie R. Kiros, and Geoffrey E. Hinton. Layer normalization. NIPS, 2016.<br />
# Christopher M. Bishop. Mixture density networks. Technical Report, 1994. URL http://publications.aston.ac.uk/373/.<br />
# Samuel R. Bowman, Luke Vilnis, Oriol Vinyals, Andrew M. Dai, Rafal Józefowicz, and Samy Bengio. Generating Sentences from a Continuous Space. CoRR, abs/1511.06349, 2015. URL http://arxiv.org/abs/1511.06349.<br />
# H. Dong, P. Neekhara, C. Wu, and Y. Guo. Unsupervised Image-to-Image Translation with Generative Adversarial Networks. ArXiv e-prints, January 2017.<br />
# David H. Douglas and Thomas K. Peucker. Algorithms for the reduction of the number of points required to represent a digitized line or its caricature. Cartographica: The International Journal for Geographic Information and Geovisualization, 10(2):112–122, October 1973. doi: 10.3138/fm57-6770-u75u-7727. URL http://dx.doi.org/10.3138/fm57-6770-u75u-7727.<br />
# Mathias Eitz, James Hays, and Marc Alexa. How Do Humans Sketch Objects? ACM Trans. Graph.(Proc. SIGGRAPH), 31(4):44:1–44:10, 2012.<br />
# I. Goodfellow. NIPS 2016 Tutorial: Generative Adversarial Networks. ArXiv e-prints, December 2016.<br />
# Alex Graves. Generating sequences with recurrent neural networks. arXiv:1308.0850, 2013.<br />
# David Ha. Recurrent Net Dreams Up Fake Chinese Characters in Vector Format with TensorFlow, 2015.<br />
# David Ha, Andrew M. Dai, and Quoc V. Le. HyperNetworks. In ICLR, 2017.<br />
# Sepp Hochreiter and Juergen Schmidhuber. Long short-term memory. Neural Computation, 1997.<br />
# P. Isola, J.-Y. Zhu, T. Zhou, and A. A. Efros. Image-to-Image Translation with Conditional Adversarial Networks. ArXiv e-prints, November 2016.<br />
# Jonas Jongejan, Henry Rowley, Takashi Kawashima, Jongmin Kim, and Nick Fox-Gieg. The Quick, Draw! - A.I. Experiment. https://quickdraw.withgoogle.com/, 2016. URL https: //quickdraw.withgoogle.com/.<br />
# C. Kaae Sønderby, T. Raiko, L. Maaløe, S. Kaae Sønderby, and O. Winther. Ladder Variational Autoencoders. ArXiv e-prints, February 2016.<br />
# T. Kim, M. Cha, H. Kim, J. Lee, and J. Kim. Learning to Discover cross-domain Relations with Generative Adversarial Networks. ArXiv e-prints, March 2017.<br />
# D. P Kingma and M. Welling. Auto-Encoding Variational Bayes. ArXiv e-prints, December 2013.<br />
# Diederik Kingma and Jimmy Ba. Adam: A method for stochastic optimization. In ICLR, 2015.<br />
# Diederik P. Kingma, Tim Salimans, and Max Welling. Improving variational inference with inverse autoregressive flow. CoRR, abs/1606.04934, 2016. URL http://arxiv.org/abs/1606.04934.<br />
# Brenden M. Lake, Ruslan Salakhutdinov, and Joshua B. Tenenbaum. Human level concept learning through probabilistic program induction. Science, 350(6266):1332–1338, December 2015. ISSN 1095-9203. doi: 10.1126/science.aab3050. URL http://dx.doi.org/10.1126/science.aab3050.<br />
# Yong Jae Lee, C. Lawrence Zitnick, and Michael F. Cohen. Shadowdraw: Real-time user guidance for freehand drawing. In ACM SIGGRAPH 2011 Papers, SIGGRAPH ’11, pp. 27:1–27:10, New York, NY, USA, 2011. ACM. ISBN 978-1-4503-0943-1. doi: 10.1145/1964921.1964922. URL http://doi.acm.org/10.1145/1964921.1964922.<br />
# M.-Y. Liu, T. Breuel, and J. Kautz. Unsupervised Image-to-Image Translation Networks. ArXiv e-prints, March 2017.<br />
# S. Reed, A. van den Oord, N. Kalchbrenner, S. Gómez Colmenarejo, Z. Wang, D. Belov, and N. de Freitas. Parallel Multiscale Autoregressive Density Estimation. ArXiv e-prints, March 2017.<br />
# Patsorn Sangkloy, Nathan Burnell, Cusuh Ham, and James Hays. The Sketchy Database: Learning to Retrieve Badly Drawn Bunnies. ACM Trans. Graph., 35(4):119:1–119:12, July 2016. ISSN 0730-0301. doi: 10.1145/2897824.2925954. URL http://doi.acm.org/10.1145/2897824.2925954.<br />
# Mike Schuster, Kuldip K. Paliwal, and A. General. Bidirectional recurrent neural networks. IEEE Transactions on Signal Processing, 1997.<br />
# Saul Simhon and Gregory Dudek. Sketch interpretation and refinement using statistical models. In Proceedings of the Fifteenth Eurographics Conference on Rendering Techniques, EGSR’04, pp. 23–32, Aire-la-Ville, Switzerland, Switzerland, 2004. Eurographics Association. ISBN 3-905673-12-6. doi: 10.2312/EGWR/EGSR04/023-032. URL http://dx.doi.org/10.2312/EGWR/EGSR04/023-032.<br />
# Patrick Tresset and Frederic Fol Leymarie. Portrait drawing by paul the robot. Comput. Graph.,37(5):348–363, August 2013. ISSN 0097-8493. doi: 10.1016/j.cag.2013.01.012. URL http://dx.doi.org/10.1016/j.cag.2013.01.012.<br />
# T. White. Sampling Generative Networks. [https://arxiv.org/abs/1609.04468 ArXiv e-prints], September 2016.<br />
#Ning Xie, Hirotaka Hachiya, and Masashi Sugiyama. Artist agent: A reinforcement learning approach to automatic stroke generation in oriental ink painting. In ICML. icml.cc / Omnipress, 2012. URL http://dblp.uni-trier.de/db/conf/icml/icml2012.html#XieHS12.<br />
# Xu-Yao Zhang, Fei Yin, Yan-Ming Zhang, Cheng-Lin Liu, and Yoshua Bengio. Drawing and Recognizing Chinese Characters with Recurrent Neural Network. CoRR, abs/1606.06539, 2016. URL http://arxiv.org/abs/1606.06539.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=a_neural_representation_of_sketch_drawings&diff=41943a neural representation of sketch drawings2018-11-29T23:36:55Z<p>Gsahu: /* References */</p>
<hr />
<div><br />
== Introduction ==<br />
In this paper, the authors present a recurrent neural network, sketch-rnn, that can be used to construct stroke-based drawings. Besides new robust training methods, they also outline a framework for conditional and unconditional sketch generation.<br />
<br />
Neural networks have been heavily used as image generation tools. For example, Generative Adversarial Networks, Variational Inference, and Autoregressive models have been used. Most of those models are designed to generate pixels to construct images. However, people learn to draw using sequences of strokes, beginning when they are young. The authors propose a new generative model that creates vector images so that it might generalize abstract concepts in a manner more similar to how humans do. <br />
<br />
The model is trained with hand-drawn sketches as input sequences. The model is able to produce sketches in vector format. In the conditional generation model, they also explore the latent space representation for vector images and discuss a few future applications of this model. The model and dataset are now available as an open source project ([https://magenta.tensorflow.org/sketch_rnn link]).<br />
<br />
=== Terminology ===<br />
Pixel images, also referred to as raster or bitmap images are files that encode image data as a set of pixels. These are the most common image type, with extensions such as .png, .jpg, .bmp. <br />
<br />
Vector images are files that encode image data as paths between points. SVG and EPS file types are used to store vector images. <br />
<br />
For a visual comparison of raster and vector images, see this [https://www.youtube.com/watch?v=-Fs2t6P5AjY video]. As mentioned, vector images are generally simpler and more abstract, whereas raster images generally are used to store detailed images. <br />
<br />
For this paper, the important distinction between the two is that the encoding of images in the model will be inherently more abstract because of the vector representation. The intuition is that generating abstract representations is more effective using a vector representation. <br />
<br />
== Related Work ==<br />
There are some works in the history that used a similar approach to generate images such as Portrait Drawing by Paul the Robot [26, 28] and some reinforcement learning approaches[28], Reinforcement Learning to discover a set of paint brush strokes that can best represent a given input photograph. They work more like a mimic of digitized photographs. There are also some Neural networks based approaches, but those are mostly dealing with pixel images. Little work is done on vector images generation. There are models that use Hidden Markov Models [25] or Mixture Density Networks [2] to generate human sketches, continuous data points (modelling Chinese characters as a sequence of pen stroke actions) or vectorized Kanji characters [9,29].<br />
<br />
The model also allows us to explore the latent space representation of vector images. There are previous works that achieved similar functions as well, such as combining Sequence-to-Sequence models with Variational Autoencoder to model sentences into latent space and using probabilistic program induction to model Omniglot dataset.<br />
<br />
The dataset they use contains 50 million vector sketches. Before this paper, there is a Sketch data with 20k vector sketches, a Sketchy dataset with 70k vector sketches along with pixel images, and a ShadowDraw system that used 30k raster images along with extracted vectorized features. They are all comparatively small.<br />
<br />
== Major Contributions ==<br />
This paper makes the following major contributions: Authors outline a framework for both unconditional and<br />
conditional generation of vector images composed of a sequence of lines. The recurrent neural<br />
network-based generative model is capable of producing sketches of common objects in a vector<br />
format. The paper develops a training procedure unique to vector images to make the training more robust. The paper also made available<br />
a large dataset of hand drawn vector images to encourage further development of generative modelling<br />
for vector images, and also release an implementation of our model as an open source project<br />
<br />
== Methodology ==<br />
=== Dataset ===<br />
QuickDraw is a dataset with 50 million vector drawings collected by an online game [https://quickdraw.withgoogle.com/# Quick Draw!], where the players are required to draw objects belonging to a particular object class in less than 20 seconds. It contains hundreds of classes, each class has 70k training samples, 2.5k validation samples and 2.5k test samples.<br />
<br />
The data format of each sample is a representation of a pen stroke action event. The Origin is the initial coordinate of the drawing. The sketches are points in a list. Each point consists of 5 elements <math> (\Delta x, \Delta y, p_{1}, p_{2}, p_{3})</math> where x and y are the offset distance in x and y directions from the previous point. The parameters <math>p_{1}, p_{2}, p_{3}</math> represent three possible states in binary one-hot representation where <math>p_{1}</math> indicates the pen is touching the paper, <math>p_{2}</math> indicates the pen will be lifted from here, and <math>p_{3}</math> represents the drawing has ended.<br />
<br />
=== Sketch-RNN ===<br />
[[File:sketchfig2.png|700px|center]]<br />
<br />
The model is a Sequence-to-Sequence Variational Autoencoder(VAE). <br />
<br />
==== Encoder ====<br />
The encoder is a bidirectional RNN. The input is a sketch sequence denoted by <math>S =\{S_0, S_1, ... S_{N_{s}}\}</math> and a reversed sketch sequence denoted by <math>S_{reverse} = \{S_{N_{s}},S_{N_{s}-1}, ... S_0\}</math>. The final hidden layer representations of the two encoded sequences <math>(h_{ \rightarrow}, h_{ \leftarrow})</math> are concatenated to form a latent vector, <math>h</math>, of size <math>N_{z}</math>,<br />
<br />
\begin{split}<br />
&h_{ \rightarrow} = encode_{ \rightarrow }(S), \\<br />
&h_{ \leftarrow} = encode_{ \leftarrow }(S_{reverse}), \\<br />
&h = [h_{\rightarrow}; h_{\leftarrow}].<br />
\end{split}<br />
<br />
Then the authors project <math>h</math> into two vectors <math>\mu</math> and <math>\hat{\sigma}</math> of size <math>N_{z}</math>. The projection is performed using a fully connected layer. These two vectors are the parameters of the latent space Gaussian distribution that will estimate the distribution of the input data. Because standard deviations cannot be negative, an exponential function is used to convert it to all positive values. Next, a random variable with mean <math>\mu</math> and standard deviation <math>\sigma</math> is constructed by scaling a normalized IID Gaussian, <math>\mathcal{N}(0,I)</math>, <br />
<br />
\begin{split}<br />
& \mu = W_\mu h + b_\mu, \\<br />
& \hat \sigma = W_\sigma h + b_\sigma, \\<br />
& \sigma = exp( \frac{\hat \sigma}{2}), \\<br />
& z = \mu + \sigma \odot \mathcal{N}(0,I). <br />
\end{split}<br />
<br />
<br />
Note that <math>z</math> is not deterministic but a random vector that can be conditioned on an input sketch sequence.<br />
<br />
==== Decoder ====<br />
The decoder is an autoregressive RNN. The initial hidden and cell states are generated using <math>[h_0;c_0] = \tanh(W_z z + b_z)</math>. Here, <math>c_0</math> is utilized if applicable (eg. if an LSTM decoder is used). <math>S_0</math> is defined as <math>(0,0,1,0,0)</math> (the pen is touching the paper at location 0, 0). <br />
<br />
For each step <math>i</math> in the decoder, the input <math>x_i</math> is the concatenation of the previous point <math>S_{i-1}</math> and the latent vector <math>z</math>. The outputs of the RNN decoder <math>y_i</math> are parameters for a probability distribution that will generate the next point <math>S_i</math>. <br />
<br />
The authors model <math>(\Delta x,\Delta y)</math> as a Gaussian mixture model (GMM) with <math>M</math> normal distributions and model the ground truth data <math>(p_1, p_2, p_3)</math> as a categorical distribution <math>(q_1, q_2, q_3)</math> where <math>q_1, q_2\ \text{and}\ q_3</math> sum up to 1,<br />
<br />
\begin{align*}<br />
p(\Delta x, \Delta y) = \sum_{j=1}^{M} \Pi_j \mathcal{N}(\Delta x,\Delta y | \mu_{x,j}, \mu_{y,j}, \sigma_{x,j},\sigma_{y,j}, \rho _{xy,j}), where \sum_{j=1}^{M}\Pi_j = 1<br />
\end{align*}<br />
<br />
Where <math>\mathcal{N}(\Delta x,\Delta y | \mu_{x,j}, \mu_{y,j}, \sigma_{x,j},\sigma_{y,j}, \rho _{xy,j})</math> is a bi-variate Normal Distribution, with parameters means <math>\mu_x, \mu_y</math>, standard deviations <math>\sigma_x, \sigma_y</math> and correlation parameter <math>\rho_{xy}</math>. There are <math>M</math> such distributions. <math>\Pi</math> is a categorical distribution vector of length <math>M</math>. Collectively these form the mixture weights of the Gaussian Mixture model.<br />
<br />
The output vector <math>y_i</math> is generated using a fully-connected forward propagation in the hidden state of the RNN.<br />
<br />
\begin{split}<br />
&x_i = [S_{i-1}; z], \\<br />
&[h_i; c_i] = forward(x_i,[h_{i-1}; c_{i-1}]), \\<br />
&y_i = W_y h_i + b_y, \\<br />
&y_i \in \mathbb{R}^{6M+3}. \\<br />
\end{split}<br />
<br />
The output consists the probability distribution of the next data point.<br />
<br />
\begin{align*}<br />
[(\hat\Pi_1\ \mu_x\ \mu_y\ \hat\sigma_x\ \hat\sigma_y\ \hat\rho_{xy})_1\ (\hat\Pi_1\ \mu_x\ \mu_y\ \hat\sigma_x\ \hat\sigma_y\ \hat\rho_{xy})_2\ ...\ (\hat\Pi_1\ \mu_x\ \mu_y\ \hat\sigma_x\ \hat\sigma_y\ \hat\rho_{xy})_M\ (\hat{q_1}\ \hat{q_2}\ \hat{q_3})] = y_i<br />
\end{align*}<br />
<br />
<math>\exp</math> and <math>\tanh</math> operations are applied to ensure that the standard deviations are non-negative and the correlation value is between -1 and 1.<br />
<br />
\begin{align*}<br />
\sigma_x = \exp (\hat \sigma_x),\ <br />
\sigma_y = \exp (\hat \sigma_y),\ <br />
\rho_{xy} = \tanh(\hat \rho_{xy}). <br />
\end{align*}<br />
<br />
Categorical distribution probabilities for <math>(p_1, p_2, p_3)</math> using <math>(q_1, q_2, q_3)</math> can be obtained as :<br />
<br />
\begin{align*}<br />
q_k = \frac{\exp{(\hat q_k)}}{ \sum\nolimits_{j = 1}^{3} \exp {(\hat q_j)}},<br />
k \in \left\{1,2,3\right\}, <br />
\Pi _k = \frac{\exp{(\hat \Pi_k)}}{ \sum\nolimits_{j = 1}^{M} \exp {(\hat \Pi_j)}},<br />
k \in \left\{1,...,M\right\}.<br />
\end{align*}<br />
<br />
It is hard for the model to decide when to stop drawing because the probabilities of the three events <math>(p_1, p_2, p_3)</math> are very unbalanced. Researchers in the past have used different weights for each pen event probability, but the authors found this approach lacking elegance and inadequate. They define a hyperparameter representing the max length of the longest sketch in the training set denoted by <math>N_{max}</math>, and set the <math>S_i</math> to be <math>(0, 0, 0, 0, 1)</math> for <math>i > N_s</math>.<br />
<br />
The outcome sample <math>S_i^{'}</math> can be generated in each time step during sample process and fed as input for the next time step. The process will stop when <math>p_3 = 1</math> or <math>i = N_{max}</math>. The output is not deterministic but conditioned random sequences. The level of randomness can be controlled using a temperature parameter <math>\tau</math>.<br />
<br />
\begin{align*}<br />
\hat q_k \rightarrow \frac{\hat q_k}{\tau}, <br />
\hat \Pi_k \rightarrow \frac{\hat \Pi_k}{\tau}, <br />
\sigma_x^2 \rightarrow \sigma_x^2\tau, <br />
\sigma_y^2 \rightarrow \sigma_y^2\tau. <br />
\end{align*}<br />
<br />
The <math>\tau</math> ranges from 0 to 1. When <math>\tau = 0</math> the output will be deterministic as the sample will consist of the points on the peak of the probability density function.<br />
<br />
=== Unconditional Generation ===<br />
There is a special case that only the decoder RNN module is trained. The decoder RNN could work as a standalone autoregressive model without latent variables. In this case, initial states are 0, the input <math>x_i</math> is only <math>S_{i-1}</math> or <math>S_{i-1}^{'}</math>. In the Figure 3, generating sketches unconditionally from the temperature parameter <math>\tau = 0.2</math> at the top in blue, to <math>\tau = 0.9</math> at the bottom in red.<br />
<br />
[[File:sketchfig3.png|700px|center]]<br />
<br />
=== Training ===<br />
The training process is the same as a Variational Autoencoder. The loss function is the sum of Reconstruction Loss <math>L_R</math> and the Kullback-Leibler Divergence Loss <math>L_{KL}</math>. The reconstruction loss <math>L_R</math> can be obtained with generated parameters of pdf and training data <math>S</math>. It is the sum of the <math>L_s</math> and <math>L_p</math>, which are the log loss of the offset <math>(\Delta x, \Delta y)</math> and the pen state <math>(p_1, p_2, p_3)</math>.<br />
<br />
\begin{align*}<br />
L_s = - \frac{1 }{N_{max}} \sum_{i = 1}^{N_s} \log(\sum_{i = 1}^{M} \Pi_{j,i} \mathcal{N}(\Delta x,\Delta y | \mu_{x,j,i}, \mu_{y,j,i}, \sigma_{x,j,i},\sigma_{y,j,i}, \rho _{xy,j,i})), <br />
\end{align*}<br />
\begin{align*}<br />
L_p = - \frac{1 }{N_{max}} \sum_{i = 1}^{N_{max}} \sum_{k = 1}^{3} p_{k,i} \log (q_{k,i}), <br />
L_R = L_s + L_p.<br />
\end{align*}<br />
<br />
<br />
Both terms are normalized by <math>N_{max}</math>.<br />
<br />
<math>L_{KL}</math> measures the difference between the distribution of the latent vector <math>z</math> and an IID Gaussian vector with zero mean and unit variance.<br />
<br />
\begin{align*}<br />
L_{KL} = - \frac{1}{2 N_z} (1+\hat \sigma - \mu^2 - \exp(\hat \sigma))<br />
\end{align*}<br />
<br />
The overall loss is weighted as:<br />
<br />
\begin{align*}<br />
Loss = L_R + w_{KL} L_{KL}<br />
\end{align*}<br />
<br />
When <math>w_{KL} = 0</math>, the model becomes a standalone unconditional generator. Specially, there will be no <math>L_{KL} </math> term as we only optimize for <math>L_{R} </math>. By removing the <math>L_{KL} </math> term the model approaches a pure autoencoder, meaning it sacrifices the ability to enforce a prior over the latent space and gains better reconstruction loss metrics.<br />
<br />
<br />
As shown in Figure 4, the <math>L_{R} </math> metric for the standalone decoder model is actually an upper bound for different models using a latent vector. The reason is the unconditional model does not access to the entire sketch it needs to generate.<br />
<br />
[[File:s.png|600px|thumb|center|Figure 4. Tradeoff between <math>L_{R} </math> and <math>L_{KL} </math>, for two models trained on single class datasets (left).<br />
Validation Loss Graph for models trained on the Yoga dataset using various <math>w_{KL} </math>. (right)]]<br />
<br />
== Experiments ==<br />
The authors experiment with the sketch-rnn model using different settings and recorded both losses. They used a Long Short-Term Memory(LSTM) model as an encoder and a HyperLSTM as a decoder. HyperLSTM is a type of RNN cell that excels at sequence generation tasks. The ability for HyperLSTM to spontaneously augment its own weights enables it to adapt to many different regimes<br />
in a large diverse dataset. They also conduct multi-class datasets. The result is as follows.<br />
<br />
[[File:sketchtable1.png|700px|center]]<br />
<br />
We could see the trade-off between <math>L_R</math> and <math>L_{KL}</math> in this table clearly. Furthermore, <math>L_R</math> decreases as <math>w_{KL} </math> is halfed. <br />
<br />
=== Conditional Reconstruction ===<br />
The authors assess the reconstructed sketch with a given sketch with different <math>\tau</math> values. We could see that with high <math>\tau</math> value on the right, the reconstructed sketches are more random.<br />
<br />
[[File:sketchfig5.png|700px|center]]<br />
<br />
They also experiment on inputting a sketch from a different class. The output will still keep some features from the class that the model is trained on.<br />
<br />
=== Latent Space Interpolation ===<br />
The authors visualize the reconstruction sketches while interpolating between latent vectors using different <math>w_{KL}</math> values. With high <math>w_{KL}</math> values, the generated images are more coherently interpolated.<br />
<br />
[[File:sketchfig6.png|700px|center]]<br />
<br />
=== Sketch Drawing Analogies ===<br />
Since the latent vector <math>z</math> encode conceptual features of a sketch, those features can also be used to augment other sketches that do not have these features. This is possible when models are trained with low <math>L_{KL}</math> values. The authors are able to perform vector arithmetic on latent vectors from different sketches and explore how the model generates sketches base on these latent spaces.<br />
<br />
=== Predicting Different Endings of Incomplete Sketches === <br />
This model is able to predict an incomplete sketch by encoding the sketch into hidden state <math>h</math> using the decoder and then using <math>h</math> as an initial hidden state to generate the remaining sketch. The authors train on individual classes by using decoder-only models and set τ = 0.8 to complete samples. Figure 7 shows the results.<br />
<br />
[[File:sketchfig7.png|700px|center]]<br />
<br />
== Applications and Future Work ==<br />
The authors believe this model can assist artists by suggesting how to finish a sketch, helping them to find interesting intersections between different drawings or objects, or generating a lot of similar but different designs. In the simplest use, pattern designers can apply sketch-rnn to generate a large number of similar, but unique designs for textile or wallpaper prints. The creative designers can also come up with abstract designs which enables them to resonate more with their target audience<br />
<br />
This model may also find its place on teaching students how to draw. Even with the simple sketches in QuickDraw, the authors of this work have become much more proficient at drawing animals, insects, and various sea creatures after conducting these experiments. <br />
When the model is trained with a high <math>w_{KL}</math> and sampled with a low <math>\tau</math>, it may help to turn a poor sketch into a more aesthetical one. Latent vector augmentation could also help to create a better drawing by inputting user-rating data during training processes.<br />
<br />
The authors conclude by providing the following future directions to this work:<br />
# Investigate using user-rating data to augmenting the latent vector in the direction that maximizes the aesthetics of the drawing.<br />
# Look into combining variations of sequence-generation models with unsupervised, cross-domain pixel image generation models.<br />
<br />
It's exciting that they manage to combine this model with other unsupervised, cross-domain pixel image generation models to create photorealistic images from sketches.<br />
<br />
The authors have also mentioned the opposite direction of converting a photograph of an object into an unrealistic, but similar looking<br />
sketch of the object composed of a minimal number of lines to be a more interesting problem.<br />
<br />
Moreover, it would be interesting to see how varying loss will be represented as a drawing. Some exotic form of loss function may change the way that the network behaves, which can lead to various applications.<br />
<br />
== Conclusion ==<br />
The paper presents a methodology to model sketch drawings using recurrent neural networks. The sketch-rnn model that can encode and decode sketches, generate and complete unfinished sketches is introduced in this paper. In addition, Authors demonstrated how to both interpolate between latent spaces from a different class, and use it to augment sketches or generate similar looking sketches. Furthermore, the importance of enforcing a prior distribution on latent vector while interpolating coherent sketch generations is shown. Finally, a large sketch drawings dataset for future research work is created.<br />
<br />
== Critique ==<br />
This paper presents both a novel large dataset of sketches and a new RNN architecture to generate new sketches. It is very exciting to read but there are stil some aspect to improve.<br />
<br />
* The performance of the decoder model can hardly be evaluated. The authors present the performance of the decoder by showing the generated sketches, it is clear and straightforward, however, not very efficient. It would be great if the authors could present a way, or a metric to evaluate how well the sketches are generated rather than printing them out and evaluate with human judgment. The authors didn't present an evaluation of the algorithms either. They provided <math>L_R</math> and <math>L_{KL}</math> for reference, however, a lower loss doesn't represent a better performance. Training loss alone likely does not capture the quality of a sketch.<br />
<br />
* Algorithm lacks comparison to the prior state of the art on standard metrics, which made the novelty unclear. Using strokes as inputs is a novel and innovative move, however, the paper does not provide a baseline or any comparison with other methods or algorithms. Some other researches were mentioned in the paper, using similar and smaller datasets. It would be great if the authors could use some basic or existing methods a baseline and compare with the new algorithm.<br />
<br />
* Besides the comparison with other algorithms, it would also be great if the authors could remove or replace some component of the algorithm in the model to show if one part is necessary, or what made them decide to include a specific component in the algorithm.<br />
<br />
* The authors proposed a few future applications for the model, however, the current output seems somehow not very close to their descriptions. But I do believe that this is a very good beginning, with the release of the sketch dataset, it must attract more scholars to research and improve with it!<br />
<br />
== References == <br />
# Jimmy L. Ba, Jamie R. Kiros, and Geoffrey E. Hinton. Layer normalization. NIPS, 2016.<br />
# Christopher M. Bishop. Mixture density networks. Technical Report, 1994. URL http://publications.aston.ac.uk/373/.<br />
# Samuel R. Bowman, Luke Vilnis, Oriol Vinyals, Andrew M. Dai, Rafal Józefowicz, and Samy Bengio. Generating Sentences from a Continuous Space. CoRR, abs/1511.06349, 2015. URL http://arxiv.org/abs/1511.06349.<br />
# H. Dong, P. Neekhara, C. Wu, and Y. Guo. Unsupervised Image-to-Image Translation with Generative Adversarial Networks. ArXiv e-prints, January 2017.<br />
# David H. Douglas and Thomas K. Peucker. Algorithms for the reduction of the number of points required to represent a digitized line or its caricature. Cartographica: The International Journal for Geographic Information and Geovisualization, 10(2):112–122, October 1973. doi: 10.3138/fm57-6770-u75u-7727. URL http://dx.doi.org/10.3138/fm57-6770-u75u-7727.<br />
# Mathias Eitz, James Hays, and Marc Alexa. How Do Humans Sketch Objects? ACM Trans. Graph.(Proc. SIGGRAPH), 31(4):44:1–44:10, 2012.<br />
# I. Goodfellow. NIPS 2016 Tutorial: Generative Adversarial Networks. ArXiv e-prints, December 2016.<br />
# Alex Graves. Generating sequences with recurrent neural networks. arXiv:1308.0850, 2013.<br />
# David Ha. Recurrent Net Dreams Up Fake Chinese Characters in Vector Format with TensorFlow, 2015.<br />
# David Ha, Andrew M. Dai, and Quoc V. Le. HyperNetworks. In ICLR, 2017.<br />
# Sepp Hochreiter and Juergen Schmidhuber. Long short-term memory. Neural Computation, 1997.<br />
# P. Isola, J.-Y. Zhu, T. Zhou, and A. A. Efros. Image-to-Image Translation with Conditional Adversarial Networks. ArXiv e-prints, November 2016.<br />
# Jonas Jongejan, Henry Rowley, Takashi Kawashima, Jongmin Kim, and Nick Fox-Gieg. The Quick, Draw! - A.I. Experiment. https://quickdraw.withgoogle.com/, 2016. URL https: //quickdraw.withgoogle.com/.<br />
# C. Kaae Sønderby, T. Raiko, L. Maaløe, S. Kaae Sønderby, and O. Winther. Ladder Variational Autoencoders. ArXiv e-prints, February 2016.<br />
# T. Kim, M. Cha, H. Kim, J. Lee, and J. Kim. Learning to Discover cross-domain Relations with Generative Adversarial Networks. ArXiv e-prints, March 2017.<br />
# D. P Kingma and M. Welling. Auto-Encoding Variational Bayes. ArXiv e-prints, December 2013.<br />
# Diederik Kingma and Jimmy Ba. Adam: A method for stochastic optimization. In ICLR, 2015.<br />
# Diederik P. Kingma, Tim Salimans, and Max Welling. Improving variational inference with inverse autoregressive flow. CoRR, abs/1606.04934, 2016. URL http://arxiv.org/abs/1606.04934.<br />
# Brenden M. Lake, Ruslan Salakhutdinov, and Joshua B. Tenenbaum. Human level concept learning through probabilistic program induction. Science, 350(6266):1332–1338, December 2015. ISSN 1095-9203. doi: 10.1126/science.aab3050. URL http://dx.doi.org/10.1126/science.aab3050.<br />
# Yong Jae Lee, C. Lawrence Zitnick, and Michael F. Cohen. Shadowdraw: Real-time user guidance for freehand drawing. In ACM SIGGRAPH 2011 Papers, SIGGRAPH ’11, pp. 27:1–27:10, New York, NY, USA, 2011. ACM. ISBN 978-1-4503-0943-1. doi: 10.1145/1964921.1964922. URL http://doi.acm.org/10.1145/1964921.1964922.<br />
# M.-Y. Liu, T. Breuel, and J. Kautz. Unsupervised Image-to-Image Translation Networks. ArXiv e-prints, March 2017.<br />
# S. Reed, A. van den Oord, N. Kalchbrenner, S. Gómez Colmenarejo, Z. Wang, D. Belov, and N. de Freitas. Parallel Multiscale Autoregressive Density Estimation. ArXiv e-prints, March 2017.<br />
# Patsorn Sangkloy, Nathan Burnell, Cusuh Ham, and James Hays. The Sketchy Database: Learning to Retrieve Badly Drawn Bunnies. ACM Trans. Graph., 35(4):119:1–119:12, July 2016. ISSN 0730-0301. doi: 10.1145/2897824.2925954. URL http://doi.acm.org/10.1145/2897824.2925954.<br />
# Mike Schuster, Kuldip K. Paliwal, and A. General. Bidirectional recurrent neural networks. IEEE Transactions on Signal Processing, 1997.<br />
# Saul Simhon and Gregory Dudek. Sketch interpretation and refinement using statistical models. In Proceedings of the Fifteenth Eurographics Conference on Rendering Techniques, EGSR’04, pp. 23–32, Aire-la-Ville, Switzerland, Switzerland, 2004. Eurographics Association. ISBN 3-905673-12-6. doi: 10.2312/EGWR/EGSR04/023-032. URL http://dx.doi.org/10.2312/EGWR/EGSR04/023-032.<br />
# Patrick Tresset and Frederic Fol Leymarie. Portrait drawing by paul the robot. Comput. Graph.,37(5):348–363, August 2013. ISSN 0097-8493. doi: 10.1016/j.cag.2013.01.012. URL http://dx.doi.org/10.1016/j.cag.2013.01.012.<br />
# T. White. Sampling Generative Networks. [https://arxiv.org/abs/1609.04468 ArXiv e-prints], September 2016.<br />
#Ning Xie, Hirotaka Hachiya, and Masashi Sugiyama. Artist agent: A reinforcement learning approach to automatic stroke generation in oriental ink painting. In ICML. icml.cc / Omnipress, 2012. URL http://dblp.uni-trier.de/db/conf/icml/icml2012.html#XieHS12.<br />
# Xu-Yao Zhang, Fei Yin, Yan-Ming Zhang, Cheng-Lin Liu, and Yoshua Bengio. Drawing and Recognizing Chinese Characters with Recurrent Neural Network. CoRR, abs/1606.06539, 2016. URL http://arxiv.org/abs/1606.06539.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Wasserstein_Auto-encoders&diff=41940Wasserstein Auto-encoders2018-11-29T23:22:18Z<p>Gsahu: /* Future Work */</p>
<hr />
<div>The first version of this work was published in 2017 and this version (which is the third revision) is presented in ICLR 2018. Source code for the first version is available [https://github.com/tolstikhin/wae here]<br />
<br />
=Introduction=<br />
Early successes in the field of representation learning were based on supervised approaches, which used large labeled datasets to achieve impressive results. On the other hand, popular unsupervised generative modeling methods mainly consisted of probabilistic approaches focusing on low dimensional data. In recent years, there have been models proposed which try to combine these two approaches. One such popular method is called variational auto-encoders (VAEs). VAEs are theoretically elegant but have a major drawback of generating blurry sample images when used for modeling natural images. In comparison, generative adversarial networks (GANs) produce much sharper sample images but have their own list of problems which include a lack of encoder, harder to train, and the "mode collapse" problem. Mode collapse problem refers to the inability of the model to capture all the variability in the true data distribution. Currently, there has been a lot of activities around finding and evaluating numerous GANs architectures and combining VAEs and GANs, but a model which combines the best of both GANs and VAEs is yet to be discovered.<br />
<br />
The work done in this paper builds upon the theoretical work done in Bousquet et al.[2017] [4]. The authors tackle generative modeling using optimal transport (OT). The OT cost is defined as the measure of distance between probability distributions.<br />
<br />
To be more specific on the OT:<br />
<br />
Given a function <math>c : X × Y → R</math>, they seek a minimizer of <math> C(µ, ν) := \underset{π ∈ Π(µ, ν)}{inf} \int_{X×Y}{c(x, y)dπ(x, y)}</math><br />
<br />
The measures <math>π ∈ Π(µ, ν)</math> are called transport plans or transference plans. The measures <math>π ∈ Π(µ, ν)</math> achieving the infimum are called optimal transport plans. The classical interpretation of this problem is the problem of minimizing the total cost <math>C(µ, ν)</math> of transporting the mass distribution <math>µ</math> to the mass distribution <math>ν</math>, where the cost of transporting one unit of mass at the point <math>x ∈ X</math> to one unit of mass at the point <math>y ∈ Y</math> is given by the cost function <math>c(x, y)</math>.<br />
<br />
One of the features of OT cost which is beneficial is that it provides much weaker topology when compared to other costs, including f-divergences which are associated with the original GAN algorithms. <br />
This particular feature is crucial in applications where the data is usually supported on low dimensional manifolds in the input space. This result in a problem with the stronger notions of distances such as f-divergences as they often max out and provide no useful gradients for training. In comparison, the OT cost has been claimed to behave much more nicely [5, 8]. Despite the preceding claim, the implementation, which is similar to GANs, still requires the addition of a constraint or a regularization term into the objective function.<br />
<br />
==Original Contributions==<br />
Let <math>P_X</math> be the true but unknown data distribution, <math>P_G</math> be the latent variable model specified by the prior distribution <math>P_Z</math> of latent codes <math>Z \in \mathcal{Z}</math> and the generative model <math>P_G(X|Z)</math> of the data points <math>X \in \mathcal{X}</math> given <math>Z</math>. The goal in this paper is to minimize <math>OT\ W_c(P_X, P_G)</math>.<br />
<br />
The main contributions are given below:<br />
<br />
* A new class of auto-encoders called Wasserstein Auto-Encoders (WAE). WAEs minimize the optimal transport <math>W_c(P_X, P_G)</math> for any cost function <math>c</math>. As is the case with VAEs, WAE objective function is also made up of two terms: the c-reconstruction cost and a regularizer term <math>\mathcal{D}_Z(P_Z, Q_Z)</math> which penalizes the discrepancy between two distributions in <math>\mathcal{Z}: P_Z\ and\ Q_Z</math>. <math>Q_Z</math> is a distribution of encoded points, i.e. <math>Q_Z := \mathbb{E}_{P_X}[Q(Z|X)]</math>. Note that when <math>c</math> is the squared cost and the regularizer term is the GAN objective, WAE is equivalent to the adversarial auto-encoders described in [2].<br />
<br />
* Experimental results of using WAE on MNIST and CelebA datasets with squared cost <math>c(x, y) = ||x - y||_2^2</math>. The results of these experiments show that WAEs have the good features of VAEs such as stable training, encoder-decoder architecture, and a nice latent manifold structure while simultaneously improving the quality of the generated samples.<br />
<br />
* Two different regularizers. One based on GANs and adversarial training in the latent space <math>\mathcal{Z}</math>. The other one is based on something called "Maximum Mean Discrepancy" which known to have high performance when matching high dimensional standard normal distributions. The second regularizer also makes the problem fully adversary-free min-min optimization problem, and gets rid of the problem of tuning the GAN.<br />
<br />
* The final contribution is the mathematical analysis used to derive the WAE objective function. In particular, the mathematical analysis shows that in the case of generative models, the primal form of <math>W_c(P_X, P_G)</math> is equivalent to a problem which deals with the optimization of a probabilistic encoder <math>Q(Z|X)</math><br />
<br />
The paper provides an ostensibly simple recipe to implement a non-blurry VAE (it is generative) It provides what looks like an elegant and logical way to cast the Wasserstein distance metric to setup the VAE/GAN problem.<br />
The paper gives three instructive VAEGAN model comparisons, unifying them thematically – Adversarial Autoencoders (AAE), Adversarial Variational Bayes (AVB), and the original Variational Autoencoders (VAE). These generalizations arise for the case with random decoders – the paper introduces the idea with deterministic decodes, and then extends it to random decoders – with play on the regularizer of the VAE which these papers replace with a GAN.<br />
<br />
=Proposed Method=<br />
The method proposed by the authors uses a novel auto-encoder architecture to minimize the optimal transport cost <math>W_c(P_X, P_G)</math>. In the optimization problem that follows, the decoder tries to accurately reconstruct the data points as measured by the cost function <math>c</math>. The encoder tries to achieve the following two conflicting goals at the same time: (1) try to match the distribution of the encoded data points <math>Q_Z := \mathbb{E}_{P_X}[Q(Z|X)]</math> to the prior distribution <math>P_Z</math> as measured by the divergence <math>\mathcal{D}_Z(P_Z, Q_Z)</math> and, (2) make sure that the latent space vectors encoded contain enough information so that the reconstruction of the data points are of high quality. The figure below illustrates this:<br />
<br />
[[File:ka2khan_figure_1.png|800px|thumb|center|Figure 1]]<br />
<br />
Figure 1: Both VAE and WAE have objectives which are composed of two terms. The two terms are the reconstruction cost and the regularizer term which penalizes the divergence between <math>P_Z</math> and <math>Q_Z</math>. VAE forces <math>Q(Z|X = x)</math> to match <math>P_Z</math> for the the different training examples drawn from <math>P_X</math>. As shown in the figure above, every red ball representing <math>Q_z</math> is forced to match <math>P_Z</math> depicted as whitish triangles. This causes intersection among red balls and results in reconstruction problems. On the other hand, WAE coerces the mixture <math>Q_Z := \int{Q(Z|X)\ dP_X}</math> to match <math>P_Z</math> as shown in the figure above. This provides a better chance of the encoded latent codes to have more distance between them. As a consequence of this, higher reconstruction quality is achieved.<br />
<br />
==Preliminaries and Notations==<br />
Authors use calligraphic letters to denote sets (for example, <math>\mathcal{X}</math>), capital letters for random variables (for example, <math>X</math>), and lower case letters for the values (for example, <math>x</math>). Probability distributions are are also denoted with capital letters (for example, <math>P(X)</math>) and the corresponding densities are denoted with lowercase letter (for example, <math>p(x)</math>).<br />
<br />
Several measure of difference between probability distributions are also used by the authors. These include f-divergences given by <math>D_f(p_X||p_G) := \int{f(\frac{p_X(x)}{p_G(x)})p_G(x)}dx\ \text{where}\ f:(0, \infty) &rarr; \mathcal{R}</math> is any convex function satisfying <math>f(1) = 0</math>. Other divergences used include KL divergence (<math>D_{KL}</math>) and Jensen-Shannon (<math>D_{JS}</math>) divergences.<br />
<br />
==Optimal Transport and its Dual Formations==<br />
<br />
A rich class of measure of distances between probability distributions is motivated by the optimal transport problem. One such formulation of the optimal transport problem is the Kantovorich's formulation given by:<br />
<br />
<center><math><br />
W_c(P_X, P_G) := \underset{\Gamma \in \mathcal{P}(X \sim P_X ,Y \sim P_G)}{inf} \mathbb{E}_{(X,Y) \sim \Gamma}[c(X,Y)],<br />
\text{where} \ c(x, y): \mathcal{X} \times \mathcal{X} &rarr; \mathcal{R_{+}}<br />
</math></center><br />
<br />
is any measurable cost function and <math>\mathcal{P}(X \sim P_X, Y \sim P_G)</math> is a set of all joint distributions of (X, Y) with marginals <math>P_X\ \text{and}\ P_G</math> respectively.<br />
<br />
A particularly interesting case is when <math>(\mathcal{X}, d)</math> is metric space and <math>c(x, y) = d^p(x, y)\ \text{for}\ p &ge; 1</math>. In this case <math>W_p</math>, the <math>p-th</math> root of <math>W_c</math>, is called the p-Wasserstein distance.<br />
<br />
When <math>c(x, y) = d(x, y)</math> the following Kantorovich-Rubinstein duality holds:<br />
<br />
<math>W_1(P_X, P_G) = \underset{f \in \mathcal{F}_L}{sup} \mathbb{E}_{X \sim P_x}[f(X)] = \mathbb{E}_{Y \sim P_G}[f(Y)]</math><br />
where <math>\mathcal{F}_L</math> is the class of all bounded 1-Lipschitz functions on <math>(\mathcal{X}, d)</math>.<br />
<br />
==Application to Generative Models: Wasserstein auto-encoders==<br />
The intuition behind modern generative models like VAEs and GANs is that they try to minimize specific distance measures between the data distribution <math>P_X</math> and the model <math>P_G</math>. Unfortunately, with the current knowledge and tools, it is usually really hard or even impossible to calculate most of the standard discrepancy measures especially when <math>P_X</math> is not known and <math>P_G</math> is parametrized by deep neural networks. Having said that, there are certain tricks available which can be employed to get around that difficulty.<br />
<br />
For KL-divergence <math>D_{KL}(P_X, P_G)</math> minimization, or equivalently the marginal log-likelihood <math>E_{P_X}[log_{P_G}(X)]</math> maximization, one can use the famous variational lower bound which provides a theoretically grounded framework. This has been used quite successfully by the VAEs. In the general case of minimizing f-divergence <math>D_f(P_X, P_G)</math>, using its dual formulation along with f-GANs and adversarial training is viable. Finally, OT cost <math>W_c(P_X, P_G)</math> can be minimized by using the Kantorovich-Rubinstein duality expressed as an adversarial objective. The Wasserstein-GAN implement this idea.<br />
<br />
In this paper, the authors focus on the latent variable models <math>P_G</math> given by a two step procedure. First, a code <math>Z</math> is sampled from a fixed distribution <math>P_Z</math> on a latent space <math>\mathcal{Z}</math>. Second step is to map <math>Z</math> to the image <math>X \in \mathcal{X} = \mathcal{R}^d</math> with a (possibly random) transformation. This gives us a density of the form,<br />
<br />
<center><math><br />
p_G(x) := \int\limits_{\mathcal{Z}}{p_G(x|z)p_z(z)}dz,\ \forall x \in \mathcal{X}, <br />
</math></center><br />
<br />
provided all the probablities involved are properly defined. In order to keep things simple, the authors focus on non-random decoders, i.e., the generative models <math>P_G(X|Z)</math> deterministically map <math>Z</math> to <math>X = G(Z)</math> using a fixed map <math>G: \mathcal{Z} &rarr; \mathcal{X}</math>. Similar results hold for the random decoders as shown by the authors in the appendix B.1.<br />
<br />
Working under the model defined in the preceding paragraph, the authors find that OT cost takes a much simpler form as the transportation plan factors through the map <math>G:</math> instead of finding a coupling <math>\Gamma</math> between two random variables in the <math>\mathcal{X}</math> space, one given by the distribution <math>P_X</math> and the other by the the distribution <math>P_G</math>, it is enough to find a conditional distribution <math>Q(Z|X)</math> such that its <math>Z</math> marginal, <math>Q_Z)Z) := \mathbb{E}_{X \sim P_X}[Q(Z|X)]</math> is the same as the prior distribution <math>P_Z</math>. This is formalized by the theorem given below. The theorem given below was proven in [4] by the authors.<br />
<br />
'''Theorem 1.''' For <math>P_G</math> defined as above with deterministic <math>P_G(X|Z)</math> and any function <math>G:\mathcal{Z} &rarr; \mathcal{X}</math><br />
<br />
<math><br />
\underset{\Gamma \in \mathcal{P}(X \sim P_X ,Y \sim P_G)}{inf} \mathbb{E}_{(X,Y) \sim \Gamma}[c(X,Y)] = \underset{Q: Q_Z = P_Z}{inf} \mathbb{E}_{P_X} \mathbb{E}_{Q(Z|X)}[c(X, G(Z))]<br />
</math><br />
<br />
where <math>Q_Z</math> is the marginal distribution of <math>Z</math> when <math>X \sim P_X</math> and <math>Z \sim Q(Z|X)</math>.<br />
<br />
According to the authors, the result above allows optimization over random encoders <math>Q(Z|X)</math> instead of optimizing overall couplings of <math>X</math> and <math>Y</math>. Both problems are still constrained. To find a numerical solution, the authors relax the constraints on <math>Q_Z</math> by adding a regularizer term to the objective. This gives them the WAE objective:<br />
<br />
<math><br />
D_{WAE}(P_X, P_G) := \underset{Q(Z|X) \in \mathcal{Q}}{inf} \mathbb{E}_{P_X} \mathbb{E}_{Q(Z|X)}[c(X, G(Z))] + \lambda \cdot \mathcal{D}_Z(Q_Z, P_Z)<br />
</math><br />
<br />
where <math>\mathcal{Q}</math> is any nonparametric set of probabilistic encoders, <math>\mathcal{D}_Z</math> is an arbitrary measure of distance between <math>Q_Z</math> and <math>P_Z</math>, and <math>\lambda &gt; 0</math> is a hyperparameter. As is the case with the VAEs, the<br />
authors propose using deep neural networks to parameterize both encoders <math>Q</math> and decoders <math>G</math>. Note that, unlike VAEs, WAE allows for non-random encoders deterministically mapping their inputs to their latent codes.<br />
<br />
The authors propose two different regularizers <math>\mathcal{D}_Z(Q_Z, P_Z)</math><br />
<br />
===GAN-based <math>\mathcal{D}_z</math>===<br />
One of the option is to use <math>\mathcal{D}_Z(Q_Z, P_Z) = \mathcal{D}_{JS}(Q_Z, P_Z)</math> along with adversarial training for estimation. In particular, the discriminator (adversary) is used in the latent space <math>\mathcal{Z}</math> to classify "true" points sampled for <math>P_X</math> and "fake" ones samples from <math>Q_Z</math>. This leads to the WAE-GAN as described in Algorithm 1 listed below. Even though WAE-GAN still uses max-min optimization, one positive feature is that it moves the adversary from the input (pixel) space <math>\mathcal{X}</math> to the latent space <math>\mathcal{Z}</math>. Additionally, the true latent space distribution <math>P_Z</math> might have a nice shape with a single mode (for a Gaussian prior), making the task of matching much easier as opposed to matching an unknown, complex, and possibly multi-modal distributions which is usually the case in GANs. This leads to the second penalty.<br />
<br />
===MMD-based <math>\mathcal{D}_z</math>===<br />
For a positive-definite reproducing kernel <math>k: \mathcal{Z} \times \mathcal{Z} &rarr; \mathcal{R}</math>, the maximum mean discrepancy (MMD) is defined as:<br />
<br />
<center><math><br />
MMD_k(P_Z, Q_Z) = \left \Vert \int \limits_{\mathcal{Z}} {k(z, \cdot)dP_Z(z)} - \int \limits_{\mathcal{Z}} {k(z, \cdot)dQ_Z(z)} \right \|_{\mathcal{H}_k}<br />
</math>,</center><br />
<br />
where <math>\mathcal{H}_k</math> is the RKHS (reproducing kernel Hilbert space) of real-valued functions mappings <math>\mathcal{Z}</math> to <math>\mathcal{R}</math>. If <math>k</math> is characteristi then <math>MMD_k</math> defines a metric and can be used as a distance measure. The authors propose to use <math>\mathcal{D}_Z(P_Z, Q_Z) = MMD_k(P_Z, Q_Z)</math>. MMD also have an unbiased U-statistic estimator which can be used alongwith stochastic gradient descent (SGD) methods. This gives us WAE-MMD as described in the Algorithm 2 listed below. Note that MMD is known to perform well when matching high dimensional standard normal distributions, so it is expected that this penalty will work well when the prior <math>P_Z</math> is Gaussian.<br />
<br />
[[File:ka2khan_figure_2.png|800px|thumb|center|Algorithms- WAE-GAN on left and WAE-MMD on right]]<br />
<br />
=Related Work=<br />
==Literature on auto-encoders==<br />
Classical unregularized auto-encoders have an objective function which only tries to minimize the reconstruction cost. This results in distinct data points being encoded into distinct zones distributed chaotically across the latent space <math>\mathcal{Z}</math>. The latent space <math>\mathcal{Z}</math> in this scenario contains huge "holes" for which the decoder <math>P_G(X|Z)</math> has never been trained. In general, the encoder trained this way do not provide terribly useful representations and sampling from the latent space <math>\mathcal{Z}</math> becomes a difficult task [12].<br />
<br />
VAEs [1] minimize the KL-divergence <math>D_{KL}(P_X, P_G)</math> which consists of the reconstruction cost and the regularizer <math>\mathbb{E}_{P_X}[D_{KL}(Q(|X), P_Z)]</math>. The regularizer penalizes the difference in the encoded training images and the prior <math>P_Z</math>. But this penalty still does not guarantee that the overall encoded distribution matches the prior distribution as WAE does. In addition, VAEs require a non-degenerate (i.e. non-deterministic) Gaussian encoders along with random decoders. Another paper [11] later, proposed a method which allows the use of non-Gaussian encoders with VAEs. In the meanwhile, WAE minimizes <math>W_{c}(P_X, P_G)</math> and allows probabilistic and deterministic encoder and decoder pairs.<br />
<br />
When parameters are appropriately defined, WAE is able to generalize AAE in two ways: it can use any cost function in the input space and use any discrepancy measure <math>D_Z</math> in latent space <math>Z</math> other than the adversarial one.<br />
<br />
There has been work done on regularized auto-encoders called InfoVAE [14], which has objective similar to [4] but using different motivations and arguments.<br />
<br />
WAEs explicitly define the cost function <math>c(x,y)</math>, whereas VAEs rely on an implicitly through a negative log likelihood term. It theoretically can induce any arbitrary cost function, but in practice can require an estimation of the normalizing constant that can be different for values of <math>z</math>.<br />
<br />
==Literature on Optimal Transport (OT)==<br />
[15] provides methods for computing OT cost for large-scale data using SGD and sampling. The WGAN [5] proposes a generative model which minimizes 1-Wasserstein distance <math>W_1(P_X, P_G)</math>. The WGAN algorithm does not provide an encoder and cannot be easily applied to any arbitrary cost <math>W_C</math>. The model proposed in [5] uses the dual form, in contrast, the model proposed in this paper uses the primal form. The primal form allows the use of any arbitrary cost function <math>c</math> and naturally, comes with an encoder. <br />
<br />
In order to compute <math>W_c(P_X, P_G)</math> or <math>W_1(P_X, P_G)</math>, the model needs to handle various non-trivial constraints, various methods has be proposed in the literature ([5], [2], [8], [16], [15], [17], [18]) to avoid this difficulty .<br />
<br />
==Literature on GANs==<br />
A lot of the GAN variations which have been proposed in the literature come without an encoder. Examples include WGAN and f-GAN. These models are deficient in cases where a reconstruction of latent space is needed to use the learned manifold.<br />
<br />
There have been numerous models proposed in the literature which try to combine the adversarial training of GANs with auto-encoder architectures. Some examples are [19], [20], [21], and [22]. There has also been work done in which reproducing kernels have been used in the context of GANS ([23], [24]).<br />
<br />
=Experiments=<br />
Experiments were used to empirically evaluate the proposed WAE model. <br />
<br />
'''Experimental setup'''<br />
<br />
For experimental setup, authors used <math> \small P_Z</math> and squared cost function <math> \small c(x,y)</math> for data points.<br />
Deterministic encoder-decoder pairs were used.The authors conducted experiments using the following two real-world datasets: (1) MNIST [27] made up of 70k images, and (2) CelebA [28] consisting of approximately 203k images. For test reconstruction and interpolations a pair of of held out images, <math>(x,y)</math> from the test set are Auto-encoded (separately), to produce <math>(z_x, z_y)</math> in the latent space<br />
<br />
The main evaluation criteria were to see if the WAE model can simultaneously achieve: <br />
<br />
<ol><br />
<li>accurate reconstruction of the data points</li><br />
<li>resonable geometry of the latent manifold</li><br />
<li>generation of high quality random samples</li><br />
</ol><br />
<br />
For the model to generalize well (1) and (2) should be met on both the training and test data set.<br />
<br />
The proposed model achieve reasonably good results as highlighted in the figures given below:<br />
<br />
[[File:ka2khan_figure_3.png|800px|thumb|center|Using CelebA dataset]]<br />
<br />
[[File:ka2khan_figure_4.png|800px|thumb|center|Using CelebA dataset, FID (Fréchet Inception Distance<br />
[32]): smaller is better, sharpness: larger is better]]<br />
<br />
=Conclusion=<br />
The authors proposed a new class of algorithms for building a generative model called Wasserstein Autoencoders based on optimal transport cost. They related the newly proposed model to the existing probabilistic modeling techniques. They empirically evaluated the proposed models using two real-world datasets. They compared the results obtained using their proposed model with the results obtained using VAEs on the same dataset to show that the proposed models generate sample images of higher quality in addition to being easier to train and having good reconstruction quality of the data points.<br />
<br />
The authors claim that in future work, they will further explore the criteria for matching the encoding distribution <math>Q_Z</math> to the prior distribution <math>P_Z</math>, evaluate whether it is feasible to adversarially train the cost function <math>c</math>in the input space <math>\mathcal{X}</math>, and a theoretical analysis of the dual-formations for WAE-GAN and WAE-MMD.<br />
<br />
=Future Work=<br />
Following the work of this paper, another generative model was introduced by [34] that is based on the concept of optimal transport. Optimal transport is basically the distance between probability distributions by transporting one of the distributions to the other (and hence the name of optimal transport). Then, a new simple model called "Sliced-Wasserstein Autoencoders" (SWAE) is presented, which is easily implemented, and provides the capabilities of Wasserstein Autoencoders.<br />
<br />
([https://openreview.net/forum?id=HkL7n1-0b]) The results from MNIST and CelebA datasets look convincing, though could include additional evaluation to compare the adversarial loss with the straightforward MMD metric and potentially discuss their pros and cons. In some sense, given the challenges in evaluating and comparing closely related auto-encoder solutions, the authors could design demonstrative experiments for cases where Wassersterin distance helps and maybe its potential limitations.<br />
<br />
=Critique=<br />
<br />
Although this paper presented some empirical tests to explain its method in an appropriate way, it would be better to provide some clearer notations including the details of the architectures in their experiments. Furthermore, they could benefit from performing some comparisons between the results of their work and other similar works. As pointed out by a reviewer, the closest work to this paper is the adversarial variational bayes framework by Mescheder et.al. which also attempts at unifying VAEs and GANs. Although the authors describe the conceptual differences and advantages over that approach, it will be beneficial to actually include some comparisons in the results section.<br />
Moreover, the performance of the algorithm is not a significant improvement compared to previous VAE algorithm. The performance can be described and tested if the author performed empirical tests on various data sets. However, the methodology is flexible and unified to other types of the algorithm which is a huge benefit.<br />
<br />
=References=<br />
[1] D. P. Kingma and M. Welling. Auto-encoding variational Bayes. In ICLR, 2014.<br />
<br />
[2] A. Makhzani, J. Shlens, N. Jaitly, and I. Goodfellow. Adversarial autoencoders. In ICLR, 2016.<br />
<br />
[3] Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial nets. In NIPS, pages 2672–2680, 2014.<br />
<br />
[4] O. Bousquet, S. Gelly, I. Tolstikhin, C. J. Simon-Gabriel, and B. Schölkopf. From optimal transport to generative modeling: the VEGAN cookbook, 2017.<br />
<br />
[5] M. Arjovsky, S. Chintala, and L. Bottou. Wasserstein GAN, 2017.<br />
<br />
[6] C. Villani. Topics in Optimal Transportation. AMS Graduate Studies in Mathematics, 2003.<br />
<br />
[7] Sebastian Nowozin, Botond Cseke, and Ryota Tomioka. f-GAN: Training generative neural samplers using variational divergence minimization. In NIPS, 2016.<br />
<br />
[8] I. Gulrajani, F. Ahmed, M. Arjovsky, V. Domoulin, and A. Courville. Improved training of wasserstein GANs, 2017.<br />
<br />
[9] A. Gretton, K. M. Borgwardt, M. J. Rasch, B. Schölkopf, and A. J. Smola. A kernel two-sample test. Journal of Machine Learning Research, 13:723–773, 2012.<br />
<br />
[10] F. Liese and K.-J. Miescke. Statistical Decision Theory. Springer, 2008.<br />
<br />
[11] L. Mescheder, S. Nowozin, and A. Geiger. Adversarial variational bayes: Unifying variational autoencoders and generative adversarial networks, 2017.<br />
<br />
[12] Y. Bengio, A. Courville, and P. Vincent. Representation learning: A review and new perspectives. Pattern Analysis and Machine Intelligence, 35, 2013.<br />
<br />
[13] M. D. Hoffman and M. Johnson. Elbo surgery: yet another way to carve up the variational evidence lower bound. In NIPS Workshop on Advances in Approximate Bayesian Inference, 2016.<br />
<br />
[14] S. Zhao, J. Song, and S. Ermon. InfoVAE: Information maximizing variational autoencoders, 2017.<br />
<br />
[15] A. Genevay, M. Cuturi, G. Peyré, and F. R. Bach. Stochastic optimization for large-scale optimal transport. In Advances in Neural Information Processing Systems, pages 3432–3440, 2016. <br />
<br />
[16] M. Cuturi. Sinkhorn distances: Lightspeed computation of optimal transport. In Advances in Neural Information Processing Systems, pages 2292–2300, 2013.<br />
<br />
[17] Lenaic Chizat, Gabriel Peyré, Bernhard Schmitzer, and François-Xavier Vialard. Unbalanced optimal transport: geometry and kantorovich formulation. arXiv preprint arXiv:1508.05216, 2015.<br />
<br />
[18] Matthias Liero, Alexander Mielke, and Giuseppe Savaré. Optimal entropy-transport problems and a new hellinger-kantorovich distance between positive measures. arXiv preprint arXiv:1508.07941, 2015.<br />
<br />
[19] J. Zhao, M. Mathieu, and Y. LeCun. Energy-based generative adversarial network. In ICLR, 2017.<br />
<br />
[20] V. Dumoulin, I. Belghazi, B. Poole, A. Lamb, M. Arjovsky, O. Mastropietro, and A. Courville. Adversarially learned inference. In ICLR, 2017.<br />
<br />
[21] D. Ulyanov, A. Vedaldi, and V. Lempitsky. It takes (only) two: Adversarial generator-encoder networks, 2017.<br />
<br />
[22] D. Berthelot, T. Schumm, and L. Metz. Began: Boundary equilibrium generative adversarial networks, 2017.<br />
<br />
[23] Y. Li, K. Swersky, and R. Zemel. Generative moment matching networks. In ICML, 2015. <br />
<br />
[24] G. K. Dziugaite, D. M. Roy, and Z. Ghahramani. Training generative neural networks via maximum mean discrepancy optimization. In UAI, 2015.<br />
<br />
[25] R. Reddi, A. Ramdas, A. Singh, B. Poczos, and L. Wasserman. On the high-dimensional power of a linear-time two sample test under mean-shift alternatives. In AISTATS, 2015.<br />
<br />
[26] C. L. Li, W. C. Chang, Y. Cheng, Y. Yang, and B. Poczos. Mmd gan: Towards deeper understanding of moment matching network, 2017.<br />
<br />
[27] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. Gradient-based learning applied to document recognition. In Proceedings of the IEEE, volume 86(11), pages 2278–2324, 1998.<br />
<br />
[28] Ziwei Liu, Ping Luo, Xiaogang Wang, and Xiaoou Tang. Deep learning face attributes in the wild. In Proceedings of International Conference on Computer Vision (ICCV), 2015.<br />
<br />
[29] D. P. Kingma and J. Lei. Adam: A method for stochastic optimization, 2014.<br />
<br />
[30] A. Radford, L. Metz, and S. Chintala. Unsupervised representation learning with deep convolutional generative adversarial networks. In ICLR, 2016.<br />
<br />
[31] S. Ioffe and C. Szegedy. Batch normalization: Accelerating deep network training by reducing internal covariate shift, 2015.<br />
<br />
[32] Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, Günter Klambauer, and Sepp Hochreiter. GANs trained by a two time-scale update rule converge to a nash equilibrium. arXiv preprint arXiv:1706.08500, 2017.<br />
<br />
[33] B. Poole, A. Alemi, J. Sohl-Dickstein, and A. Angelova. Improved generator objectives for GANs, 2016.<br />
<br />
[34] S. Kolouri, C. E. Martin, and G. K. Rohde. Sliced-wasserstein autoencoder: An embarrassingly simple generative model. arXiv preprint arXiv:1804.01947, 2018.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Wasserstein_Auto-encoders&diff=41936Wasserstein Auto-encoders2018-11-29T23:16:16Z<p>Gsahu: /* Application to Generative Models: Wasserstein auto-encoders */</p>
<hr />
<div>The first version of this work was published in 2017 and this version (which is the third revision) is presented in ICLR 2018. Source code for the first version is available [https://github.com/tolstikhin/wae here]<br />
<br />
=Introduction=<br />
Early successes in the field of representation learning were based on supervised approaches, which used large labeled datasets to achieve impressive results. On the other hand, popular unsupervised generative modeling methods mainly consisted of probabilistic approaches focusing on low dimensional data. In recent years, there have been models proposed which try to combine these two approaches. One such popular method is called variational auto-encoders (VAEs). VAEs are theoretically elegant but have a major drawback of generating blurry sample images when used for modeling natural images. In comparison, generative adversarial networks (GANs) produce much sharper sample images but have their own list of problems which include a lack of encoder, harder to train, and the "mode collapse" problem. Mode collapse problem refers to the inability of the model to capture all the variability in the true data distribution. Currently, there has been a lot of activities around finding and evaluating numerous GANs architectures and combining VAEs and GANs, but a model which combines the best of both GANs and VAEs is yet to be discovered.<br />
<br />
The work done in this paper builds upon the theoretical work done in Bousquet et al.[2017] [4]. The authors tackle generative modeling using optimal transport (OT). The OT cost is defined as the measure of distance between probability distributions.<br />
<br />
To be more specific on the OT:<br />
<br />
Given a function <math>c : X × Y → R</math>, they seek a minimizer of <math> C(µ, ν) := \underset{π ∈ Π(µ, ν)}{inf} \int_{X×Y}{c(x, y)dπ(x, y)}</math><br />
<br />
The measures <math>π ∈ Π(µ, ν)</math> are called transport plans or transference plans. The measures <math>π ∈ Π(µ, ν)</math> achieving the infimum are called optimal transport plans. The classical interpretation of this problem is the problem of minimizing the total cost <math>C(µ, ν)</math> of transporting the mass distribution <math>µ</math> to the mass distribution <math>ν</math>, where the cost of transporting one unit of mass at the point <math>x ∈ X</math> to one unit of mass at the point <math>y ∈ Y</math> is given by the cost function <math>c(x, y)</math>.<br />
<br />
One of the features of OT cost which is beneficial is that it provides much weaker topology when compared to other costs, including f-divergences which are associated with the original GAN algorithms. <br />
This particular feature is crucial in applications where the data is usually supported on low dimensional manifolds in the input space. This result in a problem with the stronger notions of distances such as f-divergences as they often max out and provide no useful gradients for training. In comparison, the OT cost has been claimed to behave much more nicely [5, 8]. Despite the preceding claim, the implementation, which is similar to GANs, still requires the addition of a constraint or a regularization term into the objective function.<br />
<br />
==Original Contributions==<br />
Let <math>P_X</math> be the true but unknown data distribution, <math>P_G</math> be the latent variable model specified by the prior distribution <math>P_Z</math> of latent codes <math>Z \in \mathcal{Z}</math> and the generative model <math>P_G(X|Z)</math> of the data points <math>X \in \mathcal{X}</math> given <math>Z</math>. The goal in this paper is to minimize <math>OT\ W_c(P_X, P_G)</math>.<br />
<br />
The main contributions are given below:<br />
<br />
* A new class of auto-encoders called Wasserstein Auto-Encoders (WAE). WAEs minimize the optimal transport <math>W_c(P_X, P_G)</math> for any cost function <math>c</math>. As is the case with VAEs, WAE objective function is also made up of two terms: the c-reconstruction cost and a regularizer term <math>\mathcal{D}_Z(P_Z, Q_Z)</math> which penalizes the discrepancy between two distributions in <math>\mathcal{Z}: P_Z\ and\ Q_Z</math>. <math>Q_Z</math> is a distribution of encoded points, i.e. <math>Q_Z := \mathbb{E}_{P_X}[Q(Z|X)]</math>. Note that when <math>c</math> is the squared cost and the regularizer term is the GAN objective, WAE is equivalent to the adversarial auto-encoders described in [2].<br />
<br />
* Experimental results of using WAE on MNIST and CelebA datasets with squared cost <math>c(x, y) = ||x - y||_2^2</math>. The results of these experiments show that WAEs have the good features of VAEs such as stable training, encoder-decoder architecture, and a nice latent manifold structure while simultaneously improving the quality of the generated samples.<br />
<br />
* Two different regularizers. One based on GANs and adversarial training in the latent space <math>\mathcal{Z}</math>. The other one is based on something called "Maximum Mean Discrepancy" which known to have high performance when matching high dimensional standard normal distributions. The second regularizer also makes the problem fully adversary-free min-min optimization problem, and gets rid of the problem of tuning the GAN.<br />
<br />
* The final contribution is the mathematical analysis used to derive the WAE objective function. In particular, the mathematical analysis shows that in the case of generative models, the primal form of <math>W_c(P_X, P_G)</math> is equivalent to a problem which deals with the optimization of a probabilistic encoder <math>Q(Z|X)</math><br />
<br />
The paper provides an ostensibly simple recipe to implement a non-blurry VAE (it is generative) It provides what looks like an elegant and logical way to cast the Wasserstein distance metric to setup the VAE/GAN problem.<br />
The paper gives three instructive VAEGAN model comparisons, unifying them thematically – Adversarial Autoencoders (AAE), Adversarial Variational Bayes (AVB), and the original Variational Autoencoders (VAE). These generalizations arise for the case with random decoders – the paper introduces the idea with deterministic decodes, and then extends it to random decoders – with play on the regularizer of the VAE which these papers replace with a GAN.<br />
<br />
=Proposed Method=<br />
The method proposed by the authors uses a novel auto-encoder architecture to minimize the optimal transport cost <math>W_c(P_X, P_G)</math>. In the optimization problem that follows, the decoder tries to accurately reconstruct the data points as measured by the cost function <math>c</math>. The encoder tries to achieve the following two conflicting goals at the same time: (1) try to match the distribution of the encoded data points <math>Q_Z := \mathbb{E}_{P_X}[Q(Z|X)]</math> to the prior distribution <math>P_Z</math> as measured by the divergence <math>\mathcal{D}_Z(P_Z, Q_Z)</math> and, (2) make sure that the latent space vectors encoded contain enough information so that the reconstruction of the data points are of high quality. The figure below illustrates this:<br />
<br />
[[File:ka2khan_figure_1.png|800px|thumb|center|Figure 1]]<br />
<br />
Figure 1: Both VAE and WAE have objectives which are composed of two terms. The two terms are the reconstruction cost and the regularizer term which penalizes the divergence between <math>P_Z</math> and <math>Q_Z</math>. VAE forces <math>Q(Z|X = x)</math> to match <math>P_Z</math> for the the different training examples drawn from <math>P_X</math>. As shown in the figure above, every red ball representing <math>Q_z</math> is forced to match <math>P_Z</math> depicted as whitish triangles. This causes intersection among red balls and results in reconstruction problems. On the other hand, WAE coerces the mixture <math>Q_Z := \int{Q(Z|X)\ dP_X}</math> to match <math>P_Z</math> as shown in the figure above. This provides a better chance of the encoded latent codes to have more distance between them. As a consequence of this, higher reconstruction quality is achieved.<br />
<br />
==Preliminaries and Notations==<br />
Authors use calligraphic letters to denote sets (for example, <math>\mathcal{X}</math>), capital letters for random variables (for example, <math>X</math>), and lower case letters for the values (for example, <math>x</math>). Probability distributions are are also denoted with capital letters (for example, <math>P(X)</math>) and the corresponding densities are denoted with lowercase letter (for example, <math>p(x)</math>).<br />
<br />
Several measure of difference between probability distributions are also used by the authors. These include f-divergences given by <math>D_f(p_X||p_G) := \int{f(\frac{p_X(x)}{p_G(x)})p_G(x)}dx\ \text{where}\ f:(0, \infty) &rarr; \mathcal{R}</math> is any convex function satisfying <math>f(1) = 0</math>. Other divergences used include KL divergence (<math>D_{KL}</math>) and Jensen-Shannon (<math>D_{JS}</math>) divergences.<br />
<br />
==Optimal Transport and its Dual Formations==<br />
<br />
A rich class of measure of distances between probability distributions is motivated by the optimal transport problem. One such formulation of the optimal transport problem is the Kantovorich's formulation given by:<br />
<br />
<center><math><br />
W_c(P_X, P_G) := \underset{\Gamma \in \mathcal{P}(X \sim P_X ,Y \sim P_G)}{inf} \mathbb{E}_{(X,Y) \sim \Gamma}[c(X,Y)],<br />
\text{where} \ c(x, y): \mathcal{X} \times \mathcal{X} &rarr; \mathcal{R_{+}}<br />
</math></center><br />
<br />
is any measurable cost function and <math>\mathcal{P}(X \sim P_X, Y \sim P_G)</math> is a set of all joint distributions of (X, Y) with marginals <math>P_X\ \text{and}\ P_G</math> respectively.<br />
<br />
A particularly interesting case is when <math>(\mathcal{X}, d)</math> is metric space and <math>c(x, y) = d^p(x, y)\ \text{for}\ p &ge; 1</math>. In this case <math>W_p</math>, the <math>p-th</math> root of <math>W_c</math>, is called the p-Wasserstein distance.<br />
<br />
When <math>c(x, y) = d(x, y)</math> the following Kantorovich-Rubinstein duality holds:<br />
<br />
<math>W_1(P_X, P_G) = \underset{f \in \mathcal{F}_L}{sup} \mathbb{E}_{X \sim P_x}[f(X)] = \mathbb{E}_{Y \sim P_G}[f(Y)]</math><br />
where <math>\mathcal{F}_L</math> is the class of all bounded 1-Lipschitz functions on <math>(\mathcal{X}, d)</math>.<br />
<br />
==Application to Generative Models: Wasserstein auto-encoders==<br />
The intuition behind modern generative models like VAEs and GANs is that they try to minimize specific distance measures between the data distribution <math>P_X</math> and the model <math>P_G</math>. Unfortunately, with the current knowledge and tools, it is usually really hard or even impossible to calculate most of the standard discrepancy measures especially when <math>P_X</math> is not known and <math>P_G</math> is parametrized by deep neural networks. Having said that, there are certain tricks available which can be employed to get around that difficulty.<br />
<br />
For KL-divergence <math>D_{KL}(P_X, P_G)</math> minimization, or equivalently the marginal log-likelihood <math>E_{P_X}[log_{P_G}(X)]</math> maximization, one can use the famous variational lower bound which provides a theoretically grounded framework. This has been used quite successfully by the VAEs. In the general case of minimizing f-divergence <math>D_f(P_X, P_G)</math>, using its dual formulation along with f-GANs and adversarial training is viable. Finally, OT cost <math>W_c(P_X, P_G)</math> can be minimized by using the Kantorovich-Rubinstein duality expressed as an adversarial objective. The Wasserstein-GAN implement this idea.<br />
<br />
In this paper, the authors focus on the latent variable models <math>P_G</math> given by a two step procedure. First, a code <math>Z</math> is sampled from a fixed distribution <math>P_Z</math> on a latent space <math>\mathcal{Z}</math>. Second step is to map <math>Z</math> to the image <math>X \in \mathcal{X} = \mathcal{R}^d</math> with a (possibly random) transformation. This gives us a density of the form,<br />
<br />
<center><math><br />
p_G(x) := \int\limits_{\mathcal{Z}}{p_G(x|z)p_z(z)}dz,\ \forall x \in \mathcal{X}, <br />
</math></center><br />
<br />
provided all the probablities involved are properly defined. In order to keep things simple, the authors focus on non-random decoders, i.e., the generative models <math>P_G(X|Z)</math> deterministically map <math>Z</math> to <math>X = G(Z)</math> using a fixed map <math>G: \mathcal{Z} &rarr; \mathcal{X}</math>. Similar results hold for the random decoders as shown by the authors in the appendix B.1.<br />
<br />
Working under the model defined in the preceding paragraph, the authors find that OT cost takes a much simpler form as the transportation plan factors through the map <math>G:</math> instead of finding a coupling <math>\Gamma</math> between two random variables in the <math>\mathcal{X}</math> space, one given by the distribution <math>P_X</math> and the other by the the distribution <math>P_G</math>, it is enough to find a conditional distribution <math>Q(Z|X)</math> such that its <math>Z</math> marginal, <math>Q_Z)Z) := \mathbb{E}_{X \sim P_X}[Q(Z|X)]</math> is the same as the prior distribution <math>P_Z</math>. This is formalized by the theorem given below. The theorem given below was proven in [4] by the authors.<br />
<br />
'''Theorem 1.''' For <math>P_G</math> defined as above with deterministic <math>P_G(X|Z)</math> and any function <math>G:\mathcal{Z} &rarr; \mathcal{X}</math><br />
<br />
<math><br />
\underset{\Gamma \in \mathcal{P}(X \sim P_X ,Y \sim P_G)}{inf} \mathbb{E}_{(X,Y) \sim \Gamma}[c(X,Y)] = \underset{Q: Q_Z = P_Z}{inf} \mathbb{E}_{P_X} \mathbb{E}_{Q(Z|X)}[c(X, G(Z))]<br />
</math><br />
<br />
where <math>Q_Z</math> is the marginal distribution of <math>Z</math> when <math>X \sim P_X</math> and <math>Z \sim Q(Z|X)</math>.<br />
<br />
According to the authors, the result above allows optimization over random encoders <math>Q(Z|X)</math> instead of optimizing overall couplings of <math>X</math> and <math>Y</math>. Both problems are still constrained. To find a numerical solution, the authors relax the constraints on <math>Q_Z</math> by adding a regularizer term to the objective. This gives them the WAE objective:<br />
<br />
<math><br />
D_{WAE}(P_X, P_G) := \underset{Q(Z|X) \in \mathcal{Q}}{inf} \mathbb{E}_{P_X} \mathbb{E}_{Q(Z|X)}[c(X, G(Z))] + \lambda \cdot \mathcal{D}_Z(Q_Z, P_Z)<br />
</math><br />
<br />
where <math>\mathcal{Q}</math> is any nonparametric set of probabilistic encoders, <math>\mathcal{D}_Z</math> is an arbitrary measure of distance between <math>Q_Z</math> and <math>P_Z</math>, and <math>\lambda &gt; 0</math> is a hyperparameter. As is the case with the VAEs, the<br />
authors propose using deep neural networks to parameterize both encoders <math>Q</math> and decoders <math>G</math>. Note that, unlike VAEs, WAE allows for non-random encoders deterministically mapping their inputs to their latent codes.<br />
<br />
The authors propose two different regularizers <math>\mathcal{D}_Z(Q_Z, P_Z)</math><br />
<br />
===GAN-based <math>\mathcal{D}_z</math>===<br />
One of the option is to use <math>\mathcal{D}_Z(Q_Z, P_Z) = \mathcal{D}_{JS}(Q_Z, P_Z)</math> along with adversarial training for estimation. In particular, the discriminator (adversary) is used in the latent space <math>\mathcal{Z}</math> to classify "true" points sampled for <math>P_X</math> and "fake" ones samples from <math>Q_Z</math>. This leads to the WAE-GAN as described in Algorithm 1 listed below. Even though WAE-GAN still uses max-min optimization, one positive feature is that it moves the adversary from the input (pixel) space <math>\mathcal{X}</math> to the latent space <math>\mathcal{Z}</math>. Additionally, the true latent space distribution <math>P_Z</math> might have a nice shape with a single mode (for a Gaussian prior), making the task of matching much easier as opposed to matching an unknown, complex, and possibly multi-modal distributions which is usually the case in GANs. This leads to the second penalty.<br />
<br />
===MMD-based <math>\mathcal{D}_z</math>===<br />
For a positive-definite reproducing kernel <math>k: \mathcal{Z} \times \mathcal{Z} &rarr; \mathcal{R}</math>, the maximum mean discrepancy (MMD) is defined as:<br />
<br />
<center><math><br />
MMD_k(P_Z, Q_Z) = \left \Vert \int \limits_{\mathcal{Z}} {k(z, \cdot)dP_Z(z)} - \int \limits_{\mathcal{Z}} {k(z, \cdot)dQ_Z(z)} \right \|_{\mathcal{H}_k}<br />
</math>,</center><br />
<br />
where <math>\mathcal{H}_k</math> is the RKHS (reproducing kernel Hilbert space) of real-valued functions mappings <math>\mathcal{Z}</math> to <math>\mathcal{R}</math>. If <math>k</math> is characteristi then <math>MMD_k</math> defines a metric and can be used as a distance measure. The authors propose to use <math>\mathcal{D}_Z(P_Z, Q_Z) = MMD_k(P_Z, Q_Z)</math>. MMD also have an unbiased U-statistic estimator which can be used alongwith stochastic gradient descent (SGD) methods. This gives us WAE-MMD as described in the Algorithm 2 listed below. Note that MMD is known to perform well when matching high dimensional standard normal distributions, so it is expected that this penalty will work well when the prior <math>P_Z</math> is Gaussian.<br />
<br />
[[File:ka2khan_figure_2.png|800px|thumb|center|Algorithms- WAE-GAN on left and WAE-MMD on right]]<br />
<br />
=Related Work=<br />
==Literature on auto-encoders==<br />
Classical unregularized auto-encoders have an objective function which only tries to minimize the reconstruction cost. This results in distinct data points being encoded into distinct zones distributed chaotically across the latent space <math>\mathcal{Z}</math>. The latent space <math>\mathcal{Z}</math> in this scenario contains huge "holes" for which the decoder <math>P_G(X|Z)</math> has never been trained. In general, the encoder trained this way do not provide terribly useful representations and sampling from the latent space <math>\mathcal{Z}</math> becomes a difficult task [12].<br />
<br />
VAEs [1] minimize the KL-divergence <math>D_{KL}(P_X, P_G)</math> which consists of the reconstruction cost and the regularizer <math>\mathbb{E}_{P_X}[D_{KL}(Q(|X), P_Z)]</math>. The regularizer penalizes the difference in the encoded training images and the prior <math>P_Z</math>. But this penalty still does not guarantee that the overall encoded distribution matches the prior distribution as WAE does. In addition, VAEs require a non-degenerate (i.e. non-deterministic) Gaussian encoders along with random decoders. Another paper [11] later, proposed a method which allows the use of non-Gaussian encoders with VAEs. In the meanwhile, WAE minimizes <math>W_{c}(P_X, P_G)</math> and allows probabilistic and deterministic encoder and decoder pairs.<br />
<br />
When parameters are appropriately defined, WAE is able to generalize AAE in two ways: it can use any cost function in the input space and use any discrepancy measure <math>D_Z</math> in latent space <math>Z</math> other than the adversarial one.<br />
<br />
There has been work done on regularized auto-encoders called InfoVAE [14], which has objective similar to [4] but using different motivations and arguments.<br />
<br />
WAEs explicitly define the cost function <math>c(x,y)</math>, whereas VAEs rely on an implicitly through a negative log likelihood term. It theoretically can induce any arbitrary cost function, but in practice can require an estimation of the normalizing constant that can be different for values of <math>z</math>.<br />
<br />
==Literature on Optimal Transport (OT)==<br />
[15] provides methods for computing OT cost for large-scale data using SGD and sampling. The WGAN [5] proposes a generative model which minimizes 1-Wasserstein distance <math>W_1(P_X, P_G)</math>. The WGAN algorithm does not provide an encoder and cannot be easily applied to any arbitrary cost <math>W_C</math>. The model proposed in [5] uses the dual form, in contrast, the model proposed in this paper uses the primal form. The primal form allows the use of any arbitrary cost function <math>c</math> and naturally, comes with an encoder. <br />
<br />
In order to compute <math>W_c(P_X, P_G)</math> or <math>W_1(P_X, P_G)</math>, the model needs to handle various non-trivial constraints, various methods has be proposed in the literature ([5], [2], [8], [16], [15], [17], [18]) to avoid this difficulty .<br />
<br />
==Literature on GANs==<br />
A lot of the GAN variations which have been proposed in the literature come without an encoder. Examples include WGAN and f-GAN. These models are deficient in cases where a reconstruction of latent space is needed to use the learned manifold.<br />
<br />
There have been numerous models proposed in the literature which try to combine the adversarial training of GANs with auto-encoder architectures. Some examples are [19], [20], [21], and [22]. There has also been work done in which reproducing kernels have been used in the context of GANS ([23], [24]).<br />
<br />
=Experiments=<br />
Experiments were used to empirically evaluate the proposed WAE model. <br />
<br />
'''Experimental setup'''<br />
<br />
For experimental setup, authors used <math> \small P_Z</math> and squared cost function <math> \small c(x,y)</math> for data points.<br />
Deterministic encoder-decoder pairs were used.The authors conducted experiments using the following two real-world datasets: (1) MNIST [27] made up of 70k images, and (2) CelebA [28] consisting of approximately 203k images. For test reconstruction and interpolations a pair of of held out images, <math>(x,y)</math> from the test set are Auto-encoded (separately), to produce <math>(z_x, z_y)</math> in the latent space<br />
<br />
The main evaluation criteria were to see if the WAE model can simultaneously achieve: <br />
<br />
<ol><br />
<li>accurate reconstruction of the data points</li><br />
<li>resonable geometry of the latent manifold</li><br />
<li>generation of high quality random samples</li><br />
</ol><br />
<br />
For the model to generalize well (1) and (2) should be met on both the training and test data set.<br />
<br />
The proposed model achieve reasonably good results as highlighted in the figures given below:<br />
<br />
[[File:ka2khan_figure_3.png|800px|thumb|center|Using CelebA dataset]]<br />
<br />
[[File:ka2khan_figure_4.png|800px|thumb|center|Using CelebA dataset, FID (Fréchet Inception Distance<br />
[32]): smaller is better, sharpness: larger is better]]<br />
<br />
=Conclusion=<br />
The authors proposed a new class of algorithms for building a generative model called Wasserstein Autoencoders based on optimal transport cost. They related the newly proposed model to the existing probabilistic modeling techniques. They empirically evaluated the proposed models using two real-world datasets. They compared the results obtained using their proposed model with the results obtained using VAEs on the same dataset to show that the proposed models generate sample images of higher quality in addition to being easier to train and having good reconstruction quality of the data points.<br />
<br />
The authors claim that in future work, they will further explore the criteria for matching the encoding distribution <math>Q_Z</math> to the prior distribution <math>P_Z</math>, evaluate whether it is feasible to adversarially train the cost function <math>c</math>in the input space <math>\mathcal{X}</math>, and a theoretical analysis of the dual-formations for WAE-GAN and WAE-MMD.<br />
<br />
=Future Work=<br />
Following the work of this paper, another generative model was introduced by [34] that is based on the concept of optimal transport. Optimal transport is basically the distances between probability distributions by transporting one of the distributions to the other (and hence the name of optimal transport). Then, a new simple model called "Sliced-Wasserstein Autoencoders" (SWAE) is presented, which is easily implemented, and provides the capabilities of Wasserstein Autoencoders.<br />
<br />
([https://openreview.net/forum?id=HkL7n1-0b]) The results from MNIST and CelebA datasets look convincing, though could include additional evaluation to compare the adversarial loss with the straightforward MMD metric and potentially discuss their pros and cons. In some sense, given the challenges in evaluating and comparing closely related auto-encoder solutions, the authors could design demonstrative experiments for cases where Wassersterin distance helps and maybe its potential limitations.<br />
<br />
<br />
<br />
=Critique=<br />
<br />
Although this paper presented some empirical tests to explain its method in an appropriate way, it would be better to provide some clearer notations including the details of the architectures in their experiments. Furthermore, they could benefit from performing some comparisons between the results of their work and other similar works. As pointed out by a reviewer, the closest work to this paper is the adversarial variational bayes framework by Mescheder et.al. which also attempts at unifying VAEs and GANs. Although the authors describe the conceptual differences and advantages over that approach, it will be beneficial to actually include some comparisons in the results section.<br />
Moreover, the performance of the algorithm is not a significant improvement compared to previous VAE algorithm. The performance can be described and tested if the author performed empirical tests on various data sets. However, the methodology is flexible and unified to other types of the algorithm which is a huge benefit.<br />
<br />
=References=<br />
[1] D. P. Kingma and M. Welling. Auto-encoding variational Bayes. In ICLR, 2014.<br />
<br />
[2] A. Makhzani, J. Shlens, N. Jaitly, and I. Goodfellow. Adversarial autoencoders. In ICLR, 2016.<br />
<br />
[3] Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial nets. In NIPS, pages 2672–2680, 2014.<br />
<br />
[4] O. Bousquet, S. Gelly, I. Tolstikhin, C. J. Simon-Gabriel, and B. Schölkopf. From optimal transport to generative modeling: the VEGAN cookbook, 2017.<br />
<br />
[5] M. Arjovsky, S. Chintala, and L. Bottou. Wasserstein GAN, 2017.<br />
<br />
[6] C. Villani. Topics in Optimal Transportation. AMS Graduate Studies in Mathematics, 2003.<br />
<br />
[7] Sebastian Nowozin, Botond Cseke, and Ryota Tomioka. f-GAN: Training generative neural samplers using variational divergence minimization. In NIPS, 2016.<br />
<br />
[8] I. Gulrajani, F. Ahmed, M. Arjovsky, V. Domoulin, and A. Courville. Improved training of wasserstein GANs, 2017.<br />
<br />
[9] A. Gretton, K. M. Borgwardt, M. J. Rasch, B. Schölkopf, and A. J. Smola. A kernel two-sample test. Journal of Machine Learning Research, 13:723–773, 2012.<br />
<br />
[10] F. Liese and K.-J. Miescke. Statistical Decision Theory. Springer, 2008.<br />
<br />
[11] L. Mescheder, S. Nowozin, and A. Geiger. Adversarial variational bayes: Unifying variational autoencoders and generative adversarial networks, 2017.<br />
<br />
[12] Y. Bengio, A. Courville, and P. Vincent. Representation learning: A review and new perspectives. Pattern Analysis and Machine Intelligence, 35, 2013.<br />
<br />
[13] M. D. Hoffman and M. Johnson. Elbo surgery: yet another way to carve up the variational evidence lower bound. In NIPS Workshop on Advances in Approximate Bayesian Inference, 2016.<br />
<br />
[14] S. Zhao, J. Song, and S. Ermon. InfoVAE: Information maximizing variational autoencoders, 2017.<br />
<br />
[15] A. Genevay, M. Cuturi, G. Peyré, and F. R. Bach. Stochastic optimization for large-scale optimal transport. In Advances in Neural Information Processing Systems, pages 3432–3440, 2016. <br />
<br />
[16] M. Cuturi. Sinkhorn distances: Lightspeed computation of optimal transport. In Advances in Neural Information Processing Systems, pages 2292–2300, 2013.<br />
<br />
[17] Lenaic Chizat, Gabriel Peyré, Bernhard Schmitzer, and François-Xavier Vialard. Unbalanced optimal transport: geometry and kantorovich formulation. arXiv preprint arXiv:1508.05216, 2015.<br />
<br />
[18] Matthias Liero, Alexander Mielke, and Giuseppe Savaré. Optimal entropy-transport problems and a new hellinger-kantorovich distance between positive measures. arXiv preprint arXiv:1508.07941, 2015.<br />
<br />
[19] J. Zhao, M. Mathieu, and Y. LeCun. Energy-based generative adversarial network. In ICLR, 2017.<br />
<br />
[20] V. Dumoulin, I. Belghazi, B. Poole, A. Lamb, M. Arjovsky, O. Mastropietro, and A. Courville. Adversarially learned inference. In ICLR, 2017.<br />
<br />
[21] D. Ulyanov, A. Vedaldi, and V. Lempitsky. It takes (only) two: Adversarial generator-encoder networks, 2017.<br />
<br />
[22] D. Berthelot, T. Schumm, and L. Metz. Began: Boundary equilibrium generative adversarial networks, 2017.<br />
<br />
[23] Y. Li, K. Swersky, and R. Zemel. Generative moment matching networks. In ICML, 2015. <br />
<br />
[24] G. K. Dziugaite, D. M. Roy, and Z. Ghahramani. Training generative neural networks via maximum mean discrepancy optimization. In UAI, 2015.<br />
<br />
[25] R. Reddi, A. Ramdas, A. Singh, B. Poczos, and L. Wasserman. On the high-dimensional power of a linear-time two sample test under mean-shift alternatives. In AISTATS, 2015.<br />
<br />
[26] C. L. Li, W. C. Chang, Y. Cheng, Y. Yang, and B. Poczos. Mmd gan: Towards deeper understanding of moment matching network, 2017.<br />
<br />
[27] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. Gradient-based learning applied to document recognition. In Proceedings of the IEEE, volume 86(11), pages 2278–2324, 1998.<br />
<br />
[28] Ziwei Liu, Ping Luo, Xiaogang Wang, and Xiaoou Tang. Deep learning face attributes in the wild. In Proceedings of International Conference on Computer Vision (ICCV), 2015.<br />
<br />
[29] D. P. Kingma and J. Lei. Adam: A method for stochastic optimization, 2014.<br />
<br />
[30] A. Radford, L. Metz, and S. Chintala. Unsupervised representation learning with deep convolutional generative adversarial networks. In ICLR, 2016.<br />
<br />
[31] S. Ioffe and C. Szegedy. Batch normalization: Accelerating deep network training by reducing internal covariate shift, 2015.<br />
<br />
[32] Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, Günter Klambauer, and Sepp Hochreiter. GANs trained by a two time-scale update rule converge to a nash equilibrium. arXiv preprint arXiv:1706.08500, 2017.<br />
<br />
[33] B. Poole, A. Alemi, J. Sohl-Dickstein, and A. Angelova. Improved generator objectives for GANs, 2016.<br />
<br />
[34] S. Kolouri, C. E. Martin, and G. K. Rohde. Sliced-wasserstein autoencoder: An embarrassingly simple generative model. arXiv preprint arXiv:1804.01947, 2018.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Wasserstein_Auto-encoders&diff=41935Wasserstein Auto-encoders2018-11-29T23:15:47Z<p>Gsahu: /* Optimal Transport and its Dual Formations */</p>
<hr />
<div>The first version of this work was published in 2017 and this version (which is the third revision) is presented in ICLR 2018. Source code for the first version is available [https://github.com/tolstikhin/wae here]<br />
<br />
=Introduction=<br />
Early successes in the field of representation learning were based on supervised approaches, which used large labeled datasets to achieve impressive results. On the other hand, popular unsupervised generative modeling methods mainly consisted of probabilistic approaches focusing on low dimensional data. In recent years, there have been models proposed which try to combine these two approaches. One such popular method is called variational auto-encoders (VAEs). VAEs are theoretically elegant but have a major drawback of generating blurry sample images when used for modeling natural images. In comparison, generative adversarial networks (GANs) produce much sharper sample images but have their own list of problems which include a lack of encoder, harder to train, and the "mode collapse" problem. Mode collapse problem refers to the inability of the model to capture all the variability in the true data distribution. Currently, there has been a lot of activities around finding and evaluating numerous GANs architectures and combining VAEs and GANs, but a model which combines the best of both GANs and VAEs is yet to be discovered.<br />
<br />
The work done in this paper builds upon the theoretical work done in Bousquet et al.[2017] [4]. The authors tackle generative modeling using optimal transport (OT). The OT cost is defined as the measure of distance between probability distributions.<br />
<br />
To be more specific on the OT:<br />
<br />
Given a function <math>c : X × Y → R</math>, they seek a minimizer of <math> C(µ, ν) := \underset{π ∈ Π(µ, ν)}{inf} \int_{X×Y}{c(x, y)dπ(x, y)}</math><br />
<br />
The measures <math>π ∈ Π(µ, ν)</math> are called transport plans or transference plans. The measures <math>π ∈ Π(µ, ν)</math> achieving the infimum are called optimal transport plans. The classical interpretation of this problem is the problem of minimizing the total cost <math>C(µ, ν)</math> of transporting the mass distribution <math>µ</math> to the mass distribution <math>ν</math>, where the cost of transporting one unit of mass at the point <math>x ∈ X</math> to one unit of mass at the point <math>y ∈ Y</math> is given by the cost function <math>c(x, y)</math>.<br />
<br />
One of the features of OT cost which is beneficial is that it provides much weaker topology when compared to other costs, including f-divergences which are associated with the original GAN algorithms. <br />
This particular feature is crucial in applications where the data is usually supported on low dimensional manifolds in the input space. This result in a problem with the stronger notions of distances such as f-divergences as they often max out and provide no useful gradients for training. In comparison, the OT cost has been claimed to behave much more nicely [5, 8]. Despite the preceding claim, the implementation, which is similar to GANs, still requires the addition of a constraint or a regularization term into the objective function.<br />
<br />
==Original Contributions==<br />
Let <math>P_X</math> be the true but unknown data distribution, <math>P_G</math> be the latent variable model specified by the prior distribution <math>P_Z</math> of latent codes <math>Z \in \mathcal{Z}</math> and the generative model <math>P_G(X|Z)</math> of the data points <math>X \in \mathcal{X}</math> given <math>Z</math>. The goal in this paper is to minimize <math>OT\ W_c(P_X, P_G)</math>.<br />
<br />
The main contributions are given below:<br />
<br />
* A new class of auto-encoders called Wasserstein Auto-Encoders (WAE). WAEs minimize the optimal transport <math>W_c(P_X, P_G)</math> for any cost function <math>c</math>. As is the case with VAEs, WAE objective function is also made up of two terms: the c-reconstruction cost and a regularizer term <math>\mathcal{D}_Z(P_Z, Q_Z)</math> which penalizes the discrepancy between two distributions in <math>\mathcal{Z}: P_Z\ and\ Q_Z</math>. <math>Q_Z</math> is a distribution of encoded points, i.e. <math>Q_Z := \mathbb{E}_{P_X}[Q(Z|X)]</math>. Note that when <math>c</math> is the squared cost and the regularizer term is the GAN objective, WAE is equivalent to the adversarial auto-encoders described in [2].<br />
<br />
* Experimental results of using WAE on MNIST and CelebA datasets with squared cost <math>c(x, y) = ||x - y||_2^2</math>. The results of these experiments show that WAEs have the good features of VAEs such as stable training, encoder-decoder architecture, and a nice latent manifold structure while simultaneously improving the quality of the generated samples.<br />
<br />
* Two different regularizers. One based on GANs and adversarial training in the latent space <math>\mathcal{Z}</math>. The other one is based on something called "Maximum Mean Discrepancy" which known to have high performance when matching high dimensional standard normal distributions. The second regularizer also makes the problem fully adversary-free min-min optimization problem, and gets rid of the problem of tuning the GAN.<br />
<br />
* The final contribution is the mathematical analysis used to derive the WAE objective function. In particular, the mathematical analysis shows that in the case of generative models, the primal form of <math>W_c(P_X, P_G)</math> is equivalent to a problem which deals with the optimization of a probabilistic encoder <math>Q(Z|X)</math><br />
<br />
The paper provides an ostensibly simple recipe to implement a non-blurry VAE (it is generative) It provides what looks like an elegant and logical way to cast the Wasserstein distance metric to setup the VAE/GAN problem.<br />
The paper gives three instructive VAEGAN model comparisons, unifying them thematically – Adversarial Autoencoders (AAE), Adversarial Variational Bayes (AVB), and the original Variational Autoencoders (VAE). These generalizations arise for the case with random decoders – the paper introduces the idea with deterministic decodes, and then extends it to random decoders – with play on the regularizer of the VAE which these papers replace with a GAN.<br />
<br />
=Proposed Method=<br />
The method proposed by the authors uses a novel auto-encoder architecture to minimize the optimal transport cost <math>W_c(P_X, P_G)</math>. In the optimization problem that follows, the decoder tries to accurately reconstruct the data points as measured by the cost function <math>c</math>. The encoder tries to achieve the following two conflicting goals at the same time: (1) try to match the distribution of the encoded data points <math>Q_Z := \mathbb{E}_{P_X}[Q(Z|X)]</math> to the prior distribution <math>P_Z</math> as measured by the divergence <math>\mathcal{D}_Z(P_Z, Q_Z)</math> and, (2) make sure that the latent space vectors encoded contain enough information so that the reconstruction of the data points are of high quality. The figure below illustrates this:<br />
<br />
[[File:ka2khan_figure_1.png|800px|thumb|center|Figure 1]]<br />
<br />
Figure 1: Both VAE and WAE have objectives which are composed of two terms. The two terms are the reconstruction cost and the regularizer term which penalizes the divergence between <math>P_Z</math> and <math>Q_Z</math>. VAE forces <math>Q(Z|X = x)</math> to match <math>P_Z</math> for the the different training examples drawn from <math>P_X</math>. As shown in the figure above, every red ball representing <math>Q_z</math> is forced to match <math>P_Z</math> depicted as whitish triangles. This causes intersection among red balls and results in reconstruction problems. On the other hand, WAE coerces the mixture <math>Q_Z := \int{Q(Z|X)\ dP_X}</math> to match <math>P_Z</math> as shown in the figure above. This provides a better chance of the encoded latent codes to have more distance between them. As a consequence of this, higher reconstruction quality is achieved.<br />
<br />
==Preliminaries and Notations==<br />
Authors use calligraphic letters to denote sets (for example, <math>\mathcal{X}</math>), capital letters for random variables (for example, <math>X</math>), and lower case letters for the values (for example, <math>x</math>). Probability distributions are are also denoted with capital letters (for example, <math>P(X)</math>) and the corresponding densities are denoted with lowercase letter (for example, <math>p(x)</math>).<br />
<br />
Several measure of difference between probability distributions are also used by the authors. These include f-divergences given by <math>D_f(p_X||p_G) := \int{f(\frac{p_X(x)}{p_G(x)})p_G(x)}dx\ \text{where}\ f:(0, \infty) &rarr; \mathcal{R}</math> is any convex function satisfying <math>f(1) = 0</math>. Other divergences used include KL divergence (<math>D_{KL}</math>) and Jensen-Shannon (<math>D_{JS}</math>) divergences.<br />
<br />
==Optimal Transport and its Dual Formations==<br />
<br />
A rich class of measure of distances between probability distributions is motivated by the optimal transport problem. One such formulation of the optimal transport problem is the Kantovorich's formulation given by:<br />
<br />
<center><math><br />
W_c(P_X, P_G) := \underset{\Gamma \in \mathcal{P}(X \sim P_X ,Y \sim P_G)}{inf} \mathbb{E}_{(X,Y) \sim \Gamma}[c(X,Y)],<br />
\text{where} \ c(x, y): \mathcal{X} \times \mathcal{X} &rarr; \mathcal{R_{+}}<br />
</math></center><br />
<br />
is any measurable cost function and <math>\mathcal{P}(X \sim P_X, Y \sim P_G)</math> is a set of all joint distributions of (X, Y) with marginals <math>P_X\ \text{and}\ P_G</math> respectively.<br />
<br />
A particularly interesting case is when <math>(\mathcal{X}, d)</math> is metric space and <math>c(x, y) = d^p(x, y)\ \text{for}\ p &ge; 1</math>. In this case <math>W_p</math>, the <math>p-th</math> root of <math>W_c</math>, is called the p-Wasserstein distance.<br />
<br />
When <math>c(x, y) = d(x, y)</math> the following Kantorovich-Rubinstein duality holds:<br />
<br />
<math>W_1(P_X, P_G) = \underset{f \in \mathcal{F}_L}{sup} \mathbb{E}_{X \sim P_x}[f(X)] = \mathbb{E}_{Y \sim P_G}[f(Y)]</math><br />
where <math>\mathcal{F}_L</math> is the class of all bounded 1-Lipschitz functions on <math>(\mathcal{X}, d)</math>.<br />
<br />
==Application to Generative Models: Wasserstein auto-encoders==<br />
The intuition behind modern generative models like VAEs and GANs is that they try to minimize specific distance measures between the data distribution <math>P_X</math> and the model <math>P_G</math>. Unfortunately, with the current knowledge and tools, it is usually really hard or even impossible to calculate most of the standard discrepancy measures especially when <math>P_X</math> is not known and <math>P_G</math> is parametrized by deep neural networks. Having said that, there are certain tricks available which can be employed to get around that difficulty.<br />
<br />
For KL-divergence <math>D_{KL}(P_X, P_G)</math> minimization, or equivalently the marginal log-likelihood <math>E_{P_X}[log_{P_G}(X)]</math> maximization, one can use the famous variational lower bound which provides a theoretically grounded framework. This has been used quite successfully by the VAEs. In the general case of minimizing f-divergence <math>D_f(P_X, P_G)</math>, using its dual formulation along with f-GANs and adversarial training is viable. Finally, OT cost <math>W_c(P_X, P_G)</math> can be minimized by using the Kantorovich-Rubinstein duality expressed as an adversarial objective. The Wasserstein-GAN implement this idea.<br />
<br />
In this paper, the authors focus on the latent variable models <math>P_G</math> given by a two step procedure. First, a code <math>Z</math> is sampled from a fixed distribution <math>P_Z</math> on a latent space <math>\mathcal{Z}</math>. Second step is to map <math>Z</math> to the image <math>X \in \mathcal{X} = \mathcal{R}^d</math> with a (possibly random) transformation. This gives us a density of the form<br />
<br />
<math><br />
p_G(x) := \int\limits_{\mathcal{Z}}{p_G(x|z)p_z(z)}dz,\ \forall x \in \mathcal{X}, <br />
</math><br />
<br />
provided all the probablities involved are properly defined. In order to keep things simple, the authors focus on non-random decoders, i.e., the generative models <math>P_G(X|Z)</math> deterministically map <math>Z</math> to <math>X = G(Z)</math> using a fixed map <math>G: \mathcal{Z} &rarr; \mathcal{X}</math>. Similar results hold for the random decoders as shown by the authors in the appendix B.1.<br />
<br />
Working under the model defined in the preceding paragraph, the authors find that OT cost takes a much simpler form as the transportation plan factors through the map <math>G:</math> instead of finding a coupling <math>\Gamma</math> between two random variables in the <math>\mathcal{X}</math> space, one given by the distribution <math>P_X</math> and the other by the the distribution <math>P_G</math>, it is enough to find a conditional distribution <math>Q(Z|X)</math> such that its <math>Z</math> marginal, <math>Q_Z)Z) := \mathbb{E}_{X \sim P_X}[Q(Z|X)]</math> is the same as the prior distribution <math>P_Z</math>. This is formalized by the theorem given below. The theorem given below was proven in [4] by the authors.<br />
<br />
'''Theorem 1.''' For <math>P_G</math> defined as above with deterministic <math>P_G(X|Z)</math> and any function <math>G:\mathcal{Z} &rarr; \mathcal{X}</math><br />
<br />
<math><br />
\underset{\Gamma \in \mathcal{P}(X \sim P_X ,Y \sim P_G)}{inf} \mathbb{E}_{(X,Y) \sim \Gamma}[c(X,Y)] = \underset{Q: Q_Z = P_Z}{inf} \mathbb{E}_{P_X} \mathbb{E}_{Q(Z|X)}[c(X, G(Z))]<br />
</math><br />
<br />
where <math>Q_Z</math> is the marginal distribution of <math>Z</math> when <math>X \sim P_X</math> and <math>Z \sim Q(Z|X)</math>.<br />
<br />
According to the authors, the result above allows optimization over random encoders <math>Q(Z|X)</math> instead of optimizing overall couplings of <math>X</math> and <math>Y</math>. Both problems are still constrained. To find a numerical solution, the authors relax the constraints on <math>Q_Z</math> by adding a regularizer term to the objective. This gives them the WAE objective:<br />
<br />
<math><br />
D_{WAE}(P_X, P_G) := \underset{Q(Z|X) \in \mathcal{Q}}{inf} \mathbb{E}_{P_X} \mathbb{E}_{Q(Z|X)}[c(X, G(Z))] + \lambda \cdot \mathcal{D}_Z(Q_Z, P_Z)<br />
</math><br />
<br />
where <math>\mathcal{Q}</math> is any nonparametric set of probabilistic encoders, <math>\mathcal{D}_Z</math> is an arbitrary measure of distance between <math>Q_Z</math> and <math>P_Z</math>, and <math>\lambda &gt; 0</math> is a hyperparameter. As is the case with the VAEs, the<br />
authors propose using deep neural networks to parameterize both encoders <math>Q</math> and decoders <math>G</math>. Note that, unlike VAEs, WAE allows for non-random encoders deterministically mapping their inputs to their latent codes.<br />
<br />
The authors propose two different regularizers <math>\mathcal{D}_Z(Q_Z, P_Z)</math><br />
<br />
===GAN-based <math>\mathcal{D}_z</math>===<br />
One of the option is to use <math>\mathcal{D}_Z(Q_Z, P_Z) = \mathcal{D}_{JS}(Q_Z, P_Z)</math> along with adversarial training for estimation. In particular, the discriminator (adversary) is used in the latent space <math>\mathcal{Z}</math> to classify "true" points sampled for <math>P_X</math> and "fake" ones samples from <math>Q_Z</math>. This leads to the WAE-GAN as described in Algorithm 1 listed below. Even though WAE-GAN still uses max-min optimization, one positive feature is that it moves the adversary from the input (pixel) space <math>\mathcal{X}</math> to the latent space <math>\mathcal{Z}</math>. Additionally, the true latent space distribution <math>P_Z</math> might have a nice shape with a single mode (for a Gaussian prior), making the task of matching much easier as opposed to matching an unknown, complex, and possibly multi-modal distributions which is usually the case in GANs. This leads to the second penalty.<br />
<br />
===MMD-based <math>\mathcal{D}_z</math>===<br />
For a positive-definite reproducing kernel <math>k: \mathcal{Z} \times \mathcal{Z} &rarr; \mathcal{R}</math>, the maximum mean discrepancy (MMD) is defined as:<br />
<br />
<center><math><br />
MMD_k(P_Z, Q_Z) = \left \Vert \int \limits_{\mathcal{Z}} {k(z, \cdot)dP_Z(z)} - \int \limits_{\mathcal{Z}} {k(z, \cdot)dQ_Z(z)} \right \|_{\mathcal{H}_k}<br />
</math>,</center><br />
<br />
where <math>\mathcal{H}_k</math> is the RKHS (reproducing kernel Hilbert space) of real-valued functions mappings <math>\mathcal{Z}</math> to <math>\mathcal{R}</math>. If <math>k</math> is characteristi then <math>MMD_k</math> defines a metric and can be used as a distance measure. The authors propose to use <math>\mathcal{D}_Z(P_Z, Q_Z) = MMD_k(P_Z, Q_Z)</math>. MMD also have an unbiased U-statistic estimator which can be used alongwith stochastic gradient descent (SGD) methods. This gives us WAE-MMD as described in the Algorithm 2 listed below. Note that MMD is known to perform well when matching high dimensional standard normal distributions, so it is expected that this penalty will work well when the prior <math>P_Z</math> is Gaussian.<br />
<br />
[[File:ka2khan_figure_2.png|800px|thumb|center|Algorithms- WAE-GAN on left and WAE-MMD on right]]<br />
<br />
=Related Work=<br />
==Literature on auto-encoders==<br />
Classical unregularized auto-encoders have an objective function which only tries to minimize the reconstruction cost. This results in distinct data points being encoded into distinct zones distributed chaotically across the latent space <math>\mathcal{Z}</math>. The latent space <math>\mathcal{Z}</math> in this scenario contains huge "holes" for which the decoder <math>P_G(X|Z)</math> has never been trained. In general, the encoder trained this way do not provide terribly useful representations and sampling from the latent space <math>\mathcal{Z}</math> becomes a difficult task [12].<br />
<br />
VAEs [1] minimize the KL-divergence <math>D_{KL}(P_X, P_G)</math> which consists of the reconstruction cost and the regularizer <math>\mathbb{E}_{P_X}[D_{KL}(Q(|X), P_Z)]</math>. The regularizer penalizes the difference in the encoded training images and the prior <math>P_Z</math>. But this penalty still does not guarantee that the overall encoded distribution matches the prior distribution as WAE does. In addition, VAEs require a non-degenerate (i.e. non-deterministic) Gaussian encoders along with random decoders. Another paper [11] later, proposed a method which allows the use of non-Gaussian encoders with VAEs. In the meanwhile, WAE minimizes <math>W_{c}(P_X, P_G)</math> and allows probabilistic and deterministic encoder and decoder pairs.<br />
<br />
When parameters are appropriately defined, WAE is able to generalize AAE in two ways: it can use any cost function in the input space and use any discrepancy measure <math>D_Z</math> in latent space <math>Z</math> other than the adversarial one.<br />
<br />
There has been work done on regularized auto-encoders called InfoVAE [14], which has objective similar to [4] but using different motivations and arguments.<br />
<br />
WAEs explicitly define the cost function <math>c(x,y)</math>, whereas VAEs rely on an implicitly through a negative log likelihood term. It theoretically can induce any arbitrary cost function, but in practice can require an estimation of the normalizing constant that can be different for values of <math>z</math>.<br />
<br />
==Literature on Optimal Transport (OT)==<br />
[15] provides methods for computing OT cost for large-scale data using SGD and sampling. The WGAN [5] proposes a generative model which minimizes 1-Wasserstein distance <math>W_1(P_X, P_G)</math>. The WGAN algorithm does not provide an encoder and cannot be easily applied to any arbitrary cost <math>W_C</math>. The model proposed in [5] uses the dual form, in contrast, the model proposed in this paper uses the primal form. The primal form allows the use of any arbitrary cost function <math>c</math> and naturally, comes with an encoder. <br />
<br />
In order to compute <math>W_c(P_X, P_G)</math> or <math>W_1(P_X, P_G)</math>, the model needs to handle various non-trivial constraints, various methods has be proposed in the literature ([5], [2], [8], [16], [15], [17], [18]) to avoid this difficulty .<br />
<br />
==Literature on GANs==<br />
A lot of the GAN variations which have been proposed in the literature come without an encoder. Examples include WGAN and f-GAN. These models are deficient in cases where a reconstruction of latent space is needed to use the learned manifold.<br />
<br />
There have been numerous models proposed in the literature which try to combine the adversarial training of GANs with auto-encoder architectures. Some examples are [19], [20], [21], and [22]. There has also been work done in which reproducing kernels have been used in the context of GANS ([23], [24]).<br />
<br />
=Experiments=<br />
Experiments were used to empirically evaluate the proposed WAE model. <br />
<br />
'''Experimental setup'''<br />
<br />
For experimental setup, authors used <math> \small P_Z</math> and squared cost function <math> \small c(x,y)</math> for data points.<br />
Deterministic encoder-decoder pairs were used.The authors conducted experiments using the following two real-world datasets: (1) MNIST [27] made up of 70k images, and (2) CelebA [28] consisting of approximately 203k images. For test reconstruction and interpolations a pair of of held out images, <math>(x,y)</math> from the test set are Auto-encoded (separately), to produce <math>(z_x, z_y)</math> in the latent space<br />
<br />
The main evaluation criteria were to see if the WAE model can simultaneously achieve: <br />
<br />
<ol><br />
<li>accurate reconstruction of the data points</li><br />
<li>resonable geometry of the latent manifold</li><br />
<li>generation of high quality random samples</li><br />
</ol><br />
<br />
For the model to generalize well (1) and (2) should be met on both the training and test data set.<br />
<br />
The proposed model achieve reasonably good results as highlighted in the figures given below:<br />
<br />
[[File:ka2khan_figure_3.png|800px|thumb|center|Using CelebA dataset]]<br />
<br />
[[File:ka2khan_figure_4.png|800px|thumb|center|Using CelebA dataset, FID (Fréchet Inception Distance<br />
[32]): smaller is better, sharpness: larger is better]]<br />
<br />
=Conclusion=<br />
The authors proposed a new class of algorithms for building a generative model called Wasserstein Autoencoders based on optimal transport cost. They related the newly proposed model to the existing probabilistic modeling techniques. They empirically evaluated the proposed models using two real-world datasets. They compared the results obtained using their proposed model with the results obtained using VAEs on the same dataset to show that the proposed models generate sample images of higher quality in addition to being easier to train and having good reconstruction quality of the data points.<br />
<br />
The authors claim that in future work, they will further explore the criteria for matching the encoding distribution <math>Q_Z</math> to the prior distribution <math>P_Z</math>, evaluate whether it is feasible to adversarially train the cost function <math>c</math>in the input space <math>\mathcal{X}</math>, and a theoretical analysis of the dual-formations for WAE-GAN and WAE-MMD.<br />
<br />
=Future Work=<br />
Following the work of this paper, another generative model was introduced by [34] that is based on the concept of optimal transport. Optimal transport is basically the distances between probability distributions by transporting one of the distributions to the other (and hence the name of optimal transport). Then, a new simple model called "Sliced-Wasserstein Autoencoders" (SWAE) is presented, which is easily implemented, and provides the capabilities of Wasserstein Autoencoders.<br />
<br />
([https://openreview.net/forum?id=HkL7n1-0b]) The results from MNIST and CelebA datasets look convincing, though could include additional evaluation to compare the adversarial loss with the straightforward MMD metric and potentially discuss their pros and cons. In some sense, given the challenges in evaluating and comparing closely related auto-encoder solutions, the authors could design demonstrative experiments for cases where Wassersterin distance helps and maybe its potential limitations.<br />
<br />
<br />
<br />
=Critique=<br />
<br />
Although this paper presented some empirical tests to explain its method in an appropriate way, it would be better to provide some clearer notations including the details of the architectures in their experiments. Furthermore, they could benefit from performing some comparisons between the results of their work and other similar works. As pointed out by a reviewer, the closest work to this paper is the adversarial variational bayes framework by Mescheder et.al. which also attempts at unifying VAEs and GANs. Although the authors describe the conceptual differences and advantages over that approach, it will be beneficial to actually include some comparisons in the results section.<br />
Moreover, the performance of the algorithm is not a significant improvement compared to previous VAE algorithm. The performance can be described and tested if the author performed empirical tests on various data sets. However, the methodology is flexible and unified to other types of the algorithm which is a huge benefit.<br />
<br />
=References=<br />
[1] D. P. Kingma and M. Welling. Auto-encoding variational Bayes. In ICLR, 2014.<br />
<br />
[2] A. Makhzani, J. Shlens, N. Jaitly, and I. Goodfellow. Adversarial autoencoders. In ICLR, 2016.<br />
<br />
[3] Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial nets. In NIPS, pages 2672–2680, 2014.<br />
<br />
[4] O. Bousquet, S. Gelly, I. Tolstikhin, C. J. Simon-Gabriel, and B. Schölkopf. From optimal transport to generative modeling: the VEGAN cookbook, 2017.<br />
<br />
[5] M. Arjovsky, S. Chintala, and L. Bottou. Wasserstein GAN, 2017.<br />
<br />
[6] C. Villani. Topics in Optimal Transportation. AMS Graduate Studies in Mathematics, 2003.<br />
<br />
[7] Sebastian Nowozin, Botond Cseke, and Ryota Tomioka. f-GAN: Training generative neural samplers using variational divergence minimization. In NIPS, 2016.<br />
<br />
[8] I. Gulrajani, F. Ahmed, M. Arjovsky, V. Domoulin, and A. Courville. Improved training of wasserstein GANs, 2017.<br />
<br />
[9] A. Gretton, K. M. Borgwardt, M. J. Rasch, B. Schölkopf, and A. J. Smola. A kernel two-sample test. Journal of Machine Learning Research, 13:723–773, 2012.<br />
<br />
[10] F. Liese and K.-J. Miescke. Statistical Decision Theory. Springer, 2008.<br />
<br />
[11] L. Mescheder, S. Nowozin, and A. Geiger. Adversarial variational bayes: Unifying variational autoencoders and generative adversarial networks, 2017.<br />
<br />
[12] Y. Bengio, A. Courville, and P. Vincent. Representation learning: A review and new perspectives. Pattern Analysis and Machine Intelligence, 35, 2013.<br />
<br />
[13] M. D. Hoffman and M. Johnson. Elbo surgery: yet another way to carve up the variational evidence lower bound. In NIPS Workshop on Advances in Approximate Bayesian Inference, 2016.<br />
<br />
[14] S. Zhao, J. Song, and S. Ermon. InfoVAE: Information maximizing variational autoencoders, 2017.<br />
<br />
[15] A. Genevay, M. Cuturi, G. Peyré, and F. R. Bach. Stochastic optimization for large-scale optimal transport. In Advances in Neural Information Processing Systems, pages 3432–3440, 2016. <br />
<br />
[16] M. Cuturi. Sinkhorn distances: Lightspeed computation of optimal transport. In Advances in Neural Information Processing Systems, pages 2292–2300, 2013.<br />
<br />
[17] Lenaic Chizat, Gabriel Peyré, Bernhard Schmitzer, and François-Xavier Vialard. Unbalanced optimal transport: geometry and kantorovich formulation. arXiv preprint arXiv:1508.05216, 2015.<br />
<br />
[18] Matthias Liero, Alexander Mielke, and Giuseppe Savaré. Optimal entropy-transport problems and a new hellinger-kantorovich distance between positive measures. arXiv preprint arXiv:1508.07941, 2015.<br />
<br />
[19] J. Zhao, M. Mathieu, and Y. LeCun. Energy-based generative adversarial network. In ICLR, 2017.<br />
<br />
[20] V. Dumoulin, I. Belghazi, B. Poole, A. Lamb, M. Arjovsky, O. Mastropietro, and A. Courville. Adversarially learned inference. In ICLR, 2017.<br />
<br />
[21] D. Ulyanov, A. Vedaldi, and V. Lempitsky. It takes (only) two: Adversarial generator-encoder networks, 2017.<br />
<br />
[22] D. Berthelot, T. Schumm, and L. Metz. Began: Boundary equilibrium generative adversarial networks, 2017.<br />
<br />
[23] Y. Li, K. Swersky, and R. Zemel. Generative moment matching networks. In ICML, 2015. <br />
<br />
[24] G. K. Dziugaite, D. M. Roy, and Z. Ghahramani. Training generative neural networks via maximum mean discrepancy optimization. In UAI, 2015.<br />
<br />
[25] R. Reddi, A. Ramdas, A. Singh, B. Poczos, and L. Wasserman. On the high-dimensional power of a linear-time two sample test under mean-shift alternatives. In AISTATS, 2015.<br />
<br />
[26] C. L. Li, W. C. Chang, Y. Cheng, Y. Yang, and B. Poczos. Mmd gan: Towards deeper understanding of moment matching network, 2017.<br />
<br />
[27] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. Gradient-based learning applied to document recognition. In Proceedings of the IEEE, volume 86(11), pages 2278–2324, 1998.<br />
<br />
[28] Ziwei Liu, Ping Luo, Xiaogang Wang, and Xiaoou Tang. Deep learning face attributes in the wild. In Proceedings of International Conference on Computer Vision (ICCV), 2015.<br />
<br />
[29] D. P. Kingma and J. Lei. Adam: A method for stochastic optimization, 2014.<br />
<br />
[30] A. Radford, L. Metz, and S. Chintala. Unsupervised representation learning with deep convolutional generative adversarial networks. In ICLR, 2016.<br />
<br />
[31] S. Ioffe and C. Szegedy. Batch normalization: Accelerating deep network training by reducing internal covariate shift, 2015.<br />
<br />
[32] Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, Günter Klambauer, and Sepp Hochreiter. GANs trained by a two time-scale update rule converge to a nash equilibrium. arXiv preprint arXiv:1706.08500, 2017.<br />
<br />
[33] B. Poole, A. Alemi, J. Sohl-Dickstein, and A. Angelova. Improved generator objectives for GANs, 2016.<br />
<br />
[34] S. Kolouri, C. E. Martin, and G. K. Rohde. Sliced-wasserstein autoencoder: An embarrassingly simple generative model. arXiv preprint arXiv:1804.01947, 2018.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Wasserstein_Auto-encoders&diff=41934Wasserstein Auto-encoders2018-11-29T23:15:24Z<p>Gsahu: /* MMD-based \mathcal{D}_z */</p>
<hr />
<div>The first version of this work was published in 2017 and this version (which is the third revision) is presented in ICLR 2018. Source code for the first version is available [https://github.com/tolstikhin/wae here]<br />
<br />
=Introduction=<br />
Early successes in the field of representation learning were based on supervised approaches, which used large labeled datasets to achieve impressive results. On the other hand, popular unsupervised generative modeling methods mainly consisted of probabilistic approaches focusing on low dimensional data. In recent years, there have been models proposed which try to combine these two approaches. One such popular method is called variational auto-encoders (VAEs). VAEs are theoretically elegant but have a major drawback of generating blurry sample images when used for modeling natural images. In comparison, generative adversarial networks (GANs) produce much sharper sample images but have their own list of problems which include a lack of encoder, harder to train, and the "mode collapse" problem. Mode collapse problem refers to the inability of the model to capture all the variability in the true data distribution. Currently, there has been a lot of activities around finding and evaluating numerous GANs architectures and combining VAEs and GANs, but a model which combines the best of both GANs and VAEs is yet to be discovered.<br />
<br />
The work done in this paper builds upon the theoretical work done in Bousquet et al.[2017] [4]. The authors tackle generative modeling using optimal transport (OT). The OT cost is defined as the measure of distance between probability distributions.<br />
<br />
To be more specific on the OT:<br />
<br />
Given a function <math>c : X × Y → R</math>, they seek a minimizer of <math> C(µ, ν) := \underset{π ∈ Π(µ, ν)}{inf} \int_{X×Y}{c(x, y)dπ(x, y)}</math><br />
<br />
The measures <math>π ∈ Π(µ, ν)</math> are called transport plans or transference plans. The measures <math>π ∈ Π(µ, ν)</math> achieving the infimum are called optimal transport plans. The classical interpretation of this problem is the problem of minimizing the total cost <math>C(µ, ν)</math> of transporting the mass distribution <math>µ</math> to the mass distribution <math>ν</math>, where the cost of transporting one unit of mass at the point <math>x ∈ X</math> to one unit of mass at the point <math>y ∈ Y</math> is given by the cost function <math>c(x, y)</math>.<br />
<br />
One of the features of OT cost which is beneficial is that it provides much weaker topology when compared to other costs, including f-divergences which are associated with the original GAN algorithms. <br />
This particular feature is crucial in applications where the data is usually supported on low dimensional manifolds in the input space. This result in a problem with the stronger notions of distances such as f-divergences as they often max out and provide no useful gradients for training. In comparison, the OT cost has been claimed to behave much more nicely [5, 8]. Despite the preceding claim, the implementation, which is similar to GANs, still requires the addition of a constraint or a regularization term into the objective function.<br />
<br />
==Original Contributions==<br />
Let <math>P_X</math> be the true but unknown data distribution, <math>P_G</math> be the latent variable model specified by the prior distribution <math>P_Z</math> of latent codes <math>Z \in \mathcal{Z}</math> and the generative model <math>P_G(X|Z)</math> of the data points <math>X \in \mathcal{X}</math> given <math>Z</math>. The goal in this paper is to minimize <math>OT\ W_c(P_X, P_G)</math>.<br />
<br />
The main contributions are given below:<br />
<br />
* A new class of auto-encoders called Wasserstein Auto-Encoders (WAE). WAEs minimize the optimal transport <math>W_c(P_X, P_G)</math> for any cost function <math>c</math>. As is the case with VAEs, WAE objective function is also made up of two terms: the c-reconstruction cost and a regularizer term <math>\mathcal{D}_Z(P_Z, Q_Z)</math> which penalizes the discrepancy between two distributions in <math>\mathcal{Z}: P_Z\ and\ Q_Z</math>. <math>Q_Z</math> is a distribution of encoded points, i.e. <math>Q_Z := \mathbb{E}_{P_X}[Q(Z|X)]</math>. Note that when <math>c</math> is the squared cost and the regularizer term is the GAN objective, WAE is equivalent to the adversarial auto-encoders described in [2].<br />
<br />
* Experimental results of using WAE on MNIST and CelebA datasets with squared cost <math>c(x, y) = ||x - y||_2^2</math>. The results of these experiments show that WAEs have the good features of VAEs such as stable training, encoder-decoder architecture, and a nice latent manifold structure while simultaneously improving the quality of the generated samples.<br />
<br />
* Two different regularizers. One based on GANs and adversarial training in the latent space <math>\mathcal{Z}</math>. The other one is based on something called "Maximum Mean Discrepancy" which known to have high performance when matching high dimensional standard normal distributions. The second regularizer also makes the problem fully adversary-free min-min optimization problem, and gets rid of the problem of tuning the GAN.<br />
<br />
* The final contribution is the mathematical analysis used to derive the WAE objective function. In particular, the mathematical analysis shows that in the case of generative models, the primal form of <math>W_c(P_X, P_G)</math> is equivalent to a problem which deals with the optimization of a probabilistic encoder <math>Q(Z|X)</math><br />
<br />
The paper provides an ostensibly simple recipe to implement a non-blurry VAE (it is generative) It provides what looks like an elegant and logical way to cast the Wasserstein distance metric to setup the VAE/GAN problem.<br />
The paper gives three instructive VAEGAN model comparisons, unifying them thematically – Adversarial Autoencoders (AAE), Adversarial Variational Bayes (AVB), and the original Variational Autoencoders (VAE). These generalizations arise for the case with random decoders – the paper introduces the idea with deterministic decodes, and then extends it to random decoders – with play on the regularizer of the VAE which these papers replace with a GAN.<br />
<br />
=Proposed Method=<br />
The method proposed by the authors uses a novel auto-encoder architecture to minimize the optimal transport cost <math>W_c(P_X, P_G)</math>. In the optimization problem that follows, the decoder tries to accurately reconstruct the data points as measured by the cost function <math>c</math>. The encoder tries to achieve the following two conflicting goals at the same time: (1) try to match the distribution of the encoded data points <math>Q_Z := \mathbb{E}_{P_X}[Q(Z|X)]</math> to the prior distribution <math>P_Z</math> as measured by the divergence <math>\mathcal{D}_Z(P_Z, Q_Z)</math> and, (2) make sure that the latent space vectors encoded contain enough information so that the reconstruction of the data points are of high quality. The figure below illustrates this:<br />
<br />
[[File:ka2khan_figure_1.png|800px|thumb|center|Figure 1]]<br />
<br />
Figure 1: Both VAE and WAE have objectives which are composed of two terms. The two terms are the reconstruction cost and the regularizer term which penalizes the divergence between <math>P_Z</math> and <math>Q_Z</math>. VAE forces <math>Q(Z|X = x)</math> to match <math>P_Z</math> for the the different training examples drawn from <math>P_X</math>. As shown in the figure above, every red ball representing <math>Q_z</math> is forced to match <math>P_Z</math> depicted as whitish triangles. This causes intersection among red balls and results in reconstruction problems. On the other hand, WAE coerces the mixture <math>Q_Z := \int{Q(Z|X)\ dP_X}</math> to match <math>P_Z</math> as shown in the figure above. This provides a better chance of the encoded latent codes to have more distance between them. As a consequence of this, higher reconstruction quality is achieved.<br />
<br />
==Preliminaries and Notations==<br />
Authors use calligraphic letters to denote sets (for example, <math>\mathcal{X}</math>), capital letters for random variables (for example, <math>X</math>), and lower case letters for the values (for example, <math>x</math>). Probability distributions are are also denoted with capital letters (for example, <math>P(X)</math>) and the corresponding densities are denoted with lowercase letter (for example, <math>p(x)</math>).<br />
<br />
Several measure of difference between probability distributions are also used by the authors. These include f-divergences given by <math>D_f(p_X||p_G) := \int{f(\frac{p_X(x)}{p_G(x)})p_G(x)}dx\ \text{where}\ f:(0, \infty) &rarr; \mathcal{R}</math> is any convex function satisfying <math>f(1) = 0</math>. Other divergences used include KL divergence (<math>D_{KL}</math>) and Jensen-Shannon (<math>D_{JS}</math>) divergences.<br />
<br />
==Optimal Transport and its Dual Formations==<br />
<br />
A rich class of measure of distances between probability distributions is motivated by the optimal transport problem. One such formulation of the optimal transport problem is the Kantovorich's formulation given by:<br />
<br />
<math><br />
W_c(P_X, P_G) := \underset{\Gamma \in \mathcal{P}(X \sim P_X ,Y \sim P_G)}{inf} \mathbb{E}_{(X,Y) \sim \Gamma}[c(X,Y)],<br />
\text{where} \ c(x, y): \mathcal{X} \times \mathcal{X} &rarr; \mathcal{R_{+}}<br />
</math><br />
<br />
is any measurable cost function and <math>\mathcal{P}(X \sim P_X, Y \sim P_G)</math> is a set of all joint distributions of (X, Y) with marginals <math>P_X\ \text{and}\ P_G</math> respectively.<br />
<br />
A particularly interesting case is when <math>(\mathcal{X}, d)</math> is metric space and <math>c(x, y) = d^p(x, y)\ \text{for}\ p &ge; 1</math>. In this case <math>W_p</math>, the <math>p-th</math> root of <math>W_c</math>, is called the p-Wasserstein distance.<br />
<br />
When <math>c(x, y) = d(x, y)</math> the following Kantorovich-Rubinstein duality holds:<br />
<br />
<math>W_1(P_X, P_G) = \underset{f \in \mathcal{F}_L}{sup} \mathbb{E}_{X \sim P_x}[f(X)] = \mathbb{E}_{Y \sim P_G}[f(Y)]</math><br />
where <math>\mathcal{F}_L</math> is the class of all bounded 1-Lipschitz functions on <math>(\mathcal{X}, d)</math>.<br />
<br />
==Application to Generative Models: Wasserstein auto-encoders==<br />
The intuition behind modern generative models like VAEs and GANs is that they try to minimize specific distance measures between the data distribution <math>P_X</math> and the model <math>P_G</math>. Unfortunately, with the current knowledge and tools, it is usually really hard or even impossible to calculate most of the standard discrepancy measures especially when <math>P_X</math> is not known and <math>P_G</math> is parametrized by deep neural networks. Having said that, there are certain tricks available which can be employed to get around that difficulty.<br />
<br />
For KL-divergence <math>D_{KL}(P_X, P_G)</math> minimization, or equivalently the marginal log-likelihood <math>E_{P_X}[log_{P_G}(X)]</math> maximization, one can use the famous variational lower bound which provides a theoretically grounded framework. This has been used quite successfully by the VAEs. In the general case of minimizing f-divergence <math>D_f(P_X, P_G)</math>, using its dual formulation along with f-GANs and adversarial training is viable. Finally, OT cost <math>W_c(P_X, P_G)</math> can be minimized by using the Kantorovich-Rubinstein duality expressed as an adversarial objective. The Wasserstein-GAN implement this idea.<br />
<br />
In this paper, the authors focus on the latent variable models <math>P_G</math> given by a two step procedure. First, a code <math>Z</math> is sampled from a fixed distribution <math>P_Z</math> on a latent space <math>\mathcal{Z}</math>. Second step is to map <math>Z</math> to the image <math>X \in \mathcal{X} = \mathcal{R}^d</math> with a (possibly random) transformation. This gives us a density of the form<br />
<br />
<math><br />
p_G(x) := \int\limits_{\mathcal{Z}}{p_G(x|z)p_z(z)}dz,\ \forall x \in \mathcal{X}, <br />
</math><br />
<br />
provided all the probablities involved are properly defined. In order to keep things simple, the authors focus on non-random decoders, i.e., the generative models <math>P_G(X|Z)</math> deterministically map <math>Z</math> to <math>X = G(Z)</math> using a fixed map <math>G: \mathcal{Z} &rarr; \mathcal{X}</math>. Similar results hold for the random decoders as shown by the authors in the appendix B.1.<br />
<br />
Working under the model defined in the preceding paragraph, the authors find that OT cost takes a much simpler form as the transportation plan factors through the map <math>G:</math> instead of finding a coupling <math>\Gamma</math> between two random variables in the <math>\mathcal{X}</math> space, one given by the distribution <math>P_X</math> and the other by the the distribution <math>P_G</math>, it is enough to find a conditional distribution <math>Q(Z|X)</math> such that its <math>Z</math> marginal, <math>Q_Z)Z) := \mathbb{E}_{X \sim P_X}[Q(Z|X)]</math> is the same as the prior distribution <math>P_Z</math>. This is formalized by the theorem given below. The theorem given below was proven in [4] by the authors.<br />
<br />
'''Theorem 1.''' For <math>P_G</math> defined as above with deterministic <math>P_G(X|Z)</math> and any function <math>G:\mathcal{Z} &rarr; \mathcal{X}</math><br />
<br />
<math><br />
\underset{\Gamma \in \mathcal{P}(X \sim P_X ,Y \sim P_G)}{inf} \mathbb{E}_{(X,Y) \sim \Gamma}[c(X,Y)] = \underset{Q: Q_Z = P_Z}{inf} \mathbb{E}_{P_X} \mathbb{E}_{Q(Z|X)}[c(X, G(Z))]<br />
</math><br />
<br />
where <math>Q_Z</math> is the marginal distribution of <math>Z</math> when <math>X \sim P_X</math> and <math>Z \sim Q(Z|X)</math>.<br />
<br />
According to the authors, the result above allows optimization over random encoders <math>Q(Z|X)</math> instead of optimizing overall couplings of <math>X</math> and <math>Y</math>. Both problems are still constrained. To find a numerical solution, the authors relax the constraints on <math>Q_Z</math> by adding a regularizer term to the objective. This gives them the WAE objective:<br />
<br />
<math><br />
D_{WAE}(P_X, P_G) := \underset{Q(Z|X) \in \mathcal{Q}}{inf} \mathbb{E}_{P_X} \mathbb{E}_{Q(Z|X)}[c(X, G(Z))] + \lambda \cdot \mathcal{D}_Z(Q_Z, P_Z)<br />
</math><br />
<br />
where <math>\mathcal{Q}</math> is any nonparametric set of probabilistic encoders, <math>\mathcal{D}_Z</math> is an arbitrary measure of distance between <math>Q_Z</math> and <math>P_Z</math>, and <math>\lambda &gt; 0</math> is a hyperparameter. As is the case with the VAEs, the<br />
authors propose using deep neural networks to parameterize both encoders <math>Q</math> and decoders <math>G</math>. Note that, unlike VAEs, WAE allows for non-random encoders deterministically mapping their inputs to their latent codes.<br />
<br />
The authors propose two different regularizers <math>\mathcal{D}_Z(Q_Z, P_Z)</math><br />
<br />
===GAN-based <math>\mathcal{D}_z</math>===<br />
One of the option is to use <math>\mathcal{D}_Z(Q_Z, P_Z) = \mathcal{D}_{JS}(Q_Z, P_Z)</math> along with adversarial training for estimation. In particular, the discriminator (adversary) is used in the latent space <math>\mathcal{Z}</math> to classify "true" points sampled for <math>P_X</math> and "fake" ones samples from <math>Q_Z</math>. This leads to the WAE-GAN as described in Algorithm 1 listed below. Even though WAE-GAN still uses max-min optimization, one positive feature is that it moves the adversary from the input (pixel) space <math>\mathcal{X}</math> to the latent space <math>\mathcal{Z}</math>. Additionally, the true latent space distribution <math>P_Z</math> might have a nice shape with a single mode (for a Gaussian prior), making the task of matching much easier as opposed to matching an unknown, complex, and possibly multi-modal distributions which is usually the case in GANs. This leads to the second penalty.<br />
<br />
===MMD-based <math>\mathcal{D}_z</math>===<br />
For a positive-definite reproducing kernel <math>k: \mathcal{Z} \times \mathcal{Z} &rarr; \mathcal{R}</math>, the maximum mean discrepancy (MMD) is defined as:<br />
<br />
<center><math><br />
MMD_k(P_Z, Q_Z) = \left \Vert \int \limits_{\mathcal{Z}} {k(z, \cdot)dP_Z(z)} - \int \limits_{\mathcal{Z}} {k(z, \cdot)dQ_Z(z)} \right \|_{\mathcal{H}_k}<br />
</math>,</center><br />
<br />
where <math>\mathcal{H}_k</math> is the RKHS (reproducing kernel Hilbert space) of real-valued functions mappings <math>\mathcal{Z}</math> to <math>\mathcal{R}</math>. If <math>k</math> is characteristi then <math>MMD_k</math> defines a metric and can be used as a distance measure. The authors propose to use <math>\mathcal{D}_Z(P_Z, Q_Z) = MMD_k(P_Z, Q_Z)</math>. MMD also have an unbiased U-statistic estimator which can be used alongwith stochastic gradient descent (SGD) methods. This gives us WAE-MMD as described in the Algorithm 2 listed below. Note that MMD is known to perform well when matching high dimensional standard normal distributions, so it is expected that this penalty will work well when the prior <math>P_Z</math> is Gaussian.<br />
<br />
[[File:ka2khan_figure_2.png|800px|thumb|center|Algorithms- WAE-GAN on left and WAE-MMD on right]]<br />
<br />
=Related Work=<br />
==Literature on auto-encoders==<br />
Classical unregularized auto-encoders have an objective function which only tries to minimize the reconstruction cost. This results in distinct data points being encoded into distinct zones distributed chaotically across the latent space <math>\mathcal{Z}</math>. The latent space <math>\mathcal{Z}</math> in this scenario contains huge "holes" for which the decoder <math>P_G(X|Z)</math> has never been trained. In general, the encoder trained this way do not provide terribly useful representations and sampling from the latent space <math>\mathcal{Z}</math> becomes a difficult task [12].<br />
<br />
VAEs [1] minimize the KL-divergence <math>D_{KL}(P_X, P_G)</math> which consists of the reconstruction cost and the regularizer <math>\mathbb{E}_{P_X}[D_{KL}(Q(|X), P_Z)]</math>. The regularizer penalizes the difference in the encoded training images and the prior <math>P_Z</math>. But this penalty still does not guarantee that the overall encoded distribution matches the prior distribution as WAE does. In addition, VAEs require a non-degenerate (i.e. non-deterministic) Gaussian encoders along with random decoders. Another paper [11] later, proposed a method which allows the use of non-Gaussian encoders with VAEs. In the meanwhile, WAE minimizes <math>W_{c}(P_X, P_G)</math> and allows probabilistic and deterministic encoder and decoder pairs.<br />
<br />
When parameters are appropriately defined, WAE is able to generalize AAE in two ways: it can use any cost function in the input space and use any discrepancy measure <math>D_Z</math> in latent space <math>Z</math> other than the adversarial one.<br />
<br />
There has been work done on regularized auto-encoders called InfoVAE [14], which has objective similar to [4] but using different motivations and arguments.<br />
<br />
WAEs explicitly define the cost function <math>c(x,y)</math>, whereas VAEs rely on an implicitly through a negative log likelihood term. It theoretically can induce any arbitrary cost function, but in practice can require an estimation of the normalizing constant that can be different for values of <math>z</math>.<br />
<br />
==Literature on Optimal Transport (OT)==<br />
[15] provides methods for computing OT cost for large-scale data using SGD and sampling. The WGAN [5] proposes a generative model which minimizes 1-Wasserstein distance <math>W_1(P_X, P_G)</math>. The WGAN algorithm does not provide an encoder and cannot be easily applied to any arbitrary cost <math>W_C</math>. The model proposed in [5] uses the dual form, in contrast, the model proposed in this paper uses the primal form. The primal form allows the use of any arbitrary cost function <math>c</math> and naturally, comes with an encoder. <br />
<br />
In order to compute <math>W_c(P_X, P_G)</math> or <math>W_1(P_X, P_G)</math>, the model needs to handle various non-trivial constraints, various methods has be proposed in the literature ([5], [2], [8], [16], [15], [17], [18]) to avoid this difficulty .<br />
<br />
==Literature on GANs==<br />
A lot of the GAN variations which have been proposed in the literature come without an encoder. Examples include WGAN and f-GAN. These models are deficient in cases where a reconstruction of latent space is needed to use the learned manifold.<br />
<br />
There have been numerous models proposed in the literature which try to combine the adversarial training of GANs with auto-encoder architectures. Some examples are [19], [20], [21], and [22]. There has also been work done in which reproducing kernels have been used in the context of GANS ([23], [24]).<br />
<br />
=Experiments=<br />
Experiments were used to empirically evaluate the proposed WAE model. <br />
<br />
'''Experimental setup'''<br />
<br />
For experimental setup, authors used <math> \small P_Z</math> and squared cost function <math> \small c(x,y)</math> for data points.<br />
Deterministic encoder-decoder pairs were used.The authors conducted experiments using the following two real-world datasets: (1) MNIST [27] made up of 70k images, and (2) CelebA [28] consisting of approximately 203k images. For test reconstruction and interpolations a pair of of held out images, <math>(x,y)</math> from the test set are Auto-encoded (separately), to produce <math>(z_x, z_y)</math> in the latent space<br />
<br />
The main evaluation criteria were to see if the WAE model can simultaneously achieve: <br />
<br />
<ol><br />
<li>accurate reconstruction of the data points</li><br />
<li>resonable geometry of the latent manifold</li><br />
<li>generation of high quality random samples</li><br />
</ol><br />
<br />
For the model to generalize well (1) and (2) should be met on both the training and test data set.<br />
<br />
The proposed model achieve reasonably good results as highlighted in the figures given below:<br />
<br />
[[File:ka2khan_figure_3.png|800px|thumb|center|Using CelebA dataset]]<br />
<br />
[[File:ka2khan_figure_4.png|800px|thumb|center|Using CelebA dataset, FID (Fréchet Inception Distance<br />
[32]): smaller is better, sharpness: larger is better]]<br />
<br />
=Conclusion=<br />
The authors proposed a new class of algorithms for building a generative model called Wasserstein Autoencoders based on optimal transport cost. They related the newly proposed model to the existing probabilistic modeling techniques. They empirically evaluated the proposed models using two real-world datasets. They compared the results obtained using their proposed model with the results obtained using VAEs on the same dataset to show that the proposed models generate sample images of higher quality in addition to being easier to train and having good reconstruction quality of the data points.<br />
<br />
The authors claim that in future work, they will further explore the criteria for matching the encoding distribution <math>Q_Z</math> to the prior distribution <math>P_Z</math>, evaluate whether it is feasible to adversarially train the cost function <math>c</math>in the input space <math>\mathcal{X}</math>, and a theoretical analysis of the dual-formations for WAE-GAN and WAE-MMD.<br />
<br />
=Future Work=<br />
Following the work of this paper, another generative model was introduced by [34] that is based on the concept of optimal transport. Optimal transport is basically the distances between probability distributions by transporting one of the distributions to the other (and hence the name of optimal transport). Then, a new simple model called "Sliced-Wasserstein Autoencoders" (SWAE) is presented, which is easily implemented, and provides the capabilities of Wasserstein Autoencoders.<br />
<br />
([https://openreview.net/forum?id=HkL7n1-0b]) The results from MNIST and CelebA datasets look convincing, though could include additional evaluation to compare the adversarial loss with the straightforward MMD metric and potentially discuss their pros and cons. In some sense, given the challenges in evaluating and comparing closely related auto-encoder solutions, the authors could design demonstrative experiments for cases where Wassersterin distance helps and maybe its potential limitations.<br />
<br />
<br />
<br />
=Critique=<br />
<br />
Although this paper presented some empirical tests to explain its method in an appropriate way, it would be better to provide some clearer notations including the details of the architectures in their experiments. Furthermore, they could benefit from performing some comparisons between the results of their work and other similar works. As pointed out by a reviewer, the closest work to this paper is the adversarial variational bayes framework by Mescheder et.al. which also attempts at unifying VAEs and GANs. Although the authors describe the conceptual differences and advantages over that approach, it will be beneficial to actually include some comparisons in the results section.<br />
Moreover, the performance of the algorithm is not a significant improvement compared to previous VAE algorithm. The performance can be described and tested if the author performed empirical tests on various data sets. However, the methodology is flexible and unified to other types of the algorithm which is a huge benefit.<br />
<br />
=References=<br />
[1] D. P. Kingma and M. Welling. Auto-encoding variational Bayes. In ICLR, 2014.<br />
<br />
[2] A. Makhzani, J. Shlens, N. Jaitly, and I. Goodfellow. Adversarial autoencoders. In ICLR, 2016.<br />
<br />
[3] Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial nets. In NIPS, pages 2672–2680, 2014.<br />
<br />
[4] O. Bousquet, S. Gelly, I. Tolstikhin, C. J. Simon-Gabriel, and B. Schölkopf. From optimal transport to generative modeling: the VEGAN cookbook, 2017.<br />
<br />
[5] M. Arjovsky, S. Chintala, and L. Bottou. Wasserstein GAN, 2017.<br />
<br />
[6] C. Villani. Topics in Optimal Transportation. AMS Graduate Studies in Mathematics, 2003.<br />
<br />
[7] Sebastian Nowozin, Botond Cseke, and Ryota Tomioka. f-GAN: Training generative neural samplers using variational divergence minimization. In NIPS, 2016.<br />
<br />
[8] I. Gulrajani, F. Ahmed, M. Arjovsky, V. Domoulin, and A. Courville. Improved training of wasserstein GANs, 2017.<br />
<br />
[9] A. Gretton, K. M. Borgwardt, M. J. Rasch, B. Schölkopf, and A. J. Smola. A kernel two-sample test. Journal of Machine Learning Research, 13:723–773, 2012.<br />
<br />
[10] F. Liese and K.-J. Miescke. Statistical Decision Theory. Springer, 2008.<br />
<br />
[11] L. Mescheder, S. Nowozin, and A. Geiger. Adversarial variational bayes: Unifying variational autoencoders and generative adversarial networks, 2017.<br />
<br />
[12] Y. Bengio, A. Courville, and P. Vincent. Representation learning: A review and new perspectives. Pattern Analysis and Machine Intelligence, 35, 2013.<br />
<br />
[13] M. D. Hoffman and M. Johnson. Elbo surgery: yet another way to carve up the variational evidence lower bound. In NIPS Workshop on Advances in Approximate Bayesian Inference, 2016.<br />
<br />
[14] S. Zhao, J. Song, and S. Ermon. InfoVAE: Information maximizing variational autoencoders, 2017.<br />
<br />
[15] A. Genevay, M. Cuturi, G. Peyré, and F. R. Bach. Stochastic optimization for large-scale optimal transport. In Advances in Neural Information Processing Systems, pages 3432–3440, 2016. <br />
<br />
[16] M. Cuturi. Sinkhorn distances: Lightspeed computation of optimal transport. In Advances in Neural Information Processing Systems, pages 2292–2300, 2013.<br />
<br />
[17] Lenaic Chizat, Gabriel Peyré, Bernhard Schmitzer, and François-Xavier Vialard. Unbalanced optimal transport: geometry and kantorovich formulation. arXiv preprint arXiv:1508.05216, 2015.<br />
<br />
[18] Matthias Liero, Alexander Mielke, and Giuseppe Savaré. Optimal entropy-transport problems and a new hellinger-kantorovich distance between positive measures. arXiv preprint arXiv:1508.07941, 2015.<br />
<br />
[19] J. Zhao, M. Mathieu, and Y. LeCun. Energy-based generative adversarial network. In ICLR, 2017.<br />
<br />
[20] V. Dumoulin, I. Belghazi, B. Poole, A. Lamb, M. Arjovsky, O. Mastropietro, and A. Courville. Adversarially learned inference. In ICLR, 2017.<br />
<br />
[21] D. Ulyanov, A. Vedaldi, and V. Lempitsky. It takes (only) two: Adversarial generator-encoder networks, 2017.<br />
<br />
[22] D. Berthelot, T. Schumm, and L. Metz. Began: Boundary equilibrium generative adversarial networks, 2017.<br />
<br />
[23] Y. Li, K. Swersky, and R. Zemel. Generative moment matching networks. In ICML, 2015. <br />
<br />
[24] G. K. Dziugaite, D. M. Roy, and Z. Ghahramani. Training generative neural networks via maximum mean discrepancy optimization. In UAI, 2015.<br />
<br />
[25] R. Reddi, A. Ramdas, A. Singh, B. Poczos, and L. Wasserman. On the high-dimensional power of a linear-time two sample test under mean-shift alternatives. In AISTATS, 2015.<br />
<br />
[26] C. L. Li, W. C. Chang, Y. Cheng, Y. Yang, and B. Poczos. Mmd gan: Towards deeper understanding of moment matching network, 2017.<br />
<br />
[27] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. Gradient-based learning applied to document recognition. In Proceedings of the IEEE, volume 86(11), pages 2278–2324, 1998.<br />
<br />
[28] Ziwei Liu, Ping Luo, Xiaogang Wang, and Xiaoou Tang. Deep learning face attributes in the wild. In Proceedings of International Conference on Computer Vision (ICCV), 2015.<br />
<br />
[29] D. P. Kingma and J. Lei. Adam: A method for stochastic optimization, 2014.<br />
<br />
[30] A. Radford, L. Metz, and S. Chintala. Unsupervised representation learning with deep convolutional generative adversarial networks. In ICLR, 2016.<br />
<br />
[31] S. Ioffe and C. Szegedy. Batch normalization: Accelerating deep network training by reducing internal covariate shift, 2015.<br />
<br />
[32] Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, Günter Klambauer, and Sepp Hochreiter. GANs trained by a two time-scale update rule converge to a nash equilibrium. arXiv preprint arXiv:1706.08500, 2017.<br />
<br />
[33] B. Poole, A. Alemi, J. Sohl-Dickstein, and A. Angelova. Improved generator objectives for GANs, 2016.<br />
<br />
[34] S. Kolouri, C. E. Martin, and G. K. Rohde. Sliced-wasserstein autoencoder: An embarrassingly simple generative model. arXiv preprint arXiv:1804.01947, 2018.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Wasserstein_Auto-encoders&diff=41933Wasserstein Auto-encoders2018-11-29T23:14:59Z<p>Gsahu: /* MMD-based \mathcal{D}_z */</p>
<hr />
<div>The first version of this work was published in 2017 and this version (which is the third revision) is presented in ICLR 2018. Source code for the first version is available [https://github.com/tolstikhin/wae here]<br />
<br />
=Introduction=<br />
Early successes in the field of representation learning were based on supervised approaches, which used large labeled datasets to achieve impressive results. On the other hand, popular unsupervised generative modeling methods mainly consisted of probabilistic approaches focusing on low dimensional data. In recent years, there have been models proposed which try to combine these two approaches. One such popular method is called variational auto-encoders (VAEs). VAEs are theoretically elegant but have a major drawback of generating blurry sample images when used for modeling natural images. In comparison, generative adversarial networks (GANs) produce much sharper sample images but have their own list of problems which include a lack of encoder, harder to train, and the "mode collapse" problem. Mode collapse problem refers to the inability of the model to capture all the variability in the true data distribution. Currently, there has been a lot of activities around finding and evaluating numerous GANs architectures and combining VAEs and GANs, but a model which combines the best of both GANs and VAEs is yet to be discovered.<br />
<br />
The work done in this paper builds upon the theoretical work done in Bousquet et al.[2017] [4]. The authors tackle generative modeling using optimal transport (OT). The OT cost is defined as the measure of distance between probability distributions.<br />
<br />
To be more specific on the OT:<br />
<br />
Given a function <math>c : X × Y → R</math>, they seek a minimizer of <math> C(µ, ν) := \underset{π ∈ Π(µ, ν)}{inf} \int_{X×Y}{c(x, y)dπ(x, y)}</math><br />
<br />
The measures <math>π ∈ Π(µ, ν)</math> are called transport plans or transference plans. The measures <math>π ∈ Π(µ, ν)</math> achieving the infimum are called optimal transport plans. The classical interpretation of this problem is the problem of minimizing the total cost <math>C(µ, ν)</math> of transporting the mass distribution <math>µ</math> to the mass distribution <math>ν</math>, where the cost of transporting one unit of mass at the point <math>x ∈ X</math> to one unit of mass at the point <math>y ∈ Y</math> is given by the cost function <math>c(x, y)</math>.<br />
<br />
One of the features of OT cost which is beneficial is that it provides much weaker topology when compared to other costs, including f-divergences which are associated with the original GAN algorithms. <br />
This particular feature is crucial in applications where the data is usually supported on low dimensional manifolds in the input space. This result in a problem with the stronger notions of distances such as f-divergences as they often max out and provide no useful gradients for training. In comparison, the OT cost has been claimed to behave much more nicely [5, 8]. Despite the preceding claim, the implementation, which is similar to GANs, still requires the addition of a constraint or a regularization term into the objective function.<br />
<br />
==Original Contributions==<br />
Let <math>P_X</math> be the true but unknown data distribution, <math>P_G</math> be the latent variable model specified by the prior distribution <math>P_Z</math> of latent codes <math>Z \in \mathcal{Z}</math> and the generative model <math>P_G(X|Z)</math> of the data points <math>X \in \mathcal{X}</math> given <math>Z</math>. The goal in this paper is to minimize <math>OT\ W_c(P_X, P_G)</math>.<br />
<br />
The main contributions are given below:<br />
<br />
* A new class of auto-encoders called Wasserstein Auto-Encoders (WAE). WAEs minimize the optimal transport <math>W_c(P_X, P_G)</math> for any cost function <math>c</math>. As is the case with VAEs, WAE objective function is also made up of two terms: the c-reconstruction cost and a regularizer term <math>\mathcal{D}_Z(P_Z, Q_Z)</math> which penalizes the discrepancy between two distributions in <math>\mathcal{Z}: P_Z\ and\ Q_Z</math>. <math>Q_Z</math> is a distribution of encoded points, i.e. <math>Q_Z := \mathbb{E}_{P_X}[Q(Z|X)]</math>. Note that when <math>c</math> is the squared cost and the regularizer term is the GAN objective, WAE is equivalent to the adversarial auto-encoders described in [2].<br />
<br />
* Experimental results of using WAE on MNIST and CelebA datasets with squared cost <math>c(x, y) = ||x - y||_2^2</math>. The results of these experiments show that WAEs have the good features of VAEs such as stable training, encoder-decoder architecture, and a nice latent manifold structure while simultaneously improving the quality of the generated samples.<br />
<br />
* Two different regularizers. One based on GANs and adversarial training in the latent space <math>\mathcal{Z}</math>. The other one is based on something called "Maximum Mean Discrepancy" which known to have high performance when matching high dimensional standard normal distributions. The second regularizer also makes the problem fully adversary-free min-min optimization problem, and gets rid of the problem of tuning the GAN.<br />
<br />
* The final contribution is the mathematical analysis used to derive the WAE objective function. In particular, the mathematical analysis shows that in the case of generative models, the primal form of <math>W_c(P_X, P_G)</math> is equivalent to a problem which deals with the optimization of a probabilistic encoder <math>Q(Z|X)</math><br />
<br />
The paper provides an ostensibly simple recipe to implement a non-blurry VAE (it is generative) It provides what looks like an elegant and logical way to cast the Wasserstein distance metric to setup the VAE/GAN problem.<br />
The paper gives three instructive VAEGAN model comparisons, unifying them thematically – Adversarial Autoencoders (AAE), Adversarial Variational Bayes (AVB), and the original Variational Autoencoders (VAE). These generalizations arise for the case with random decoders – the paper introduces the idea with deterministic decodes, and then extends it to random decoders – with play on the regularizer of the VAE which these papers replace with a GAN.<br />
<br />
=Proposed Method=<br />
The method proposed by the authors uses a novel auto-encoder architecture to minimize the optimal transport cost <math>W_c(P_X, P_G)</math>. In the optimization problem that follows, the decoder tries to accurately reconstruct the data points as measured by the cost function <math>c</math>. The encoder tries to achieve the following two conflicting goals at the same time: (1) try to match the distribution of the encoded data points <math>Q_Z := \mathbb{E}_{P_X}[Q(Z|X)]</math> to the prior distribution <math>P_Z</math> as measured by the divergence <math>\mathcal{D}_Z(P_Z, Q_Z)</math> and, (2) make sure that the latent space vectors encoded contain enough information so that the reconstruction of the data points are of high quality. The figure below illustrates this:<br />
<br />
[[File:ka2khan_figure_1.png|800px|thumb|center|Figure 1]]<br />
<br />
Figure 1: Both VAE and WAE have objectives which are composed of two terms. The two terms are the reconstruction cost and the regularizer term which penalizes the divergence between <math>P_Z</math> and <math>Q_Z</math>. VAE forces <math>Q(Z|X = x)</math> to match <math>P_Z</math> for the the different training examples drawn from <math>P_X</math>. As shown in the figure above, every red ball representing <math>Q_z</math> is forced to match <math>P_Z</math> depicted as whitish triangles. This causes intersection among red balls and results in reconstruction problems. On the other hand, WAE coerces the mixture <math>Q_Z := \int{Q(Z|X)\ dP_X}</math> to match <math>P_Z</math> as shown in the figure above. This provides a better chance of the encoded latent codes to have more distance between them. As a consequence of this, higher reconstruction quality is achieved.<br />
<br />
==Preliminaries and Notations==<br />
Authors use calligraphic letters to denote sets (for example, <math>\mathcal{X}</math>), capital letters for random variables (for example, <math>X</math>), and lower case letters for the values (for example, <math>x</math>). Probability distributions are are also denoted with capital letters (for example, <math>P(X)</math>) and the corresponding densities are denoted with lowercase letter (for example, <math>p(x)</math>).<br />
<br />
Several measure of difference between probability distributions are also used by the authors. These include f-divergences given by <math>D_f(p_X||p_G) := \int{f(\frac{p_X(x)}{p_G(x)})p_G(x)}dx\ \text{where}\ f:(0, \infty) &rarr; \mathcal{R}</math> is any convex function satisfying <math>f(1) = 0</math>. Other divergences used include KL divergence (<math>D_{KL}</math>) and Jensen-Shannon (<math>D_{JS}</math>) divergences.<br />
<br />
==Optimal Transport and its Dual Formations==<br />
<br />
A rich class of measure of distances between probability distributions is motivated by the optimal transport problem. One such formulation of the optimal transport problem is the Kantovorich's formulation given by:<br />
<br />
<math><br />
W_c(P_X, P_G) := \underset{\Gamma \in \mathcal{P}(X \sim P_X ,Y \sim P_G)}{inf} \mathbb{E}_{(X,Y) \sim \Gamma}[c(X,Y)],<br />
\text{where} \ c(x, y): \mathcal{X} \times \mathcal{X} &rarr; \mathcal{R_{+}}<br />
</math><br />
<br />
is any measurable cost function and <math>\mathcal{P}(X \sim P_X, Y \sim P_G)</math> is a set of all joint distributions of (X, Y) with marginals <math>P_X\ \text{and}\ P_G</math> respectively.<br />
<br />
A particularly interesting case is when <math>(\mathcal{X}, d)</math> is metric space and <math>c(x, y) = d^p(x, y)\ \text{for}\ p &ge; 1</math>. In this case <math>W_p</math>, the <math>p-th</math> root of <math>W_c</math>, is called the p-Wasserstein distance.<br />
<br />
When <math>c(x, y) = d(x, y)</math> the following Kantorovich-Rubinstein duality holds:<br />
<br />
<math>W_1(P_X, P_G) = \underset{f \in \mathcal{F}_L}{sup} \mathbb{E}_{X \sim P_x}[f(X)] = \mathbb{E}_{Y \sim P_G}[f(Y)]</math><br />
where <math>\mathcal{F}_L</math> is the class of all bounded 1-Lipschitz functions on <math>(\mathcal{X}, d)</math>.<br />
<br />
==Application to Generative Models: Wasserstein auto-encoders==<br />
The intuition behind modern generative models like VAEs and GANs is that they try to minimize specific distance measures between the data distribution <math>P_X</math> and the model <math>P_G</math>. Unfortunately, with the current knowledge and tools, it is usually really hard or even impossible to calculate most of the standard discrepancy measures especially when <math>P_X</math> is not known and <math>P_G</math> is parametrized by deep neural networks. Having said that, there are certain tricks available which can be employed to get around that difficulty.<br />
<br />
For KL-divergence <math>D_{KL}(P_X, P_G)</math> minimization, or equivalently the marginal log-likelihood <math>E_{P_X}[log_{P_G}(X)]</math> maximization, one can use the famous variational lower bound which provides a theoretically grounded framework. This has been used quite successfully by the VAEs. In the general case of minimizing f-divergence <math>D_f(P_X, P_G)</math>, using its dual formulation along with f-GANs and adversarial training is viable. Finally, OT cost <math>W_c(P_X, P_G)</math> can be minimized by using the Kantorovich-Rubinstein duality expressed as an adversarial objective. The Wasserstein-GAN implement this idea.<br />
<br />
In this paper, the authors focus on the latent variable models <math>P_G</math> given by a two step procedure. First, a code <math>Z</math> is sampled from a fixed distribution <math>P_Z</math> on a latent space <math>\mathcal{Z}</math>. Second step is to map <math>Z</math> to the image <math>X \in \mathcal{X} = \mathcal{R}^d</math> with a (possibly random) transformation. This gives us a density of the form<br />
<br />
<math><br />
p_G(x) := \int\limits_{\mathcal{Z}}{p_G(x|z)p_z(z)}dz,\ \forall x \in \mathcal{X}, <br />
</math><br />
<br />
provided all the probablities involved are properly defined. In order to keep things simple, the authors focus on non-random decoders, i.e., the generative models <math>P_G(X|Z)</math> deterministically map <math>Z</math> to <math>X = G(Z)</math> using a fixed map <math>G: \mathcal{Z} &rarr; \mathcal{X}</math>. Similar results hold for the random decoders as shown by the authors in the appendix B.1.<br />
<br />
Working under the model defined in the preceding paragraph, the authors find that OT cost takes a much simpler form as the transportation plan factors through the map <math>G:</math> instead of finding a coupling <math>\Gamma</math> between two random variables in the <math>\mathcal{X}</math> space, one given by the distribution <math>P_X</math> and the other by the the distribution <math>P_G</math>, it is enough to find a conditional distribution <math>Q(Z|X)</math> such that its <math>Z</math> marginal, <math>Q_Z)Z) := \mathbb{E}_{X \sim P_X}[Q(Z|X)]</math> is the same as the prior distribution <math>P_Z</math>. This is formalized by the theorem given below. The theorem given below was proven in [4] by the authors.<br />
<br />
'''Theorem 1.''' For <math>P_G</math> defined as above with deterministic <math>P_G(X|Z)</math> and any function <math>G:\mathcal{Z} &rarr; \mathcal{X}</math><br />
<br />
<math><br />
\underset{\Gamma \in \mathcal{P}(X \sim P_X ,Y \sim P_G)}{inf} \mathbb{E}_{(X,Y) \sim \Gamma}[c(X,Y)] = \underset{Q: Q_Z = P_Z}{inf} \mathbb{E}_{P_X} \mathbb{E}_{Q(Z|X)}[c(X, G(Z))]<br />
</math><br />
<br />
where <math>Q_Z</math> is the marginal distribution of <math>Z</math> when <math>X \sim P_X</math> and <math>Z \sim Q(Z|X)</math>.<br />
<br />
According to the authors, the result above allows optimization over random encoders <math>Q(Z|X)</math> instead of optimizing overall couplings of <math>X</math> and <math>Y</math>. Both problems are still constrained. To find a numerical solution, the authors relax the constraints on <math>Q_Z</math> by adding a regularizer term to the objective. This gives them the WAE objective:<br />
<br />
<math><br />
D_{WAE}(P_X, P_G) := \underset{Q(Z|X) \in \mathcal{Q}}{inf} \mathbb{E}_{P_X} \mathbb{E}_{Q(Z|X)}[c(X, G(Z))] + \lambda \cdot \mathcal{D}_Z(Q_Z, P_Z)<br />
</math><br />
<br />
where <math>\mathcal{Q}</math> is any nonparametric set of probabilistic encoders, <math>\mathcal{D}_Z</math> is an arbitrary measure of distance between <math>Q_Z</math> and <math>P_Z</math>, and <math>\lambda &gt; 0</math> is a hyperparameter. As is the case with the VAEs, the<br />
authors propose using deep neural networks to parameterize both encoders <math>Q</math> and decoders <math>G</math>. Note that, unlike VAEs, WAE allows for non-random encoders deterministically mapping their inputs to their latent codes.<br />
<br />
The authors propose two different regularizers <math>\mathcal{D}_Z(Q_Z, P_Z)</math><br />
<br />
===GAN-based <math>\mathcal{D}_z</math>===<br />
One of the option is to use <math>\mathcal{D}_Z(Q_Z, P_Z) = \mathcal{D}_{JS}(Q_Z, P_Z)</math> along with adversarial training for estimation. In particular, the discriminator (adversary) is used in the latent space <math>\mathcal{Z}</math> to classify "true" points sampled for <math>P_X</math> and "fake" ones samples from <math>Q_Z</math>. This leads to the WAE-GAN as described in Algorithm 1 listed below. Even though WAE-GAN still uses max-min optimization, one positive feature is that it moves the adversary from the input (pixel) space <math>\mathcal{X}</math> to the latent space <math>\mathcal{Z}</math>. Additionally, the true latent space distribution <math>P_Z</math> might have a nice shape with a single mode (for a Gaussian prior), making the task of matching much easier as opposed to matching an unknown, complex, and possibly multi-modal distributions which is usually the case in GANs. This leads to the second penalty.<br />
<br />
===MMD-based <math>\mathcal{D}_z</math>===<br />
For a positive-definite reproducing kernel <math>k: \mathcal{Z} \times \mathcal{Z} &rarr; \mathcal{R}</math>, the maximum mean discrepancy (MMD) is defined as:<br />
<br />
<center><math><br />
MMD_k(P_Z, Q_Z) = \left \Vert \int \limits_{\mathcal{Z}} {k(z, \cdot)dP_Z(z)} - \int \limits_{\mathcal{Z}} {k(z, \cdot)dQ_Z(z)} \right \|_{\mathcal{H}_k}<br />
</math></center>,<br />
<br />
where <math>\mathcal{H}_k</math> is the RKHS (reproducing kernel Hilbert space) of real-valued functions mappings <math>\mathcal{Z}</math> to <math>\mathcal{R}</math>. If <math>k</math> is characteristi then <math>MMD_k</math> defines a metric and can be used as a distance measure. The authors propose to use <math>\mathcal{D}_Z(P_Z, Q_Z) = MMD_k(P_Z, Q_Z)</math>. MMD also have an unbiased U-statistic estimator which can be used alongwith stochastic gradient descent (SGD) methods. This gives us WAE-MMD as described in the Algorithm 2 listed below. Note that MMD is known to perform well when matching high dimensional standard normal distributions, so it is expected that this penalty will work well when the prior <math>P_Z</math> is Gaussian.<br />
<br />
[[File:ka2khan_figure_2.png|800px|thumb|center|Algorithms- WAE-GAN on left and WAE-MMD on right]]<br />
<br />
=Related Work=<br />
==Literature on auto-encoders==<br />
Classical unregularized auto-encoders have an objective function which only tries to minimize the reconstruction cost. This results in distinct data points being encoded into distinct zones distributed chaotically across the latent space <math>\mathcal{Z}</math>. The latent space <math>\mathcal{Z}</math> in this scenario contains huge "holes" for which the decoder <math>P_G(X|Z)</math> has never been trained. In general, the encoder trained this way do not provide terribly useful representations and sampling from the latent space <math>\mathcal{Z}</math> becomes a difficult task [12].<br />
<br />
VAEs [1] minimize the KL-divergence <math>D_{KL}(P_X, P_G)</math> which consists of the reconstruction cost and the regularizer <math>\mathbb{E}_{P_X}[D_{KL}(Q(|X), P_Z)]</math>. The regularizer penalizes the difference in the encoded training images and the prior <math>P_Z</math>. But this penalty still does not guarantee that the overall encoded distribution matches the prior distribution as WAE does. In addition, VAEs require a non-degenerate (i.e. non-deterministic) Gaussian encoders along with random decoders. Another paper [11] later, proposed a method which allows the use of non-Gaussian encoders with VAEs. In the meanwhile, WAE minimizes <math>W_{c}(P_X, P_G)</math> and allows probabilistic and deterministic encoder and decoder pairs.<br />
<br />
When parameters are appropriately defined, WAE is able to generalize AAE in two ways: it can use any cost function in the input space and use any discrepancy measure <math>D_Z</math> in latent space <math>Z</math> other than the adversarial one.<br />
<br />
There has been work done on regularized auto-encoders called InfoVAE [14], which has objective similar to [4] but using different motivations and arguments.<br />
<br />
WAEs explicitly define the cost function <math>c(x,y)</math>, whereas VAEs rely on an implicitly through a negative log likelihood term. It theoretically can induce any arbitrary cost function, but in practice can require an estimation of the normalizing constant that can be different for values of <math>z</math>.<br />
<br />
==Literature on Optimal Transport (OT)==<br />
[15] provides methods for computing OT cost for large-scale data using SGD and sampling. The WGAN [5] proposes a generative model which minimizes 1-Wasserstein distance <math>W_1(P_X, P_G)</math>. The WGAN algorithm does not provide an encoder and cannot be easily applied to any arbitrary cost <math>W_C</math>. The model proposed in [5] uses the dual form, in contrast, the model proposed in this paper uses the primal form. The primal form allows the use of any arbitrary cost function <math>c</math> and naturally, comes with an encoder. <br />
<br />
In order to compute <math>W_c(P_X, P_G)</math> or <math>W_1(P_X, P_G)</math>, the model needs to handle various non-trivial constraints, various methods has be proposed in the literature ([5], [2], [8], [16], [15], [17], [18]) to avoid this difficulty .<br />
<br />
==Literature on GANs==<br />
A lot of the GAN variations which have been proposed in the literature come without an encoder. Examples include WGAN and f-GAN. These models are deficient in cases where a reconstruction of latent space is needed to use the learned manifold.<br />
<br />
There have been numerous models proposed in the literature which try to combine the adversarial training of GANs with auto-encoder architectures. Some examples are [19], [20], [21], and [22]. There has also been work done in which reproducing kernels have been used in the context of GANS ([23], [24]).<br />
<br />
=Experiments=<br />
Experiments were used to empirically evaluate the proposed WAE model. <br />
<br />
'''Experimental setup'''<br />
<br />
For experimental setup, authors used <math> \small P_Z</math> and squared cost function <math> \small c(x,y)</math> for data points.<br />
Deterministic encoder-decoder pairs were used.The authors conducted experiments using the following two real-world datasets: (1) MNIST [27] made up of 70k images, and (2) CelebA [28] consisting of approximately 203k images. For test reconstruction and interpolations a pair of of held out images, <math>(x,y)</math> from the test set are Auto-encoded (separately), to produce <math>(z_x, z_y)</math> in the latent space<br />
<br />
The main evaluation criteria were to see if the WAE model can simultaneously achieve: <br />
<br />
<ol><br />
<li>accurate reconstruction of the data points</li><br />
<li>resonable geometry of the latent manifold</li><br />
<li>generation of high quality random samples</li><br />
</ol><br />
<br />
For the model to generalize well (1) and (2) should be met on both the training and test data set.<br />
<br />
The proposed model achieve reasonably good results as highlighted in the figures given below:<br />
<br />
[[File:ka2khan_figure_3.png|800px|thumb|center|Using CelebA dataset]]<br />
<br />
[[File:ka2khan_figure_4.png|800px|thumb|center|Using CelebA dataset, FID (Fréchet Inception Distance<br />
[32]): smaller is better, sharpness: larger is better]]<br />
<br />
=Conclusion=<br />
The authors proposed a new class of algorithms for building a generative model called Wasserstein Autoencoders based on optimal transport cost. They related the newly proposed model to the existing probabilistic modeling techniques. They empirically evaluated the proposed models using two real-world datasets. They compared the results obtained using their proposed model with the results obtained using VAEs on the same dataset to show that the proposed models generate sample images of higher quality in addition to being easier to train and having good reconstruction quality of the data points.<br />
<br />
The authors claim that in future work, they will further explore the criteria for matching the encoding distribution <math>Q_Z</math> to the prior distribution <math>P_Z</math>, evaluate whether it is feasible to adversarially train the cost function <math>c</math>in the input space <math>\mathcal{X}</math>, and a theoretical analysis of the dual-formations for WAE-GAN and WAE-MMD.<br />
<br />
=Future Work=<br />
Following the work of this paper, another generative model was introduced by [34] that is based on the concept of optimal transport. Optimal transport is basically the distances between probability distributions by transporting one of the distributions to the other (and hence the name of optimal transport). Then, a new simple model called "Sliced-Wasserstein Autoencoders" (SWAE) is presented, which is easily implemented, and provides the capabilities of Wasserstein Autoencoders.<br />
<br />
([https://openreview.net/forum?id=HkL7n1-0b]) The results from MNIST and CelebA datasets look convincing, though could include additional evaluation to compare the adversarial loss with the straightforward MMD metric and potentially discuss their pros and cons. In some sense, given the challenges in evaluating and comparing closely related auto-encoder solutions, the authors could design demonstrative experiments for cases where Wassersterin distance helps and maybe its potential limitations.<br />
<br />
<br />
<br />
=Critique=<br />
<br />
Although this paper presented some empirical tests to explain its method in an appropriate way, it would be better to provide some clearer notations including the details of the architectures in their experiments. Furthermore, they could benefit from performing some comparisons between the results of their work and other similar works. As pointed out by a reviewer, the closest work to this paper is the adversarial variational bayes framework by Mescheder et.al. which also attempts at unifying VAEs and GANs. Although the authors describe the conceptual differences and advantages over that approach, it will be beneficial to actually include some comparisons in the results section.<br />
Moreover, the performance of the algorithm is not a significant improvement compared to previous VAE algorithm. The performance can be described and tested if the author performed empirical tests on various data sets. However, the methodology is flexible and unified to other types of the algorithm which is a huge benefit.<br />
<br />
=References=<br />
[1] D. P. Kingma and M. Welling. Auto-encoding variational Bayes. In ICLR, 2014.<br />
<br />
[2] A. Makhzani, J. Shlens, N. Jaitly, and I. Goodfellow. Adversarial autoencoders. In ICLR, 2016.<br />
<br />
[3] Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial nets. In NIPS, pages 2672–2680, 2014.<br />
<br />
[4] O. Bousquet, S. Gelly, I. Tolstikhin, C. J. Simon-Gabriel, and B. Schölkopf. From optimal transport to generative modeling: the VEGAN cookbook, 2017.<br />
<br />
[5] M. Arjovsky, S. Chintala, and L. Bottou. Wasserstein GAN, 2017.<br />
<br />
[6] C. Villani. Topics in Optimal Transportation. AMS Graduate Studies in Mathematics, 2003.<br />
<br />
[7] Sebastian Nowozin, Botond Cseke, and Ryota Tomioka. f-GAN: Training generative neural samplers using variational divergence minimization. In NIPS, 2016.<br />
<br />
[8] I. Gulrajani, F. Ahmed, M. Arjovsky, V. Domoulin, and A. Courville. Improved training of wasserstein GANs, 2017.<br />
<br />
[9] A. Gretton, K. M. Borgwardt, M. J. Rasch, B. Schölkopf, and A. J. Smola. A kernel two-sample test. Journal of Machine Learning Research, 13:723–773, 2012.<br />
<br />
[10] F. Liese and K.-J. Miescke. Statistical Decision Theory. Springer, 2008.<br />
<br />
[11] L. Mescheder, S. Nowozin, and A. Geiger. Adversarial variational bayes: Unifying variational autoencoders and generative adversarial networks, 2017.<br />
<br />
[12] Y. Bengio, A. Courville, and P. Vincent. Representation learning: A review and new perspectives. Pattern Analysis and Machine Intelligence, 35, 2013.<br />
<br />
[13] M. D. Hoffman and M. Johnson. Elbo surgery: yet another way to carve up the variational evidence lower bound. In NIPS Workshop on Advances in Approximate Bayesian Inference, 2016.<br />
<br />
[14] S. Zhao, J. Song, and S. Ermon. InfoVAE: Information maximizing variational autoencoders, 2017.<br />
<br />
[15] A. Genevay, M. Cuturi, G. Peyré, and F. R. Bach. Stochastic optimization for large-scale optimal transport. In Advances in Neural Information Processing Systems, pages 3432–3440, 2016. <br />
<br />
[16] M. Cuturi. Sinkhorn distances: Lightspeed computation of optimal transport. In Advances in Neural Information Processing Systems, pages 2292–2300, 2013.<br />
<br />
[17] Lenaic Chizat, Gabriel Peyré, Bernhard Schmitzer, and François-Xavier Vialard. Unbalanced optimal transport: geometry and kantorovich formulation. arXiv preprint arXiv:1508.05216, 2015.<br />
<br />
[18] Matthias Liero, Alexander Mielke, and Giuseppe Savaré. Optimal entropy-transport problems and a new hellinger-kantorovich distance between positive measures. arXiv preprint arXiv:1508.07941, 2015.<br />
<br />
[19] J. Zhao, M. Mathieu, and Y. LeCun. Energy-based generative adversarial network. In ICLR, 2017.<br />
<br />
[20] V. Dumoulin, I. Belghazi, B. Poole, A. Lamb, M. Arjovsky, O. Mastropietro, and A. Courville. Adversarially learned inference. In ICLR, 2017.<br />
<br />
[21] D. Ulyanov, A. Vedaldi, and V. Lempitsky. It takes (only) two: Adversarial generator-encoder networks, 2017.<br />
<br />
[22] D. Berthelot, T. Schumm, and L. Metz. Began: Boundary equilibrium generative adversarial networks, 2017.<br />
<br />
[23] Y. Li, K. Swersky, and R. Zemel. Generative moment matching networks. In ICML, 2015. <br />
<br />
[24] G. K. Dziugaite, D. M. Roy, and Z. Ghahramani. Training generative neural networks via maximum mean discrepancy optimization. In UAI, 2015.<br />
<br />
[25] R. Reddi, A. Ramdas, A. Singh, B. Poczos, and L. Wasserman. On the high-dimensional power of a linear-time two sample test under mean-shift alternatives. In AISTATS, 2015.<br />
<br />
[26] C. L. Li, W. C. Chang, Y. Cheng, Y. Yang, and B. Poczos. Mmd gan: Towards deeper understanding of moment matching network, 2017.<br />
<br />
[27] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. Gradient-based learning applied to document recognition. In Proceedings of the IEEE, volume 86(11), pages 2278–2324, 1998.<br />
<br />
[28] Ziwei Liu, Ping Luo, Xiaogang Wang, and Xiaoou Tang. Deep learning face attributes in the wild. In Proceedings of International Conference on Computer Vision (ICCV), 2015.<br />
<br />
[29] D. P. Kingma and J. Lei. Adam: A method for stochastic optimization, 2014.<br />
<br />
[30] A. Radford, L. Metz, and S. Chintala. Unsupervised representation learning with deep convolutional generative adversarial networks. In ICLR, 2016.<br />
<br />
[31] S. Ioffe and C. Szegedy. Batch normalization: Accelerating deep network training by reducing internal covariate shift, 2015.<br />
<br />
[32] Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, Günter Klambauer, and Sepp Hochreiter. GANs trained by a two time-scale update rule converge to a nash equilibrium. arXiv preprint arXiv:1706.08500, 2017.<br />
<br />
[33] B. Poole, A. Alemi, J. Sohl-Dickstein, and A. Angelova. Improved generator objectives for GANs, 2016.<br />
<br />
[34] S. Kolouri, C. E. Martin, and G. K. Rohde. Sliced-wasserstein autoencoder: An embarrassingly simple generative model. arXiv preprint arXiv:1804.01947, 2018.</div>Gsahuhttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Wasserstein_Auto-encoders&diff=41932Wasserstein Auto-encoders2018-11-29T23:14:19Z<p>Gsahu: /* Critique */</p>
<hr />
<div>The first version of this work was published in 2017 and this version (which is the third revision) is presented in ICLR 2018. Source code for the first version is available [https://github.com/tolstikhin/wae here]<br />
<br />
=Introduction=<br />
Early successes in the field of representation learning were based on supervised approaches, which used large labeled datasets to achieve impressive results. On the other hand, popular unsupervised generative modeling methods mainly consisted of probabilistic approaches focusing on low dimensional data. In recent years, there have been models proposed which