http://wiki.math.uwaterloo.ca/statwiki/api.php?action=feedcontributions&user=X249wang&feedformat=atomstatwiki - User contributions [US]2021-04-13T14:31:00ZUser contributionsMediaWiki 1.28.3http://wiki.math.uwaterloo.ca/statwiki/index.php?title=Dynamic_Routing_Between_Capsules_STAT946&diff=36073Dynamic Routing Between Capsules STAT9462018-04-03T15:58:01Z<p>X249wang: /* Other data sets */</p>
<hr />
<div>= Presented by =<br />
<br />
Yang, Tong(Richard)<br />
<br />
= Contributions =<br />
<br />
This paper introduces the concept of "capsules" and an approach to implement its concept in neural networks. Capsules are a group of neurons used to represent various properties of an entity/object present in the image, such as pose, deformation, and even the existence of the entity. Instead of the obvious representation of a logistic unit for the probability of existence, the paper explores using the length of the capsule output vector to represent existence, and the orientation to represent other properties of the entity. The paper has the following major contributions:<br />
<br />
* Proposed an alternative approach to max-pooling, which is called routing-by-agreement.<br />
* Demonstrated an mathematical structure for capsule layers and routing mechanism that builds a prototype architecture for capsule networks. <br />
* Presented the promising results of CapsNet that confirms its value as a new direction for development in deep learning.<br />
<br />
= Hinton's Critiques on CNN =<br />
<br />
In the past talk, Hinton tried to explained why max-pooling is the biggest problem in current convolutional network structure, here are some highlights from his talk. <br />
<br />
== Four arguments against pooling ==<br />
<br />
* It is a bad fit to the psychology of shape perception: It does not explain why we assign intrinsic coordinate frames to objects and why they have such huge effects.<br />
<br />
* It solves the wrong problem: We want equivariance, not invariance. Disentangling rather than discarding.<br />
<br />
* It fails to use the underlying linear structure: It does not make use of the natural linear manifold that perfectly handles the largest source of variance in images.<br />
<br />
* Pooling is a poor way to do dynamic routing: We need to route each part of the input to the neurons that know how to deal with it. Finding the best routing is equivalent to parsing the image.<br />
<br />
===Intuition Behind Capsules ===<br />
We try to achieve viewpoint invariance in the activities of neurons by doing max-pooling. Invariance here means that by changing the input a little, the output still stays the same while the activity is just the output signal of a neuron. In other words, when in the input image we shift the object that we want to detect by a little bit, networks activities (outputs of neurons) will not change because of max pooling and the network will still detect the object. But the spacial relationships are not taken care of in this approach so instead capsules are used, because they encapsulate all important information about the state of the features they are detecting in a form of a vector. Capsules encode probability of detection of a feature as the length of their output vector. And the state of the detected feature is encoded as the direction in which that vector points to. So when detected feature moves around the image or its state somehow changes, the probability still stays the same (length of vector does not change), but its orientation changes.<br />
<br />
== Equivariance ==<br />
<br />
To deal with the invariance problem of CNN, Hinton proposes the concept called equivariance, which is the foundation of capsule concept.<br />
<br />
=== Two types of equivariance ===<br />
<br />
==== Place-coded equivariance ====<br />
If a low-level part moves to a very different position it will be represented by a different capsule.<br />
<br />
==== Rate-coded equivariance ====<br />
If a part only moves a small distance it will be represented by the same capsule but the pose outputs of the capsule will change.<br />
<br />
Higher-level capsules have bigger domains so low-level place-coded equivariance gets converted into high-level rate-coded equivariance.<br />
<br />
= Dynamic Routing =<br />
<br />
In the second section of this paper, authors give a mathematical representations for two key features in routing algorithm in capsule network, which are squashing and agreement. The general setting for this algorithm is between two arbitrary capsules i and j. Capsule j is assumed to be an arbitrary capsule from the first layer of capsules, and capsule i is an arbitrary capsule from the layer below. The purpose of routing algorithm is generate a vector output for routing decision between capsule j and capsule i. Furthermore, this vector output will be used in the decision for choice of dynamic routing. <br />
<br />
== Routing Algorithm ==<br />
<br />
The routing algorithm is as the following:<br />
<br />
[[File:DRBC_Figure_1.png|650px|center||Source: Sabour, Frosst, Hinton, 2017]]<br />
<br />
In the following sections, each part of this algorithm will be explained in details.<br />
<br />
=== Log Prior Probability ===<br />
<br />
<math>b_{ij}</math> represents the log prior probabilities that capsule i should be coupled to capsule j, and updated in each routing iteration. As line 2 suggests, the initial values of <math>b_{ij}</math> for all possible pairs of capsules are set to 0. In the very first routing iteration, <math>b_{ij}</math> equals to zero. For each routing iteration, <math>b_{ij}</math> gets updated by the value of agreement, which will be explained later.<br />
<br />
=== Coupling Coefficient === <br />
<br />
<math>c_{ij}</math> represents the coupling coefficient between capsule j and capsule i. It is calculated by applying the softmax function on the log prior probability <math>b_{ij}</math>. The mathematical transformation is shown below (Equation 3 in paper): <br />
<br />
\begin{align}<br />
c_{ij} = \frac{exp(b_ij)}{\sum_{k}exp(b_ik)}<br />
\end{align}<br />
<br />
<math>c_{ij}</math> are served as weights for computing the weighted sum and probabilities. Therefore, as probabilities, they have the following properties:<br />
<br />
\begin{align}<br />
c_{ij} \geq 0, \forall i, j<br />
\end{align}<br />
<br />
and, <br />
<br />
\begin{align}<br />
\sum_{i,j}c_{ij} = 1, \forall i, j<br />
\end{align}<br />
<br />
=== Predicted Output from Layer Below === <br />
<br />
<math>u_{i}</math> are the output vector from capsule i in the lower layer, and <math>\hat{u}_{j|i}</math> are the input vector for capsule j, which are the "prediction vectors" from the capsules in the layer below. <math>\hat{u}_{j|i}</math> is produced by multiplying <math>u_{i}</math> by a weight matrix <math>W_{ij}</math>, such as the following:<br />
<br />
\begin{align}<br />
\hat{u}_{j|i} = W_{ij}u_i<br />
\end{align}<br />
<br />
where <math>W_{ij}</math> encodes some spatial relationship between capsule j and capsule i.<br />
<br />
=== Capsule ===<br />
<br />
By using the definitions from previous sections, the total input vector for an arbitrary capsule j can be defined as:<br />
<br />
\begin{align}<br />
s_j = \sum_{i}c_{ij}\hat{u}_{j|i}<br />
\end{align}<br />
<br />
which is a weighted sum over all prediction vectors by using coupling coefficients.<br />
<br />
=== Squashing ===<br />
<br />
The length of <math>s_j</math> is arbitrary, which is needed to be addressed with. The next step is to convert its length between 0 and 1, since we want the length of the output vector of a capsule to represent the probability that the entity represented by the capsule is present in the current input. The "squashing" process is shown below:<br />
<br />
\begin{align}<br />
v_j = \frac{||s_j||^2}{1+||s_j||^2}\frac{s_j}{||s_j||}<br />
\end{align}<br />
<br />
Notice that "squashing" is not just normalizing the vector into unit length. In addition, it does extra non-linear transformation to ensure that short vectors get shrunk to almost zero length and long vectors get shrunk to a length slightly below 1. The reason for doing this is to make decision of routing, which is called "routing by agreement" much easier to make between capsule layers.<br />
<br />
=== Agreement ===<br />
<br />
The final step of a routing iteration is to form an routing agreement <math>a_{ij}</math>, which is represents as a scalar product:<br />
<br />
\begin{align}<br />
a_{ij} = v_{j}\hat{u}_{j|i}<br />
\end{align}<br />
<br />
As we mentioned in "squashing" section, the length of <math>v_{j}</math> is either close to 0 or close to 1, which will effect the magnitude of <math>a_{ij}</math> in this case. Therefore, the magnitude of <math>a_{ij}</math> indicate the how strong the routing algorithm agrees on taking the route between capsule j and capsule i. For each routing iteration, the log prior probability, <math>b_{ij}</math> will be updated by adding the value of its agreement value, which will effect how the coupling coefficients are computed in the next routing iteration. Because of the "squashing" process, we will eventually end up with a capsule j with its <math>v_{j}</math> close to 1 while all other capsules with its <math>v_{j}</math> close to 0, which indicates that this capsule j should be activated.<br />
<br />
= CapsNet Architecture =<br />
<br />
The second part of this paper discuss the experiment results from a 3-layer CapsNet, the architecture can be divided into two parts, encoder and decoder. <br />
<br />
== Encoder == <br />
<br />
[[File:DRBC_Architecture.png|650px|center||Source: Sabour, Frosst, Hinton, 2017]]<br />
<br />
=== How many routing iteration to use? === <br />
In appendix A of this paper, the authors have shown the empirical results from 500 epochs of training at different choice of routing iterations. According to their observation, more routing iterations increases the capacity of CapsNet but tends to bring additional risk of overfitting. Moreover, CapsNet with routing iterations less than three are not effective in general. As result, they suggest 3 iterations of routing for all experiments.<br />
<br />
=== Marginal loss for digit existence ===<br />
<br />
The experiments performed include segmenting overlapping digits on MultiMINST data set, so the loss function has be adjusted for presents of multiple digits. The marginal lose <math>L_k</math> for each capsule k is calculate by:<br />
<br />
\begin{align}<br />
L_k = T_k max(0, m^+ - ||v_k||)^2 + \lambda(1 - T_k) max(0, ||v_k|| - m^-)^2<br />
\end{align}<br />
<br />
where <math>m^+ = 0.9</math>, <math>m^- = 0.1</math>, and <math>\lambda = 0.5</math>.<br />
<br />
<math>T_k</math> is an indicator for presence of digit of class k, it takes value of 1 if and only if class k is presented. If class k is not presented, <math>\lambda</math> down-weight the loss which shrinks the lengths of the activity vectors for all the digit capsules. By doing this, The loss function penalizes the initial learning for all absent digit class, since we would like the top-level capsule for digit class k to have long instantiation vector if and only if that digit class is present in the input.<br />
<br />
=== Layer 1: Conv1 === <br />
<br />
The first layer of CapsNet. Similar to CNN, this is just convolutional layer that converts pixel intensities to activities of local feature detectors. <br />
<br />
* Layer Type: Convolutional Layer.<br />
* Input: <math>28 \times 28</math> pixels.<br />
* Kernel size: <math>9 \times 9</math>.<br />
* Number of Kernels: 256.<br />
* Activation function: ReLU.<br />
* Output: <math>20 \times 20 \times 256</math> tensor.<br />
<br />
=== Layer 2: PrimaryCapsules ===<br />
<br />
The second layer is formed by 32 primary 8D capsules. By 8D, it means that each primary capsule contains 8 convolutional units with a <math>9 \times 9</math> kernel and a stride of 2. Each capsule will take a <math>20 \times 20 \times 256</math> tensor from Conv1 and produce an output of a <math>6 \times 6 \times 8</math> tensor.<br />
<br />
* Layer Type: Convolutional Layer<br />
* Input: <math>20 \times 20 \times 256</math> tensor.<br />
* Number of capsules: 32.<br />
* Number of convolutional units in each capsule: 8.<br />
* Size of each convolutional unit: <math>6 \times 6</math>.<br />
* Output: <math>6 \times 6 \times 8</math> 8-dimensional vectors.<br />
<br />
=== Layer 3: DigitsCaps ===<br />
<br />
The last layer has 10 16D capsules, one for each digit. Not like the PrimaryCapsules layer, this layer is fully connected. Since this is the top capsule layer, dynamic routing mechanism will be applied between DigitsCaps and PrimaryCapsules. The process begins by taking a transformation of predicted output from PrimaryCapsules layer. Each output is a 8-dimensional vector, which needed to be mapped to a 16-dimensional space. Therefore, the weight matrix, <math>W_{ij}</math> is a <math>8 \times 16</math> matrix. The next step is to acquire coupling coefficients from routing algorithm and to perform "squashing" to get the output. <br />
<br />
* Layer Type: Fully connected layer.<br />
* Input: <math>6 \times 6 \times 8</math> 8-dimensional vectors.<br />
* Output: <math>16 \times 10 </math> matrix.<br />
<br />
=== The loss function ===<br />
<br />
The output of the loss function would be a ten-dimensional one-hot encoded vector with 9 zeros and 1 one at the correct position.<br />
<br />
<br />
== Regularization Method: Reconstruction ==<br />
<br />
This is regularization method introduced in the implementation of CapsNet. The method is to introduce a reconstruction loss (scaled down by 0.0005) to margin loss during training. The authors argue this would encourage the digit capsules to encode the instantiation parameters the input digits. All the reconstruction during training is by using the true labels of the image input. The results from experiments also confirms that adding the reconstruction regularizer enforces the pose encoding in CapsNet and thus boots the performance of routing procedure. <br />
<br />
=== Decoder ===<br />
<br />
The decoder consists of 3 fully connected layers, each layer maps pixel intensities to pixel intensities. The number of parameters in each layer and the activation functions used are indicated in the figure below:<br />
<br />
[[File:DRBC_Decoder.png|650px|center||Source: Sabour, Frosst, Hinton, 2017]]<br />
<br />
=== Result ===<br />
<br />
The authors includes some results for CapsNet classification test accuracy to justify the result of reconstruction. We can see that for CapsNet with 1 routing iteration and CapsNet with 3 routing iterations, implement reconstruction shows significant improvements in both MINIST and MultiMINST data set. These improvements show the importance of routing and reconstruction regularizer. <br />
<br />
[[File:DRBC_Reconstruction.png|650px|center||Source: Sabour, Frosst, Hinton, 2017]]<br />
<br />
= Experiment Results for CapsNet = <br />
<br />
In this part, the authors demonstrate experiment results of CapsNet on different data sets, such as MINIST and different variation of MINST, such as expanded MINST, affNIST, MultiMNIST. Moreover, they also briefly discuss the performance on some other popular data set such CIFAR 10. <br />
<br />
== MINST ==<br />
<br />
=== Highlights ===<br />
<br />
* CapsNet archives state-of-the-art performance on MINST with significantly fewer parameters (3-layer baseline CNN model has 35.4M parameters, compared to 8.2M for CapsNet with reconstruction network).<br />
* CapsNet with shallow structure (3 layers) achieves performance that only achieves by deeper network before.<br />
<br />
=== Interpretation of Each Capsule ===<br />
<br />
The authors suggest that they found evidence that dimension of some capsule always captures some variance of the digit, while some others represents the global combinations of different variations, this would open some possibility for interpretation of capsules in the future. After computing the activity vector for the correct digit capsule, the authors fed perturbed versions of those activity vectors to the decoder to examine the effect on reconstruction. Some results from perturbations are shown below, where each row represents the reconstructions when one of the 16 dimensions in the DigitCaps representation is tweaked by intervals of 0.05 from the range [-0.25, 0.25]: <br />
<br />
[[File:DRBC_Dimension.png|650px|center||Source: Sabour, Frosst, Hinton, 2017]]<br />
<br />
== affNIST == <br />
<br />
affNIT data set contains different affine transformation of original MINST data set. By the concept of capsule, CapsNet should gain more robustness from its equivariance nature, and the result confirms this. Compare the baseline CNN, CapsNet achieves 13% improvement on accuracy.<br />
<br />
== MultiMNIST ==<br />
<br />
The MultiMNIST is basically the overlapped version of MINIST. An important point to notice here is that this data set is generated by overlaying a digit on top of another digit from the same set but different class. In other words, the case of stacking digits from the same class is not allowed in MultiMINST. For example, stacking a 5 on a 0 is allowed, but stacking a 5 on another 5 is not. The reason is that CapsNet suffers from the "crowding" effect which will be discussed in the weakness of CapsNet section. <br />
<br />
== Other data sets ==<br />
<br />
CapsNet is used on other data sets such as CIFAR10, smallNORB and SVHN. The results are not comparable with state-of-the-art performance, but it is still promising since this architecture is the very first, while other networks have been development for a long time. The authors pointed out one drawback of CapsNet is that they tend to account for everything in the input images - in the CIFAR10 dataset, the image backgrounds were too varied to model in a reasonably sized network, which partly explains the poorer results.<br />
<br />
= Conclusion = <br />
<br />
This paper discuss the specific part of capsule network, which is the routing-by-agreement mechanism. The authors suggest this is a great approach to solve the current problem with max-pooling in convolutional neural network. Moreover, as author mentioned, the approach mentioned in this paper is only one possible implementation of the capsule concept. The preliminary results from experiment using a simple shallow CapsNet also demonstrate unparalleled performance that indicates the capsules are a direction worth exploring. <br />
<br />
= Weakness of Capsule Network =<br />
<br />
* Routing algorithm introduces internal loops for each capsule. As number of capsules and layers increases, these internal loops may exponentially expand the training time. <br />
* Capsule network suffers a perceptual phenomenon called "crowding", which is common for human vision as well. To address this weakness, capsules have to make a very strong representation assumption that at each location of the image, there is at most one instance of the type of entity that capsule represents. This is also the reason for not allowing overlaying digits from same class in generating process of MultiMINST.<br />
* Other criticisms include that the design of capsule networks requires domain knowledge or feature engineering, contrary to the abstraction-oriented goals of deep learning.<br />
<br />
= Implementations = <br />
1) Tensorflow Implementation : https://github.com/naturomics/CapsNet-Tensorflow<br />
<br />
2) Keras Implementation. : https://github.com/XifengGuo/CapsNet-Keras<br />
<br />
= References =<br />
# S. Sabour, N. Frosst, and G. E. Hinton, “Dynamic routing between capsules,” arXiv preprint arXiv:1710.09829v2, 2017<br />
# “XifengGuo/CapsNet-Keras.” GitHub, 14 Dec. 2017, github.com/XifengGuo/CapsNet-Keras. <br />
# “Naturomics/CapsNet-Tensorflow.” GitHub, 6 Mar. 2018, github.com/naturomics/CapsNet-Tensorflow.</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Dynamic_Routing_Between_Capsules_STAT946&diff=36072Dynamic Routing Between Capsules STAT9462018-04-03T15:57:35Z<p>X249wang: /* Other data sets */</p>
<hr />
<div>= Presented by =<br />
<br />
Yang, Tong(Richard)<br />
<br />
= Contributions =<br />
<br />
This paper introduces the concept of "capsules" and an approach to implement its concept in neural networks. Capsules are a group of neurons used to represent various properties of an entity/object present in the image, such as pose, deformation, and even the existence of the entity. Instead of the obvious representation of a logistic unit for the probability of existence, the paper explores using the length of the capsule output vector to represent existence, and the orientation to represent other properties of the entity. The paper has the following major contributions:<br />
<br />
* Proposed an alternative approach to max-pooling, which is called routing-by-agreement.<br />
* Demonstrated an mathematical structure for capsule layers and routing mechanism that builds a prototype architecture for capsule networks. <br />
* Presented the promising results of CapsNet that confirms its value as a new direction for development in deep learning.<br />
<br />
= Hinton's Critiques on CNN =<br />
<br />
In the past talk, Hinton tried to explained why max-pooling is the biggest problem in current convolutional network structure, here are some highlights from his talk. <br />
<br />
== Four arguments against pooling ==<br />
<br />
* It is a bad fit to the psychology of shape perception: It does not explain why we assign intrinsic coordinate frames to objects and why they have such huge effects.<br />
<br />
* It solves the wrong problem: We want equivariance, not invariance. Disentangling rather than discarding.<br />
<br />
* It fails to use the underlying linear structure: It does not make use of the natural linear manifold that perfectly handles the largest source of variance in images.<br />
<br />
* Pooling is a poor way to do dynamic routing: We need to route each part of the input to the neurons that know how to deal with it. Finding the best routing is equivalent to parsing the image.<br />
<br />
===Intuition Behind Capsules ===<br />
We try to achieve viewpoint invariance in the activities of neurons by doing max-pooling. Invariance here means that by changing the input a little, the output still stays the same while the activity is just the output signal of a neuron. In other words, when in the input image we shift the object that we want to detect by a little bit, networks activities (outputs of neurons) will not change because of max pooling and the network will still detect the object. But the spacial relationships are not taken care of in this approach so instead capsules are used, because they encapsulate all important information about the state of the features they are detecting in a form of a vector. Capsules encode probability of detection of a feature as the length of their output vector. And the state of the detected feature is encoded as the direction in which that vector points to. So when detected feature moves around the image or its state somehow changes, the probability still stays the same (length of vector does not change), but its orientation changes.<br />
<br />
== Equivariance ==<br />
<br />
To deal with the invariance problem of CNN, Hinton proposes the concept called equivariance, which is the foundation of capsule concept.<br />
<br />
=== Two types of equivariance ===<br />
<br />
==== Place-coded equivariance ====<br />
If a low-level part moves to a very different position it will be represented by a different capsule.<br />
<br />
==== Rate-coded equivariance ====<br />
If a part only moves a small distance it will be represented by the same capsule but the pose outputs of the capsule will change.<br />
<br />
Higher-level capsules have bigger domains so low-level place-coded equivariance gets converted into high-level rate-coded equivariance.<br />
<br />
= Dynamic Routing =<br />
<br />
In the second section of this paper, authors give a mathematical representations for two key features in routing algorithm in capsule network, which are squashing and agreement. The general setting for this algorithm is between two arbitrary capsules i and j. Capsule j is assumed to be an arbitrary capsule from the first layer of capsules, and capsule i is an arbitrary capsule from the layer below. The purpose of routing algorithm is generate a vector output for routing decision between capsule j and capsule i. Furthermore, this vector output will be used in the decision for choice of dynamic routing. <br />
<br />
== Routing Algorithm ==<br />
<br />
The routing algorithm is as the following:<br />
<br />
[[File:DRBC_Figure_1.png|650px|center||Source: Sabour, Frosst, Hinton, 2017]]<br />
<br />
In the following sections, each part of this algorithm will be explained in details.<br />
<br />
=== Log Prior Probability ===<br />
<br />
<math>b_{ij}</math> represents the log prior probabilities that capsule i should be coupled to capsule j, and updated in each routing iteration. As line 2 suggests, the initial values of <math>b_{ij}</math> for all possible pairs of capsules are set to 0. In the very first routing iteration, <math>b_{ij}</math> equals to zero. For each routing iteration, <math>b_{ij}</math> gets updated by the value of agreement, which will be explained later.<br />
<br />
=== Coupling Coefficient === <br />
<br />
<math>c_{ij}</math> represents the coupling coefficient between capsule j and capsule i. It is calculated by applying the softmax function on the log prior probability <math>b_{ij}</math>. The mathematical transformation is shown below (Equation 3 in paper): <br />
<br />
\begin{align}<br />
c_{ij} = \frac{exp(b_ij)}{\sum_{k}exp(b_ik)}<br />
\end{align}<br />
<br />
<math>c_{ij}</math> are served as weights for computing the weighted sum and probabilities. Therefore, as probabilities, they have the following properties:<br />
<br />
\begin{align}<br />
c_{ij} \geq 0, \forall i, j<br />
\end{align}<br />
<br />
and, <br />
<br />
\begin{align}<br />
\sum_{i,j}c_{ij} = 1, \forall i, j<br />
\end{align}<br />
<br />
=== Predicted Output from Layer Below === <br />
<br />
<math>u_{i}</math> are the output vector from capsule i in the lower layer, and <math>\hat{u}_{j|i}</math> are the input vector for capsule j, which are the "prediction vectors" from the capsules in the layer below. <math>\hat{u}_{j|i}</math> is produced by multiplying <math>u_{i}</math> by a weight matrix <math>W_{ij}</math>, such as the following:<br />
<br />
\begin{align}<br />
\hat{u}_{j|i} = W_{ij}u_i<br />
\end{align}<br />
<br />
where <math>W_{ij}</math> encodes some spatial relationship between capsule j and capsule i.<br />
<br />
=== Capsule ===<br />
<br />
By using the definitions from previous sections, the total input vector for an arbitrary capsule j can be defined as:<br />
<br />
\begin{align}<br />
s_j = \sum_{i}c_{ij}\hat{u}_{j|i}<br />
\end{align}<br />
<br />
which is a weighted sum over all prediction vectors by using coupling coefficients.<br />
<br />
=== Squashing ===<br />
<br />
The length of <math>s_j</math> is arbitrary, which is needed to be addressed with. The next step is to convert its length between 0 and 1, since we want the length of the output vector of a capsule to represent the probability that the entity represented by the capsule is present in the current input. The "squashing" process is shown below:<br />
<br />
\begin{align}<br />
v_j = \frac{||s_j||^2}{1+||s_j||^2}\frac{s_j}{||s_j||}<br />
\end{align}<br />
<br />
Notice that "squashing" is not just normalizing the vector into unit length. In addition, it does extra non-linear transformation to ensure that short vectors get shrunk to almost zero length and long vectors get shrunk to a length slightly below 1. The reason for doing this is to make decision of routing, which is called "routing by agreement" much easier to make between capsule layers.<br />
<br />
=== Agreement ===<br />
<br />
The final step of a routing iteration is to form an routing agreement <math>a_{ij}</math>, which is represents as a scalar product:<br />
<br />
\begin{align}<br />
a_{ij} = v_{j}\hat{u}_{j|i}<br />
\end{align}<br />
<br />
As we mentioned in "squashing" section, the length of <math>v_{j}</math> is either close to 0 or close to 1, which will effect the magnitude of <math>a_{ij}</math> in this case. Therefore, the magnitude of <math>a_{ij}</math> indicate the how strong the routing algorithm agrees on taking the route between capsule j and capsule i. For each routing iteration, the log prior probability, <math>b_{ij}</math> will be updated by adding the value of its agreement value, which will effect how the coupling coefficients are computed in the next routing iteration. Because of the "squashing" process, we will eventually end up with a capsule j with its <math>v_{j}</math> close to 1 while all other capsules with its <math>v_{j}</math> close to 0, which indicates that this capsule j should be activated.<br />
<br />
= CapsNet Architecture =<br />
<br />
The second part of this paper discuss the experiment results from a 3-layer CapsNet, the architecture can be divided into two parts, encoder and decoder. <br />
<br />
== Encoder == <br />
<br />
[[File:DRBC_Architecture.png|650px|center||Source: Sabour, Frosst, Hinton, 2017]]<br />
<br />
=== How many routing iteration to use? === <br />
In appendix A of this paper, the authors have shown the empirical results from 500 epochs of training at different choice of routing iterations. According to their observation, more routing iterations increases the capacity of CapsNet but tends to bring additional risk of overfitting. Moreover, CapsNet with routing iterations less than three are not effective in general. As result, they suggest 3 iterations of routing for all experiments.<br />
<br />
=== Marginal loss for digit existence ===<br />
<br />
The experiments performed include segmenting overlapping digits on MultiMINST data set, so the loss function has be adjusted for presents of multiple digits. The marginal lose <math>L_k</math> for each capsule k is calculate by:<br />
<br />
\begin{align}<br />
L_k = T_k max(0, m^+ - ||v_k||)^2 + \lambda(1 - T_k) max(0, ||v_k|| - m^-)^2<br />
\end{align}<br />
<br />
where <math>m^+ = 0.9</math>, <math>m^- = 0.1</math>, and <math>\lambda = 0.5</math>.<br />
<br />
<math>T_k</math> is an indicator for presence of digit of class k, it takes value of 1 if and only if class k is presented. If class k is not presented, <math>\lambda</math> down-weight the loss which shrinks the lengths of the activity vectors for all the digit capsules. By doing this, The loss function penalizes the initial learning for all absent digit class, since we would like the top-level capsule for digit class k to have long instantiation vector if and only if that digit class is present in the input.<br />
<br />
=== Layer 1: Conv1 === <br />
<br />
The first layer of CapsNet. Similar to CNN, this is just convolutional layer that converts pixel intensities to activities of local feature detectors. <br />
<br />
* Layer Type: Convolutional Layer.<br />
* Input: <math>28 \times 28</math> pixels.<br />
* Kernel size: <math>9 \times 9</math>.<br />
* Number of Kernels: 256.<br />
* Activation function: ReLU.<br />
* Output: <math>20 \times 20 \times 256</math> tensor.<br />
<br />
=== Layer 2: PrimaryCapsules ===<br />
<br />
The second layer is formed by 32 primary 8D capsules. By 8D, it means that each primary capsule contains 8 convolutional units with a <math>9 \times 9</math> kernel and a stride of 2. Each capsule will take a <math>20 \times 20 \times 256</math> tensor from Conv1 and produce an output of a <math>6 \times 6 \times 8</math> tensor.<br />
<br />
* Layer Type: Convolutional Layer<br />
* Input: <math>20 \times 20 \times 256</math> tensor.<br />
* Number of capsules: 32.<br />
* Number of convolutional units in each capsule: 8.<br />
* Size of each convolutional unit: <math>6 \times 6</math>.<br />
* Output: <math>6 \times 6 \times 8</math> 8-dimensional vectors.<br />
<br />
=== Layer 3: DigitsCaps ===<br />
<br />
The last layer has 10 16D capsules, one for each digit. Not like the PrimaryCapsules layer, this layer is fully connected. Since this is the top capsule layer, dynamic routing mechanism will be applied between DigitsCaps and PrimaryCapsules. The process begins by taking a transformation of predicted output from PrimaryCapsules layer. Each output is a 8-dimensional vector, which needed to be mapped to a 16-dimensional space. Therefore, the weight matrix, <math>W_{ij}</math> is a <math>8 \times 16</math> matrix. The next step is to acquire coupling coefficients from routing algorithm and to perform "squashing" to get the output. <br />
<br />
* Layer Type: Fully connected layer.<br />
* Input: <math>6 \times 6 \times 8</math> 8-dimensional vectors.<br />
* Output: <math>16 \times 10 </math> matrix.<br />
<br />
=== The loss function ===<br />
<br />
The output of the loss function would be a ten-dimensional one-hot encoded vector with 9 zeros and 1 one at the correct position.<br />
<br />
<br />
== Regularization Method: Reconstruction ==<br />
<br />
This is regularization method introduced in the implementation of CapsNet. The method is to introduce a reconstruction loss (scaled down by 0.0005) to margin loss during training. The authors argue this would encourage the digit capsules to encode the instantiation parameters the input digits. All the reconstruction during training is by using the true labels of the image input. The results from experiments also confirms that adding the reconstruction regularizer enforces the pose encoding in CapsNet and thus boots the performance of routing procedure. <br />
<br />
=== Decoder ===<br />
<br />
The decoder consists of 3 fully connected layers, each layer maps pixel intensities to pixel intensities. The number of parameters in each layer and the activation functions used are indicated in the figure below:<br />
<br />
[[File:DRBC_Decoder.png|650px|center||Source: Sabour, Frosst, Hinton, 2017]]<br />
<br />
=== Result ===<br />
<br />
The authors includes some results for CapsNet classification test accuracy to justify the result of reconstruction. We can see that for CapsNet with 1 routing iteration and CapsNet with 3 routing iterations, implement reconstruction shows significant improvements in both MINIST and MultiMINST data set. These improvements show the importance of routing and reconstruction regularizer. <br />
<br />
[[File:DRBC_Reconstruction.png|650px|center||Source: Sabour, Frosst, Hinton, 2017]]<br />
<br />
= Experiment Results for CapsNet = <br />
<br />
In this part, the authors demonstrate experiment results of CapsNet on different data sets, such as MINIST and different variation of MINST, such as expanded MINST, affNIST, MultiMNIST. Moreover, they also briefly discuss the performance on some other popular data set such CIFAR 10. <br />
<br />
== MINST ==<br />
<br />
=== Highlights ===<br />
<br />
* CapsNet archives state-of-the-art performance on MINST with significantly fewer parameters (3-layer baseline CNN model has 35.4M parameters, compared to 8.2M for CapsNet with reconstruction network).<br />
* CapsNet with shallow structure (3 layers) achieves performance that only achieves by deeper network before.<br />
<br />
=== Interpretation of Each Capsule ===<br />
<br />
The authors suggest that they found evidence that dimension of some capsule always captures some variance of the digit, while some others represents the global combinations of different variations, this would open some possibility for interpretation of capsules in the future. After computing the activity vector for the correct digit capsule, the authors fed perturbed versions of those activity vectors to the decoder to examine the effect on reconstruction. Some results from perturbations are shown below, where each row represents the reconstructions when one of the 16 dimensions in the DigitCaps representation is tweaked by intervals of 0.05 from the range [-0.25, 0.25]: <br />
<br />
[[File:DRBC_Dimension.png|650px|center||Source: Sabour, Frosst, Hinton, 2017]]<br />
<br />
== affNIST == <br />
<br />
affNIT data set contains different affine transformation of original MINST data set. By the concept of capsule, CapsNet should gain more robustness from its equivariance nature, and the result confirms this. Compare the baseline CNN, CapsNet achieves 13% improvement on accuracy.<br />
<br />
== MultiMNIST ==<br />
<br />
The MultiMNIST is basically the overlapped version of MINIST. An important point to notice here is that this data set is generated by overlaying a digit on top of another digit from the same set but different class. In other words, the case of stacking digits from the same class is not allowed in MultiMINST. For example, stacking a 5 on a 0 is allowed, but stacking a 5 on another 5 is not. The reason is that CapsNet suffers from the "crowding" effect which will be discussed in the weakness of CapsNet section. <br />
<br />
== Other data sets ==<br />
<br />
CapsNet is used on other data sets such as CIFAR10, smallNORB and SVHN. The results are not comparable with state-of-the-art performance, but it is still promising since this architecture is the very first, while other networks have been development for a long time. The authors pointed out one drawback of CapsNet is that they tend to account for everything in the input images. For example, in the CIFAR10 dataset, the image backgrounds were too varied to model in a reasonably sized network, which partly explains the poorer results.<br />
<br />
= Conclusion = <br />
<br />
This paper discuss the specific part of capsule network, which is the routing-by-agreement mechanism. The authors suggest this is a great approach to solve the current problem with max-pooling in convolutional neural network. Moreover, as author mentioned, the approach mentioned in this paper is only one possible implementation of the capsule concept. The preliminary results from experiment using a simple shallow CapsNet also demonstrate unparalleled performance that indicates the capsules are a direction worth exploring. <br />
<br />
= Weakness of Capsule Network =<br />
<br />
* Routing algorithm introduces internal loops for each capsule. As number of capsules and layers increases, these internal loops may exponentially expand the training time. <br />
* Capsule network suffers a perceptual phenomenon called "crowding", which is common for human vision as well. To address this weakness, capsules have to make a very strong representation assumption that at each location of the image, there is at most one instance of the type of entity that capsule represents. This is also the reason for not allowing overlaying digits from same class in generating process of MultiMINST.<br />
* Other criticisms include that the design of capsule networks requires domain knowledge or feature engineering, contrary to the abstraction-oriented goals of deep learning.<br />
<br />
= Implementations = <br />
1) Tensorflow Implementation : https://github.com/naturomics/CapsNet-Tensorflow<br />
<br />
2) Keras Implementation. : https://github.com/XifengGuo/CapsNet-Keras<br />
<br />
= References =<br />
# S. Sabour, N. Frosst, and G. E. Hinton, “Dynamic routing between capsules,” arXiv preprint arXiv:1710.09829v2, 2017<br />
# “XifengGuo/CapsNet-Keras.” GitHub, 14 Dec. 2017, github.com/XifengGuo/CapsNet-Keras. <br />
# “Naturomics/CapsNet-Tensorflow.” GitHub, 6 Mar. 2018, github.com/naturomics/CapsNet-Tensorflow.</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Dynamic_Routing_Between_Capsules_STAT946&diff=36071Dynamic Routing Between Capsules STAT9462018-04-03T15:44:52Z<p>X249wang: /* Highlights */</p>
<hr />
<div>= Presented by =<br />
<br />
Yang, Tong(Richard)<br />
<br />
= Contributions =<br />
<br />
This paper introduces the concept of "capsules" and an approach to implement its concept in neural networks. Capsules are a group of neurons used to represent various properties of an entity/object present in the image, such as pose, deformation, and even the existence of the entity. Instead of the obvious representation of a logistic unit for the probability of existence, the paper explores using the length of the capsule output vector to represent existence, and the orientation to represent other properties of the entity. The paper has the following major contributions:<br />
<br />
* Proposed an alternative approach to max-pooling, which is called routing-by-agreement.<br />
* Demonstrated an mathematical structure for capsule layers and routing mechanism that builds a prototype architecture for capsule networks. <br />
* Presented the promising results of CapsNet that confirms its value as a new direction for development in deep learning.<br />
<br />
= Hinton's Critiques on CNN =<br />
<br />
In the past talk, Hinton tried to explained why max-pooling is the biggest problem in current convolutional network structure, here are some highlights from his talk. <br />
<br />
== Four arguments against pooling ==<br />
<br />
* It is a bad fit to the psychology of shape perception: It does not explain why we assign intrinsic coordinate frames to objects and why they have such huge effects.<br />
<br />
* It solves the wrong problem: We want equivariance, not invariance. Disentangling rather than discarding.<br />
<br />
* It fails to use the underlying linear structure: It does not make use of the natural linear manifold that perfectly handles the largest source of variance in images.<br />
<br />
* Pooling is a poor way to do dynamic routing: We need to route each part of the input to the neurons that know how to deal with it. Finding the best routing is equivalent to parsing the image.<br />
<br />
===Intuition Behind Capsules ===<br />
We try to achieve viewpoint invariance in the activities of neurons by doing max-pooling. Invariance here means that by changing the input a little, the output still stays the same while the activity is just the output signal of a neuron. In other words, when in the input image we shift the object that we want to detect by a little bit, networks activities (outputs of neurons) will not change because of max pooling and the network will still detect the object. But the spacial relationships are not taken care of in this approach so instead capsules are used, because they encapsulate all important information about the state of the features they are detecting in a form of a vector. Capsules encode probability of detection of a feature as the length of their output vector. And the state of the detected feature is encoded as the direction in which that vector points to. So when detected feature moves around the image or its state somehow changes, the probability still stays the same (length of vector does not change), but its orientation changes.<br />
<br />
== Equivariance ==<br />
<br />
To deal with the invariance problem of CNN, Hinton proposes the concept called equivariance, which is the foundation of capsule concept.<br />
<br />
=== Two types of equivariance ===<br />
<br />
==== Place-coded equivariance ====<br />
If a low-level part moves to a very different position it will be represented by a different capsule.<br />
<br />
==== Rate-coded equivariance ====<br />
If a part only moves a small distance it will be represented by the same capsule but the pose outputs of the capsule will change.<br />
<br />
Higher-level capsules have bigger domains so low-level place-coded equivariance gets converted into high-level rate-coded equivariance.<br />
<br />
= Dynamic Routing =<br />
<br />
In the second section of this paper, authors give a mathematical representations for two key features in routing algorithm in capsule network, which are squashing and agreement. The general setting for this algorithm is between two arbitrary capsules i and j. Capsule j is assumed to be an arbitrary capsule from the first layer of capsules, and capsule i is an arbitrary capsule from the layer below. The purpose of routing algorithm is generate a vector output for routing decision between capsule j and capsule i. Furthermore, this vector output will be used in the decision for choice of dynamic routing. <br />
<br />
== Routing Algorithm ==<br />
<br />
The routing algorithm is as the following:<br />
<br />
[[File:DRBC_Figure_1.png|650px|center||Source: Sabour, Frosst, Hinton, 2017]]<br />
<br />
In the following sections, each part of this algorithm will be explained in details.<br />
<br />
=== Log Prior Probability ===<br />
<br />
<math>b_{ij}</math> represents the log prior probabilities that capsule i should be coupled to capsule j, and updated in each routing iteration. As line 2 suggests, the initial values of <math>b_{ij}</math> for all possible pairs of capsules are set to 0. In the very first routing iteration, <math>b_{ij}</math> equals to zero. For each routing iteration, <math>b_{ij}</math> gets updated by the value of agreement, which will be explained later.<br />
<br />
=== Coupling Coefficient === <br />
<br />
<math>c_{ij}</math> represents the coupling coefficient between capsule j and capsule i. It is calculated by applying the softmax function on the log prior probability <math>b_{ij}</math>. The mathematical transformation is shown below (Equation 3 in paper): <br />
<br />
\begin{align}<br />
c_{ij} = \frac{exp(b_ij)}{\sum_{k}exp(b_ik)}<br />
\end{align}<br />
<br />
<math>c_{ij}</math> are served as weights for computing the weighted sum and probabilities. Therefore, as probabilities, they have the following properties:<br />
<br />
\begin{align}<br />
c_{ij} \geq 0, \forall i, j<br />
\end{align}<br />
<br />
and, <br />
<br />
\begin{align}<br />
\sum_{i,j}c_{ij} = 1, \forall i, j<br />
\end{align}<br />
<br />
=== Predicted Output from Layer Below === <br />
<br />
<math>u_{i}</math> are the output vector from capsule i in the lower layer, and <math>\hat{u}_{j|i}</math> are the input vector for capsule j, which are the "prediction vectors" from the capsules in the layer below. <math>\hat{u}_{j|i}</math> is produced by multiplying <math>u_{i}</math> by a weight matrix <math>W_{ij}</math>, such as the following:<br />
<br />
\begin{align}<br />
\hat{u}_{j|i} = W_{ij}u_i<br />
\end{align}<br />
<br />
where <math>W_{ij}</math> encodes some spatial relationship between capsule j and capsule i.<br />
<br />
=== Capsule ===<br />
<br />
By using the definitions from previous sections, the total input vector for an arbitrary capsule j can be defined as:<br />
<br />
\begin{align}<br />
s_j = \sum_{i}c_{ij}\hat{u}_{j|i}<br />
\end{align}<br />
<br />
which is a weighted sum over all prediction vectors by using coupling coefficients.<br />
<br />
=== Squashing ===<br />
<br />
The length of <math>s_j</math> is arbitrary, which is needed to be addressed with. The next step is to convert its length between 0 and 1, since we want the length of the output vector of a capsule to represent the probability that the entity represented by the capsule is present in the current input. The "squashing" process is shown below:<br />
<br />
\begin{align}<br />
v_j = \frac{||s_j||^2}{1+||s_j||^2}\frac{s_j}{||s_j||}<br />
\end{align}<br />
<br />
Notice that "squashing" is not just normalizing the vector into unit length. In addition, it does extra non-linear transformation to ensure that short vectors get shrunk to almost zero length and long vectors get shrunk to a length slightly below 1. The reason for doing this is to make decision of routing, which is called "routing by agreement" much easier to make between capsule layers.<br />
<br />
=== Agreement ===<br />
<br />
The final step of a routing iteration is to form an routing agreement <math>a_{ij}</math>, which is represents as a scalar product:<br />
<br />
\begin{align}<br />
a_{ij} = v_{j}\hat{u}_{j|i}<br />
\end{align}<br />
<br />
As we mentioned in "squashing" section, the length of <math>v_{j}</math> is either close to 0 or close to 1, which will effect the magnitude of <math>a_{ij}</math> in this case. Therefore, the magnitude of <math>a_{ij}</math> indicate the how strong the routing algorithm agrees on taking the route between capsule j and capsule i. For each routing iteration, the log prior probability, <math>b_{ij}</math> will be updated by adding the value of its agreement value, which will effect how the coupling coefficients are computed in the next routing iteration. Because of the "squashing" process, we will eventually end up with a capsule j with its <math>v_{j}</math> close to 1 while all other capsules with its <math>v_{j}</math> close to 0, which indicates that this capsule j should be activated.<br />
<br />
= CapsNet Architecture =<br />
<br />
The second part of this paper discuss the experiment results from a 3-layer CapsNet, the architecture can be divided into two parts, encoder and decoder. <br />
<br />
== Encoder == <br />
<br />
[[File:DRBC_Architecture.png|650px|center||Source: Sabour, Frosst, Hinton, 2017]]<br />
<br />
=== How many routing iteration to use? === <br />
In appendix A of this paper, the authors have shown the empirical results from 500 epochs of training at different choice of routing iterations. According to their observation, more routing iterations increases the capacity of CapsNet but tends to bring additional risk of overfitting. Moreover, CapsNet with routing iterations less than three are not effective in general. As result, they suggest 3 iterations of routing for all experiments.<br />
<br />
=== Marginal loss for digit existence ===<br />
<br />
The experiments performed include segmenting overlapping digits on MultiMINST data set, so the loss function has be adjusted for presents of multiple digits. The marginal lose <math>L_k</math> for each capsule k is calculate by:<br />
<br />
\begin{align}<br />
L_k = T_k max(0, m^+ - ||v_k||)^2 + \lambda(1 - T_k) max(0, ||v_k|| - m^-)^2<br />
\end{align}<br />
<br />
where <math>m^+ = 0.9</math>, <math>m^- = 0.1</math>, and <math>\lambda = 0.5</math>.<br />
<br />
<math>T_k</math> is an indicator for presence of digit of class k, it takes value of 1 if and only if class k is presented. If class k is not presented, <math>\lambda</math> down-weight the loss which shrinks the lengths of the activity vectors for all the digit capsules. By doing this, The loss function penalizes the initial learning for all absent digit class, since we would like the top-level capsule for digit class k to have long instantiation vector if and only if that digit class is present in the input.<br />
<br />
=== Layer 1: Conv1 === <br />
<br />
The first layer of CapsNet. Similar to CNN, this is just convolutional layer that converts pixel intensities to activities of local feature detectors. <br />
<br />
* Layer Type: Convolutional Layer.<br />
* Input: <math>28 \times 28</math> pixels.<br />
* Kernel size: <math>9 \times 9</math>.<br />
* Number of Kernels: 256.<br />
* Activation function: ReLU.<br />
* Output: <math>20 \times 20 \times 256</math> tensor.<br />
<br />
=== Layer 2: PrimaryCapsules ===<br />
<br />
The second layer is formed by 32 primary 8D capsules. By 8D, it means that each primary capsule contains 8 convolutional units with a <math>9 \times 9</math> kernel and a stride of 2. Each capsule will take a <math>20 \times 20 \times 256</math> tensor from Conv1 and produce an output of a <math>6 \times 6 \times 8</math> tensor.<br />
<br />
* Layer Type: Convolutional Layer<br />
* Input: <math>20 \times 20 \times 256</math> tensor.<br />
* Number of capsules: 32.<br />
* Number of convolutional units in each capsule: 8.<br />
* Size of each convolutional unit: <math>6 \times 6</math>.<br />
* Output: <math>6 \times 6 \times 8</math> 8-dimensional vectors.<br />
<br />
=== Layer 3: DigitsCaps ===<br />
<br />
The last layer has 10 16D capsules, one for each digit. Not like the PrimaryCapsules layer, this layer is fully connected. Since this is the top capsule layer, dynamic routing mechanism will be applied between DigitsCaps and PrimaryCapsules. The process begins by taking a transformation of predicted output from PrimaryCapsules layer. Each output is a 8-dimensional vector, which needed to be mapped to a 16-dimensional space. Therefore, the weight matrix, <math>W_{ij}</math> is a <math>8 \times 16</math> matrix. The next step is to acquire coupling coefficients from routing algorithm and to perform "squashing" to get the output. <br />
<br />
* Layer Type: Fully connected layer.<br />
* Input: <math>6 \times 6 \times 8</math> 8-dimensional vectors.<br />
* Output: <math>16 \times 10 </math> matrix.<br />
<br />
=== The loss function ===<br />
<br />
The output of the loss function would be a ten-dimensional one-hot encoded vector with 9 zeros and 1 one at the correct position.<br />
<br />
<br />
== Regularization Method: Reconstruction ==<br />
<br />
This is regularization method introduced in the implementation of CapsNet. The method is to introduce a reconstruction loss (scaled down by 0.0005) to margin loss during training. The authors argue this would encourage the digit capsules to encode the instantiation parameters the input digits. All the reconstruction during training is by using the true labels of the image input. The results from experiments also confirms that adding the reconstruction regularizer enforces the pose encoding in CapsNet and thus boots the performance of routing procedure. <br />
<br />
=== Decoder ===<br />
<br />
The decoder consists of 3 fully connected layers, each layer maps pixel intensities to pixel intensities. The number of parameters in each layer and the activation functions used are indicated in the figure below:<br />
<br />
[[File:DRBC_Decoder.png|650px|center||Source: Sabour, Frosst, Hinton, 2017]]<br />
<br />
=== Result ===<br />
<br />
The authors includes some results for CapsNet classification test accuracy to justify the result of reconstruction. We can see that for CapsNet with 1 routing iteration and CapsNet with 3 routing iterations, implement reconstruction shows significant improvements in both MINIST and MultiMINST data set. These improvements show the importance of routing and reconstruction regularizer. <br />
<br />
[[File:DRBC_Reconstruction.png|650px|center||Source: Sabour, Frosst, Hinton, 2017]]<br />
<br />
= Experiment Results for CapsNet = <br />
<br />
In this part, the authors demonstrate experiment results of CapsNet on different data sets, such as MINIST and different variation of MINST, such as expanded MINST, affNIST, MultiMNIST. Moreover, they also briefly discuss the performance on some other popular data set such CIFAR 10. <br />
<br />
== MINST ==<br />
<br />
=== Highlights ===<br />
<br />
* CapsNet archives state-of-the-art performance on MINST with significantly fewer parameters (3-layer baseline CNN model has 35.4M parameters, compared to 8.2M for CapsNet with reconstruction network).<br />
* CapsNet with shallow structure (3 layers) achieves performance that only achieves by deeper network before.<br />
<br />
=== Interpretation of Each Capsule ===<br />
<br />
The authors suggest that they found evidence that dimension of some capsule always captures some variance of the digit, while some others represents the global combinations of different variations, this would open some possibility for interpretation of capsules in the future. After computing the activity vector for the correct digit capsule, the authors fed perturbed versions of those activity vectors to the decoder to examine the effect on reconstruction. Some results from perturbations are shown below, where each row represents the reconstructions when one of the 16 dimensions in the DigitCaps representation is tweaked by intervals of 0.05 from the range [-0.25, 0.25]: <br />
<br />
[[File:DRBC_Dimension.png|650px|center||Source: Sabour, Frosst, Hinton, 2017]]<br />
<br />
== affNIST == <br />
<br />
affNIT data set contains different affine transformation of original MINST data set. By the concept of capsule, CapsNet should gain more robustness from its equivariance nature, and the result confirms this. Compare the baseline CNN, CapsNet achieves 13% improvement on accuracy.<br />
<br />
== MultiMNIST ==<br />
<br />
The MultiMNIST is basically the overlapped version of MINIST. An important point to notice here is that this data set is generated by overlaying a digit on top of another digit from the same set but different class. In other words, the case of stacking digits from the same class is not allowed in MultiMINST. For example, stacking a 5 on a 0 is allowed, but stacking a 5 on another 5 is not. The reason is that CapsNet suffers from the "crowding" effect which will be discussed in the weakness of CapsNet section. <br />
<br />
== Other data sets ==<br />
<br />
CapsNet is used on other data sets such as CIFAR1-, smallNORB and SVHN. The results are not comparable with state-of-the-art performance, but it is still promising since this architecture is the very first, while other networks have been development for a long time.<br />
<br />
= Conclusion = <br />
<br />
This paper discuss the specific part of capsule network, which is the routing-by-agreement mechanism. The authors suggest this is a great approach to solve the current problem with max-pooling in convolutional neural network. Moreover, as author mentioned, the approach mentioned in this paper is only one possible implementation of the capsule concept. The preliminary results from experiment using a simple shallow CapsNet also demonstrate unparalleled performance that indicates the capsules are a direction worth exploring. <br />
<br />
= Weakness of Capsule Network =<br />
<br />
* Routing algorithm introduces internal loops for each capsule. As number of capsules and layers increases, these internal loops may exponentially expand the training time. <br />
* Capsule network suffers a perceptual phenomenon called "crowding", which is common for human vision as well. To address this weakness, capsules have to make a very strong representation assumption that at each location of the image, there is at most one instance of the type of entity that capsule represents. This is also the reason for not allowing overlaying digits from same class in generating process of MultiMINST.<br />
* Other criticisms include that the design of capsule networks requires domain knowledge or feature engineering, contrary to the abstraction-oriented goals of deep learning.<br />
<br />
= Implementations = <br />
1) Tensorflow Implementation : https://github.com/naturomics/CapsNet-Tensorflow<br />
<br />
2) Keras Implementation. : https://github.com/XifengGuo/CapsNet-Keras<br />
<br />
= References =<br />
# S. Sabour, N. Frosst, and G. E. Hinton, “Dynamic routing between capsules,” arXiv preprint arXiv:1710.09829v2, 2017<br />
# “XifengGuo/CapsNet-Keras.” GitHub, 14 Dec. 2017, github.com/XifengGuo/CapsNet-Keras. <br />
# “Naturomics/CapsNet-Tensorflow.” GitHub, 6 Mar. 2018, github.com/naturomics/CapsNet-Tensorflow.</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Dynamic_Routing_Between_Capsules_STAT946&diff=36069Dynamic Routing Between Capsules STAT9462018-04-03T15:42:29Z<p>X249wang: /* Interpretation of Each Capsule */</p>
<hr />
<div>= Presented by =<br />
<br />
Yang, Tong(Richard)<br />
<br />
= Contributions =<br />
<br />
This paper introduces the concept of "capsules" and an approach to implement its concept in neural networks. Capsules are a group of neurons used to represent various properties of an entity/object present in the image, such as pose, deformation, and even the existence of the entity. Instead of the obvious representation of a logistic unit for the probability of existence, the paper explores using the length of the capsule output vector to represent existence, and the orientation to represent other properties of the entity. The paper has the following major contributions:<br />
<br />
* Proposed an alternative approach to max-pooling, which is called routing-by-agreement.<br />
* Demonstrated an mathematical structure for capsule layers and routing mechanism that builds a prototype architecture for capsule networks. <br />
* Presented the promising results of CapsNet that confirms its value as a new direction for development in deep learning.<br />
<br />
= Hinton's Critiques on CNN =<br />
<br />
In the past talk, Hinton tried to explained why max-pooling is the biggest problem in current convolutional network structure, here are some highlights from his talk. <br />
<br />
== Four arguments against pooling ==<br />
<br />
* It is a bad fit to the psychology of shape perception: It does not explain why we assign intrinsic coordinate frames to objects and why they have such huge effects.<br />
<br />
* It solves the wrong problem: We want equivariance, not invariance. Disentangling rather than discarding.<br />
<br />
* It fails to use the underlying linear structure: It does not make use of the natural linear manifold that perfectly handles the largest source of variance in images.<br />
<br />
* Pooling is a poor way to do dynamic routing: We need to route each part of the input to the neurons that know how to deal with it. Finding the best routing is equivalent to parsing the image.<br />
<br />
===Intuition Behind Capsules ===<br />
We try to achieve viewpoint invariance in the activities of neurons by doing max-pooling. Invariance here means that by changing the input a little, the output still stays the same while the activity is just the output signal of a neuron. In other words, when in the input image we shift the object that we want to detect by a little bit, networks activities (outputs of neurons) will not change because of max pooling and the network will still detect the object. But the spacial relationships are not taken care of in this approach so instead capsules are used, because they encapsulate all important information about the state of the features they are detecting in a form of a vector. Capsules encode probability of detection of a feature as the length of their output vector. And the state of the detected feature is encoded as the direction in which that vector points to. So when detected feature moves around the image or its state somehow changes, the probability still stays the same (length of vector does not change), but its orientation changes.<br />
<br />
== Equivariance ==<br />
<br />
To deal with the invariance problem of CNN, Hinton proposes the concept called equivariance, which is the foundation of capsule concept.<br />
<br />
=== Two types of equivariance ===<br />
<br />
==== Place-coded equivariance ====<br />
If a low-level part moves to a very different position it will be represented by a different capsule.<br />
<br />
==== Rate-coded equivariance ====<br />
If a part only moves a small distance it will be represented by the same capsule but the pose outputs of the capsule will change.<br />
<br />
Higher-level capsules have bigger domains so low-level place-coded equivariance gets converted into high-level rate-coded equivariance.<br />
<br />
= Dynamic Routing =<br />
<br />
In the second section of this paper, authors give a mathematical representations for two key features in routing algorithm in capsule network, which are squashing and agreement. The general setting for this algorithm is between two arbitrary capsules i and j. Capsule j is assumed to be an arbitrary capsule from the first layer of capsules, and capsule i is an arbitrary capsule from the layer below. The purpose of routing algorithm is generate a vector output for routing decision between capsule j and capsule i. Furthermore, this vector output will be used in the decision for choice of dynamic routing. <br />
<br />
== Routing Algorithm ==<br />
<br />
The routing algorithm is as the following:<br />
<br />
[[File:DRBC_Figure_1.png|650px|center||Source: Sabour, Frosst, Hinton, 2017]]<br />
<br />
In the following sections, each part of this algorithm will be explained in details.<br />
<br />
=== Log Prior Probability ===<br />
<br />
<math>b_{ij}</math> represents the log prior probabilities that capsule i should be coupled to capsule j, and updated in each routing iteration. As line 2 suggests, the initial values of <math>b_{ij}</math> for all possible pairs of capsules are set to 0. In the very first routing iteration, <math>b_{ij}</math> equals to zero. For each routing iteration, <math>b_{ij}</math> gets updated by the value of agreement, which will be explained later.<br />
<br />
=== Coupling Coefficient === <br />
<br />
<math>c_{ij}</math> represents the coupling coefficient between capsule j and capsule i. It is calculated by applying the softmax function on the log prior probability <math>b_{ij}</math>. The mathematical transformation is shown below (Equation 3 in paper): <br />
<br />
\begin{align}<br />
c_{ij} = \frac{exp(b_ij)}{\sum_{k}exp(b_ik)}<br />
\end{align}<br />
<br />
<math>c_{ij}</math> are served as weights for computing the weighted sum and probabilities. Therefore, as probabilities, they have the following properties:<br />
<br />
\begin{align}<br />
c_{ij} \geq 0, \forall i, j<br />
\end{align}<br />
<br />
and, <br />
<br />
\begin{align}<br />
\sum_{i,j}c_{ij} = 1, \forall i, j<br />
\end{align}<br />
<br />
=== Predicted Output from Layer Below === <br />
<br />
<math>u_{i}</math> are the output vector from capsule i in the lower layer, and <math>\hat{u}_{j|i}</math> are the input vector for capsule j, which are the "prediction vectors" from the capsules in the layer below. <math>\hat{u}_{j|i}</math> is produced by multiplying <math>u_{i}</math> by a weight matrix <math>W_{ij}</math>, such as the following:<br />
<br />
\begin{align}<br />
\hat{u}_{j|i} = W_{ij}u_i<br />
\end{align}<br />
<br />
where <math>W_{ij}</math> encodes some spatial relationship between capsule j and capsule i.<br />
<br />
=== Capsule ===<br />
<br />
By using the definitions from previous sections, the total input vector for an arbitrary capsule j can be defined as:<br />
<br />
\begin{align}<br />
s_j = \sum_{i}c_{ij}\hat{u}_{j|i}<br />
\end{align}<br />
<br />
which is a weighted sum over all prediction vectors by using coupling coefficients.<br />
<br />
=== Squashing ===<br />
<br />
The length of <math>s_j</math> is arbitrary, which is needed to be addressed with. The next step is to convert its length between 0 and 1, since we want the length of the output vector of a capsule to represent the probability that the entity represented by the capsule is present in the current input. The "squashing" process is shown below:<br />
<br />
\begin{align}<br />
v_j = \frac{||s_j||^2}{1+||s_j||^2}\frac{s_j}{||s_j||}<br />
\end{align}<br />
<br />
Notice that "squashing" is not just normalizing the vector into unit length. In addition, it does extra non-linear transformation to ensure that short vectors get shrunk to almost zero length and long vectors get shrunk to a length slightly below 1. The reason for doing this is to make decision of routing, which is called "routing by agreement" much easier to make between capsule layers.<br />
<br />
=== Agreement ===<br />
<br />
The final step of a routing iteration is to form an routing agreement <math>a_{ij}</math>, which is represents as a scalar product:<br />
<br />
\begin{align}<br />
a_{ij} = v_{j}\hat{u}_{j|i}<br />
\end{align}<br />
<br />
As we mentioned in "squashing" section, the length of <math>v_{j}</math> is either close to 0 or close to 1, which will effect the magnitude of <math>a_{ij}</math> in this case. Therefore, the magnitude of <math>a_{ij}</math> indicate the how strong the routing algorithm agrees on taking the route between capsule j and capsule i. For each routing iteration, the log prior probability, <math>b_{ij}</math> will be updated by adding the value of its agreement value, which will effect how the coupling coefficients are computed in the next routing iteration. Because of the "squashing" process, we will eventually end up with a capsule j with its <math>v_{j}</math> close to 1 while all other capsules with its <math>v_{j}</math> close to 0, which indicates that this capsule j should be activated.<br />
<br />
= CapsNet Architecture =<br />
<br />
The second part of this paper discuss the experiment results from a 3-layer CapsNet, the architecture can be divided into two parts, encoder and decoder. <br />
<br />
== Encoder == <br />
<br />
[[File:DRBC_Architecture.png|650px|center||Source: Sabour, Frosst, Hinton, 2017]]<br />
<br />
=== How many routing iteration to use? === <br />
In appendix A of this paper, the authors have shown the empirical results from 500 epochs of training at different choice of routing iterations. According to their observation, more routing iterations increases the capacity of CapsNet but tends to bring additional risk of overfitting. Moreover, CapsNet with routing iterations less than three are not effective in general. As result, they suggest 3 iterations of routing for all experiments.<br />
<br />
=== Marginal loss for digit existence ===<br />
<br />
The experiments performed include segmenting overlapping digits on MultiMINST data set, so the loss function has be adjusted for presents of multiple digits. The marginal lose <math>L_k</math> for each capsule k is calculate by:<br />
<br />
\begin{align}<br />
L_k = T_k max(0, m^+ - ||v_k||)^2 + \lambda(1 - T_k) max(0, ||v_k|| - m^-)^2<br />
\end{align}<br />
<br />
where <math>m^+ = 0.9</math>, <math>m^- = 0.1</math>, and <math>\lambda = 0.5</math>.<br />
<br />
<math>T_k</math> is an indicator for presence of digit of class k, it takes value of 1 if and only if class k is presented. If class k is not presented, <math>\lambda</math> down-weight the loss which shrinks the lengths of the activity vectors for all the digit capsules. By doing this, The loss function penalizes the initial learning for all absent digit class, since we would like the top-level capsule for digit class k to have long instantiation vector if and only if that digit class is present in the input.<br />
<br />
=== Layer 1: Conv1 === <br />
<br />
The first layer of CapsNet. Similar to CNN, this is just convolutional layer that converts pixel intensities to activities of local feature detectors. <br />
<br />
* Layer Type: Convolutional Layer.<br />
* Input: <math>28 \times 28</math> pixels.<br />
* Kernel size: <math>9 \times 9</math>.<br />
* Number of Kernels: 256.<br />
* Activation function: ReLU.<br />
* Output: <math>20 \times 20 \times 256</math> tensor.<br />
<br />
=== Layer 2: PrimaryCapsules ===<br />
<br />
The second layer is formed by 32 primary 8D capsules. By 8D, it means that each primary capsule contains 8 convolutional units with a <math>9 \times 9</math> kernel and a stride of 2. Each capsule will take a <math>20 \times 20 \times 256</math> tensor from Conv1 and produce an output of a <math>6 \times 6 \times 8</math> tensor.<br />
<br />
* Layer Type: Convolutional Layer<br />
* Input: <math>20 \times 20 \times 256</math> tensor.<br />
* Number of capsules: 32.<br />
* Number of convolutional units in each capsule: 8.<br />
* Size of each convolutional unit: <math>6 \times 6</math>.<br />
* Output: <math>6 \times 6 \times 8</math> 8-dimensional vectors.<br />
<br />
=== Layer 3: DigitsCaps ===<br />
<br />
The last layer has 10 16D capsules, one for each digit. Not like the PrimaryCapsules layer, this layer is fully connected. Since this is the top capsule layer, dynamic routing mechanism will be applied between DigitsCaps and PrimaryCapsules. The process begins by taking a transformation of predicted output from PrimaryCapsules layer. Each output is a 8-dimensional vector, which needed to be mapped to a 16-dimensional space. Therefore, the weight matrix, <math>W_{ij}</math> is a <math>8 \times 16</math> matrix. The next step is to acquire coupling coefficients from routing algorithm and to perform "squashing" to get the output. <br />
<br />
* Layer Type: Fully connected layer.<br />
* Input: <math>6 \times 6 \times 8</math> 8-dimensional vectors.<br />
* Output: <math>16 \times 10 </math> matrix.<br />
<br />
=== The loss function ===<br />
<br />
The output of the loss function would be a ten-dimensional one-hot encoded vector with 9 zeros and 1 one at the correct position.<br />
<br />
<br />
== Regularization Method: Reconstruction ==<br />
<br />
This is regularization method introduced in the implementation of CapsNet. The method is to introduce a reconstruction loss (scaled down by 0.0005) to margin loss during training. The authors argue this would encourage the digit capsules to encode the instantiation parameters the input digits. All the reconstruction during training is by using the true labels of the image input. The results from experiments also confirms that adding the reconstruction regularizer enforces the pose encoding in CapsNet and thus boots the performance of routing procedure. <br />
<br />
=== Decoder ===<br />
<br />
The decoder consists of 3 fully connected layers, each layer maps pixel intensities to pixel intensities. The number of parameters in each layer and the activation functions used are indicated in the figure below:<br />
<br />
[[File:DRBC_Decoder.png|650px|center||Source: Sabour, Frosst, Hinton, 2017]]<br />
<br />
=== Result ===<br />
<br />
The authors includes some results for CapsNet classification test accuracy to justify the result of reconstruction. We can see that for CapsNet with 1 routing iteration and CapsNet with 3 routing iterations, implement reconstruction shows significant improvements in both MINIST and MultiMINST data set. These improvements show the importance of routing and reconstruction regularizer. <br />
<br />
[[File:DRBC_Reconstruction.png|650px|center||Source: Sabour, Frosst, Hinton, 2017]]<br />
<br />
= Experiment Results for CapsNet = <br />
<br />
In this part, the authors demonstrate experiment results of CapsNet on different data sets, such as MINIST and different variation of MINST, such as expanded MINST, affNIST, MultiMNIST. Moreover, they also briefly discuss the performance on some other popular data set such CIFAR 10. <br />
<br />
== MINST ==<br />
<br />
=== Highlights ===<br />
<br />
* CapsNet archives state-of-the-art performance on MINST.<br />
* CapsNet with shallow structure (3 layers) achieves performance that only achieves by deeper network before.<br />
<br />
=== Interpretation of Each Capsule ===<br />
<br />
The authors suggest that they found evidence that dimension of some capsule always captures some variance of the digit, while some others represents the global combinations of different variations, this would open some possibility for interpretation of capsules in the future. After computing the activity vector for the correct digit capsule, the authors fed perturbed versions of those activity vectors to the decoder to examine the effect on reconstruction. Some results from perturbations are shown below, where each row represents the reconstructions when one of the 16 dimensions in the DigitCaps representation is tweaked by intervals of 0.05 from the range [-0.25, 0.25]: <br />
<br />
[[File:DRBC_Dimension.png|650px|center||Source: Sabour, Frosst, Hinton, 2017]]<br />
<br />
== affNIST == <br />
<br />
affNIT data set contains different affine transformation of original MINST data set. By the concept of capsule, CapsNet should gain more robustness from its equivariance nature, and the result confirms this. Compare the baseline CNN, CapsNet achieves 13% improvement on accuracy.<br />
<br />
== MultiMNIST ==<br />
<br />
The MultiMNIST is basically the overlapped version of MINIST. An important point to notice here is that this data set is generated by overlaying a digit on top of another digit from the same set but different class. In other words, the case of stacking digits from the same class is not allowed in MultiMINST. For example, stacking a 5 on a 0 is allowed, but stacking a 5 on another 5 is not. The reason is that CapsNet suffers from the "crowding" effect which will be discussed in the weakness of CapsNet section. <br />
<br />
== Other data sets ==<br />
<br />
CapsNet is used on other data sets such as CIFAR1-, smallNORB and SVHN. The results are not comparable with state-of-the-art performance, but it is still promising since this architecture is the very first, while other networks have been development for a long time.<br />
<br />
= Conclusion = <br />
<br />
This paper discuss the specific part of capsule network, which is the routing-by-agreement mechanism. The authors suggest this is a great approach to solve the current problem with max-pooling in convolutional neural network. Moreover, as author mentioned, the approach mentioned in this paper is only one possible implementation of the capsule concept. The preliminary results from experiment using a simple shallow CapsNet also demonstrate unparalleled performance that indicates the capsules are a direction worth exploring. <br />
<br />
= Weakness of Capsule Network =<br />
<br />
* Routing algorithm introduces internal loops for each capsule. As number of capsules and layers increases, these internal loops may exponentially expand the training time. <br />
* Capsule network suffers a perceptual phenomenon called "crowding", which is common for human vision as well. To address this weakness, capsules have to make a very strong representation assumption that at each location of the image, there is at most one instance of the type of entity that capsule represents. This is also the reason for not allowing overlaying digits from same class in generating process of MultiMINST.<br />
* Other criticisms include that the design of capsule networks requires domain knowledge or feature engineering, contrary to the abstraction-oriented goals of deep learning.<br />
<br />
= Implementations = <br />
1) Tensorflow Implementation : https://github.com/naturomics/CapsNet-Tensorflow<br />
<br />
2) Keras Implementation. : https://github.com/XifengGuo/CapsNet-Keras<br />
<br />
= References =<br />
# S. Sabour, N. Frosst, and G. E. Hinton, “Dynamic routing between capsules,” arXiv preprint arXiv:1710.09829v2, 2017<br />
# “XifengGuo/CapsNet-Keras.” GitHub, 14 Dec. 2017, github.com/XifengGuo/CapsNet-Keras. <br />
# “Naturomics/CapsNet-Tensorflow.” GitHub, 6 Mar. 2018, github.com/naturomics/CapsNet-Tensorflow.</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=stat946w18/Tensorized_LSTMs&diff=35879stat946w18/Tensorized LSTMs2018-03-28T22:57:26Z<p>X249wang: /* Part 3: Extending to LSTMs */</p>
<hr />
<div>= Presented by =<br />
<br />
Chen, Weishi(Edward)<br />
<br />
= Introduction =<br />
<br />
Long Short-Term Memory (LSTM) is a popular approach to boosting the ability of Recurrent Neural Networks to store longer term temporal information. The capacity of an LSTM network can be increased by widening and adding layers (illustrations will be provided later). <br />
<br />
<br />
However, usually the LSTM model introduces additional parameters, while LSTM with additional layers and wider layers increases the time required for model training and evaluation. As an alternative, the paper <Wider and Deeper, Cheaper and Faster: Tensorized LSTMs for Sequence Learning> has proposed a model based on LSTM call the Tensorized LSTM in which the hidden states are represented by '''tensors''' and updated via a '''cross-layer convolution'''. <br />
<br />
* By increasing the tensor size, the network can be widened efficiently without additional parameters since the parameters are shared across different locations in the tensor<br />
* By delaying the output, the network can be deepened implicitly with little additional runtime since deep computations for each time step are merged into temporal computations of the sequence. <br />
<br />
<br />
Also, the paper has presented presented experiments conducted on five challenging sequence learning tasks show the potential of the proposed model.<br />
<br />
= A Quick Introduction to RNN and LSTM =<br />
<br />
We consider the time-series prediction task of producing a desired output <math>y_t</math> at each time-step t∈ {1, ..., T} given an observed input sequence <math>x1: t = {x_1,x_2, ···, x_t}</math>, where <math>x_t∈R^R</math> and <math>y_t∈R^S</math> are vectors. RNN learns how to use a hidden state vector <math>h_t ∈ R^M</math> to encapsulate the relevant features of the entire input history x1:t (indicates all inputs from to initial time-step to final step before predication - illustration given below) up to time-step t.<br />
<br />
\begin{align}<br />
h_{t-1}^{cat} = [x_t, h_{t-1}] \hspace{2cm} (1)<br />
\end{align}<br />
<br />
Where <math>h_{t-1}^{cat} ∈R^{R+M}</math> is the concatenation of the current input <math>x_t</math> and the previous hidden state <math>h_{t−1}</math>, which expands the dimensionality of intermediate information.<br />
<br />
The update of the hidden state ht is defined as:<br />
<br />
\begin{align}<br />
a_{t} =h_{t-1}^{cat} W^h + b^h \hspace{2cm} (2)<br />
\end{align}<br />
<br />
and<br />
<br />
\begin{align}<br />
h_t = \Phi(a_t) \hspace{2cm} (3)<br />
\end{align}<br />
<br />
<math>W^h∈R^(R+M)xM </math> guarantees each hidden status provided by the previous step is of dimension M. <math> a_t ∈R^M </math> the hidden activation, and φ(·) the element-wise "tanh" function. Finally, the output <math> y_t </math> at time-step t is generated by:<br />
<br />
\begin{align}<br />
y_t = \varphi(h_{t}^{cat} W^y + b^y) \hspace{2cm} (4)<br />
\end{align}<br />
<br />
where <math>W^y∈R^{M×S}</math> and <math>b^y∈R^S</math>, and <math>\varphi(·)</math> can be any differentiable function, notes that the "Phi" is the element-wise function which produces some non-linearity and further generates another '''hidden status''', while the "Curly Phi" is applied to generates the '''output'''<br />
<br />
[[File:StdRNN.png|650px|center||Figure 1: Recurrent Neural Network]]<br />
<br />
However, one shortfall of RNN is the vanishing/exploding gradients. This shortfall is more significant especially when constructing long-range dependencies models. One alternative is to apply LSTM (Long Short-Term Memories), LSTMs alleviate these problems by employing memory cells to preserve information for longer, and adopting gating mechanisms to modulate the information flow. Since LSTM is successfully in sequence models, it is natural to consider how to increase the complexity of the model to accommodate more complex analytical needs.<br />
<br />
[[File:LSTM_Gated.png|650px|center||Figure 2: LSTM]]<br />
<br />
= Structural Measurement of Sequential Model =<br />
<br />
We can consider the capacity of a network consists of two components: the '''width''' (the amount of information handled in parallel) and the depth (the number of computation steps). <br />
<br />
A way to '''widen''' the LSTM is to increase the number of units in a hidden layer; however, the parameter number scales quadratically with the number of units. To deepen the LSTM, the popular Stacked LSTM (sLSTM) stacks multiple LSTM layers. The drawback of sLSTM, however, is that runtime is proportional to the number of layers and information from the input is potentially lost (due to gradient vanishing/explosion) as it propagates vertically through the layers. This paper introduced a way to both widen and deepen the LSTM whilst keeping the parameter number and runtime largely unchanged. In summary, we make the following contributions:<br />
<br />
'''(a)''' Tensorize RNN hidden state vectors into higher-dimensional tensors, to enable more flexible parameter sharing and can be widened more efficiently without additional parameters.<br />
<br />
'''(b)''' Based on (a), merge RNN deep computations into its temporal computations so that the network can be deepened with little additional runtime, resulting in a Tensorized RNN (tRNN).<br />
<br />
'''(c)''' We extend the tRNN to an LSTM, namely the Tensorized LSTM (tLSTM), which integrates a novel memory cell convolution to help to prevent the vanishing/exploding gradients.<br />
<br />
= Method =<br />
<br />
Go through the methodology.<br />
<br />
== Part 1: Tensorize RNN hidden State vectors ==<br />
<br />
'''Definition:''' Tensorization is defined as the transformation or mapping of lower-order data to higher-order data. For example, the low-order data can be a vector, and the tensorized result is a matrix, a third-order tensor or a higher-order tensor. The ‘low-order’ data can also be a matrix or a third-order tensor, for example. In the latter case, tensorization can take place along one or multiple modes.<br />
<br />
[[File:VecTsor.png|320px|center||Figure 3: Vector Third-order tensorization of a vector]]<br />
<br />
'''Optimization Methodology Part 1:''' It can be seen that in an RNN, the parameter number scales quadratically with the size of the hidden state. A popular way to limit the parameter number when widening the network is to organize parameters as higher-dimensional tensors which can be factorized into lower-rank sub-tensors that contain significantly fewer elements, which is is known as tensor factorization. <br />
<br />
'''Optimization Methodology Part 2:''' Another common way to reduce the parameter number is to share a small set of parameters across different locations in the hidden state, similar to Convolutional Neural Networks (CNNs).<br />
<br />
'''Effects:''' This '''widens''' the network since the hidden state vectors are in fact broadcast to interact with the tensorized parameters. <br />
<br />
<br />
<br />
We adopt parameter sharing to cutdown the parameter number for RNNs, since compared with factorization, it has the following advantages: <br />
<br />
(i) '''Scalability,''' the number of shared parameters can be set independent of the hidden state size<br />
<br />
(ii) '''Separability,''' the information flow can be carefully managed by controlling the receptive field, allowing one to shift RNN deep computations to the temporal domain<br />
<br />
<br />
<br />
We also explicitly tensorize the RNN hidden state vectors, since compared with vectors, tensors have a better: <br />
<br />
(i) '''Flexibility,''' one can specify which dimensions to share parameters and then can just increase the size of those dimensions without introducing additional parameters<br />
<br />
(ii) '''Efficiency,''' with higher-dimensional tensors, the network can be widened faster w.r.t. its depth when fixing the parameter number (explained later). <br />
<br />
<br />
'''Illustration:''' For ease of exposition, we first consider 2D tensors (matrices): we tensorize the hidden state <math>h_t∈R^{M}</math> to become <math>Ht∈R^{P×M}</math>, '''where P is the tensor size,''' and '''M the channel size'''. We locally-connect the first dimension of <math>H_t</math> (which is P - the tensor size) in order to share parameters, and fully-connect the second dimension of <math>H_t</math> (which is M - the channel size) to allow global interactions. This is analogous to the CNN which fully-connects one dimension (e.g., the RGB channel for input images) to globally fuse different feature planes. Also, if one compares <math>H_t</math> to the hidden state of a Stacked RNN (sRNN) (see Figure Blow). <br />
<br />
[[File:Screen_Shot_2018-03-26_at_11.28.37_AM.png|160px|center||Figure 4: Stacked RNN]]<br />
<br />
[[File:ind.png|60px|center||Figure 4: Stacked RNN]]<br />
<br />
Then P is akin to the number of stacked hidden layers (vertical length in the graph), and M the size of each hidden layer (each white node in the graph). We start to describe our model based on 2D tensors, and finally show how to strengthen the model with higher-dimensional tensors.<br />
<br />
== Part 2: Merging Deep Computations ==<br />
<br />
Since an RNN is already deep in its temporal direction, we can deepen an input-to-output computation by associating the input <math>x_t</math> with a (delayed) future output. In doing this, we need to ensure that the output <math>y_t</math> is separable, i.e., not influenced by any future input <math>x_{t^{'}}</math> <math>(t^{'}>t)</math>. Thus, we concatenate the projection of <math>x_t</math> to the top of the previous hidden state <math>H_{t−1}</math>, then gradually shift the input information down when the temporal computation proceeds, and finally generate <math>y_t</math> from the bottom of <math>H_{t+L−1}</math>, where L−1 is the number of delayed time-steps for computations of depth L. <br />
<br />
An example with L= 3 is shown in Figure.<br />
<br />
[[File:tRNN.png|160px|center||Figure 5: skewed sRNN]]<br />
<br />
[[File:ind.png|60px|center||Figure 5: skewed sRNN]]<br />
<br />
<br />
This is in fact a skewed sRNN (or tRNN without feedback). However, the method does not need to change the network structure and also allows different kinds of interactions as long as the output is separable; for example, one can increase the local connections and '''use feedback''' (shown in figure below), which can be beneficial for sRNNs (or tRNN). <br />
<br />
[[File:tRNN_wF.png|160px|center||Figure 5: skewed sRNN with F]]<br />
<br />
[[File:ind.png|60px|center||Figure 5: skewed sRNN with F]]<br />
<br />
'''In order to share parameters, we update <math>H_t</math> using a convolution with a learnable kernel.''' In this manner we increase the complexity of the input-to-output mapping (by delaying outputs) and limit parameter growth (by sharing transition parameters using convolutions).<br />
<br />
To examine the resulting model mathematically, let <math>H^{cat}_{t−1}∈R^{(P+1)×M}</math> be the concatenated hidden state, and <math>p∈Z_+</math> the location at a tensor. The channel vector <math>h^{cat}_{t−1, p }∈R^M</math> at location p of <math>H^{cat}_{t−1}</math> (the p-th channel of H) is defined as:<br />
<br />
\begin{align}<br />
h^{cat}_{t-1, p} = x_t W^x + b^x \hspace{1cm} if p = 1 \hspace{1cm} (5)<br />
\end{align}<br />
<br />
\begin{align}<br />
h^{cat}_{t-1, p} = h_{t-1, p-1} \hspace{1cm} if p > 1 \hspace{1cm} (6)<br />
\end{align}<br />
<br />
where <math>W^x ∈ R^{R×M}</math> and <math>b^x ∈ R^M</math> (recall the dimension of input x is R). Then, the update of tensor <math>H_t</math> is implemented via a convolution:<br />
<br />
\begin{align}<br />
A_t = H^{cat}_{t-1} \circledast \{W^h, b^h \} \hspace{2cm} (7)<br />
\end{align}<br />
<br />
\begin{align}<br />
H_t = \Phi{A_t} \hspace{2cm} (8)<br />
\end{align}<br />
<br />
where <math>W^h∈R^{K×M^i×M^o}</math> is the kernel weight of size K, with <math>M^i =M</math> input channels and <math>M^o =M</math> output channels, <math>b^h ∈ R^{M^o}</math> is the kernel bias, <math>A_t ∈ R^{P×M^o}</math> is the hidden activation, and <math>\circledast</math> is the convolution operator. Since the kernel convolves across different hidden layers, we call it the cross-layer convolution. The kernel enables interaction, both bottom-up and top-down across layers. Finally, we generate <math>y_t</math> from the channel vector <math>h_{t+L−1,P}∈R^M</math> which is located at the bottom of <math>H_{t+L−1}</math>:<br />
<br />
\begin{align}<br />
y_t = \varphi(h_{t+L−1}, _PW^y + b^y) \hspace{2cm} (9)<br />
\end{align}<br />
<br />
Where <math>W^y ∈R^{M×S}</math> and <math>b^y ∈R^S</math>. To guarantee that the receptive field of <math>y_t</math> only covers the current and previous inputs x1:t. (Check the Skewed sRNN again below):<br />
<br />
[[File:tRNN_wF.png|160px|center||Figure 5: skewed sRNN with F]]<br />
<br />
[[File:ind.png|60px|center||Figure 5: skewed sRNN with F]]<br />
<br />
=== Quick Summary of Set of Parameters ===<br />
<br />
'''1. <math> W^x</math> and <math>b_x</math>''' connect input to the first hidden node<br />
<br />
'''2. <math> W^h</math> and <math>b_h</math>''' convolute between layers<br />
<br />
'''3. <math> W^y</math> and <math>b_y</math>''' produce output of each stages<br />
<br />
<br />
== Part 3: Extending to LSTMs==<br />
<br />
Similar to standard RNN, to allow the tRNN (skewed sRNN) to capture long-range temporal dependencies, one can straightforwardly extend it<br />
to a tLSTM by replacing the tRNN tensors:<br />
<br />
\begin{align}<br />
[A^g_t, A^i_t, A^f_t, A^o_t] = H^{cat}_{t-1} \circledast \{W^h, b^h \} \hspace{2cm} (10)<br />
\end{align}<br />
<br />
\begin{align}<br />
[G_t, I_t, F_t, O_t]= [\Phi{(A^g_t)}, σ(A^i_t), σ(A^f_t), σ(A^o_t)] \hspace{2cm} (11)<br />
\end{align}<br />
<br />
Which are pretty similar to tRNN case, the main differences can be observes for memory cells of tLSTM (Ct):<br />
<br />
\begin{align}<br />
C_t= G_t \odot I_t + C_{t-1} \odot F_t \hspace{2cm} (12)<br />
\end{align}<br />
<br />
\begin{align}<br />
H_t= \Phi{(C_t )} \odot O_t \hspace{2cm} (13)<br />
\end{align}<br />
<br />
Summary of the terms: <br />
<br />
1. '''<math>G_t</math>:''' Activation of new content<br />
<br />
2. '''<math>I_t</math>:''' Input gate<br />
<br />
3. '''<math>F_t</math>:''' Forget gate<br />
<br />
4. '''<math>O_t</math>:''' Output gate<br />
<br />
Then, see graph below for illustration:<br />
<br />
[[File:tLSTM_wo_MC.png |160px|center||Figure 5: tLSTM wo MC]]<br />
<br />
[[File:ind.png|60px|center||Figure 5: tLSTM wo MC]]<br />
<br />
To further evolve tLSTM, we invoke the '''Memory Cell Convolution''' to capture long-range dependencies from multiple directions, we additionally introduce a novel memory cell convolution, by which the memory cells can have a larger receptive field (figure provided below). <br />
<br />
[[File:tLSTM_w_MC.png |160px|center||Figure 5: tLSTM w MC]]<br />
<br />
[[File:ind.png|60px|center||Figure 5: tLSTM w MC]]<br />
<br />
One can also dynamically generate this convolution kernel so that it is both time - and location-dependent, allowing for flexible control over long-range dependencies from different directions. Mathematically, it can be represented in with the following formulas:<br />
<br />
\begin{align}<br />
[A^g_t, A^i_t, A^f_t, A^o_t, A^q_t] = H^{cat}_{t-1} \circledast \{W^h, b^h \} \hspace{2cm} (14)<br />
\end{align}<br />
<br />
\begin{align}<br />
[G_t, I_t, F_t, O_t, Q_t]= [\Phi{(A^g_t)}, σ(A^i_t), σ(A^f_t), σ(A^o_t), ς(A^q_t)] \hspace{2cm} (15)<br />
\end{align}<br />
<br />
\begin{align}<br />
W_t^c(p) = reshape(q_{t,p}, [K, 1, 1]) \hspace{2cm} (16)<br />
\end{align}<br />
<br />
\begin{align}<br />
C_{t-1}^{conv}= C_{t-1} \circledast W_t^c(p) \hspace{2cm} (17)<br />
\end{align}<br />
<br />
\begin{align}<br />
C_t= G_t \odot I_t + C_{t-1}^{conv} \odot F_t \hspace{2cm} (18)<br />
\end{align}<br />
<br />
\begin{align}<br />
H_t= \Phi{(C_t )} \odot O_t \hspace{2cm} (19)<br />
\end{align}<br />
<br />
where the kernel <math>{W^h, b^h}</math> has additional <K> output channels to generate the activation <math>A^q_t ∈ R^{P×<K>}</math> for the dynamic kernel bank <math>Q_t∈R^{P × <K>}</math>, <math>q_{t,p}∈R^{<K>}</math> is the vectorized adaptive kernel at the location p of <math>Q_t</math>, and <math>W^c_t(p) ∈ R^{K×1×1}</math> is the dynamic kernel of size K with a single input/output channel, which is reshaped from <math>q_{t,p}</math>. Note the paper also employed a softmax function ς(·) to normalize the channel dimension of <math>Q_t</math>. which can also stabilize the value of memory cells and help to prevent the vanishing/exploding gradients. An illustration is provided below to better illustrate the process:<br />
<br />
[[File:MCC.png |240px|center||Figure 5: MCC]]<br />
<br />
To improve training, the authors introduced a new normalization technique for ''t''LSTM termed channel normalization (adapted from layer normalization), in which the channel vector are normalized at different locations with their own statistics. Note that layer normalization does not work well with ''t''LSTM, because lower level information is near the input and higher level information is near the output. Channel normalization (CN) is defined as: <br />
<br />
\begin{align}<br />
\mathrm{CN}(\mathbf{Z}; \mathbf{\Gamma}, \mathbf{B}) = \mathbf{\hat{Z}} \odot \mathbf{\Gamma} + \mathbf{B} \hspace{2cm} (20)<br />
\end{align}<br />
<br />
where <math>\mathbf{Z}</math>, <math>\mathbf{\hat{Z}}</math>, <math>\mathbf{\Gamma}</math>, <math>\mathbf{B} \in \mathbb{R}^{P \times M^z}</math> are the original tensor, normalized tensor, gain parameter and bias parameter. The <math>m^z</math>-th channel of <math>\mathbf{Z}</math> is normalized element-wisely: <br />
<br />
\begin{align}<br />
\hat{z_{m^z}} = (z_{m^z} - z^\mu)/z^{\sigma} \hspace{2cm} (21)<br />
\end{align}<br />
<br />
where <math>z^{\mu}</math>, <math>z^{\sigma} \in \mathbb{R}^P</math> are the mean and standard deviation along the channel dimension of <math>\mathbf{Z}</math>, and <math>\hat{z_{m^z}} \in \mathbb{R}^P</math> is the <math>m^z</math>-th channel <math>\mathbf{\hat{Z}}</math>. Channel normalization introduces very few additional parameters compared to the number of other parameters in the model.<br />
<br />
= Results and Evaluation =<br />
<br />
Summary of list of models tLSTM family (may be useful later):<br />
<br />
(a) sLSTM (baseline): the implementation of sLSTM with parameters shared across all layers.<br />
<br />
(b) 2D tLSTM: the standard 2D tLSTM.<br />
<br />
(c) 2D tLSTM–M: removing memory (M) cell convolutions from (b).<br />
<br />
(d) 2D tLSTM–F: removing (–) feedback (F) connections from (b).<br />
<br />
(e) 3D tLSTM: tensorizing (b) into 3D tLSTM.<br />
<br />
(f) 3D tLSTM+LN: applying (+) Layer Normalization.<br />
<br />
(g) 3D tLSTM+CN: applying (+) Channel Normalization.<br />
<br />
=== Efficiency Analysis ===<br />
<br />
'''Fundaments:''' For each configuration, fix the parameter number and increase the tensor size to see if the performance of tLSTM can be boosted without increasing the parameter number. Can also investigate how the runtime is affected by the depth, where the runtime is measured by the average GPU milliseconds spent by a forward and backward pass over one timestep of a single example. <br />
<br />
'''Dataset:''' The Hutter Prize Wikipedia dataset consists of 100 million characters taken from 205 different characters including alphabets, XML markups and special symbols. We model the dataset at the character-level, and try to predict the next character of the input sequence.<br />
<br />
All configurations are evaluated with depths L = 1, 2, 3, 4. Bits-per-character(BPC) is used to measure the model performance and the results are shown in the figure below.<br />
[[File:wiki.png |280px|center||Figure 5: WifiPerf]]<br />
[[File:Wiki_Performance.png |480px|center||Figure 5: WifiPerf]]<br />
<br />
=== Accuracy Analysis ===<br />
<br />
The MNIST dataset [35] consists of 50000/10000/10000 handwritten digit images of size 28×28 for training/validation/test. We have two tasks on this dataset:<br />
<br />
(a) '''Sequential MNIST:''' The goal is to classify the digit after sequentially reading the pixels in a scan-line order. It is therefore a 784 time-step sequence learning task where a single output is produced at the last time-step; the task requires very long range dependencies in the sequence.<br />
<br />
(b) '''Sequential Permuted MNIST:''' We permute the original image pixels in a fixed random order, resulting in a permuted MNIST (pMNIST) problem that has even longer range dependencies across pixels and is harder.<br />
<br />
In both tasks, all configurations are evaluated with M = 100 and L= 1, 3, 5. The model performance is measured by the classification accuracy and results are shown in the figure below.<br />
<br />
[[File:MNISTperf.png |480px|center]]<br />
<br />
<br />
<br />
[[File:Acc_res.png |480px|center||Figure 5: MNIST]]<br />
<br />
[[File:33_mnist.PNG|center|thumb|800px| This figure displays a visualization of the means of the diagonal channels of the tLSTM memory cells per task. The columns indicate the time steps and the rows indicate the diagonal locations. The values are normalized between 0 and 1.]]<br />
<br />
= Conclusions =<br />
<br />
The paper introduced the Tensorized LSTM, which employs tensors to share parameters and utilizes the temporal computation to perform the deep computation for sequential tasks. Then validated the model<br />
on a variety of tasks, showing its potential over other popular approaches.<br />
<br />
= Critique(to be edited) =<br />
<br />
= References =<br />
#Zhen He, Shaobing Gao, Liang Xiao, Daxue Liu, Hangen He, and David Barber. <Wider and Deeper, Cheaper and Faster: Tensorized LSTMs for Sequence Learning> (2017)<br />
#Ali Ghodsi, <Deep Learning: STAT 946 - Winter 2018></div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=stat946w18/Tensorized_LSTMs&diff=35878stat946w18/Tensorized LSTMs2018-03-28T22:55:29Z<p>X249wang: /* Added summary on channel normalization */</p>
<hr />
<div>= Presented by =<br />
<br />
Chen, Weishi(Edward)<br />
<br />
= Introduction =<br />
<br />
Long Short-Term Memory (LSTM) is a popular approach to boosting the ability of Recurrent Neural Networks to store longer term temporal information. The capacity of an LSTM network can be increased by widening and adding layers (illustrations will be provided later). <br />
<br />
<br />
However, usually the LSTM model introduces additional parameters, while LSTM with additional layers and wider layers increases the time required for model training and evaluation. As an alternative, the paper <Wider and Deeper, Cheaper and Faster: Tensorized LSTMs for Sequence Learning> has proposed a model based on LSTM call the Tensorized LSTM in which the hidden states are represented by '''tensors''' and updated via a '''cross-layer convolution'''. <br />
<br />
* By increasing the tensor size, the network can be widened efficiently without additional parameters since the parameters are shared across different locations in the tensor<br />
* By delaying the output, the network can be deepened implicitly with little additional runtime since deep computations for each time step are merged into temporal computations of the sequence. <br />
<br />
<br />
Also, the paper has presented presented experiments conducted on five challenging sequence learning tasks show the potential of the proposed model.<br />
<br />
= A Quick Introduction to RNN and LSTM =<br />
<br />
We consider the time-series prediction task of producing a desired output <math>y_t</math> at each time-step t∈ {1, ..., T} given an observed input sequence <math>x1: t = {x_1,x_2, ···, x_t}</math>, where <math>x_t∈R^R</math> and <math>y_t∈R^S</math> are vectors. RNN learns how to use a hidden state vector <math>h_t ∈ R^M</math> to encapsulate the relevant features of the entire input history x1:t (indicates all inputs from to initial time-step to final step before predication - illustration given below) up to time-step t.<br />
<br />
\begin{align}<br />
h_{t-1}^{cat} = [x_t, h_{t-1}] \hspace{2cm} (1)<br />
\end{align}<br />
<br />
Where <math>h_{t-1}^{cat} ∈R^{R+M}</math> is the concatenation of the current input <math>x_t</math> and the previous hidden state <math>h_{t−1}</math>, which expands the dimensionality of intermediate information.<br />
<br />
The update of the hidden state ht is defined as:<br />
<br />
\begin{align}<br />
a_{t} =h_{t-1}^{cat} W^h + b^h \hspace{2cm} (2)<br />
\end{align}<br />
<br />
and<br />
<br />
\begin{align}<br />
h_t = \Phi(a_t) \hspace{2cm} (3)<br />
\end{align}<br />
<br />
<math>W^h∈R^(R+M)xM </math> guarantees each hidden status provided by the previous step is of dimension M. <math> a_t ∈R^M </math> the hidden activation, and φ(·) the element-wise "tanh" function. Finally, the output <math> y_t </math> at time-step t is generated by:<br />
<br />
\begin{align}<br />
y_t = \varphi(h_{t}^{cat} W^y + b^y) \hspace{2cm} (4)<br />
\end{align}<br />
<br />
where <math>W^y∈R^{M×S}</math> and <math>b^y∈R^S</math>, and <math>\varphi(·)</math> can be any differentiable function, notes that the "Phi" is the element-wise function which produces some non-linearity and further generates another '''hidden status''', while the "Curly Phi" is applied to generates the '''output'''<br />
<br />
[[File:StdRNN.png|650px|center||Figure 1: Recurrent Neural Network]]<br />
<br />
However, one shortfall of RNN is the vanishing/exploding gradients. This shortfall is more significant especially when constructing long-range dependencies models. One alternative is to apply LSTM (Long Short-Term Memories), LSTMs alleviate these problems by employing memory cells to preserve information for longer, and adopting gating mechanisms to modulate the information flow. Since LSTM is successfully in sequence models, it is natural to consider how to increase the complexity of the model to accommodate more complex analytical needs.<br />
<br />
[[File:LSTM_Gated.png|650px|center||Figure 2: LSTM]]<br />
<br />
= Structural Measurement of Sequential Model =<br />
<br />
We can consider the capacity of a network consists of two components: the '''width''' (the amount of information handled in parallel) and the depth (the number of computation steps). <br />
<br />
A way to '''widen''' the LSTM is to increase the number of units in a hidden layer; however, the parameter number scales quadratically with the number of units. To deepen the LSTM, the popular Stacked LSTM (sLSTM) stacks multiple LSTM layers. The drawback of sLSTM, however, is that runtime is proportional to the number of layers and information from the input is potentially lost (due to gradient vanishing/explosion) as it propagates vertically through the layers. This paper introduced a way to both widen and deepen the LSTM whilst keeping the parameter number and runtime largely unchanged. In summary, we make the following contributions:<br />
<br />
'''(a)''' Tensorize RNN hidden state vectors into higher-dimensional tensors, to enable more flexible parameter sharing and can be widened more efficiently without additional parameters.<br />
<br />
'''(b)''' Based on (a), merge RNN deep computations into its temporal computations so that the network can be deepened with little additional runtime, resulting in a Tensorized RNN (tRNN).<br />
<br />
'''(c)''' We extend the tRNN to an LSTM, namely the Tensorized LSTM (tLSTM), which integrates a novel memory cell convolution to help to prevent the vanishing/exploding gradients.<br />
<br />
= Method =<br />
<br />
Go through the methodology.<br />
<br />
== Part 1: Tensorize RNN hidden State vectors ==<br />
<br />
'''Definition:''' Tensorization is defined as the transformation or mapping of lower-order data to higher-order data. For example, the low-order data can be a vector, and the tensorized result is a matrix, a third-order tensor or a higher-order tensor. The ‘low-order’ data can also be a matrix or a third-order tensor, for example. In the latter case, tensorization can take place along one or multiple modes.<br />
<br />
[[File:VecTsor.png|320px|center||Figure 3: Vector Third-order tensorization of a vector]]<br />
<br />
'''Optimization Methodology Part 1:''' It can be seen that in an RNN, the parameter number scales quadratically with the size of the hidden state. A popular way to limit the parameter number when widening the network is to organize parameters as higher-dimensional tensors which can be factorized into lower-rank sub-tensors that contain significantly fewer elements, which is is known as tensor factorization. <br />
<br />
'''Optimization Methodology Part 2:''' Another common way to reduce the parameter number is to share a small set of parameters across different locations in the hidden state, similar to Convolutional Neural Networks (CNNs).<br />
<br />
'''Effects:''' This '''widens''' the network since the hidden state vectors are in fact broadcast to interact with the tensorized parameters. <br />
<br />
<br />
<br />
We adopt parameter sharing to cutdown the parameter number for RNNs, since compared with factorization, it has the following advantages: <br />
<br />
(i) '''Scalability,''' the number of shared parameters can be set independent of the hidden state size<br />
<br />
(ii) '''Separability,''' the information flow can be carefully managed by controlling the receptive field, allowing one to shift RNN deep computations to the temporal domain<br />
<br />
<br />
<br />
We also explicitly tensorize the RNN hidden state vectors, since compared with vectors, tensors have a better: <br />
<br />
(i) '''Flexibility,''' one can specify which dimensions to share parameters and then can just increase the size of those dimensions without introducing additional parameters<br />
<br />
(ii) '''Efficiency,''' with higher-dimensional tensors, the network can be widened faster w.r.t. its depth when fixing the parameter number (explained later). <br />
<br />
<br />
'''Illustration:''' For ease of exposition, we first consider 2D tensors (matrices): we tensorize the hidden state <math>h_t∈R^{M}</math> to become <math>Ht∈R^{P×M}</math>, '''where P is the tensor size,''' and '''M the channel size'''. We locally-connect the first dimension of <math>H_t</math> (which is P - the tensor size) in order to share parameters, and fully-connect the second dimension of <math>H_t</math> (which is M - the channel size) to allow global interactions. This is analogous to the CNN which fully-connects one dimension (e.g., the RGB channel for input images) to globally fuse different feature planes. Also, if one compares <math>H_t</math> to the hidden state of a Stacked RNN (sRNN) (see Figure Blow). <br />
<br />
[[File:Screen_Shot_2018-03-26_at_11.28.37_AM.png|160px|center||Figure 4: Stacked RNN]]<br />
<br />
[[File:ind.png|60px|center||Figure 4: Stacked RNN]]<br />
<br />
Then P is akin to the number of stacked hidden layers (vertical length in the graph), and M the size of each hidden layer (each white node in the graph). We start to describe our model based on 2D tensors, and finally show how to strengthen the model with higher-dimensional tensors.<br />
<br />
== Part 2: Merging Deep Computations ==<br />
<br />
Since an RNN is already deep in its temporal direction, we can deepen an input-to-output computation by associating the input <math>x_t</math> with a (delayed) future output. In doing this, we need to ensure that the output <math>y_t</math> is separable, i.e., not influenced by any future input <math>x_{t^{'}}</math> <math>(t^{'}>t)</math>. Thus, we concatenate the projection of <math>x_t</math> to the top of the previous hidden state <math>H_{t−1}</math>, then gradually shift the input information down when the temporal computation proceeds, and finally generate <math>y_t</math> from the bottom of <math>H_{t+L−1}</math>, where L−1 is the number of delayed time-steps for computations of depth L. <br />
<br />
An example with L= 3 is shown in Figure.<br />
<br />
[[File:tRNN.png|160px|center||Figure 5: skewed sRNN]]<br />
<br />
[[File:ind.png|60px|center||Figure 5: skewed sRNN]]<br />
<br />
<br />
This is in fact a skewed sRNN (or tRNN without feedback). However, the method does not need to change the network structure and also allows different kinds of interactions as long as the output is separable; for example, one can increase the local connections and '''use feedback''' (shown in figure below), which can be beneficial for sRNNs (or tRNN). <br />
<br />
[[File:tRNN_wF.png|160px|center||Figure 5: skewed sRNN with F]]<br />
<br />
[[File:ind.png|60px|center||Figure 5: skewed sRNN with F]]<br />
<br />
'''In order to share parameters, we update <math>H_t</math> using a convolution with a learnable kernel.''' In this manner we increase the complexity of the input-to-output mapping (by delaying outputs) and limit parameter growth (by sharing transition parameters using convolutions).<br />
<br />
To examine the resulting model mathematically, let <math>H^{cat}_{t−1}∈R^{(P+1)×M}</math> be the concatenated hidden state, and <math>p∈Z_+</math> the location at a tensor. The channel vector <math>h^{cat}_{t−1, p }∈R^M</math> at location p of <math>H^{cat}_{t−1}</math> (the p-th channel of H) is defined as:<br />
<br />
\begin{align}<br />
h^{cat}_{t-1, p} = x_t W^x + b^x \hspace{1cm} if p = 1 \hspace{1cm} (5)<br />
\end{align}<br />
<br />
\begin{align}<br />
h^{cat}_{t-1, p} = h_{t-1, p-1} \hspace{1cm} if p > 1 \hspace{1cm} (6)<br />
\end{align}<br />
<br />
where <math>W^x ∈ R^{R×M}</math> and <math>b^x ∈ R^M</math> (recall the dimension of input x is R). Then, the update of tensor <math>H_t</math> is implemented via a convolution:<br />
<br />
\begin{align}<br />
A_t = H^{cat}_{t-1} \circledast \{W^h, b^h \} \hspace{2cm} (7)<br />
\end{align}<br />
<br />
\begin{align}<br />
H_t = \Phi{A_t} \hspace{2cm} (8)<br />
\end{align}<br />
<br />
where <math>W^h∈R^{K×M^i×M^o}</math> is the kernel weight of size K, with <math>M^i =M</math> input channels and <math>M^o =M</math> output channels, <math>b^h ∈ R^{M^o}</math> is the kernel bias, <math>A_t ∈ R^{P×M^o}</math> is the hidden activation, and <math>\circledast</math> is the convolution operator. Since the kernel convolves across different hidden layers, we call it the cross-layer convolution. The kernel enables interaction, both bottom-up and top-down across layers. Finally, we generate <math>y_t</math> from the channel vector <math>h_{t+L−1,P}∈R^M</math> which is located at the bottom of <math>H_{t+L−1}</math>:<br />
<br />
\begin{align}<br />
y_t = \varphi(h_{t+L−1}, _PW^y + b^y) \hspace{2cm} (9)<br />
\end{align}<br />
<br />
Where <math>W^y ∈R^{M×S}</math> and <math>b^y ∈R^S</math>. To guarantee that the receptive field of <math>y_t</math> only covers the current and previous inputs x1:t. (Check the Skewed sRNN again below):<br />
<br />
[[File:tRNN_wF.png|160px|center||Figure 5: skewed sRNN with F]]<br />
<br />
[[File:ind.png|60px|center||Figure 5: skewed sRNN with F]]<br />
<br />
=== Quick Summary of Set of Parameters ===<br />
<br />
'''1. <math> W^x</math> and <math>b_x</math>''' connect input to the first hidden node<br />
<br />
'''2. <math> W^h</math> and <math>b_h</math>''' convolute between layers<br />
<br />
'''3. <math> W^y</math> and <math>b_y</math>''' produce output of each stages<br />
<br />
<br />
== Part 3: Extending to LSTMs==<br />
<br />
Similar to standard RNN, to allow the tRNN (skewed sRNN) to capture long-range temporal dependencies, one can straightforwardly extend it<br />
to a tLSTM by replacing the tRNN tensors:<br />
<br />
\begin{align}<br />
[A^g_t, A^i_t, A^f_t, A^o_t] = H^{cat}_{t-1} \circledast \{W^h, b^h \} \hspace{2cm} (10)<br />
\end{align}<br />
<br />
\begin{align}<br />
[G_t, I_t, F_t, O_t]= [\Phi{(A^g_t)}, σ(A^i_t), σ(A^f_t), σ(A^o_t)] \hspace{2cm} (11)<br />
\end{align}<br />
<br />
Which are pretty similar to tRNN case, the main differences can be observes for memory cells of tLSTM (Ct):<br />
<br />
\begin{align}<br />
C_t= G_t \odot I_t + C_{t-1} \odot F_t \hspace{2cm} (12)<br />
\end{align}<br />
<br />
\begin{align}<br />
H_t= \Phi{(C_t )} \odot O_t \hspace{2cm} (13)<br />
\end{align}<br />
<br />
Summary of the terms: <br />
<br />
1. '''<math>G_t</math>:''' Activation of new content<br />
<br />
2. '''<math>I_t</math>:''' Input gate<br />
<br />
3. '''<math>F_t</math>:''' Forget gate<br />
<br />
4. '''<math>O_t</math>:''' Output gate<br />
<br />
Then, see graph below for illustration:<br />
<br />
[[File:tLSTM_wo_MC.png |160px|center||Figure 5: tLSTM wo MC]]<br />
<br />
[[File:ind.png|60px|center||Figure 5: tLSTM wo MC]]<br />
<br />
To further evolve tLSTM, we invoke the '''Memory Cell Convolution''' to capture long-range dependencies from multiple directions, we additionally introduce a novel memory cell convolution, by which the memory cells can have a larger receptive field (figure provided below). <br />
<br />
[[File:tLSTM_w_MC.png |160px|center||Figure 5: tLSTM w MC]]<br />
<br />
[[File:ind.png|60px|center||Figure 5: tLSTM w MC]]<br />
<br />
One can also dynamically generate this convolution kernel so that it is both time - and location-dependent, allowing for flexible control over long-range dependencies from different directions. Mathematically, it can be represented in with the following formulas:<br />
<br />
\begin{align}<br />
[A^g_t, A^i_t, A^f_t, A^o_t, A^q_t] = H^{cat}_{t-1} \circledast \{W^h, b^h \} \hspace{2cm} (14)<br />
\end{align}<br />
<br />
\begin{align}<br />
[G_t, I_t, F_t, O_t, Q_t]= [\Phi{(A^g_t)}, σ(A^i_t), σ(A^f_t), σ(A^o_t), ς(A^q_t)] \hspace{2cm} (15)<br />
\end{align}<br />
<br />
\begin{align}<br />
W_t^c(p) = reshape(q_{t,p}, [K, 1, 1]) \hspace{2cm} (16)<br />
\end{align}<br />
<br />
\begin{align}<br />
C_{t-1}^{conv}= C_{t-1} \circledast W_t^c(p) \hspace{2cm} (17)<br />
\end{align}<br />
<br />
\begin{align}<br />
C_t= G_t \odot I_t + C_{t-1}^{conv} \odot F_t \hspace{2cm} (18)<br />
\end{align}<br />
<br />
\begin{align}<br />
H_t= \Phi{(C_t )} \odot O_t \hspace{2cm} (19)<br />
\end{align}<br />
<br />
where the kernel <math>{W^h, b^h}</math> has additional <K> output channels to generate the activation <math>A^q_t ∈ R^{P×<K>}</math> for the dynamic kernel bank <math>Q_t∈R^{P × <K>}</math>, <math>q_{t,p}∈R^{<K>}</math> is the vectorized adaptive kernel at the location p of <math>Q_t</math>, and <math>W^c_t(p) ∈ R^{K×1×1}</math> is the dynamic kernel of size K with a single input/output channel, which is reshaped from <math>q_{t,p}</math>. Note the paper also employed a softmax function ς(·) to normalize the channel dimension of <math>Q_t</math>. which can also stabilize the value of memory cells and help to prevent the vanishing/exploding gradients. An illustration is provided below to better illustrate the process:<br />
<br />
[[File:MCC.png |240px|center||Figure 5: MCC]]<br />
<br />
To improve training, the authors introduced a new normalization technique for ''t''LSTM termed channel normalization (adapted from layer normalization), in which the channel vector are normalized at different locations with their own statistics, since layer normalization does not work well with ''t''LSTM, because lower level information is near the input and higher level information is near the output. Channel normalization (CN) is defined as: <br />
<br />
\begin{align}<br />
\mathrm{CN}(\mathbf{Z}; \mathbf{\Gamma}, \mathbf{B}) = \mathbf{\hat{Z}} \odot \mathbf{\Gamma} + \mathbf{B} \hspace{2cm} (20)<br />
\end{align}<br />
<br />
where <math>\mathbf{Z}</math>, <math>\mathbf{\hat{Z}}</math>, <math>\mathbf{\Gamma}</math>, <math>\mathbf{B} \in \mathbb{R}^{P \times M^z}</math> are the original tensor, normalized tensor, gain parameter and bias parameter. The <math>m^z</math>-th channel of <math>\mathbf{Z}</math> is normalized element-wisely: <br />
<br />
\begin{align}<br />
\hat{z_{m^z}} = (z_{m^z} - z^\mu)/z^{\sigma} \hspace{2cm} (21)<br />
\end{align}<br />
<br />
where <math>z^{\mu}</math>, <math>z^{\sigma} \in \mathbb{R}^P</math> are the mean and standard deviation along the channel dimension of <math>\mathbf{Z}</math>, and <math>\hat{z_{m^z}} \in \mathbb{R}^P</math> is the <math>m^z</math>-th channel <math>\mathbf{\hat{Z}}</math>. Channel normalization introduces very few additional parameters compared to the number of other parameters in the model.<br />
<br />
= Results and Evaluation =<br />
<br />
Summary of list of models tLSTM family (may be useful later):<br />
<br />
(a) sLSTM (baseline): the implementation of sLSTM with parameters shared across all layers.<br />
<br />
(b) 2D tLSTM: the standard 2D tLSTM.<br />
<br />
(c) 2D tLSTM–M: removing memory (M) cell convolutions from (b).<br />
<br />
(d) 2D tLSTM–F: removing (–) feedback (F) connections from (b).<br />
<br />
(e) 3D tLSTM: tensorizing (b) into 3D tLSTM.<br />
<br />
(f) 3D tLSTM+LN: applying (+) Layer Normalization.<br />
<br />
(g) 3D tLSTM+CN: applying (+) Channel Normalization.<br />
<br />
=== Efficiency Analysis ===<br />
<br />
'''Fundaments:''' For each configuration, fix the parameter number and increase the tensor size to see if the performance of tLSTM can be boosted without increasing the parameter number. Can also investigate how the runtime is affected by the depth, where the runtime is measured by the average GPU milliseconds spent by a forward and backward pass over one timestep of a single example. <br />
<br />
'''Dataset:''' The Hutter Prize Wikipedia dataset consists of 100 million characters taken from 205 different characters including alphabets, XML markups and special symbols. We model the dataset at the character-level, and try to predict the next character of the input sequence.<br />
<br />
All configurations are evaluated with depths L = 1, 2, 3, 4. Bits-per-character(BPC) is used to measure the model performance and the results are shown in the figure below.<br />
[[File:wiki.png |280px|center||Figure 5: WifiPerf]]<br />
[[File:Wiki_Performance.png |480px|center||Figure 5: WifiPerf]]<br />
<br />
=== Accuracy Analysis ===<br />
<br />
The MNIST dataset [35] consists of 50000/10000/10000 handwritten digit images of size 28×28 for training/validation/test. We have two tasks on this dataset:<br />
<br />
(a) '''Sequential MNIST:''' The goal is to classify the digit after sequentially reading the pixels in a scan-line order. It is therefore a 784 time-step sequence learning task where a single output is produced at the last time-step; the task requires very long range dependencies in the sequence.<br />
<br />
(b) '''Sequential Permuted MNIST:''' We permute the original image pixels in a fixed random order, resulting in a permuted MNIST (pMNIST) problem that has even longer range dependencies across pixels and is harder.<br />
<br />
In both tasks, all configurations are evaluated with M = 100 and L= 1, 3, 5. The model performance is measured by the classification accuracy and results are shown in the figure below.<br />
<br />
[[File:MNISTperf.png |480px|center]]<br />
<br />
<br />
<br />
[[File:Acc_res.png |480px|center||Figure 5: MNIST]]<br />
<br />
[[File:33_mnist.PNG|center|thumb|800px| This figure displays a visualization of the means of the diagonal channels of the tLSTM memory cells per task. The columns indicate the time steps and the rows indicate the diagonal locations. The values are normalized between 0 and 1.]]<br />
<br />
= Conclusions =<br />
<br />
The paper introduced the Tensorized LSTM, which employs tensors to share parameters and utilizes the temporal computation to perform the deep computation for sequential tasks. Then validated the model<br />
on a variety of tasks, showing its potential over other popular approaches.<br />
<br />
= Critique(to be edited) =<br />
<br />
= References =<br />
#Zhen He, Shaobing Gao, Liang Xiao, Daxue Liu, Hangen He, and David Barber. <Wider and Deeper, Cheaper and Faster: Tensorized LSTMs for Sequence Learning> (2017)<br />
#Ali Ghodsi, <Deep Learning: STAT 946 - Winter 2018></div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Training_And_Inference_with_Integers_in_Deep_Neural_Networks&diff=35877Training And Inference with Integers in Deep Neural Networks2018-03-28T22:17:29Z<p>X249wang: /* Bitwidth of Gradients */</p>
<hr />
<div>== Introduction ==<br />
<br />
Deep neural networks have enjoyed much success in all manners of tasks, but it is common for these networks to be complicated, requiring large amounts of energy-intensive memory and floating-point operations. Therefore, in order to use state-of-the-art networks in applications where energy is limited or having packaging limitation for hardware, such as anything not connected to the power grid, the energy costs must be reduced while preserving as much performance as practical.<br />
<br />
Most existing methods focus on reducing the energy requirements during inference rather than training. Since training with SGD requires accumulation, training usually has higher precision demand than inference. Most of the existing methods focus on how to compress a model for inference, rather than during training. This paper proposes a framework to reduce complexity both during training and inference through the use of integers instead of floats. They address how to quantize all operations and operands as well as examining the bitwidth requirement for SGD computation & accumulation. Using integers instead of floats results in energy-savings because integer operations are more efficient than floating point (see the table below). Also, there already exists dedicated hardware for deep learning that uses integer operations (such as the 1st generation of Google TPU) so understanding the best way to use integers is well-motivated.<br />
{| class="wikitable"<br />
|+Rough Energy Costs in 45nm 0.9V [1]<br />
!<br />
! colspan="2" |Energy(pJ)<br />
! colspan="2" |Area(<math>\mu m^2</math>)<br />
|-<br />
!Operation<br />
!MUL<br />
!ADD<br />
!MUL<br />
!ADD<br />
|-<br />
|8-bit INT<br />
|0.2<br />
|0.03<br />
|282<br />
|36<br />
|-<br />
|16-bit FP<br />
|1.1<br />
|0.4<br />
|1640<br />
|1360<br />
|-<br />
|32-bit FP<br />
|3.7<br />
|0.9<br />
|7700<br />
|4184<br />
|}<br />
The authors call the framework WAGE because they consider how best to handle the '''W'''eights, '''A'''ctivations, '''G'''radients, and '''E'''rrors separately.<br />
<br />
== Related Work ==<br />
<br />
=== Weight and Activation ===<br />
Existing works to train DNNs on binary weights and activations [2] add noise to weights and activations as a form of regularization. The use of high-precision accumulation is required for SGD optimization since real-valued gradients are obtained from real-valued variables. XNOR-Net [11] uses bitwise operations to approximate convolutions in a highly memory-efficient manner, and applies a filter-wise scaling factor for weights to improve performance. However, these floating-point factors are calculated simultaneously during training, which aggravates the training effort. Ternary weight networks (TWN) [3] and Trained ternary quantization (TTQ) [9] offer more expressive ability than binary weight networks by constraining the weights to be ternary-valued {-1,0,1} using two symmetric thresholds.<br />
<br />
=== Gradient Computation and Accumulation ===<br />
The DoReFa-Net quantizes gradients to low-bandwidth floating point numbers with discrete states in the backwards pass. In order to reduce the overhead of gradient synchronization in distributed training the TernGrad method quantizes the gradient updates to ternary values. In both works the weights are still stored and updated with float32, and the quantization of batch normalization and its derivative is ignored.<br />
<br />
== WAGE Quantization ==<br />
The core idea of the proposed method is to constrain the following to low-bitwidth integers on each layer:<br />
* '''W:''' weight in inference<br />
* '''a:''' activation in inference<br />
* '''e:''' error in backpropagation<br />
* '''g:''' gradient in backpropagation<br />
[[File:p32fig1.PNG|center|thumb|800px|Four operators QW (·), QA(·), QG(·), QE(·) added in WAGE computation dataflow to reduce precision, bitwidth of signed integers are below or on the right of arrows, activations are included in MAC for concision.]]<br />
The error and gradient are defined as:<br />
<br />
<math>e^i = \frac{\partial L}{\partial a^i}, g^i = \frac{\partial L}{\partial W^i}</math><br />
<br />
where L is the loss function.<br />
<br />
The precision in bits of the errors, activations, gradients, and weights are <math>k_E</math>, <math>k_A</math>, <math>k_G</math>, and <math>k_W</math> respectively. As shown in the above figure, each quantity also has a quantization operators to reduce bitwidth increases caused by multiply-accumulate (MAC) operations. Also, note that since this is a layer-by-layer approach, each layer may be followed or preceded by a layer with different precision, or even a layer using floating point math.<br />
<br />
=== Shift-Based Linear Mapping and Stochastic Mapping ===<br />
The proposed method makes use of a linear mapping where continuous, unbounded values are discretized for each bitwidth <math>k</math> with a uniform spacing of<br />
<br />
<math>\sigma(k) = 2^{1-k}, k \in Z_+ </math><br />
With this, the full quantization function is<br />
<br />
<math>Q(x,k) = Clip\left \{ \sigma(k) \cdot round\left [ \frac{x}{\sigma(k)} \right ], -1 + \sigma(k), 1 - \sigma(k) \right \}</math>, <br />
<br />
where <math>round</math> approximates continuous values to their nearest discrete state, and <math>Clip</math> is the saturation function that clips unbounded values to <math>[-1 + \sigma, 1 - \sigma]</math>. Note that this function is only using when simulating integer operations on floating-point hardware, on native integer hardware, this is done automatically. In addition to this quantization function, a distribution scaling factor is used in some quantization operators to preserve as much variance as possible when applying the quantization function above. The scaling factor is defined below.<br />
<br />
<math>Shift(x) = 2^{round(log_2(x))}</math><br />
<br />
Finally, stochastic rounding is substituted for small or real-valued updates during gradient accumulation.<br />
<br />
A visual representation of these operations is below.<br />
[[File:p32fig2.PNG|center|thumb|800px|Quantization methods used in WAGE. The notation <math>P, x, \lfloor \cdot \rfloor, \lceil \cdot \rceil</math> denotes probability, vector, floor and ceil, respectively. <math>Shift(\cdot)</math> refers to distribution shifting with a certain argument]]<br />
<br />
=== Weight Initialization ===<br />
In this work, batch normalization is simplified to a constant scaling layer in order to sidestep the problem of normalizing outputs without floating point math, and to remove the extra memory requirement with batch normalization. As such, some care must be taken when initializing weights. The authors use a modified initialization method base on MSRA [4].<br />
<br />
<math>W \thicksim U(-L, +L),L = max \left \{ \sqrt{6/n_{in}}, L_{min} \right \}, L_{min} = \beta \sigma</math><br />
<br />
<math>n_{in}</math> is the layer fan-in number, <math>U</math> denotes uniform distribution. The original<math>\eta</math> initialization method is modified by adding the condition that the distribution width should be at least <math>\beta \sigma</math>, where <math>\beta</math> is a constant greater than 1 and <math>\sigma</math> is the minimum step size see already. This prevents weights being initialised to all-zeros in the case where the bitwidth is low, or the fan-in number is high.<br />
<br />
=== Quantization Details ===<br />
<br />
==== Weight <math>Q_W(\cdot)</math> ====<br />
<math>W_q = Q_W(W) = Q(W, k_W)</math><br />
<br />
The quantization operator is simply the quantization function previously introduced. <br />
<br />
==== Activation <math>Q_A(\cdot)</math> ====<br />
The authors say that the variance of the weights passed through this function will be scaled compared to the variance of the weights as initialized. To prevent this effect from blowing up the network outputs, they introduce a scaling factor <math>\alpha</math>. Notice that it is constant for each layer.<br />
<br />
<math>\alpha = max \left \{ Shift(L_{min} / L), 1 \right \}</math><br />
<br />
The quantization operator is then<br />
<br />
<math>a_q = Q_A(a) = Q(a/\alpha, k_A)</math><br />
<br />
The scaling factor approximates batch normalization.<br />
<br />
==== Error <math>Q_E(\cdot)</math> ====<br />
The magnitude of the error can vary greatly, and that a previous approach (DoReFa-Net [5]) solves the issue by using an affine transform to map the error to the range <math>[-1, 1]</math>, apply quantization, and then applying the inverse transform. However, the authors claim that this approach still requires using float32, and that the magnitude of the error is unimportant: rather it is the orientation of the error. Thus, they only scale the error distribution to the range <math>\left [ -\sqrt2, \sqrt2 \right ]</math> and quantise:<br />
<br />
<math>e_q = Q_E(e) = Q(e/Shift(max\{|e|\}), k_E)</math><br />
<br />
Max is the element-wise maximum. Note that this discards any error elements less than the minimum step size.<br />
<br />
==== Gradient <math>Q_G(\cdot)</math> ====<br />
Similar to the activations and errors, the gradients are rescaled:<br />
<br />
<math>g_s = \eta \cdot g/Shift(max\{|g|\})</math><br />
<br />
<math> \eta </math> is a shift-based learning rate. It is an integer power of 2. The shifted gradients are represented in units of minimum step sizes <math> \sigma(k) </math>. When reducing the bitwidth of the gradients (remember that the gradients are coming out of a MAC operation, so the bitwidth may have increased) stochastic rounding is used as a substitute for small gradient accumulation.<br />
<br />
<math>\Delta W = Q_G(g) = \sigma(k_G) \cdot sgn(g_s) \cdot \left \{ \lfloor | g_s | \rfloor + Bernoulli(|g_s|<br />
- \lfloor | g_s | \rfloor) \right \}</math><br />
<br />
This randomly rounds the result of the MAC operation up or down to the nearest quantization for the given gradient bitwidth. The weights are updated with the resulting discrete increments:<br />
<br />
<math>W_{t+1} = Clip \left \{ W_t - \Delta W_t, -1 + \sigma(k_G), 1 - \sigma(k_G) \right \}</math><br />
<br />
=== Miscellaneous ===<br />
To train WAGE networks, the authors used pure SGD exclusively because more complicated techniques such as Momentum or RMSProp increase memory consumption and are complicated by the rescaling that happens within each quantization operator.<br />
<br />
The quantization and stochastic rounding are a form of regularization.<br />
<br />
The authors didn't use a traditional softmax with cross-entropy loss for the experiments because there does not yet exist a softmax layer for low-bit integers. Instead, they use a sum of squared error loss. This works for tasks with a small number of categories, but does not scale well.<br />
<br />
== Experiments ==<br />
For all experiments, the default layer bitwidth configuration is 2-8-8-8 for Weights, Activations, Gradients, and Error bits. The weight bitwidth is set to 2 because that results in ternary weights, and therefore no multiplication during inference. They authors argue that the bitwidth for activation and errors should be the same because the computation graph for each is similar and might use the same hardware. During training, the weight bitwidth is 8. For inference the weights are ternarized.<br />
<br />
=== Implementation Details ===<br />
MNIST: Network is LeNet-5 variant [6] with 32C5-MP2-64C5-MP2-512FC-10SSE.<br />
<br />
SVHN & CIFAR10: VGG variant [7] with 2×(128C3)-MP2-2×(256C3)-MP2-2×(512C3)-MP2-1024FC-10SSE. For CIFAR10 dataset, the data augmentation is followed in Lee et al. (2015) [10] for training.<br />
<br />
ImageNet: AlexNet variant [8] on ILSVRC12 dataset.<br />
{| class="wikitable"<br />
|+Test or validation error rates (%) in previous works and WAGE on multiple datasets. Opt denotes gradient descent optimizer, withM means SGD with momentum, BN represents batch normalization, 32 bit refers to float32, and ImageNet top-k format: top1/top5.<br />
!Method<br />
!<math>k_W</math><br />
!<math>k_A</math><br />
!<math>k_G</math><br />
!<math>k_E</math><br />
!Opt<br />
!BN<br />
!MNIST<br />
!SVHN<br />
!CIFAR10<br />
!ImageNet<br />
|-<br />
|BC<br />
|1<br />
|32<br />
|32<br />
|32<br />
|Adam<br />
|yes<br />
|1.29<br />
|2.30<br />
|9.90<br />
|<br />
|-<br />
|BNN<br />
|1<br />
|1<br />
|32<br />
|32<br />
|Adam<br />
|yes <br />
|0.96<br />
|2.53<br />
|10.15<br />
|<br />
|-<br />
|BWN<br />
|1<br />
|32<br />
|32<br />
|32<br />
|withM<br />
|yes<br />
|<br />
|<br />
|<br />
|43.2/20.6<br />
|-<br />
|XNOR<br />
|1<br />
|1<br />
|32<br />
|32<br />
|Adam<br />
|yes<br />
|<br />
|<br />
|<br />
|55.8/30.8<br />
|-<br />
|TWN<br />
|2<br />
|32<br />
|32<br />
|32<br />
|withM<br />
|yes<br />
|0.65<br />
|<br />
|7.44<br />
|'''34.7/13.8'''<br />
|-<br />
|TTQ<br />
|2<br />
|32<br />
|32<br />
|32<br />
|Adam<br />
|yes<br />
|<br />
|<br />
|6.44<br />
|42.5/20.3<br />
|-<br />
|DoReFa<br />
|8<br />
|8<br />
|32<br />
|8<br />
|Adam<br />
|yes<br />
|<br />
|2.30<br />
|<br />
|47.0/<br />
|-<br />
|TernGrad<br />
|32<br />
|32<br />
|2<br />
|32<br />
|Adam<br />
|yes<br />
|<br />
|<br />
|14.36<br />
|42.4/19.5<br />
|-<br />
|WAGE<br />
|2<br />
|8<br />
|8<br />
|8<br />
|SGD<br />
|no<br />
|'''0.40'''<br />
|'''1.92'''<br />
|'''6.78'''<br />
|51.6/27.8<br />
|}<br />
<br />
=== Training Curves and Regularization ===<br />
The authors compare the 2-8-8-8 WAGE configuration introduced above, a 2-8-f-f (meaning float32) configuration, and a completely floating point version on CIFAR10. The test error is plotted against epoch. For training these networks, the learning rate is divided by 8 at the 200th epoch and again at the 250th epoch.<br />
[[File:p32fig3.PNG|center|thumb|800px|Training curves of WAGE variations and a vanilla CNN on CIFAR10]]<br />
The convergence of the 2-8-8-8 has comparable convergence to the vanilla CNN and outperforms the 2-8-f-f variant. The authors speculate that this is because the extra discretization acts as a regularizer.<br />
<br />
=== Bitwidth of Errors ===<br />
The CIFAR10 test accuracy is plotted against bitwidth below and the error density for a single layer is compared with the Vanilla network.<br />
[[File:p32fig4.PNG|center|thumb|520x522px|The 10 run accuracies of different <math>k_E</math>]]<br />
<br />
[[File:32_error.png|center|thumb|520x522px|Histogram of errors for Vanilla network and Wage network. After being quantized and shifted each layer, the error is reshaped and so most orientation information is retained. ]]<br />
<br />
The table below shows the test error rates on CIFAR10 when left-shift upper boundary with factor γ. From this table we could see that large values play critical roles for backpropagation training even though they don't have many while the majority with small values actually just noises.<br />
<br />
[[File:testerror_rate.png|center]]<br />
<br />
=== Bitwidth of Gradients ===<br />
<br />
The authors next investigated the choice of a proper <math>k_G</math> for gradients using the CIFAR10 dataset. <br />
<br />
{| class="wikitable"<br />
|+Test error rates (%) on CIFAR10 with different <math>k_G</math><br />
!<math>k_G</math><br />
!2<br />
!3<br />
!4<br />
!5<br />
!6<br />
!7<br />
!8<br />
!9<br />
!10<br />
!11<br />
!12<br />
|-<br />
|error<br />
|54.22<br />
|51.57<br />
|28.22<br />
|18.01<br />
|11.48<br />
|7.61<br />
|6.78<br />
|6.63<br />
|6.43<br />
|6.55<br />
|6.57<br />
|}<br />
<br />
The results show similar bitwidth requirements as the last experiment for <math>k_E</math>.<br />
<br />
The authors also examined the effect of bitwidth on the ImageNet implementation.<br />
<br />
Here, C denotes 12 bits (Hexidecimal) and BN refers to batch normalization being added. 7 models are used: 2888 from the first experiment, 288C for more accurate errors (12 bits), 28C8 for larger buffer space, 28f8 for non-quantization of gradients, 28ff for errors and gradients in float32, and 28ff with BN added. The baseline vanilla model refers to the original AlexNet architecture. <br />
<br />
{| class="wikitable"<br />
|+Top-5 error rates (%) on ImageNet with different <math>k_G</math>and <math>k_E</math><br />
!Pattern<br />
!vanilla<br />
!28ff-BN<br />
!28ff<br />
!28f8<br />
!28C8<br />
!288C<br />
!2888<br />
|-<br />
|error<br />
|19.29<br />
|20.67<br />
|24.14<br />
|23.92<br />
|26.88<br />
|28.06<br />
|27.82<br />
|}<br />
<br />
The comparison between 28C8 and 288C shows that the model may perform better if it has more buffer space <math>k_G</math> for gradient accumulation than if it has high-resolution orientation <math>k_E</math>. The authors also noted that batch normalization and <math>k_G</math> are more important for ImageNet because the training set samples are highly variant.<br />
<br />
== Discussion ==<br />
The authors have a few areas they believe this approach could be improved.<br />
<br />
'''MAC Operation:''' The 2-8-8-8 configuration was chosen because the low weight bitwidth means there aren't any multiplication during inference. However, this does not remove the requirement for multiplication during training. 2-2-8-8 configuration satisfies this requirement, but it is difficult to train and detrimental to the accuracy.<br />
<br />
'''Non-linear Quantization:''' The linear mapping used in this approach is simple, but there might be a more effective mapping. For example, a logarithmic mapping could be more effective if the weights and activations have a log-normal distribution.<br />
<br />
'''Normalization:''' Normalization layers (softmax, batch normalization) were not used in this paper. Quantized versions are an area of future work<br />
<br />
== Conclusion ==<br />
<br />
A framework for training and inference without the use of floating-point representation is presented. Future work may further improve compression and memory requirements.<br />
== References ==<br />
<br />
# Sze, Vivienne; Chen, Yu-Hsin; Yang, Tien-Ju; Emer, Joel (2017-03-27). [http://arxiv.org/abs/1703.09039 "Efficient Processing of Deep Neural Networks: A Tutorial and Survey"]. arXiv:1703.09039 [cs].<br />
# Courbariaux, Matthieu; Bengio, Yoshua; David, Jean-Pierre (2015-11-01). [http://arxiv.org/abs/1511.00363 "BinaryConnect: Training Deep Neural Networks with binary weights during propagations"]. arXiv:1511.00363 [cs].<br />
# Li, Fengfu; Zhang, Bo; Liu, Bin (2016-05-16). [http://arxiv.org/abs/1605.04711 "Ternary Weight Networks"]. arXiv:1605.04711 [cs].<br />
# He, Kaiming; Zhang, Xiangyu; Ren, Shaoqing; Sun, Jian (2015-02-06). [http://arxiv.org/abs/1502.01852 "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification"]. arXiv:1502.01852 [cs].<br />
# Zhou, Shuchang; Wu, Yuxin; Ni, Zekun; Zhou, Xinyu; Wen, He; Zou, Yuheng (2016-06-20). [http://arxiv.org/abs/1606.06160 "DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients"]. arXiv:1606.06160 [cs].<br />
# Lecun, Y.; Bottou, L.; Bengio, Y.; Haffner, P. (November 1998). [http://ieeexplore.ieee.org/document/726791/?reload=true "Gradient-based learning applied to document recognition"]. Proceedings of the IEEE. 86 (11): 2278–2324. doi:10.1109/5.726791. ISSN 0018-9219.<br />
# Simonyan, Karen; Zisserman, Andrew (2014-09-04). [http://arxiv.org/abs/1409.1556 "Very Deep Convolutional Networks for Large-Scale Image Recognition"]. arXiv:1409.1556 [cs].<br />
# Krizhevsky, Alex; Sutskever, Ilya; Hinton, Geoffrey E (2012). Pereira, F.; Burges, C. J. C.; Bottou, L.; Weinberger, K. Q., eds. [http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf Advances in Neural Information Processing Systems 25 (PDF)]. Curran Associates, Inc. pp. 1097–1105.<br />
# Chenzhuo Zhu, Song Han, Huizi Mao, and William J Dally. Trained ternary quantization. arXiv preprint arXiv:1612.01064, 2016.<br />
# Chen-Yu Lee, Saining Xie, Patrick Gallagher, Zhengyou Zhang, and Zhuowen Tu. Deeplysupervisednets. In Artificial Intelligence and Statistics, pp. 562–570, 2015.<br />
# Mohammad Rastegari, Vicente Ordonez, Joseph Redmon, and Ali Farhadi. Xnor-net: Imagenet classification using binary convolutional neural networks. In European Conference on Computer Vision, pp. 525–542. Springer, 2016.</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Training_And_Inference_with_Integers_in_Deep_Neural_Networks&diff=35876Training And Inference with Integers in Deep Neural Networks2018-03-28T21:32:15Z<p>X249wang: /* Shift-Based Linear Mapping and Stochastic Mapping */</p>
<hr />
<div>== Introduction ==<br />
<br />
Deep neural networks have enjoyed much success in all manners of tasks, but it is common for these networks to be complicated, requiring large amounts of energy-intensive memory and floating-point operations. Therefore, in order to use state-of-the-art networks in applications where energy is limited or having packaging limitation for hardware, such as anything not connected to the power grid, the energy costs must be reduced while preserving as much performance as practical.<br />
<br />
Most existing methods focus on reducing the energy requirements during inference rather than training. Since training with SGD requires accumulation, training usually has higher precision demand than inference. Most of the existing methods focus on how to compress a model for inference, rather than during training. This paper proposes a framework to reduce complexity both during training and inference through the use of integers instead of floats. They address how to quantize all operations and operands as well as examining the bitwidth requirement for SGD computation & accumulation. Using integers instead of floats results in energy-savings because integer operations are more efficient than floating point (see the table below). Also, there already exists dedicated hardware for deep learning that uses integer operations (such as the 1st generation of Google TPU) so understanding the best way to use integers is well-motivated.<br />
{| class="wikitable"<br />
|+Rough Energy Costs in 45nm 0.9V [1]<br />
!<br />
! colspan="2" |Energy(pJ)<br />
! colspan="2" |Area(<math>\mu m^2</math>)<br />
|-<br />
!Operation<br />
!MUL<br />
!ADD<br />
!MUL<br />
!ADD<br />
|-<br />
|8-bit INT<br />
|0.2<br />
|0.03<br />
|282<br />
|36<br />
|-<br />
|16-bit FP<br />
|1.1<br />
|0.4<br />
|1640<br />
|1360<br />
|-<br />
|32-bit FP<br />
|3.7<br />
|0.9<br />
|7700<br />
|4184<br />
|}<br />
The authors call the framework WAGE because they consider how best to handle the '''W'''eights, '''A'''ctivations, '''G'''radients, and '''E'''rrors separately.<br />
<br />
== Related Work ==<br />
<br />
=== Weight and Activation ===<br />
Existing works to train DNNs on binary weights and activations [2] add noise to weights and activations as a form of regularization. The use of high-precision accumulation is required for SGD optimization since real-valued gradients are obtained from real-valued variables. XNOR-Net [11] uses bitwise operations to approximate convolutions in a highly memory-efficient manner, and applies a filter-wise scaling factor for weights to improve performance. However, these floating-point factors are calculated simultaneously during training, which aggravates the training effort. Ternary weight networks (TWN) [3] and Trained ternary quantization (TTQ) [9] offer more expressive ability than binary weight networks by constraining the weights to be ternary-valued {-1,0,1} using two symmetric thresholds.<br />
<br />
=== Gradient Computation and Accumulation ===<br />
The DoReFa-Net quantizes gradients to low-bandwidth floating point numbers with discrete states in the backwards pass. In order to reduce the overhead of gradient synchronization in distributed training the TernGrad method quantizes the gradient updates to ternary values. In both works the weights are still stored and updated with float32, and the quantization of batch normalization and its derivative is ignored.<br />
<br />
== WAGE Quantization ==<br />
The core idea of the proposed method is to constrain the following to low-bitwidth integers on each layer:<br />
* '''W:''' weight in inference<br />
* '''a:''' activation in inference<br />
* '''e:''' error in backpropagation<br />
* '''g:''' gradient in backpropagation<br />
[[File:p32fig1.PNG|center|thumb|800px|Four operators QW (·), QA(·), QG(·), QE(·) added in WAGE computation dataflow to reduce precision, bitwidth of signed integers are below or on the right of arrows, activations are included in MAC for concision.]]<br />
The error and gradient are defined as:<br />
<br />
<math>e^i = \frac{\partial L}{\partial a^i}, g^i = \frac{\partial L}{\partial W^i}</math><br />
<br />
where L is the loss function.<br />
<br />
The precision in bits of the errors, activations, gradients, and weights are <math>k_E</math>, <math>k_A</math>, <math>k_G</math>, and <math>k_W</math> respectively. As shown in the above figure, each quantity also has a quantization operators to reduce bitwidth increases caused by multiply-accumulate (MAC) operations. Also, note that since this is a layer-by-layer approach, each layer may be followed or preceded by a layer with different precision, or even a layer using floating point math.<br />
<br />
=== Shift-Based Linear Mapping and Stochastic Mapping ===<br />
The proposed method makes use of a linear mapping where continuous, unbounded values are discretized for each bitwidth <math>k</math> with a uniform spacing of<br />
<br />
<math>\sigma(k) = 2^{1-k}, k \in Z_+ </math><br />
With this, the full quantization function is<br />
<br />
<math>Q(x,k) = Clip\left \{ \sigma(k) \cdot round\left [ \frac{x}{\sigma(k)} \right ], -1 + \sigma(k), 1 - \sigma(k) \right \}</math>, <br />
<br />
where <math>round</math> approximates continuous values to their nearest discrete state, and <math>Clip</math> is the saturation function that clips unbounded values to <math>[-1 + \sigma, 1 - \sigma]</math>. Note that this function is only using when simulating integer operations on floating-point hardware, on native integer hardware, this is done automatically. In addition to this quantization function, a distribution scaling factor is used in some quantization operators to preserve as much variance as possible when applying the quantization function above. The scaling factor is defined below.<br />
<br />
<math>Shift(x) = 2^{round(log_2(x))}</math><br />
<br />
Finally, stochastic rounding is substituted for small or real-valued updates during gradient accumulation.<br />
<br />
A visual representation of these operations is below.<br />
[[File:p32fig2.PNG|center|thumb|800px|Quantization methods used in WAGE. The notation <math>P, x, \lfloor \cdot \rfloor, \lceil \cdot \rceil</math> denotes probability, vector, floor and ceil, respectively. <math>Shift(\cdot)</math> refers to distribution shifting with a certain argument]]<br />
<br />
=== Weight Initialization ===<br />
In this work, batch normalization is simplified to a constant scaling layer in order to sidestep the problem of normalizing outputs without floating point math, and to remove the extra memory requirement with batch normalization. As such, some care must be taken when initializing weights. The authors use a modified initialization method base on MSRA [4].<br />
<br />
<math>W \thicksim U(-L, +L),L = max \left \{ \sqrt{6/n_{in}}, L_{min} \right \}, L_{min} = \beta \sigma</math><br />
<br />
<math>n_{in}</math> is the layer fan-in number, <math>U</math> denotes uniform distribution. The original<math>\eta</math> initialization method is modified by adding the condition that the distribution width should be at least <math>\beta \sigma</math>, where <math>\beta</math> is a constant greater than 1 and <math>\sigma</math> is the minimum step size see already. This prevents weights being initialised to all-zeros in the case where the bitwidth is low, or the fan-in number is high.<br />
<br />
=== Quantization Details ===<br />
<br />
==== Weight <math>Q_W(\cdot)</math> ====<br />
<math>W_q = Q_W(W) = Q(W, k_W)</math><br />
<br />
The quantization operator is simply the quantization function previously introduced. <br />
<br />
==== Activation <math>Q_A(\cdot)</math> ====<br />
The authors say that the variance of the weights passed through this function will be scaled compared to the variance of the weights as initialized. To prevent this effect from blowing up the network outputs, they introduce a scaling factor <math>\alpha</math>. Notice that it is constant for each layer.<br />
<br />
<math>\alpha = max \left \{ Shift(L_{min} / L), 1 \right \}</math><br />
<br />
The quantization operator is then<br />
<br />
<math>a_q = Q_A(a) = Q(a/\alpha, k_A)</math><br />
<br />
The scaling factor approximates batch normalization.<br />
<br />
==== Error <math>Q_E(\cdot)</math> ====<br />
The magnitude of the error can vary greatly, and that a previous approach (DoReFa-Net [5]) solves the issue by using an affine transform to map the error to the range <math>[-1, 1]</math>, apply quantization, and then applying the inverse transform. However, the authors claim that this approach still requires using float32, and that the magnitude of the error is unimportant: rather it is the orientation of the error. Thus, they only scale the error distribution to the range <math>\left [ -\sqrt2, \sqrt2 \right ]</math> and quantise:<br />
<br />
<math>e_q = Q_E(e) = Q(e/Shift(max\{|e|\}), k_E)</math><br />
<br />
Max is the element-wise maximum. Note that this discards any error elements less than the minimum step size.<br />
<br />
==== Gradient <math>Q_G(\cdot)</math> ====<br />
Similar to the activations and errors, the gradients are rescaled:<br />
<br />
<math>g_s = \eta \cdot g/Shift(max\{|g|\})</math><br />
<br />
<math> \eta </math> is a shift-based learning rate. It is an integer power of 2. The shifted gradients are represented in units of minimum step sizes <math> \sigma(k) </math>. When reducing the bitwidth of the gradients (remember that the gradients are coming out of a MAC operation, so the bitwidth may have increased) stochastic rounding is used as a substitute for small gradient accumulation.<br />
<br />
<math>\Delta W = Q_G(g) = \sigma(k_G) \cdot sgn(g_s) \cdot \left \{ \lfloor | g_s | \rfloor + Bernoulli(|g_s|<br />
- \lfloor | g_s | \rfloor) \right \}</math><br />
<br />
This randomly rounds the result of the MAC operation up or down to the nearest quantization for the given gradient bitwidth. The weights are updated with the resulting discrete increments:<br />
<br />
<math>W_{t+1} = Clip \left \{ W_t - \Delta W_t, -1 + \sigma(k_G), 1 - \sigma(k_G) \right \}</math><br />
<br />
=== Miscellaneous ===<br />
To train WAGE networks, the authors used pure SGD exclusively because more complicated techniques such as Momentum or RMSProp increase memory consumption and are complicated by the rescaling that happens within each quantization operator.<br />
<br />
The quantization and stochastic rounding are a form of regularization.<br />
<br />
The authors didn't use a traditional softmax with cross-entropy loss for the experiments because there does not yet exist a softmax layer for low-bit integers. Instead, they use a sum of squared error loss. This works for tasks with a small number of categories, but does not scale well.<br />
<br />
== Experiments ==<br />
For all experiments, the default layer bitwidth configuration is 2-8-8-8 for Weights, Activations, Gradients, and Error bits. The weight bitwidth is set to 2 because that results in ternary weights, and therefore no multiplication during inference. They authors argue that the bitwidth for activation and errors should be the same because the computation graph for each is similar and might use the same hardware. During training, the weight bitwidth is 8. For inference the weights are ternarized.<br />
<br />
=== Implementation Details ===<br />
MNIST: Network is LeNet-5 variant [6] with 32C5-MP2-64C5-MP2-512FC-10SSE.<br />
<br />
SVHN & CIFAR10: VGG variant [7] with 2×(128C3)-MP2-2×(256C3)-MP2-2×(512C3)-MP2-1024FC-10SSE. For CIFAR10 dataset, the data augmentation is followed in Lee et al. (2015) [10] for training.<br />
<br />
ImageNet: AlexNet variant [8] on ILSVRC12 dataset.<br />
{| class="wikitable"<br />
|+Test or validation error rates (%) in previous works and WAGE on multiple datasets. Opt denotes gradient descent optimizer, withM means SGD with momentum, BN represents batch normalization, 32 bit refers to float32, and ImageNet top-k format: top1/top5.<br />
!Method<br />
!<math>k_W</math><br />
!<math>k_A</math><br />
!<math>k_G</math><br />
!<math>k_E</math><br />
!Opt<br />
!BN<br />
!MNIST<br />
!SVHN<br />
!CIFAR10<br />
!ImageNet<br />
|-<br />
|BC<br />
|1<br />
|32<br />
|32<br />
|32<br />
|Adam<br />
|yes<br />
|1.29<br />
|2.30<br />
|9.90<br />
|<br />
|-<br />
|BNN<br />
|1<br />
|1<br />
|32<br />
|32<br />
|Adam<br />
|yes <br />
|0.96<br />
|2.53<br />
|10.15<br />
|<br />
|-<br />
|BWN<br />
|1<br />
|32<br />
|32<br />
|32<br />
|withM<br />
|yes<br />
|<br />
|<br />
|<br />
|43.2/20.6<br />
|-<br />
|XNOR<br />
|1<br />
|1<br />
|32<br />
|32<br />
|Adam<br />
|yes<br />
|<br />
|<br />
|<br />
|55.8/30.8<br />
|-<br />
|TWN<br />
|2<br />
|32<br />
|32<br />
|32<br />
|withM<br />
|yes<br />
|0.65<br />
|<br />
|7.44<br />
|'''34.7/13.8'''<br />
|-<br />
|TTQ<br />
|2<br />
|32<br />
|32<br />
|32<br />
|Adam<br />
|yes<br />
|<br />
|<br />
|6.44<br />
|42.5/20.3<br />
|-<br />
|DoReFa<br />
|8<br />
|8<br />
|32<br />
|8<br />
|Adam<br />
|yes<br />
|<br />
|2.30<br />
|<br />
|47.0/<br />
|-<br />
|TernGrad<br />
|32<br />
|32<br />
|2<br />
|32<br />
|Adam<br />
|yes<br />
|<br />
|<br />
|14.36<br />
|42.4/19.5<br />
|-<br />
|WAGE<br />
|2<br />
|8<br />
|8<br />
|8<br />
|SGD<br />
|no<br />
|'''0.40'''<br />
|'''1.92'''<br />
|'''6.78'''<br />
|51.6/27.8<br />
|}<br />
<br />
=== Training Curves and Regularization ===<br />
The authors compare the 2-8-8-8 WAGE configuration introduced above, a 2-8-f-f (meaning float32) configuration, and a completely floating point version on CIFAR10. The test error is plotted against epoch. For training these networks, the learning rate is divided by 8 at the 200th epoch and again at the 250th epoch.<br />
[[File:p32fig3.PNG|center|thumb|800px|Training curves of WAGE variations and a vanilla CNN on CIFAR10]]<br />
The convergence of the 2-8-8-8 has comparable convergence to the vanilla CNN and outperforms the 2-8-f-f variant. The authors speculate that this is because the extra discretization acts as a regularizer.<br />
<br />
=== Bitwidth of Errors ===<br />
The CIFAR10 test accuracy is plotted against bitwidth below and the error density for a single layer is compared with the Vanilla network.<br />
[[File:p32fig4.PNG|center|thumb|520x522px|The 10 run accuracies of different <math>k_E</math>]]<br />
<br />
[[File:32_error.png|center|thumb|520x522px|Histogram of errors for Vanilla network and Wage network. After being quantized and shifted each layer, the error is reshaped and so most orientation information is retained. ]]<br />
<br />
The table below shows the test error rates on CIFAR10 when left-shift upper boundary with factor γ. From this table we could see that large values play critical roles for backpropagation training even though they don't have many while the majority with small values actually just noises.<br />
<br />
[[File:testerror_rate.png|center]]<br />
<br />
=== Bitwidth of Gradients ===<br />
{| class="wikitable"<br />
|+Test error rates (%) on CIFAR10 with different <math>k_G</math><br />
!<math>k_G</math><br />
!2<br />
!3<br />
!4<br />
!5<br />
!6<br />
!7<br />
!8<br />
!9<br />
!10<br />
!11<br />
!12<br />
|-<br />
|error<br />
|54.22<br />
|51.57<br />
|28.22<br />
|18.01<br />
|11.48<br />
|7.61<br />
|6.78<br />
|6.63<br />
|6.43<br />
|6.55<br />
|6.57<br />
|}<br />
The authors also examined the effect of bitwidth on the ImageNet implementation.<br />
<br />
{| class="wikitable"<br />
|+Top-5 error rates (%) on ImageNet with different <math>k_G</math>and <math>k_E</math><br />
!Pattern<br />
!vanilla<br />
!28ff-BN<br />
!28ff<br />
!28f8<br />
!28C8<br />
!288C<br />
!2888<br />
|-<br />
|error<br />
|19.29<br />
|20.67<br />
|24.14<br />
|23.92<br />
|26.88<br />
|28.06<br />
|27.82<br />
|}<br />
Here, C denotes 12 bits (Hexidecimal) and BN refers to batch normalization being added.<br />
<br />
== Discussion ==<br />
The authors have a few areas they believe this approach could be improved.<br />
<br />
'''MAC Operation:''' The 2-8-8-8 configuration was chosen because the low weight bitwidth means there aren't any multiplication during inference. However, this does not remove the requirement for multiplication during training. 2-2-8-8 configuration satisfies this requirement, but it is difficult to train and detrimental to the accuracy.<br />
<br />
'''Non-linear Quantization:''' The linear mapping used in this approach is simple, but there might be a more effective mapping. For example, a logarithmic mapping could be more effective if the weights and activations have a log-normal distribution.<br />
<br />
'''Normalization:''' Normalization layers (softmax, batch normalization) were not used in this paper. Quantized versions are an area of future work<br />
<br />
== Conclusion ==<br />
<br />
A framework for training and inference without the use of floating-point representation is presented. Future work may further improve compression and memory requirements.<br />
== References ==<br />
<br />
# Sze, Vivienne; Chen, Yu-Hsin; Yang, Tien-Ju; Emer, Joel (2017-03-27). [http://arxiv.org/abs/1703.09039 "Efficient Processing of Deep Neural Networks: A Tutorial and Survey"]. arXiv:1703.09039 [cs].<br />
# Courbariaux, Matthieu; Bengio, Yoshua; David, Jean-Pierre (2015-11-01). [http://arxiv.org/abs/1511.00363 "BinaryConnect: Training Deep Neural Networks with binary weights during propagations"]. arXiv:1511.00363 [cs].<br />
# Li, Fengfu; Zhang, Bo; Liu, Bin (2016-05-16). [http://arxiv.org/abs/1605.04711 "Ternary Weight Networks"]. arXiv:1605.04711 [cs].<br />
# He, Kaiming; Zhang, Xiangyu; Ren, Shaoqing; Sun, Jian (2015-02-06). [http://arxiv.org/abs/1502.01852 "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification"]. arXiv:1502.01852 [cs].<br />
# Zhou, Shuchang; Wu, Yuxin; Ni, Zekun; Zhou, Xinyu; Wen, He; Zou, Yuheng (2016-06-20). [http://arxiv.org/abs/1606.06160 "DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients"]. arXiv:1606.06160 [cs].<br />
# Lecun, Y.; Bottou, L.; Bengio, Y.; Haffner, P. (November 1998). [http://ieeexplore.ieee.org/document/726791/?reload=true "Gradient-based learning applied to document recognition"]. Proceedings of the IEEE. 86 (11): 2278–2324. doi:10.1109/5.726791. ISSN 0018-9219.<br />
# Simonyan, Karen; Zisserman, Andrew (2014-09-04). [http://arxiv.org/abs/1409.1556 "Very Deep Convolutional Networks for Large-Scale Image Recognition"]. arXiv:1409.1556 [cs].<br />
# Krizhevsky, Alex; Sutskever, Ilya; Hinton, Geoffrey E (2012). Pereira, F.; Burges, C. J. C.; Bottou, L.; Weinberger, K. Q., eds. [http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf Advances in Neural Information Processing Systems 25 (PDF)]. Curran Associates, Inc. pp. 1097–1105.<br />
# Chenzhuo Zhu, Song Han, Huizi Mao, and William J Dally. Trained ternary quantization. arXiv preprint arXiv:1612.01064, 2016.<br />
# Chen-Yu Lee, Saining Xie, Patrick Gallagher, Zhengyou Zhang, and Zhuowen Tu. Deeplysupervisednets. In Artificial Intelligence and Statistics, pp. 562–570, 2015.<br />
# Mohammad Rastegari, Vicente Ordonez, Joseph Redmon, and Ali Farhadi. Xnor-net: Imagenet classification using binary convolutional neural networks. In European Conference on Computer Vision, pp. 525–542. Springer, 2016.</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Training_And_Inference_with_Integers_in_Deep_Neural_Networks&diff=35875Training And Inference with Integers in Deep Neural Networks2018-03-28T21:21:56Z<p>X249wang: /* References */</p>
<hr />
<div>== Introduction ==<br />
<br />
Deep neural networks have enjoyed much success in all manners of tasks, but it is common for these networks to be complicated, requiring large amounts of energy-intensive memory and floating-point operations. Therefore, in order to use state-of-the-art networks in applications where energy is limited or having packaging limitation for hardware, such as anything not connected to the power grid, the energy costs must be reduced while preserving as much performance as practical.<br />
<br />
Most existing methods focus on reducing the energy requirements during inference rather than training. Since training with SGD requires accumulation, training usually has higher precision demand than inference. Most of the existing methods focus on how to compress a model for inference, rather than during training. This paper proposes a framework to reduce complexity both during training and inference through the use of integers instead of floats. They address how to quantize all operations and operands as well as examining the bitwidth requirement for SGD computation & accumulation. Using integers instead of floats results in energy-savings because integer operations are more efficient than floating point (see the table below). Also, there already exists dedicated hardware for deep learning that uses integer operations (such as the 1st generation of Google TPU) so understanding the best way to use integers is well-motivated.<br />
{| class="wikitable"<br />
|+Rough Energy Costs in 45nm 0.9V [1]<br />
!<br />
! colspan="2" |Energy(pJ)<br />
! colspan="2" |Area(<math>\mu m^2</math>)<br />
|-<br />
!Operation<br />
!MUL<br />
!ADD<br />
!MUL<br />
!ADD<br />
|-<br />
|8-bit INT<br />
|0.2<br />
|0.03<br />
|282<br />
|36<br />
|-<br />
|16-bit FP<br />
|1.1<br />
|0.4<br />
|1640<br />
|1360<br />
|-<br />
|32-bit FP<br />
|3.7<br />
|0.9<br />
|7700<br />
|4184<br />
|}<br />
The authors call the framework WAGE because they consider how best to handle the '''W'''eights, '''A'''ctivations, '''G'''radients, and '''E'''rrors separately.<br />
<br />
== Related Work ==<br />
<br />
=== Weight and Activation ===<br />
Existing works to train DNNs on binary weights and activations [2] add noise to weights and activations as a form of regularization. The use of high-precision accumulation is required for SGD optimization since real-valued gradients are obtained from real-valued variables. XNOR-Net [11] uses bitwise operations to approximate convolutions in a highly memory-efficient manner, and applies a filter-wise scaling factor for weights to improve performance. However, these floating-point factors are calculated simultaneously during training, which aggravates the training effort. Ternary weight networks (TWN) [3] and Trained ternary quantization (TTQ) [9] offer more expressive ability than binary weight networks by constraining the weights to be ternary-valued {-1,0,1} using two symmetric thresholds.<br />
<br />
=== Gradient Computation and Accumulation ===<br />
The DoReFa-Net quantizes gradients to low-bandwidth floating point numbers with discrete states in the backwards pass. In order to reduce the overhead of gradient synchronization in distributed training the TernGrad method quantizes the gradient updates to ternary values. In both works the weights are still stored and updated with float32, and the quantization of batch normalization and its derivative is ignored.<br />
<br />
== WAGE Quantization ==<br />
The core idea of the proposed method is to constrain the following to low-bitwidth integers on each layer:<br />
* '''W:''' weight in inference<br />
* '''a:''' activation in inference<br />
* '''e:''' error in backpropagation<br />
* '''g:''' gradient in backpropagation<br />
[[File:p32fig1.PNG|center|thumb|800px|Four operators QW (·), QA(·), QG(·), QE(·) added in WAGE computation dataflow to reduce precision, bitwidth of signed integers are below or on the right of arrows, activations are included in MAC for concision.]]<br />
The error and gradient are defined as:<br />
<br />
<math>e^i = \frac{\partial L}{\partial a^i}, g^i = \frac{\partial L}{\partial W^i}</math><br />
<br />
where L is the loss function.<br />
<br />
The precision in bits of the errors, activations, gradients, and weights are <math>k_E</math>, <math>k_A</math>, <math>k_G</math>, and <math>k_W</math> respectively. As shown in the above figure, each quantity also has a quantization operators to reduce bitwidth increases caused by multiply-accumulate (MAC) operations. Also, note that since this is a layer-by-layer approach, each layer may be followed or preceded by a layer with different precision, or even a layer using floating point math.<br />
<br />
=== Shift-Based Linear Mapping and Stochastic Mapping ===<br />
The proposed method makes use of a linear mapping where continuous, unbounded values are discretized for each bitwidth <math>k</math> with a uniform spacing of<br />
<br />
<math>\sigma(k) = 2^{1-k}, k \in Z_+ </math><br />
With this, the full quantization function is<br />
<br />
<math>Q(x,k) = Clip\left \{ \sigma(k) \cdot round\left [ \frac{x}{\sigma(k)} \right ], -1 + \sigma(k), 1 - \sigma(k) \right \}</math><br />
<br />
Note that this function is only using when simulating integer operations on floating-point hardware, on native integer hardware, this is done automatically. In addition to this quantization function.<br />
<br />
A distribution scaling factor is used in some quantization operators to preserve as much variance as possible when applying the quantization function above. The scaling factor is defined below.<br />
<br />
<math>Shift(x) = 2^{round(log_2(x))}</math><br />
<br />
Finally, stochastic rounding is substituted for small or real-valued updates during gradient accumulation.<br />
<br />
A visual representation of these operations is below.<br />
[[File:p32fig2.PNG|center|thumb|800px|Quantization methods used in WAGE. The notation <math>P, x, \lfloor \cdot \rfloor, \lceil \cdot \rceil</math> denotes probability, vector, floor and ceil, respectively. <math>Shift(\cdot)</math> refers to distribution shifting with a certain argument]]<br />
<br />
=== Weight Initialization ===<br />
In this work, batch normalization is simplified to a constant scaling layer in order to sidestep the problem of normalizing outputs without floating point math, and to remove the extra memory requirement with batch normalization. As such, some care must be taken when initializing weights. The authors use a modified initialization method base on MSRA [4].<br />
<br />
<math>W \thicksim U(-L, +L),L = max \left \{ \sqrt{6/n_{in}}, L_{min} \right \}, L_{min} = \beta \sigma</math><br />
<br />
<math>n_{in}</math> is the layer fan-in number, <math>U</math> denotes uniform distribution. The original<math>\eta</math> initialization method is modified by adding the condition that the distribution width should be at least <math>\beta \sigma</math>, where <math>\beta</math> is a constant greater than 1 and <math>\sigma</math> is the minimum step size see already. This prevents weights being initialised to all-zeros in the case where the bitwidth is low, or the fan-in number is high.<br />
<br />
=== Quantization Details ===<br />
<br />
==== Weight <math>Q_W(\cdot)</math> ====<br />
<math>W_q = Q_W(W) = Q(W, k_W)</math><br />
<br />
The quantization operator is simply the quantization function previously introduced. <br />
<br />
==== Activation <math>Q_A(\cdot)</math> ====<br />
The authors say that the variance of the weights passed through this function will be scaled compared to the variance of the weights as initialized. To prevent this effect from blowing up the network outputs, they introduce a scaling factor <math>\alpha</math>. Notice that it is constant for each layer.<br />
<br />
<math>\alpha = max \left \{ Shift(L_{min} / L), 1 \right \}</math><br />
<br />
The quantization operator is then<br />
<br />
<math>a_q = Q_A(a) = Q(a/\alpha, k_A)</math><br />
<br />
The scaling factor approximates batch normalization.<br />
<br />
==== Error <math>Q_E(\cdot)</math> ====<br />
The magnitude of the error can vary greatly, and that a previous approach (DoReFa-Net [5]) solves the issue by using an affine transform to map the error to the range <math>[-1, 1]</math>, apply quantization, and then applying the inverse transform. However, the authors claim that this approach still requires using float32, and that the magnitude of the error is unimportant: rather it is the orientation of the error. Thus, they only scale the error distribution to the range <math>\left [ -\sqrt2, \sqrt2 \right ]</math> and quantise:<br />
<br />
<math>e_q = Q_E(e) = Q(e/Shift(max\{|e|\}), k_E)</math><br />
<br />
Max is the element-wise maximum. Note that this discards any error elements less than the minimum step size.<br />
<br />
==== Gradient <math>Q_G(\cdot)</math> ====<br />
Similar to the activations and errors, the gradients are rescaled:<br />
<br />
<math>g_s = \eta \cdot g/Shift(max\{|g|\})</math><br />
<br />
<math> \eta </math> is a shift-based learning rate. It is an integer power of 2. The shifted gradients are represented in units of minimum step sizes <math> \sigma(k) </math>. When reducing the bitwidth of the gradients (remember that the gradients are coming out of a MAC operation, so the bitwidth may have increased) stochastic rounding is used as a substitute for small gradient accumulation.<br />
<br />
<math>\Delta W = Q_G(g) = \sigma(k_G) \cdot sgn(g_s) \cdot \left \{ \lfloor | g_s | \rfloor + Bernoulli(|g_s|<br />
- \lfloor | g_s | \rfloor) \right \}</math><br />
<br />
This randomly rounds the result of the MAC operation up or down to the nearest quantization for the given gradient bitwidth. The weights are updated with the resulting discrete increments:<br />
<br />
<math>W_{t+1} = Clip \left \{ W_t - \Delta W_t, -1 + \sigma(k_G), 1 - \sigma(k_G) \right \}</math><br />
<br />
=== Miscellaneous ===<br />
To train WAGE networks, the authors used pure SGD exclusively because more complicated techniques such as Momentum or RMSProp increase memory consumption and are complicated by the rescaling that happens within each quantization operator.<br />
<br />
The quantization and stochastic rounding are a form of regularization.<br />
<br />
The authors didn't use a traditional softmax with cross-entropy loss for the experiments because there does not yet exist a softmax layer for low-bit integers. Instead, they use a sum of squared error loss. This works for tasks with a small number of categories, but does not scale well.<br />
<br />
== Experiments ==<br />
For all experiments, the default layer bitwidth configuration is 2-8-8-8 for Weights, Activations, Gradients, and Error bits. The weight bitwidth is set to 2 because that results in ternary weights, and therefore no multiplication during inference. They authors argue that the bitwidth for activation and errors should be the same because the computation graph for each is similar and might use the same hardware. During training, the weight bitwidth is 8. For inference the weights are ternarized.<br />
<br />
=== Implementation Details ===<br />
MNIST: Network is LeNet-5 variant [6] with 32C5-MP2-64C5-MP2-512FC-10SSE.<br />
<br />
SVHN & CIFAR10: VGG variant [7] with 2×(128C3)-MP2-2×(256C3)-MP2-2×(512C3)-MP2-1024FC-10SSE. For CIFAR10 dataset, the data augmentation is followed in Lee et al. (2015) [10] for training.<br />
<br />
ImageNet: AlexNet variant [8] on ILSVRC12 dataset.<br />
{| class="wikitable"<br />
|+Test or validation error rates (%) in previous works and WAGE on multiple datasets. Opt denotes gradient descent optimizer, withM means SGD with momentum, BN represents batch normalization, 32 bit refers to float32, and ImageNet top-k format: top1/top5.<br />
!Method<br />
!<math>k_W</math><br />
!<math>k_A</math><br />
!<math>k_G</math><br />
!<math>k_E</math><br />
!Opt<br />
!BN<br />
!MNIST<br />
!SVHN<br />
!CIFAR10<br />
!ImageNet<br />
|-<br />
|BC<br />
|1<br />
|32<br />
|32<br />
|32<br />
|Adam<br />
|yes<br />
|1.29<br />
|2.30<br />
|9.90<br />
|<br />
|-<br />
|BNN<br />
|1<br />
|1<br />
|32<br />
|32<br />
|Adam<br />
|yes <br />
|0.96<br />
|2.53<br />
|10.15<br />
|<br />
|-<br />
|BWN<br />
|1<br />
|32<br />
|32<br />
|32<br />
|withM<br />
|yes<br />
|<br />
|<br />
|<br />
|43.2/20.6<br />
|-<br />
|XNOR<br />
|1<br />
|1<br />
|32<br />
|32<br />
|Adam<br />
|yes<br />
|<br />
|<br />
|<br />
|55.8/30.8<br />
|-<br />
|TWN<br />
|2<br />
|32<br />
|32<br />
|32<br />
|withM<br />
|yes<br />
|0.65<br />
|<br />
|7.44<br />
|'''34.7/13.8'''<br />
|-<br />
|TTQ<br />
|2<br />
|32<br />
|32<br />
|32<br />
|Adam<br />
|yes<br />
|<br />
|<br />
|6.44<br />
|42.5/20.3<br />
|-<br />
|DoReFa<br />
|8<br />
|8<br />
|32<br />
|8<br />
|Adam<br />
|yes<br />
|<br />
|2.30<br />
|<br />
|47.0/<br />
|-<br />
|TernGrad<br />
|32<br />
|32<br />
|2<br />
|32<br />
|Adam<br />
|yes<br />
|<br />
|<br />
|14.36<br />
|42.4/19.5<br />
|-<br />
|WAGE<br />
|2<br />
|8<br />
|8<br />
|8<br />
|SGD<br />
|no<br />
|'''0.40'''<br />
|'''1.92'''<br />
|'''6.78'''<br />
|51.6/27.8<br />
|}<br />
<br />
=== Training Curves and Regularization ===<br />
The authors compare the 2-8-8-8 WAGE configuration introduced above, a 2-8-f-f (meaning float32) configuration, and a completely floating point version on CIFAR10. The test error is plotted against epoch. For training these networks, the learning rate is divided by 8 at the 200th epoch and again at the 250th epoch.<br />
[[File:p32fig3.PNG|center|thumb|800px|Training curves of WAGE variations and a vanilla CNN on CIFAR10]]<br />
The convergence of the 2-8-8-8 has comparable convergence to the vanilla CNN and outperforms the 2-8-f-f variant. The authors speculate that this is because the extra discretization acts as a regularizer.<br />
<br />
=== Bitwidth of Errors ===<br />
The CIFAR10 test accuracy is plotted against bitwidth below and the error density for a single layer is compared with the Vanilla network.<br />
[[File:p32fig4.PNG|center|thumb|520x522px|The 10 run accuracies of different <math>k_E</math>]]<br />
<br />
[[File:32_error.png|center|thumb|520x522px|Histogram of errors for Vanilla network and Wage network. After being quantized and shifted each layer, the error is reshaped and so most orientation information is retained. ]]<br />
<br />
The table below shows the test error rates on CIFAR10 when left-shift upper boundary with factor γ. From this table we could see that large values play critical roles for backpropagation training even though they don't have many while the majority with small values actually just noises.<br />
<br />
[[File:testerror_rate.png|center]]<br />
<br />
=== Bitwidth of Gradients ===<br />
{| class="wikitable"<br />
|+Test error rates (%) on CIFAR10 with different <math>k_G</math><br />
!<math>k_G</math><br />
!2<br />
!3<br />
!4<br />
!5<br />
!6<br />
!7<br />
!8<br />
!9<br />
!10<br />
!11<br />
!12<br />
|-<br />
|error<br />
|54.22<br />
|51.57<br />
|28.22<br />
|18.01<br />
|11.48<br />
|7.61<br />
|6.78<br />
|6.63<br />
|6.43<br />
|6.55<br />
|6.57<br />
|}<br />
The authors also examined the effect of bitwidth on the ImageNet implementation.<br />
<br />
{| class="wikitable"<br />
|+Top-5 error rates (%) on ImageNet with different <math>k_G</math>and <math>k_E</math><br />
!Pattern<br />
!vanilla<br />
!28ff-BN<br />
!28ff<br />
!28f8<br />
!28C8<br />
!288C<br />
!2888<br />
|-<br />
|error<br />
|19.29<br />
|20.67<br />
|24.14<br />
|23.92<br />
|26.88<br />
|28.06<br />
|27.82<br />
|}<br />
Here, C denotes 12 bits (Hexidecimal) and BN refers to batch normalization being added.<br />
<br />
== Discussion ==<br />
The authors have a few areas they believe this approach could be improved.<br />
<br />
'''MAC Operation:''' The 2-8-8-8 configuration was chosen because the low weight bitwidth means there aren't any multiplication during inference. However, this does not remove the requirement for multiplication during training. 2-2-8-8 configuration satisfies this requirement, but it is difficult to train and detrimental to the accuracy.<br />
<br />
'''Non-linear Quantization:''' The linear mapping used in this approach is simple, but there might be a more effective mapping. For example, a logarithmic mapping could be more effective if the weights and activations have a log-normal distribution.<br />
<br />
'''Normalization:''' Normalization layers (softmax, batch normalization) were not used in this paper. Quantized versions are an area of future work<br />
<br />
== Conclusion ==<br />
<br />
A framework for training and inference without the use of floating-point representation is presented. Future work may further improve compression and memory requirements.<br />
== References ==<br />
<br />
# Sze, Vivienne; Chen, Yu-Hsin; Yang, Tien-Ju; Emer, Joel (2017-03-27). [http://arxiv.org/abs/1703.09039 "Efficient Processing of Deep Neural Networks: A Tutorial and Survey"]. arXiv:1703.09039 [cs].<br />
# Courbariaux, Matthieu; Bengio, Yoshua; David, Jean-Pierre (2015-11-01). [http://arxiv.org/abs/1511.00363 "BinaryConnect: Training Deep Neural Networks with binary weights during propagations"]. arXiv:1511.00363 [cs].<br />
# Li, Fengfu; Zhang, Bo; Liu, Bin (2016-05-16). [http://arxiv.org/abs/1605.04711 "Ternary Weight Networks"]. arXiv:1605.04711 [cs].<br />
# He, Kaiming; Zhang, Xiangyu; Ren, Shaoqing; Sun, Jian (2015-02-06). [http://arxiv.org/abs/1502.01852 "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification"]. arXiv:1502.01852 [cs].<br />
# Zhou, Shuchang; Wu, Yuxin; Ni, Zekun; Zhou, Xinyu; Wen, He; Zou, Yuheng (2016-06-20). [http://arxiv.org/abs/1606.06160 "DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients"]. arXiv:1606.06160 [cs].<br />
# Lecun, Y.; Bottou, L.; Bengio, Y.; Haffner, P. (November 1998). [http://ieeexplore.ieee.org/document/726791/?reload=true "Gradient-based learning applied to document recognition"]. Proceedings of the IEEE. 86 (11): 2278–2324. doi:10.1109/5.726791. ISSN 0018-9219.<br />
# Simonyan, Karen; Zisserman, Andrew (2014-09-04). [http://arxiv.org/abs/1409.1556 "Very Deep Convolutional Networks for Large-Scale Image Recognition"]. arXiv:1409.1556 [cs].<br />
# Krizhevsky, Alex; Sutskever, Ilya; Hinton, Geoffrey E (2012). Pereira, F.; Burges, C. J. C.; Bottou, L.; Weinberger, K. Q., eds. [http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf Advances in Neural Information Processing Systems 25 (PDF)]. Curran Associates, Inc. pp. 1097–1105.<br />
# Chenzhuo Zhu, Song Han, Huizi Mao, and William J Dally. Trained ternary quantization. arXiv preprint arXiv:1612.01064, 2016.<br />
# Chen-Yu Lee, Saining Xie, Patrick Gallagher, Zhengyou Zhang, and Zhuowen Tu. Deeplysupervisednets. In Artificial Intelligence and Statistics, pp. 562–570, 2015.<br />
# Mohammad Rastegari, Vicente Ordonez, Joseph Redmon, and Ali Farhadi. Xnor-net: Imagenet classification using binary convolutional neural networks. In European Conference on Computer Vision, pp. 525–542. Springer, 2016.</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Training_And_Inference_with_Integers_in_Deep_Neural_Networks&diff=35874Training And Inference with Integers in Deep Neural Networks2018-03-28T21:21:12Z<p>X249wang: /* Weight and Activation */</p>
<hr />
<div>== Introduction ==<br />
<br />
Deep neural networks have enjoyed much success in all manners of tasks, but it is common for these networks to be complicated, requiring large amounts of energy-intensive memory and floating-point operations. Therefore, in order to use state-of-the-art networks in applications where energy is limited or having packaging limitation for hardware, such as anything not connected to the power grid, the energy costs must be reduced while preserving as much performance as practical.<br />
<br />
Most existing methods focus on reducing the energy requirements during inference rather than training. Since training with SGD requires accumulation, training usually has higher precision demand than inference. Most of the existing methods focus on how to compress a model for inference, rather than during training. This paper proposes a framework to reduce complexity both during training and inference through the use of integers instead of floats. They address how to quantize all operations and operands as well as examining the bitwidth requirement for SGD computation & accumulation. Using integers instead of floats results in energy-savings because integer operations are more efficient than floating point (see the table below). Also, there already exists dedicated hardware for deep learning that uses integer operations (such as the 1st generation of Google TPU) so understanding the best way to use integers is well-motivated.<br />
{| class="wikitable"<br />
|+Rough Energy Costs in 45nm 0.9V [1]<br />
!<br />
! colspan="2" |Energy(pJ)<br />
! colspan="2" |Area(<math>\mu m^2</math>)<br />
|-<br />
!Operation<br />
!MUL<br />
!ADD<br />
!MUL<br />
!ADD<br />
|-<br />
|8-bit INT<br />
|0.2<br />
|0.03<br />
|282<br />
|36<br />
|-<br />
|16-bit FP<br />
|1.1<br />
|0.4<br />
|1640<br />
|1360<br />
|-<br />
|32-bit FP<br />
|3.7<br />
|0.9<br />
|7700<br />
|4184<br />
|}<br />
The authors call the framework WAGE because they consider how best to handle the '''W'''eights, '''A'''ctivations, '''G'''radients, and '''E'''rrors separately.<br />
<br />
== Related Work ==<br />
<br />
=== Weight and Activation ===<br />
Existing works to train DNNs on binary weights and activations [2] add noise to weights and activations as a form of regularization. The use of high-precision accumulation is required for SGD optimization since real-valued gradients are obtained from real-valued variables. XNOR-Net [11] uses bitwise operations to approximate convolutions in a highly memory-efficient manner, and applies a filter-wise scaling factor for weights to improve performance. However, these floating-point factors are calculated simultaneously during training, which aggravates the training effort. Ternary weight networks (TWN) [3] and Trained ternary quantization (TTQ) [9] offer more expressive ability than binary weight networks by constraining the weights to be ternary-valued {-1,0,1} using two symmetric thresholds.<br />
<br />
=== Gradient Computation and Accumulation ===<br />
The DoReFa-Net quantizes gradients to low-bandwidth floating point numbers with discrete states in the backwards pass. In order to reduce the overhead of gradient synchronization in distributed training the TernGrad method quantizes the gradient updates to ternary values. In both works the weights are still stored and updated with float32, and the quantization of batch normalization and its derivative is ignored.<br />
<br />
== WAGE Quantization ==<br />
The core idea of the proposed method is to constrain the following to low-bitwidth integers on each layer:<br />
* '''W:''' weight in inference<br />
* '''a:''' activation in inference<br />
* '''e:''' error in backpropagation<br />
* '''g:''' gradient in backpropagation<br />
[[File:p32fig1.PNG|center|thumb|800px|Four operators QW (·), QA(·), QG(·), QE(·) added in WAGE computation dataflow to reduce precision, bitwidth of signed integers are below or on the right of arrows, activations are included in MAC for concision.]]<br />
The error and gradient are defined as:<br />
<br />
<math>e^i = \frac{\partial L}{\partial a^i}, g^i = \frac{\partial L}{\partial W^i}</math><br />
<br />
where L is the loss function.<br />
<br />
The precision in bits of the errors, activations, gradients, and weights are <math>k_E</math>, <math>k_A</math>, <math>k_G</math>, and <math>k_W</math> respectively. As shown in the above figure, each quantity also has a quantization operators to reduce bitwidth increases caused by multiply-accumulate (MAC) operations. Also, note that since this is a layer-by-layer approach, each layer may be followed or preceded by a layer with different precision, or even a layer using floating point math.<br />
<br />
=== Shift-Based Linear Mapping and Stochastic Mapping ===<br />
The proposed method makes use of a linear mapping where continuous, unbounded values are discretized for each bitwidth <math>k</math> with a uniform spacing of<br />
<br />
<math>\sigma(k) = 2^{1-k}, k \in Z_+ </math><br />
With this, the full quantization function is<br />
<br />
<math>Q(x,k) = Clip\left \{ \sigma(k) \cdot round\left [ \frac{x}{\sigma(k)} \right ], -1 + \sigma(k), 1 - \sigma(k) \right \}</math><br />
<br />
Note that this function is only using when simulating integer operations on floating-point hardware, on native integer hardware, this is done automatically. In addition to this quantization function.<br />
<br />
A distribution scaling factor is used in some quantization operators to preserve as much variance as possible when applying the quantization function above. The scaling factor is defined below.<br />
<br />
<math>Shift(x) = 2^{round(log_2(x))}</math><br />
<br />
Finally, stochastic rounding is substituted for small or real-valued updates during gradient accumulation.<br />
<br />
A visual representation of these operations is below.<br />
[[File:p32fig2.PNG|center|thumb|800px|Quantization methods used in WAGE. The notation <math>P, x, \lfloor \cdot \rfloor, \lceil \cdot \rceil</math> denotes probability, vector, floor and ceil, respectively. <math>Shift(\cdot)</math> refers to distribution shifting with a certain argument]]<br />
<br />
=== Weight Initialization ===<br />
In this work, batch normalization is simplified to a constant scaling layer in order to sidestep the problem of normalizing outputs without floating point math, and to remove the extra memory requirement with batch normalization. As such, some care must be taken when initializing weights. The authors use a modified initialization method base on MSRA [4].<br />
<br />
<math>W \thicksim U(-L, +L),L = max \left \{ \sqrt{6/n_{in}}, L_{min} \right \}, L_{min} = \beta \sigma</math><br />
<br />
<math>n_{in}</math> is the layer fan-in number, <math>U</math> denotes uniform distribution. The original<math>\eta</math> initialization method is modified by adding the condition that the distribution width should be at least <math>\beta \sigma</math>, where <math>\beta</math> is a constant greater than 1 and <math>\sigma</math> is the minimum step size see already. This prevents weights being initialised to all-zeros in the case where the bitwidth is low, or the fan-in number is high.<br />
<br />
=== Quantization Details ===<br />
<br />
==== Weight <math>Q_W(\cdot)</math> ====<br />
<math>W_q = Q_W(W) = Q(W, k_W)</math><br />
<br />
The quantization operator is simply the quantization function previously introduced. <br />
<br />
==== Activation <math>Q_A(\cdot)</math> ====<br />
The authors say that the variance of the weights passed through this function will be scaled compared to the variance of the weights as initialized. To prevent this effect from blowing up the network outputs, they introduce a scaling factor <math>\alpha</math>. Notice that it is constant for each layer.<br />
<br />
<math>\alpha = max \left \{ Shift(L_{min} / L), 1 \right \}</math><br />
<br />
The quantization operator is then<br />
<br />
<math>a_q = Q_A(a) = Q(a/\alpha, k_A)</math><br />
<br />
The scaling factor approximates batch normalization.<br />
<br />
==== Error <math>Q_E(\cdot)</math> ====<br />
The magnitude of the error can vary greatly, and that a previous approach (DoReFa-Net [5]) solves the issue by using an affine transform to map the error to the range <math>[-1, 1]</math>, apply quantization, and then applying the inverse transform. However, the authors claim that this approach still requires using float32, and that the magnitude of the error is unimportant: rather it is the orientation of the error. Thus, they only scale the error distribution to the range <math>\left [ -\sqrt2, \sqrt2 \right ]</math> and quantise:<br />
<br />
<math>e_q = Q_E(e) = Q(e/Shift(max\{|e|\}), k_E)</math><br />
<br />
Max is the element-wise maximum. Note that this discards any error elements less than the minimum step size.<br />
<br />
==== Gradient <math>Q_G(\cdot)</math> ====<br />
Similar to the activations and errors, the gradients are rescaled:<br />
<br />
<math>g_s = \eta \cdot g/Shift(max\{|g|\})</math><br />
<br />
<math> \eta </math> is a shift-based learning rate. It is an integer power of 2. The shifted gradients are represented in units of minimum step sizes <math> \sigma(k) </math>. When reducing the bitwidth of the gradients (remember that the gradients are coming out of a MAC operation, so the bitwidth may have increased) stochastic rounding is used as a substitute for small gradient accumulation.<br />
<br />
<math>\Delta W = Q_G(g) = \sigma(k_G) \cdot sgn(g_s) \cdot \left \{ \lfloor | g_s | \rfloor + Bernoulli(|g_s|<br />
- \lfloor | g_s | \rfloor) \right \}</math><br />
<br />
This randomly rounds the result of the MAC operation up or down to the nearest quantization for the given gradient bitwidth. The weights are updated with the resulting discrete increments:<br />
<br />
<math>W_{t+1} = Clip \left \{ W_t - \Delta W_t, -1 + \sigma(k_G), 1 - \sigma(k_G) \right \}</math><br />
<br />
=== Miscellaneous ===<br />
To train WAGE networks, the authors used pure SGD exclusively because more complicated techniques such as Momentum or RMSProp increase memory consumption and are complicated by the rescaling that happens within each quantization operator.<br />
<br />
The quantization and stochastic rounding are a form of regularization.<br />
<br />
The authors didn't use a traditional softmax with cross-entropy loss for the experiments because there does not yet exist a softmax layer for low-bit integers. Instead, they use a sum of squared error loss. This works for tasks with a small number of categories, but does not scale well.<br />
<br />
== Experiments ==<br />
For all experiments, the default layer bitwidth configuration is 2-8-8-8 for Weights, Activations, Gradients, and Error bits. The weight bitwidth is set to 2 because that results in ternary weights, and therefore no multiplication during inference. They authors argue that the bitwidth for activation and errors should be the same because the computation graph for each is similar and might use the same hardware. During training, the weight bitwidth is 8. For inference the weights are ternarized.<br />
<br />
=== Implementation Details ===<br />
MNIST: Network is LeNet-5 variant [6] with 32C5-MP2-64C5-MP2-512FC-10SSE.<br />
<br />
SVHN & CIFAR10: VGG variant [7] with 2×(128C3)-MP2-2×(256C3)-MP2-2×(512C3)-MP2-1024FC-10SSE. For CIFAR10 dataset, the data augmentation is followed in Lee et al. (2015) [10] for training.<br />
<br />
ImageNet: AlexNet variant [8] on ILSVRC12 dataset.<br />
{| class="wikitable"<br />
|+Test or validation error rates (%) in previous works and WAGE on multiple datasets. Opt denotes gradient descent optimizer, withM means SGD with momentum, BN represents batch normalization, 32 bit refers to float32, and ImageNet top-k format: top1/top5.<br />
!Method<br />
!<math>k_W</math><br />
!<math>k_A</math><br />
!<math>k_G</math><br />
!<math>k_E</math><br />
!Opt<br />
!BN<br />
!MNIST<br />
!SVHN<br />
!CIFAR10<br />
!ImageNet<br />
|-<br />
|BC<br />
|1<br />
|32<br />
|32<br />
|32<br />
|Adam<br />
|yes<br />
|1.29<br />
|2.30<br />
|9.90<br />
|<br />
|-<br />
|BNN<br />
|1<br />
|1<br />
|32<br />
|32<br />
|Adam<br />
|yes <br />
|0.96<br />
|2.53<br />
|10.15<br />
|<br />
|-<br />
|BWN<br />
|1<br />
|32<br />
|32<br />
|32<br />
|withM<br />
|yes<br />
|<br />
|<br />
|<br />
|43.2/20.6<br />
|-<br />
|XNOR<br />
|1<br />
|1<br />
|32<br />
|32<br />
|Adam<br />
|yes<br />
|<br />
|<br />
|<br />
|55.8/30.8<br />
|-<br />
|TWN<br />
|2<br />
|32<br />
|32<br />
|32<br />
|withM<br />
|yes<br />
|0.65<br />
|<br />
|7.44<br />
|'''34.7/13.8'''<br />
|-<br />
|TTQ<br />
|2<br />
|32<br />
|32<br />
|32<br />
|Adam<br />
|yes<br />
|<br />
|<br />
|6.44<br />
|42.5/20.3<br />
|-<br />
|DoReFa<br />
|8<br />
|8<br />
|32<br />
|8<br />
|Adam<br />
|yes<br />
|<br />
|2.30<br />
|<br />
|47.0/<br />
|-<br />
|TernGrad<br />
|32<br />
|32<br />
|2<br />
|32<br />
|Adam<br />
|yes<br />
|<br />
|<br />
|14.36<br />
|42.4/19.5<br />
|-<br />
|WAGE<br />
|2<br />
|8<br />
|8<br />
|8<br />
|SGD<br />
|no<br />
|'''0.40'''<br />
|'''1.92'''<br />
|'''6.78'''<br />
|51.6/27.8<br />
|}<br />
<br />
=== Training Curves and Regularization ===<br />
The authors compare the 2-8-8-8 WAGE configuration introduced above, a 2-8-f-f (meaning float32) configuration, and a completely floating point version on CIFAR10. The test error is plotted against epoch. For training these networks, the learning rate is divided by 8 at the 200th epoch and again at the 250th epoch.<br />
[[File:p32fig3.PNG|center|thumb|800px|Training curves of WAGE variations and a vanilla CNN on CIFAR10]]<br />
The convergence of the 2-8-8-8 has comparable convergence to the vanilla CNN and outperforms the 2-8-f-f variant. The authors speculate that this is because the extra discretization acts as a regularizer.<br />
<br />
=== Bitwidth of Errors ===<br />
The CIFAR10 test accuracy is plotted against bitwidth below and the error density for a single layer is compared with the Vanilla network.<br />
[[File:p32fig4.PNG|center|thumb|520x522px|The 10 run accuracies of different <math>k_E</math>]]<br />
<br />
[[File:32_error.png|center|thumb|520x522px|Histogram of errors for Vanilla network and Wage network. After being quantized and shifted each layer, the error is reshaped and so most orientation information is retained. ]]<br />
<br />
The table below shows the test error rates on CIFAR10 when left-shift upper boundary with factor γ. From this table we could see that large values play critical roles for backpropagation training even though they don't have many while the majority with small values actually just noises.<br />
<br />
[[File:testerror_rate.png|center]]<br />
<br />
=== Bitwidth of Gradients ===<br />
{| class="wikitable"<br />
|+Test error rates (%) on CIFAR10 with different <math>k_G</math><br />
!<math>k_G</math><br />
!2<br />
!3<br />
!4<br />
!5<br />
!6<br />
!7<br />
!8<br />
!9<br />
!10<br />
!11<br />
!12<br />
|-<br />
|error<br />
|54.22<br />
|51.57<br />
|28.22<br />
|18.01<br />
|11.48<br />
|7.61<br />
|6.78<br />
|6.63<br />
|6.43<br />
|6.55<br />
|6.57<br />
|}<br />
The authors also examined the effect of bitwidth on the ImageNet implementation.<br />
<br />
{| class="wikitable"<br />
|+Top-5 error rates (%) on ImageNet with different <math>k_G</math>and <math>k_E</math><br />
!Pattern<br />
!vanilla<br />
!28ff-BN<br />
!28ff<br />
!28f8<br />
!28C8<br />
!288C<br />
!2888<br />
|-<br />
|error<br />
|19.29<br />
|20.67<br />
|24.14<br />
|23.92<br />
|26.88<br />
|28.06<br />
|27.82<br />
|}<br />
Here, C denotes 12 bits (Hexidecimal) and BN refers to batch normalization being added.<br />
<br />
== Discussion ==<br />
The authors have a few areas they believe this approach could be improved.<br />
<br />
'''MAC Operation:''' The 2-8-8-8 configuration was chosen because the low weight bitwidth means there aren't any multiplication during inference. However, this does not remove the requirement for multiplication during training. 2-2-8-8 configuration satisfies this requirement, but it is difficult to train and detrimental to the accuracy.<br />
<br />
'''Non-linear Quantization:''' The linear mapping used in this approach is simple, but there might be a more effective mapping. For example, a logarithmic mapping could be more effective if the weights and activations have a log-normal distribution.<br />
<br />
'''Normalization:''' Normalization layers (softmax, batch normalization) were not used in this paper. Quantized versions are an area of future work<br />
<br />
== Conclusion ==<br />
<br />
A framework for training and inference without the use of floating-point representation is presented. Future work may further improve compression and memory requirements.<br />
== References ==<br />
<br />
# Sze, Vivienne; Chen, Yu-Hsin; Yang, Tien-Ju; Emer, Joel (2017-03-27). [http://arxiv.org/abs/1703.09039 "Efficient Processing of Deep Neural Networks: A Tutorial and Survey"]. arXiv:1703.09039 [cs].<br />
# Courbariaux, Matthieu; Bengio, Yoshua; David, Jean-Pierre (2015-11-01). [http://arxiv.org/abs/1511.00363 "BinaryConnect: Training Deep Neural Networks with binary weights during propagations"]. arXiv:1511.00363 [cs].<br />
# Li, Fengfu; Zhang, Bo; Liu, Bin (2016-05-16). [http://arxiv.org/abs/1605.04711 "Ternary Weight Networks"]. arXiv:1605.04711 [cs].<br />
# He, Kaiming; Zhang, Xiangyu; Ren, Shaoqing; Sun, Jian (2015-02-06). [http://arxiv.org/abs/1502.01852 "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification"]. arXiv:1502.01852 [cs].<br />
# Zhou, Shuchang; Wu, Yuxin; Ni, Zekun; Zhou, Xinyu; Wen, He; Zou, Yuheng (2016-06-20). [http://arxiv.org/abs/1606.06160 "DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients"]. arXiv:1606.06160 [cs].<br />
# Lecun, Y.; Bottou, L.; Bengio, Y.; Haffner, P. (November 1998). [http://ieeexplore.ieee.org/document/726791/?reload=true "Gradient-based learning applied to document recognition"]. Proceedings of the IEEE. 86 (11): 2278–2324. doi:10.1109/5.726791. ISSN 0018-9219.<br />
# Simonyan, Karen; Zisserman, Andrew (2014-09-04). [http://arxiv.org/abs/1409.1556 "Very Deep Convolutional Networks for Large-Scale Image Recognition"]. arXiv:1409.1556 [cs].<br />
# Krizhevsky, Alex; Sutskever, Ilya; Hinton, Geoffrey E (2012). Pereira, F.; Burges, C. J. C.; Bottou, L.; Weinberger, K. Q., eds. [http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf Advances in Neural Information Processing Systems 25 (PDF)]. Curran Associates, Inc. pp. 1097–1105.<br />
# Chenzhuo Zhu, Song Han, Huizi Mao, and William J Dally. Trained ternary quantization. arXiv preprint arXiv:1612.01064, 2016.<br />
# Chen-Yu Lee, Saining Xie, Patrick Gallagher, Zhengyou Zhang, and Zhuowen Tu. Deeplysupervisednets. In Artificial Intelligence and Statistics, pp. 562–570, 2015.</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=MarrNet:_3D_Shape_Reconstruction_via_2.5D_Sketches&diff=35648MarrNet: 3D Shape Reconstruction via 2.5D Sketches2018-03-27T05:00:21Z<p>X249wang: /* Results */</p>
<hr />
<div>= Introduction =<br />
Humans are able to quickly recognize 3D shapes from images, even in spite of drastic differences in object texture, material, lighting, and background.<br />
<br />
[[File:marrnet_intro_image.png|700px|thumb|center|Objects in real images. The appearance of the same shaped object varies based on colour, texture, lighting, background, etc. However, the 2.5D sketches (e.g. depth or normal maps) of the object remain constant, and can be seen as an abstraction of the object which is used to reconstruct the 3D shape.]]<br />
<br />
In this work, the authors propose a novel end-to-end trainable model that sequentially estimates 2.5D sketches and 3D object shape from images and also enforce the re projection consistency between the 3D shape and the estimated sketch. The two step approach makes the network more robust to differences in object texture, material, lighting and background. Based on the idea from [Marr, 1982] that human 3D perception relies on recovering 2.5D sketches, which include depth maps (contains information related to the distance of surfaces from a viewpoint) and surface normal maps (technique for adding the illusion of depth details to surfaces using an image's RGB information), the authors design an end-to-end trainable pipeline which they call MarrNet. MarrNet first estimates depth, normal maps, and silhouette, followed by a 3D shape. MarrNet uses an encoder-decoder structure for the sub-components of the framework. <br />
<br />
The authors claim several unique advantages to their method. Single image 3D reconstruction is a highly under-constrained problem, requiring strong prior knowledge of object shapes. As well, accurate 3D object annotations using real images are not common, and many previous approaches rely on purely synthetic data. However, most of these methods suffer from domain adaptation due to imperfect rendering.<br />
<br />
Using 2.5D sketches can alleviate the challenges of domain transfer. It is straightforward to generate perfect object surface normals and depths using a graphics engine. Since 2.5D sketches contain only depth, surface normal, and silhouette information, the second step of recovering 3D shape can be trained purely from synthetic data. As well, the introduction of differentiable constraints between 2.5D sketches and 3D shape makes it possible to fine-tune the system, even without any annotations.<br />
<br />
The framework is evaluated on both synthetic objects from ShapeNet, and real images from PASCAL 3D+, showing good qualitative and quantitative performance in 3D shape reconstruction.<br />
<br />
= Related Work =<br />
<br />
== 2.5D Sketch Recovery ==<br />
Researchers have explored recovering 2.5D information from shading, texture, and colour images in the past. More recently, the development of depth sensors has led to the creation of large RGB-D datasets, and papers on estimating depth, surface normals, and other intrinsic images using deep networks. While this method employs 2.5D estimation, the final output is a full 3D shape of an object.<br />
<br />
[[File:2-5d_example.PNG|700px|thumb|center|Results from the paper: Learning Non-Lambertian Object Intrinsics across ShapeNet Categories. The results show that neural networks can be trained to recover 2.5D information from an image. The top row predicts the albedo and the bottom row predicts the shading. It can be observed that the results are still blurry and the fine details are not fully recovered.]]<br />
<br />
== Single Image 3D Reconstruction ==<br />
The development of large-scale shape repositories like ShapeNet has allowed for the development of models encoding shape priors for single image 3D reconstruction. These methods normally regress voxelized 3D shapes, relying on synthetic data or 2D masks for training. The formulation in the paper tackles domain adaptation better, since the network can be fine-tuned on images without any annotations.<br />
<br />
== 2D-3D Consistency ==<br />
Intuitively, the 3D shape can be constrained to be consistent with 2D observations. This idea has been explored for decades, with the use of depth and silhouettes, as well as some papers enforcing differentiable 2D-3D constraints for joint training of deep networks. In this work, this idea is exploited to develop differentiable constraints for consistency between the 2.5D sketches and 3D shape.<br />
<br />
= Approach =<br />
The 3D structure is recovered from a single RGB view using three steps, shown in the figure below. The first step estimates 2.5D sketches, including depth, surface normal, and silhouette of the object. The second step estimates a 3D voxel representation of the object. The third step uses a reprojection consistency function to enforce the 2.5D sketch and 3D structure alignment.<br />
<br />
[[File:marrnet_model_components.png|700px|thumb|center|MarrNet architecture. 2.5D sketches of normals, depths, and silhouette are first estimated. The sketches are then used to estimate the 3D shape. Finally, re-projection consistency is used to ensure consistency between the sketch and 3D output.]]<br />
<br />
== 2.5D Sketch Estimation ==<br />
The first step takes a 2D RGB image and predicts the surface normal, depth, and silhouette of the object. The goal is to estimate intrinsic object properties from the image, while discarding non-essential information. An encoder-decoder architecture is used. The encoder is a A ResNet-18 network, which takes a 256 x 256 RGB image and produces 8 x 8 x 512 feature maps. The decoder is four sets of 5 x 5 convolutional and ReLU layers, followed by four sets of 1 x 1 convolutional and ReLU layers. The output is 256 x 256 resolution depth, surface normal, and silhouette images.<br />
<br />
== 3D Shape Estimation ==<br />
The second step estimates a voxelized 3D shape using the 2.5D sketches from the first step. The focus here is for the network to learn the shape prior that can explain the input well, and can be trained on synthetic data without suffering from the domain adaptation problem since it only takes in surface normal and depth images as input. The network architecture is inspired by the TL network, and 3D-VAE-GAN, with an encoder-decoder structure. The normal and depth image, masked by the estimated silhouette, are passed into 5 sets of convolutional, ReLU, and pooling layers, followed by two fully connected layers, with a final output width of 200. The 200-dimensional vector is passed into a decoder of 5 convolutional and ReLU layers, outputting a 128 x 128 x 128 voxelized estimate of the input.<br />
<br />
== Re-projection Consistency ==<br />
The third step consists of a depth re-projection loss and surface normal re-projection loss. Here, <math>v_{x, y, z}</math> represents the value at position <math>(x, y, z)</math> in a 3D voxel grid, with <math>v_{x, y, z} \in [0, 1] ∀ x, y, z</math>. <math>d_{x, y}</math> denotes the estimated depth at position <math>(x, y)</math>, <math>n_{x, y} = (n_a, n_b, n_c)</math> denotes the estimated surface normal. Orthographic projection is used.<br />
<br />
[[File:marrnet_reprojection_consistency.png|700px|thumb|center|Reprojection consistency for voxels. Left and middle: criteria for depth and silhouettes. Right: criterion for surface normals]]<br />
<br />
=== Depths ===<br />
The voxel with depth <math>v_{x, y}, d_{x, y}</math> should be 1, while all voxels in front of it should be 0. This ensures the estimated 3D shape matches the estimated depth values. The projected depth loss and its gradient are defined as follows:<br />
<br />
<math><br />
L_{depth}(x, y, z)=<br />
\left\{<br />
\begin{array}{ll}<br />
v^2_{x, y, z}, & z < d_{x, y} \\<br />
(1 - v_{x, y, z})^2, & z = d_{x, y} \\<br />
0, & z > d_{x, y} \\<br />
\end{array}<br />
\right.<br />
</math><br />
<br />
<math><br />
\frac{∂L_{depth}(x, y, z)}{∂v_{x, y, z}} =<br />
\left\{<br />
\begin{array}{ll}<br />
2v{x, y, z}, & z < d_{x, y} \\<br />
2(v_{x, y, z} - 1), & z = d_{x, y} \\<br />
0, & z > d_{x, y} \\<br />
\end{array}<br />
\right.<br />
</math><br />
<br />
When <math>d_{x, y} = \infty</math>, all voxels in front of it should be 0.<br />
<br />
=== Surface Normals ===<br />
Since vectors <math>n_{x} = (0, −n_{c}, n_{b})</math> and <math>n_{y} = (−n_{c}, 0, n_{a})</math> are orthogonal to the normal vector <math>n_{x, y} = (n_{a}, n_{b}, n_{c})</math>, they can be normalized to obtain <math>n’_{x} = (0, −1, n_{b}/n_{c})</math> and <math>n’_{y} = (−1, 0, n_{a}/n_{c})</math> on the estimated surface plane at <math>(x, y, z)</math>. The projected surface normal tried to guarantee voxels at <math>(x, y, z) ± n’_{x}</math> and <math>(x, y, z) ± n’_{y}</math> should be 1 to match the estimated normal. The constraints are only applied when the target voxels are inside the estimated silhouette.<br />
<br />
The projected surface normal loss is defined as follows, with <math>z = d_{x, y}</math>:<br />
<br />
<math><br />
L_{normal}(x, y, z) =<br />
(1 - v_{x, y-1, z+\frac{n_b}{n_c}})^2 + (1 - v_{x, y+1, z-\frac{n_b}{n_c}})^2 + <br />
(1 - v_{x-1, y, z+\frac{n_a}{n_c}})^2 + (1 - v_{x+1, y, z-\frac{n_a}{n_c}})^2<br />
</math><br />
<br />
Gradients along x are:<br />
<br />
<math><br />
\frac{dL_{normal}(x, y, z)}{dv_{x-1, y, z+\frac{n_a}{n_c}}} = 2(v_{x-1, y, z+\frac{n_a}{n_c}}-1)<br />
</math><br />
and<br />
<math><br />
\frac{dL_{normal}(x, y, z)}{dv_{x+1, y, z-\frac{n_a}{n_c}}} = 2(v_{x+1, y, z-\frac{n_a}{n_c}}-1)<br />
</math><br />
<br />
Gradients along y are similar to x.<br />
<br />
= Training =<br />
The 2.5D and 3D estimation components are first pre-trained separately on synthetic data from ShapeNet, and then fine-tuned on real images.<br />
<br />
For pre-training, the 2.5D sketch estimator is trained on synthetic ShapeNet depth, surface normal, and silhouette ground truth, using an L2 loss. The 3D estimator is trained with ground truth voxels using a cross-entropy loss.<br />
<br />
Reprojection consistency loss is used to fine-tune the 3D estimation using real images, using the predicted depth, normals, and silhouette. A straightforward implementation leads to shapes that explain the 2.5D sketches well, but lead to unrealistic 3D appearance due to overfitting.<br />
<br />
Instead, the decoder of the 3D estimator is fixed, and only the encoder is fine-tuned. The model is fine-tuned separately on each image for 40 iterations, which takes up to 10 seconds on the GPU. Without fine-tuning, testing time takes around 100 milliseconds. SGD is used for optimization with batch size of 4, learning rate of 0.001, and momentum of 0.9.<br />
<br />
= Evaluation =<br />
Qualitative and quantitative results are provided using different variants of the framework. The framework is evaluated on both synthetic and real images on three datasets.<br />
<br />
== ShapeNet ==<br />
Synthesized images of 6,778 chairs from ShapeNet are rendered from 20 random viewpoints. The chairs are placed in front of random background from the SUN dataset, and the RGB, depth, normal, and silhouette images are rendered using the physics-based renderer Mitsuba for more realistic images.<br />
<br />
=== Method ===<br />
MarrNet is trained without the final fine-tuning stage, since 3D shapes are available. A baseline is created that directly predicts the 3D shape using the same 3D shape estimator architecture with no 2.5D sketch estimation.<br />
<br />
=== Results ===<br />
The baseline output is compared to the full framework, and the figure below shows that MarrNet provides model outputs with more details and smoother surfaces than the baseline. The estimated normal and depth images are able to extract intrinsic information about object shape while leaving behind non-essential information such as textures from the original images. Quantitatively, the full model also achieves 0.57 IoU, higher than the direct prediction baseline.<br />
<br />
[[File:marrnet_shapenet_results.png|700px|thumb|center|ShapeNet results.]]<br />
<br />
== PASCAL 3D+ ==<br />
Rough 3D models are provided from real-life images.<br />
<br />
=== Method ===<br />
Each module is pre-trained on the ShapeNet dataset, and then fine-tuned on the PASCAL 3D+ dataset. Three variants of the model are tested. The first is trained using ShapeNet data only with no fine-tuning. The second is fine-tuned without fixing the decoder. The third is fine-tuned with a fixed decoder.<br />
<br />
=== Results ===<br />
The figure below shows the results of the ablation study. The model trained only on synthetic data provides reasonable estimates. However, fine-tuning without fixing the decoder leads to impossible shapes from certain views. The third model keeps the shape prior, providing more details in the final shape.<br />
<br />
[[File:marrnet_pascal_3d_ablation.png|600px|thumb|center|Ablation studies using the PASCAL 3D+ dataset.]]<br />
<br />
Additional comparisons are made with the state-of-the-art (DRC) on the provided ground truth shapes. MarrNet achieves 0.39 IoU, while DRC achieves 0.34. However, the authors claim that the IoU metric is sub-optimal for three reasons. First, there is no emphasis on details since the metric prefers models that predict mean shapes consistently. Second, all possible scales are searched during the IoU computation, making it less efficient. Third, PASCAL 3D+ only has rough annotations, with only 10 CAD chair models for all images, and computing IoU with these shapes is not very informative. Instead, human studies are conducted and MarrNet reconstructions are preferred 74% of the time over DRC, and 42% of the time to ground truth. This shows how MarrNet produces nice shapes and also highlights the fact that ground truth shapes are not very good.<br />
<br />
[[File:human_studies.png|400px|thumb|center|Human preferences on chairs in PASCAL 3D+ (Xiang et al. 2014). The numbers show the percentage of how often humans prefered the 3D shape from DRC (state-of-the-art), MarrNet, or GT.]]<br />
<br />
<br />
[[File:marrnet_pascal_3d_drc_comparison.png|600px|thumb|center|Comparison between DRC and MarrNet results.]]<br />
<br />
Several failure cases are shown in the figure below. Specifically, the framework does not seem to work well on thin structures.<br />
<br />
[[File:marrnet_pascal_3d_failure_cases.png|500px|thumb|center|Failure cases on PASCAL 3D+. The algorithm cannot recover thin structures.]]<br />
<br />
== IKEA ==<br />
This dataset contains images of IKEA furniture, with accurate 3D shape and pose annotations. Objects are often heavily occluded or truncated.<br />
<br />
=== Results ===<br />
Qualitative results are shown in the figure below. The model is shown to deal with mild occlusions in real life scenarios. Human studes show that MarrNet reconstructions are preferred 61% of the time to 3D-VAE-GAN.<br />
<br />
[[File:marrnet_ikea_results.png|700px|thumb|center|Results on chairs in the IKEA dataset, and comparison with 3D-VAE-GAN.]]<br />
<br />
== Other Data ==<br />
MarrNet is also applied on cars and airplanes. Shown below, smaller details such as the horizontal stabilizer and rear-view mirrors are recovered.<br />
<br />
[[File:marrnet_airplanes_and_cars.png|700px|thumb|center|Results on airplanes and cars from the PASCAL 3D+ dataset, and comparison with DRC.]]<br />
<br />
MarrNet is also jointly trained on three object categories, and successfully recovers the shapes of different categories. Results are shown in the figure below.<br />
<br />
[[File:marrnet_multiple_categories.png|700px|thumb|center|Results when trained jointly on all three object categories (cars, airplanes, and chairs).]]<br />
<br />
= Commentary =<br />
Qualitatively, the results look quite impressive. The 2.5D sketch estimation seems to distill the useful information for more realistic looking 3D shape estimation. The disentanglement of 2.5D and 3D estimation steps also allows for easier training and domain adaptation from synthetic data.<br />
<br />
As the authors mention, the IoU metric is not very descriptive, and most of the comparisons in this paper are only qualitative, mainly being human preference studies. A better quantitative evaluation metric would greatly help in making an unbiased comparison between different results.<br />
<br />
As seen in several of the results, the network does not deal well with objects that have thin structures, which is particularly noticeable with many of the chair arm rests. As well, looking more carefully at some results, it seems that fine-tuning only the 3D encoder does not seem to transfer well to unseen objects, since shape priors have already been learned by the decoder.<br />
<br />
= Conclusion =<br />
The proposed MarrNet employs a novel model to estimate 2.5D sketches for 3D shape reconstruction. The sketches are shown to improve the model’s performance, and make it easy to adapt to images across different domains and categories. Differentiable loss functions are created such that the model can be fine-tuned end-to-end on images without ground truth. The experiments show that the model performs well, and human studies show that the results are preferred over other methods.<br />
<br />
= Implementation =<br />
The following repository provides the source code for the paper. The repository provides the source code as written by the authors: https://github.com/jiajunwu/marrnet<br />
<br />
= References =<br />
# David Marr. Vision: A computational investigation into the human representation and processing of visual information. W. H. Freeman and Company, 1982.<br />
# Shubham Tulsiani, Tinghui Zhou, Alexei A Efros, and Jitendra Malik. Multi-view supervision for single-view reconstruction via differentiable ray consistency. In CVPR, 2017.<br />
# JiajunWu, Chengkai Zhang, Tianfan Xue,William T Freeman, and Joshua B Tenenbaum. Learning a Proba- bilistic Latent Space of Object Shapes via 3D Generative-Adversarial Modeling. In NIPS, 2016b.<br />
# Wu, J. (n.d.). Jiajunwu/marrnet. Retrieved March 25, 2018, from https://github.com/jiajunwu/marrnet</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=MarrNet:_3D_Shape_Reconstruction_via_2.5D_Sketches&diff=35647MarrNet: 3D Shape Reconstruction via 2.5D Sketches2018-03-27T04:59:51Z<p>X249wang: Undo revision 35646 by X249wang (talk)</p>
<hr />
<div>= Introduction =<br />
Humans are able to quickly recognize 3D shapes from images, even in spite of drastic differences in object texture, material, lighting, and background.<br />
<br />
[[File:marrnet_intro_image.png|700px|thumb|center|Objects in real images. The appearance of the same shaped object varies based on colour, texture, lighting, background, etc. However, the 2.5D sketches (e.g. depth or normal maps) of the object remain constant, and can be seen as an abstraction of the object which is used to reconstruct the 3D shape.]]<br />
<br />
In this work, the authors propose a novel end-to-end trainable model that sequentially estimates 2.5D sketches and 3D object shape from images and also enforce the re projection consistency between the 3D shape and the estimated sketch. The two step approach makes the network more robust to differences in object texture, material, lighting and background. Based on the idea from [Marr, 1982] that human 3D perception relies on recovering 2.5D sketches, which include depth maps (contains information related to the distance of surfaces from a viewpoint) and surface normal maps (technique for adding the illusion of depth details to surfaces using an image's RGB information), the authors design an end-to-end trainable pipeline which they call MarrNet. MarrNet first estimates depth, normal maps, and silhouette, followed by a 3D shape. MarrNet uses an encoder-decoder structure for the sub-components of the framework. <br />
<br />
The authors claim several unique advantages to their method. Single image 3D reconstruction is a highly under-constrained problem, requiring strong prior knowledge of object shapes. As well, accurate 3D object annotations using real images are not common, and many previous approaches rely on purely synthetic data. However, most of these methods suffer from domain adaptation due to imperfect rendering.<br />
<br />
Using 2.5D sketches can alleviate the challenges of domain transfer. It is straightforward to generate perfect object surface normals and depths using a graphics engine. Since 2.5D sketches contain only depth, surface normal, and silhouette information, the second step of recovering 3D shape can be trained purely from synthetic data. As well, the introduction of differentiable constraints between 2.5D sketches and 3D shape makes it possible to fine-tune the system, even without any annotations.<br />
<br />
The framework is evaluated on both synthetic objects from ShapeNet, and real images from PASCAL 3D+, showing good qualitative and quantitative performance in 3D shape reconstruction.<br />
<br />
= Related Work =<br />
<br />
== 2.5D Sketch Recovery ==<br />
Researchers have explored recovering 2.5D information from shading, texture, and colour images in the past. More recently, the development of depth sensors has led to the creation of large RGB-D datasets, and papers on estimating depth, surface normals, and other intrinsic images using deep networks. While this method employs 2.5D estimation, the final output is a full 3D shape of an object.<br />
<br />
[[File:2-5d_example.PNG|700px|thumb|center|Results from the paper: Learning Non-Lambertian Object Intrinsics across ShapeNet Categories. The results show that neural networks can be trained to recover 2.5D information from an image. The top row predicts the albedo and the bottom row predicts the shading. It can be observed that the results are still blurry and the fine details are not fully recovered.]]<br />
<br />
== Single Image 3D Reconstruction ==<br />
The development of large-scale shape repositories like ShapeNet has allowed for the development of models encoding shape priors for single image 3D reconstruction. These methods normally regress voxelized 3D shapes, relying on synthetic data or 2D masks for training. The formulation in the paper tackles domain adaptation better, since the network can be fine-tuned on images without any annotations.<br />
<br />
== 2D-3D Consistency ==<br />
Intuitively, the 3D shape can be constrained to be consistent with 2D observations. This idea has been explored for decades, with the use of depth and silhouettes, as well as some papers enforcing differentiable 2D-3D constraints for joint training of deep networks. In this work, this idea is exploited to develop differentiable constraints for consistency between the 2.5D sketches and 3D shape.<br />
<br />
= Approach =<br />
The 3D structure is recovered from a single RGB view using three steps, shown in the figure below. The first step estimates 2.5D sketches, including depth, surface normal, and silhouette of the object. The second step estimates a 3D voxel representation of the object. The third step uses a reprojection consistency function to enforce the 2.5D sketch and 3D structure alignment.<br />
<br />
[[File:marrnet_model_components.png|700px|thumb|center|MarrNet architecture. 2.5D sketches of normals, depths, and silhouette are first estimated. The sketches are then used to estimate the 3D shape. Finally, re-projection consistency is used to ensure consistency between the sketch and 3D output.]]<br />
<br />
== 2.5D Sketch Estimation ==<br />
The first step takes a 2D RGB image and predicts the surface normal, depth, and silhouette of the object. The goal is to estimate intrinsic object properties from the image, while discarding non-essential information. An encoder-decoder architecture is used. The encoder is a A ResNet-18 network, which takes a 256 x 256 RGB image and produces 8 x 8 x 512 feature maps. The decoder is four sets of 5 x 5 convolutional and ReLU layers, followed by four sets of 1 x 1 convolutional and ReLU layers. The output is 256 x 256 resolution depth, surface normal, and silhouette images.<br />
<br />
== 3D Shape Estimation ==<br />
The second step estimates a voxelized 3D shape using the 2.5D sketches from the first step. The focus here is for the network to learn the shape prior that can explain the input well, and can be trained on synthetic data without suffering from the domain adaptation problem since it only takes in surface normal and depth images as input. The network architecture is inspired by the TL network, and 3D-VAE-GAN, with an encoder-decoder structure. The normal and depth image, masked by the estimated silhouette, are passed into 5 sets of convolutional, ReLU, and pooling layers, followed by two fully connected layers, with a final output width of 200. The 200-dimensional vector is passed into a decoder of 5 convolutional and ReLU layers, outputting a 128 x 128 x 128 voxelized estimate of the input.<br />
<br />
== Re-projection Consistency ==<br />
The third step consists of a depth re-projection loss and surface normal re-projection loss. Here, <math>v_{x, y, z}</math> represents the value at position <math>(x, y, z)</math> in a 3D voxel grid, with <math>v_{x, y, z} \in [0, 1] ∀ x, y, z</math>. <math>d_{x, y}</math> denotes the estimated depth at position <math>(x, y)</math>, <math>n_{x, y} = (n_a, n_b, n_c)</math> denotes the estimated surface normal. Orthographic projection is used.<br />
<br />
[[File:marrnet_reprojection_consistency.png|700px|thumb|center|Reprojection consistency for voxels. Left and middle: criteria for depth and silhouettes. Right: criterion for surface normals]]<br />
<br />
=== Depths ===<br />
The voxel with depth <math>v_{x, y}, d_{x, y}</math> should be 1, while all voxels in front of it should be 0. This ensures the estimated 3D shape matches the estimated depth values. The projected depth loss and its gradient are defined as follows:<br />
<br />
<math><br />
L_{depth}(x, y, z)=<br />
\left\{<br />
\begin{array}{ll}<br />
v^2_{x, y, z}, & z < d_{x, y} \\<br />
(1 - v_{x, y, z})^2, & z = d_{x, y} \\<br />
0, & z > d_{x, y} \\<br />
\end{array}<br />
\right.<br />
</math><br />
<br />
<math><br />
\frac{∂L_{depth}(x, y, z)}{∂v_{x, y, z}} =<br />
\left\{<br />
\begin{array}{ll}<br />
2v{x, y, z}, & z < d_{x, y} \\<br />
2(v_{x, y, z} - 1), & z = d_{x, y} \\<br />
0, & z > d_{x, y} \\<br />
\end{array}<br />
\right.<br />
</math><br />
<br />
When <math>d_{x, y} = \infty</math>, all voxels in front of it should be 0.<br />
<br />
=== Surface Normals ===<br />
Since vectors <math>n_{x} = (0, −n_{c}, n_{b})</math> and <math>n_{y} = (−n_{c}, 0, n_{a})</math> are orthogonal to the normal vector <math>n_{x, y} = (n_{a}, n_{b}, n_{c})</math>, they can be normalized to obtain <math>n’_{x} = (0, −1, n_{b}/n_{c})</math> and <math>n’_{y} = (−1, 0, n_{a}/n_{c})</math> on the estimated surface plane at <math>(x, y, z)</math>. The projected surface normal tried to guarantee voxels at <math>(x, y, z) ± n’_{x}</math> and <math>(x, y, z) ± n’_{y}</math> should be 1 to match the estimated normal. The constraints are only applied when the target voxels are inside the estimated silhouette.<br />
<br />
The projected surface normal loss is defined as follows, with <math>z = d_{x, y}</math>:<br />
<br />
<math><br />
L_{normal}(x, y, z) =<br />
(1 - v_{x, y-1, z+\frac{n_b}{n_c}})^2 + (1 - v_{x, y+1, z-\frac{n_b}{n_c}})^2 + <br />
(1 - v_{x-1, y, z+\frac{n_a}{n_c}})^2 + (1 - v_{x+1, y, z-\frac{n_a}{n_c}})^2<br />
</math><br />
<br />
Gradients along x are:<br />
<br />
<math><br />
\frac{dL_{normal}(x, y, z)}{dv_{x-1, y, z+\frac{n_a}{n_c}}} = 2(v_{x-1, y, z+\frac{n_a}{n_c}}-1)<br />
</math><br />
and<br />
<math><br />
\frac{dL_{normal}(x, y, z)}{dv_{x+1, y, z-\frac{n_a}{n_c}}} = 2(v_{x+1, y, z-\frac{n_a}{n_c}}-1)<br />
</math><br />
<br />
Gradients along y are similar to x.<br />
<br />
= Training =<br />
The 2.5D and 3D estimation components are first pre-trained separately on synthetic data from ShapeNet, and then fine-tuned on real images.<br />
<br />
For pre-training, the 2.5D sketch estimator is trained on synthetic ShapeNet depth, surface normal, and silhouette ground truth, using an L2 loss. The 3D estimator is trained with ground truth voxels using a cross-entropy loss.<br />
<br />
Reprojection consistency loss is used to fine-tune the 3D estimation using real images, using the predicted depth, normals, and silhouette. A straightforward implementation leads to shapes that explain the 2.5D sketches well, but lead to unrealistic 3D appearance due to overfitting.<br />
<br />
Instead, the decoder of the 3D estimator is fixed, and only the encoder is fine-tuned. The model is fine-tuned separately on each image for 40 iterations, which takes up to 10 seconds on the GPU. Without fine-tuning, testing time takes around 100 milliseconds. SGD is used for optimization with batch size of 4, learning rate of 0.001, and momentum of 0.9.<br />
<br />
= Evaluation =<br />
Qualitative and quantitative results are provided using different variants of the framework. The framework is evaluated on both synthetic and real images on three datasets.<br />
<br />
== ShapeNet ==<br />
Synthesized images of 6,778 chairs from ShapeNet are rendered from 20 random viewpoints. The chairs are placed in front of random background from the SUN dataset, and the RGB, depth, normal, and silhouette images are rendered using the physics-based renderer Mitsuba for more realistic images.<br />
<br />
=== Method ===<br />
MarrNet is trained without the final fine-tuning stage, since 3D shapes are available. A baseline is created that directly predicts the 3D shape using the same 3D shape estimator architecture with no 2.5D sketch estimation.<br />
<br />
=== Results ===<br />
The baseline output is compared to the full framework, and the figure below shows that MarrNet provides model outputs with more details and smoother surfaces than the baseline. Quantitatively, the full model also achieves 0.57 IoU, higher than the direct prediction baseline.<br />
<br />
[[File:marrnet_shapenet_results.png|700px|thumb|center|ShapeNet results.]]<br />
<br />
== PASCAL 3D+ ==<br />
Rough 3D models are provided from real-life images.<br />
<br />
=== Method ===<br />
Each module is pre-trained on the ShapeNet dataset, and then fine-tuned on the PASCAL 3D+ dataset. Three variants of the model are tested. The first is trained using ShapeNet data only with no fine-tuning. The second is fine-tuned without fixing the decoder. The third is fine-tuned with a fixed decoder.<br />
<br />
=== Results ===<br />
The figure below shows the results of the ablation study. The model trained only on synthetic data provides reasonable estimates. However, fine-tuning without fixing the decoder leads to impossible shapes from certain views. The third model keeps the shape prior, providing more details in the final shape.<br />
<br />
[[File:marrnet_pascal_3d_ablation.png|600px|thumb|center|Ablation studies using the PASCAL 3D+ dataset.]]<br />
<br />
Additional comparisons are made with the state-of-the-art (DRC) on the provided ground truth shapes. MarrNet achieves 0.39 IoU, while DRC achieves 0.34. However, the authors claim that the IoU metric is sub-optimal for three reasons. First, there is no emphasis on details since the metric prefers models that predict mean shapes consistently. Second, all possible scales are searched during the IoU computation, making it less efficient. Third, PASCAL 3D+ only has rough annotations, with only 10 CAD chair models for all images, and computing IoU with these shapes is not very informative. Instead, human studies are conducted and MarrNet reconstructions are preferred 74% of the time over DRC, and 42% of the time to ground truth. This shows how MarrNet produces nice shapes and also highlights the fact that ground truth shapes are not very good.<br />
<br />
[[File:human_studies.png|400px|thumb|center|Human preferences on chairs in PASCAL 3D+ (Xiang et al. 2014). The numbers show the percentage of how often humans prefered the 3D shape from DRC (state-of-the-art), MarrNet, or GT.]]<br />
<br />
<br />
[[File:marrnet_pascal_3d_drc_comparison.png|600px|thumb|center|Comparison between DRC and MarrNet results.]]<br />
<br />
Several failure cases are shown in the figure below. Specifically, the framework does not seem to work well on thin structures.<br />
<br />
[[File:marrnet_pascal_3d_failure_cases.png|500px|thumb|center|Failure cases on PASCAL 3D+. The algorithm cannot recover thin structures.]]<br />
<br />
== IKEA ==<br />
This dataset contains images of IKEA furniture, with accurate 3D shape and pose annotations. Objects are often heavily occluded or truncated.<br />
<br />
=== Results ===<br />
Qualitative results are shown in the figure below. The model is shown to deal with mild occlusions in real life scenarios. Human studes show that MarrNet reconstructions are preferred 61% of the time to 3D-VAE-GAN.<br />
<br />
[[File:marrnet_ikea_results.png|700px|thumb|center|Results on chairs in the IKEA dataset, and comparison with 3D-VAE-GAN.]]<br />
<br />
== Other Data ==<br />
MarrNet is also applied on cars and airplanes. Shown below, smaller details such as the horizontal stabilizer and rear-view mirrors are recovered.<br />
<br />
[[File:marrnet_airplanes_and_cars.png|700px|thumb|center|Results on airplanes and cars from the PASCAL 3D+ dataset, and comparison with DRC.]]<br />
<br />
MarrNet is also jointly trained on three object categories, and successfully recovers the shapes of different categories. Results are shown in the figure below.<br />
<br />
[[File:marrnet_multiple_categories.png|700px|thumb|center|Results when trained jointly on all three object categories (cars, airplanes, and chairs).]]<br />
<br />
= Commentary =<br />
Qualitatively, the results look quite impressive. The 2.5D sketch estimation seems to distill the useful information for more realistic looking 3D shape estimation. The disentanglement of 2.5D and 3D estimation steps also allows for easier training and domain adaptation from synthetic data.<br />
<br />
As the authors mention, the IoU metric is not very descriptive, and most of the comparisons in this paper are only qualitative, mainly being human preference studies. A better quantitative evaluation metric would greatly help in making an unbiased comparison between different results.<br />
<br />
As seen in several of the results, the network does not deal well with objects that have thin structures, which is particularly noticeable with many of the chair arm rests. As well, looking more carefully at some results, it seems that fine-tuning only the 3D encoder does not seem to transfer well to unseen objects, since shape priors have already been learned by the decoder.<br />
<br />
= Conclusion =<br />
The proposed MarrNet employs a novel model to estimate 2.5D sketches for 3D shape reconstruction. The sketches are shown to improve the model’s performance, and make it easy to adapt to images across different domains and categories. Differentiable loss functions are created such that the model can be fine-tuned end-to-end on images without ground truth. The experiments show that the model performs well, and human studies show that the results are preferred over other methods.<br />
<br />
= Implementation =<br />
The following repository provides the source code for the paper. The repository provides the source code as written by the authors: https://github.com/jiajunwu/marrnet<br />
<br />
= References =<br />
# David Marr. Vision: A computational investigation into the human representation and processing of visual information. W. H. Freeman and Company, 1982.<br />
# Shubham Tulsiani, Tinghui Zhou, Alexei A Efros, and Jitendra Malik. Multi-view supervision for single-view reconstruction via differentiable ray consistency. In CVPR, 2017.<br />
# JiajunWu, Chengkai Zhang, Tianfan Xue,William T Freeman, and Joshua B Tenenbaum. Learning a Proba- bilistic Latent Space of Object Shapes via 3D Generative-Adversarial Modeling. In NIPS, 2016b.<br />
# Wu, J. (n.d.). Jiajunwu/marrnet. Retrieved March 25, 2018, from https://github.com/jiajunwu/marrnet</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=MarrNet:_3D_Shape_Reconstruction_via_2.5D_Sketches&diff=35646MarrNet: 3D Shape Reconstruction via 2.5D Sketches2018-03-27T04:58:10Z<p>X249wang: /* Results */</p>
<hr />
<div>= Introduction =<br />
Humans are able to quickly recognize 3D shapes from images, even in spite of drastic differences in object texture, material, lighting, and background.<br />
<br />
[[File:marrnet_intro_image.png|700px|thumb|center|Objects in real images. The appearance of the same shaped object varies based on colour, texture, lighting, background, etc. However, the 2.5D sketches (e.g. depth or normal maps) of the object remain constant, and can be seen as an abstraction of the object which is used to reconstruct the 3D shape.]]<br />
<br />
In this work, the authors propose a novel end-to-end trainable model that sequentially estimates 2.5D sketches and 3D object shape from images and also enforce the re projection consistency between the 3D shape and the estimated sketch. The two step approach makes the network more robust to differences in object texture, material, lighting and background. Based on the idea from [Marr, 1982] that human 3D perception relies on recovering 2.5D sketches, which include depth maps (contains information related to the distance of surfaces from a viewpoint) and surface normal maps (technique for adding the illusion of depth details to surfaces using an image's RGB information), the authors design an end-to-end trainable pipeline which they call MarrNet. MarrNet first estimates depth, normal maps, and silhouette, followed by a 3D shape. MarrNet uses an encoder-decoder structure for the sub-components of the framework. <br />
<br />
The authors claim several unique advantages to their method. Single image 3D reconstruction is a highly under-constrained problem, requiring strong prior knowledge of object shapes. As well, accurate 3D object annotations using real images are not common, and many previous approaches rely on purely synthetic data. However, most of these methods suffer from domain adaptation due to imperfect rendering.<br />
<br />
Using 2.5D sketches can alleviate the challenges of domain transfer. It is straightforward to generate perfect object surface normals and depths using a graphics engine. Since 2.5D sketches contain only depth, surface normal, and silhouette information, the second step of recovering 3D shape can be trained purely from synthetic data. As well, the introduction of differentiable constraints between 2.5D sketches and 3D shape makes it possible to fine-tune the system, even without any annotations.<br />
<br />
The framework is evaluated on both synthetic objects from ShapeNet, and real images from PASCAL 3D+, showing good qualitative and quantitative performance in 3D shape reconstruction.<br />
<br />
= Related Work =<br />
<br />
== 2.5D Sketch Recovery ==<br />
Researchers have explored recovering 2.5D information from shading, texture, and colour images in the past. More recently, the development of depth sensors has led to the creation of large RGB-D datasets, and papers on estimating depth, surface normals, and other intrinsic images using deep networks. While this method employs 2.5D estimation, the final output is a full 3D shape of an object.<br />
<br />
[[File:2-5d_example.PNG|700px|thumb|center|Results from the paper: Learning Non-Lambertian Object Intrinsics across ShapeNet Categories. The results show that neural networks can be trained to recover 2.5D information from an image. The top row predicts the albedo and the bottom row predicts the shading. It can be observed that the results are still blurry and the fine details are not fully recovered.]]<br />
<br />
== Single Image 3D Reconstruction ==<br />
The development of large-scale shape repositories like ShapeNet has allowed for the development of models encoding shape priors for single image 3D reconstruction. These methods normally regress voxelized 3D shapes, relying on synthetic data or 2D masks for training. The formulation in the paper tackles domain adaptation better, since the network can be fine-tuned on images without any annotations.<br />
<br />
== 2D-3D Consistency ==<br />
Intuitively, the 3D shape can be constrained to be consistent with 2D observations. This idea has been explored for decades, with the use of depth and silhouettes, as well as some papers enforcing differentiable 2D-3D constraints for joint training of deep networks. In this work, this idea is exploited to develop differentiable constraints for consistency between the 2.5D sketches and 3D shape.<br />
<br />
= Approach =<br />
The 3D structure is recovered from a single RGB view using three steps, shown in the figure below. The first step estimates 2.5D sketches, including depth, surface normal, and silhouette of the object. The second step estimates a 3D voxel representation of the object. The third step uses a reprojection consistency function to enforce the 2.5D sketch and 3D structure alignment.<br />
<br />
[[File:marrnet_model_components.png|700px|thumb|center|MarrNet architecture. 2.5D sketches of normals, depths, and silhouette are first estimated. The sketches are then used to estimate the 3D shape. Finally, re-projection consistency is used to ensure consistency between the sketch and 3D output.]]<br />
<br />
== 2.5D Sketch Estimation ==<br />
The first step takes a 2D RGB image and predicts the surface normal, depth, and silhouette of the object. The goal is to estimate intrinsic object properties from the image, while discarding non-essential information. An encoder-decoder architecture is used. The encoder is a A ResNet-18 network, which takes a 256 x 256 RGB image and produces 8 x 8 x 512 feature maps. The decoder is four sets of 5 x 5 convolutional and ReLU layers, followed by four sets of 1 x 1 convolutional and ReLU layers. The output is 256 x 256 resolution depth, surface normal, and silhouette images.<br />
<br />
== 3D Shape Estimation ==<br />
The second step estimates a voxelized 3D shape using the 2.5D sketches from the first step. The focus here is for the network to learn the shape prior that can explain the input well, and can be trained on synthetic data without suffering from the domain adaptation problem since it only takes in surface normal and depth images as input. The network architecture is inspired by the TL network, and 3D-VAE-GAN, with an encoder-decoder structure. The normal and depth image, masked by the estimated silhouette, are passed into 5 sets of convolutional, ReLU, and pooling layers, followed by two fully connected layers, with a final output width of 200. The 200-dimensional vector is passed into a decoder of 5 convolutional and ReLU layers, outputting a 128 x 128 x 128 voxelized estimate of the input.<br />
<br />
== Re-projection Consistency ==<br />
The third step consists of a depth re-projection loss and surface normal re-projection loss. Here, <math>v_{x, y, z}</math> represents the value at position <math>(x, y, z)</math> in a 3D voxel grid, with <math>v_{x, y, z} \in [0, 1] ∀ x, y, z</math>. <math>d_{x, y}</math> denotes the estimated depth at position <math>(x, y)</math>, <math>n_{x, y} = (n_a, n_b, n_c)</math> denotes the estimated surface normal. Orthographic projection is used.<br />
<br />
[[File:marrnet_reprojection_consistency.png|700px|thumb|center|Reprojection consistency for voxels. Left and middle: criteria for depth and silhouettes. Right: criterion for surface normals]]<br />
<br />
=== Depths ===<br />
The voxel with depth <math>v_{x, y}, d_{x, y}</math> should be 1, while all voxels in front of it should be 0. This ensures the estimated 3D shape matches the estimated depth values. The projected depth loss and its gradient are defined as follows:<br />
<br />
<math><br />
L_{depth}(x, y, z)=<br />
\left\{<br />
\begin{array}{ll}<br />
v^2_{x, y, z}, & z < d_{x, y} \\<br />
(1 - v_{x, y, z})^2, & z = d_{x, y} \\<br />
0, & z > d_{x, y} \\<br />
\end{array}<br />
\right.<br />
</math><br />
<br />
<math><br />
\frac{∂L_{depth}(x, y, z)}{∂v_{x, y, z}} =<br />
\left\{<br />
\begin{array}{ll}<br />
2v{x, y, z}, & z < d_{x, y} \\<br />
2(v_{x, y, z} - 1), & z = d_{x, y} \\<br />
0, & z > d_{x, y} \\<br />
\end{array}<br />
\right.<br />
</math><br />
<br />
When <math>d_{x, y} = \infty</math>, all voxels in front of it should be 0.<br />
<br />
=== Surface Normals ===<br />
Since vectors <math>n_{x} = (0, −n_{c}, n_{b})</math> and <math>n_{y} = (−n_{c}, 0, n_{a})</math> are orthogonal to the normal vector <math>n_{x, y} = (n_{a}, n_{b}, n_{c})</math>, they can be normalized to obtain <math>n’_{x} = (0, −1, n_{b}/n_{c})</math> and <math>n’_{y} = (−1, 0, n_{a}/n_{c})</math> on the estimated surface plane at <math>(x, y, z)</math>. The projected surface normal tried to guarantee voxels at <math>(x, y, z) ± n’_{x}</math> and <math>(x, y, z) ± n’_{y}</math> should be 1 to match the estimated normal. The constraints are only applied when the target voxels are inside the estimated silhouette.<br />
<br />
The projected surface normal loss is defined as follows, with <math>z = d_{x, y}</math>:<br />
<br />
<math><br />
L_{normal}(x, y, z) =<br />
(1 - v_{x, y-1, z+\frac{n_b}{n_c}})^2 + (1 - v_{x, y+1, z-\frac{n_b}{n_c}})^2 + <br />
(1 - v_{x-1, y, z+\frac{n_a}{n_c}})^2 + (1 - v_{x+1, y, z-\frac{n_a}{n_c}})^2<br />
</math><br />
<br />
Gradients along x are:<br />
<br />
<math><br />
\frac{dL_{normal}(x, y, z)}{dv_{x-1, y, z+\frac{n_a}{n_c}}} = 2(v_{x-1, y, z+\frac{n_a}{n_c}}-1)<br />
</math><br />
and<br />
<math><br />
\frac{dL_{normal}(x, y, z)}{dv_{x+1, y, z-\frac{n_a}{n_c}}} = 2(v_{x+1, y, z-\frac{n_a}{n_c}}-1)<br />
</math><br />
<br />
Gradients along y are similar to x.<br />
<br />
= Training =<br />
The 2.5D and 3D estimation components are first pre-trained separately on synthetic data from ShapeNet, and then fine-tuned on real images.<br />
<br />
For pre-training, the 2.5D sketch estimator is trained on synthetic ShapeNet depth, surface normal, and silhouette ground truth, using an L2 loss. The 3D estimator is trained with ground truth voxels using a cross-entropy loss.<br />
<br />
Reprojection consistency loss is used to fine-tune the 3D estimation using real images, using the predicted depth, normals, and silhouette. A straightforward implementation leads to shapes that explain the 2.5D sketches well, but lead to unrealistic 3D appearance due to overfitting.<br />
<br />
Instead, the decoder of the 3D estimator is fixed, and only the encoder is fine-tuned. The model is fine-tuned separately on each image for 40 iterations, which takes up to 10 seconds on the GPU. Without fine-tuning, testing time takes around 100 milliseconds. SGD is used for optimization with batch size of 4, learning rate of 0.001, and momentum of 0.9.<br />
<br />
= Evaluation =<br />
Qualitative and quantitative results are provided using different variants of the framework. The framework is evaluated on both synthetic and real images on three datasets.<br />
<br />
== ShapeNet ==<br />
Synthesized images of 6,778 chairs from ShapeNet are rendered from 20 random viewpoints. The chairs are placed in front of random background from the SUN dataset, and the RGB, depth, normal, and silhouette images are rendered using the physics-based renderer Mitsuba for more realistic images.<br />
<br />
=== Method ===<br />
MarrNet is trained without the final fine-tuning stage, since 3D shapes are available. A baseline is created that directly predicts the 3D shape using the same 3D shape estimator architecture with no 2.5D sketch estimation.<br />
<br />
=== Results ===<br />
The baseline output is compared to the full framework, and the figure below shows that MarrNet provides model outputs with more details and smoother surfaces than the baseline. The estimated normal and depth images are able to extract intrinsic information about object shape while leaving behind non-essential information such as textures from the original images. Quantitatively, the full model also achieves 0.57 IoU, higher than the direct prediction baseline.<br />
<br />
<br />
[[File:marrnet_shapenet_results.png|700px|thumb|center|Results on rendered images of ShapeNet objects. From left to right: input, estimated normal map, estimated depth map, model prediction, a baseline algorithm that predicts 3D shape directly from RGB input without modeling 2.5D sketch, and ground truth. Both normal and depth maps are masked by predicted silhouettes. The method proposed in the paper is able to recover shapes with smoother surfaces and finer details.]]<br />
<br />
== PASCAL 3D+ ==<br />
Rough 3D models are provided from real-life images.<br />
<br />
=== Method ===<br />
Each module is pre-trained on the ShapeNet dataset, and then fine-tuned on the PASCAL 3D+ dataset. Three variants of the model are tested. The first is trained using ShapeNet data only with no fine-tuning. The second is fine-tuned without fixing the decoder. The third is fine-tuned with a fixed decoder.<br />
<br />
=== Results ===<br />
The figure below shows the results of the ablation study. The model trained only on synthetic data provides reasonable estimates. However, fine-tuning without fixing the decoder leads to impossible shapes from certain views. The third model keeps the shape prior, providing more details in the final shape.<br />
<br />
[[File:marrnet_pascal_3d_ablation.png|600px|thumb|center|Ablation studies using the PASCAL 3D+ dataset.]]<br />
<br />
Additional comparisons are made with the state-of-the-art (DRC) on the provided ground truth shapes. MarrNet achieves 0.39 IoU, while DRC achieves 0.34. However, the authors claim that the IoU metric is sub-optimal for three reasons. First, there is no emphasis on details since the metric prefers models that predict mean shapes consistently. Second, all possible scales are searched during the IoU computation, making it less efficient. Third, PASCAL 3D+ only has rough annotations, with only 10 CAD chair models for all images, and computing IoU with these shapes is not very informative. Instead, human studies are conducted and MarrNet reconstructions are preferred 74% of the time over DRC, and 42% of the time to ground truth. This shows how MarrNet produces nice shapes and also highlights the fact that ground truth shapes are not very good.<br />
<br />
[[File:human_studies.png|400px|thumb|center|Human preferences on chairs in PASCAL 3D+ (Xiang et al. 2014). The numbers show the percentage of how often humans prefered the 3D shape from DRC (state-of-the-art), MarrNet, or GT.]]<br />
<br />
<br />
[[File:marrnet_pascal_3d_drc_comparison.png|600px|thumb|center|Comparison between DRC and MarrNet results.]]<br />
<br />
Several failure cases are shown in the figure below. Specifically, the framework does not seem to work well on thin structures.<br />
<br />
[[File:marrnet_pascal_3d_failure_cases.png|500px|thumb|center|Failure cases on PASCAL 3D+. The algorithm cannot recover thin structures.]]<br />
<br />
== IKEA ==<br />
This dataset contains images of IKEA furniture, with accurate 3D shape and pose annotations. Objects are often heavily occluded or truncated.<br />
<br />
=== Results ===<br />
Qualitative results are shown in the figure below. The model is shown to deal with mild occlusions in real life scenarios. Human studes show that MarrNet reconstructions are preferred 61% of the time to 3D-VAE-GAN.<br />
<br />
[[File:marrnet_ikea_results.png|700px|thumb|center|Results on chairs in the IKEA dataset, and comparison with 3D-VAE-GAN.]]<br />
<br />
== Other Data ==<br />
MarrNet is also applied on cars and airplanes. Shown below, smaller details such as the horizontal stabilizer and rear-view mirrors are recovered.<br />
<br />
[[File:marrnet_airplanes_and_cars.png|700px|thumb|center|Results on airplanes and cars from the PASCAL 3D+ dataset, and comparison with DRC.]]<br />
<br />
MarrNet is also jointly trained on three object categories, and successfully recovers the shapes of different categories. Results are shown in the figure below.<br />
<br />
[[File:marrnet_multiple_categories.png|700px|thumb|center|Results when trained jointly on all three object categories (cars, airplanes, and chairs).]]<br />
<br />
= Commentary =<br />
Qualitatively, the results look quite impressive. The 2.5D sketch estimation seems to distill the useful information for more realistic looking 3D shape estimation. The disentanglement of 2.5D and 3D estimation steps also allows for easier training and domain adaptation from synthetic data.<br />
<br />
As the authors mention, the IoU metric is not very descriptive, and most of the comparisons in this paper are only qualitative, mainly being human preference studies. A better quantitative evaluation metric would greatly help in making an unbiased comparison between different results.<br />
<br />
As seen in several of the results, the network does not deal well with objects that have thin structures, which is particularly noticeable with many of the chair arm rests. As well, looking more carefully at some results, it seems that fine-tuning only the 3D encoder does not seem to transfer well to unseen objects, since shape priors have already been learned by the decoder.<br />
<br />
= Conclusion =<br />
The proposed MarrNet employs a novel model to estimate 2.5D sketches for 3D shape reconstruction. The sketches are shown to improve the model’s performance, and make it easy to adapt to images across different domains and categories. Differentiable loss functions are created such that the model can be fine-tuned end-to-end on images without ground truth. The experiments show that the model performs well, and human studies show that the results are preferred over other methods.<br />
<br />
= Implementation =<br />
The following repository provides the source code for the paper. The repository provides the source code as written by the authors: https://github.com/jiajunwu/marrnet<br />
<br />
= References =<br />
# David Marr. Vision: A computational investigation into the human representation and processing of visual information. W. H. Freeman and Company, 1982.<br />
# Shubham Tulsiani, Tinghui Zhou, Alexei A Efros, and Jitendra Malik. Multi-view supervision for single-view reconstruction via differentiable ray consistency. In CVPR, 2017.<br />
# JiajunWu, Chengkai Zhang, Tianfan Xue,William T Freeman, and Joshua B Tenenbaum. Learning a Proba- bilistic Latent Space of Object Shapes via 3D Generative-Adversarial Modeling. In NIPS, 2016b.<br />
# Wu, J. (n.d.). Jiajunwu/marrnet. Retrieved March 25, 2018, from https://github.com/jiajunwu/marrnet</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=MarrNet:_3D_Shape_Reconstruction_via_2.5D_Sketches&diff=35645MarrNet: 3D Shape Reconstruction via 2.5D Sketches2018-03-27T04:51:08Z<p>X249wang: /* Introduction */</p>
<hr />
<div>= Introduction =<br />
Humans are able to quickly recognize 3D shapes from images, even in spite of drastic differences in object texture, material, lighting, and background.<br />
<br />
[[File:marrnet_intro_image.png|700px|thumb|center|Objects in real images. The appearance of the same shaped object varies based on colour, texture, lighting, background, etc. However, the 2.5D sketches (e.g. depth or normal maps) of the object remain constant, and can be seen as an abstraction of the object which is used to reconstruct the 3D shape.]]<br />
<br />
In this work, the authors propose a novel end-to-end trainable model that sequentially estimates 2.5D sketches and 3D object shape from images and also enforce the re projection consistency between the 3D shape and the estimated sketch. The two step approach makes the network more robust to differences in object texture, material, lighting and background. Based on the idea from [Marr, 1982] that human 3D perception relies on recovering 2.5D sketches, which include depth maps (contains information related to the distance of surfaces from a viewpoint) and surface normal maps (technique for adding the illusion of depth details to surfaces using an image's RGB information), the authors design an end-to-end trainable pipeline which they call MarrNet. MarrNet first estimates depth, normal maps, and silhouette, followed by a 3D shape. MarrNet uses an encoder-decoder structure for the sub-components of the framework. <br />
<br />
The authors claim several unique advantages to their method. Single image 3D reconstruction is a highly under-constrained problem, requiring strong prior knowledge of object shapes. As well, accurate 3D object annotations using real images are not common, and many previous approaches rely on purely synthetic data. However, most of these methods suffer from domain adaptation due to imperfect rendering.<br />
<br />
Using 2.5D sketches can alleviate the challenges of domain transfer. It is straightforward to generate perfect object surface normals and depths using a graphics engine. Since 2.5D sketches contain only depth, surface normal, and silhouette information, the second step of recovering 3D shape can be trained purely from synthetic data. As well, the introduction of differentiable constraints between 2.5D sketches and 3D shape makes it possible to fine-tune the system, even without any annotations.<br />
<br />
The framework is evaluated on both synthetic objects from ShapeNet, and real images from PASCAL 3D+, showing good qualitative and quantitative performance in 3D shape reconstruction.<br />
<br />
= Related Work =<br />
<br />
== 2.5D Sketch Recovery ==<br />
Researchers have explored recovering 2.5D information from shading, texture, and colour images in the past. More recently, the development of depth sensors has led to the creation of large RGB-D datasets, and papers on estimating depth, surface normals, and other intrinsic images using deep networks. While this method employs 2.5D estimation, the final output is a full 3D shape of an object.<br />
<br />
[[File:2-5d_example.PNG|700px|thumb|center|Results from the paper: Learning Non-Lambertian Object Intrinsics across ShapeNet Categories. The results show that neural networks can be trained to recover 2.5D information from an image. The top row predicts the albedo and the bottom row predicts the shading. It can be observed that the results are still blurry and the fine details are not fully recovered.]]<br />
<br />
== Single Image 3D Reconstruction ==<br />
The development of large-scale shape repositories like ShapeNet has allowed for the development of models encoding shape priors for single image 3D reconstruction. These methods normally regress voxelized 3D shapes, relying on synthetic data or 2D masks for training. The formulation in the paper tackles domain adaptation better, since the network can be fine-tuned on images without any annotations.<br />
<br />
== 2D-3D Consistency ==<br />
Intuitively, the 3D shape can be constrained to be consistent with 2D observations. This idea has been explored for decades, with the use of depth and silhouettes, as well as some papers enforcing differentiable 2D-3D constraints for joint training of deep networks. In this work, this idea is exploited to develop differentiable constraints for consistency between the 2.5D sketches and 3D shape.<br />
<br />
= Approach =<br />
The 3D structure is recovered from a single RGB view using three steps, shown in the figure below. The first step estimates 2.5D sketches, including depth, surface normal, and silhouette of the object. The second step estimates a 3D voxel representation of the object. The third step uses a reprojection consistency function to enforce the 2.5D sketch and 3D structure alignment.<br />
<br />
[[File:marrnet_model_components.png|700px|thumb|center|MarrNet architecture. 2.5D sketches of normals, depths, and silhouette are first estimated. The sketches are then used to estimate the 3D shape. Finally, re-projection consistency is used to ensure consistency between the sketch and 3D output.]]<br />
<br />
== 2.5D Sketch Estimation ==<br />
The first step takes a 2D RGB image and predicts the surface normal, depth, and silhouette of the object. The goal is to estimate intrinsic object properties from the image, while discarding non-essential information. An encoder-decoder architecture is used. The encoder is a A ResNet-18 network, which takes a 256 x 256 RGB image and produces 8 x 8 x 512 feature maps. The decoder is four sets of 5 x 5 convolutional and ReLU layers, followed by four sets of 1 x 1 convolutional and ReLU layers. The output is 256 x 256 resolution depth, surface normal, and silhouette images.<br />
<br />
== 3D Shape Estimation ==<br />
The second step estimates a voxelized 3D shape using the 2.5D sketches from the first step. The focus here is for the network to learn the shape prior that can explain the input well, and can be trained on synthetic data without suffering from the domain adaptation problem since it only takes in surface normal and depth images as input. The network architecture is inspired by the TL network, and 3D-VAE-GAN, with an encoder-decoder structure. The normal and depth image, masked by the estimated silhouette, are passed into 5 sets of convolutional, ReLU, and pooling layers, followed by two fully connected layers, with a final output width of 200. The 200-dimensional vector is passed into a decoder of 5 convolutional and ReLU layers, outputting a 128 x 128 x 128 voxelized estimate of the input.<br />
<br />
== Re-projection Consistency ==<br />
The third step consists of a depth re-projection loss and surface normal re-projection loss. Here, <math>v_{x, y, z}</math> represents the value at position <math>(x, y, z)</math> in a 3D voxel grid, with <math>v_{x, y, z} \in [0, 1] ∀ x, y, z</math>. <math>d_{x, y}</math> denotes the estimated depth at position <math>(x, y)</math>, <math>n_{x, y} = (n_a, n_b, n_c)</math> denotes the estimated surface normal. Orthographic projection is used.<br />
<br />
[[File:marrnet_reprojection_consistency.png|700px|thumb|center|Reprojection consistency for voxels. Left and middle: criteria for depth and silhouettes. Right: criterion for surface normals]]<br />
<br />
=== Depths ===<br />
The voxel with depth <math>v_{x, y}, d_{x, y}</math> should be 1, while all voxels in front of it should be 0. This ensures the estimated 3D shape matches the estimated depth values. The projected depth loss and its gradient are defined as follows:<br />
<br />
<math><br />
L_{depth}(x, y, z)=<br />
\left\{<br />
\begin{array}{ll}<br />
v^2_{x, y, z}, & z < d_{x, y} \\<br />
(1 - v_{x, y, z})^2, & z = d_{x, y} \\<br />
0, & z > d_{x, y} \\<br />
\end{array}<br />
\right.<br />
</math><br />
<br />
<math><br />
\frac{∂L_{depth}(x, y, z)}{∂v_{x, y, z}} =<br />
\left\{<br />
\begin{array}{ll}<br />
2v{x, y, z}, & z < d_{x, y} \\<br />
2(v_{x, y, z} - 1), & z = d_{x, y} \\<br />
0, & z > d_{x, y} \\<br />
\end{array}<br />
\right.<br />
</math><br />
<br />
When <math>d_{x, y} = \infty</math>, all voxels in front of it should be 0.<br />
<br />
=== Surface Normals ===<br />
Since vectors <math>n_{x} = (0, −n_{c}, n_{b})</math> and <math>n_{y} = (−n_{c}, 0, n_{a})</math> are orthogonal to the normal vector <math>n_{x, y} = (n_{a}, n_{b}, n_{c})</math>, they can be normalized to obtain <math>n’_{x} = (0, −1, n_{b}/n_{c})</math> and <math>n’_{y} = (−1, 0, n_{a}/n_{c})</math> on the estimated surface plane at <math>(x, y, z)</math>. The projected surface normal tried to guarantee voxels at <math>(x, y, z) ± n’_{x}</math> and <math>(x, y, z) ± n’_{y}</math> should be 1 to match the estimated normal. The constraints are only applied when the target voxels are inside the estimated silhouette.<br />
<br />
The projected surface normal loss is defined as follows, with <math>z = d_{x, y}</math>:<br />
<br />
<math><br />
L_{normal}(x, y, z) =<br />
(1 - v_{x, y-1, z+\frac{n_b}{n_c}})^2 + (1 - v_{x, y+1, z-\frac{n_b}{n_c}})^2 + <br />
(1 - v_{x-1, y, z+\frac{n_a}{n_c}})^2 + (1 - v_{x+1, y, z-\frac{n_a}{n_c}})^2<br />
</math><br />
<br />
Gradients along x are:<br />
<br />
<math><br />
\frac{dL_{normal}(x, y, z)}{dv_{x-1, y, z+\frac{n_a}{n_c}}} = 2(v_{x-1, y, z+\frac{n_a}{n_c}}-1)<br />
</math><br />
and<br />
<math><br />
\frac{dL_{normal}(x, y, z)}{dv_{x+1, y, z-\frac{n_a}{n_c}}} = 2(v_{x+1, y, z-\frac{n_a}{n_c}}-1)<br />
</math><br />
<br />
Gradients along y are similar to x.<br />
<br />
= Training =<br />
The 2.5D and 3D estimation components are first pre-trained separately on synthetic data from ShapeNet, and then fine-tuned on real images.<br />
<br />
For pre-training, the 2.5D sketch estimator is trained on synthetic ShapeNet depth, surface normal, and silhouette ground truth, using an L2 loss. The 3D estimator is trained with ground truth voxels using a cross-entropy loss.<br />
<br />
Reprojection consistency loss is used to fine-tune the 3D estimation using real images, using the predicted depth, normals, and silhouette. A straightforward implementation leads to shapes that explain the 2.5D sketches well, but lead to unrealistic 3D appearance due to overfitting.<br />
<br />
Instead, the decoder of the 3D estimator is fixed, and only the encoder is fine-tuned. The model is fine-tuned separately on each image for 40 iterations, which takes up to 10 seconds on the GPU. Without fine-tuning, testing time takes around 100 milliseconds. SGD is used for optimization with batch size of 4, learning rate of 0.001, and momentum of 0.9.<br />
<br />
= Evaluation =<br />
Qualitative and quantitative results are provided using different variants of the framework. The framework is evaluated on both synthetic and real images on three datasets.<br />
<br />
== ShapeNet ==<br />
Synthesized images of 6,778 chairs from ShapeNet are rendered from 20 random viewpoints. The chairs are placed in front of random background from the SUN dataset, and the RGB, depth, normal, and silhouette images are rendered using the physics-based renderer Mitsuba for more realistic images.<br />
<br />
=== Method ===<br />
MarrNet is trained without the final fine-tuning stage, since 3D shapes are available. A baseline is created that directly predicts the 3D shape using the same 3D shape estimator architecture with no 2.5D sketch estimation.<br />
<br />
=== Results ===<br />
The baseline output is compared to the full framework, and the figure below shows that MarrNet provides model outputs with more details and smoother surfaces than the baseline. Quantitatively, the full model also achieves 0.57 IoU, higher than the direct prediction baseline.<br />
<br />
[[File:marrnet_shapenet_results.png|700px|thumb|center|ShapeNet results.]]<br />
<br />
== PASCAL 3D+ ==<br />
Rough 3D models are provided from real-life images.<br />
<br />
=== Method ===<br />
Each module is pre-trained on the ShapeNet dataset, and then fine-tuned on the PASCAL 3D+ dataset. Three variants of the model are tested. The first is trained using ShapeNet data only with no fine-tuning. The second is fine-tuned without fixing the decoder. The third is fine-tuned with a fixed decoder.<br />
<br />
=== Results ===<br />
The figure below shows the results of the ablation study. The model trained only on synthetic data provides reasonable estimates. However, fine-tuning without fixing the decoder leads to impossible shapes from certain views. The third model keeps the shape prior, providing more details in the final shape.<br />
<br />
[[File:marrnet_pascal_3d_ablation.png|600px|thumb|center|Ablation studies using the PASCAL 3D+ dataset.]]<br />
<br />
Additional comparisons are made with the state-of-the-art (DRC) on the provided ground truth shapes. MarrNet achieves 0.39 IoU, while DRC achieves 0.34. However, the authors claim that the IoU metric is sub-optimal for three reasons. First, there is no emphasis on details since the metric prefers models that predict mean shapes consistently. Second, all possible scales are searched during the IoU computation, making it less efficient. Third, PASCAL 3D+ only has rough annotations, with only 10 CAD chair models for all images, and computing IoU with these shapes is not very informative. Instead, human studies are conducted and MarrNet reconstructions are preferred 74% of the time over DRC, and 42% of the time to ground truth. This shows how MarrNet produces nice shapes and also highlights the fact that ground truth shapes are not very good.<br />
<br />
[[File:human_studies.png|400px|thumb|center|Human preferences on chairs in PASCAL 3D+ (Xiang et al. 2014). The numbers show the percentage of how often humans prefered the 3D shape from DRC (state-of-the-art), MarrNet, or GT.]]<br />
<br />
<br />
[[File:marrnet_pascal_3d_drc_comparison.png|600px|thumb|center|Comparison between DRC and MarrNet results.]]<br />
<br />
Several failure cases are shown in the figure below. Specifically, the framework does not seem to work well on thin structures.<br />
<br />
[[File:marrnet_pascal_3d_failure_cases.png|500px|thumb|center|Failure cases on PASCAL 3D+. The algorithm cannot recover thin structures.]]<br />
<br />
== IKEA ==<br />
This dataset contains images of IKEA furniture, with accurate 3D shape and pose annotations. Objects are often heavily occluded or truncated.<br />
<br />
=== Results ===<br />
Qualitative results are shown in the figure below. The model is shown to deal with mild occlusions in real life scenarios. Human studes show that MarrNet reconstructions are preferred 61% of the time to 3D-VAE-GAN.<br />
<br />
[[File:marrnet_ikea_results.png|700px|thumb|center|Results on chairs in the IKEA dataset, and comparison with 3D-VAE-GAN.]]<br />
<br />
== Other Data ==<br />
MarrNet is also applied on cars and airplanes. Shown below, smaller details such as the horizontal stabilizer and rear-view mirrors are recovered.<br />
<br />
[[File:marrnet_airplanes_and_cars.png|700px|thumb|center|Results on airplanes and cars from the PASCAL 3D+ dataset, and comparison with DRC.]]<br />
<br />
MarrNet is also jointly trained on three object categories, and successfully recovers the shapes of different categories. Results are shown in the figure below.<br />
<br />
[[File:marrnet_multiple_categories.png|700px|thumb|center|Results when trained jointly on all three object categories (cars, airplanes, and chairs).]]<br />
<br />
= Commentary =<br />
Qualitatively, the results look quite impressive. The 2.5D sketch estimation seems to distill the useful information for more realistic looking 3D shape estimation. The disentanglement of 2.5D and 3D estimation steps also allows for easier training and domain adaptation from synthetic data.<br />
<br />
As the authors mention, the IoU metric is not very descriptive, and most of the comparisons in this paper are only qualitative, mainly being human preference studies. A better quantitative evaluation metric would greatly help in making an unbiased comparison between different results.<br />
<br />
As seen in several of the results, the network does not deal well with objects that have thin structures, which is particularly noticeable with many of the chair arm rests. As well, looking more carefully at some results, it seems that fine-tuning only the 3D encoder does not seem to transfer well to unseen objects, since shape priors have already been learned by the decoder.<br />
<br />
= Conclusion =<br />
The proposed MarrNet employs a novel model to estimate 2.5D sketches for 3D shape reconstruction. The sketches are shown to improve the model’s performance, and make it easy to adapt to images across different domains and categories. Differentiable loss functions are created such that the model can be fine-tuned end-to-end on images without ground truth. The experiments show that the model performs well, and human studies show that the results are preferred over other methods.<br />
<br />
= Implementation =<br />
The following repository provides the source code for the paper. The repository provides the source code as written by the authors: https://github.com/jiajunwu/marrnet<br />
<br />
= References =<br />
# David Marr. Vision: A computational investigation into the human representation and processing of visual information. W. H. Freeman and Company, 1982.<br />
# Shubham Tulsiani, Tinghui Zhou, Alexei A Efros, and Jitendra Malik. Multi-view supervision for single-view reconstruction via differentiable ray consistency. In CVPR, 2017.<br />
# JiajunWu, Chengkai Zhang, Tianfan Xue,William T Freeman, and Joshua B Tenenbaum. Learning a Proba- bilistic Latent Space of Object Shapes via 3D Generative-Adversarial Modeling. In NIPS, 2016b.<br />
# Wu, J. (n.d.). Jiajunwu/marrnet. Retrieved March 25, 2018, from https://github.com/jiajunwu/marrnet</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=MarrNet:_3D_Shape_Reconstruction_via_2.5D_Sketches&diff=35643MarrNet: 3D Shape Reconstruction via 2.5D Sketches2018-03-27T04:28:13Z<p>X249wang: /* 3D Shape Estimation */</p>
<hr />
<div>= Introduction =<br />
Humans are able to quickly recognize 3D shapes from images, even in spite of drastic differences in object texture, material, lighting, and background.<br />
<br />
[[File:marrnet_intro_image.png|700px|thumb|center|Objects in real images. The appearance of the same shaped object varies based on colour, texture, lighting, background, etc. However, the 2.5D sketches (e.g. depth or normal maps) of the object remain constant, and can be seen as an abstraction of the object which is used to reconstruct the 3D shape.]]<br />
<br />
In this work, the authors propose a novel end-to-end trainable model that sequentially estimates 2.5D sketches and 3D object shape from images and also enforce the re projection consistency between the 3D shape and the estimated sketch. The two step approach makes the network more robust to differences in object texture, material, lighting and background. Based on the idea from [Marr, 1982] that human 3D perception relies on recovering 2.5D sketches, which include depth and surface normal maps, the author’s design an end-to-end trainable pipeline which they call MarrNet. MarrNet first estimates depth, normal maps, and silhouette, followed by a 3D shape. MarrNet uses an encoder-decoder structure for the sub-components of the framework. <br />
<br />
The authors claim several unique advantages to their method. Single image 3D reconstruction is a highly under-constrained problem, requiring strong prior knowledge of object shapes. As well, accurate 3D object annotations using real images are not common, and many previous approaches rely on purely synthetic data. However, most of these methods suffer from domain adaptation due to imperfect rendering.<br />
<br />
Using 2.5D sketches can alleviate the challenges of domain transfer. It is straightforward to generate perfect object surface normals and depths using a graphics engine. Since 2.5D sketches contain only depth, surface normal, and silhouette information, the second step of recovering 3D shape can be trained purely from synthetic data. As well, the introduction of differentiable constraints between 2.5D sketches and 3D shape makes it possible to fine-tune the system, even without any annotations.<br />
<br />
The framework is evaluated on both synthetic objects from ShapeNet, and real images from PASCAL 3D+, showing good qualitative and quantitative performance in 3D shape reconstruction.<br />
<br />
= Related Work =<br />
<br />
== 2.5D Sketch Recovery ==<br />
Researchers have explored recovering 2.5D information from shading, texture, and colour images in the past. More recently, the development of depth sensors has led to the creation of large RGB-D datasets, and papers on estimating depth, surface normals, and other intrinsic images using deep networks. While this method employs 2.5D estimation, the final output is a full 3D shape of an object.<br />
<br />
[[File:2-5d_example.PNG|700px|thumb|center|Results from the paper: Learning Non-Lambertian Object Intrinsics across ShapeNet Categories. The results show that neural networks can be trained to recover 2.5D information from an image. The top row predicts the albedo and the bottom row predicts the shading. It can be observed that the results are still blurry and the fine details are not fully recovered.]]<br />
<br />
== Single Image 3D Reconstruction ==<br />
The development of large-scale shape repositories like ShapeNet has allowed for the development of models encoding shape priors for single image 3D reconstruction. These methods normally regress voxelized 3D shapes, relying on synthetic data or 2D masks for training. The formulation in the paper tackles domain adaptation better, since the network can be fine-tuned on images without any annotations.<br />
<br />
== 2D-3D Consistency ==<br />
Intuitively, the 3D shape can be constrained to be consistent with 2D observations. This idea has been explored for decades, with the use of depth and silhouettes, as well as some papers enforcing differentiable 2D-3D constraints for joint training of deep networks. In this work, this idea is exploited to develop differentiable constraints for consistency between the 2.5D sketches and 3D shape.<br />
<br />
= Approach =<br />
The 3D structure is recovered from a single RGB view using three steps, shown in the figure below. The first step estimates 2.5D sketches, including depth, surface normal, and silhouette of the object. The second step estimates a 3D voxel representation of the object. The third step uses a reprojection consistency function to enforce the 2.5D sketch and 3D structure alignment.<br />
<br />
[[File:marrnet_model_components.png|700px|thumb|center|MarrNet architecture. 2.5D sketches of normals, depths, and silhouette are first estimated. The sketches are then used to estimate the 3D shape. Finally, re-projection consistency is used to ensure consistency between the sketch and 3D output.]]<br />
<br />
== 2.5D Sketch Estimation ==<br />
The first step takes a 2D RGB image and predicts the surface normal, depth, and silhouette of the object. The goal is to estimate intrinsic object properties from the image, while discarding non-essential information. An encoder-decoder architecture is used. The encoder is a A ResNet-18 network, which takes a 256 x 256 RGB image and produces 8 x 8 x 512 feature maps. The decoder is four sets of 5 x 5 convolutional and ReLU layers, followed by four sets of 1 x 1 convolutional and ReLU layers. The output is 256 x 256 resolution depth, surface normal, and silhouette images.<br />
<br />
== 3D Shape Estimation ==<br />
The second step estimates a voxelized 3D shape using the 2.5D sketches from the first step. The focus here is for the network to learn the shape prior that can explain the input well, and can be trained on synthetic data without suffering from the domain adaptation problem since it only takes in surface normal and depth images as input. The network architecture is inspired by the TL network, and 3D-VAE-GAN, with an encoder-decoder structure. The normal and depth image, masked by the estimated silhouette, are passed into 5 sets of convolutional, ReLU, and pooling layers, followed by two fully connected layers, with a final output width of 200. The 200-dimensional vector is passed into a decoder of 5 convolutional and ReLU layers, outputting a 128 x 128 x 128 voxelized estimate of the input.<br />
<br />
== Re-projection Consistency ==<br />
The third step consists of a depth re-projection loss and surface normal re-projection loss. Here, <math>v_{x, y, z}</math> represents the value at position <math>(x, y, z)</math> in a 3D voxel grid, with <math>v_{x, y, z} \in [0, 1] ∀ x, y, z</math>. <math>d_{x, y}</math> denotes the estimated depth at position <math>(x, y)</math>, <math>n_{x, y} = (n_a, n_b, n_c)</math> denotes the estimated surface normal. Orthographic projection is used.<br />
<br />
[[File:marrnet_reprojection_consistency.png|700px|thumb|center|Reprojection consistency for voxels. Left and middle: criteria for depth and silhouettes. Right: criterion for surface normals]]<br />
<br />
=== Depths ===<br />
The voxel with depth <math>v_{x, y}, d_{x, y}</math> should be 1, while all voxels in front of it should be 0. This ensures the estimated 3D shape matches the estimated depth values. The projected depth loss and its gradient are defined as follows:<br />
<br />
<math><br />
L_{depth}(x, y, z)=<br />
\left\{<br />
\begin{array}{ll}<br />
v^2_{x, y, z}, & z < d_{x, y} \\<br />
(1 - v_{x, y, z})^2, & z = d_{x, y} \\<br />
0, & z > d_{x, y} \\<br />
\end{array}<br />
\right.<br />
</math><br />
<br />
<math><br />
\frac{∂L_{depth}(x, y, z)}{∂v_{x, y, z}} =<br />
\left\{<br />
\begin{array}{ll}<br />
2v{x, y, z}, & z < d_{x, y} \\<br />
2(v_{x, y, z} - 1), & z = d_{x, y} \\<br />
0, & z > d_{x, y} \\<br />
\end{array}<br />
\right.<br />
</math><br />
<br />
When <math>d_{x, y} = \infty</math>, all voxels in front of it should be 0.<br />
<br />
=== Surface Normals ===<br />
Since vectors <math>n_{x} = (0, −n_{c}, n_{b})</math> and <math>n_{y} = (−n_{c}, 0, n_{a})</math> are orthogonal to the normal vector <math>n_{x, y} = (n_{a}, n_{b}, n_{c})</math>, they can be normalized to obtain <math>n’_{x} = (0, −1, n_{b}/n_{c})</math> and <math>n’_{y} = (−1, 0, n_{a}/n_{c})</math> on the estimated surface plane at <math>(x, y, z)</math>. The projected surface normal tried to guarantee voxels at <math>(x, y, z) ± n’_{x}</math> and <math>(x, y, z) ± n’_{y}</math> should be 1 to match the estimated normal. The constraints are only applied when the target voxels are inside the estimated silhouette.<br />
<br />
The projected surface normal loss is defined as follows, with <math>z = d_{x, y}</math>:<br />
<br />
<math><br />
L_{normal}(x, y, z) =<br />
(1 - v_{x, y-1, z+\frac{n_b}{n_c}})^2 + (1 - v_{x, y+1, z-\frac{n_b}{n_c}})^2 + <br />
(1 - v_{x-1, y, z+\frac{n_a}{n_c}})^2 + (1 - v_{x+1, y, z-\frac{n_a}{n_c}})^2<br />
</math><br />
<br />
Gradients along x are:<br />
<br />
<math><br />
\frac{dL_{normal}(x, y, z)}{dv_{x-1, y, z+\frac{n_a}{n_c}}} = 2(v_{x-1, y, z+\frac{n_a}{n_c}}-1)<br />
</math><br />
and<br />
<math><br />
\frac{dL_{normal}(x, y, z)}{dv_{x+1, y, z-\frac{n_a}{n_c}}} = 2(v_{x+1, y, z-\frac{n_a}{n_c}}-1)<br />
</math><br />
<br />
Gradients along y are similar to x.<br />
<br />
= Training =<br />
The 2.5D and 3D estimation components are first pre-trained separately on synthetic data from ShapeNet, and then fine-tuned on real images.<br />
<br />
For pre-training, the 2.5D sketch estimator is trained on synthetic ShapeNet depth, surface normal, and silhouette ground truth, using an L2 loss. The 3D estimator is trained with ground truth voxels using a cross-entropy loss.<br />
<br />
Reprojection consistency loss is used to fine-tune the 3D estimation using real images, using the predicted depth, normals, and silhouette. A straightforward implementation leads to shapes that explain the 2.5D sketches well, but lead to unrealistic 3D appearance due to overfitting.<br />
<br />
Instead, the decoder of the 3D estimator is fixed, and only the encoder is fine-tuned. The model is fine-tuned separately on each image for 40 iterations, which takes up to 10 seconds on the GPU. Without fine-tuning, testing time takes around 100 milliseconds. SGD is used for optimization with batch size of 4, learning rate of 0.001, and momentum of 0.9.<br />
<br />
= Evaluation =<br />
Qualitative and quantitative results are provided using different variants of the framework. The framework is evaluated on both synthetic and real images on three datasets.<br />
<br />
== ShapeNet ==<br />
Synthesized images of 6,778 chairs from ShapeNet are rendered from 20 random viewpoints. The chairs are placed in front of random background from the SUN dataset, and the RGB, depth, normal, and silhouette images are rendered using the physics-based renderer Mitsuba for more realistic images.<br />
<br />
=== Method ===<br />
MarrNet is trained without the final fine-tuning stage, since 3D shapes are available. A baseline is created that directly predicts the 3D shape using the same 3D shape estimator architecture with no 2.5D sketch estimation.<br />
<br />
=== Results ===<br />
The baseline output is compared to the full framework, and the figure below shows that MarrNet provides model outputs with more details and smoother surfaces than the baseline. Quantitatively, the full model also achieves 0.57 IoU, higher than the direct prediction baseline.<br />
<br />
[[File:marrnet_shapenet_results.png|700px|thumb|center|ShapeNet results.]]<br />
<br />
== PASCAL 3D+ ==<br />
Rough 3D models are provided from real-life images.<br />
<br />
=== Method ===<br />
Each module is pre-trained on the ShapeNet dataset, and then fine-tuned on the PASCAL 3D+ dataset. Three variants of the model are tested. The first is trained using ShapeNet data only with no fine-tuning. The second is fine-tuned without fixing the decoder. The third is fine-tuned with a fixed decoder.<br />
<br />
=== Results ===<br />
The figure below shows the results of the ablation study. The model trained only on synthetic data provides reasonable estimates. However, fine-tuning without fixing the decoder leads to impossible shapes from certain views. The third model keeps the shape prior, providing more details in the final shape.<br />
<br />
[[File:marrnet_pascal_3d_ablation.png|600px|thumb|center|Ablation studies using the PASCAL 3D+ dataset.]]<br />
<br />
Additional comparisons are made with the state-of-the-art (DRC) on the provided ground truth shapes. MarrNet achieves 0.39 IoU, while DRC achieves 0.34. However, the authors claim that the IoU metric is sub-optimal for three reasons. First, there is no emphasis on details since the metric prefers models that predict mean shapes consistently. Second, all possible scales are searched during the IoU computation, making it less efficient. Third, PASCAL 3D+ only has rough annotations, with only 10 CAD chair models for all images, and computing IoU with these shapes is not very informative. Instead, human studies are conducted and MarrNet reconstructions are preferred 74% of the time over DRC, and 42% of the time to ground truth. This shows how MarrNet produces nice shapes and also highlights the fact that ground truth shapes are not very good.<br />
<br />
[[File:human_studies.png|400px|thumb|center|Human preferences on chairs in PASCAL 3D+ (Xiang et al. 2014). The numbers show the percentage of how often humans prefered the 3D shape from DRC (state-of-the-art), MarrNet, or GT.]]<br />
<br />
<br />
[[File:marrnet_pascal_3d_drc_comparison.png|600px|thumb|center|Comparison between DRC and MarrNet results.]]<br />
<br />
Several failure cases are shown in the figure below. Specifically, the framework does not seem to work well on thin structures.<br />
<br />
[[File:marrnet_pascal_3d_failure_cases.png|500px|thumb|center|Failure cases on PASCAL 3D+. The algorithm cannot recover thin structures.]]<br />
<br />
== IKEA ==<br />
This dataset contains images of IKEA furniture, with accurate 3D shape and pose annotations. Objects are often heavily occluded or truncated.<br />
<br />
=== Results ===<br />
Qualitative results are shown in the figure below. The model is shown to deal with mild occlusions in real life scenarios. Human studes show that MarrNet reconstructions are preferred 61% of the time to 3D-VAE-GAN.<br />
<br />
[[File:marrnet_ikea_results.png|700px|thumb|center|Results on chairs in the IKEA dataset, and comparison with 3D-VAE-GAN.]]<br />
<br />
== Other Data ==<br />
MarrNet is also applied on cars and airplanes. Shown below, smaller details such as the horizontal stabilizer and rear-view mirrors are recovered.<br />
<br />
[[File:marrnet_airplanes_and_cars.png|700px|thumb|center|Results on airplanes and cars from the PASCAL 3D+ dataset, and comparison with DRC.]]<br />
<br />
MarrNet is also jointly trained on three object categories, and successfully recovers the shapes of different categories. Results are shown in the figure below.<br />
<br />
[[File:marrnet_multiple_categories.png|700px|thumb|center|Results when trained jointly on all three object categories (cars, airplanes, and chairs).]]<br />
<br />
= Commentary =<br />
Qualitatively, the results look quite impressive. The 2.5D sketch estimation seems to distill the useful information for more realistic looking 3D shape estimation. The disentanglement of 2.5D and 3D estimation steps also allows for easier training and domain adaptation from synthetic data.<br />
<br />
As the authors mention, the IoU metric is not very descriptive, and most of the comparisons in this paper are only qualitative, mainly being human preference studies. A better quantitative evaluation metric would greatly help in making an unbiased comparison between different results.<br />
<br />
As seen in several of the results, the network does not deal well with objects that have thin structures, which is particularly noticeable with many of the chair arm rests. As well, looking more carefully at some results, it seems that fine-tuning only the 3D encoder does not seem to transfer well to unseen objects, since shape priors have already been learned by the decoder.<br />
<br />
= Conclusion =<br />
The proposed MarrNet employs a novel model to estimate 2.5D sketches for 3D shape reconstruction. The sketches are shown to improve the model’s performance, and make it easy to adapt to images across different domains and categories. Differentiable loss functions are created such that the model can be fine-tuned end-to-end on images without ground truth. The experiments show that the model performs well, and human studies show that the results are preferred over other methods.<br />
<br />
= Implementation =<br />
The following repository provides the source code for the paper. The repository provides the source code as written by the authors: https://github.com/jiajunwu/marrnet<br />
<br />
= References =<br />
# David Marr. Vision: A computational investigation into the human representation and processing of visual information. W. H. Freeman and Company, 1982.<br />
# Shubham Tulsiani, Tinghui Zhou, Alexei A Efros, and Jitendra Malik. Multi-view supervision for single-view reconstruction via differentiable ray consistency. In CVPR, 2017.<br />
# JiajunWu, Chengkai Zhang, Tianfan Xue,William T Freeman, and Joshua B Tenenbaum. Learning a Proba- bilistic Latent Space of Object Shapes via 3D Generative-Adversarial Modeling. In NIPS, 2016b.<br />
# Wu, J. (n.d.). Jiajunwu/marrnet. Retrieved March 25, 2018, from https://github.com/jiajunwu/marrnet</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=MarrNet:_3D_Shape_Reconstruction_via_2.5D_Sketches&diff=35639MarrNet: 3D Shape Reconstruction via 2.5D Sketches2018-03-27T04:06:54Z<p>X249wang: /* Depths */</p>
<hr />
<div>= Introduction =<br />
Humans are able to quickly recognize 3D shapes from images, even in spite of drastic differences in object texture, material, lighting, and background.<br />
<br />
[[File:marrnet_intro_image.png|700px|thumb|center|Objects in real images. The appearance of the same shaped object varies based on colour, texture, lighting, background, etc. However, the 2.5D sketches (e.g. depth or normal maps) of the object remain constant, and can be seen as an abstraction of the object which is used to reconstruct the 3D shape.]]<br />
<br />
In this work, the authors propose a novel end-to-end trainable model that sequentially estimates 2.5D sketches and 3D object shape from images and also enforce the re projection consistency between the 3D shape and the estimated sketch. The two step approach makes the network more robust to differences in object texture, material, lighting and background. Based on the idea from [Marr, 1982] that human 3D perception relies on recovering 2.5D sketches, which include depth and surface normal maps, the author’s design an end-to-end trainable pipeline which they call MarrNet. MarrNet first estimates depth, normal maps, and silhouette, followed by a 3D shape. MarrNet uses an encoder-decoder structure for the sub-components of the framework. <br />
<br />
The authors claim several unique advantages to their method. Single image 3D reconstruction is a highly under-constrained problem, requiring strong prior knowledge of object shapes. As well, accurate 3D object annotations using real images are not common, and many previous approaches rely on purely synthetic data. However, most of these methods suffer from domain adaptation due to imperfect rendering.<br />
<br />
Using 2.5D sketches can alleviate the challenges of domain transfer. It is straightforward to generate perfect object surface normals and depths using a graphics engine. Since 2.5D sketches contain only depth, surface normal, and silhouette information, the second step of recovering 3D shape can be trained purely from synthetic data. As well, the introduction of differentiable constraints between 2.5D sketches and 3D shape makes it possible to fine-tune the system, even without any annotations.<br />
<br />
The framework is evaluated on both synthetic objects from ShapeNet, and real images from PASCAL 3D+, showing good qualitative and quantitative performance in 3D shape reconstruction.<br />
<br />
= Related Work =<br />
<br />
== 2.5D Sketch Recovery ==<br />
Researchers have explored recovering 2.5D information from shading, texture, and colour images in the past. More recently, the development of depth sensors has led to the creation of large RGB-D datasets, and papers on estimating depth, surface normals, and other intrinsic images using deep networks. While this method employs 2.5D estimation, the final output is a full 3D shape of an object.<br />
<br />
[[File:2-5d_example.PNG|700px|thumb|center|Results from the paper: Learning Non-Lambertian Object Intrinsics across ShapeNet Categories. The results show that neural networks can be trained to recover 2.5D information from an image. The top row predicts the albedo and the bottom row predicts the shading. It can be observed that the results are still blurry and the fine details are not fully recovered.]]<br />
<br />
== Single Image 3D Reconstruction ==<br />
The development of large-scale shape repositories like ShapeNet has allowed for the development of models encoding shape priors for single image 3D reconstruction. These methods normally regress voxelized 3D shapes, relying on synthetic data or 2D masks for training. The formulation in the paper tackles domain adaptation better, since the network can be fine-tuned on images without any annotations.<br />
<br />
== 2D-3D Consistency ==<br />
Intuitively, the 3D shape can be constrained to be consistent with 2D observations. This idea has been explored for decades, with the use of depth and silhouettes, as well as some papers enforcing differentiable 2D-3D constraints for joint training of deep networks. In this work, this idea is exploited to develop differentiable constraints for consistency between the 2.5D sketches and 3D shape.<br />
<br />
= Approach =<br />
The 3D structure is recovered from a single RGB view using three steps, shown in the figure below. The first step estimates 2.5D sketches, including depth, surface normal, and silhouette of the object. The second step estimates a 3D voxel representation of the object. The third step uses a reprojection consistency function to enforce the 2.5D sketch and 3D structure alignment.<br />
<br />
[[File:marrnet_model_components.png|700px|thumb|center|MarrNet architecture. 2.5D sketches of normals, depths, and silhouette are first estimated. The sketches are then used to estimate the 3D shape. Finally, re-projection consistency is used to ensure consistency between the sketch and 3D output.]]<br />
<br />
== 2.5D Sketch Estimation ==<br />
The first step takes a 2D RGB image and predicts the surface normal, depth, and silhouette of the object. The goal is to estimate intrinsic object properties from the image, while discarding non-essential information. An encoder-decoder architecture is used. The encoder is a A ResNet-18 network, which takes a 256 x 256 RGB image and produces 8 x 8 x 512 feature maps. The decoder is four sets of 5 x 5 convolutional and ReLU layers, followed by four sets of 1 x 1 convolutional and ReLU layers. The output is 256 x 256 resolution depth, surface normal, and silhouette images.<br />
<br />
== 3D Shape Estimation ==<br />
The second step estimates a voxelized 3D shape using the 2.5D sketches from the first step. The focus here is for the network to learn the shape prior that can explain the input well, and can be trained on synthetic data without suffering from the domain adaptation problem. The network architecture is inspired by the TL network, and 3D-VAE-GAN, with an encoder-decoder structure. The normal and depth image, masked by the estimated silhouette, are passed into 5 sets of convolutional, ReLU, and pooling layers, followed by two fully connected layers, with a final output width of 200. The 200-dimensional vector is passed into a decoder of 5 convolutional and ReLU layers, outputting a 128 x 128 x 128 voxelized estimate of the input.<br />
<br />
== Re-projection Consistency ==<br />
The third step consists of a depth re-projection loss and surface normal re-projection loss. Here, <math>v_{x, y, z}</math> represents the value at position <math>(x, y, z)</math> in a 3D voxel grid, with <math>v_{x, y, z} \in [0, 1] ∀ x, y, z</math>. <math>d_{x, y}</math> denotes the estimated depth at position <math>(x, y)</math>, <math>n_{x, y} = (n_a, n_b, n_c)</math> denotes the estimated surface normal. Orthographic projection is used.<br />
<br />
[[File:marrnet_reprojection_consistency.png|700px|thumb|center|Reprojection consistency for voxels. Left and middle: criteria for depth and silhouettes. Right: criterion for surface normals]]<br />
<br />
=== Depths ===<br />
The voxel with depth <math>v_{x, y}, d_{x, y}</math> should be 1, while all voxels in front of it should be 0. This ensures the estimated 3D shape matches the estimated depth values. The projected depth loss and its gradient are defined as follows:<br />
<br />
<math><br />
L_{depth}(x, y, z)=<br />
\left\{<br />
\begin{array}{ll}<br />
v^2_{x, y, z}, & z < d_{x, y} \\<br />
(1 - v_{x, y, z})^2, & z = d_{x, y} \\<br />
0, & z > d_{x, y} \\<br />
\end{array}<br />
\right.<br />
</math><br />
<br />
<math><br />
\frac{∂L_{depth}(x, y, z)}{∂v_{x, y, z}} =<br />
\left\{<br />
\begin{array}{ll}<br />
2v{x, y, z}, & z < d_{x, y} \\<br />
2(v_{x, y, z} - 1), & z = d_{x, y} \\<br />
0, & z > d_{x, y} \\<br />
\end{array}<br />
\right.<br />
</math><br />
<br />
When <math>d_{x, y} = \infty</math>, all voxels in front of it should be 0.<br />
<br />
=== Surface Normals ===<br />
Since vectors <math>n_{x} = (0, −n_{c}, n_{b})</math> and <math>n_{y} = (−n_{c}, 0, n_{a})</math> are orthogonal to the normal vector <math>n_{x, y} = (n_{a}, n_{b}, n_{c})</math>, they can be normalized to obtain <math>n’_{x} = (0, −1, n_{b}/n_{c})</math> and <math>n’_{y} = (−1, 0, n_{a}/n_{c})</math> on the estimated surface plane at <math>(x, y, z)</math>. The projected surface normal tried to guarantee voxels at <math>(x, y, z) ± n’_{x}</math> and <math>(x, y, z) ± n’_{y}</math> should be 1 to match the estimated normal. The constraints are only applied when the target voxels are inside the estimated silhouette.<br />
<br />
The projected surface normal loss is defined as follows, with <math>z = d_{x, y}</math>:<br />
<br />
<math><br />
L_{normal}(x, y, z) =<br />
(1 - v_{x, y-1, z+\frac{n_b}{n_c}})^2 + (1 - v_{x, y+1, z-\frac{n_b}{n_c}})^2 + <br />
(1 - v_{x-1, y, z+\frac{n_a}{n_c}})^2 + (1 - v_{x+1, y, z-\frac{n_a}{n_c}})^2<br />
</math><br />
<br />
Gradients along x are:<br />
<br />
<math><br />
\frac{dL_{normal}(x, y, z)}{dv_{x-1, y, z+\frac{n_a}{n_c}}} = 2(v_{x-1, y, z+\frac{n_a}{n_c}}-1)<br />
</math><br />
and<br />
<math><br />
\frac{dL_{normal}(x, y, z)}{dv_{x+1, y, z-\frac{n_a}{n_c}}} = 2(v_{x+1, y, z-\frac{n_a}{n_c}}-1)<br />
</math><br />
<br />
Gradients along y are similar to x.<br />
<br />
= Training =<br />
The 2.5D and 3D estimation components are first pre-trained separately on synthetic data from ShapeNet, and then fine-tuned on real images.<br />
<br />
For pre-training, the 2.5D sketch estimator is trained on synthetic ShapeNet depth, surface normal, and silhouette ground truth, using an L2 loss. The 3D estimator is trained with ground truth voxels using a cross-entropy loss.<br />
<br />
Reprojection consistency loss is used to fine-tune the 3D estimation using real images, using the predicted depth, normals, and silhouette. A straightforward implementation leads to shapes that explain the 2.5D sketches well, but lead to unrealistic 3D appearance due to overfitting.<br />
<br />
Instead, the decoder of the 3D estimator is fixed, and only the encoder is fine-tuned. The model is fine-tuned separately on each image for 40 iterations, which takes up to 10 seconds on the GPU. Without fine-tuning, testing time takes around 100 milliseconds. SGD is used for optimization with batch size of 4, learning rate of 0.001, and momentum of 0.9.<br />
<br />
= Evaluation =<br />
Qualitative and quantitative results are provided using different variants of the framework. The framework is evaluated on both synthetic and real images on three datasets.<br />
<br />
== ShapeNet ==<br />
Synthesized images of 6,778 chairs from ShapeNet are rendered from 20 random viewpoints. The chairs are placed in front of random background from the SUN dataset, and the RGB, depth, normal, and silhouette images are rendered using the physics-based renderer Mitsuba for more realistic images.<br />
<br />
=== Method ===<br />
MarrNet is trained without the final fine-tuning stage, since 3D shapes are available. A baseline is created that directly predicts the 3D shape using the same 3D shape estimator architecture with no 2.5D sketch estimation.<br />
<br />
=== Results ===<br />
The baseline output is compared to the full framework, and the figure below shows that MarrNet provides model outputs with more details and smoother surfaces than the baseline. Quantitatively, the full model also achieves 0.57 IoU, higher than the direct prediction baseline.<br />
<br />
[[File:marrnet_shapenet_results.png|700px|thumb|center|ShapeNet results.]]<br />
<br />
== PASCAL 3D+ ==<br />
Rough 3D models are provided from real-life images.<br />
<br />
=== Method ===<br />
Each module is pre-trained on the ShapeNet dataset, and then fine-tuned on the PASCAL 3D+ dataset. Three variants of the model are tested. The first is trained using ShapeNet data only with no fine-tuning. The second is fine-tuned without fixing the decoder. The third is fine-tuned with a fixed decoder.<br />
<br />
=== Results ===<br />
The figure below shows the results of the ablation study. The model trained only on synthetic data provides reasonable estimates. However, fine-tuning without fixing the decoder leads to impossible shapes from certain views. The third model keeps the shape prior, providing more details in the final shape.<br />
<br />
[[File:marrnet_pascal_3d_ablation.png|600px|thumb|center|Ablation studies using the PASCAL 3D+ dataset.]]<br />
<br />
Additional comparisons are made with the state-of-the-art (DRC) on the provided ground truth shapes. MarrNet achieves 0.39 IoU, while DRC achieves 0.34. However, the authors claim that the IoU metric is sub-optimal for three reasons. First, there is no emphasis on details since the metric prefers models that predict mean shapes consistently. Second, all possible scales are searched during the IoU computation, making it less efficient. Third, PASCAL 3D+ only has rough annotations, with only 10 CAD chair models for all images, and computing IoU with these shapes is not very informative. Instead, human studies are conducted and MarrNet reconstructions are preferred 74% of the time over DRC, and 42% of the time to ground truth. This shows how MarrNet produces nice shapes and also highlights the fact that ground truth shapes are not very good.<br />
<br />
[[File:human_studies.png|400px|thumb|center|Human preferences on chairs in PASCAL 3D+ (Xiang et al. 2014). The numbers show the percentage of how often humans prefered the 3D shape from DRC (state-of-the-art), MarrNet, or GT.]]<br />
<br />
<br />
[[File:marrnet_pascal_3d_drc_comparison.png|600px|thumb|center|Comparison between DRC and MarrNet results.]]<br />
<br />
Several failure cases are shown in the figure below. Specifically, the framework does not seem to work well on thin structures.<br />
<br />
[[File:marrnet_pascal_3d_failure_cases.png|500px|thumb|center|Failure cases on PASCAL 3D+. The algorithm cannot recover thin structures.]]<br />
<br />
== IKEA ==<br />
This dataset contains images of IKEA furniture, with accurate 3D shape and pose annotations. Objects are often heavily occluded or truncated.<br />
<br />
=== Results ===<br />
Qualitative results are shown in the figure below. The model is shown to deal with mild occlusions in real life scenarios. Human studes show that MarrNet reconstructions are preferred 61% of the time to 3D-VAE-GAN.<br />
<br />
[[File:marrnet_ikea_results.png|700px|thumb|center|Results on chairs in the IKEA dataset, and comparison with 3D-VAE-GAN.]]<br />
<br />
== Other Data ==<br />
MarrNet is also applied on cars and airplanes. Shown below, smaller details such as the horizontal stabilizer and rear-view mirrors are recovered.<br />
<br />
[[File:marrnet_airplanes_and_cars.png|700px|thumb|center|Results on airplanes and cars from the PASCAL 3D+ dataset, and comparison with DRC.]]<br />
<br />
MarrNet is also jointly trained on three object categories, and successfully recovers the shapes of different categories. Results are shown in the figure below.<br />
<br />
[[File:marrnet_multiple_categories.png|700px|thumb|center|Results when trained jointly on all three object categories (cars, airplanes, and chairs).]]<br />
<br />
= Commentary =<br />
Qualitatively, the results look quite impressive. The 2.5D sketch estimation seems to distill the useful information for more realistic looking 3D shape estimation. The disentanglement of 2.5D and 3D estimation steps also allows for easier training and domain adaptation from synthetic data.<br />
<br />
As the authors mention, the IoU metric is not very descriptive, and most of the comparisons in this paper are only qualitative, mainly being human preference studies. A better quantitative evaluation metric would greatly help in making an unbiased comparison between different results.<br />
<br />
As seen in several of the results, the network does not deal well with objects that have thin structures, which is particularly noticeable with many of the chair arm rests. As well, looking more carefully at some results, it seems that fine-tuning only the 3D encoder does not seem to transfer well to unseen objects, since shape priors have already been learned by the decoder.<br />
<br />
= Conclusion =<br />
The proposed MarrNet employs a novel model to estimate 2.5D sketches for 3D shape reconstruction. The sketches are shown to improve the model’s performance, and make it easy to adapt to images across different domains and categories. Differentiable loss functions are created such that the model can be fine-tuned end-to-end on images without ground truth. The experiments show that the model performs well, and human studies show that the results are preferred over other methods.<br />
<br />
= Implementation =<br />
The following repository provides the source code for the paper. The repository provides the source code as written by the authors: https://github.com/jiajunwu/marrnet<br />
<br />
= References =<br />
# David Marr. Vision: A computational investigation into the human representation and processing of visual information. W. H. Freeman and Company, 1982.<br />
# Shubham Tulsiani, Tinghui Zhou, Alexei A Efros, and Jitendra Malik. Multi-view supervision for single-view reconstruction via differentiable ray consistency. In CVPR, 2017.<br />
# JiajunWu, Chengkai Zhang, Tianfan Xue,William T Freeman, and Joshua B Tenenbaum. Learning a Proba- bilistic Latent Space of Object Shapes via 3D Generative-Adversarial Modeling. In NIPS, 2016b.<br />
# Wu, J. (n.d.). Jiajunwu/marrnet. Retrieved March 25, 2018, from https://github.com/jiajunwu/marrnet</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=MarrNet:_3D_Shape_Reconstruction_via_2.5D_Sketches&diff=35636MarrNet: 3D Shape Reconstruction via 2.5D Sketches2018-03-27T03:55:16Z<p>X249wang: /* 2.5D Sketch Estimation */</p>
<hr />
<div>= Introduction =<br />
Humans are able to quickly recognize 3D shapes from images, even in spite of drastic differences in object texture, material, lighting, and background.<br />
<br />
[[File:marrnet_intro_image.png|700px|thumb|center|Objects in real images. The appearance of the same shaped object varies based on colour, texture, lighting, background, etc. However, the 2.5D sketches (e.g. depth or normal maps) of the object remain constant, and can be seen as an abstraction of the object which is used to reconstruct the 3D shape.]]<br />
<br />
In this work, the authors propose a novel end-to-end trainable model that sequentially estimates 2.5D sketches and 3D object shape from images and also enforce the re projection consistency between the 3D shape and the estimated sketch. The two step approach makes the network more robust to differences in object texture, material, lighting and background. Based on the idea from [Marr, 1982] that human 3D perception relies on recovering 2.5D sketches, which include depth and surface normal maps, the author’s design an end-to-end trainable pipeline which they call MarrNet. MarrNet first estimates depth, normal maps, and silhouette, followed by a 3D shape. MarrNet uses an encoder-decoder structure for the sub-components of the framework. <br />
<br />
The authors claim several unique advantages to their method. Single image 3D reconstruction is a highly under-constrained problem, requiring strong prior knowledge of object shapes. As well, accurate 3D object annotations using real images are not common, and many previous approaches rely on purely synthetic data. However, most of these methods suffer from domain adaptation due to imperfect rendering.<br />
<br />
Using 2.5D sketches can alleviate the challenges of domain transfer. It is straightforward to generate perfect object surface normals and depths using a graphics engine. Since 2.5D sketches contain only depth, surface normal, and silhouette information, the second step of recovering 3D shape can be trained purely from synthetic data. As well, the introduction of differentiable constraints between 2.5D sketches and 3D shape makes it possible to fine-tune the system, even without any annotations.<br />
<br />
The framework is evaluated on both synthetic objects from ShapeNet, and real images from PASCAL 3D+, showing good qualitative and quantitative performance in 3D shape reconstruction.<br />
<br />
= Related Work =<br />
<br />
== 2.5D Sketch Recovery ==<br />
Researchers have explored recovering 2.5D information from shading, texture, and colour images in the past. More recently, the development of depth sensors has led to the creation of large RGB-D datasets, and papers on estimating depth, surface normals, and other intrinsic images using deep networks. While this method employs 2.5D estimation, the final output is a full 3D shape of an object.<br />
<br />
[[File:2-5d_example.PNG|700px|thumb|center|Results from the paper: Learning Non-Lambertian Object Intrinsics across ShapeNet Categories. The results show that neural networks can be trained to recover 2.5D information from an image. The top row predicts the albedo and the bottom row predicts the shading. It can be observed that the results are still blurry and the fine details are not fully recovered.]]<br />
<br />
== Single Image 3D Reconstruction ==<br />
The development of large-scale shape repositories like ShapeNet has allowed for the development of models encoding shape priors for single image 3D reconstruction. These methods normally regress voxelized 3D shapes, relying on synthetic data or 2D masks for training. The formulation in the paper tackles domain adaptation better, since the network can be fine-tuned on images without any annotations.<br />
<br />
== 2D-3D Consistency ==<br />
Intuitively, the 3D shape can be constrained to be consistent with 2D observations. This idea has been explored for decades, with the use of depth and silhouettes, as well as some papers enforcing differentiable 2D-3D constraints for joint training of deep networks. In this work, this idea is exploited to develop differentiable constraints for consistency between the 2.5D sketches and 3D shape.<br />
<br />
= Approach =<br />
The 3D structure is recovered from a single RGB view using three steps, shown in the figure below. The first step estimates 2.5D sketches, including depth, surface normal, and silhouette of the object. The second step estimates a 3D voxel representation of the object. The third step uses a reprojection consistency function to enforce the 2.5D sketch and 3D structure alignment.<br />
<br />
[[File:marrnet_model_components.png|700px|thumb|center|MarrNet architecture. 2.5D sketches of normals, depths, and silhouette are first estimated. The sketches are then used to estimate the 3D shape. Finally, re-projection consistency is used to ensure consistency between the sketch and 3D output.]]<br />
<br />
== 2.5D Sketch Estimation ==<br />
The first step takes a 2D RGB image and predicts the surface normal, depth, and silhouette of the object. The goal is to estimate intrinsic object properties from the image, while discarding non-essential information. An encoder-decoder architecture is used. The encoder is a A ResNet-18 network, which takes a 256 x 256 RGB image and produces 8 x 8 x 512 feature maps. The decoder is four sets of 5 x 5 convolutional and ReLU layers, followed by four sets of 1 x 1 convolutional and ReLU layers. The output is 256 x 256 resolution depth, surface normal, and silhouette images.<br />
<br />
== 3D Shape Estimation ==<br />
The second step estimates a voxelized 3D shape using the 2.5D sketches from the first step. The focus here is for the network to learn the shape prior that can explain the input well, and can be trained on synthetic data without suffering from the domain adaptation problem. The network architecture is inspired by the TL network, and 3D-VAE-GAN, with an encoder-decoder structure. The normal and depth image, masked by the estimated silhouette, are passed into 5 sets of convolutional, ReLU, and pooling layers, followed by two fully connected layers, with a final output width of 200. The 200-dimensional vector is passed into a decoder of 5 convolutional and ReLU layers, outputting a 128 x 128 x 128 voxelized estimate of the input.<br />
<br />
== Re-projection Consistency ==<br />
The third step consists of a depth re-projection loss and surface normal re-projection loss. Here, <math>v_{x, y, z}</math> represents the value at position <math>(x, y, z)</math> in a 3D voxel grid, with <math>v_{x, y, z} \in [0, 1] ∀ x, y, z</math>. <math>d_{x, y}</math> denotes the estimated depth at position <math>(x, y)</math>, <math>n_{x, y} = (n_a, n_b, n_c)</math> denotes the estimated surface normal. Orthographic projection is used.<br />
<br />
[[File:marrnet_reprojection_consistency.png|700px|thumb|center|Reprojection consistency for voxels. Left and middle: criteria for depth and silhouettes. Right: criterion for surface normals]]<br />
<br />
=== Depths ===<br />
The voxel with depth <math>v_{x, y}, d_{x, y}</math> should be 1, while all voxels in front of it should be 0. The projected depth loss is defined as follows:<br />
<br />
<math><br />
L_{depth}(x, y, z)=<br />
\left\{<br />
\begin{array}{ll}<br />
v^2_{x, y, z}, & z < d_{x, y} \\<br />
(1 - v_{x, y, z})^2, & z = d_{x, y} \\<br />
0, & z > d_{x, y} \\<br />
\end{array}<br />
\right.<br />
</math><br />
<br />
<math><br />
\frac{∂L_{depth}(x, y, z)}{∂v_{x, y, z}} =<br />
\left\{<br />
\begin{array}{ll}<br />
2v{x, y, z}, & z < d_{x, y} \\<br />
2(v_{x, y, z} - 1), & z = d_{x, y} \\<br />
0, & z > d_{x, y} \\<br />
\end{array}<br />
\right.<br />
</math><br />
<br />
When <math>d_{x, y} = \infty</math>, all voxels in front of it should be 0.<br />
<br />
=== Surface Normals ===<br />
Since vectors <math>n_{x} = (0, −n_{c}, n_{b})</math> and <math>n_{y} = (−n_{c}, 0, n_{a})</math> are orthogonal to the normal vector <math>n_{x, y} = (n_{a}, n_{b}, n_{c})</math>, they can be normalized to obtain <math>n’_{x} = (0, −1, n_{b}/n_{c})</math> and <math>n’_{y} = (−1, 0, n_{a}/n_{c})</math> on the estimated surface plane at <math>(x, y, z)</math>. The projected surface normal tried to guarantee voxels at <math>(x, y, z) ± n’_{x}</math> and <math>(x, y, z) ± n’_{y}</math> should be 1 to match the estimated normal. The constraints are only applied when the target voxels are inside the estimated silhouette.<br />
<br />
The projected surface normal loss is defined as follows, with <math>z = d_{x, y}</math>:<br />
<br />
<math><br />
L_{normal}(x, y, z) =<br />
(1 - v_{x, y-1, z+\frac{n_b}{n_c}})^2 + (1 - v_{x, y+1, z-\frac{n_b}{n_c}})^2 + <br />
(1 - v_{x-1, y, z+\frac{n_a}{n_c}})^2 + (1 - v_{x+1, y, z-\frac{n_a}{n_c}})^2<br />
</math><br />
<br />
Gradients along x are:<br />
<br />
<math><br />
\frac{dL_{normal}(x, y, z)}{dv_{x-1, y, z+\frac{n_a}{n_c}}} = 2(v_{x-1, y, z+\frac{n_a}{n_c}}-1)<br />
</math><br />
and<br />
<math><br />
\frac{dL_{normal}(x, y, z)}{dv_{x+1, y, z-\frac{n_a}{n_c}}} = 2(v_{x+1, y, z-\frac{n_a}{n_c}}-1)<br />
</math><br />
<br />
Gradients along y are similar to x.<br />
<br />
= Training =<br />
The 2.5D and 3D estimation components are first pre-trained separately on synthetic data from ShapeNet, and then fine-tuned on real images.<br />
<br />
For pre-training, the 2.5D sketch estimator is trained on synthetic ShapeNet depth, surface normal, and silhouette ground truth, using an L2 loss. The 3D estimator is trained with ground truth voxels using a cross-entropy loss.<br />
<br />
Reprojection consistency loss is used to fine-tune the 3D estimation using real images, using the predicted depth, normals, and silhouette. A straightforward implementation leads to shapes that explain the 2.5D sketches well, but lead to unrealistic 3D appearance due to overfitting.<br />
<br />
Instead, the decoder of the 3D estimator is fixed, and only the encoder is fine-tuned. The model is fine-tuned separately on each image for 40 iterations, which takes up to 10 seconds on the GPU. Without fine-tuning, testing time takes around 100 milliseconds. SGD is used for optimization with batch size of 4, learning rate of 0.001, and momentum of 0.9.<br />
<br />
= Evaluation =<br />
Qualitative and quantitative results are provided using different variants of the framework. The framework is evaluated on both synthetic and real images on three datasets.<br />
<br />
== ShapeNet ==<br />
Synthesized images of 6,778 chairs from ShapeNet are rendered from 20 random viewpoints. The chairs are placed in front of random background from the SUN dataset, and the RGB, depth, normal, and silhouette images are rendered using the physics-based renderer Mitsuba for more realistic images.<br />
<br />
=== Method ===<br />
MarrNet is trained without the final fine-tuning stage, since 3D shapes are available. A baseline is created that directly predicts the 3D shape using the same 3D shape estimator architecture with no 2.5D sketch estimation.<br />
<br />
=== Results ===<br />
The baseline output is compared to the full framework, and the figure below shows that MarrNet provides model outputs with more details and smoother surfaces than the baseline. Quantitatively, the full model also achieves 0.57 IoU, higher than the direct prediction baseline.<br />
<br />
[[File:marrnet_shapenet_results.png|700px|thumb|center|ShapeNet results.]]<br />
<br />
== PASCAL 3D+ ==<br />
Rough 3D models are provided from real-life images.<br />
<br />
=== Method ===<br />
Each module is pre-trained on the ShapeNet dataset, and then fine-tuned on the PASCAL 3D+ dataset. Three variants of the model are tested. The first is trained using ShapeNet data only with no fine-tuning. The second is fine-tuned without fixing the decoder. The third is fine-tuned with a fixed decoder.<br />
<br />
=== Results ===<br />
The figure below shows the results of the ablation study. The model trained only on synthetic data provides reasonable estimates. However, fine-tuning without fixing the decoder leads to impossible shapes from certain views. The third model keeps the shape prior, providing more details in the final shape.<br />
<br />
[[File:marrnet_pascal_3d_ablation.png|600px|thumb|center|Ablation studies using the PASCAL 3D+ dataset.]]<br />
<br />
Additional comparisons are made with the state-of-the-art (DRC) on the provided ground truth shapes. MarrNet achieves 0.39 IoU, while DRC achieves 0.34. However, the authors claim that the IoU metric is sub-optimal for three reasons. First, there is no emphasis on details since the metric prefers models that predict mean shapes consistently. Second, all possible scales are searched during the IoU computation, making it less efficient. Third, PASCAL 3D+ only has rough annotations, with only 10 CAD chair models for all images, and computing IoU with these shapes is not very informative. Instead, human studies are conducted and MarrNet reconstructions are preferred 74% of the time over DRC, and 42% of the time to ground truth. This shows how MarrNet produces nice shapes and also highlights the fact that ground truth shapes are not very good.<br />
<br />
[[File:human_studies.png|400px|thumb|center|Human preferences on chairs in PASCAL 3D+ (Xiang et al. 2014). The numbers show the percentage of how often humans prefered the 3D shape from DRC (state-of-the-art), MarrNet, or GT.]]<br />
<br />
<br />
[[File:marrnet_pascal_3d_drc_comparison.png|600px|thumb|center|Comparison between DRC and MarrNet results.]]<br />
<br />
Several failure cases are shown in the figure below. Specifically, the framework does not seem to work well on thin structures.<br />
<br />
[[File:marrnet_pascal_3d_failure_cases.png|500px|thumb|center|Failure cases on PASCAL 3D+. The algorithm cannot recover thin structures.]]<br />
<br />
== IKEA ==<br />
This dataset contains images of IKEA furniture, with accurate 3D shape and pose annotations. Objects are often heavily occluded or truncated.<br />
<br />
=== Results ===<br />
Qualitative results are shown in the figure below. The model is shown to deal with mild occlusions in real life scenarios. Human studes show that MarrNet reconstructions are preferred 61% of the time to 3D-VAE-GAN.<br />
<br />
[[File:marrnet_ikea_results.png|700px|thumb|center|Results on chairs in the IKEA dataset, and comparison with 3D-VAE-GAN.]]<br />
<br />
== Other Data ==<br />
MarrNet is also applied on cars and airplanes. Shown below, smaller details such as the horizontal stabilizer and rear-view mirrors are recovered.<br />
<br />
[[File:marrnet_airplanes_and_cars.png|700px|thumb|center|Results on airplanes and cars from the PASCAL 3D+ dataset, and comparison with DRC.]]<br />
<br />
MarrNet is also jointly trained on three object categories, and successfully recovers the shapes of different categories. Results are shown in the figure below.<br />
<br />
[[File:marrnet_multiple_categories.png|700px|thumb|center|Results when trained jointly on all three object categories (cars, airplanes, and chairs).]]<br />
<br />
= Commentary =<br />
Qualitatively, the results look quite impressive. The 2.5D sketch estimation seems to distill the useful information for more realistic looking 3D shape estimation. The disentanglement of 2.5D and 3D estimation steps also allows for easier training and domain adaptation from synthetic data.<br />
<br />
As the authors mention, the IoU metric is not very descriptive, and most of the comparisons in this paper are only qualitative, mainly being human preference studies. A better quantitative evaluation metric would greatly help in making an unbiased comparison between different results.<br />
<br />
As seen in several of the results, the network does not deal well with objects that have thin structures, which is particularly noticeable with many of the chair arm rests. As well, looking more carefully at some results, it seems that fine-tuning only the 3D encoder does not seem to transfer well to unseen objects, since shape priors have already been learned by the decoder.<br />
<br />
= Conclusion =<br />
The proposed MarrNet employs a novel model to estimate 2.5D sketches for 3D shape reconstruction. The sketches are shown to improve the model’s performance, and make it easy to adapt to images across different domains and categories. Differentiable loss functions are created such that the model can be fine-tuned end-to-end on images without ground truth. The experiments show that the model performs well, and human studies show that the results are preferred over other methods.<br />
<br />
= Implementation =<br />
The following repository provides the source code for the paper. The repository provides the source code as written by the authors: https://github.com/jiajunwu/marrnet<br />
<br />
= References =<br />
# David Marr. Vision: A computational investigation into the human representation and processing of visual information. W. H. Freeman and Company, 1982.<br />
# Shubham Tulsiani, Tinghui Zhou, Alexei A Efros, and Jitendra Malik. Multi-view supervision for single-view reconstruction via differentiable ray consistency. In CVPR, 2017.<br />
# JiajunWu, Chengkai Zhang, Tianfan Xue,William T Freeman, and Joshua B Tenenbaum. Learning a Proba- bilistic Latent Space of Object Shapes via 3D Generative-Adversarial Modeling. In NIPS, 2016b.<br />
# Wu, J. (n.d.). Jiajunwu/marrnet. Retrieved March 25, 2018, from https://github.com/jiajunwu/marrnet</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Multi-scale_Dense_Networks_for_Resource_Efficient_Image_Classification&diff=35634Multi-scale Dense Networks for Resource Efficient Image Classification2018-03-27T03:42:12Z<p>X249wang: /* Anytime Prediction */</p>
<hr />
<div>= Introduction = <br />
<br />
Multi-Scale Dense Networks, MSDNets, are designed to address the growing demand for efficient object recognition. The issue with existing recognition networks is that they are either:<br />
efficient networks, but don't do well on hard examples, or large networks that do well on all examples but require a large amount of resources.<br />
<br />
In order to be efficient on all difficulties MSDNets propose a structure that can accurately output classifications for varying levels of computational requirements. The two cases that are used to evaluate the network are:<br />
Anytime Prediction: What is the best prediction the network can provide when suddenly prompted.<br />
Budget Batch Predictions: Given a maximum amount of computational resources how well does the network do on the batch.<br />
<br />
= Related Networks =<br />
<br />
== Computationally Efficient Networks ==<br />
<br />
Existing methods for refining an accurate network to be more efficient include weight pruning, quantization of weights (during or after training), and knowledge distillation, which trains smaller network to match teacher network.<br />
<br />
== Resource Efficient Networks == <br />
<br />
Unlike the above, resource efficient concepts consider limited resources as a part of the structure/loss.<br />
Examples of work in this area include: <br />
* Efficient variants to existing state of the art networks<br />
* Gradient boosted decision trees, which incorporate computational limitations into the training<br />
* Fractal nets<br />
* Adaptive computation time method<br />
<br />
== Related architectures ==<br />
<br />
MSDNets pull on concepts from a number of existing networks:<br />
* Neural fabrics and others, are used to quickly establish a low resolution feature map, which is integral for classification.<br />
* Deeply supervised nets, introduced the incorporation of multiple classifiers throughout the network<br />
* The feature concatenation method from DenseNets allows the later classifiers to not be disrupted by the weight updates from earlier classifiers.<br />
<br />
= Problem Setup =<br />
The authors consider two settings that impose computational constraints at prediction time.<br />
<br />
== Anytime Prediction ==<br />
In the anytime prediction setting (Grubb & Bagnell, 2012), there is a finite computational budget <math>B > 0</math> available for each test example <math>x</math>. The budget is nondeterministic and varies per test instance.<br />
<br />
== Budgeted Batch Classification ==<br />
In the budgeted batch classification setting, the model needs to classify a set of examples <math>D_test = {x_1, . . . , x_M}</math> within a finite computational budget <math>B > 0</math> that is known in advance.<br />
<br />
= Multi-Scale Dense Networks =<br />
<br />
== Integral Contributions ==<br />
<br />
The way MSDNets aims to provide efficient classification with varying computational costs is to create one network that outputs results at depths. While this may seem trivial, as intermediate classifiers can be inserted into any existing network, two major problems arise.<br />
<br />
=== Coarse Level Features Needed For Classification ===<br />
<br />
[[File:paper29 fig3.png | 700px|thumb|center]]<br />
<br />
Coarse level features are needed to gain context of scene. In typical CNN based networks, the features propagate from fine to coarse. Classifiers added to the early, fine featured, layers do not output accurate predictions due to the lack of context.<br />
<br />
Figure 3 depicts relative accuracies of the intermediate classifiers and shows that the accuracy of a classifier is highly correlated with its position in the network. It is easy to see, specifically with the case of ResNet, that the classifiers improve in a staircase pattern. All of the experiments were performed on Cifar-100 dataset and it can be seen that the intermediate classifiers perform worst than the final classifiers, thus highlighting the problem with the lack of coarse level features early on.<br />
<br />
To address this issue, MSDNets proposes an architecture in which uses multi scaled feature maps. The feature maps at a particular layer and scale are computed by concatenating results from up to two convolutions: a standard convolution is first applied to same-scale features from the previous layer to pass on high-resolution information that subsequent layers can use to construct better coarse features, and if possible, a strided convolution is also applied on the finer-scale feature map from the previous layer to produce coarser features amenable to classification. The network is quickly formed to contain a set number of scales ranging from fine to coarse. These scales are propagated throughout, so that for the length of the network there are always coarse level features for classification and fine features for learning more difficult representations.<br />
<br />
=== Training of Early Classifiers Interferes with Later Classifiers ===<br />
<br />
When training a network containing intermediate classifiers, the training of early classifiers will cause the early layers to focus on features for that classifier. These learned features may not be as useful to the later classifiers and degrade their accuracy.<br />
<br />
MSDNets use dense connectivity to avoid this issue. By concatenating all prior layers to learn future layers, the gradient propagation is spread throughout the available features. This allows later layers to not be reliant on any single prior, providing opportunities to learn new features that priors have ignored.<br />
<br />
== Architecture ==<br />
<br />
[[File:MSDNet_arch.png | 700px|thumb|center|Left: the MSDNet architecture. Right: example calculations for each output given 3 scales and 4 layers.]]<br />
<br />
The architecture of MSDNet is a structure of convolutions with a set number of layers and a set number of scales. Layers allow the network to build on the previous information to generate more accurate predictions, while the scales allow the network to maintain coarse level features throughout.<br />
<br />
The first layer is a special, mini-CNN-network, that quickly fills all required scales with features. The following layers are generated through the convolutions of the previous layers and scales.<br />
<br />
Each output at a given s scale is given by the convolution of all prior outputs of the same scale, and the strided-convolution of all prior outputs from the previous scale. <br />
<br />
The classifiers are run on the concatenation of all of the coarsest outputs from the preceding layers.<br />
<br />
=== Loss Function ===<br />
<br />
The loss is calculated as a weighted sum of each classifier's logistic loss. The weighted loss is taken as an average over a set of training samples. The weights can be determined from a budget of computational power, but results also show that setting all to 1 is also acceptable.<br />
<br />
=== Computational Limit Inclusion ===<br />
<br />
When running in a budgeted batch scenario, the network attempts to provide the best overall accuracy. To do this with a set limit on computational resources, it works to use less of the budget on easy detections in order to allow more time to be spent on hard ones. <br />
In order to facilitate this, the classifiers are designed to exit when the confidence of the classification exceeds a preset threshold. To determine the threshold for each classifier, <math>|D_{test}|\sum_{k}(q_k C_k) \leq B </math> must be true. Where <math>|D_{test}|</math> is the total number of test samples, <math>C_k</math> is the computational requirement to get an output from the <math>k</math>th classifier, and <math>q_k </math> is the probability that a sample exits at the <math>k</math>th classifier. Assuming that all classifiers have the same base probability, <math>q</math>, then <math>q_k</math> can be used to find the threshold.<br />
<br />
=== Network Reduction and Lazy Evaluation ===<br />
There are two ways to reduce the computational needs of MSDNets:<br />
<br />
# Reduce the size of the network by splitting it into <math>S</math> blocks along the depth dimension and keeping the <math>(S-i+1)</math> scales in the <math>i^{\text{th}}</math> block.<br />
# Remove unnecessary computations: Group the computation in "diagonal blocks"; this propagates the example along paths that are required for the evaluation of the next classifier.<br />
<br />
= Experiments = <br />
<br />
When evaluating on CIFAR-10 and CIFAR-100 ensembles and multi-classifier versions of ResNets and DenseNets, as well as FractalNet are used to compare with MSDNet. <br />
<br />
When evaluating on ImageNet ensembles and individual versions of ResNets and DenseNets are compared with MSDNets.<br />
<br />
== Anytime Prediction ==<br />
<br />
In anytime prediction MSDNets are shown to have highly accurate with very little budget, and continue to remain above the alternate methods as the budget increases. The authors attributed this to the fact that MSDNets are able to produce low-resolution feature maps well-suited for classification after just a few layers, in contrast to the high-resolution feature maps in early layers of ResNets or DenseNets. Ensemble networks need to repeat computations of similar low-level features repeatedly when new models need to be evaluated, so their accuracy results do not increase as fast when computational budget increases. <br />
<br />
[[File:MSDNet_anytime.png | 700px|thumb|center|Accuracy of the anytime classification models.]]<br />
<br />
== Budget Batch ==<br />
<br />
For budget batch 3 MSDNets are designed with classifiers set-up for varying ranges of budget constraints. On both dataset options the MSDNets exceed all alternate methods with a fraction of the budget required.<br />
<br />
[[File:MSDNet_budgetbatch.png | 700px|thumb|center|Accuracy of the budget batch classification models.]]<br />
<br />
= Critique = <br />
<br />
The problem formulation and scenario evaluation were very well formulated, and according to independent reviews, the results were reproducible. Where the paper could improve is on explaining how to implement the threshold; it isn't very well explained how the use of the validation set can be used to set the threshold value.<br />
<br />
= Implementation =<br />
The following repository provides the source code for the paper, written by the authors: https://github.com/gaohuang/MSDNet<br />
<br />
= Sources =<br />
# Huang, G., Chen, D., Li, T., Wu, F., Maaten, L., & Weinberger, K. Q. (n.d.). Multi-Scale Dense Networks for Resource Efficient Image Classification. ICLR 2018. doi:1703.09844 <br />
# Huang, G. (n.d.). Gaohuang/MSDNet. Retrieved March 25, 2018, from https://github.com/gaohuang/MSDNet</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Multi-scale_Dense_Networks_for_Resource_Efficient_Image_Classification&diff=35631Multi-scale Dense Networks for Resource Efficient Image Classification2018-03-27T03:34:19Z<p>X249wang: /* Anytime Prediction */</p>
<hr />
<div>= Introduction = <br />
<br />
Multi-Scale Dense Networks, MSDNets, are designed to address the growing demand for efficient object recognition. The issue with existing recognition networks is that they are either:<br />
efficient networks, but don't do well on hard examples, or large networks that do well on all examples but require a large amount of resources.<br />
<br />
In order to be efficient on all difficulties MSDNets propose a structure that can accurately output classifications for varying levels of computational requirements. The two cases that are used to evaluate the network are:<br />
Anytime Prediction: What is the best prediction the network can provide when suddenly prompted.<br />
Budget Batch Predictions: Given a maximum amount of computational resources how well does the network do on the batch.<br />
<br />
= Related Networks =<br />
<br />
== Computationally Efficient Networks ==<br />
<br />
Existing methods for refining an accurate network to be more efficient include weight pruning, quantization of weights (during or after training), and knowledge distillation, which trains smaller network to match teacher network.<br />
<br />
== Resource Efficient Networks == <br />
<br />
Unlike the above, resource efficient concepts consider limited resources as a part of the structure/loss.<br />
Examples of work in this area include: <br />
* Efficient variants to existing state of the art networks<br />
* Gradient boosted decision trees, which incorporate computational limitations into the training<br />
* Fractal nets<br />
* Adaptive computation time method<br />
<br />
== Related architectures ==<br />
<br />
MSDNets pull on concepts from a number of existing networks:<br />
* Neural fabrics and others, are used to quickly establish a low resolution feature map, which is integral for classification.<br />
* Deeply supervised nets, introduced the incorporation of multiple classifiers throughout the network<br />
* The feature concatenation method from DenseNets allows the later classifiers to not be disrupted by the weight updates from earlier classifiers.<br />
<br />
= Problem Setup =<br />
The authors consider two settings that impose computational constraints at prediction time.<br />
<br />
== Anytime Prediction ==<br />
In the anytime prediction setting (Grubb & Bagnell, 2012), there is a finite computational budget <math>B > 0</math> available for each test example <math>x</math>. The budget is nondeterministic and varies per test instance.<br />
<br />
== Budgeted Batch Classification ==<br />
In the budgeted batch classification setting, the model needs to classify a set of examples <math>D_test = {x_1, . . . , x_M}</math> within a finite computational budget <math>B > 0</math> that is known in advance.<br />
<br />
= Multi-Scale Dense Networks =<br />
<br />
== Integral Contributions ==<br />
<br />
The way MSDNets aims to provide efficient classification with varying computational costs is to create one network that outputs results at depths. While this may seem trivial, as intermediate classifiers can be inserted into any existing network, two major problems arise.<br />
<br />
=== Coarse Level Features Needed For Classification ===<br />
<br />
[[File:paper29 fig3.png | 700px|thumb|center]]<br />
<br />
Coarse level features are needed to gain context of scene. In typical CNN based networks, the features propagate from fine to coarse. Classifiers added to the early, fine featured, layers do not output accurate predictions due to the lack of context.<br />
<br />
Figure 3 depicts relative accuracies of the intermediate classifiers and shows that the accuracy of a classifier is highly correlated with its position in the network. It is easy to see, specifically with the case of ResNet, that the classifiers improve in a staircase pattern. All of the experiments were performed on Cifar-100 dataset and it can be seen that the intermediate classifiers perform worst than the final classifiers, thus highlighting the problem with the lack of coarse level features early on.<br />
<br />
To address this issue, MSDNets proposes an architecture in which uses multi scaled feature maps. The feature maps at a particular layer and scale are computed by concatenating results from up to two convolutions: a standard convolution is first applied to same-scale features from the previous layer to pass on high-resolution information that subsequent layers can use to construct better coarse features, and if possible, a strided convolution is also applied on the finer-scale feature map from the previous layer to produce coarser features amenable to classification. The network is quickly formed to contain a set number of scales ranging from fine to coarse. These scales are propagated throughout, so that for the length of the network there are always coarse level features for classification and fine features for learning more difficult representations.<br />
<br />
=== Training of Early Classifiers Interferes with Later Classifiers ===<br />
<br />
When training a network containing intermediate classifiers, the training of early classifiers will cause the early layers to focus on features for that classifier. These learned features may not be as useful to the later classifiers and degrade their accuracy.<br />
<br />
MSDNets use dense connectivity to avoid this issue. By concatenating all prior layers to learn future layers, the gradient propagation is spread throughout the available features. This allows later layers to not be reliant on any single prior, providing opportunities to learn new features that priors have ignored.<br />
<br />
== Architecture ==<br />
<br />
[[File:MSDNet_arch.png | 700px|thumb|center|Left: the MSDNet architecture. Right: example calculations for each output given 3 scales and 4 layers.]]<br />
<br />
The architecture of MSDNet is a structure of convolutions with a set number of layers and a set number of scales. Layers allow the network to build on the previous information to generate more accurate predictions, while the scales allow the network to maintain coarse level features throughout.<br />
<br />
The first layer is a special, mini-CNN-network, that quickly fills all required scales with features. The following layers are generated through the convolutions of the previous layers and scales.<br />
<br />
Each output at a given s scale is given by the convolution of all prior outputs of the same scale, and the strided-convolution of all prior outputs from the previous scale. <br />
<br />
The classifiers are run on the concatenation of all of the coarsest outputs from the preceding layers.<br />
<br />
=== Loss Function ===<br />
<br />
The loss is calculated as a weighted sum of each classifier's logistic loss. The weighted loss is taken as an average over a set of training samples. The weights can be determined from a budget of computational power, but results also show that setting all to 1 is also acceptable.<br />
<br />
=== Computational Limit Inclusion ===<br />
<br />
When running in a budgeted batch scenario, the network attempts to provide the best overall accuracy. To do this with a set limit on computational resources, it works to use less of the budget on easy detections in order to allow more time to be spent on hard ones. <br />
In order to facilitate this, the classifiers are designed to exit when the confidence of the classification exceeds a preset threshold. To determine the threshold for each classifier, <math>|D_{test}|\sum_{k}(q_k C_k) \leq B </math> must be true. Where <math>|D_{test}|</math> is the total number of test samples, <math>C_k</math> is the computational requirement to get an output from the <math>k</math>th classifier, and <math>q_k </math> is the probability that a sample exits at the <math>k</math>th classifier. Assuming that all classifiers have the same base probability, <math>q</math>, then <math>q_k</math> can be used to find the threshold.<br />
<br />
=== Network Reduction and Lazy Evaluation ===<br />
There are two ways to reduce the computational needs of MSDNets:<br />
<br />
# Reduce the size of the network by splitting it into <math>S</math> blocks along the depth dimension and keeping the <math>(S-i+1)</math> scales in the <math>i^{\text{th}}</math> block.<br />
# Remove unnecessary computations: Group the computation in "diagonal blocks"; this propagates the example along paths that are required for the evaluation of the next classifier.<br />
<br />
= Experiments = <br />
<br />
When evaluating on CIFAR-10 and CIFAR-100 ensembles and multi-classifier versions of ResNets and DenseNets, as well as FractalNet are used to compare with MSDNet. <br />
<br />
When evaluating on ImageNet ensembles and individual versions of ResNets and DenseNets are compared with MSDNets.<br />
<br />
== Anytime Prediction ==<br />
<br />
In anytime prediction MSDNets are shown to have highly accurate with very little budget, and continue to remain above the alternate methods as the budget increases. MSDNets produced low-resolution feature maps well-suited for classification after just a few layers, in contrast to the high-resolution feature maps in early layers of ResNets or DenseNets. <br />
<br />
[[File:MSDNet_anytime.png | 700px|thumb|center|Accuracy of the anytime classification models.]]<br />
<br />
== Budget Batch ==<br />
<br />
For budget batch 3 MSDNets are designed with classifiers set-up for varying ranges of budget constraints. On both dataset options the MSDNets exceed all alternate methods with a fraction of the budget required.<br />
<br />
[[File:MSDNet_budgetbatch.png | 700px|thumb|center|Accuracy of the budget batch classification models.]]<br />
<br />
= Critique = <br />
<br />
The problem formulation and scenario evaluation were very well formulated, and according to independent reviews, the results were reproducible. Where the paper could improve is on explaining how to implement the threshold; it isn't very well explained how the use of the validation set can be used to set the threshold value.<br />
<br />
= Implementation =<br />
The following repository provides the source code for the paper, written by the authors: https://github.com/gaohuang/MSDNet<br />
<br />
= Sources =<br />
# Huang, G., Chen, D., Li, T., Wu, F., Maaten, L., & Weinberger, K. Q. (n.d.). Multi-Scale Dense Networks for Resource Efficient Image Classification. ICLR 2018. doi:1703.09844 <br />
# Huang, G. (n.d.). Gaohuang/MSDNet. Retrieved March 25, 2018, from https://github.com/gaohuang/MSDNet</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Multi-scale_Dense_Networks_for_Resource_Efficient_Image_Classification&diff=35630Multi-scale Dense Networks for Resource Efficient Image Classification2018-03-27T03:21:24Z<p>X249wang: /* Coarse Level Features Needed For Classification */</p>
<hr />
<div>= Introduction = <br />
<br />
Multi-Scale Dense Networks, MSDNets, are designed to address the growing demand for efficient object recognition. The issue with existing recognition networks is that they are either:<br />
efficient networks, but don't do well on hard examples, or large networks that do well on all examples but require a large amount of resources.<br />
<br />
In order to be efficient on all difficulties MSDNets propose a structure that can accurately output classifications for varying levels of computational requirements. The two cases that are used to evaluate the network are:<br />
Anytime Prediction: What is the best prediction the network can provide when suddenly prompted.<br />
Budget Batch Predictions: Given a maximum amount of computational resources how well does the network do on the batch.<br />
<br />
= Related Networks =<br />
<br />
== Computationally Efficient Networks ==<br />
<br />
Existing methods for refining an accurate network to be more efficient include weight pruning, quantization of weights (during or after training), and knowledge distillation, which trains smaller network to match teacher network.<br />
<br />
== Resource Efficient Networks == <br />
<br />
Unlike the above, resource efficient concepts consider limited resources as a part of the structure/loss.<br />
Examples of work in this area include: <br />
* Efficient variants to existing state of the art networks<br />
* Gradient boosted decision trees, which incorporate computational limitations into the training<br />
* Fractal nets<br />
* Adaptive computation time method<br />
<br />
== Related architectures ==<br />
<br />
MSDNets pull on concepts from a number of existing networks:<br />
* Neural fabrics and others, are used to quickly establish a low resolution feature map, which is integral for classification.<br />
* Deeply supervised nets, introduced the incorporation of multiple classifiers throughout the network<br />
* The feature concatenation method from DenseNets allows the later classifiers to not be disrupted by the weight updates from earlier classifiers.<br />
<br />
= Problem Setup =<br />
The authors consider two settings that impose computational constraints at prediction time.<br />
<br />
== Anytime Prediction ==<br />
In the anytime prediction setting (Grubb & Bagnell, 2012), there is a finite computational budget <math>B > 0</math> available for each test example <math>x</math>. The budget is nondeterministic and varies per test instance.<br />
<br />
== Budgeted Batch Classification ==<br />
In the budgeted batch classification setting, the model needs to classify a set of examples <math>D_test = {x_1, . . . , x_M}</math> within a finite computational budget <math>B > 0</math> that is known in advance.<br />
<br />
= Multi-Scale Dense Networks =<br />
<br />
== Integral Contributions ==<br />
<br />
The way MSDNets aims to provide efficient classification with varying computational costs is to create one network that outputs results at depths. While this may seem trivial, as intermediate classifiers can be inserted into any existing network, two major problems arise.<br />
<br />
=== Coarse Level Features Needed For Classification ===<br />
<br />
[[File:paper29 fig3.png | 700px|thumb|center]]<br />
<br />
Coarse level features are needed to gain context of scene. In typical CNN based networks, the features propagate from fine to coarse. Classifiers added to the early, fine featured, layers do not output accurate predictions due to the lack of context.<br />
<br />
Figure 3 depicts relative accuracies of the intermediate classifiers and shows that the accuracy of a classifier is highly correlated with its position in the network. It is easy to see, specifically with the case of ResNet, that the classifiers improve in a staircase pattern. All of the experiments were performed on Cifar-100 dataset and it can be seen that the intermediate classifiers perform worst than the final classifiers, thus highlighting the problem with the lack of coarse level features early on.<br />
<br />
To address this issue, MSDNets proposes an architecture in which uses multi scaled feature maps. The feature maps at a particular layer and scale are computed by concatenating results from up to two convolutions: a standard convolution is first applied to same-scale features from the previous layer to pass on high-resolution information that subsequent layers can use to construct better coarse features, and if possible, a strided convolution is also applied on the finer-scale feature map from the previous layer to produce coarser features amenable to classification. The network is quickly formed to contain a set number of scales ranging from fine to coarse. These scales are propagated throughout, so that for the length of the network there are always coarse level features for classification and fine features for learning more difficult representations.<br />
<br />
=== Training of Early Classifiers Interferes with Later Classifiers ===<br />
<br />
When training a network containing intermediate classifiers, the training of early classifiers will cause the early layers to focus on features for that classifier. These learned features may not be as useful to the later classifiers and degrade their accuracy.<br />
<br />
MSDNets use dense connectivity to avoid this issue. By concatenating all prior layers to learn future layers, the gradient propagation is spread throughout the available features. This allows later layers to not be reliant on any single prior, providing opportunities to learn new features that priors have ignored.<br />
<br />
== Architecture ==<br />
<br />
[[File:MSDNet_arch.png | 700px|thumb|center|Left: the MSDNet architecture. Right: example calculations for each output given 3 scales and 4 layers.]]<br />
<br />
The architecture of MSDNet is a structure of convolutions with a set number of layers and a set number of scales. Layers allow the network to build on the previous information to generate more accurate predictions, while the scales allow the network to maintain coarse level features throughout.<br />
<br />
The first layer is a special, mini-CNN-network, that quickly fills all required scales with features. The following layers are generated through the convolutions of the previous layers and scales.<br />
<br />
Each output at a given s scale is given by the convolution of all prior outputs of the same scale, and the strided-convolution of all prior outputs from the previous scale. <br />
<br />
The classifiers are run on the concatenation of all of the coarsest outputs from the preceding layers.<br />
<br />
=== Loss Function ===<br />
<br />
The loss is calculated as a weighted sum of each classifier's logistic loss. The weighted loss is taken as an average over a set of training samples. The weights can be determined from a budget of computational power, but results also show that setting all to 1 is also acceptable.<br />
<br />
=== Computational Limit Inclusion ===<br />
<br />
When running in a budgeted batch scenario, the network attempts to provide the best overall accuracy. To do this with a set limit on computational resources, it works to use less of the budget on easy detections in order to allow more time to be spent on hard ones. <br />
In order to facilitate this, the classifiers are designed to exit when the confidence of the classification exceeds a preset threshold. To determine the threshold for each classifier, <math>|D_{test}|\sum_{k}(q_k C_k) \leq B </math> must be true. Where <math>|D_{test}|</math> is the total number of test samples, <math>C_k</math> is the computational requirement to get an output from the <math>k</math>th classifier, and <math>q_k </math> is the probability that a sample exits at the <math>k</math>th classifier. Assuming that all classifiers have the same base probability, <math>q</math>, then <math>q_k</math> can be used to find the threshold.<br />
<br />
=== Network Reduction and Lazy Evaluation ===<br />
There are two ways to reduce the computational needs of MSDNets:<br />
<br />
# Reduce the size of the network by splitting it into <math>S</math> blocks along the depth dimension and keeping the <math>(S-i+1)</math> scales in the <math>i^{\text{th}}</math> block.<br />
# Remove unnecessary computations: Group the computation in "diagonal blocks"; this propagates the example along paths that are required for the evaluation of the next classifier.<br />
<br />
= Experiments = <br />
<br />
When evaluating on CIFAR-10 and CIFAR-100 ensembles and multi-classifier versions of ResNets and DenseNets, as well as FractalNet are used to compare with MSDNet. <br />
<br />
When evaluating on ImageNet ensembles and individual versions of ResNets and DenseNets are compared with MSDNets.<br />
<br />
== Anytime Prediction ==<br />
<br />
In anytime prediction MSDNets are shown to have highly accurate with very little budget, and continue to remain above the alternate methods as the budget increases.<br />
<br />
[[File:MSDNet_anytime.png | 700px|thumb|center|Accuracy of the anytime classification models.]]<br />
<br />
== Budget Batch ==<br />
<br />
For budget batch 3 MSDNets are designed with classifiers set-up for varying ranges of budget constraints. On both dataset options the MSDNets exceed all alternate methods with a fraction of the budget required.<br />
<br />
[[File:MSDNet_budgetbatch.png | 700px|thumb|center|Accuracy of the budget batch classification models.]]<br />
<br />
= Critique = <br />
<br />
The problem formulation and scenario evaluation were very well formulated, and according to independent reviews, the results were reproducible. Where the paper could improve is on explaining how to implement the threshold; it isn't very well explained how the use of the validation set can be used to set the threshold value.<br />
<br />
= Implementation =<br />
The following repository provides the source code for the paper, written by the authors: https://github.com/gaohuang/MSDNet<br />
<br />
= Sources =<br />
# Huang, G., Chen, D., Li, T., Wu, F., Maaten, L., & Weinberger, K. Q. (n.d.). Multi-Scale Dense Networks for Resource Efficient Image Classification. ICLR 2018. doi:1703.09844 <br />
# Huang, G. (n.d.). Gaohuang/MSDNet. Retrieved March 25, 2018, from https://github.com/gaohuang/MSDNet</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=PointNet%2B%2B:_Deep_Hierarchical_Feature_Learning_on_Point_Sets_in_a_Metric_Space&diff=35628PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space2018-03-27T03:00:24Z<p>X249wang: /* Problem Statement */</p>
<hr />
<div>= Introduction =<br />
This paper builds off of ideas from PointNet (Qi et al., 2017). The name PointNet is derived from the network's input - a point cloud. A point cloud is a set of three dimensional points that each have coordinates <math> (x,y,z) </math>. These coordinates usually represent the surface of an object. For example, a point cloud describing the shape of a torus is shown below.<br />
<br />
[[File:Point cloud torus.gif|thumb|center|Point cloud torus]]<br />
<br />
<br />
Processing point clouds is important in applications such as autonomous driving where point clouds are collected from an onboard LiDAR sensor. These point clouds can then be used for object detection. However, point clouds are challenging to process because:<br />
<br />
# They are unordered. If <math> N </math> is the number of points in a point cloud, then there are <math> N! </math> permutations that the point cloud can be represented.<br />
# The spatial arrangement of the points contains useful information, thus it needs to be encoded.<br />
# The function processing the point cloud needs to be invariant to transformations such as rotation and translations of all points. <br />
<br />
Previously, typical point cloud processing methods handled the challenges of point clouds by transforming the data with a 3D voxel grid or by representing the point cloud with multiple 2D images. When PointNet was introduced, it was novel because it directly took points as its input. PointNet++ improves on PointNet by using a hierarchical method to better capture local structures of the point cloud. <br />
<br />
[[File:point_cloud.png | 400px|thumb|center|Examples of point clouds and their associated task. Classification (left), part segmentation (centre), scene segmentation (right) ]]<br />
<br />
= Review of PointNet =<br />
<br />
The PointNet architecture is shown below. The input of the network is <math> n </math> points, which each have <math> (x,y,z) </math> coordinates. Each point is processed individually through a multi-layer perceptron (MLP). This network creates an encoding for each point; in the diagram, each point is represented by a 1024 dimension vector. Then, using a max pool layer a vector is created that represents the "global signature" of a point cloud. If classification is the task, this global signature is processed by another MLP to compute the classification scores. If segmentation is the task, this global signature is appended to to each point from the "nx64" layer, and these points are processed by a MLP to compute a semantic category score for each point.<br />
<br />
The core idea of the network is to learn a symmetric function on transformed points. Through the T-Nets and the MLP network, a transformation is learned with the hopes of making points invariant to point cloud transformations. Learning a symmetric function solves the challenge imposed by having unordered points; a symmetric function will produce the same value no matter the order of the input. This symmetric function is represented by the max pool layer.<br />
<br />
[[File:pointnet_arch.png | 700px|thumb|center|PointNet architecture. The blue highlighted region is when it is used for classification, and the beige highlighted region is when it is used for segmentation.]]<br />
<br />
= PointNet++ =<br />
<br />
The motivation for PointNet++ is that PointNet does not capture local, fine-grained details. Since PointNet performs a max pool layer over all of its points, information such as the local interaction between points is lost.<br />
<br />
== Problem Statement ==<br />
<br />
There is a metric space <math> X = (M,d) </math> where <math>d</math> is the metric from a Euclidean space <math>\pmb{\mathbb{R}}^n</math> and <math> M \subseteq \pmb{\mathbb{R}}^n </math> is the set of points. The goal is to learn functions <math>f</math> that take <math>X</math> as the input and produce information of semantic interest about it. In practice, <math>f</math> can often be a classification function that outputs a class label or a segmentation function that outputs a per point label for each member of <math>M</math>.<br />
<br />
== Method ==<br />
<br />
=== High Level Overview ===<br />
[[File:point_net++.png | 700px|thumb|right|PointNet++ architecture]]<br />
<br />
The PointNet++ architecture is shown on the right. The core idea is that a hierarchical architecture is used and at each level of the hierarchy a set of points is processed and abstracted to a new set with less points, i.e.,<br />
<br />
\begin{aligned}<br />
\text{Input at each level: } N \times (d + c) \text{ matrix}<br />
\end{aligned}<br />
<br />
where <math>N</math> is the number of points, <math>d</math> is the coordinate points <math>(x,y,z)</math> and <math>c</math> is the feature representation of each point, and<br />
<br />
\begin{aligned}<br />
\text{Output at each level: } N' \times (d + c') \text{ matrix}<br />
\end{aligned}<br />
<br />
where <math>N'</math> is the new number (smaller) of points and <math>c'</math> is the new feature vector.<br />
<br />
<br />
Each level has three layers: Sampling, Grouping, and PointNet. The Sampling layer selects points that will act as centroids of local regions within the point cloud. The Grouping layer then finds points near these centroids. Lastly, the PointNet layer performs PointNet on each group to encode local information.<br />
<br />
=== Sampling Layer ===<br />
<br />
The input of this layer is a set of points <math>{\{x_1,x_2,...,x_n}\}</math>. The goal of this layer is to select a subset of these points <math>{\{\hat{x}_1, \hat{x}_2,...,\hat{x}_m\}} </math> that will define the centroid of local regions.<br />
<br />
To select these points farthest point sampling is used. This is where <math>\hat{x}_j</math> is the most distant point with regards to <math>{\{\hat{x}_1, \hat{x}_2,...,\hat{x}_{j-1}\}}</math>. This ensures coverage of the entire point cloud opposed to random sampling.<br />
<br />
=== Grouping Layer ===<br />
<br />
The objective of the grouping layer is to form local regions around each centroid by grouping points near the selected centroids. The input is a point set of size <math>N \times (d + c)</math> and the coordinates of the centroids <math>N' \times d</math>. The output is the groups of points within each region <math>N' \times k \times (d+c)</math> where <math>k</math> is the number of points in each region.<br />
<br />
Note that <math>k</math> can vary per group. Later, the PointNet layer creates a feature vector that is the same size for all regions at a hierarchical level.<br />
<br />
To determine which points belong to a group a ball query is used; all points within a radius of the centroid are grouped. This is advantageous over nearest neighbour because it guarantees a fixed region space, which is important when learning local structure.<br />
<br />
=== PointNet Layer ===<br />
<br />
After grouping, PointNet is applied to the points. However, first the coordinates of points in a local region are converted to a local coordinate frame by <math> x_i = x_i - \bar{x}</math> where <math>\bar{x}</math> is the coordinates of the centroid.<br />
<br />
=== Robust Feature Learning under Non-Uniform Sampling Density ===<br />
<br />
The previous description of grouping uses a single scale. This is not optimal because the density varies per section of the point cloud. At each level, it would be better if the PointNet layer was applied to adaptively sized groups depending on the point cloud density.<br />
<br />
The two grouping methods the authors propose are shown in the diagram below. Multi-scale grouping (MSG) applies PointNet at various scales per group. The features from the various scales are concatenated to form a multi-scale feature. To train the network to learn an optimal strategy for combining the multi-scale features, the authors proposed random input dropout, which involves randomly dropping input points with a random probability for each training point set. Each input point has a dropout probability <math>\theta</math>. The authors used a <math>\theta</math> value of 0.95. As shown in the experiments section below, dropout provides robustness to input point density variations. During testing stage all points are used. MSG, however, is computationally expensive because for each region it always applies PointNet at large scale neighborhoods to all points. <br />
<br />
On the other hand, multi-resolution grouping (MRG) is less computationally expensive but still adaptively collects features. As shown in the diagram, features of a region from a certain level is a concatenation of two vectors. The left vector is obtained by applying PointNet to three points, and these three points obtained information from three groups. This vector is then concatenated by a vector that is created by using PointNet on all the points in the level below. The second vector can be weighed more heavily if the first vector contains a sparse amount of points, since the first vector is based on subregions that would be even more sparse and suffer from sampling deficiency. On the other hand, when the density of a local region is high, the first vector can be weighted more heavily as it allows for inspecting at higher resolutions in the lower levels to obtain finer details. <br />
<br />
[[File:grouping.png | 300px|thumb|center|Example of the two ways to perform grouping]]<br />
<br />
== Point Cloud Segmentation ==<br />
<br />
If the task is segmentation, the architecture is slightly modified since we want a semantic score for each point. To achieve this, distance-based interpolation and skip-connections are used.<br />
<br />
=== Distance-based Interpolation ===<br />
<br />
Here, point features from <math>N_l \times (d + C)</math> points are propagated to <math>N_{l-1} \times (d + C)</math> points where <math>N_{l-1}</math> is greater than <math>N_l</math>.<br />
<br />
To propagate features an inverse distance weighted average based on <math>k</math> nearest neighbors is used. The <math>p=2</math> and <math>k=3</math>.<br />
<br />
[[File:prop_feature.png | 500px|thumb|center|Feature interpolation during segmentation]]<br />
<br />
=== Skip-connections ===<br />
<br />
In addition, skip connections are used (see the PointNet++ architecture diagram). The features from the the skip layers are concatenated with the interpolated features. Next, a "unit-wise" PointNet is applied, which the authors describe as similar to a one-by-one convolution.<br />
<br />
== Experiments ==<br />
To validate the effectiveness of PointNet++, experiments in three areas were performed - classification in Euclidean metric space, semantic scene labelling, and classification in non-Euclidean space.<br />
<br />
=== Point Set Classification in Euclidean Metric Space ===<br />
<br />
The digit dataset, MNIST, was converted to a 2D point cloud. Pixel intensities were normalized in the range of <math>[0, 1]</math>, and only pixels with intensities larger than 0.5 were considered. The coordinate system was set at the centre of the image. PointNet++ achieved a classification error of 0.51%. The original PointNet had 0.78% classification error. The table below compares these results to the state-of-the-art.<br />
<br />
[[File:mnist_results.png | 300px|thumb|center|MNIST classification results.]]<br />
<br />
In addition, the ModelNet40 dataset was used. This dataset consists of CAD models. Three dimensional point clouds were sampled from mesh surfaces of the ModelNet40 shapes. The classification results from this dataset are shown below.<br />
<br />
[[File:modelnet40.png | 300px|thumb|center|ModelNet40 classification results.]]<br />
<br />
An experiment was performed to show how the accuracy was affected by the number of points used. With PointNet++ using multi-scale grouping and dropout, the performance decreased by less than 1% when 1024 test points were reduced to 256. On the other hand, PointNet's performance was impacted by the decrease in points.<br />
<br />
[[File:paper28_fig4_chair.png | 300px|thumb|center|An example showing the reduction of points visually. At 256 points, the points making up the object is very spare, however the accuracy is only reduced by 1%]][[File:num_points_acc.png | 300px|thumb|center|Relationship between accuracy and the number of points used for classification.]]<br />
<br />
=== Semantic Scene Labelling ===<br />
<br />
The ScanNet dataset was used for experiments in semantic scene labelling. This dataset consists of laser scans of indoor scenes where the goal is to predict a semantic label for each point. Example results are shown below.<br />
<br />
[[File:scannet.png | 300px|thumb|center|Example ScanNet semantic segmentation results.]]<br />
<br />
To compare to other methods, the authors convert their point labels to a voxel format, and accuracy is determined on a per voxel basis. The accuracy compared to other methods is shown below.<br />
<br />
[[File:scannet_acc.png | 500px|thumb|center|ScanNet semantic segmentation classification comparison to other methods.]]<br />
<br />
To test how the trained model performed on scans with non-uniform sampling density, virtual scans of Scannet scenes were synthesized and the network was evaluated on this data. It can be seen from the above figures that SSG performance greatly falls due to the sampling density shift. MRG network, on the other hand, is more robust to the sampling density shift since it is able to automatically switch to features depicting coarser granularity when the sampling is sparse. This proves the effectiveness of the proposed density adaptive layer design.<br />
<br />
=== Classification in Non-Euclidean Metric Space ===<br />
<br />
[[File:shrec.png | 300px|thumb|right|Example of shapes from the SHREC15 dataset.]]<br />
<br />
Lastly, experiments were performed on the SHREC15 dataset. This dataset contains shapes that have different poses. This experiment shows that PointNet++ is able to generalize to non-Euclidean spaces. Results from this dataset are provided below.<br />
<br />
[[File:shrec15_results.png | 500px|thumb|center|Results from the SHREC15 dataset.]]<br />
<br />
=== Feature Visualization ===<br />
The figure below visualizes what is learned by just the first layer kernels of the network. The model is trained on a dataset the mostly consisted of furniture which explains the lines, corners, and planes visible in the visualization. Visualization is performed by creating a voxel grid in space and only aggregating point sets that activate specific neurons the most.<br />
<br />
[[File:26_8.PNG | 500px|thumb|center|Pointclouds learned from first layer kernels (red is near, blue is far)]]<br />
<br />
== Critique ==<br />
<br />
It seems clear that PointNet is lacking capturing local context between points. PointNet++ seems to be an important extension, but the improvements in the experimental results seem small. Some computational efficiency experiments would have been nice. For example, the processing speed of the network, and the computational efficiency of MRG over MRG.<br />
<br />
== Code ==<br />
<br />
Code for PointNet++ can be found at: https://github.com/charlesq34/pointnet2 <br />
<br />
<br />
=Sources=<br />
1. Charles R. Qi, Li Yi, Hao Su, Leonidas J. Guibas. PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space, 2017<br />
<br />
2. Charles R. Qi, Hao Su, Kaichun Mo, Leonidas J. Guibas. PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation, 2017</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=PointNet%2B%2B:_Deep_Hierarchical_Feature_Learning_on_Point_Sets_in_a_Metric_Space&diff=35622PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space2018-03-27T02:54:54Z<p>X249wang: /* Problem Statement */</p>
<hr />
<div>= Introduction =<br />
This paper builds off of ideas from PointNet (Qi et al., 2017). The name PointNet is derived from the network's input - a point cloud. A point cloud is a set of three dimensional points that each have coordinates <math> (x,y,z) </math>. These coordinates usually represent the surface of an object. For example, a point cloud describing the shape of a torus is shown below.<br />
<br />
[[File:Point cloud torus.gif|thumb|center|Point cloud torus]]<br />
<br />
<br />
Processing point clouds is important in applications such as autonomous driving where point clouds are collected from an onboard LiDAR sensor. These point clouds can then be used for object detection. However, point clouds are challenging to process because:<br />
<br />
# They are unordered. If <math> N </math> is the number of points in a point cloud, then there are <math> N! </math> permutations that the point cloud can be represented.<br />
# The spatial arrangement of the points contains useful information, thus it needs to be encoded.<br />
# The function processing the point cloud needs to be invariant to transformations such as rotation and translations of all points. <br />
<br />
Previously, typical point cloud processing methods handled the challenges of point clouds by transforming the data with a 3D voxel grid or by representing the point cloud with multiple 2D images. When PointNet was introduced, it was novel because it directly took points as its input. PointNet++ improves on PointNet by using a hierarchical method to better capture local structures of the point cloud. <br />
<br />
[[File:point_cloud.png | 400px|thumb|center|Examples of point clouds and their associated task. Classification (left), part segmentation (centre), scene segmentation (right) ]]<br />
<br />
= Review of PointNet =<br />
<br />
The PointNet architecture is shown below. The input of the network is <math> n </math> points, which each have <math> (x,y,z) </math> coordinates. Each point is processed individually through a multi-layer perceptron (MLP). This network creates an encoding for each point; in the diagram, each point is represented by a 1024 dimension vector. Then, using a max pool layer a vector is created that represents the "global signature" of a point cloud. If classification is the task, this global signature is processed by another MLP to compute the classification scores. If segmentation is the task, this global signature is appended to to each point from the "nx64" layer, and these points are processed by a MLP to compute a semantic category score for each point.<br />
<br />
The core idea of the network is to learn a symmetric function on transformed points. Through the T-Nets and the MLP network, a transformation is learned with the hopes of making points invariant to point cloud transformations. Learning a symmetric function solves the challenge imposed by having unordered points; a symmetric function will produce the same value no matter the order of the input. This symmetric function is represented by the max pool layer.<br />
<br />
[[File:pointnet_arch.png | 700px|thumb|center|PointNet architecture. The blue highlighted region is when it is used for classification, and the beige highlighted region is when it is used for segmentation.]]<br />
<br />
= PointNet++ =<br />
<br />
The motivation for PointNet++ is that PointNet does not capture local, fine-grained details. Since PointNet performs a max pool layer over all of its points, information such as the local interaction between points is lost.<br />
<br />
== Problem Statement ==<br />
<br />
There is a metric space <math> X = (M,d) </math> where <math>d</math> is the metric from a Euclidean space <math>\pmb{\mathbb{R}}^n</math> and <math> M \subseteq \pmb{\mathbb{R}}^n </math> is the set of points. The goal is to learn functions <math>f</math> that takes <math>X</math> as the input and produce information of semantic interest about it. In practice, <math>f</math> can often be a classification function that outputs a class label or a segmentation function that outputs a per point label for each member of <math>M</math>.<br />
<br />
== Method ==<br />
<br />
=== High Level Overview ===<br />
[[File:point_net++.png | 700px|thumb|right|PointNet++ architecture]]<br />
<br />
The PointNet++ architecture is shown on the right. The core idea is that a hierarchical architecture is used and at each level of the hierarchy a set of points is processed and abstracted to a new set with less points, i.e.,<br />
<br />
\begin{aligned}<br />
\text{Input at each level: } N \times (d + c) \text{ matrix}<br />
\end{aligned}<br />
<br />
where <math>N</math> is the number of points, <math>d</math> is the coordinate points <math>(x,y,z)</math> and <math>c</math> is the feature representation of each point, and<br />
<br />
\begin{aligned}<br />
\text{Output at each level: } N' \times (d + c') \text{ matrix}<br />
\end{aligned}<br />
<br />
where <math>N'</math> is the new number (smaller) of points and <math>c'</math> is the new feature vector.<br />
<br />
<br />
Each level has three layers: Sampling, Grouping, and PointNet. The Sampling layer selects points that will act as centroids of local regions within the point cloud. The Grouping layer then finds points near these centroids. Lastly, the PointNet layer performs PointNet on each group to encode local information.<br />
<br />
=== Sampling Layer ===<br />
<br />
The input of this layer is a set of points <math>{\{x_1,x_2,...,x_n}\}</math>. The goal of this layer is to select a subset of these points <math>{\{\hat{x}_1, \hat{x}_2,...,\hat{x}_m\}} </math> that will define the centroid of local regions.<br />
<br />
To select these points farthest point sampling is used. This is where <math>\hat{x}_j</math> is the most distant point with regards to <math>{\{\hat{x}_1, \hat{x}_2,...,\hat{x}_{j-1}\}}</math>. This ensures coverage of the entire point cloud opposed to random sampling.<br />
<br />
=== Grouping Layer ===<br />
<br />
The objective of the grouping layer is to form local regions around each centroid by grouping points near the selected centroids. The input is a point set of size <math>N \times (d + c)</math> and the coordinates of the centroids <math>N' \times d</math>. The output is the groups of points within each region <math>N' \times k \times (d+c)</math> where <math>k</math> is the number of points in each region.<br />
<br />
Note that <math>k</math> can vary per group. Later, the PointNet layer creates a feature vector that is the same size for all regions at a hierarchical level.<br />
<br />
To determine which points belong to a group a ball query is used; all points within a radius of the centroid are grouped. This is advantageous over nearest neighbour because it guarantees a fixed region space, which is important when learning local structure.<br />
<br />
=== PointNet Layer ===<br />
<br />
After grouping, PointNet is applied to the points. However, first the coordinates of points in a local region are converted to a local coordinate frame by <math> x_i = x_i - \bar{x}</math> where <math>\bar{x}</math> is the coordinates of the centroid.<br />
<br />
=== Robust Feature Learning under Non-Uniform Sampling Density ===<br />
<br />
The previous description of grouping uses a single scale. This is not optimal because the density varies per section of the point cloud. At each level, it would be better if the PointNet layer was applied to adaptively sized groups depending on the point cloud density.<br />
<br />
The two grouping methods the authors propose are shown in the diagram below. Multi-scale grouping (MSG) applies PointNet at various scales per group. The features from the various scales are concatenated to form a multi-scale feature. To train the network to learn an optimal strategy for combining the multi-scale features, the authors proposed random input dropout, which involves randomly dropping input points with a random probability for each training point set. Each input point has a dropout probability <math>\theta</math>. The authors used a <math>\theta</math> value of 0.95. As shown in the experiments section below, dropout provides robustness to input point density variations. During testing stage all points are used. MSG, however, is computationally expensive because for each region it always applies PointNet at large scale neighborhoods to all points. <br />
<br />
On the other hand, multi-resolution grouping (MRG) is less computationally expensive but still adaptively collects features. As shown in the diagram, features of a region from a certain level is a concatenation of two vectors. The left vector is obtained by applying PointNet to three points, and these three points obtained information from three groups. This vector is then concatenated by a vector that is created by using PointNet on all the points in the level below. The second vector can be weighed more heavily if the first vector contains a sparse amount of points, since the first vector is based on subregions that would be even more sparse and suffer from sampling deficiency. On the other hand, when the density of a local region is high, the first vector can be weighted more heavily as it allows for inspecting at higher resolutions in the lower levels to obtain finer details. <br />
<br />
[[File:grouping.png | 300px|thumb|center|Example of the two ways to perform grouping]]<br />
<br />
== Point Cloud Segmentation ==<br />
<br />
If the task is segmentation, the architecture is slightly modified since we want a semantic score for each point. To achieve this, distance-based interpolation and skip-connections are used.<br />
<br />
=== Distance-based Interpolation ===<br />
<br />
Here, point features from <math>N_l \times (d + C)</math> points are propagated to <math>N_{l-1} \times (d + C)</math> points where <math>N_{l-1}</math> is greater than <math>N_l</math>.<br />
<br />
To propagate features an inverse distance weighted average based on <math>k</math> nearest neighbors is used. The <math>p=2</math> and <math>k=3</math>.<br />
<br />
[[File:prop_feature.png | 500px|thumb|center|Feature interpolation during segmentation]]<br />
<br />
=== Skip-connections ===<br />
<br />
In addition, skip connections are used (see the PointNet++ architecture diagram). The features from the the skip layers are concatenated with the interpolated features. Next, a "unit-wise" PointNet is applied, which the authors describe as similar to a one-by-one convolution.<br />
<br />
== Experiments ==<br />
To validate the effectiveness of PointNet++, experiments in three areas were performed - classification in Euclidean metric space, semantic scene labelling, and classification in non-Euclidean space.<br />
<br />
=== Point Set Classification in Euclidean Metric Space ===<br />
<br />
The digit dataset, MNIST, was converted to a 2D point cloud. Pixel intensities were normalized in the range of <math>[0, 1]</math>, and only pixels with intensities larger than 0.5 were considered. The coordinate system was set at the centre of the image. PointNet++ achieved a classification error of 0.51%. The original PointNet had 0.78% classification error. The table below compares these results to the state-of-the-art.<br />
<br />
[[File:mnist_results.png | 300px|thumb|center|MNIST classification results.]]<br />
<br />
In addition, the ModelNet40 dataset was used. This dataset consists of CAD models. Three dimensional point clouds were sampled from mesh surfaces of the ModelNet40 shapes. The classification results from this dataset are shown below.<br />
<br />
[[File:modelnet40.png | 300px|thumb|center|ModelNet40 classification results.]]<br />
<br />
An experiment was performed to show how the accuracy was affected by the number of points used. With PointNet++ using multi-scale grouping and dropout, the performance decreased by less than 1% when 1024 test points were reduced to 256. On the other hand, PointNet's performance was impacted by the decrease in points.<br />
<br />
[[File:paper28_fig4_chair.png | 300px|thumb|center|An example showing the reduction of points visually. At 256 points, the points making up the object is very spare, however the accuracy is only reduced by 1%]][[File:num_points_acc.png | 300px|thumb|center|Relationship between accuracy and the number of points used for classification.]]<br />
<br />
=== Semantic Scene Labelling ===<br />
<br />
The ScanNet dataset was used for experiments in semantic scene labelling. This dataset consists of laser scans of indoor scenes where the goal is to predict a semantic label for each point. Example results are shown below.<br />
<br />
[[File:scannet.png | 300px|thumb|center|Example ScanNet semantic segmentation results.]]<br />
<br />
To compare to other methods, the authors convert their point labels to a voxel format, and accuracy is determined on a per voxel basis. The accuracy compared to other methods is shown below.<br />
<br />
[[File:scannet_acc.png | 500px|thumb|center|ScanNet semantic segmentation classification comparison to other methods.]]<br />
<br />
To test how the trained model performed on scans with non-uniform sampling density, virtual scans of Scannet scenes were synthesized and the network was evaluated on this data. It can be seen from the above figures that SSG performance greatly falls due to the sampling density shift. MRG network, on the other hand, is more robust to the sampling density shift since it is able to automatically switch to features depicting coarser granularity when the sampling is sparse. This proves the effectiveness of the proposed density adaptive layer design.<br />
<br />
=== Classification in Non-Euclidean Metric Space ===<br />
<br />
[[File:shrec.png | 300px|thumb|right|Example of shapes from the SHREC15 dataset.]]<br />
<br />
Lastly, experiments were performed on the SHREC15 dataset. This dataset contains shapes that have different poses. This experiment shows that PointNet++ is able to generalize to non-Euclidean spaces. Results from this dataset are provided below.<br />
<br />
[[File:shrec15_results.png | 500px|thumb|center|Results from the SHREC15 dataset.]]<br />
<br />
=== Feature Visualization ===<br />
The figure below visualizes what is learned by just the first layer kernels of the network. The model is trained on a dataset the mostly consisted of furniture which explains the lines, corners, and planes visible in the visualization. Visualization is performed by creating a voxel grid in space and only aggregating point sets that activate specific neurons the most.<br />
<br />
[[File:26_8.PNG | 500px|thumb|center|Pointclouds learned from first layer kernels (red is near, blue is far)]]<br />
<br />
== Critique ==<br />
<br />
It seems clear that PointNet is lacking capturing local context between points. PointNet++ seems to be an important extension, but the improvements in the experimental results seem small. Some computational efficiency experiments would have been nice. For example, the processing speed of the network, and the computational efficiency of MRG over MRG.<br />
<br />
== Code ==<br />
<br />
Code for PointNet++ can be found at: https://github.com/charlesq34/pointnet2 <br />
<br />
<br />
=Sources=<br />
1. Charles R. Qi, Li Yi, Hao Su, Leonidas J. Guibas. PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space, 2017<br />
<br />
2. Charles R. Qi, Hao Su, Kaichun Mo, Leonidas J. Guibas. PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation, 2017</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=PointNet%2B%2B:_Deep_Hierarchical_Feature_Learning_on_Point_Sets_in_a_Metric_Space&diff=35620PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space2018-03-27T02:45:08Z<p>X249wang: /* Robust Feature Learning under Non-Uniform Sampling Density */</p>
<hr />
<div>= Introduction =<br />
This paper builds off of ideas from PointNet (Qi et al., 2017). The name PointNet is derived from the network's input - a point cloud. A point cloud is a set of three dimensional points that each have coordinates <math> (x,y,z) </math>. These coordinates usually represent the surface of an object. For example, a point cloud describing the shape of a torus is shown below.<br />
<br />
[[File:Point cloud torus.gif|thumb|center|Point cloud torus]]<br />
<br />
<br />
Processing point clouds is important in applications such as autonomous driving where point clouds are collected from an onboard LiDAR sensor. These point clouds can then be used for object detection. However, point clouds are challenging to process because:<br />
<br />
# They are unordered. If <math> N </math> is the number of points in a point cloud, then there are <math> N! </math> permutations that the point cloud can be represented.<br />
# The spatial arrangement of the points contains useful information, thus it needs to be encoded.<br />
# The function processing the point cloud needs to be invariant to transformations such as rotation and translations of all points. <br />
<br />
Previously, typical point cloud processing methods handled the challenges of point clouds by transforming the data with a 3D voxel grid or by representing the point cloud with multiple 2D images. When PointNet was introduced, it was novel because it directly took points as its input. PointNet++ improves on PointNet by using a hierarchical method to better capture local structures of the point cloud. <br />
<br />
[[File:point_cloud.png | 400px|thumb|center|Examples of point clouds and their associated task. Classification (left), part segmentation (centre), scene segmentation (right) ]]<br />
<br />
= Review of PointNet =<br />
<br />
The PointNet architecture is shown below. The input of the network is <math> n </math> points, which each have <math> (x,y,z) </math> coordinates. Each point is processed individually through a multi-layer perceptron (MLP). This network creates an encoding for each point; in the diagram, each point is represented by a 1024 dimension vector. Then, using a max pool layer a vector is created that represents the "global signature" of a point cloud. If classification is the task, this global signature is processed by another MLP to compute the classification scores. If segmentation is the task, this global signature is appended to to each point from the "nx64" layer, and these points are processed by a MLP to compute a semantic category score for each point.<br />
<br />
The core idea of the network is to learn a symmetric function on transformed points. Through the T-Nets and the MLP network, a transformation is learned with the hopes of making points invariant to point cloud transformations. Learning a symmetric function solves the challenge imposed by having unordered points; a symmetric function will produce the same value no matter the order of the input. This symmetric function is represented by the max pool layer.<br />
<br />
[[File:pointnet_arch.png | 700px|thumb|center|PointNet architecture. The blue highlighted region is when it is used for classification, and the beige highlighted region is when it is used for segmentation.]]<br />
<br />
= PointNet++ =<br />
<br />
The motivation for PointNet++ is that PointNet does not capture local, fine-grained details. Since PointNet performs a max pool layer over all of its points, information such as the local interaction between points is lost.<br />
<br />
== Problem Statement ==<br />
<br />
There is a metric space <math> X = (M,d) </math> where <math>d</math> is the metric from a Euclidean space <math>\pmb{\mathbb{R}}^n</math> and <math> M \subseteq \pmb{\mathbb{R}}^n </math> is the set of points. The goal is to learn a function that takes <math>X</math> as the input and output a class or per point label to each member of <math>M</math>.<br />
<br />
== Method ==<br />
<br />
=== High Level Overview ===<br />
[[File:point_net++.png | 700px|thumb|right|PointNet++ architecture]]<br />
<br />
The PointNet++ architecture is shown on the right. The core idea is that a hierarchical architecture is used and at each level of the hierarchy a set of points is processed and abstracted to a new set with less points, i.e.,<br />
<br />
\begin{aligned}<br />
\text{Input at each level: } N \times (d + c) \text{ matrix}<br />
\end{aligned}<br />
<br />
where <math>N</math> is the number of points, <math>d</math> is the coordinate points <math>(x,y,z)</math> and <math>c</math> is the feature representation of each point, and<br />
<br />
\begin{aligned}<br />
\text{Output at each level: } N' \times (d + c') \text{ matrix}<br />
\end{aligned}<br />
<br />
where <math>N'</math> is the new number (smaller) of points and <math>c'</math> is the new feature vector.<br />
<br />
<br />
Each level has three layers: Sampling, Grouping, and PointNet. The Sampling layer selects points that will act as centroids of local regions within the point cloud. The Grouping layer then finds points near these centroids. Lastly, the PointNet layer performs PointNet on each group to encode local information.<br />
<br />
=== Sampling Layer ===<br />
<br />
The input of this layer is a set of points <math>{\{x_1,x_2,...,x_n}\}</math>. The goal of this layer is to select a subset of these points <math>{\{\hat{x}_1, \hat{x}_2,...,\hat{x}_m\}} </math> that will define the centroid of local regions.<br />
<br />
To select these points farthest point sampling is used. This is where <math>\hat{x}_j</math> is the most distant point with regards to <math>{\{\hat{x}_1, \hat{x}_2,...,\hat{x}_{j-1}\}}</math>. This ensures coverage of the entire point cloud opposed to random sampling.<br />
<br />
=== Grouping Layer ===<br />
<br />
The objective of the grouping layer is to form local regions around each centroid by grouping points near the selected centroids. The input is a point set of size <math>N \times (d + c)</math> and the coordinates of the centroids <math>N' \times d</math>. The output is the groups of points within each region <math>N' \times k \times (d+c)</math> where <math>k</math> is the number of points in each region.<br />
<br />
Note that <math>k</math> can vary per group. Later, the PointNet layer creates a feature vector that is the same size for all regions at a hierarchical level.<br />
<br />
To determine which points belong to a group a ball query is used; all points within a radius of the centroid are grouped. This is advantageous over nearest neighbour because it guarantees a fixed region space, which is important when learning local structure.<br />
<br />
=== PointNet Layer ===<br />
<br />
After grouping, PointNet is applied to the points. However, first the coordinates of points in a local region are converted to a local coordinate frame by <math> x_i = x_i - \bar{x}</math> where <math>\bar{x}</math> is the coordinates of the centroid.<br />
<br />
=== Robust Feature Learning under Non-Uniform Sampling Density ===<br />
<br />
The previous description of grouping uses a single scale. This is not optimal because the density varies per section of the point cloud. At each level, it would be better if the PointNet layer was applied to adaptively sized groups depending on the point cloud density.<br />
<br />
The two grouping methods the authors propose are shown in the diagram below. Multi-scale grouping (MSG) applies PointNet at various scales per group. The features from the various scales are concatenated to form a multi-scale feature. To train the network to learn an optimal strategy for combining the multi-scale features, the authors proposed random input dropout, which involves randomly dropping input points with a random probability for each training point set. Each input point has a dropout probability <math>\theta</math>. The authors used a <math>\theta</math> value of 0.95. As shown in the experiments section below, dropout provides robustness to input point density variations. During testing stage all points are used. MSG, however, is computationally expensive because for each region it always applies PointNet at large scale neighborhoods to all points. <br />
<br />
On the other hand, multi-resolution grouping (MRG) is less computationally expensive but still adaptively collects features. As shown in the diagram, features of a region from a certain level is a concatenation of two vectors. The left vector is obtained by applying PointNet to three points, and these three points obtained information from three groups. This vector is then concatenated by a vector that is created by using PointNet on all the points in the level below. The second vector can be weighed more heavily if the first vector contains a sparse amount of points, since the first vector is based on subregions that would be even more sparse and suffer from sampling deficiency. On the other hand, when the density of a local region is high, the first vector can be weighted more heavily as it allows for inspecting at higher resolutions in the lower levels to obtain finer details. <br />
<br />
[[File:grouping.png | 300px|thumb|center|Example of the two ways to perform grouping]]<br />
<br />
== Point Cloud Segmentation ==<br />
<br />
If the task is segmentation, the architecture is slightly modified since we want a semantic score for each point. To achieve this, distance-based interpolation and skip-connections are used.<br />
<br />
=== Distance-based Interpolation ===<br />
<br />
Here, point features from <math>N_l \times (d + C)</math> points are propagated to <math>N_{l-1} \times (d + C)</math> points where <math>N_{l-1}</math> is greater than <math>N_l</math>.<br />
<br />
To propagate features an inverse distance weighted average based on <math>k</math> nearest neighbors is used. The <math>p=2</math> and <math>k=3</math>.<br />
<br />
[[File:prop_feature.png | 500px|thumb|center|Feature interpolation during segmentation]]<br />
<br />
=== Skip-connections ===<br />
<br />
In addition, skip connections are used (see the PointNet++ architecture diagram). The features from the the skip layers are concatenated with the interpolated features. Next, a "unit-wise" PointNet is applied, which the authors describe as similar to a one-by-one convolution.<br />
<br />
== Experiments ==<br />
To validate the effectiveness of PointNet++, experiments in three areas were performed - classification in Euclidean metric space, semantic scene labelling, and classification in non-Euclidean space.<br />
<br />
=== Point Set Classification in Euclidean Metric Space ===<br />
<br />
The digit dataset, MNIST, was converted to a 2D point cloud. Pixel intensities were normalized in the range of <math>[0, 1]</math>, and only pixels with intensities larger than 0.5 were considered. The coordinate system was set at the centre of the image. PointNet++ achieved a classification error of 0.51%. The original PointNet had 0.78% classification error. The table below compares these results to the state-of-the-art.<br />
<br />
[[File:mnist_results.png | 300px|thumb|center|MNIST classification results.]]<br />
<br />
In addition, the ModelNet40 dataset was used. This dataset consists of CAD models. Three dimensional point clouds were sampled from mesh surfaces of the ModelNet40 shapes. The classification results from this dataset are shown below.<br />
<br />
[[File:modelnet40.png | 300px|thumb|center|ModelNet40 classification results.]]<br />
<br />
An experiment was performed to show how the accuracy was affected by the number of points used. With PointNet++ using multi-scale grouping and dropout, the performance decreased by less than 1% when 1024 test points were reduced to 256. On the other hand, PointNet's performance was impacted by the decrease in points.<br />
<br />
[[File:paper28_fig4_chair.png | 300px|thumb|center|An example showing the reduction of points visually. At 256 points, the points making up the object is very spare, however the accuracy is only reduced by 1%]][[File:num_points_acc.png | 300px|thumb|center|Relationship between accuracy and the number of points used for classification.]]<br />
<br />
=== Semantic Scene Labelling ===<br />
<br />
The ScanNet dataset was used for experiments in semantic scene labelling. This dataset consists of laser scans of indoor scenes where the goal is to predict a semantic label for each point. Example results are shown below.<br />
<br />
[[File:scannet.png | 300px|thumb|center|Example ScanNet semantic segmentation results.]]<br />
<br />
To compare to other methods, the authors convert their point labels to a voxel format, and accuracy is determined on a per voxel basis. The accuracy compared to other methods is shown below.<br />
<br />
[[File:scannet_acc.png | 500px|thumb|center|ScanNet semantic segmentation classification comparison to other methods.]]<br />
<br />
To test how the trained model performed on scans with non-uniform sampling density, virtual scans of Scannet scenes were synthesized and the network was evaluated on this data. It can be seen from the above figures that SSG performance greatly falls due to the sampling density shift. MRG network, on the other hand, is more robust to the sampling density shift since it is able to automatically switch to features depicting coarser granularity when the sampling is sparse. This proves the effectiveness of the proposed density adaptive layer design.<br />
<br />
=== Classification in Non-Euclidean Metric Space ===<br />
<br />
[[File:shrec.png | 300px|thumb|right|Example of shapes from the SHREC15 dataset.]]<br />
<br />
Lastly, experiments were performed on the SHREC15 dataset. This dataset contains shapes that have different poses. This experiment shows that PointNet++ is able to generalize to non-Euclidean spaces. Results from this dataset are provided below.<br />
<br />
[[File:shrec15_results.png | 500px|thumb|center|Results from the SHREC15 dataset.]]<br />
<br />
=== Feature Visualization ===<br />
The figure below visualizes what is learned by just the first layer kernels of the network. The model is trained on a dataset the mostly consisted of furniture which explains the lines, corners, and planes visible in the visualization. Visualization is performed by creating a voxel grid in space and only aggregating point sets that activate specific neurons the most.<br />
<br />
[[File:26_8.PNG | 500px|thumb|center|Pointclouds learned from first layer kernels (red is near, blue is far)]]<br />
<br />
== Critique ==<br />
<br />
It seems clear that PointNet is lacking capturing local context between points. PointNet++ seems to be an important extension, but the improvements in the experimental results seem small. Some computational efficiency experiments would have been nice. For example, the processing speed of the network, and the computational efficiency of MRG over MRG.<br />
<br />
== Code ==<br />
<br />
Code for PointNet++ can be found at: https://github.com/charlesq34/pointnet2 <br />
<br />
<br />
=Sources=<br />
1. Charles R. Qi, Li Yi, Hao Su, Leonidas J. Guibas. PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space, 2017<br />
<br />
2. Charles R. Qi, Hao Su, Kaichun Mo, Leonidas J. Guibas. PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation, 2017</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Neural_Audio_Synthesis_of_Musical_Notes_with_WaveNet_autoencoders&diff=35397Neural Audio Synthesis of Musical Notes with WaveNet autoencoders2018-03-23T20:46:22Z<p>X249wang: /* WaveNet Autoencoder */</p>
<hr />
<div>= Introduction =<br />
The authors of this paper have pointed out that the method in which most notes are created are hand-designed instruments modifying pitch, velocity and filter parameters to produce the required tone, timbre and dynamics of a sound. The authors suggest that this may be a problem and thus suggest a data-driven approach to audio synthesis. To train such a data expensive model the authors highlight the need for a large dataset much like imagenet for music. <br />
<br />
= Contributions =<br />
To solve the problem highlighted above the authors propose two main contributions of their paper: <br />
* Wavenet-style autoencoder that learn to encode temural data over a long term audio structures without requiring external conditioning<br />
* NSynth: a large dataset of musical notes inspired by the emerging of large image datasets<br />
<br />
<br />
= Models =<br />
<br />
[[File:paper26-figure1-models.png|center]]<br />
<br />
== WaveNet Autoencoder ==<br />
<br />
While the proposed autoencoder structure is very similar to that of WaveNet the authors argue that the algorithm is novel in two ways:<br />
* It is able to attain consistent long-term structure without any external conditioning <br />
* Creating meaningful embedding which can be interpolated between<br />
The authors accomplish this by passing the raw audio throw the encoder to produce an embedding <math>Z = f(x) </math>, next the input is shifted and feed into the decoder which reproduces the input. The resulting probability distribution: <br />
<br />
\begin{align}<br />
p(x) = \prod_{i=1}^N\{x_i | x_1, … , x_N-1, f(x) \}<br />
\end{align}<br />
<br />
A detailed block diagram of the modified WaveNet structure can be seen in figure 1b. This diagram demonstrates the encoder as a 30 layer network in each each node is a ReLU nonlinearity followed by a non-causal dilated convolution. Dilated convolution (aka convolutions with holes) is a type of convolution in which the filter skips input values with a certain step (step size of 1 is equivalent to the standard convolution), effectively allowing the network to operate at a coarser scale compared to traditional convolutional layers and have very large receptive fields. The resulting convolution is 128 channels all feed into another ReLU nonlinearity which is feed into another 1x1 convolution before getting down sampled with average pooling to produce a 16 dimension <math>Z </math> distribution. Each <math>Z </math> encoding is for a specific temporal resolution which the authors of the paper tuned to 32ms. This means that there are 125, 16 dimension <math>Z </math> encodings for each 4 second note present in the NSynth database (1984 embeddings). <br />
Before the <math>Z </math> embedding enters the decoder it is first upsampled to the original audio rate using nearest neighbor interpolation. The embedding then passes through the decoder to recreate the original audio note. The input audio data is first quantized using 8-bit mu-law encoding into 256 possible values, and the output prediction is the softmax over the possible values.<br />
<br />
== Baseline: Spectral Autoencoder ==<br />
Being unable to find an alternative fully deep model which the authors could use to compare to there proposed WaveNet autoencoder to, the authors just made a strong baseline. The baseline algorithm that the authors developed is a spectral autoencoder. The block diagram of its architecture can be seen in figure 1a. The baseline network is 10 layer deep. Each layer has a 4x4 kernels with 2x2 strides followed by a leaky-ReLU (0.1) and batch normalization. The final hidden vector(Z) was set to 1984 to exactly match the hidden vector of the WaveNet autoencoder. <br />
<br />
The authors attempted to train the baseline on multiple input: raw waveforms, FFT, and log magnitude of spectrum finding the latter to be best correlated with perceptual distortion. The authors also explored several representations of phase, finding that estimating magnitude and using established iterative techniques to reconstruct phase to be most effective. A final heuristic that was used by the authors to increase the accuracy of the baseline was weighting the mean square error (MSE) loss starting at 10 for 0 HZ and decreasing linearly to 1 at 4000 Hz and above. This is valid as the fundamental frequency of most instrument are found at lower frequencies. <br />
<br />
== Training ==<br />
Both the modified WaveNet and the baseline autoencoder used stochastic gradient descent with an Adam optimizer. The authors trained the baseline autoencoder model asynchronously for 1800000 epocs with a batch size of 8 with a learning rate of 1e-4. Where as the WaveNet modules were trained synchronously for 250000 epocs with a batch size of 32 with a decaying learning rate ranging from 2e-4 to 6e-6.<br />
<br />
= The NSynth Dataset =<br />
The NSynth dataset has 306 043 unique musical notes all 4 seconds in length sampled at 16,000 Hz. The data set consists of 1006 different instruments playing on average of 65.4 different pitches across on average 4.75 different velocities. Average pitches and velocities are used as not all instruments, can reach all 88 MIDI frequencies, or the 5 velocities desired by the authors. The dataset has the following split: training set with 289,205 notes, validation set with 12,678 notes, and test set with 4,096 notes.<br />
<br />
Along with each note the authors also included the following annotations:<br />
* Source - The way each sound was produced. There were 3 classes ‘acoustic’, ‘electronic’ and ‘synthetic’<br />
* Family - The family class of instruments that produced each note. There is 11 classes which include: {‘bass’, ‘brass’, ‘vocal’ ext.}<br />
* Qualities - Sonic qualities about each note<br />
<br />
The full dataset is publicly available here: https://magenta.tensorflow.org/datasets/nsynth.<br />
<br />
<br />
<br />
= Evaluation =<br />
<br />
To fully analyze all aspects of WaveNet the authors proposed three evaluations:<br />
* Reconstruction - Both Quantitative and Qualitative analysis were considered<br />
* Interpolation in Timbre and Dynamics<br />
* Entanglement of Pitch and Timbre <br />
<br />
Sound is historically very difficult to quantify from a picture representation as it requires training and expertise to analyze. Even with expertise it can be difficult to complete a full analyses as two very different sound can look quite similar in the respective pictorial representation. This is why the authors recommend all readers to listen to the created notes which can be sound here: https://magenta.tensorflow.org/nsynth.<br />
<br />
However, even when taking this under consideration the authors do pictorially demonstrate differences in the two proposed algorithms along with the original note, as it is hard to publish a paper with sound included. To demonstrate the pictorial difference the authors demonstrate each note using constant-q transform (CQT) which is able to capture the dynamics of timbre along with representing the frequencies of the sound.<br />
<br />
== Reconstruction ==<br />
<br />
[[File:paper27-figure2-reconstruction.png|center]]<br />
<br />
=== Qualitative Comparison ===<br />
In the Glockenspiel the WaveNet autoencoder is able to reproduce the magnitude, phase of the fundamental frequency (A and C in figure 2), and the attack (B in figure 2) of the instrument; Whereas the Baseline autoencoder introduces non existing harmonics (D in figure 2). The flugelhorn on the other hand, presents the starkest difference between the WaveNet and baseline autoencoders. The WaveNet while not perfect is able to reproduce the verbarto (I and J in figure 2) across multiple frequencies, which results in a natural sounding note. The baseline not only fails to do this but also adds extra noise (K in figure 2). The authors do add that the WaveNet produces some strikes (L in figure 2) however they argue that they are inaudible.<br />
<br />
[[File:paper27-table1.png|center]]<br />
<br />
=== Quantitative Comparison ===<br />
For a quantitative comparison the authors trained a separate multi-task classifier to classify a note using given pitch or quality of a note. The results of both the Baseline and the WaveNet where then inputted and attempted to be classified. As seen in table 1 WaveNet significantly outperformed the Baseline in both metrics posting a ~70% increase when only considering pitch.<br />
<br />
== Interpolation in Timbre and Dynamics ==<br />
<br />
[[File:paper27-figure3-interpolation.png|center]]<br />
<br />
For this evaluation the authors reconstructed from linear interpolations in Z space among different instruments and compared these to superimposed position of the original two instruments. Not surprisingly the model fuse aspects of both instruments during the recreation. The authors claim however, that WaveNet produces much more realistic sounding results. <br />
To support their claim the authors the authors point to WaveNet ability to create dynamic mixing of overtone in time, even jumping to higher harmonics (A in figure 3), capturing the timbre and dynamics of both the bass and flute. This can be once again seen in (B in figure 3) where Wavenet adds additional harmonics as well as a sub-harmonics to the original flute note. <br />
<br />
<br />
== Entanglement of Pitch and Timbre ==<br />
<br />
[[File:paper27-table2.png|center]]<br />
<br />
[[File:paper27-figure4-entanglement.png|center]]<br />
<br />
To study the entanglement between pitch and Z space the authors constructed a classifier which was expected to drop in accuracy if the representation of pitch and timbre is disentangled as it relies heavily on the pitch information. This is clearly demonstrated by the first two rows of table 2 where WaveNet relies more strongly on pitch then the baseline algorithm. The authors provide a more qualitative demonstrating in figure 4. They demonstrate a situation in which a classifier may be confused; a note with pitch of +12 is almost exactly the same as the original apart from an emergence of sub-harmonics.<br />
<br />
Further insight can be gained on the relationship between pitch and timbre by studying the trend amongst the network embeddings among the pitches for specific instruments. This is depicted in figure 5 for several instruments across their entire 88 note range at 127 velocity. It can be noted from the figure that the instruments have unique separation of two or more registers over which the embeddings of notes with different pitches are similar. This is expected since instrumental dynamics and timbre varies dramatically over the range of the instrument.<br />
<br />
= Future Directions =<br />
<br />
One significant area which the authors claim great improvement is needed is the large memory constraints required by there algorithm. Due to the large memory requirement the current WaveNet must rely on down sampling thus being unable to fully capture the global context. <br />
<br />
= Open Source Code base =<br />
<br />
Google has released all code related to this paper at the following open source repository: https://github.com/tensorflow/magenta/tree/master/magenta/models/nsynth<br />
<br />
= References =<br />
<br />
# Engel, J., Resnick, C., Roberts, A., Dieleman, S., Norouzi, M., Eck, D. & Simonyan, K.. (2017). Neural Audio Synthesis of Musical Notes with WaveNet Autoencoders. Proceedings of the 34th International Conference on Machine Learning, in PMLR 70:1068-1077<br />
# NSynth: Neural Audio Synthesis. (2017, April 06). Retrieved March 19, 2018, from https://magenta.tensorflow.org/nsynth <br />
# The NSynth Dataset. (2017, April 05). Retrieved March 19, 2018, from https://magenta.tensorflow.org/datasets/nsynth</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Neural_Audio_Synthesis_of_Musical_Notes_with_WaveNet_autoencoders&diff=35396Neural Audio Synthesis of Musical Notes with WaveNet autoencoders2018-03-23T20:17:21Z<p>X249wang: /* WaveNet Autoencoder */</p>
<hr />
<div>= Introduction =<br />
The authors of this paper have pointed out that the method in which most notes are created are hand-designed instruments modifying pitch, velocity and filter parameters to produce the required tone, timbre and dynamics of a sound. The authors suggest that this may be a problem and thus suggest a data-driven approach to audio synthesis. To train such a data expensive model the authors highlight the need for a large dataset much like imagenet for music. <br />
<br />
= Contributions =<br />
To solve the problem highlighted above the authors propose two main contributions of their paper: <br />
* Wavenet-style autoencoder that learn to encode temural data over a long term audio structures without requiring external conditioning<br />
* NSynth: a large dataset of musical notes inspired by the emerging of large image datasets<br />
<br />
<br />
= Models =<br />
<br />
[[File:paper26-figure1-models.png|center]]<br />
<br />
== WaveNet Autoencoder ==<br />
<br />
While the proposed autoencoder structure is very similar to that of WaveNet the authors argue that the algorithm is novel in two ways:<br />
* It is able to attain consistent long-term structure without any external conditioning <br />
* Creating meaningful embedding which can be interpolated between<br />
The authors accomplish this by passing the raw audio throw the encoder to produce an embedding <math>Z = f(x) </math>, next the input is shifted and feed into the decoder which reproduces the input. The resulting probability distribution: <br />
<br />
\begin{align}<br />
p(x) = \prod_{i=1}^N\{x_i | x_1, … , x_N-1, f(x) \}<br />
\end{align}<br />
<br />
A detailed block diagram of the modified WaveNet structure can be seen in figure 1b. This diagram demonstrates the encoder as a 30 layer network in each each node is a ReLU nonlinearity followed by a non-causal dilated convolution. Dilated convolution (aka convolutions with holes) is a type of convolution in which the filter skips input values with a certain step (step size of 1 is equivalent to the standard convolution), effectively allowing the network to operate at a coarser scale compared to traditional convolutional layers and have very large receptive fields. The resulting convolution is 128 channels all feed into another ReLU nonlinearity which is feed into another 1x1 convolution before getting down sampled with average pooling to produce a 16 dimension <math>Z </math> distribution. Each <math>Z </math> encoding is for a specific temporal resolution which the authors of the paper tuned to 32ms. This means that there are 125, 16 dimension <math>Z </math> encodings for each 4 second note present in the NSynth database (1984 embeddings). <br />
Before the <math>Z </math> embedding enters the decoder it is first upsampled to the original audio rate using nearest neighbor interpolation. The embedding then passes through the decoder to recreate the original audio note.<br />
<br />
== Baseline: Spectral Autoencoder ==<br />
Being unable to find an alternative fully deep model which the authors could use to compare to there proposed WaveNet autoencoder to, the authors just made a strong baseline. The baseline algorithm that the authors developed is a spectral autoencoder. The block diagram of its architecture can be seen in figure 1a. The baseline network is 10 layer deep. Each layer has a 4x4 kernels with 2x2 strides followed by a leaky-ReLU (0.1) and batch normalization. The final hidden vector(Z) was set to 1984 to exactly match the hidden vector of the WaveNet autoencoder. <br />
<br />
The authors attempted to train the baseline on multiple input: raw waveforms, FFT, and log magnitude of spectrum finding the latter to be best correlated with perceptual distortion. The authors also explored several representations of phase, finding that estimating magnitude and using established iterative techniques to reconstruct phase to be most effective. A final heuristic that was used by the authors to increase the accuracy of the baseline was weighting the mean square error (MSE) loss starting at 10 for 0 HZ and decreasing linearly to 1 at 4000 Hz and above. This is valid as the fundamental frequency of most instrument are found at lower frequencies. <br />
<br />
== Training ==<br />
Both the modified WaveNet and the baseline autoencoder used stochastic gradient descent with an Adam optimizer. The authors trained the baseline autoencoder model asynchronously for 1800000 epocs with a batch size of 8 with a learning rate of 1e-4. Where as the WaveNet modules were trained synchronously for 250000 epocs with a batch size of 32 with a decaying learning rate ranging from 2e-4 to 6e-6.<br />
<br />
= The NSynth Dataset =<br />
The NSynth dataset has 306 043 unique musical notes all 4 seconds in length sampled at 16,000 Hz. The data set consists of 1006 different instruments playing on average of 65.4 different pitches across on average 4.75 different velocities. Average pitches and velocities are used as not all instruments, can reach all 88 MIDI frequencies, or the 5 velocities desired by the authors. The dataset has the following split: training set with 289,205 notes, validation set with 12,678 notes, and test set with 4,096 notes.<br />
<br />
Along with each note the authors also included the following annotations:<br />
* Source - The way each sound was produced. There were 3 classes ‘acoustic’, ‘electronic’ and ‘synthetic’<br />
* Family - The family class of instruments that produced each note. There is 11 classes which include: {‘bass’, ‘brass’, ‘vocal’ ext.}<br />
* Qualities - Sonic qualities about each note<br />
<br />
The full dataset is publicly available here: https://magenta.tensorflow.org/datasets/nsynth.<br />
<br />
<br />
<br />
= Evaluation =<br />
<br />
To fully analyze all aspects of WaveNet the authors proposed three evaluations:<br />
* Reconstruction - Both Quantitative and Qualitative analysis were considered<br />
* Interpolation in Timbre and Dynamics<br />
* Entanglement of Pitch and Timbre <br />
<br />
Sound is historically very difficult to quantify from a picture representation as it requires training and expertise to analyze. Even with expertise it can be difficult to complete a full analyses as two very different sound can look quite similar in the respective pictorial representation. This is why the authors recommend all readers to listen to the created notes which can be sound here: https://magenta.tensorflow.org/nsynth.<br />
<br />
However, even when taking this under consideration the authors do pictorially demonstrate differences in the two proposed algorithms along with the original note, as it is hard to publish a paper with sound included. To demonstrate the pictorial difference the authors demonstrate each note using constant-q transform (CQT) which is able to capture the dynamics of timbre along with representing the frequencies of the sound.<br />
<br />
== Reconstruction ==<br />
<br />
[[File:paper27-figure2-reconstruction.png|center]]<br />
<br />
=== Qualitative Comparison ===<br />
In the Glockenspiel the WaveNet autoencoder is able to reproduce the magnitude, phase of the fundamental frequency (A and C in figure 2), and the attack (B in figure 2) of the instrument; Whereas the Baseline autoencoder introduces non existing harmonics (D in figure 2). The flugelhorn on the other hand, presents the starkest difference between the WaveNet and baseline autoencoders. The WaveNet while not perfect is able to reproduce the verbarto (I and J in figure 2) across multiple frequencies, which results in a natural sounding note. The baseline not only fails to do this but also adds extra noise (K in figure 2). The authors do add that the WaveNet produces some strikes (L in figure 2) however they argue that they are inaudible.<br />
<br />
[[File:paper27-table1.png|center]]<br />
<br />
=== Quantitative Comparison ===<br />
For a quantitative comparison the authors trained a separate multi-task classifier to classify a note using given pitch or quality of a note. The results of both the Baseline and the WaveNet where then inputted and attempted to be classified. As seen in table 1 WaveNet significantly outperformed the Baseline in both metrics posting a ~70% increase when only considering pitch.<br />
<br />
== Interpolation in Timbre and Dynamics ==<br />
<br />
[[File:paper27-figure3-interpolation.png|center]]<br />
<br />
For this evaluation the authors reconstructed from linear interpolations in Z space among different instruments and compared these to superimposed position of the original two instruments. Not surprisingly the model fuse aspects of both instruments during the recreation. The authors claim however, that WaveNet produces much more realistic sounding results. <br />
To support their claim the authors the authors point to WaveNet ability to create dynamic mixing of overtone in time, even jumping to higher harmonics (A in figure 3), capturing the timbre and dynamics of both the bass and flute. This can be once again seen in (B in figure 3) where Wavenet adds additional harmonics as well as a sub-harmonics to the original flute note. <br />
<br />
<br />
== Entanglement of Pitch and Timbre ==<br />
<br />
[[File:paper27-table2.png|center]]<br />
<br />
[[File:paper27-figure4-entanglement.png|center]]<br />
<br />
To study the entanglement between pitch and Z space the authors constructed a classifier which was expected to drop in accuracy if the representation of pitch and timbre is disentangled as it relies heavily on the pitch information. This is clearly demonstrated by the first two rows of table 2 where WaveNet relies more strongly on pitch then the baseline algorithm. The authors provide a more qualitative demonstrating in figure 4. They demonstrate a situation in which a classifier may be confused; a note with pitch of +12 is almost exactly the same as the original apart from an emergence of sub-harmonics.<br />
<br />
Further insight can be gained on the relationship between pitch and timbre by studying the trend amongst the network embeddings among the pitches for specific instruments. This is depicted in figure 5 for several instruments across their entire 88 note range at 127 velocity. It can be noted from the figure that the instruments have unique separation of two or more registers over which the embeddings of notes with different pitches are similar. This is expected since instrumental dynamics and timbre varies dramatically over the range of the instrument.<br />
<br />
= Future Directions =<br />
<br />
One significant area which the authors claim great improvement is needed is the large memory constraints required by there algorithm. Due to the large memory requirement the current WaveNet must rely on down sampling thus being unable to fully capture the global context. <br />
<br />
= Open Source Code base =<br />
<br />
Google has released all code related to this paper at the following open source repository: https://github.com/tensorflow/magenta/tree/master/magenta/models/nsynth<br />
<br />
= References =<br />
<br />
# Engel, J., Resnick, C., Roberts, A., Dieleman, S., Norouzi, M., Eck, D. & Simonyan, K.. (2017). Neural Audio Synthesis of Musical Notes with WaveNet Autoencoders. Proceedings of the 34th International Conference on Machine Learning, in PMLR 70:1068-1077<br />
# NSynth: Neural Audio Synthesis. (2017, April 06). Retrieved March 19, 2018, from https://magenta.tensorflow.org/nsynth <br />
# The NSynth Dataset. (2017, April 05). Retrieved March 19, 2018, from https://magenta.tensorflow.org/datasets/nsynth</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Spherical_CNNs&diff=35149Spherical CNNs2018-03-22T05:54:16Z<p>X249wang: /* Correlations on the Sphere and Rotation Group */</p>
<hr />
<div>= Introduction =<br />
Convolutional Neural Networks (CNNs), or network architectures involving CNNs, are the current state of the art for learning 2D image processing tasks such as semantic segmentation and object detection. CNNs work well in large part due to the property of being translationally equivariant. This property allows a network trained to detect a certain type of object to still detect the object even if it is translated to another position in the image. However, this does not correspond well to spherical signals since projecting a spherical signal onto a plane will result in distortions, as demonstrated in Figure 1. There are many different types of spherical projections onto a 2D plane, as most people know from the various types of world maps, none of which provide all the necessary properties for rotation-invariant learning.<br />
<br />
[[File:paper26-fig1.png|center]]<br />
<br />
= Notation =<br />
Below are listed several important terms:<br />
* '''Unit Sphere''' <math>S^2</math> is defined as a sphere where all of its points are distance of 1 from the origin. The unit sphere can be parameterized by the spherical coordinates <math>\alpha ∈ [0, 2π]</math> and <math>β ∈ [0, π]</math>. This is a two-dimensional manifold with respect to <math>\alpha</math> and <math>β</math>.<br />
* '''<math>S^2</math> Sphere''' The three dimensional surface from a 3D sphere<br />
* '''Spherical Signals''' In the paper spherical images and filters are modeled as continuous functions <math>f : s^2 → \mathbb{R}^K</math>. K is the number of channels. Such as how RGB images have 3 channels a spherical signal can have numerous channels describing the data. Examples of channels which were used can be found in the experiments section.<br />
* '''Rotations - SO(3)''' The group of 3D rotations on an <math>S^2</math> sphere. Sometimes called the "special orthogonal group". In this paper the ZYZ-Euler parameterization is used to represent SO(3) rotations with <math>\alpha, \beta</math>, and <math>\gamma</math>. Any rotation can be broken down into first a rotation (<math>\alpha</math>) about the Z-axis, then a rotation (<math>\beta</math>) about the new Y-axis (Y'), followed by a rotation (<math>\gamma</math>) about the new Z axis (Z"). [In the rest of this paper, to integrate functions on SO(3), the authors use a rotationally invariant probability measure on the Borel subsets of SO(3). This measure is an example of a Haar measure. Haar measures generalize the idea of rotationally invariant probability measures to general topological groups. For more on Haar measures, see (Feldman 2002) ]<br />
<br />
= Related Work =<br />
The related work presented in this paper is very brief, in large part due to the novelty of spherical CNNs and the length of the rest of the paper. The authors enumerate numerous papers which attempt to exploit larger groups of symmetries such as the translational symmetries of CNNs but do not go into specific details for any of these attempts. They do state that all the previous works are limited to discrete groups with the exception of SO(2)-steerable networks.<br />
The authors also mention that previous works exist that analyze spherical images but that these do not have an equivariant architecture. They claim that Spherical CNNs are "the first to achieve equivariance to a continuous, non-commutative group (SO(3))". They also claim to be the first to use the generalized Fourier transform for speed effective performance of group correlation.<br />
<br />
= Correlations on the Sphere and Rotation Group =<br />
Spherical correlation is like planar correlation except instead of translation, there is rotation. The definitions for each are provided as follows:<br />
<br />
'''Planar correlation''' The value of the output feature map at translation <math>\small x ∈ Z^2</math> is computed as an inner product between the input feature map and a filter, shifted by <math>\small x</math>.<br />
<br />
'''Spherical correlation''' The value of the output feature map evaluated at rotation <math>\small R ∈ SO(3)</math> is computed as an inner product between the input feature map and a filter, rotated by <math>\small R</math>.<br />
<br />
'''Rotation of Spherical Signals''' The paper introduces the rotation operator <math>L_R</math>. The rotation operator simply rotates a function (which allows us to rotate the the spherical filters) by <math>R^{-1}</math>. With this definition we have the property that <math>L_{RR'} = L_R L_{R'}</math>.<br />
<br />
'''Inner Products''' The inner product of spherical signals is simply the integral summation on the vector space over the entire sphere.<br />
<br />
<math>\langle\psi , f \rangle = \int_{S^2} \sum_{k=1}^K \psi_k (x)f_k (x)dx</math><br />
<br />
<math>dx</math> here is SO(3) rotation invariant and is equivalent to <math>d \alpha sin(\beta) d \beta / 4 \pi </math> in spherical coordinates. This comes from the ZYZ-Euler paramaterization where any rotation can be broken down into first a rotation about the Z-axis, then a rotation about the new Y-axis (Y'), followed by a rotation about the new Z axis (Z"). More details on this are given in Appendix A in the paper.<br />
<br />
By this definition, the invariance of the inner product is then guaranteed for any rotation <math>R ∈ SO(3)</math>. In other words, when subjected to rotations, the volume under a spherical heightmap does not change. The following equations show that <math>L_R</math> has a distinct adjoint (<math>L_{R^{-1}}</math>) and that <math>L_R</math> is unitary and thus preserves orthogonality and distances.<br />
<br />
<math>\langle L_R \psi \,, f \rangle = \int_{S^2} \sum_{k=1}^K \psi_k (R^{-1} x)f_k (x)dx</math><br />
<br />
::::<math>= \int_{S^2} \sum_{k=1}^K \psi_k (x)f_k (Rx)dx</math><br />
<br />
::::<math>= \langle \psi , L_{R^{-1}} f \rangle</math><br />
<br />
'''Spherical Correlation''' With the above knowledge the definition of spherical correlation of two signals <math>f</math> and <math>\psi</math> is:<br />
<br />
<math>[\psi \star f](R) = \langle L_R \psi \,, f \rangle = \int_{S^2} \sum_{k=1}^K \psi_k (R^{-1} x)f_k (x)dx</math><br />
<br />
The output of the above equation is a function on SO(3). This can be thought of as for each rotation combination of <math>\alpha , \beta , \gamma </math> there is a different volume under the correlation. The authors make a point of noting that previous work by Driscoll and Healey only ensures circular symmetries about the Z axis and their new formulation ensures symmetry about any rotation.<br />
<br />
'''Rotation of SO(3) Signals''' The first layer of Spherical CNNs take a function on the sphere (<math>S^2</math>) and output a function on SO(3). Therefore, if a Spherical CNN with more than one layer is going to be built there needs to be a way to find the correlation between two signals on SO(3). The authors then generalize the rotation operator (<math>L_R</math>) to encompass acting on signals from SO(3). This new definition of <math>L_R</math> is as follows: (where <math>R^{-1}Q</math> is a composition of rotations, i.e. multiplication of rotation matrices)<br />
<br />
<math>[L_Rf](Q)=f(R^{-1} Q)</math><br />
<br />
'''Rotation Group Correlation''' The correlation of two signals (<math>f,\psi</math>) on SO(3) with K channels is defined as the following:<br />
<br />
<math>[\psi \star f](R) = \langle L_R \psi , f \rangle = \int_{SO(3)} \sum_{k=1}^K \psi_k (R^{-1} Q)f_k (Q)dQ</math><br />
<br />
where dQ represents the ZYZ-Euler angles <math>d \alpha sin(\beta) d \beta d \gamma / 8 \pi^2 </math>. A complete derivation of this can be found in Appendix A.<br />
<br />
'''Equivariance''' The equivariance for the rotation group correlation is similarly demonstrated. A layer is equivariant if for some operator <math>T_R</math>, <math>\Phi \circ L_R = T_R \circ \Phi</math>, and: <br />
<br />
<math>[\psi \star [L_Qf]](R) = \langle L_R \psi , L_Qf \rangle = \langle L_{Q^{-1} R} \psi , f \rangle = [\psi \star f](Q^{-1}R) = [L_Q[\psi \star f]](R) </math>.<br />
<br />
= Implementation with GFFT =<br />
The authors leverage the Generalized Fourier Transform (GFT) and Generalized Fast Fourier Transform (GFFT) algorithms to compute the correlations outlined in the previous section. The Fast Fourier Transform (FFT) can compute correlations and convolutions efficiently by means of the Fourier theorem. The Fourier theorem states that a continuous periodic function can be expressed as a sum of a series of sine or cosine terms (called Fourier coefficients). The FFT can be generalized to <math>S^2</math> and SO(3) and is then called the GFT. The GFT is a linear projection of a function onto orthogonal basis functions. The basis functions are a set of irreducible unitary representations for a group (such as for <math>S^2</math> or SO(3)). For <math>S^2</math> the basis functions are the spherical harmonics <math>Y_m^l(x)</math>. For SO(3) these basis functions are called the Wigner D-functions <math>D_{mn}^l(R)</math>. For both sets of functions the indices are restricted to <math>l\geq0</math> and <math>-l \leq m,n \geq l</math>. The Wigner D-functions are also orthogonal so the Fourier coefficients can be computed by the inner product with the Wigner D-functions (See Appendix C for complete proof). The Wigner D-functions are complete which means that any function (which is well behaved) on SO(3) can be expressed as a linear combination of the Wigner D-functions. The GFT of a function on SO(3) is thus:<br />
<br />
<math>\hat{f^l} = \int_X f(x) D^l(x)dx</math><br />
<br />
where <math>\hat{f}</math> represents the Fourier coefficients. For <math>S^2</math> we have the same equation but with the basis functions <math>Y^l</math>.<br />
<br />
The inverse SO(3) Fourier transform is:<br />
<br />
<math>f(R)=[\mathcal{F}^{-1} \hat{f}](R) = \sum_{l=0}^b (2l + 1) \sum_{m=-l}^l \sum_{n=-l}^l \hat{f_{mn}^l} D_{mn}^l(R) </math><br />
<br />
The bandwidth b represents the maximum frequency and is related to the resolution of the spatial grid. Kostelec and Rockmore are referenced for more knowledge on this topic.<br />
<br />
The authors give proofs (Appendix D) that the SO(3) correlation satisfies the Fourier theorem and the <math>S^2</math> correlation of spherical signals can be computed by the outer products of the <math>S^2</math>-FTs (Shown in Figure 2).<br />
<br />
[[File:paper26-fig2.png|center]]<br />
<br />
The GFFT algorithm details are taken from Kostelec and Rockmore. The authors claim they have the first automatically differentiable implementation of the GFT for <math>S^2</math> and SO(3). The authors do not provide any run time comparisons for real time applications (they just mentioned that FFT can be computed in <math>O(n\mathrm{log}n)</math> time) or any comparisons on training times with/without GFFT. However, they do provide the source code of their implementation at: https://github.com/jonas-koehler/s2cnn.<br />
<br />
= Experiments =<br />
The authors provide several experiments. The first set of experiments are designed to show the numerical stability and accuracy of the outlined methods. The second group of experiments demonstrates how the algorithms can be applied to current problem domains.<br />
<br />
==Equivariance Error==<br />
In this experiment the authors try to show experimentally that their theory of equivariance holds. They express that they had doubts about the equivariance in practice due to potential discretization artifacts since equivariance was proven for the continuous case, with the potential consequence of equivariance not holding being that the weight sharing scheme becomes less effective. The experiment is set up by first testing the equivariance of the SO(3) correlation at different resolutions. 500 random rotations and feature maps (with 10 channels) are sampled. They then calculate the approximation error <math>\small\Delta = \dfrac{1}{n} \sum_{i=1}^n std(L_{R_i} \Phi(f_i) - \phi(L_{R_i} f_i))/std(\Phi(f_i))</math><br />
Note: The authors do not mention what the std function is however it is likely the standard deviation function as 'std' is the command for standard deviation in MATLAB.<br />
<math>\Phi</math> is a composition of SO(3) correlation layers with filters which have been randomly initialized. The authors mention that they were expecting <math>\Delta</math> to be zero in the case of perfect equivariance. This is due to, as proven earlier, the following two terms equaling each other in the continuous case: <math>\small L_{R_i} \Phi(f_i) - \phi(L_{R_i} f_i)</math>. The results are shown in Figure 3. <br />
<br />
[[File:paper26-fig3.png|center]]<br />
<br />
<math>\Delta</math> only grows with resolution/layers when there is no activation function. With ReLU activation the error stays constant once slightly higher than 0 resolution. The authors indicate that the error must therefore be from the feature map rotation since this type of error is exact only for bandlimited functions.<br />
<br />
==MNIST Data==<br />
The experiment using MNIST data was created by projecting MNIST digits onto a sphere using stereographic projection to create the resulting images as seen in Figure 4.<br />
<br />
[[File:paper26-fig4.png|center]]<br />
<br />
The authors created two datasets, one with the projected digits and the other with the same projected digits which were then subjected to a random rotation. The spherical CNN architecture used was <math>\small S^2</math>conv-ReLU-SO(3)conv-ReLU-FC-softmax and was attempted with bandwidths of 30,10,6 and 20,40,10 channels for each layer respectively. This model was compared to a baseline CNN with layers conv-ReLU-conv-ReLU-FC-softmax with 5x5 filters, 32,64,10 channels and stride of 3. For comparison this leads to approximately 68K parameters for the baseline and 58K parameters for the spherical CNN. Results can be seen in Table 1. It is clear from the results that the spherical CNN architecture made the network rotationally invariant. Performance on the rotated set is almost identical to the non-rotated set. This is true even when trained on the non-rotated set and tested on the rotated set. Compare this to the non-spherical architecture which becomes unusable when rotating the digits.<br />
<br />
[[File:paper26-tab1.png|center]]<br />
<br />
==SHREC17==<br />
The SHREC dataset contains 3D models from the ShapeNet dataset which are classified into categories. It consists of a regularly aligned dataset and a rotated dataset. The models from the SHREC17 dataset were projected onto a sphere by means of raycasting. Different properties of the objects obtained from the raycast of the original model and the convex hull of the model make up the different channels which are input into the spherical CNN.<br />
<br />
<br />
[[File:paper26-fig5.png|center]]<br />
<br />
<br />
The network architecture used is an initial <math>\small S^2</math>conv-BN-ReLU block which is followed by two SO(3)conv-BN-ReLU blocks. The output is then fed into a MaxPool-BN block then a linear layer to the output for final classification. The architecture for this experiment has ~1.4M parameters, far exceeding the scale of the spherical CNNs in the other experiments.<br />
<br />
This architecture achieves state of the art results on the SHREC17 tasks. The model places 2nd or 3rd in all categories but was not submitted as the SHREC17 task is closed. Table 2 shows the comparison of results with the top 3 submissions in each category. In the table, P@N stands for precision, R@N stands for recall, F1@N stands for F-score, mAP stands for mean average precision, and NDCG stands for normalized discounted cumulative gain in relevance based on whether the category and subcategory labels are predicted correctly. The authors claim the results show empirical proof of the usefulness of spherical CNNs. They elaborate that this is largely due to the fact that most architectures on the SHREC17 competition are highly specialized whereas their model is fairly general.<br />
<br />
<br />
[[File:paper26-tab2.png|center]]<br />
<br />
==Molecular Atomization==<br />
In this experiment a spherical CNN is implemented with an architecture resembling that of ResNet. They use the QM7 dataset which has the task of predicting atomization energy of molecules. The positions and charges given in the dataset are projected onto the sphere using potential functions. A summary of their results is shown in Table 3 along with some of the spherical CNN architecture details. It shows the different RMSE obtained from different methods. The results from this final experiment also seem to be promising as the network the authors present achieves the second best score. They also note that the first place method grows exponentially with the number of atoms per molecule so is unlikely to scale well.<br />
<br />
[[File:paper26-tab3.png|center]]<br />
<br />
= Conclusions =<br />
This paper presents a novel architecture called Spherical CNNs. The paper defines <math>\small S^2</math> and SO(3) cross correlations, shows the theory behind their rotational invariance for continuous functions, and demonstrates that the invariance also applies to the discrete case. An effective GFFT algorithm was implemented and evaluated on two very different datasets with close to state of the art results, demonstrating that there are practical applications to Spherical CNNs.<br />
<br />
For future work the authors believe that improvements can be obtained by generalizing the algorithms to the SE(3) group (SE(3) simply adds translations in 3D space to the SO(3) group). The authors also briefly mention their excitement for applying Spherical CNNs to omnidirectional vision such as in drones and autonomous cars. They state that there is very little publicly available omnidirectional image data which could be why they did not conduct any experiments in this area.<br />
<br />
= Commentary =<br />
The reviews on Spherical CNNs are very positive and it is ranked in the top 1% of papers submitted to ICLR 2018. Positive points are the novelty of the architecture, the wide variety of experiments performed, and the writing. One critique of the original submission is that the related works section only lists, instead of describing, previous methods and that a description of the methods would have provided more clarity. The authors have since expanded the section however I found that it is still limited which the authors attribute to length limitations. Another critique is that the evaluation does not provide enough depth. For example, it would have been great to see an example of omnidirectional vision for spherical networks. However, this is to be expected as it is just the introduction of spherical CNNs and more work is sure to come.<br />
<br />
= Source Code =<br />
Source code is available at:<br />
https://github.com/jonas-koehler/s2cnn<br />
<br />
= Sources =<br />
* T. Cohen et al. Spherical CNNs, 2018.<br />
* J. Feldman. Haar Measure. http://www.math.ubc.ca/~feldman/m606/haar.pdf<br />
* P. Kostelec, D. Rockmore. FFTs on the Rotation Group, 2008.</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Spherical_CNNs&diff=35148Spherical CNNs2018-03-22T05:37:20Z<p>X249wang: Undo revision 35145 by X249wang (talk)</p>
<hr />
<div>= Introduction =<br />
Convolutional Neural Networks (CNNs), or network architectures involving CNNs, are the current state of the art for learning 2D image processing tasks such as semantic segmentation and object detection. CNNs work well in large part due to the property of being translationally equivariant. This property allows a network trained to detect a certain type of object to still detect the object even if it is translated to another position in the image. However, this does not correspond well to spherical signals since projecting a spherical signal onto a plane will result in distortions, as demonstrated in Figure 1. There are many different types of spherical projections onto a 2D plane, as most people know from the various types of world maps, none of which provide all the necessary properties for rotation-invariant learning.<br />
<br />
[[File:paper26-fig1.png|center]]<br />
<br />
= Notation =<br />
Below are listed several important terms:<br />
* '''Unit Sphere''' <math>S^2</math> is defined as a sphere where all of its points are distance of 1 from the origin. The unit sphere can be parameterized by the spherical coordinates <math>\alpha ∈ [0, 2π]</math> and <math>β ∈ [0, π]</math>. This is a two-dimensional manifold with respect to <math>\alpha</math> and <math>β</math>.<br />
* '''<math>S^2</math> Sphere''' The three dimensional surface from a 3D sphere<br />
* '''Spherical Signals''' In the paper spherical images and filters are modeled as continuous functions <math>f : s^2 → \mathbb{R}^K</math>. K is the number of channels. Such as how RGB images have 3 channels a spherical signal can have numerous channels describing the data. Examples of channels which were used can be found in the experiments section.<br />
* '''Rotations - SO(3)''' The group of 3D rotations on an <math>S^2</math> sphere. Sometimes called the "special orthogonal group". In this paper the ZYZ-Euler parameterization is used to represent SO(3) rotations with <math>\alpha, \beta</math>, and <math>\gamma</math>. Any rotation can be broken down into first a rotation (<math>\alpha</math>) about the Z-axis, then a rotation (<math>\beta</math>) about the new Y-axis (Y'), followed by a rotation (<math>\gamma</math>) about the new Z axis (Z"). [In the rest of this paper, to integrate functions on SO(3), the authors use a rotationally invariant probability measure on the Borel subsets of SO(3). This measure is an example of a Haar measure. Haar measures generalize the idea of rotationally invariant probability measures to general topological groups. For more on Haar measures, see (Feldman 2002) ]<br />
<br />
= Related Work =<br />
The related work presented in this paper is very brief, in large part due to the novelty of spherical CNNs and the length of the rest of the paper. The authors enumerate numerous papers which attempt to exploit larger groups of symmetries such as the translational symmetries of CNNs but do not go into specific details for any of these attempts. They do state that all the previous works are limited to discrete groups with the exception of SO(2)-steerable networks.<br />
The authors also mention that previous works exist that analyze spherical images but that these do not have an equivariant architecture. They claim that Spherical CNNs are "the first to achieve equivariance to a continuous, non-commutative group (SO(3))". They also claim to be the first to use the generalized Fourier transform for speed effective performance of group correlation.<br />
<br />
= Correlations on the Sphere and Rotation Group =<br />
Spherical correlation is like planar correlation except instead of translation, there is rotation. The definitions for each are provided as follows:<br />
<br />
'''Planar correlation''' The value of the output feature map at translation <math>\small x ∈ Z^2</math> is computed as an inner product between the input feature map and a filter, shifted by <math>\small x</math>.<br />
<br />
'''Spherical correlation''' The value of the output feature map evaluated at rotation <math>\small R ∈ SO(3)</math> is computed as an inner product between the input feature map and a filter, rotated by <math>\small R</math>.<br />
<br />
'''Rotation of Spherical Signals''' The paper introduces the rotation operator <math>L_R</math>. The rotation operator simply rotates a function (which allows us to rotate the the spherical filters) by <math>R^{-1}</math>. With this definition we have the property that <math>L_{RR'} = L_R L_{R'}</math>.<br />
<br />
'''Inner Products''' The inner product of spherical signals is simply the integral summation on the vector space over the entire sphere.<br />
<br />
<math>\langle\psi , f \rangle = \int_{S^2} \sum_{k=1}^K \psi_k (x)f_k (x)dx</math><br />
<br />
<math>dx</math> here is SO(3) rotation invariant and is equivalent to <math>d \alpha sin(\beta) d \beta / 4 \pi </math> in spherical coordinates. This comes from the ZYZ-Euler paramaterization where any rotation can be broken down into first a rotation about the Z-axis, then a rotation about the new Y-axis (Y'), followed by a rotation about the new Z axis (Z"). More details on this are given in Appendix A in the paper.<br />
<br />
By this definition, the invariance of the inner product is then guaranteed for any rotation <math>R ∈ SO(3)</math>. In other words, when subjected to rotations, the volume under a spherical heightmap does not change. The following equations show that <math>L_R</math> has a distinct adjoint (<math>L_{R^{-1}}</math>) and that <math>L_R</math> is unitary and thus preserves orthogonality and distances.<br />
<br />
<math>\langle L_R \psi \,, f \rangle = \int_{S^2} \sum_{k=1}^K \psi_k (R^{-1} x)f_k (x)dx</math><br />
<br />
::::<math>= \int_{S^2} \sum_{k=1}^K \psi_k (x)f_k (Rx)dx</math><br />
<br />
::::<math>= \langle \psi , L_{R^{-1}} f \rangle</math><br />
<br />
'''Spherical Correlation''' With the above knowledge the definition of spherical correlation of two signals <math>f</math> and <math>\psi</math> is:<br />
<br />
<math>[\psi \star f](R) = \langle L_R \psi \,, f \rangle = \int_{S^2} \sum_{k=1}^K \psi_k (R^{-1} x)f_k (x)dx</math><br />
<br />
The output of the above equation is a function on SO(3). This can be thought of as for each rotation combination of <math>\alpha , \beta , \gamma </math> there is a different volume under the correlation. The authors make a point of noting that previous work by Driscoll and Healey only ensures circular symmetries about the Z axis and their new formulation ensures symmetry about any rotation.<br />
<br />
'''Rotation of SO(3) Signals''' The first layer of Spherical CNNs take a function on the sphere (<math>S^2</math>) and output a function on SO(3). Therefore, if a Spherical CNN with more than one layer is going to be built there needs to be a way to find the correlation between two signals on SO(3). The authors then generalize the rotation operator (<math>L_R</math>) to encompass acting on signals from SO(3). This new definition of <math>L_R</math> is as follows: (where <math>R^{-1}Q</math> is a composition of rotations, i.e. multiplication of rotation matrices)<br />
<br />
<math>[L_Rf](Q)=f(R^{-1} Q)</math><br />
<br />
'''Rotation Group Correlation''' The correlation of two signals (<math>f,\psi</math>) on SO(3) with K channels is defined as the following:<br />
<br />
<math>[\psi \star f](R) = \langle L_R \psi , f \rangle = \int_{SO(3)} \sum_{k=1}^K \psi_k (R^{-1} Q)f_k (Q)dQ</math><br />
<br />
where dQ represents the ZYZ-Euler angles <math>d \alpha sin(\beta) d \beta d \gamma / 8 \pi^2 </math>. A complete derivation of this can be found in Appendix A.<br />
<br />
The equivariance for the rotation group correlation is similarly demonstrated.<br />
<br />
= Implementation with GFFT =<br />
The authors leverage the Generalized Fourier Transform (GFT) and Generalized Fast Fourier Transform (GFFT) algorithms to compute the correlations outlined in the previous section. The Fast Fourier Transform (FFT) can compute correlations and convolutions efficiently by means of the Fourier theorem. The Fourier theorem states that a continuous periodic function can be expressed as a sum of a series of sine or cosine terms (called Fourier coefficients). The FFT can be generalized to <math>S^2</math> and SO(3) and is then called the GFT. The GFT is a linear projection of a function onto orthogonal basis functions. The basis functions are a set of irreducible unitary representations for a group (such as for <math>S^2</math> or SO(3)). For <math>S^2</math> the basis functions are the spherical harmonics <math>Y_m^l(x)</math>. For SO(3) these basis functions are called the Wigner D-functions <math>D_{mn}^l(R)</math>. For both sets of functions the indices are restricted to <math>l\geq0</math> and <math>-l \leq m,n \geq l</math>. The Wigner D-functions are also orthogonal so the Fourier coefficients can be computed by the inner product with the Wigner D-functions (See Appendix C for complete proof). The Wigner D-functions are complete which means that any function (which is well behaved) on SO(3) can be expressed as a linear combination of the Wigner D-functions. The GFT of a function on SO(3) is thus:<br />
<br />
<math>\hat{f^l} = \int_X f(x) D^l(x)dx</math><br />
<br />
where <math>\hat{f}</math> represents the Fourier coefficients. For <math>S^2</math> we have the same equation but with the basis functions <math>Y^l</math>.<br />
<br />
The inverse SO(3) Fourier transform is:<br />
<br />
<math>f(R)=[\mathcal{F}^{-1} \hat{f}](R) = \sum_{l=0}^b (2l + 1) \sum_{m=-l}^l \sum_{n=-l}^l \hat{f_{mn}^l} D_{mn}^l(R) </math><br />
<br />
The bandwidth b represents the maximum frequency and is related to the resolution of the spatial grid. Kostelec and Rockmore are referenced for more knowledge on this topic.<br />
<br />
The authors give proofs (Appendix D) that the SO(3) correlation satisfies the Fourier theorem and the <math>S^2</math> correlation of spherical signals can be computed by the outer products of the <math>S^2</math>-FTs (Shown in Figure 2).<br />
<br />
[[File:paper26-fig2.png|center]]<br />
<br />
The GFFT algorithm details are taken from Kostelec and Rockmore. The authors claim they have the first automatically differentiable implementation of the GFT for <math>S^2</math> and SO(3). The authors do not provide any run time comparisons for real time applications (they just mentioned that FFT can be computed in <math>O(n\mathrm{log}n)</math> time) or any comparisons on training times with/without GFFT. However, they do provide the source code of their implementation at: https://github.com/jonas-koehler/s2cnn.<br />
<br />
= Experiments =<br />
The authors provide several experiments. The first set of experiments are designed to show the numerical stability and accuracy of the outlined methods. The second group of experiments demonstrates how the algorithms can be applied to current problem domains.<br />
<br />
==Equivariance Error==<br />
In this experiment the authors try to show experimentally that their theory of equivariance holds. They express that they had doubts about the equivariance in practice due to potential discretization artifacts since equivariance was proven for the continuous case, with the potential consequence of equivariance not holding being that the weight sharing scheme becomes less effective. The experiment is set up by first testing the equivariance of the SO(3) correlation at different resolutions. 500 random rotations and feature maps (with 10 channels) are sampled. They then calculate the approximation error <math>\small\Delta = \dfrac{1}{n} \sum_{i=1}^n std(L_{R_i} \Phi(f_i) - \phi(L_{R_i} f_i))/std(\Phi(f_i))</math><br />
Note: The authors do not mention what the std function is however it is likely the standard deviation function as 'std' is the command for standard deviation in MATLAB.<br />
<math>\Phi</math> is a composition of SO(3) correlation layers with filters which have been randomly initialized. The authors mention that they were expecting <math>\Delta</math> to be zero in the case of perfect equivariance. This is due to, as proven earlier, the following two terms equaling each other in the continuous case: <math>\small L_{R_i} \Phi(f_i) - \phi(L_{R_i} f_i)</math>. The results are shown in Figure 3. <br />
<br />
[[File:paper26-fig3.png|center]]<br />
<br />
<math>\Delta</math> only grows with resolution/layers when there is no activation function. With ReLU activation the error stays constant once slightly higher than 0 resolution. The authors indicate that the error must therefore be from the feature map rotation since this type of error is exact only for bandlimited functions.<br />
<br />
==MNIST Data==<br />
The experiment using MNIST data was created by projecting MNIST digits onto a sphere using stereographic projection to create the resulting images as seen in Figure 4.<br />
<br />
[[File:paper26-fig4.png|center]]<br />
<br />
The authors created two datasets, one with the projected digits and the other with the same projected digits which were then subjected to a random rotation. The spherical CNN architecture used was <math>\small S^2</math>conv-ReLU-SO(3)conv-ReLU-FC-softmax and was attempted with bandwidths of 30,10,6 and 20,40,10 channels for each layer respectively. This model was compared to a baseline CNN with layers conv-ReLU-conv-ReLU-FC-softmax with 5x5 filters, 32,64,10 channels and stride of 3. For comparison this leads to approximately 68K parameters for the baseline and 58K parameters for the spherical CNN. Results can be seen in Table 1. It is clear from the results that the spherical CNN architecture made the network rotationally invariant. Performance on the rotated set is almost identical to the non-rotated set. This is true even when trained on the non-rotated set and tested on the rotated set. Compare this to the non-spherical architecture which becomes unusable when rotating the digits.<br />
<br />
[[File:paper26-tab1.png|center]]<br />
<br />
==SHREC17==<br />
The SHREC dataset contains 3D models from the ShapeNet dataset which are classified into categories. It consists of a regularly aligned dataset and a rotated dataset. The models from the SHREC17 dataset were projected onto a sphere by means of raycasting. Different properties of the objects obtained from the raycast of the original model and the convex hull of the model make up the different channels which are input into the spherical CNN.<br />
<br />
<br />
[[File:paper26-fig5.png|center]]<br />
<br />
<br />
The network architecture used is an initial <math>\small S^2</math>conv-BN-ReLU block which is followed by two SO(3)conv-BN-ReLU blocks. The output is then fed into a MaxPool-BN block then a linear layer to the output for final classification. The architecture for this experiment has ~1.4M parameters, far exceeding the scale of the spherical CNNs in the other experiments.<br />
<br />
This architecture achieves state of the art results on the SHREC17 tasks. The model places 2nd or 3rd in all categories but was not submitted as the SHREC17 task is closed. Table 2 shows the comparison of results with the top 3 submissions in each category. In the table, P@N stands for precision, R@N stands for recall, F1@N stands for F-score, mAP stands for mean average precision, and NDCG stands for normalized discounted cumulative gain in relevance based on whether the category and subcategory labels are predicted correctly. The authors claim the results show empirical proof of the usefulness of spherical CNNs. They elaborate that this is largely due to the fact that most architectures on the SHREC17 competition are highly specialized whereas their model is fairly general.<br />
<br />
<br />
[[File:paper26-tab2.png|center]]<br />
<br />
==Molecular Atomization==<br />
In this experiment a spherical CNN is implemented with an architecture resembling that of ResNet. They use the QM7 dataset which has the task of predicting atomization energy of molecules. The positions and charges given in the dataset are projected onto the sphere using potential functions. A summary of their results is shown in Table 3 along with some of the spherical CNN architecture details. It shows the different RMSE obtained from different methods. The results from this final experiment also seem to be promising as the network the authors present achieves the second best score. They also note that the first place method grows exponentially with the number of atoms per molecule so is unlikely to scale well.<br />
<br />
[[File:paper26-tab3.png|center]]<br />
<br />
= Conclusions =<br />
This paper presents a novel architecture called Spherical CNNs. The paper defines <math>\small S^2</math> and SO(3) cross correlations, shows the theory behind their rotational invariance for continuous functions, and demonstrates that the invariance also applies to the discrete case. An effective GFFT algorithm was implemented and evaluated on two very different datasets with close to state of the art results, demonstrating that there are practical applications to Spherical CNNs.<br />
<br />
For future work the authors believe that improvements can be obtained by generalizing the algorithms to the SE(3) group (SE(3) simply adds translations in 3D space to the SO(3) group). The authors also briefly mention their excitement for applying Spherical CNNs to omnidirectional vision such as in drones and autonomous cars. They state that there is very little publicly available omnidirectional image data which could be why they did not conduct any experiments in this area.<br />
<br />
= Commentary =<br />
The reviews on Spherical CNNs are very positive and it is ranked in the top 1% of papers submitted to ICLR 2018. Positive points are the novelty of the architecture, the wide variety of experiments performed, and the writing. One critique of the original submission is that the related works section only lists, instead of describing, previous methods and that a description of the methods would have provided more clarity. The authors have since expanded the section however I found that it is still limited which the authors attribute to length limitations. Another critique is that the evaluation does not provide enough depth. For example, it would have been great to see an example of omnidirectional vision for spherical networks. However, this is to be expected as it is just the introduction of spherical CNNs and more work is sure to come.<br />
<br />
= Source Code =<br />
Source code is available at:<br />
https://github.com/jonas-koehler/s2cnn<br />
<br />
= Sources =<br />
* T. Cohen et al. Spherical CNNs, 2018.<br />
* J. Feldman. Haar Measure. http://www.math.ubc.ca/~feldman/m606/haar.pdf<br />
* P. Kostelec, D. Rockmore. FFTs on the Rotation Group, 2008.</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Spherical_CNNs&diff=35145Spherical CNNs2018-03-22T05:33:13Z<p>X249wang: /* Experiments */</p>
<hr />
<div>= Introduction =<br />
Convolutional Neural Networks (CNNs), or network architectures involving CNNs, are the current state of the art for learning 2D image processing tasks such as semantic segmentation and object detection. CNNs work well in large part due to the property of being translationally equivariant. This property allows a network trained to detect a certain type of object to still detect the object even if it is translated to another position in the image. However, this does not correspond well to spherical signals since projecting a spherical signal onto a plane will result in distortions, as demonstrated in Figure 1. There are many different types of spherical projections onto a 2D plane, as most people know from the various types of world maps, none of which provide all the necessary properties for rotation-invariant learning.<br />
<br />
[[File:paper26-fig1.png|center]]<br />
<br />
= Notation =<br />
Below are listed several important terms:<br />
* '''Unit Sphere''' <math>S^2</math> is defined as a sphere where all of its points are distance of 1 from the origin. The unit sphere can be parameterized by the spherical coordinates <math>\alpha ∈ [0, 2π]</math> and <math>β ∈ [0, π]</math>. This is a two-dimensional manifold with respect to <math>\alpha</math> and <math>β</math>.<br />
* '''<math>S^2</math> Sphere''' The three dimensional surface from a 3D sphere<br />
* '''Spherical Signals''' In the paper spherical images and filters are modeled as continuous functions <math>f : s^2 → \mathbb{R}^K</math>. K is the number of channels. Such as how RGB images have 3 channels a spherical signal can have numerous channels describing the data. Examples of channels which were used can be found in the experiments section.<br />
* '''Rotations - SO(3)''' The group of 3D rotations on an <math>S^2</math> sphere. Sometimes called the "special orthogonal group". In this paper the ZYZ-Euler parameterization is used to represent SO(3) rotations with <math>\alpha, \beta</math>, and <math>\gamma</math>. Any rotation can be broken down into first a rotation (<math>\alpha</math>) about the Z-axis, then a rotation (<math>\beta</math>) about the new Y-axis (Y'), followed by a rotation (<math>\gamma</math>) about the new Z axis (Z"). [In the rest of this paper, to integrate functions on SO(3), the authors use a rotationally invariant probability measure on the Borel subsets of SO(3). This measure is an example of a Haar measure. Haar measures generalize the idea of rotationally invariant probability measures to general topological groups. For more on Haar measures, see (Feldman 2002) ]<br />
<br />
= Related Work =<br />
The related work presented in this paper is very brief, in large part due to the novelty of spherical CNNs and the length of the rest of the paper. The authors enumerate numerous papers which attempt to exploit larger groups of symmetries such as the translational symmetries of CNNs but do not go into specific details for any of these attempts. They do state that all the previous works are limited to discrete groups with the exception of SO(2)-steerable networks.<br />
The authors also mention that previous works exist that analyze spherical images but that these do not have an equivariant architecture. They claim that Spherical CNNs are "the first to achieve equivariance to a continuous, non-commutative group (SO(3))". They also claim to be the first to use the generalized Fourier transform for speed effective performance of group correlation.<br />
<br />
= Correlations on the Sphere and Rotation Group =<br />
Spherical correlation is like planar correlation except instead of translation, there is rotation. The definitions for each are provided as follows:<br />
<br />
'''Planar correlation''' The value of the output feature map at translation <math>\small x ∈ Z^2</math> is computed as an inner product between the input feature map and a filter, shifted by <math>\small x</math>.<br />
<br />
'''Spherical correlation''' The value of the output feature map evaluated at rotation <math>\small R ∈ SO(3)</math> is computed as an inner product between the input feature map and a filter, rotated by <math>\small R</math>.<br />
<br />
'''Rotation of Spherical Signals''' The paper introduces the rotation operator <math>L_R</math>. The rotation operator simply rotates a function (which allows us to rotate the the spherical filters) by <math>R^{-1}</math>. With this definition we have the property that <math>L_{RR'} = L_R L_{R'}</math>.<br />
<br />
'''Inner Products''' The inner product of spherical signals is simply the integral summation on the vector space over the entire sphere.<br />
<br />
<math>\langle\psi , f \rangle = \int_{S^2} \sum_{k=1}^K \psi_k (x)f_k (x)dx</math><br />
<br />
<math>dx</math> here is SO(3) rotation invariant and is equivalent to <math>d \alpha sin(\beta) d \beta / 4 \pi </math> in spherical coordinates. This comes from the ZYZ-Euler paramaterization where any rotation can be broken down into first a rotation about the Z-axis, then a rotation about the new Y-axis (Y'), followed by a rotation about the new Z axis (Z"). More details on this are given in Appendix A in the paper.<br />
<br />
By this definition, the invariance of the inner product is then guaranteed for any rotation <math>R ∈ SO(3)</math>. In other words, when subjected to rotations, the volume under a spherical heightmap does not change. The following equations show that <math>L_R</math> has a distinct adjoint (<math>L_{R^{-1}}</math>) and that <math>L_R</math> is unitary and thus preserves orthogonality and distances.<br />
<br />
<math>\langle L_R \psi \,, f \rangle = \int_{S^2} \sum_{k=1}^K \psi_k (R^{-1} x)f_k (x)dx</math><br />
<br />
::::<math>= \int_{S^2} \sum_{k=1}^K \psi_k (x)f_k (Rx)dx</math><br />
<br />
::::<math>= \langle \psi , L_{R^{-1}} f \rangle</math><br />
<br />
'''Spherical Correlation''' With the above knowledge the definition of spherical correlation of two signals <math>f</math> and <math>\psi</math> is:<br />
<br />
<math>[\psi \star f](R) = \langle L_R \psi \,, f \rangle = \int_{S^2} \sum_{k=1}^K \psi_k (R^{-1} x)f_k (x)dx</math><br />
<br />
The output of the above equation is a function on SO(3). This can be thought of as for each rotation combination of <math>\alpha , \beta , \gamma </math> there is a different volume under the correlation. The authors make a point of noting that previous work by Driscoll and Healey only ensures circular symmetries about the Z axis and their new formulation ensures symmetry about any rotation.<br />
<br />
'''Rotation of SO(3) Signals''' The first layer of Spherical CNNs take a function on the sphere (<math>S^2</math>) and output a function on SO(3). Therefore, if a Spherical CNN with more than one layer is going to be built there needs to be a way to find the correlation between two signals on SO(3). The authors then generalize the rotation operator (<math>L_R</math>) to encompass acting on signals from SO(3). This new definition of <math>L_R</math> is as follows: (where <math>R^{-1}Q</math> is a composition of rotations, i.e. multiplication of rotation matrices)<br />
<br />
<math>[L_Rf](Q)=f(R^{-1} Q)</math><br />
<br />
'''Rotation Group Correlation''' The correlation of two signals (<math>f,\psi</math>) on SO(3) with K channels is defined as the following:<br />
<br />
<math>[\psi \star f](R) = \langle L_R \psi , f \rangle = \int_{SO(3)} \sum_{k=1}^K \psi_k (R^{-1} Q)f_k (Q)dQ</math><br />
<br />
where dQ represents the ZYZ-Euler angles <math>d \alpha sin(\beta) d \beta d \gamma / 8 \pi^2 </math>. A complete derivation of this can be found in Appendix A.<br />
<br />
The equivariance for the rotation group correlation is similarly demonstrated.<br />
<br />
= Implementation with GFFT =<br />
The authors leverage the Generalized Fourier Transform (GFT) and Generalized Fast Fourier Transform (GFFT) algorithms to compute the correlations outlined in the previous section. The Fast Fourier Transform (FFT) can compute correlations and convolutions efficiently by means of the Fourier theorem. The Fourier theorem states that a continuous periodic function can be expressed as a sum of a series of sine or cosine terms (called Fourier coefficients). The FFT can be generalized to <math>S^2</math> and SO(3) and is then called the GFT. The GFT is a linear projection of a function onto orthogonal basis functions. The basis functions are a set of irreducible unitary representations for a group (such as for <math>S^2</math> or SO(3)). For <math>S^2</math> the basis functions are the spherical harmonics <math>Y_m^l(x)</math>. For SO(3) these basis functions are called the Wigner D-functions <math>D_{mn}^l(R)</math>. For both sets of functions the indices are restricted to <math>l\geq0</math> and <math>-l \leq m,n \geq l</math>. The Wigner D-functions are also orthogonal so the Fourier coefficients can be computed by the inner product with the Wigner D-functions (See Appendix C for complete proof). The Wigner D-functions are complete which means that any function (which is well behaved) on SO(3) can be expressed as a linear combination of the Wigner D-functions. The GFT of a function on SO(3) is thus:<br />
<br />
<math>\hat{f^l} = \int_X f(x) D^l(x)dx</math><br />
<br />
where <math>\hat{f}</math> represents the Fourier coefficients. For <math>S^2</math> we have the same equation but with the basis functions <math>Y^l</math>.<br />
<br />
The inverse SO(3) Fourier transform is:<br />
<br />
<math>f(R)=[\mathcal{F}^{-1} \hat{f}](R) = \sum_{l=0}^b (2l + 1) \sum_{m=-l}^l \sum_{n=-l}^l \hat{f_{mn}^l} D_{mn}^l(R) </math><br />
<br />
The bandwidth b represents the maximum frequency and is related to the resolution of the spatial grid. Kostelec and Rockmore are referenced for more knowledge on this topic.<br />
<br />
The authors give proofs (Appendix D) that the SO(3) correlation satisfies the Fourier theorem and the <math>S^2</math> correlation of spherical signals can be computed by the outer products of the <math>S^2</math>-FTs (Shown in Figure 2).<br />
<br />
[[File:paper26-fig2.png|center]]<br />
<br />
The GFFT algorithm details are taken from Kostelec and Rockmore. The authors claim they have the first automatically differentiable implementation of the GFT for <math>S^2</math> and SO(3). The authors do not provide any run time comparisons for real time applications (they just mentioned that FFT can be computed in <math>O(n\mathrm{log}n)</math> time) or any comparisons on training times with/without GFFT. However, they do provide the source code of their implementation at: https://github.com/jonas-koehler/s2cnn.<br />
<br />
= Experiments =<br />
The authors provide several experiments. The first set of experiments are designed to show the numerical stability and accuracy of the outlined methods. The second group of experiments (MNIST, SHREC17 and Molecular Atomization) demonstrates how the algorithms can be applied to current problem domains.<br />
<br />
==Equivariance Error==<br />
In this experiment the authors try to show experimentally that their theory of equivariance holds. They express that they had doubts about the equivariance in practice due to potential discretization artifacts since equivariance was proven for the continuous case, with the potential consequence of equivariance not holding being that the weight sharing scheme becomes less effective. The experiment is set up by first testing the equivariance of the SO(3) correlation at different resolutions. 500 random rotations and feature maps (with 10 channels) are sampled. They then calculate the approximation error <math>\small\Delta = \dfrac{1}{n} \sum_{i=1}^n std(L_{R_i} \Phi(f_i) - \phi(L_{R_i} f_i))/std(\Phi(f_i))</math><br />
Note: The authors do not mention what the std function is however it is likely the standard deviation function as 'std' is the command for standard deviation in MATLAB.<br />
<math>\Phi</math> is a composition of SO(3) correlation layers with filters which have been randomly initialized. The authors mention that they were expecting <math>\Delta</math> to be zero in the case of perfect equivariance. This is due to, as proven earlier, the following two terms equaling each other in the continuous case: <math>\small L_{R_i} \Phi(f_i) - \phi(L_{R_i} f_i)</math>. The results are shown in Figure 3. <br />
<br />
[[File:paper26-fig3.png|center]]<br />
<br />
<math>\Delta</math> only grows with resolution/layers when there is no activation function. With ReLU activation the error stays constant once slightly higher than 0 resolution. The authors indicate that the error must therefore be from the feature map rotation since this type of error is exact only for bandlimited functions.<br />
<br />
==MNIST Data==<br />
The experiment using MNIST data was created by projecting MNIST digits onto a sphere using stereographic projection to create the resulting images as seen in Figure 4.<br />
<br />
[[File:paper26-fig4.png|center]]<br />
<br />
The authors created two datasets, one with the projected digits and the other with the same projected digits which were then subjected to a random rotation. The spherical CNN architecture used was <math>\small S^2</math>conv-ReLU-SO(3)conv-ReLU-FC-softmax and was attempted with bandwidths of 30,10,6 and 20,40,10 channels for each layer respectively. This model was compared to a baseline CNN with layers conv-ReLU-conv-ReLU-FC-softmax with 5x5 filters, 32,64,10 channels and stride of 3. For comparison this leads to approximately 68K parameters for the baseline and 58K parameters for the spherical CNN. Results can be seen in Table 1. It is clear from the results that the spherical CNN architecture made the network rotationally invariant. Performance on the rotated set is almost identical to the non-rotated set. This is true even when trained on the non-rotated set and tested on the rotated set. Compare this to the non-spherical architecture which becomes unusable when rotating the digits.<br />
<br />
[[File:paper26-tab1.png|center]]<br />
<br />
==SHREC17==<br />
The SHREC dataset contains 3D models from the ShapeNet dataset which are classified into categories. It consists of a regularly aligned dataset and a rotated dataset. The models from the SHREC17 dataset were projected onto a sphere by means of raycasting. Different properties of the objects obtained from the raycast of the original model and the convex hull of the model make up the different channels which are input into the spherical CNN.<br />
<br />
<br />
[[File:paper26-fig5.png|center]]<br />
<br />
<br />
The network architecture used is an initial <math>\small S^2</math>conv-BN-ReLU block which is followed by two SO(3)conv-BN-ReLU blocks. The output is then fed into a MaxPool-BN block then a linear layer to the output for final classification. The architecture for this experiment has ~1.4M parameters, far exceeding the scale of the spherical CNNs in the other experiments.<br />
<br />
This architecture achieves state of the art results on the SHREC17 tasks. The model places 2nd or 3rd in all categories but was not submitted as the SHREC17 task is closed. Table 2 shows the comparison of results with the top 3 submissions in each category. In the table, P@N stands for precision, R@N stands for recall, F1@N stands for F-score, mAP stands for mean average precision, and NDCG stands for normalized discounted cumulative gain in relevance based on whether the category and subcategory labels are predicted correctly. The authors claim the results show empirical proof of the usefulness of spherical CNNs. They elaborate that this is largely due to the fact that most architectures on the SHREC17 competition are highly specialized whereas their model is fairly general.<br />
<br />
<br />
[[File:paper26-tab2.png|center]]<br />
<br />
==Molecular Atomization==<br />
In this experiment a spherical CNN is implemented with an architecture resembling that of ResNet. They use the QM7 dataset which has the task of predicting atomization energy of molecules. The positions and charges given in the dataset are projected onto the sphere using potential functions. A summary of their results is shown in Table 3 along with some of the spherical CNN architecture details. It shows the different RMSE obtained from different methods. The results from this final experiment also seem to be promising as the network the authors present achieves the second best score. They also note that the first place method grows exponentially with the number of atoms per molecule so is unlikely to scale well.<br />
<br />
[[File:paper26-tab3.png|center]]<br />
<br />
= Conclusions =<br />
This paper presents a novel architecture called Spherical CNNs. The paper defines <math>\small S^2</math> and SO(3) cross correlations, shows the theory behind their rotational invariance for continuous functions, and demonstrates that the invariance also applies to the discrete case. An effective GFFT algorithm was implemented and evaluated on two very different datasets with close to state of the art results, demonstrating that there are practical applications to Spherical CNNs.<br />
<br />
For future work the authors believe that improvements can be obtained by generalizing the algorithms to the SE(3) group (SE(3) simply adds translations in 3D space to the SO(3) group). The authors also briefly mention their excitement for applying Spherical CNNs to omnidirectional vision such as in drones and autonomous cars. They state that there is very little publicly available omnidirectional image data which could be why they did not conduct any experiments in this area.<br />
<br />
= Commentary =<br />
The reviews on Spherical CNNs are very positive and it is ranked in the top 1% of papers submitted to ICLR 2018. Positive points are the novelty of the architecture, the wide variety of experiments performed, and the writing. One critique of the original submission is that the related works section only lists, instead of describing, previous methods and that a description of the methods would have provided more clarity. The authors have since expanded the section however I found that it is still limited which the authors attribute to length limitations. Another critique is that the evaluation does not provide enough depth. For example, it would have been great to see an example of omnidirectional vision for spherical networks. However, this is to be expected as it is just the introduction of spherical CNNs and more work is sure to come.<br />
<br />
= Source Code =<br />
Source code is available at:<br />
https://github.com/jonas-koehler/s2cnn<br />
<br />
= Sources =<br />
* T. Cohen et al. Spherical CNNs, 2018.<br />
* J. Feldman. Haar Measure. http://www.math.ubc.ca/~feldman/m606/haar.pdf<br />
* P. Kostelec, D. Rockmore. FFTs on the Rotation Group, 2008.</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Spherical_CNNs&diff=35139Spherical CNNs2018-03-22T04:57:16Z<p>X249wang: /* SHREC17 */</p>
<hr />
<div>= Introduction =<br />
Convolutional Neural Networks (CNNs), or network architectures involving CNNs, are the current state of the art for learning 2D image processing tasks such as semantic segmentation and object detection. CNNs work well in large part due to the property of being translationally equivariant. This property allows a network trained to detect a certain type of object to still detect the object even if it is translated to another position in the image. However, this does not correspond well to spherical signals since projecting a spherical signal onto a plane will result in distortions, as demonstrated in Figure 1. There are many different types of spherical projections onto a 2D plane, as most people know from the various types of world maps, none of which provide all the necessary properties for rotation-invariant learning.<br />
<br />
[[File:paper26-fig1.png|center]]<br />
<br />
= Notation =<br />
Below are listed several important terms:<br />
* '''Unit Sphere''' <math>S^2</math> is defined as a sphere where all of its points are distance of 1 from the origin. The unit sphere can be parameterized by the spherical coordinates <math>\alpha ∈ [0, 2π]</math> and <math>β ∈ [0, π]</math>. This is a two-dimensional manifold with respect to <math>\alpha</math> and <math>β</math>.<br />
* '''<math>S^2</math> Sphere''' The three dimensional surface from a 3D sphere<br />
* '''Spherical Signals''' In the paper spherical images and filters are modeled as continuous functions <math>f : s^2 → \mathbb{R}^K</math>. K is the number of channels. Such as how RGB images have 3 channels a spherical signal can have numerous channels describing the data. Examples of channels which were used can be found in the experiments section.<br />
* '''Rotations - SO(3)''' The group of 3D rotations on an <math>S^2</math> sphere. Sometimes called the "special orthogonal group". In this paper the ZYZ-Euler parameterization is used to represent SO(3) rotations with <math>\alpha, \beta</math>, and <math>\gamma</math>. Any rotation can be broken down into first a rotation (<math>\alpha</math>) about the Z-axis, then a rotation (<math>\beta</math>) about the new Y-axis (Y'), followed by a rotation (<math>\gamma</math>) about the new Z axis (Z"). [In the rest of this paper, to integrate functions on SO(3), the authors use a rotationally invariant probability measure on the Borel subsets of SO(3). This measure is an example of a Haar measure. Haar measures generalize the idea of rotationally invariant probability measures to general topological groups. For more on Haar measures, see (Feldman 2002) ]<br />
<br />
= Related Work =<br />
The related work presented in this paper is very brief, in large part due to the novelty of spherical CNNs and the length of the rest of the paper. The authors enumerate numerous papers which attempt to exploit larger groups of symmetries such as the translational symmetries of CNNs but do not go into specific details for any of these attempts. They do state that all the previous works are limited to discrete groups with the exception of SO(2)-steerable networks.<br />
The authors also mention that previous works exist that analyze spherical images but that these do not have an equivariant architecture. They claim that Spherical CNNs are "the first to achieve equivariance to a continuous, non-commutative group (SO(3))". They also claim to be the first to use the generalized Fourier transform for speed effective performance of group correlation.<br />
<br />
= Correlations on the Sphere and Rotation Group =<br />
Spherical correlation is like planar correlation except instead of translation, there is rotation. The definitions for each are provided as follows:<br />
<br />
'''Planar correlation''' The value of the output feature map at translation <math>\small x ∈ Z^2</math> is computed as an inner product between the input feature map and a filter, shifted by <math>\small x</math>.<br />
<br />
'''Spherical correlation''' The value of the output feature map evaluated at rotation <math>\small R ∈ SO(3)</math> is computed as an inner product between the input feature map and a filter, rotated by <math>\small R</math>.<br />
<br />
'''Rotation of Spherical Signals''' The paper introduces the rotation operator <math>L_R</math>. The rotation operator simply rotates a function (which allows us to rotate the the spherical filters) by <math>R^{-1}</math>. With this definition we have the property that <math>L_{RR'} = L_R L_{R'}</math>.<br />
<br />
'''Inner Products''' The inner product of spherical signals is simply the integral summation on the vector space over the entire sphere.<br />
<br />
<math>\langle\psi , f \rangle = \int_{S^2} \sum_{k=1}^K \psi_k (x)f_k (x)dx</math><br />
<br />
<math>dx</math> here is SO(3) rotation invariant and is equivalent to <math>d \alpha sin(\beta) d \beta / 4 \pi </math> in spherical coordinates. This comes from the ZYZ-Euler paramaterization where any rotation can be broken down into first a rotation about the Z-axis, then a rotation about the new Y-axis (Y'), followed by a rotation about the new Z axis (Z"). More details on this are given in Appendix A in the paper.<br />
<br />
By this definition, the invariance of the inner product is then guaranteed for any rotation <math>R ∈ SO(3)</math>. In other words, when subjected to rotations, the volume under a spherical heightmap does not change. The following equations show that <math>L_R</math> has a distinct adjoint (<math>L_{R^{-1}}</math>) and that <math>L_R</math> is unitary and thus preserves orthogonality and distances.<br />
<br />
<math>\langle L_R \psi \,, f \rangle = \int_{S^2} \sum_{k=1}^K \psi_k (R^{-1} x)f_k (x)dx</math><br />
<br />
::::<math>= \int_{S^2} \sum_{k=1}^K \psi_k (x)f_k (Rx)dx</math><br />
<br />
::::<math>= \langle \psi , L_{R^{-1}} f \rangle</math><br />
<br />
'''Spherical Correlation''' With the above knowledge the definition of spherical correlation of two signals <math>f</math> and <math>\psi</math> is:<br />
<br />
<math>[\psi \star f](R) = \langle L_R \psi \,, f \rangle = \int_{S^2} \sum_{k=1}^K \psi_k (R^{-1} x)f_k (x)dx</math><br />
<br />
The output of the above equation is a function on SO(3). This can be thought of as for each rotation combination of <math>\alpha , \beta , \gamma </math> there is a different volume under the correlation. The authors make a point of noting that previous work by Driscoll and Healey only ensures circular symmetries about the Z axis and their new formulation ensures symmetry about any rotation.<br />
<br />
'''Rotation of SO(3) Signals''' The first layer of Spherical CNNs take a function on the sphere (<math>S^2</math>) and output a function on SO(3). Therefore, if a Spherical CNN with more than one layer is going to be built there needs to be a way to find the correlation between two signals on SO(3). The authors then generalize the rotation operator (<math>L_R</math>) to encompass acting on signals from SO(3). This new definition of <math>L_R</math> is as follows: (where <math>R^{-1}Q</math> is a composition of rotations, i.e. multiplication of rotation matrices)<br />
<br />
<math>[L_Rf](Q)=f(R^{-1} Q)</math><br />
<br />
'''Rotation Group Correlation''' The correlation of two signals (<math>f,\psi</math>) on SO(3) with K channels is defined as the following:<br />
<br />
<math>[\psi \star f](R) = \langle L_R \psi , f \rangle = \int_{SO(3)} \sum_{k=1}^K \psi_k (R^{-1} Q)f_k (Q)dQ</math><br />
<br />
where dQ represents the ZYZ-Euler angles <math>d \alpha sin(\beta) d \beta d \gamma / 8 \pi^2 </math>. A complete derivation of this can be found in Appendix A.<br />
<br />
The equivariance for the rotation group correlation is similarly demonstrated.<br />
<br />
= Implementation with GFFT =<br />
The authors leverage the Generalized Fourier Transform (GFT) and Generalized Fast Fourier Transform (GFFT) algorithms to compute the correlations outlined in the previous section. The Fast Fourier Transform (FFT) can compute correlations and convolutions efficiently by means of the Fourier theorem. The Fourier theorem states that a continuous periodic function can be expressed as a sum of a series of sine or cosine terms (called Fourier coefficients). The FFT can be generalized to <math>S^2</math> and SO(3) and is then called the GFT. The GFT is a linear projection of a function onto orthogonal basis functions. The basis functions are a set of irreducible unitary representations for a group (such as for <math>S^2</math> or SO(3)). For <math>S^2</math> the basis functions are the spherical harmonics <math>Y_m^l(x)</math>. For SO(3) these basis functions are called the Wigner D-functions <math>D_{mn}^l(R)</math>. For both sets of functions the indices are restricted to <math>l\geq0</math> and <math>-l \leq m,n \geq l</math>. The Wigner D-functions are also orthogonal so the Fourier coefficients can be computed by the inner product with the Wigner D-functions (See Appendix C for complete proof). The Wigner D-functions are complete which means that any function (which is well behaved) on SO(3) can be expressed as a linear combination of the Wigner D-functions. The GFT of a function on SO(3) is thus:<br />
<br />
<math>\hat{f^l} = \int_X f(x) D^l(x)dx</math><br />
<br />
where <math>\hat{f}</math> represents the Fourier coefficients. For <math>S^2</math> we have the same equation but with the basis functions <math>Y^l</math>.<br />
<br />
The inverse SO(3) Fourier transform is:<br />
<br />
<math>f(R)=[\mathcal{F}^{-1} \hat{f}](R) = \sum_{l=0}^b (2l + 1) \sum_{m=-l}^l \sum_{n=-l}^l \hat{f_{mn}^l} D_{mn}^l(R) </math><br />
<br />
The bandwidth b represents the maximum frequency and is related to the resolution of the spatial grid. Kostelec and Rockmore are referenced for more knowledge on this topic.<br />
<br />
The authors give proofs (Appendix D) that the SO(3) correlation satisfies the Fourier theorem and the <math>S^2</math> correlation of spherical signals can be computed by the outer products of the <math>S^2</math>-FTs (Shown in Figure 2).<br />
<br />
[[File:paper26-fig2.png|center]]<br />
<br />
The GFFT algorithm details are taken from Kostelec and Rockmore. The authors claim they have the first automatically differentiable implementation of the GFT for <math>S^2</math> and SO(3). The authors do not provide any run time comparisons for real time applications (they just mentioned that FFT can be computed in <math>O(n\mathrm{log}n)</math> time) or any comparisons on training times with/without GFFT. However, they do provide the source code of their implementation at: https://github.com/jonas-koehler/s2cnn.<br />
<br />
= Experiments =<br />
The authors provide several experiments. The first set of experiments are designed to show the numerical stability and accuracy of the outlined methods. The second group of experiments demonstrates how the algorithms can be applied to current problem domains.<br />
<br />
==Equivariance Error==<br />
In this experiment the authors try to show experimentally that their theory of equivariance holds. They express that they had doubts about the equivariance in practice due to potential discretization artifacts since equivariance was proven for the continuous case, with the potential consequence of equivariance not holding being that the weight sharing scheme becomes less effective. The experiment is set up by first testing the equivariance of the SO(3) correlation at different resolutions. 500 random rotations and feature maps (with 10 channels) are sampled. They then calculate the approximation error <math>\small\Delta = \dfrac{1}{n} \sum_{i=1}^n std(L_{R_i} \Phi(f_i) - \phi(L_{R_i} f_i))/std(\Phi(f_i))</math><br />
Note: The authors do not mention what the std function is however it is likely the standard deviation function as 'std' is the command for standard deviation in MATLAB.<br />
<math>\Phi</math> is a composition of SO(3) correlation layers with filters which have been randomly initialized. The authors mention that they were expecting <math>\Delta</math> to be zero in the case of perfect equivariance. This is due to, as proven earlier, the following two terms equaling each other in the continuous case: <math>\small L_{R_i} \Phi(f_i) - \phi(L_{R_i} f_i)</math>. The results are shown in Figure 3. <br />
<br />
[[File:paper26-fig3.png|center]]<br />
<br />
<math>\Delta</math> only grows with resolution/layers when there is no activation function. With ReLU activation the error stays constant once slightly higher than 0 resolution. The authors indicate that the error must therefore be from the feature map rotation since this type of error is exact only for bandlimited functions.<br />
<br />
==MNIST Data==<br />
The experiment using MNIST data was created by projecting MNIST digits onto a sphere using stereographic projection to create the resulting images as seen in Figure 4.<br />
<br />
[[File:paper26-fig4.png|center]]<br />
<br />
The authors created two datasets, one with the projected digits and the other with the same projected digits which were then subjected to a random rotation. The spherical CNN architecture used was <math>\small S^2</math>conv-ReLU-SO(3)conv-ReLU-FC-softmax and was attempted with bandwidths of 30,10,6 and 20,40,10 channels for each layer respectively. This model was compared to a baseline CNN with layers conv-ReLU-conv-ReLU-FC-softmax with 5x5 filters, 32,64,10 channels and stride of 3. For comparison this leads to approximately 68K parameters for the baseline and 58K parameters for the spherical CNN. Results can be seen in Table 1. It is clear from the results that the spherical CNN architecture made the network rotationally invariant. Performance on the rotated set is almost identical to the non-rotated set. This is true even when trained on the non-rotated set and tested on the rotated set. Compare this to the non-spherical architecture which becomes unusable when rotating the digits.<br />
<br />
[[File:paper26-tab1.png|center]]<br />
<br />
==SHREC17==<br />
The SHREC dataset contains 3D models from the ShapeNet dataset which are classified into categories. It consists of a regularly aligned dataset and a rotated dataset. The models from the SHREC17 dataset were projected onto a sphere by means of raycasting. Different properties of the objects obtained from the raycast of the original model and the convex hull of the model make up the different channels which are input into the spherical CNN.<br />
<br />
<br />
[[File:paper26-fig5.png|center]]<br />
<br />
<br />
The network architecture used is an initial <math>\small S^2</math>conv-BN-ReLU block which is followed by two SO(3)conv-BN-ReLU blocks. The output is then fed into a MaxPool-BN block then a linear layer to the output for final classification. The architecture for this experiment has ~1.4M parameters, far exceeding the scale of the spherical CNNs in the other experiments.<br />
<br />
This architecture achieves state of the art results on the SHREC17 tasks. The model places 2nd or 3rd in all categories but was not submitted as the SHREC17 task is closed. Table 2 shows the comparison of results with the top 3 submissions in each category. In the table, P@N stands for precision, R@N stands for recall, F1@N stands for F-score, mAP stands for mean average precision, and NDCG stands for normalized discounted cumulative gain in relevance based on whether the category and subcategory labels are predicted correctly. The authors claim the results show empirical proof of the usefulness of spherical CNNs. They elaborate that this is largely due to the fact that most architectures on the SHREC17 competition are highly specialized whereas their model is fairly general.<br />
<br />
<br />
[[File:paper26-tab2.png|center]]<br />
<br />
==Molecular Atomization==<br />
In this experiment a spherical CNN is implemented with an architecture resembling that of ResNet. They use the QM7 dataset which has the task of predicting atomization energy of molecules. The positions and charges given in the dataset are projected onto the sphere using potential functions. A summary of their results is shown in Table 3 along with some of the spherical CNN architecture details. It shows the different RMSE obtained from different methods. The results from this final experiment also seem to be promising as the network the authors present achieves the second best score. They also note that the first place method grows exponentially with the number of atoms per molecule so is unlikely to scale well.<br />
<br />
[[File:paper26-tab3.png|center]]<br />
<br />
= Conclusions =<br />
This paper presents a novel architecture called Spherical CNNs. The paper defines <math>\small S^2</math> and SO(3) cross correlations, shows the theory behind their rotational invariance for continuous functions, and demonstrates that the invariance also applies to the discrete case. An effective GFFT algorithm was implemented and evaluated on two very different datasets with close to state of the art results, demonstrating that there are practical applications to Spherical CNNs.<br />
<br />
For future work the authors believe that improvements can be obtained by generalizing the algorithms to the SE(3) group (SE(3) simply adds translations in 3D space to the SO(3) group). The authors also briefly mention their excitement for applying Spherical CNNs to omnidirectional vision such as in drones and autonomous cars. They state that there is very little publicly available omnidirectional image data which could be why they did not conduct any experiments in this area.<br />
<br />
= Commentary =<br />
The reviews on Spherical CNNs are very positive and it is ranked in the top 1% of papers submitted to ICLR 2018. Positive points are the novelty of the architecture, the wide variety of experiments performed, and the writing. One critique of the original submission is that the related works section only lists, instead of describing, previous methods and that a description of the methods would have provided more clarity. The authors have since expanded the section however I found that it is still limited which the authors attribute to length limitations. Another critique is that the evaluation does not provide enough depth. For example, it would have been great to see an example of omnidirectional vision for spherical networks. However, this is to be expected as it is just the introduction of spherical CNNs and more work is sure to come.<br />
<br />
= Source Code =<br />
Source code is available at:<br />
https://github.com/jonas-koehler/s2cnn<br />
<br />
= Sources =<br />
* T. Cohen et al. Spherical CNNs, 2018.<br />
* J. Feldman. Haar Measure. http://www.math.ubc.ca/~feldman/m606/haar.pdf<br />
* P. Kostelec, D. Rockmore. FFTs on the Rotation Group, 2008.</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Spherical_CNNs&diff=35130Spherical CNNs2018-03-22T04:13:24Z<p>X249wang: /* Equivariance Error */</p>
<hr />
<div>= Introduction =<br />
Convolutional Neural Networks (CNNs), or network architectures involving CNNs, are the current state of the art for learning 2D image processing tasks such as semantic segmentation and object detection. CNNs work well in large part due to the property of being translationally equivariant. This property allows a network trained to detect a certain type of object to still detect the object even if it is translated to another position in the image. However, this does not correspond well to spherical signals since projecting a spherical signal onto a plane will result in distortions, as demonstrated in Figure 1. There are many different types of spherical projections onto a 2D plane, as most people know from the various types of world maps, none of which provide all the necessary properties for rotation-invariant learning.<br />
<br />
[[File:paper26-fig1.png|center]]<br />
<br />
= Notation =<br />
Below are listed several important terms:<br />
* '''Unit Sphere''' <math>S^2</math> is defined as a sphere where all of its points are distance of 1 from the origin. The unit sphere can be parameterized by the spherical coordinates <math>\alpha ∈ [0, 2π]</math> and <math>β ∈ [0, π]</math>. This is a two-dimensional manifold with respect to <math>\alpha</math> and <math>β</math>.<br />
* '''<math>S^2</math> Sphere''' The three dimensional surface from a 3D sphere<br />
* '''Spherical Signals''' In the paper spherical images and filters are modeled as continuous functions <math>f : s^2 → \mathbb{R}^K</math>. K is the number of channels. Such as how RGB images have 3 channels a spherical signal can have numerous channels describing the data. Examples of channels which were used can be found in the experiments section.<br />
* '''Rotations - SO(3)''' The group of 3D rotations on an <math>S^2</math> sphere. Sometimes called the "special orthogonal group". In this paper the ZYZ-Euler parameterization is used to represent SO(3) rotations with <math>\alpha, \beta</math>, and <math>\gamma</math>. Any rotation can be broken down into first a rotation (<math>\alpha</math>) about the Z-axis, then a rotation (<math>\beta</math>) about the new Y-axis (Y'), followed by a rotation (<math>\gamma</math>) about the new Z axis (Z"). [In the rest of this paper, to integrate functions on SO(3), the authors use a rotationally invariant probability measure on the Borel subsets of SO(3). This measure is an example of a Haar measure. Haar measures generalize the idea of rotationally invariant probability measures to general topological groups. For more on Haar measures, see (Feldman 2002) ]<br />
<br />
= Related Work =<br />
The related work presented in this paper is very brief, in large part due to the novelty of spherical CNNs and the length of the rest of the paper. The authors enumerate numerous papers which attempt to exploit larger groups of symmetries such as the translational symmetries of CNNs but do not go into specific details for any of these attempts. They do state that all the previous works are limited to discrete groups with the exception of SO(2)-steerable networks.<br />
The authors also mention that previous works exist that analyze spherical images but that these do not have an equivariant architecture. They claim that Spherical CNNs are "the first to achieve equivariance to a continuous, non-commutative group (SO(3))". They also claim to be the first to use the generalized Fourier transform for speed effective performance of group correlation.<br />
<br />
= Correlations on the Sphere and Rotation Group =<br />
Spherical correlation is like planar correlation except instead of translation, there is rotation. The definitions for each are provided as follows:<br />
<br />
'''Planar correlation''' The value of the output feature map at translation <math>\small x ∈ Z^2</math> is computed as an inner product between the input feature map and a filter, shifted by <math>\small x</math>.<br />
<br />
'''Spherical correlation''' The value of the output feature map evaluated at rotation <math>\small R ∈ SO(3)</math> is computed as an inner product between the input feature map and a filter, rotated by <math>\small R</math>.<br />
<br />
'''Rotation of Spherical Signals''' The paper introduces the rotation operator <math>L_R</math>. The rotation operator simply rotates a function (which allows us to rotate the the spherical filters) by <math>R^{-1}</math>. With this definition we have the property that <math>L_{RR'} = L_R L_{R'}</math>.<br />
<br />
'''Inner Products''' The inner product of spherical signals is simply the integral summation on the vector space over the entire sphere.<br />
<br />
<math>\langle\psi , f \rangle = \int_{S^2} \sum_{k=1}^K \psi_k (x)f_k (x)dx</math><br />
<br />
<math>dx</math> here is SO(3) rotation invariant and is equivalent to <math>d \alpha sin(\beta) d \beta / 4 \pi </math> in spherical coordinates. This comes from the ZYZ-Euler paramaterization where any rotation can be broken down into first a rotation about the Z-axis, then a rotation about the new Y-axis (Y'), followed by a rotation about the new Z axis (Z"). More details on this are given in Appendix A in the paper.<br />
<br />
By this definition, the invariance of the inner product is then guaranteed for any rotation <math>R ∈ SO(3)</math>. In other words, when subjected to rotations, the volume under a spherical heightmap does not change. The following equations show that <math>L_R</math> has a distinct adjoint (<math>L_{R^{-1}}</math>) and that <math>L_R</math> is unitary and thus preserves orthogonality and distances.<br />
<br />
<math>\langle L_R \psi \,, f \rangle = \int_{S^2} \sum_{k=1}^K \psi_k (R^{-1} x)f_k (x)dx</math><br />
<br />
::::<math>= \int_{S^2} \sum_{k=1}^K \psi_k (x)f_k (Rx)dx</math><br />
<br />
::::<math>= \langle \psi , L_{R^{-1}} f \rangle</math><br />
<br />
'''Spherical Correlation''' With the above knowledge the definition of spherical correlation of two signals <math>f</math> and <math>\psi</math> is:<br />
<br />
<math>[\psi \star f](R) = \langle L_R \psi \,, f \rangle = \int_{S^2} \sum_{k=1}^K \psi_k (R^{-1} x)f_k (x)dx</math><br />
<br />
The output of the above equation is a function on SO(3). This can be thought of as for each rotation combination of <math>\alpha , \beta , \gamma </math> there is a different volume under the correlation. The authors make a point of noting that previous work by Driscoll and Healey only ensures circular symmetries about the Z axis and their new formulation ensures symmetry about any rotation.<br />
<br />
'''Rotation of SO(3) Signals''' The first layer of Spherical CNNs take a function on the sphere (<math>S^2</math>) and output a function on SO(3). Therefore, if a Spherical CNN with more than one layer is going to be built there needs to be a way to find the correlation between two signals on SO(3). The authors then generalize the rotation operator (<math>L_R</math>) to encompass acting on signals from SO(3). This new definition of <math>L_R</math> is as follows: (where <math>R^{-1}Q</math> is a composition of rotations, i.e. multiplication of rotation matrices)<br />
<br />
<math>[L_Rf](Q)=f(R^{-1} Q)</math><br />
<br />
'''Rotation Group Correlation''' The correlation of two signals (<math>f,\psi</math>) on SO(3) with K channels is defined as the following:<br />
<br />
<math>[\psi \star f](R) = \langle L_R \psi , f \rangle = \int_{SO(3)} \sum_{k=1}^K \psi_k (R^{-1} Q)f_k (Q)dQ</math><br />
<br />
where dQ represents the ZYZ-Euler angles <math>d \alpha sin(\beta) d \beta d \gamma / 8 \pi^2 </math>. A complete derivation of this can be found in Appendix A.<br />
<br />
The equivariance for the rotation group correlation is similarly demonstrated.<br />
<br />
= Implementation with GFFT =<br />
The authors leverage the Generalized Fourier Transform (GFT) and Generalized Fast Fourier Transform (GFFT) algorithms to compute the correlations outlined in the previous section. The Fast Fourier Transform (FFT) can compute correlations and convolutions efficiently by means of the Fourier theorem. The Fourier theorem states that a continuous periodic function can be expressed as a sum of a series of sine or cosine terms (called Fourier coefficients). The FFT can be generalized to <math>S^2</math> and SO(3) and is then called the GFT. The GFT is a linear projection of a function onto orthogonal basis functions. The basis functions are a set of irreducible unitary representations for a group (such as for <math>S^2</math> or SO(3)). For <math>S^2</math> the basis functions are the spherical harmonics <math>Y_m^l(x)</math>. For SO(3) these basis functions are called the Wigner D-functions <math>D_{mn}^l(R)</math>. For both sets of functions the indices are restricted to <math>l\geq0</math> and <math>-l \leq m,n \geq l</math>. The Wigner D-functions are also orthogonal so the Fourier coefficients can be computed by the inner product with the Wigner D-functions (See Appendix C for complete proof). The Wigner D-functions are complete which means that any function (which is well behaved) on SO(3) can be expressed as a linear combination of the Wigner D-functions. The GFT of a function on SO(3) is thus:<br />
<br />
<math>\hat{f^l} = \int_X f(x) D^l(x)dx</math><br />
<br />
where <math>\hat{f}</math> represents the Fourier coefficients. For <math>S^2</math> we have the same equation but with the basis functions <math>Y^l</math>.<br />
<br />
The inverse SO(3) Fourier transform is:<br />
<br />
<math>f(R)=[\mathcal{F}^{-1} \hat{f}](R) = \sum_{l=0}^b (2l + 1) \sum_{m=-l}^l \sum_{n=-l}^l \hat{f_{mn}^l} D_{mn}^l(R) </math><br />
<br />
The bandwidth b represents the maximum frequency and is related to the resolution of the spatial grid. Kostelec and Rockmore are referenced for more knowledge on this topic.<br />
<br />
The authors give proofs (Appendix D) that the SO(3) correlation satisfies the Fourier theorem and the <math>S^2</math> correlation of spherical signals can be computed by the outer products of the <math>S^2</math>-FTs (Shown in Figure 2).<br />
<br />
[[File:paper26-fig2.png|center]]<br />
<br />
The GFFT algorithm details are taken from Kostelec and Rockmore. The authors claim they have the first automatically differentiable implementation of the GFT for <math>S^2</math> and SO(3). The authors do not provide any run time comparisons for real time applications (they just mentioned that FFT can be computed in <math>O(n\mathrm{log}n)</math> time) or any comparisons on training times with/without GFFT. However, they do provide the source code of their implementation at: https://github.com/jonas-koehler/s2cnn.<br />
<br />
= Experiments =<br />
The authors provide several experiments. The first set of experiments are designed to show the numerical stability and accuracy of the outlined methods. The second group of experiments demonstrates how the algorithms can be applied to current problem domains.<br />
<br />
==Equivariance Error==<br />
In this experiment the authors try to show experimentally that their theory of equivariance holds. They express that they had doubts about the equivariance in practice due to potential discretization artifacts since equivariance was proven for the continuous case, with the potential consequence of equivariance not holding being that the weight sharing scheme becomes less effective. The experiment is set up by first testing the equivariance of the SO(3) correlation at different resolutions. 500 random rotations and feature maps (with 10 channels) are sampled. They then calculate the approximation error <math>\small\Delta = \dfrac{1}{n} \sum_{i=1}^n std(L_{R_i} \Phi(f_i) - \phi(L_{R_i} f_i))/std(\Phi(f_i))</math><br />
Note: The authors do not mention what the std function is however it is likely the standard deviation function as 'std' is the command for standard deviation in MATLAB.<br />
<math>\Phi</math> is a composition of SO(3) correlation layers with filters which have been randomly initialized. The authors mention that they were expecting <math>\Delta</math> to be zero in the case of perfect equivariance. This is due to, as proven earlier, the following two terms equaling each other in the continuous case: <math>\small L_{R_i} \Phi(f_i) - \phi(L_{R_i} f_i)</math>. The results are shown in Figure 3. <br />
<br />
[[File:paper26-fig3.png|center]]<br />
<br />
<math>\Delta</math> only grows with resolution/layers when there is no activation function. With ReLU activation the error stays constant once slightly higher than 0 resolution. The authors indicate that the error must therefore be from the feature map rotation since this type of error is exact only for bandlimited functions.<br />
<br />
==MNIST Data==<br />
The experiment using MNIST data was created by projecting MNIST digits onto a sphere using stereographic projection to create the resulting images as seen in Figure 4.<br />
<br />
[[File:paper26-fig4.png|center]]<br />
<br />
The authors created two datasets, one with the projected digits and the other with the same projected digits which were then subjected to a random rotation. The spherical CNN architecture used was <math>\small S^2</math>conv-ReLU-SO(3)conv-ReLU-FC-softmax and was attempted with bandwidths of 30,10,6 and 20,40,10 channels for each layer respectively. This model was compared to a baseline CNN with layers conv-ReLU-conv-ReLU-FC-softmax with 5x5 filters, 32,64,10 channels and stride of 3. For comparison this leads to approximately 68K parameters for the baseline and 58K parameters for the spherical CNN. Results can be seen in Table 1. It is clear from the results that the spherical CNN architecture made the network rotationally invariant. Performance on the rotated set is almost identical to the non-rotated set. This is true even when trained on the non-rotated set and tested on the rotated set. Compare this to the non-spherical architecture which becomes unusable when rotating the digits.<br />
<br />
[[File:paper26-tab1.png|center]]<br />
<br />
==SHREC17==<br />
The SHREC dataset contains 3D models from the ShapeNet dataset which are classified into categories. It consists of a regularly aligned dataset and a rotated dataset. The models from the SHREC17 dataset were projected onto a sphere by means of raycasting. Different properties of the objects obtained from the raycast of the original model and the convex hull of the model make up the different channels which are input into the spherical CNN.<br />
<br />
<br />
[[File:paper26-fig5.png|center]]<br />
<br />
<br />
The network architecture used is an initial <math>\small S^2</math>conv-BN-ReLU block which is followed by two SO(3)conv-BN-ReLU blocks. The output is then fed into a MaxPool-BN block then a linear layer to the output for final classification. The architecture for this experiment has ~1.4M parameters, far exceeding the scale of the spherical CNNs in the other experiments.<br />
<br />
This architecture achieves state of the art results on the SHREC17 tasks. The model places 2nd or 3rd in all categories but was not submitted as the SHREC17 task is closed. Table 2 shows the comparison of results with the top 3 submissions in each category. The authors claim the results show empirical proof of the usefulness of spherical CNNs. They elaborate that this is largely due to the fact that most architectures on the SHREC17 competition are highly specialized whereas their model is fairly general.<br />
<br />
<br />
[[File:paper26-tab2.png|center]]<br />
<br />
==Molecular Atomization==<br />
In this experiment a spherical CNN is implemented with an architecture resembling that of ResNet. They use the QM7 dataset which has the task of predicting atomization energy of molecules. The positions and charges given in the dataset are projected onto the sphere using potential functions. A summary of their results is shown in Table 3 along with some of the spherical CNN architecture details. It shows the different RMSE obtained from different methods. The results from this final experiment also seem to be promising as the network the authors present achieves the second best score. They also note that the first place method grows exponentially with the number of atoms per molecule so is unlikely to scale well.<br />
<br />
[[File:paper26-tab3.png|center]]<br />
<br />
= Conclusions =<br />
This paper presents a novel architecture called Spherical CNNs. The paper defines <math>\small S^2</math> and SO(3) cross correlations, shows the theory behind their rotational invariance for continuous functions, and demonstrates that the invariance also applies to the discrete case. An effective GFFT algorithm was implemented and evaluated on two very different datasets with close to state of the art results, demonstrating that there are practical applications to Spherical CNNs.<br />
<br />
For future work the authors believe that improvements can be obtained by generalizing the algorithms to the SE(3) group (SE(3) simply adds translations in 3D space to the SO(3) group). The authors also briefly mention their excitement for applying Spherical CNNs to omnidirectional vision such as in drones and autonomous cars. They state that there is very little publicly available omnidirectional image data which could be why they did not conduct any experiments in this area.<br />
<br />
= Commentary =<br />
The reviews on Spherical CNNs are very positive and it is ranked in the top 1% of papers submitted to ICLR 2018. Positive points are the novelty of the architecture, the wide variety of experiments performed, and the writing. One critique of the original submission is that the related works section only lists, instead of describing, previous methods and that a description of the methods would have provided more clarity. The authors have since expanded the section however I found that it is still limited which the authors attribute to length limitations. Another critique is that the evaluation does not provide enough depth. For example, it would have been great to see an example of omnidirectional vision for spherical networks. However, this is to be expected as it is just the introduction of spherical CNNs and more work is sure to come.<br />
<br />
= Source Code =<br />
Source code is available at:<br />
https://github.com/jonas-koehler/s2cnn<br />
<br />
= Sources =<br />
* T. Cohen et al. Spherical CNNs, 2018.<br />
* J. Feldman. Haar Measure. http://www.math.ubc.ca/~feldman/m606/haar.pdf<br />
* P. Kostelec, D. Rockmore. FFTs on the Rotation Group, 2008.</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Spherical_CNNs&diff=35120Spherical CNNs2018-03-22T03:48:48Z<p>X249wang: /* Implementation with GFFT */</p>
<hr />
<div>= Introduction =<br />
Convolutional Neural Networks (CNNs), or network architectures involving CNNs, are the current state of the art for learning 2D image processing tasks such as semantic segmentation and object detection. CNNs work well in large part due to the property of being translationally equivariant. This property allows a network trained to detect a certain type of object to still detect the object even if it is translated to another position in the image. However, this does not correspond well to spherical signals since projecting a spherical signal onto a plane will result in distortions, as demonstrated in Figure 1. There are many different types of spherical projections onto a 2D plane, as most people know from the various types of world maps, none of which provide all the necessary properties for rotation-invariant learning.<br />
<br />
[[File:paper26-fig1.png|center]]<br />
<br />
= Notation =<br />
Below are listed several important terms:<br />
* '''Unit Sphere''' <math>S^2</math> is defined as a sphere where all of its points are distance of 1 from the origin. The unit sphere can be parameterized by the spherical coordinates <math>\alpha ∈ [0, 2π]</math> and <math>β ∈ [0, π]</math>. This is a two-dimensional manifold with respect to <math>\alpha</math> and <math>β</math>.<br />
* '''<math>S^2</math> Sphere''' The three dimensional surface from a 3D sphere<br />
* '''Spherical Signals''' In the paper spherical images and filters are modeled as continuous functions <math>f : s^2 → \mathbb{R}^K</math>. K is the number of channels. Such as how RGB images have 3 channels a spherical signal can have numerous channels describing the data. Examples of channels which were used can be found in the experiments section.<br />
* '''Rotations - SO(3)''' The group of 3D rotations on an <math>S^2</math> sphere. Sometimes called the "special orthogonal group". In this paper the ZYZ-Euler parameterization is used to represent SO(3) rotations with <math>\alpha, \beta</math>, and <math>\gamma</math>. Any rotation can be broken down into first a rotation (<math>\alpha</math>) about the Z-axis, then a rotation (<math>\beta</math>) about the new Y-axis (Y'), followed by a rotation (<math>\gamma</math>) about the new Z axis (Z"). [In the rest of this paper, to integrate functions on SO(3), the authors use a rotationally invariant probability measure on the Borel subsets of SO(3). This measure is an example of a Haar measure. Haar measures generalize the idea of rotationally invariant probability measures to general topological groups. For more on Haar measures, see (Feldman 2002) ]<br />
<br />
= Related Work =<br />
The related work presented in this paper is very brief, in large part due to the novelty of spherical CNNs and the length of the rest of the paper. The authors enumerate numerous papers which attempt to exploit larger groups of symmetries such as the translational symmetries of CNNs but do not go into specific details for any of these attempts. They do state that all the previous works are limited to discrete groups with the exception of SO(2)-steerable networks.<br />
The authors also mention that previous works exist that analyze spherical images but that these do not have an equivariant architecture. They claim that Spherical CNNs are "the first to achieve equivariance to a continuous, non-commutative group (SO(3))". They also claim to be the first to use the generalized Fourier transform for speed effective performance of group correlation.<br />
<br />
= Correlations on the Sphere and Rotation Group =<br />
Spherical correlation is like planar correlation except instead of translation, there is rotation. The definitions for each are provided as follows:<br />
<br />
'''Planar correlation''' The value of the output feature map at translation <math>\small x ∈ Z^2</math> is computed as an inner product between the input feature map and a filter, shifted by <math>\small x</math>.<br />
<br />
'''Spherical correlation''' The value of the output feature map evaluated at rotation <math>\small R ∈ SO(3)</math> is computed as an inner product between the input feature map and a filter, rotated by <math>\small R</math>.<br />
<br />
'''Rotation of Spherical Signals''' The paper introduces the rotation operator <math>L_R</math>. The rotation operator simply rotates a function (which allows us to rotate the the spherical filters) by <math>R^{-1}</math>. With this definition we have the property that <math>L_{RR'} = L_R L_{R'}</math>.<br />
<br />
'''Inner Products''' The inner product of spherical signals is simply the integral summation on the vector space over the entire sphere.<br />
<br />
<math>\langle\psi , f \rangle = \int_{S^2} \sum_{k=1}^K \psi_k (x)f_k (x)dx</math><br />
<br />
<math>dx</math> here is SO(3) rotation invariant and is equivalent to <math>d \alpha sin(\beta) d \beta / 4 \pi </math> in spherical coordinates. This comes from the ZYZ-Euler paramaterization where any rotation can be broken down into first a rotation about the Z-axis, then a rotation about the new Y-axis (Y'), followed by a rotation about the new Z axis (Z"). More details on this are given in Appendix A in the paper.<br />
<br />
By this definition, the invariance of the inner product is then guaranteed for any rotation <math>R ∈ SO(3)</math>. In other words, when subjected to rotations, the volume under a spherical heightmap does not change. The following equations show that <math>L_R</math> has a distinct adjoint (<math>L_{R^{-1}}</math>) and that <math>L_R</math> is unitary and thus preserves orthogonality and distances.<br />
<br />
<math>\langle L_R \psi \,, f \rangle = \int_{S^2} \sum_{k=1}^K \psi_k (R^{-1} x)f_k (x)dx</math><br />
<br />
::::<math>= \int_{S^2} \sum_{k=1}^K \psi_k (x)f_k (Rx)dx</math><br />
<br />
::::<math>= \langle \psi , L_{R^{-1}} f \rangle</math><br />
<br />
'''Spherical Correlation''' With the above knowledge the definition of spherical correlation of two signals <math>f</math> and <math>\psi</math> is:<br />
<br />
<math>[\psi \star f](R) = \langle L_R \psi \,, f \rangle = \int_{S^2} \sum_{k=1}^K \psi_k (R^{-1} x)f_k (x)dx</math><br />
<br />
The output of the above equation is a function on SO(3). This can be thought of as for each rotation combination of <math>\alpha , \beta , \gamma </math> there is a different volume under the correlation. The authors make a point of noting that previous work by Driscoll and Healey only ensures circular symmetries about the Z axis and their new formulation ensures symmetry about any rotation.<br />
<br />
'''Rotation of SO(3) Signals''' The first layer of Spherical CNNs take a function on the sphere (<math>S^2</math>) and output a function on SO(3). Therefore, if a Spherical CNN with more than one layer is going to be built there needs to be a way to find the correlation between two signals on SO(3). The authors then generalize the rotation operator (<math>L_R</math>) to encompass acting on signals from SO(3). This new definition of <math>L_R</math> is as follows: (where <math>R^{-1}Q</math> is a composition of rotations, i.e. multiplication of rotation matrices)<br />
<br />
<math>[L_Rf](Q)=f(R^{-1} Q)</math><br />
<br />
'''Rotation Group Correlation''' The correlation of two signals (<math>f,\psi</math>) on SO(3) with K channels is defined as the following:<br />
<br />
<math>[\psi \star f](R) = \langle L_R \psi , f \rangle = \int_{SO(3)} \sum_{k=1}^K \psi_k (R^{-1} Q)f_k (Q)dQ</math><br />
<br />
where dQ represents the ZYZ-Euler angles <math>d \alpha sin(\beta) d \beta d \gamma / 8 \pi^2 </math>. A complete derivation of this can be found in Appendix A.<br />
<br />
The equivariance for the rotation group correlation is similarly demonstrated.<br />
<br />
= Implementation with GFFT =<br />
The authors leverage the Generalized Fourier Transform (GFT) and Generalized Fast Fourier Transform (GFFT) algorithms to compute the correlations outlined in the previous section. The Fast Fourier Transform (FFT) can compute correlations and convolutions efficiently by means of the Fourier theorem. The Fourier theorem states that a continuous periodic function can be expressed as a sum of a series of sine or cosine terms (called Fourier coefficients). The FFT can be generalized to <math>S^2</math> and SO(3) and is then called the GFT. The GFT is a linear projection of a function onto orthogonal basis functions. The basis functions are a set of irreducible unitary representations for a group (such as for <math>S^2</math> or SO(3)). For <math>S^2</math> the basis functions are the spherical harmonics <math>Y_m^l(x)</math>. For SO(3) these basis functions are called the Wigner D-functions <math>D_{mn}^l(R)</math>. For both sets of functions the indices are restricted to <math>l\geq0</math> and <math>-l \leq m,n \geq l</math>. The Wigner D-functions are also orthogonal so the Fourier coefficients can be computed by the inner product with the Wigner D-functions (See Appendix C for complete proof). The Wigner D-functions are complete which means that any function (which is well behaved) on SO(3) can be expressed as a linear combination of the Wigner D-functions. The GFT of a function on SO(3) is thus:<br />
<br />
<math>\hat{f^l} = \int_X f(x) D^l(x)dx</math><br />
<br />
where <math>\hat{f}</math> represents the Fourier coefficients. For <math>S^2</math> we have the same equation but with the basis functions <math>Y^l</math>.<br />
<br />
The inverse SO(3) Fourier transform is:<br />
<br />
<math>f(R)=[\mathcal{F}^{-1} \hat{f}](R) = \sum_{l=0}^b (2l + 1) \sum_{m=-l}^l \sum_{n=-l}^l \hat{f_{mn}^l} D_{mn}^l(R) </math><br />
<br />
The bandwidth b represents the maximum frequency and is related to the resolution of the spatial grid. Kostelec and Rockmore are referenced for more knowledge on this topic.<br />
<br />
The authors give proofs (Appendix D) that the SO(3) correlation satisfies the Fourier theorem and the <math>S^2</math> correlation of spherical signals can be computed by the outer products of the <math>S^2</math>-FTs (Shown in Figure 2).<br />
<br />
[[File:paper26-fig2.png|center]]<br />
<br />
The GFFT algorithm details are taken from Kostelec and Rockmore. The authors claim they have the first automatically differentiable implementation of the GFT for <math>S^2</math> and SO(3). The authors do not provide any run time comparisons for real time applications (they just mentioned that FFT can be computed in <math>O(n\mathrm{log}n)</math> time) or any comparisons on training times with/without GFFT. However, they do provide the source code of their implementation at: https://github.com/jonas-koehler/s2cnn.<br />
<br />
= Experiments =<br />
The authors provide several experiments. The first set of experiments are designed to show the numerical stability and accuracy of the outlined methods. The second group of experiments demonstrates how the algorithms can be applied to current problem domains.<br />
<br />
==Equivariance Error==<br />
In this experiment the authors try to show experimentally that their theory of equivariance holds. They express that they had doubts about the equivariance in practice due to potential discretization artifacts since equivariance was proven for the continuous case. The experiment is set up by first testing the equivariance of the SO(3) correlation at different resolutions. 500 random rotations and feature maps (with 10 channels) are sampled. They then calculate <math>\small\Delta = \dfrac{1}{n} \sum_{i=1}^n std(L_{R_i} \Phi(f_i) - \phi(L_{R_i} f_i))/std(\Phi(f_i))</math><br />
Note: The authors do not mention what the std function is however it is likely the standard deviation function as 'std' is the command for standard deviation in MATLAB.<br />
<math>\Phi</math> is a composition of SO(3) correlation layers with filters which have been randomly initialized. The authors mention that they were expecting <math>\Delta</math> to be zero in the case of perfect equivariance. This is due to, as proven earlier, the following two terms equaling each other in the continuous case: <math>\small L_{R_i} \Phi(f_i) - \phi(L_{R_i} f_i)</math>. The results are shown in Figure 3. <br />
<br />
[[File:paper26-fig3.png|center]]<br />
<br />
<math>\Delta</math> only grows with resolution/layers when there is no activation function. With ReLU activation the error stays constant once slightly higher than 0 resolution. The authors indicate that the error must therefore be from the feature map rotation since this type of error is exact only for bandlimited functions.<br />
<br />
==MNIST Data==<br />
The experiment using MNIST data was created by projecting MNIST digits onto a sphere using stereographic projection to create the resulting images as seen in Figure 4.<br />
<br />
[[File:paper26-fig4.png|center]]<br />
<br />
The authors created two datasets, one with the projected digits and the other with the same projected digits which were then subjected to a random rotation. The spherical CNN architecture used was <math>\small S^2</math>conv-ReLU-SO(3)conv-ReLU-FC-softmax and was attempted with bandwidths of 30,10,6 and 20,40,10 channels for each layer respectively. This model was compared to a baseline CNN with layers conv-ReLU-conv-ReLU-FC-softmax with 5x5 filters, 32,64,10 channels and stride of 3. For comparison this leads to approximately 68K parameters for the baseline and 58K parameters for the spherical CNN. Results can be seen in Table 1. It is clear from the results that the spherical CNN architecture made the network rotationally invariant. Performance on the rotated set is almost identical to the non-rotated set. This is true even when trained on the non-rotated set and tested on the rotated set. Compare this to the non-spherical architecture which becomes unusable when rotating the digits.<br />
<br />
[[File:paper26-tab1.png|center]]<br />
<br />
==SHREC17==<br />
The SHREC dataset contains 3D models from the ShapeNet dataset which are classified into categories. It consists of a regularly aligned dataset and a rotated dataset. The models from the SHREC17 dataset were projected onto a sphere by means of raycasting. Different properties of the objects obtained from the raycast of the original model and the convex hull of the model make up the different channels which are input into the spherical CNN.<br />
<br />
<br />
[[File:paper26-fig5.png|center]]<br />
<br />
<br />
The network architecture used is an initial <math>\small S^2</math>conv-BN-ReLU block which is followed by two SO(3)conv-BN-ReLU blocks. The output is then fed into a MaxPool-BN block then a linear layer to the output for final classification. The architecture for this experiment has ~1.4M parameters, far exceeding the scale of the spherical CNNs in the other experiments.<br />
<br />
This architecture achieves state of the art results on the SHREC17 tasks. The model places 2nd or 3rd in all categories but was not submitted as the SHREC17 task is closed. Table 2 shows the comparison of results with the top 3 submissions in each category. The authors claim the results show empirical proof of the usefulness of spherical CNNs. They elaborate that this is largely due to the fact that most architectures on the SHREC17 competition are highly specialized whereas their model is fairly general.<br />
<br />
<br />
[[File:paper26-tab2.png|center]]<br />
<br />
==Molecular Atomization==<br />
In this experiment a spherical CNN is implemented with an architecture resembling that of ResNet. They use the QM7 dataset which has the task of predicting atomization energy of molecules. The positions and charges given in the dataset are projected onto the sphere using potential functions. A summary of their results is shown in Table 3 along with some of the spherical CNN architecture details. It shows the different RMSE obtained from different methods. The results from this final experiment also seem to be promising as the network the authors present achieves the second best score. They also note that the first place method grows exponentially with the number of atoms per molecule so is unlikely to scale well.<br />
<br />
[[File:paper26-tab3.png|center]]<br />
<br />
= Conclusions =<br />
This paper presents a novel architecture called Spherical CNNs. The paper defines <math>\small S^2</math> and SO(3) cross correlations, shows the theory behind their rotational invariance for continuous functions, and demonstrates that the invariance also applies to the discrete case. An effective GFFT algorithm was implemented and evaluated on two very different datasets with close to state of the art results, demonstrating that there are practical applications to Spherical CNNs.<br />
<br />
For future work the authors believe that improvements can be obtained by generalizing the algorithms to the SE(3) group (SE(3) simply adds translations in 3D space to the SO(3) group). The authors also briefly mention their excitement for applying Spherical CNNs to omnidirectional vision such as in drones and autonomous cars. They state that there is very little publicly available omnidirectional image data which could be why they did not conduct any experiments in this area.<br />
<br />
= Commentary =<br />
The reviews on Spherical CNNs are very positive and it is ranked in the top 1% of papers submitted to ICLR 2018. Positive points are the novelty of the architecture, the wide variety of experiments performed, and the writing. One critique of the original submission is that the related works section only lists, instead of describing, previous methods and that a description of the methods would have provided more clarity. The authors have since expanded the section however I found that it is still limited which the authors attribute to length limitations. Another critique is that the evaluation does not provide enough depth. For example, it would have been great to see an example of omnidirectional vision for spherical networks. However, this is to be expected as it is just the introduction of spherical CNNs and more work is sure to come.<br />
<br />
= Source Code =<br />
Source code is available at:<br />
https://github.com/jonas-koehler/s2cnn<br />
<br />
= Sources =<br />
* T. Cohen et al. Spherical CNNs, 2018.<br />
* J. Feldman. Haar Measure. http://www.math.ubc.ca/~feldman/m606/haar.pdf<br />
* P. Kostelec, D. Rockmore. FFTs on the Rotation Group, 2008.</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Do_Deep_Neural_Networks_Suffer_from_Crowding&diff=35084Do Deep Neural Networks Suffer from Crowding2018-03-22T02:39:29Z<p>X249wang: /* Experiments and its Set-Up */</p>
<hr />
<div>= Introduction =<br />
Ever since the evolution of Deep Networks, there has been tremendous amount of research and effort that has been put into making machines capable of recognizing objects the same way as humans do. Humans can recognize objects in a way that is invariant to scale, translation, and clutter. Crowding is another visual effect suffered by humans, in which an object that can be recognized in isolation can no longer be recognized when other objects, called flankers, are placed close to it and this is a very common real-life experience. This paper focuses on studying the impact of crowding on Deep Neural Networks (DNNs) by adding clutter to the images and then analyzing which models and settings suffer less from such effects. <br />
<br />
[[File:paper25_fig_crowding_ex.png|center|600px]]<br />
The figure shows a visual example of crowding [3]. Keep your eyes still and look at the dot in the center and try to identify the "A" in the two circles. You should see that it is much easier to make out the "A" in the right than in the left circle. The same "A" exists in both circles, however the left circle contains flankers which are those line segments.<br />
<br />
The paper investigates two types of DNNs for crowding: traditional deep convolutional neural networks(DCNN) and a multi-scale eccentricity-dependent model which is an extension of the DCNNs and inspired by the retina where the receptive field size of the convolutional filters in the model grows with increasing distance from the center of the image, called the eccentricity and will be explained below. The authors focus on the dependence of crowding on image factors, such as flanker configuration, target-flanker similarity, target eccentricity and premature pooling in particular.<br />
<br />
= Models =<br />
== Deep Convolutional Neural Networks ==<br />
The DCNN is a basic architecture with 3 convolutional layers, spatial 3x3 max-pooling with varying strides and a fully connected layer for classification as shown in the below figure. <br />
[[File:DCNN.png|800px|center]]<br />
<br />
The network is fed with images resized to 60x60, with minibatches of 128 images, 32 feature channels for all convolutional layers, and convolutional filters of size 5x5 and stride 1.<br />
<br />
As highlighted earlier, the effect of pooling is into main consideration and hence three different configurations have been investigated as below: <br />
<br />
1. '''No total pooling''' Feature maps sizes decrease only due to boundary effects, as the 3x3 max pooling has stride 1. The square feature maps sizes after each pool layer are 60-54-48-42.<br />
<br />
2. '''Progressive pooling''' 3x3 pooling with a stride of 2 halves the square size of the feature maps, until we pool over what remains in the final layer, getting rid of any spatial information before the fully connected layer. (60-27-11-1).<br />
<br />
3. '''At end pooling''' Same as no total pooling, but before the fully connected layer, max-pool over the entire feature map. (60-54-48-1).<br />
<br />
===What is the problem in CNNs?===<br />
CNNs fall short in explaining human perceptual invariance. First, CNNs typically take input at a single uniform resolution. Biological measurements suggest that resolution is not uniform across the human visual field, but rather decays with eccentricity, i.e. distance from the center of focus Even more importantly, CNNs rely on data augmentation to achieve transformation-invariance and obviously a lot of processing is needed for CNNs.<br />
<br />
==Eccentricity-dependent Model==<br />
In order to take care of the scale invariance in the input image, the eccentricity dependent DNN is utilized. The main intuition behind this architecture is that as we increase eccentricity, the receptive fields also increase and hence the model will become invariant to changing input scales. It emphasizes scale invariance over translation invariance, in contrast to traditional DCNNs. In this model the input image is cropped into varying scales (11 crops increasing by a factor of <math>\sqrt{2}</math> which are then resized to 60x60 pixels) and then fed to the network. The model computes an invariant representation of the input by sampling the inverted pyramid at a discrete set of scales with the same number of filters at each scale. Since the same number of filters are used for each scale, the smaller crops will be sampled at a high resolution while the larger crops will be sampled with a low resolution. These scales are fed into the network as an input channel to the convolutional layers and share the weights across scale and space.<br />
[[File:EDM.png|2000x450px|center]]<br />
<br />
The architecture of this model is the same as the previous DCNN model with the only change being the extra filters added for each of the scales, so the number of parameters remain the same as DCNN models. The authors perform spatial pooling, the aforementioned ''At end pooling'' is used here, and scale pooling which helps in reducing number of scales by taking the maximum value of corresponding locations in the feature maps across multiple scales. It has three configurations: (1) at the beginning, in which all the different scales are pooled together after the first layer, 11-1-1-1-1 (2) progressively, 11-7-5-3-1 and (3) at the end, 11-11-11-11-1, in which all 11 scales are pooled together at the last layer.<br />
<br />
===Contrast Normalization===<br />
Since we have multiple scales of input image, in some experiments, we perform normalization such that the sum of the pixel intensities in each scale is in the same range [0,1] (this is to prevent smaller crops, which have more non-black pixels, from disproportionately dominating max-pooling across scales). The normalized pixel intensities are then divided by a factor proportional to the crop area [[File:sqrtf.png|60px]] where i=1 is the smallest crop.<br />
<br />
=Experiments and its Set-Up =<br />
Targets are the set of objects to be recognized and flankers are the set of objects the model has not been trained to recognize, which act as clutter with respect to these target objects. The target objects are the even MNIST numbers having translational variance (shifted at different locations of the image along the horizontal axis), while flankers are from odd MNIST numbers, notMNIST dataset (contains alphabet letters) and Omniglot dataset (contains characters). Examples of the target and flanker configurations are shown below: <br />
[[File:eximages.png|800px|center]]<br />
<br />
The target and the object are referred to as ''a'' and ''x'' respectively with the below four configurations: (1) No flankers. Only the target object. (a in the plots) (2) One central flanker closer to the center of the image than the target. (xa) (3) One peripheral flanker closer to the boundary of the image that the target. (ax) (4) Two flankers spaced equally around the target, being both the same object (xax).<br />
<br />
==DNNs trained with Target and Flankers==<br />
This is a constant spacing training setup where identical flankers are placed at a distance of 120 pixels either side of the target(xax) with the target having translational variance. The tests are evaluated on (i) DCNN with at the end pooling, and (ii) eccentricity-dependent model with 11-11-11-11-1 scale pooling, at the end spatial pooling and contrast normalization. The test data has different flanker configurations as described above.<br />
[[File:result1.png|x450px|center]]<br />
<br />
===Observations===<br />
* With the flanker configuration same as the training one, models are better at recognizing objects in clutter rather than isolated objects for all image locations<br />
<br />
* If the target-flanker spacing is changed, then models perform worse<br />
<br />
* the eccentricity model is much better at recognizing objects in isolation than the DCNN because the multi-scale crops divide the image into discrete regions, letting the model learn from image parts as well as the whole image<br />
<br />
* Only the eccentricity-dependent model is robust to different flanker configurations not included in training, when the target is centered.<br />
<br />
==DNNs trained with Images with the Target in Isolation==<br />
Here the target objects are in isolation and with translational variance while the test-set is the same set of flanker configurations as used before.<br />
[[File:result2.png|750x400px|center]]<br />
===DCNN Observations===<br />
* The recognition gets worse with the increase in the number of flankers.<br />
<br />
* Convolutional networks are capable of being invariant to translations.<br />
<br />
* In the constant target eccentricity setup, where target is fixed at the center of the image with varying target-flanker spacing, we observe that as the distance between target and flankers increase, recognition gets better.<br />
<br />
* Spatial pooling helps in learning invariance.<br />
<br />
*Flankers similar to the target object helps in recognition since they dont activate the convolutional filter more.<br />
<br />
* notMNIST data affects leads to more crowding since they have many more edges and white image pixels which activate the convolutional layers more.<br />
===Eccentric Model===<br />
The set-up is the same as explained earlier.<br />
[[File:result3.png|750x400px|center]]<br />
====Observations====<br />
* If target is placed at the center and no contrast normalization is done, then the recognition accuracy is high since this model concentrates the most on the central region of the image.<br />
<br />
* If contrast normalization is done, then all the scales will contribute equal amount and hence the eccentricity dependence is removed.<br />
<br />
* Early pooling is harmful since it migh take away the useful information very early which might be useful to the network.<br />
<br />
==Complex Clutter==<br />
Here, the targets are randomly embedded into images of the Places dataset and shifted along horizontally in order to investigate model robustness when the target is not at the image center. Tests are performed on DCNN and the eccentricity model with and without contrast normalization using at end pooling. The results are shown in Figure 9 below. <br />
<br />
[[File:result4.png|750x400px|center]]<br />
<br />
====Observations====<br />
- Only eccentricity model without contrast normalization can recognize the target and only when the target is close to the image center.<br />
<br />
- The eccentricity model does not need to be trained on different types of clutter to become robust to those types of clutter, but it needs to fixate on the relevant part of the image to recognize the target.<br />
<br />
=Conclusions=<br />
We often think that just training the network with data similar to the test data would achieve good results in a general scenario too but thats not the case as we trained the model with flankers and it did not give us the ideal results for the target obects.<br />
*'''Flanker Configuration''': When models are trained with images of objects in isolation, adding flankers harms recognition. Adding two flankers is the same or worse than adding just one and the smaller the spacing between flanker and target, the more crowding occurs. These is because the pooling operation merges nearby responses, such as the target and flankers if they are close.<br />
<br />
*'''Similarity between target and flanker''': Flankers more similar to targets cause more crowding, because of the selectivity property of the learned DNN filters.<br />
<br />
*'''Dependence on target location and contrast normalization''': In DCNNs and eccentricitydependent models with contrast normalization, recognition accuracy is the same across all eccentricities. In eccentricity-dependent networks without contrast normalization, recognition does not decrease despite presence of clutter when the target is at the center of the image.<br />
<br />
*'''Effect of pooling''': adding pooling leads to better recognition accuracy of the models. Yet, in the eccentricity model, pooling across the scales too early in the hierarchy leads to lower accuracy.<br />
<br />
=Critique=<br />
This paper just tries to check the impact of flankers on targets as to how crowding can affect recognition but it does not propose anything novel in terms of architecture to take care of such a type of crowding. The eccentricity based model does well only when the target is placed at the center of the image but maybe windowing over the frames instead of taking crops starting from the middle might help.<br />
<br />
=References=<br />
1) Volokitin A, Roig G, Poggio T:"Do Deep Neural Networks Suffer from Crowding?" Conference on Neural Information Processing Systems (NIPS). 2017<br />
<br />
2) Francis X. Chen, Gemma Roig, Leyla Isik, Xavier Boix and Tomaso Poggio: "Eccentricity Dependent Deep Neural Networks for Modeling Human Vision" Journal of Vision. 17. 808. 10.1167/17.10.808.<br />
<br />
3) J Harrison, W & W Remington, R & Mattingley, Jason. (2014). Visual crowding is anisotropic along the horizontal meridian during smooth pursuit. Journal of vision. 14. 10.1167/14.1.21. http://willjharrison.com/2014/01/new-paper-visual-crowding-is-anisotropic-along-the-horizontal-meridian-during-smooth-pursuit/</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Do_Deep_Neural_Networks_Suffer_from_Crowding&diff=35082Do Deep Neural Networks Suffer from Crowding2018-03-22T02:29:55Z<p>X249wang: /* Eccentricity-dependent Model */</p>
<hr />
<div>= Introduction =<br />
Ever since the evolution of Deep Networks, there has been tremendous amount of research and effort that has been put into making machines capable of recognizing objects the same way as humans do. Humans can recognize objects in a way that is invariant to scale, translation, and clutter. Crowding is another visual effect suffered by humans, in which an object that can be recognized in isolation can no longer be recognized when other objects, called flankers, are placed close to it and this is a very common real-life experience. This paper focuses on studying the impact of crowding on Deep Neural Networks (DNNs) by adding clutter to the images and then analyzing which models and settings suffer less from such effects. <br />
<br />
[[File:paper25_fig_crowding_ex.png|center|600px]]<br />
The figure shows a visual example of crowding [3]. Keep your eyes still and look at the dot in the center and try to identify the "A" in the two circles. You should see that it is much easier to make out the "A" in the right than in the left circle. The same "A" exists in both circles, however the left circle contains flankers which are those line segments.<br />
<br />
The paper investigates two types of DNNs for crowding: traditional deep convolutional neural networks(DCNN) and a multi-scale eccentricity-dependent model which is an extension of the DCNNs and inspired by the retina where the receptive field size of the convolutional filters in the model grows with increasing distance from the center of the image, called the eccentricity and will be explained below. The authors focus on the dependence of crowding on image factors, such as flanker configuration, target-flanker similarity, target eccentricity and premature pooling in particular.<br />
<br />
= Models =<br />
== Deep Convolutional Neural Networks ==<br />
The DCNN is a basic architecture with 3 convolutional layers, spatial 3x3 max-pooling with varying strides and a fully connected layer for classification as shown in the below figure. <br />
[[File:DCNN.png|800px|center]]<br />
<br />
The network is fed with images resized to 60x60, with minibatches of 128 images, 32 feature channels for all convolutional layers, and convolutional filters of size 5x5 and stride 1.<br />
<br />
As highlighted earlier, the effect of pooling is into main consideration and hence three different configurations have been investigated as below: <br />
<br />
1. '''No total pooling''' Feature maps sizes decrease only due to boundary effects, as the 3x3 max pooling has stride 1. The square feature maps sizes after each pool layer are 60-54-48-42.<br />
<br />
2. '''Progressive pooling''' 3x3 pooling with a stride of 2 halves the square size of the feature maps, until we pool over what remains in the final layer, getting rid of any spatial information before the fully connected layer. (60-27-11-1).<br />
<br />
3. '''At end pooling''' Same as no total pooling, but before the fully connected layer, max-pool over the entire feature map. (60-54-48-1).<br />
<br />
===What is the problem in CNNs?===<br />
CNNs fall short in explaining human perceptual invariance. First, CNNs typically take input at a single uniform resolution. Biological measurements suggest that resolution is not uniform across the human visual field, but rather decays with eccentricity, i.e. distance from the center of focus Even more importantly, CNNs rely on data augmentation to achieve transformation-invariance and obviously a lot of processing is needed for CNNs.<br />
<br />
==Eccentricity-dependent Model==<br />
In order to take care of the scale invariance in the input image, the eccentricity dependent DNN is utilized. The main intuition behind this architecture is that as we increase eccentricity, the receptive fields also increase and hence the model will become invariant to changing input scales. It emphasizes scale invariance over translation invariance, in contrast to traditional DCNNs. In this model the input image is cropped into varying scales (11 crops increasing by a factor of <math>\sqrt{2}</math> which are then resized to 60x60 pixels) and then fed to the network. The model computes an invariant representation of the input by sampling the inverted pyramid at a discrete set of scales with the same number of filters at each scale. Since the same number of filters are used for each scale, the smaller crops will be sampled at a high resolution while the larger crops will be sampled with a low resolution. These scales are fed into the network as an input channel to the convolutional layers and share the weights across scale and space.<br />
[[File:EDM.png|2000x450px|center]]<br />
<br />
The architecture of this model is the same as the previous DCNN model with the only change being the extra filters added for each of the scales, so the number of parameters remain the same as DCNN models. The authors perform spatial pooling, the aforementioned ''At end pooling'' is used here, and scale pooling which helps in reducing number of scales by taking the maximum value of corresponding locations in the feature maps across multiple scales. It has three configurations: (1) at the beginning, in which all the different scales are pooled together after the first layer, 11-1-1-1-1 (2) progressively, 11-7-5-3-1 and (3) at the end, 11-11-11-11-1, in which all 11 scales are pooled together at the last layer.<br />
<br />
===Contrast Normalization===<br />
Since we have multiple scales of input image, in some experiments, we perform normalization such that the sum of the pixel intensities in each scale is in the same range [0,1] (this is to prevent smaller crops, which have more non-black pixels, from disproportionately dominating max-pooling across scales). The normalized pixel intensities are then divided by a factor proportional to the crop area [[File:sqrtf.png|60px]] where i=1 is the smallest crop.<br />
<br />
=Experiments and its Set-Up =<br />
Targets are the set of objects to be recognized and flankers act as clutter with respect to these target objects. The target objects are the even MNIST numbers having translational variance (shifted at different locations of the image along the horizontal axis). Examples of the target and flanker configurations is shown below: <br />
[[File:eximages.png|800px|center]]<br />
<br />
The target and the object are referred to as ''a'' and ''x'' respectively with the below four conifgurations: (1) No flankers. Only the target object. (a in the plots) (2) One central flanker closer to the center of the image than the target. (xa) (3) One peripheral flanker closer to the boundary of the image that the target. (ax) (4) Two flankers spaced equally around the target, being both the same object (xax).<br />
<br />
==DNNs trained with Target and Flankers==<br />
This is a constant spacing training setup where identical flankers are placed at a distance of 120 pixels either side of the target(xax) with the target having translational variance. THe tests are evaluated on (i) DCNN with at the end pooling, and (ii) eccentricity-dependent model with 11-11-11-11-1 scale pooling, at the end spatial pooling and contrast normalization. The test data has different flanker configurations as described above.<br />
[[File:result1.png|x450px|center]]<br />
<br />
===Observations===<br />
* With the flanker configuration same as the training one, models are better at recognizing objects in clutter rather than isolated objects for all image locations<br />
<br />
* If the target-flanker spacing is changed, then models perform worse<br />
<br />
* the eccentricity model is much better at recognizing objects in isolation than the DCNN because the multi-scale crops divide the image into discrete regions, letting the model learn from image parts as well as the whole image<br />
<br />
* Only the eccentricity-dependent model is robust to different flanker configurations not included in training, when the target is centered.<br />
<br />
==DNNs trained with Images with the Target in Isolation==<br />
Here the target objects are in isolation and with translational variance while the test-set is the same set of flanker configurations as used before.<br />
[[File:result2.png|750x400px|center]]<br />
===DCNN Observations===<br />
* The recognition gets worse with the increase in the number of flankers.<br />
<br />
* Convolutional networks are capable of being invariant to translations.<br />
<br />
* In the constant target eccentricity setup, where target is fixed at the center of the image with varying target-flanker spacing, we observe that as the distance between target and flankers increase, recognition gets better.<br />
<br />
* Spatial pooling helps in learning invariance.<br />
<br />
*Flankers similar to the target object helps in recognition since they dont activate the convolutional filter more.<br />
<br />
* notMNIST data affects leads to more crowding since they have many more edges and white image pixels which activate the convolutional layers more.<br />
===Eccentric Model===<br />
The set-up is the same as explained earlier.<br />
[[File:result3.png|750x400px|center]]<br />
====Observations====<br />
* If target is placed at the center and no contrast normalization is done, then the recognition accuracy is high since this model concentrates the most on the central region of the image.<br />
<br />
* If contrast normalization is done, then all the scales will contribute equal amount and hence the eccentricity dependence is removed.<br />
<br />
* Early pooling is harmful since it migh take away the useful information very early which might be useful to the network.<br />
<br />
==Complex Clutter==<br />
Here, the targets are randomly embedded into images of the Places dataset and shifted along horizontally in order to investigate model robustness when the target is not at the image center. Tests are performed on DCNN and the eccentricity model with and without contrast normalization using at end pooling. The results are shown in Figure 9 below. <br />
<br />
[[File:result4.png|750x400px|center]]<br />
<br />
====Observations====<br />
- Only eccentricity model without contrast normalization can recognize the target and only when the target is close to the image center.<br />
<br />
- The eccentricity model does not need to be trained on different types of clutter to become robust to those types of clutter, but it needs to fixate on the relevant part of the image to recognize the target.<br />
<br />
=Conclusions=<br />
We often think that just training the network with data similar to the test data would achieve good results in a general scenario too but thats not the case as we trained the model with flankers and it did not give us the ideal results for the target obects.<br />
*'''Flanker Configuration''': When models are trained with images of objects in isolation, adding flankers harms recognition. Adding two flankers is the same or worse than adding just one and the smaller the spacing between flanker and target, the more crowding occurs. These is because the pooling operation merges nearby responses, such as the target and flankers if they are close.<br />
<br />
*'''Similarity between target and flanker''': Flankers more similar to targets cause more crowding, because of the selectivity property of the learned DNN filters.<br />
<br />
*'''Dependence on target location and contrast normalization''': In DCNNs and eccentricitydependent models with contrast normalization, recognition accuracy is the same across all eccentricities. In eccentricity-dependent networks without contrast normalization, recognition does not decrease despite presence of clutter when the target is at the center of the image.<br />
<br />
*'''Effect of pooling''': adding pooling leads to better recognition accuracy of the models. Yet, in the eccentricity model, pooling across the scales too early in the hierarchy leads to lower accuracy.<br />
<br />
=Critique=<br />
This paper just tries to check the impact of flankers on targets as to how crowding can affect recognition but it does not propose anything novel in terms of architecture to take care of such a type of crowding. The eccentricity based model does well only when the target is placed at the center of the image but maybe windowing over the frames instead of taking crops starting from the middle might help.<br />
<br />
=References=<br />
1) Volokitin A, Roig G, Poggio T:"Do Deep Neural Networks Suffer from Crowding?" Conference on Neural Information Processing Systems (NIPS). 2017<br />
<br />
2) Francis X. Chen, Gemma Roig, Leyla Isik, Xavier Boix and Tomaso Poggio: "Eccentricity Dependent Deep Neural Networks for Modeling Human Vision" Journal of Vision. 17. 808. 10.1167/17.10.808.<br />
<br />
3) J Harrison, W & W Remington, R & Mattingley, Jason. (2014). Visual crowding is anisotropic along the horizontal meridian during smooth pursuit. Journal of vision. 14. 10.1167/14.1.21. http://willjharrison.com/2014/01/new-paper-visual-crowding-is-anisotropic-along-the-horizontal-meridian-during-smooth-pursuit/</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Do_Deep_Neural_Networks_Suffer_from_Crowding&diff=35075Do Deep Neural Networks Suffer from Crowding2018-03-22T01:48:23Z<p>X249wang: /* Complex Clutter */</p>
<hr />
<div>= Introduction =<br />
Ever since the evolution of Deep Networks, there has been tremendous amount of research and effort that has been put into making machines capable of recognizing objects the same way as humans do. Humans can recognize objects in a way that is invariant to scale, translation, and clutter. Crowding is another visual effect suffered by humans, in which an object that can be recognized in isolation can no longer be recognized when other objects, called flankers, are placed close to it and this is a very common real-life experience. This paper focuses on studying the impact of crowding on Deep Neural Networks (DNNs) by adding clutter to the images and then analyzing which models and settings suffer less from such effects. <br />
<br />
[[File:paper25_fig_crowding_ex.png|center|600px]]<br />
The figure shows a visual example of crowding [3]. Keep your eyes still and look at the dot in the center and try to identify the "A" in the two circles. You should see that it is much easier to make out the "A" in the right than in the left circle. The same "A" exists in both circles, however the left circle contains flankers which are those line segments.<br />
<br />
The paper investigates two types of DNNs for crowding: traditional deep convolutional neural networks(DCNN) and a multi-scale eccentricity-dependent model which is an extension of the DCNNs and inspired by the retina where the receptive field size of the convolutional filters in the model grows with increasing distance from the center of the image, called the eccentricity and will be explained below. The authors focus on the dependence of crowding on image factors, such as flanker configuration, target-flanker similarity, target eccentricity and premature pooling in particular.<br />
<br />
= Models =<br />
== Deep Convolutional Neural Networks ==<br />
The DCNN is a basic architecture with 3 convolutional layers, spatial 3x3 max-pooling with varying strides and a fully connected layer for classification as shown in the below figure. <br />
[[File:DCNN.png|800px|center]]<br />
<br />
The network is fed with images resized to 60x60, with minibatches of 128 images, 32 feature channels for all convolutional layers, and convolutional filters of size 5x5 and stride 1.<br />
<br />
As highlighted earlier, the effect of pooling is into main consideration and hence three different configurations have been investigated as below: <br />
<br />
1. '''No total pooling''' Feature maps sizes decrease only due to boundary effects, as the 3x3 max pooling has stride 1. The square feature maps sizes after each pool layer are 60-54-48-42.<br />
<br />
2. '''Progressive pooling''' 3x3 pooling with a stride of 2 halves the square size of the feature maps, until we pool over what remains in the final layer, getting rid of any spatial information before the fully connected layer. (60-27-11-1).<br />
<br />
3. '''At end pooling''' Same as no total pooling, but before the fully connected layer, max-pool over the entire feature map. (60-54-48-1).<br />
<br />
===What is the problem in CNNs?===<br />
CNNs fall short in explaining human perceptual invariance. First, CNNs typically take input at a single uniform resolution. Biological measurements suggest that resolution is not uniform across the human visual field, but rather decays with eccentricity, i.e. distance from the center of focus Even more importantly, CNNs rely on data augmentation to achieve transformation-invariance and obviously a lot of processing is needed for CNNs.<br />
<br />
==Eccentricity-dependent Model==<br />
In order to take care of the scale invariance in the input image, the eccentricity dependent DNN is utilized. The main intuition behind this architecture is that as we increase eccentricity, the receptive fields also increase and hence the model will become invariant to changing input scales. In this model the input image is cropped into varying scales(11 crops increasing by a factor of <math>\sqrt{2}</math> which are then resized to 60x60 pixels) and then fed to the network. The model computes an invariant representation of the input by sampling the inverted pyramid at a discrete set of scales with the same number of filters at each scale. Since the same number of filters are used for each scale, the smaller crops will be sampled at a high resolution while the larger crops will be sampled with a low resolution. These scales are fed into the network as an input channel to the convolutional layers and share the weights across scale and space.<br />
[[File:EDM.png|2000x450px|center]]<br />
<br />
The architecture of this model is the same as the previous DCNN model with the only change being the extra filters added for each of the scales. The authors perform spatial pooling, the aforementioned ''At end pooling'' is used here, and scale pooling which helps in reducing number of scales by taking the maximum value of corresponding locations in the feature maps across multiple scales. It has three configurations: (1) at the beginning, in which all the different scales are pooled together after the first layer, 11-1-1-1-1 (2) progressively, 11-7-5-3-1 and (3) at the end, 11-11-11-11-1, in which all 11 scales are pooled together at the last layer.<br />
<br />
===Contrast Normalization===<br />
Since we have multiple scales of input image, we perform normalization such that the sum of the pixel intensities in each scale is in the same range [0,1]. These are then divided by a factor proportional to the crop area [[File:sqrtf.png|60px]] where i=1 is the smallest crop.<br />
<br />
=Experiments and its Set-Up =<br />
Targets are the set of objects to be recognized and flankers act as clutter with respect to these target objects. The target objects are the even MNIST numbers having translational variance (shifted at different locations of the image along the horizontal axis). Examples of the target and flanker configurations is shown below: <br />
[[File:eximages.png|800px|center]]<br />
<br />
The target and the object are referred to as ''a'' and ''x'' respectively with the below four conifgurations: (1) No flankers. Only the target object. (a in the plots) (2) One central flanker closer to the center of the image than the target. (xa) (3) One peripheral flanker closer to the boundary of the image that the target. (ax) (4) Two flankers spaced equally around the target, being both the same object (xax).<br />
<br />
==DNNs trained with Target and Flankers==<br />
This is a constant spacing training setup where identical flankers are placed at a distance of 120 pixels either side of the target(xax) with the target having translational variance. THe tests are evaluated on (i) DCNN with at the end pooling, and (ii) eccentricity-dependent model with 11-11-11-11-1 scale pooling, at the end spatial pooling and contrast normalization. The test data has different flanker configurations as described above.<br />
[[File:result1.png|x450px|center]]<br />
<br />
===Observations===<br />
* With the flanker configuration same as the training one, models are better at recognizing objects in clutter rather than isolated objects for all image locations<br />
<br />
* If the target-flanker spacing is changed, then models perform worse<br />
<br />
* the eccentricity model is much better at recognizing objects in isolation than the DCNN because the multi-scale crops divide the image into discrete regions, letting the model learn from image parts as well as the whole image<br />
<br />
* Only the eccentricity-dependent model is robust to different flanker configurations not included in training, when the target is centered.<br />
<br />
==DNNs trained with Images with the Target in Isolation==<br />
Here the target objects are in isolation and with translational variance while the test-set is the same set of flanker configurations as used before.<br />
[[File:result2.png|750x400px|center]]<br />
===DCNN Observations===<br />
* The recognition gets worse with the increase in the number of flankers.<br />
<br />
* Convolutional networks are capable of being invariant to translations.<br />
<br />
* In the constant target eccentricity setup, where target is fixed at the center of the image with varying target-flanker spacing, we observe that as the distance between target and flankers increase, recognition gets better.<br />
<br />
* Spatial pooling helps in learning invariance.<br />
<br />
*Flankers similar to the target object helps in recognition since they dont activate the convolutional filter more.<br />
<br />
* notMNIST data affects leads to more crowding since they have many more edges and white image pixels which activate the convolutional layers more.<br />
===Eccentric Model===<br />
The set-up is the same as explained earlier.<br />
[[File:result3.png|750x400px|center]]<br />
====Observations====<br />
* If target is placed at the center and no contrast normalization is done, then the recognition accuracy is high since this model concentrates the most on the central region of the image.<br />
<br />
* If contrast normalization is done, then all the scales will contribute equal amount and hence the eccentricity dependence is removed.<br />
<br />
* Early pooling is harmful since it migh take away the useful information very early which might be useful to the network.<br />
<br />
==Complex Clutter==<br />
Here, the targets are randomly embedded into images of the Places dataset and shifted along horizontally in order to investigate model robustness when the target is not at the image center. Tests are performed on DCNN and the eccentricity model with and without contrast normalization using at end pooling. The results are shown in Figure 9 below. <br />
<br />
[[File:result4.png|750x400px|center]]<br />
<br />
====Observations====<br />
- Only eccentricity model without contrast normalization can recognize the target and only when the target is close to the image center.<br />
<br />
- The eccentricity model does not need to be trained on different types of clutter to become robust to those types of clutter, but it needs to fixate on the relevant part of the image to recognize the target.<br />
<br />
=Conclusions=<br />
We often think that just training the network with data similar to the test data would achieve good results in a general scenario too but thats not the case as we trained the model with flankers and it did not give us the ideal results for the target obects.<br />
*'''Flanker Configuration''': When models are trained with images of objects in isolation, adding flankers harms recognition. Adding two flankers is the same or worse than adding just one and the smaller the spacing between flanker and target, the more crowding occurs. These is because the pooling operation merges nearby responses, such as the target and flankers if they are close.<br />
<br />
*'''Similarity between target and flanker''': Flankers more similar to targets cause more crowding, because of the selectivity property of the learned DNN filters.<br />
<br />
*'''Dependence on target location and contrast normalization''': In DCNNs and eccentricitydependent models with contrast normalization, recognition accuracy is the same across all eccentricities. In eccentricity-dependent networks without contrast normalization, recognition does not decrease despite presence of clutter when the target is at the center of the image.<br />
<br />
*'''Effect of pooling''': adding pooling leads to better recognition accuracy of the models. Yet, in the eccentricity model, pooling across the scales too early in the hierarchy leads to lower accuracy.<br />
<br />
=Critique=<br />
This paper just tries to check the impact of flankers on targets as to how crowding can affect recognition but it does not propose anything novel in terms of architecture to take care of such a type of crowding. The eccentricity based model does well only when the target is placed at the center of the image but maybe windowing over the frames instead of taking crops starting from the middle might help.<br />
<br />
=References=<br />
1) Volokitin A, Roig G, Poggio T:"Do Deep Neural Networks Suffer from Crowding?" Conference on Neural Information Processing Systems (NIPS). 2017<br />
<br />
2) Francis X. Chen, Gemma Roig, Leyla Isik, Xavier Boix and Tomaso Poggio: "Eccentricity Dependent Deep Neural Networks for Modeling Human Vision" Journal of Vision. 17. 808. 10.1167/17.10.808.<br />
<br />
3) J Harrison, W & W Remington, R & Mattingley, Jason. (2014). Visual crowding is anisotropic along the horizontal meridian during smooth pursuit. Journal of vision. 14. 10.1167/14.1.21. http://willjharrison.com/2014/01/new-paper-visual-crowding-is-anisotropic-along-the-horizontal-meridian-during-smooth-pursuit/</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=End-to-End_Differentiable_Adversarial_Imitation_Learning&diff=35032End-to-End Differentiable Adversarial Imitation Learning2018-03-21T20:20:46Z<p>X249wang: /* Discussion */</p>
<hr />
<div>= Introduction =<br />
The ability to imitate an expert policy is very beneficial in the case of automating human demonstrated tasks. Assuming that a sequence of state action pairs (trajectories) of an expert policy are available, a new policy can be trained that imitates the expert without having access to the original reward signal used by the expert. There are two main approaches to solve the problem of imitating a policy; they are Behavioural Cloning (BC) and Inverse Reinforcement Learning (IRL). BC directly learns the conditional distribution of actions over states in a supervised fashion by training on single time-step state-action pairs. The disadvantage of BC is that the training requires large amounts of expert data, which is hard to obtain. In addition, an agent trained using BC is unaware of how its action can affect future state distribution. The second method using IRL involves recovering a reward signal under which the expert is uniquely optimal; the main disadvantage is that it’s an ill-posed problem.<br />
<br />
To address the problem of imitating an expert policy, techniques based on Generative Adversarial Networks (GANs) have been proposed in recent years. GANs use a discriminator to guide the generative model towards producing patterns like those of the expert. This idea was used by (Ho & Ermon, 2016) in their work titled Generative Adversarial Imitation Learning (GAIL) to imitate an expert policy in a model-free setup. A model free setup is the one where the agent cannot make predictions about what the next state and reward will be before it takes each action since the transition function to move from state A to state B is not learned. The disadvantage of GAIL’s model-free approach is that backpropagation required gradient estimation which tends to suffer from high variance, which results in the need for large sample sizes and variance reduction methods. This paper proposed a model-based method (MGAIL) to address these issues by training a policy which is represented by the information propagated from the discriminator to the generator..<br />
<br />
= Background =<br />
== Markov Decision Process ==<br />
Consider an infinite-horizon discounted Markov decision process (MDP), defined by the tuple <math>(S, A, P, r, \rho_0, \gamma)</math> where <math>S</math> is the set of states, <math>A</math> is a set of actions, <math>P :<br />
S × A × S → [0, 1]</math> is the transition probability distribution, <math>r : (S × A) → R</math> is the reward function, <math>\rho_0 : S → [0, 1]</math> is the distribution over initial states, and <math>γ ∈ (0, 1)</math> is the discount factor. Let <math>π</math> denote a stochastic policy <math>π : S × A → [0, 1]</math>, <math>R(π)</math> denote its expected discounted reward: <math>E_πR = E_π [\sum_{t=0}^T \gamma^t r_t]</math> and <math>τ</math> denote a trajectory of states and actions <math>τ = {s_0, a_0, s_1, a_1, ...}</math>.<br />
<br />
== Imitation Learning ==<br />
A common technique for performing imitation learning is to train a policy <math> \pi </math> that minimizes some loss function <math> l(s, \pi(s)) </math> with respect to a discounted state distribution encountered by the expert: <math> d_\pi(s) = (1-\gamma)\sum_{t=0}^{\infty}\gamma^t p(s_t) </math>. This can be obtained using any supervised learning (SL) algorithm, but the policy's prediction affects future state distributions; this violates the independent and identically distributed (i.i.d) assumption made by most SL algorithms. This process is susceptible to compounding errors since a slight deviation in the learner's behavior can lead to different state distributions not encountered by the expert policy. <br />
<br />
This issue was overcome through the use of the Forward Training (FT) algorithm which trains a non-stationary policy iteratively over time. At each time step a new policy is trained on the state distribution induced by the previously trained policies <math>\pi_0</math>, <math>\pi_1</math>, ...<math>\pi_{t-1}</math>. This is continued till the end of the time horizon to obtain a policy that can mimic the expert policy. This requirement to train a policy at each time step till the end makes the FT algorithm impractical for cases where the time horizon is very large or undefined. This shortcoming is resolved using the Stochastic Mixing Iterative Learning (SMILe) algorithm. SMILe trains a stochastic stationary policy over several iterations under the trajectory distribution induced by the previously trained policy: <math> \pi_t = \pi_{t-1} + \alpha (1 - \alpha)^{t-1}(\hat{\pi}_t - \pi_0)</math>, with <math>\pi_0</math> following expert's policy at the start of training.<br />
<br />
== Generative Adversarial Networks ==<br />
GANs learn a generative model that can fool the discriminator by using a two-player zero-sum game:<br />
<br />
\begin{align} <br />
\underset{G}{\operatorname{argmin}}\; \underset{D\in (0,1)}{\operatorname{argmax}} = \mathbb{E}_{x\sim p_E}[log(D(x)]\ +\ \mathbb{E}_{z\sim p_z}[log(1 - D(G(z)))]<br />
\end{align}<br />
<br />
In the above equation, <math> p_E </math> represents the expert distribution and <math> p_z </math> represents the input noise distribution from which the input to the generator is sampled. The generator produces patterns and the discriminator judges if the pattern was generated or from the expert data. When the discriminator cannot distinguish between the two distributions the game ends and the generator has learned to mimic the expert. GANs rely on basic ideas such as binary classification and algorithms such as backpropagation in order to learn the expert distribution.<br />
<br />
GAIL applies GANs to the task of imitating an expert policy in a model-free approach. GAIL uses similar objective functions like GANs, but the expert distribution in GAIL represents the joint distribution over state action tuples:<br />
<br />
\begin{align} <br />
\underset{\pi}{\operatorname{argmin}}\; \underset{D\in (0,1)}{\operatorname{argmax}} = \mathbb{E}_{\pi}[log(D(s,a)]\ +\ \mathbb{E}_{\pi_E}[log(1 - D(s,a))] - \lambda H(\pi))<br />
\end{align}<br />
<br />
where <math> H(\pi) \triangleq \mathbb{E}_{\pi}[-log\: \pi(a|s)]</math> is the entropy.<br />
<br />
This problem cannot be solved using the standard methods described for GANs because the generator in GAIL represents a stochastic policy. The exact form of the first term in the above equation is given by: <math> \mathbb{E}_{s\sim \rho_\pi(s)}\mathbb{E}_{a\sim \pi(\cdot |s)} [log(D(s,a)] </math>.<br />
<br />
The two-player game now depends on the stochastic properties (<math> \theta </math>) of the policy, and it is unclear how to differentiate the above equation with respect to <math> \theta </math>. This problem can be overcome using score functions such as REINFORCE to obtain an unbiased gradient estimation:<br />
<br />
\begin{align}<br />
\nabla_\theta\mathbb{E}_{\pi} [log\; D(s,a)] \cong \hat{\mathbb{E}}_{\tau_i}[\nabla_\theta\; log\; \pi_\theta(a|s)Q(s,a)]<br />
\end{align}<br />
<br />
where <math> Q(\hat{s},\hat{a}) </math> is the score function of the gradient:<br />
<br />
\begin{align}<br />
Q(\hat{s},\hat{a}) = \hat{\mathbb{E}}_{\tau_i}[log\; D(s,a) | s_0 = \hat{s}, a_0 = \hat{a}]<br />
\end{align}<br />
<br />
<br />
REINFORCE gradients suffer from high variance which makes them difficult to work with even after applying variance reduction techniques. In order to better understand the changes required to fool the discriminator we need access to the gradients of the discriminator network, which can be obtained from the Jacobian of the discriminator. This paper demonstrates the use of a forward model along with the Jacobian of the discriminator to train a policy, without using high-variance gradient estimations.<br />
<br />
= Algorithm =<br />
This section first analyzes the characteristics of the discriminator network, then describes how a forward model can enable policy imitation through GANs. Lastly, the model based adversarial imitation learning algorithm is presented.<br />
<br />
== The discriminator network ==<br />
The discriminator network is trained to predict the conditional distribution: <math> D(s,a) = p(y|s,a) </math> where <math> y \in (\pi_E, \pi) </math>.<br />
<br />
The discriminator is trained on an even distribution of expert and generated examples; hence <math> p(\pi) = p(\pi_E) = \frac{1}{2} </math>. Given this and applying Bayes' theorem, we can rearrange and factor <math> D(s,a) </math> to obtain:<br />
<br />
\begin{aligned}<br />
D(s,a) &= p(\pi|s,a) \\<br />
& = \frac{p(s,a|\pi)p(\pi)}{p(s,a|\pi)p(\pi) + p(s,a|\pi_E)p(\pi_E)} \\<br />
& = \frac{p(s,a|\pi)}{p(s,a|\pi) + p(s,a|\pi_E)} \\<br />
& = \frac{1}{1 + \frac{p(s,a|\pi_E)}{p(s,a|\pi)}} \\<br />
& = \frac{1}{1 + \frac{p(a|s,\pi_E)}{p(a|s,\pi)} \cdot \frac{p(s|\pi_E)}{p(s|\pi)}} \\<br />
\end{aligned}<br />
<br />
Define <math> \varphi(s,a) </math> and <math> \psi(s) </math> to be:<br />
<br />
\begin{aligned}<br />
\varphi(s,a) = \frac{p(a|s,\pi_E)}{p(a|s,\pi)}, \psi(s) = \frac{p(s|\pi_E)}{p(s|\pi)}<br />
\end{aligned}<br />
<br />
to get the final expression for <math> D(s,a) </math>:<br />
\begin{aligned}<br />
D(s,a) = \frac{1}{1 + \varphi(s,a)\cdot \psi(s)}<br />
\end{aligned}<br />
<br />
<math> \varphi(s,a) </math> represents a policy likelihood ratio, and <math> \psi(s) </math> represents a state distribution likelihood ratio. Based on these expressions, the paper states that the discriminator makes its decisions by answering two questions. The first question relates to state distribution: what is the likelihood of encountering state <math> s </math> under the distribution induces by <math> \pi_E </math> vs <math> \pi </math>? The second question is about behavior: given a state <math> s </math>, how likely is action a under <math> \pi_E </math> vs <math> \pi </math>? The desired change in state is given by <math> \psi_s \equiv \partial \psi / \partial s </math>; this information can by obtained from the partial derivatives of <math> D(s,a) </math>, which is why these derivatives are proposed to be used for training policies (see following sections):<br />
<br />
\begin{aligned}<br />
\nabla_aD &= - \frac{\varphi_a(s,a)\psi(s)}{(1 + \varphi(s,a)\psi(s))^2} \\<br />
\nabla_sD &= - \frac{\varphi_s(s,a)\psi(s) + \varphi(s,a)\psi_s(s)}{(1 + \varphi(s,a)\psi(s))^2} \\<br />
\end{aligned}<br />
<br />
== Backpropagating through stochastic units ==<br />
There is interest in training stochastic policies because stochasticity encourages exploration for Policy Gradient methods. This is a problem for algorithms that build differentiable computation graphs where the gradients flow from one component to another since it is unclear how to backpropagate through stochastic units. The following subsections show how to estimate the gradients of continuous and categorical stochastic elements for continuous and discrete action domains respectively.<br />
<br />
=== Continuous Action Distributions ===<br />
In the case of continuous action policies, re-parameterization was used to enable computing the derivatives of stochastic models. Assuming that the stochastic policy has a Gaussian distribution <math> \mathcal{N}(\mu_{\theta} (s), \sigma_{\theta}^2 (s))</math>, where the mean and variance are given by some deterministic functions <math>\mu_{\theta}</math> and <math>\sigma_{\theta}</math>, then the policy <math> \pi </math> can be written as <math> \pi_\theta(a|s) = \mu_\theta(s) + \xi \sigma_\theta(s) </math>, where <math> \xi \sim N(0,1) </math>. This way, the authors are able to get a Monte-Carlo estimator of the derivative of the expected value of <math> D(s, a) </math> with respect to <math> \theta </math>:<br />
<br />
\begin{align}<br />
\nabla_\theta\mathbb{E}_{\pi(a|s)}D(s,a) = \mathbb{E}_{\rho (\xi )}\nabla_a D(a,s) \nabla_\theta \pi_\theta(a|s) \cong \frac{1}{M}\sum_{i=1}^{M} \nabla_a D(s,a) \nabla_\theta \pi_\theta(a|s)\Bigr|_{\substack{\xi=\xi_i}}<br />
\end{align}<br />
<br />
=== Categorical Action Distributions ===<br />
In the case of discrete action domains, the paper uses categorical re-parameterization with Gumbel-Softmax. This method relies on the Gumbel-Max trick which is a method for drawing samples from a categorical distribution with class probabilities <math> \pi(a_1|s),\pi(a_2|s),...,\pi(a_N|s) </math>:<br />
<br />
\begin{align}<br />
a_{argmax} = \underset{i}{argmax}[g_i + log\ \pi(a_i|s)]\textrm{, where } g_i \sim Gumbel(0, 1).<br />
\end{align}<br />
<br />
Gumbel-Softmax provides a differentiable approximation of the samples obtained using the Gumbel-Max trick:<br />
<br />
\begin{align}<br />
a_{softmax} = \frac{exp[\frac{1}{\tau}(g_i + log\ \pi(a_i|s))]}{\sum_{j=1}^{k}exp[\frac{1}{\tau}(g_j + log\ \pi(a_i|s))]}<br />
\end{align}<br />
<br />
<br />
In the above equation, the hyper-parameter <math> \tau </math> (temperature) trades bias for variance. When <math> \tau </math> gets closer to zero, the softmax operator acts like argmax resulting in a low bias, but high variance; vice versa when the <math> \tau </math> is large.<br />
<br />
The authors use <math> a_{softmax} </math> to interact with the environment; argmax is applied over <math> a_{softmax} </math> to obtain a single “pure” action, but the continuous approximation is used in the backward pass using the estimation: <math> \nabla_\theta\; a_{argmax} \approx \nabla_\theta\; a_{softmax} </math>.<br />
<br />
== Backpropagating through a Forward model ==<br />
The above subsections presented the means for extracting the partial derivative <math> \nabla_aD </math>. The main contribution of this paper is incorporating the use of <math> \nabla_sD </math>. In a model-free approach the state <math> s </math> is treated as a fixed input, therefore <math> \nabla_sD </math> is discarded. This is illustrated in Figure 1. This work uses a model-based approach which makes incorporating <math> \nabla_sD </math> more involved. In the model-based approach, a state <math> s_t </math> can be written as a function of the previous state action pair: <math> s_t = f(s_{t-1}, a_{t-1}) </math>, where <math> f </math> represents the forward model. Using the forward model and the law of total derivatives we get:<br />
<br />
\begin{align}<br />
\nabla_\theta D(s_t,a_t)\Bigr|_{\substack{s=s_t, a=a_t}} &= \frac{\partial D}{\partial a}\frac{\partial a}{\partial \theta}\Bigr|_{\substack{a=a_t}} + \frac{\partial D}{\partial s}\frac{\partial s}{\partial \theta}\Bigr|_{\substack{s=s_t}} \\<br />
&= \frac{\partial D}{\partial a}\frac{\partial a}{\partial \theta}\Bigr|_{\substack{a=a_t}} + \frac{\partial D}{\partial s}\left (\frac{\partial f}{\partial s}\frac{\partial s}{\partial \theta}\Bigr|_{\substack{s=s_{t-1}}} + \frac{\partial f}{\partial a}\frac{\partial a}{\partial \theta}\Bigr|_{\substack{a=a_{t-1}}} \right )<br />
\end{align}<br />
<br />
<br />
Using this formula, the error regarding deviations of future states <math> (\psi_s) </math> propagate back in time and influence the actions of policies in earlier times. This is summarized in Figure 2.<br />
<br />
[[File:modelFree_blockDiagram.PNG|300px]]<br />
<br />
Figure 1: Block-diagram of the model-free approach: given a state <math> s </math>, the policy outputs <math> \mu </math> which is fed to a stochastic sampling unit. An action <math> a </math> is sampled, and together with <math> s </math> are presented to the discriminator network. In the backward phase, the error message <math> \delta_a </math> is blocked at the stochastic sampling unit. From there, a high-variance gradient estimation is used (<math> \delta_{HV} </math>). Meanwhile, the error message <math> \delta_s </math> is flushed.<br />
<br />
[[File:modelBased_blockDiagram.PNG|1000px]]<br />
<br />
Figure 2: Block diagram of model-based adversarial imitation learning. This diagram describes the computation graph for training the policy (i.e. G). The discriminator network D is fixed at this stage and is trained separately. At time <math> t </math> of the forward pass, <math> \pi </math> outputs a distribution over actions: <math> \mu_t = \pi(s_t) </math>, from which an action at is sampled. For example, in the continuous case, this is done using the re-parametrization trick: <math> a_t = \mu_t + \xi \cdot \sigma </math>, where <math> \xi \sim N(0,1) </math>. The next state <math> s_{t+1} = f(s_t, a_t) </math> is computed using the forward model (which is also trained separately), and the entire process repeats for time <math> t+1 </math>. In the backward pass, the gradient of <math> \pi </math> is comprised of a.) the error message <math> \delta_a </math> (Green) that propagates fluently through the differentiable approximation of the sampling process. And b.) the error message <math> \delta_s </math> (Blue) of future time-steps, that propagate back through the differentiable forward model.<br />
<br />
== MGAIL Algorithm ==<br />
Shalev- Shwartz et al. (2016) and Heess et al. (2015) built a multi-step computation graph for describing the familiar policy gradient objective; in this case it is given by:<br />
<br />
\begin{align}<br />
J(\theta) = \mathbb{E}\left [ \sum_{t=0}^{T} \gamma ^t D(s_t,a_t)|\theta\right ]<br />
\end{align}<br />
<br />
<br />
Using the results from Heess et al. (2015) this paper demonstrates how to differentiate <math> J(\theta) </math> over a trajectory of <math>(s,a,s’) </math> transitions:<br />
<br />
\begin{align}<br />
J_s &= \mathbb{E}_{p(a|s)}\mathbb{E}_{p(s'|s,a)}\left [ D_s + D_a \pi_s + \gamma J'_{s'}(f_s + f_a \pi_s) \right] \\<br />
J_\theta &= \mathbb{E}_{p(a|s)}\mathbb{E}_{p(s'|s,a)}\left [ D_a \pi_\theta + \gamma (J'_{s'} f_a \pi_\theta + J'_\theta) \right]<br />
\end{align}<br />
<br />
The policy gradient <math> \nabla_\theta J </math> is calculated by applying equations 12 and 13 recursively for <math> T </math> iterations. The MGAIL algorithm is presented below.<br />
<br />
[[File:MGAIL_alg.PNG]]<br />
<br />
== Forward Model Structure ==<br />
The stability of the learning process depends on the prediction accuracy of the forward model, but learning an accurate forward model is challenging by itself. The authors propose methods for improving the performance of the forward model based on two aspects of its functionality. First, the forward model should learn to use the action as an operator over the state space. To accomplish this, the actions and states, which are sampled form different distributions need to be first represented in a shared space. This is done by encoding the state and action using two separate neural networks and combining their outputs to form a single vector. Additionally, multiple previous states are used to predict the next state by representing the environment as an <math> n^{th} </math> order MDP. A gated recurrent units (GRU, a simpler variant on the LSTM model) layer is incorporated into the state encoder to enable recurrent connections from previous states. Using these modifications, the model is able to achieve better, and more stable results compared to the standard forward model based on a feed forward neural network. The comparison is presented in Figure 3.<br />
<br />
[[File:performance_comparison.PNG]]<br />
<br />
Figure 3: Performance comparison between a basic forward model (Blue), and the advanced forward model (Green).<br />
<br />
= Experiments =<br />
The proposed algorithm is evaluated on three discrete control tasks (Cartpole, Mountain-Car, Acrobot), and five continuous control tasks (Hopper, Walker, Half-Cheetah, Ant, and Humanoid), which are modeled by the MuJoCo physics simulator (Todorov et al., 2012). Expert policies are trained using the Trust Region Policy Optimization (TRPO) algorithm (Schulman et al., 2015). Different number of trajectories are used to train the expert for each task, but all trajectories are of length 1000.<br />
The discriminator and generator (policy) networks contains two hidden layers with ReLU non-linearity and are trained using the ADAM optimizer. The total reward received over a period of <math> N </math> steps using BC, GAIL and MGAIL is presented in Table 1. The proposed algorithm achieved the highest reward for most environments while exhibiting performance comparable to the expert over all of them.<br />
<br />
[[File:mgail_test_results.PNG]]<br />
<br />
Table 1. Policy performance, boldface indicates better results, <math> \pm </math> represents one standard deviation.<br />
<br />
= Discussion =<br />
This paper presented a model-free algorithm for imitation learning. It demonstrated how a forward model can be used to train policies using the exact gradient of the discriminator network. A downside of this approach is the need to learn a forward model, since this could be difficult in certain domains. Learning the system dynamics directly from raw images is considered as one line of future work. Another future work is to address the violation of the fundamental assumption made by all supervised learning algorithms, which requires the data to be i.i.d. This problem arises because the discriminator and forward models are trained in a supervised learning fashion using data sampled from a dynamic distribution. The authors tried a solution proposed by another paper, which is to reset the learning rate several times during training period, but it did not result in significant improvements.<br />
<br />
= Source =<br />
# Baram, Nir, et al. "End-to-end differentiable adversarial imitation learning." International Conference on Machine Learning. 2017.<br />
# Ho, Jonathan, and Stefano Ermon. "Generative adversarial imitation learning." Advances in Neural Information Processing Systems. 2016.<br />
# Shalev-Shwartz, Shai, et al. "Long-term planning by short-term prediction." arXiv preprint arXiv:1602.01580 (2016).<br />
# Heess, Nicolas, et al. "Learning continuous control policies by stochastic value gradients." Advances in Neural Information Processing Systems. 2015.<br />
# Schulman, John, et al. "Trust region policy optimization." International Conference on Machine Learning. 2015.</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=End-to-End_Differentiable_Adversarial_Imitation_Learning&diff=35029End-to-End Differentiable Adversarial Imitation Learning2018-03-21T20:08:19Z<p>X249wang: /* Imitation Learning */</p>
<hr />
<div>= Introduction =<br />
The ability to imitate an expert policy is very beneficial in the case of automating human demonstrated tasks. Assuming that a sequence of state action pairs (trajectories) of an expert policy are available, a new policy can be trained that imitates the expert without having access to the original reward signal used by the expert. There are two main approaches to solve the problem of imitating a policy; they are Behavioural Cloning (BC) and Inverse Reinforcement Learning (IRL). BC directly learns the conditional distribution of actions over states in a supervised fashion by training on single time-step state-action pairs. The disadvantage of BC is that the training requires large amounts of expert data, which is hard to obtain. In addition, an agent trained using BC is unaware of how its action can affect future state distribution. The second method using IRL involves recovering a reward signal under which the expert is uniquely optimal; the main disadvantage is that it’s an ill-posed problem.<br />
<br />
To address the problem of imitating an expert policy, techniques based on Generative Adversarial Networks (GANs) have been proposed in recent years. GANs use a discriminator to guide the generative model towards producing patterns like those of the expert. This idea was used by (Ho & Ermon, 2016) in their work titled Generative Adversarial Imitation Learning (GAIL) to imitate an expert policy in a model-free setup. A model free setup is the one where the agent cannot make predictions about what the next state and reward will be before it takes each action since the transition function to move from state A to state B is not learned. The disadvantage of GAIL’s model-free approach is that backpropagation required gradient estimation which tends to suffer from high variance, which results in the need for large sample sizes and variance reduction methods. This paper proposed a model-based method (MGAIL) to address these issues by training a policy which is represented by the information propagated from the discriminator to the generator..<br />
<br />
= Background =<br />
== Markov Decision Process ==<br />
Consider an infinite-horizon discounted Markov decision process (MDP), defined by the tuple <math>(S, A, P, r, \rho_0, \gamma)</math> where <math>S</math> is the set of states, <math>A</math> is a set of actions, <math>P :<br />
S × A × S → [0, 1]</math> is the transition probability distribution, <math>r : (S × A) → R</math> is the reward function, <math>\rho_0 : S → [0, 1]</math> is the distribution over initial states, and <math>γ ∈ (0, 1)</math> is the discount factor. Let <math>π</math> denote a stochastic policy <math>π : S × A → [0, 1]</math>, <math>R(π)</math> denote its expected discounted reward: <math>E_πR = E_π [\sum_{t=0}^T \gamma^t r_t]</math> and <math>τ</math> denote a trajectory of states and actions <math>τ = {s_0, a_0, s_1, a_1, ...}</math>.<br />
<br />
== Imitation Learning ==<br />
A common technique for performing imitation learning is to train a policy <math> \pi </math> that minimizes some loss function <math> l(s, \pi(s)) </math> with respect to a discounted state distribution encountered by the expert: <math> d_\pi(s) = (1-\gamma)\sum_{t=0}^{\infty}\gamma^t p(s_t) </math>. This can be obtained using any supervised learning (SL) algorithm, but the policy's prediction affects future state distributions; this violates the independent and identically distributed (i.i.d) assumption made by most SL algorithms. This process is susceptible to compounding errors since a slight deviation in the learner's behavior can lead to different state distributions not encountered by the expert policy. <br />
<br />
This issue was overcome through the use of the Forward Training (FT) algorithm which trains a non-stationary policy iteratively over time. At each time step a new policy is trained on the state distribution induced by the previously trained policies <math>\pi_0</math>, <math>\pi_1</math>, ...<math>\pi_{t-1}</math>. This is continued till the end of the time horizon to obtain a policy that can mimic the expert policy. This requirement to train a policy at each time step till the end makes the FT algorithm impractical for cases where the time horizon is very large or undefined. This shortcoming is resolved using the Stochastic Mixing Iterative Learning (SMILe) algorithm. SMILe trains a stochastic stationary policy over several iterations under the trajectory distribution induced by the previously trained policy: <math> \pi_t = \pi_{t-1} + \alpha (1 - \alpha)^{t-1}(\hat{\pi}_t - \pi_0)</math>, with <math>\pi_0</math> following expert's policy at the start of training.<br />
<br />
== Generative Adversarial Networks ==<br />
GANs learn a generative model that can fool the discriminator by using a two-player zero-sum game:<br />
<br />
\begin{align} <br />
\underset{G}{\operatorname{argmin}}\; \underset{D\in (0,1)}{\operatorname{argmax}} = \mathbb{E}_{x\sim p_E}[log(D(x)]\ +\ \mathbb{E}_{z\sim p_z}[log(1 - D(G(z)))]<br />
\end{align}<br />
<br />
In the above equation, <math> p_E </math> represents the expert distribution and <math> p_z </math> represents the input noise distribution from which the input to the generator is sampled. The generator produces patterns and the discriminator judges if the pattern was generated or from the expert data. When the discriminator cannot distinguish between the two distributions the game ends and the generator has learned to mimic the expert. GANs rely on basic ideas such as binary classification and algorithms such as backpropagation in order to learn the expert distribution.<br />
<br />
GAIL applies GANs to the task of imitating an expert policy in a model-free approach. GAIL uses similar objective functions like GANs, but the expert distribution in GAIL represents the joint distribution over state action tuples:<br />
<br />
\begin{align} <br />
\underset{\pi}{\operatorname{argmin}}\; \underset{D\in (0,1)}{\operatorname{argmax}} = \mathbb{E}_{\pi}[log(D(s,a)]\ +\ \mathbb{E}_{\pi_E}[log(1 - D(s,a))] - \lambda H(\pi))<br />
\end{align}<br />
<br />
where <math> H(\pi) \triangleq \mathbb{E}_{\pi}[-log\: \pi(a|s)]</math> is the entropy.<br />
<br />
This problem cannot be solved using the standard methods described for GANs because the generator in GAIL represents a stochastic policy. The exact form of the first term in the above equation is given by: <math> \mathbb{E}_{s\sim \rho_\pi(s)}\mathbb{E}_{a\sim \pi(\cdot |s)} [log(D(s,a)] </math>.<br />
<br />
The two-player game now depends on the stochastic properties (<math> \theta </math>) of the policy, and it is unclear how to differentiate the above equation with respect to <math> \theta </math>. This problem can be overcome using score functions such as REINFORCE to obtain an unbiased gradient estimation:<br />
<br />
\begin{align}<br />
\nabla_\theta\mathbb{E}_{\pi} [log\; D(s,a)] \cong \hat{\mathbb{E}}_{\tau_i}[\nabla_\theta\; log\; \pi_\theta(a|s)Q(s,a)]<br />
\end{align}<br />
<br />
where <math> Q(\hat{s},\hat{a}) </math> is the score function of the gradient:<br />
<br />
\begin{align}<br />
Q(\hat{s},\hat{a}) = \hat{\mathbb{E}}_{\tau_i}[log\; D(s,a) | s_0 = \hat{s}, a_0 = \hat{a}]<br />
\end{align}<br />
<br />
<br />
REINFORCE gradients suffer from high variance which makes them difficult to work with even after applying variance reduction techniques. In order to better understand the changes required to fool the discriminator we need access to the gradients of the discriminator network, which can be obtained from the Jacobian of the discriminator. This paper demonstrates the use of a forward model along with the Jacobian of the discriminator to train a policy, without using high-variance gradient estimations.<br />
<br />
= Algorithm =<br />
This section first analyzes the characteristics of the discriminator network, then describes how a forward model can enable policy imitation through GANs. Lastly, the model based adversarial imitation learning algorithm is presented.<br />
<br />
== The discriminator network ==<br />
The discriminator network is trained to predict the conditional distribution: <math> D(s,a) = p(y|s,a) </math> where <math> y \in (\pi_E, \pi) </math>.<br />
<br />
The discriminator is trained on an even distribution of expert and generated examples; hence <math> p(\pi) = p(\pi_E) = \frac{1}{2} </math>. Given this and applying Bayes' theorem, we can rearrange and factor <math> D(s,a) </math> to obtain:<br />
<br />
\begin{aligned}<br />
D(s,a) &= p(\pi|s,a) \\<br />
& = \frac{p(s,a|\pi)p(\pi)}{p(s,a|\pi)p(\pi) + p(s,a|\pi_E)p(\pi_E)} \\<br />
& = \frac{p(s,a|\pi)}{p(s,a|\pi) + p(s,a|\pi_E)} \\<br />
& = \frac{1}{1 + \frac{p(s,a|\pi_E)}{p(s,a|\pi)}} \\<br />
& = \frac{1}{1 + \frac{p(a|s,\pi_E)}{p(a|s,\pi)} \cdot \frac{p(s|\pi_E)}{p(s|\pi)}} \\<br />
\end{aligned}<br />
<br />
Define <math> \varphi(s,a) </math> and <math> \psi(s) </math> to be:<br />
<br />
\begin{aligned}<br />
\varphi(s,a) = \frac{p(a|s,\pi_E)}{p(a|s,\pi)}, \psi(s) = \frac{p(s|\pi_E)}{p(s|\pi)}<br />
\end{aligned}<br />
<br />
to get the final expression for <math> D(s,a) </math>:<br />
\begin{aligned}<br />
D(s,a) = \frac{1}{1 + \varphi(s,a)\cdot \psi(s)}<br />
\end{aligned}<br />
<br />
<math> \varphi(s,a) </math> represents a policy likelihood ratio, and <math> \psi(s) </math> represents a state distribution likelihood ratio. Based on these expressions, the paper states that the discriminator makes its decisions by answering two questions. The first question relates to state distribution: what is the likelihood of encountering state <math> s </math> under the distribution induces by <math> \pi_E </math> vs <math> \pi </math>? The second question is about behavior: given a state <math> s </math>, how likely is action a under <math> \pi_E </math> vs <math> \pi </math>? The desired change in state is given by <math> \psi_s \equiv \partial \psi / \partial s </math>; this information can by obtained from the partial derivatives of <math> D(s,a) </math>, which is why these derivatives are proposed to be used for training policies (see following sections):<br />
<br />
\begin{aligned}<br />
\nabla_aD &= - \frac{\varphi_a(s,a)\psi(s)}{(1 + \varphi(s,a)\psi(s))^2} \\<br />
\nabla_sD &= - \frac{\varphi_s(s,a)\psi(s) + \varphi(s,a)\psi_s(s)}{(1 + \varphi(s,a)\psi(s))^2} \\<br />
\end{aligned}<br />
<br />
== Backpropagating through stochastic units ==<br />
There is interest in training stochastic policies because stochasticity encourages exploration for Policy Gradient methods. This is a problem for algorithms that build differentiable computation graphs where the gradients flow from one component to another since it is unclear how to backpropagate through stochastic units. The following subsections show how to estimate the gradients of continuous and categorical stochastic elements for continuous and discrete action domains respectively.<br />
<br />
=== Continuous Action Distributions ===<br />
In the case of continuous action policies, re-parameterization was used to enable computing the derivatives of stochastic models. Assuming that the stochastic policy has a Gaussian distribution <math> \mathcal{N}(\mu_{\theta} (s), \sigma_{\theta}^2 (s))</math>, where the mean and variance are given by some deterministic functions <math>\mu_{\theta}</math> and <math>\sigma_{\theta}</math>, then the policy <math> \pi </math> can be written as <math> \pi_\theta(a|s) = \mu_\theta(s) + \xi \sigma_\theta(s) </math>, where <math> \xi \sim N(0,1) </math>. This way, the authors are able to get a Monte-Carlo estimator of the derivative of the expected value of <math> D(s, a) </math> with respect to <math> \theta </math>:<br />
<br />
\begin{align}<br />
\nabla_\theta\mathbb{E}_{\pi(a|s)}D(s,a) = \mathbb{E}_{\rho (\xi )}\nabla_a D(a,s) \nabla_\theta \pi_\theta(a|s) \cong \frac{1}{M}\sum_{i=1}^{M} \nabla_a D(s,a) \nabla_\theta \pi_\theta(a|s)\Bigr|_{\substack{\xi=\xi_i}}<br />
\end{align}<br />
<br />
=== Categorical Action Distributions ===<br />
In the case of discrete action domains, the paper uses categorical re-parameterization with Gumbel-Softmax. This method relies on the Gumbel-Max trick which is a method for drawing samples from a categorical distribution with class probabilities <math> \pi(a_1|s),\pi(a_2|s),...,\pi(a_N|s) </math>:<br />
<br />
\begin{align}<br />
a_{argmax} = \underset{i}{argmax}[g_i + log\ \pi(a_i|s)]\textrm{, where } g_i \sim Gumbel(0, 1).<br />
\end{align}<br />
<br />
Gumbel-Softmax provides a differentiable approximation of the samples obtained using the Gumbel-Max trick:<br />
<br />
\begin{align}<br />
a_{softmax} = \frac{exp[\frac{1}{\tau}(g_i + log\ \pi(a_i|s))]}{\sum_{j=1}^{k}exp[\frac{1}{\tau}(g_j + log\ \pi(a_i|s))]}<br />
\end{align}<br />
<br />
<br />
In the above equation, the hyper-parameter <math> \tau </math> (temperature) trades bias for variance. When <math> \tau </math> gets closer to zero, the softmax operator acts like argmax resulting in a low bias, but high variance; vice versa when the <math> \tau </math> is large.<br />
<br />
The authors use <math> a_{softmax} </math> to interact with the environment; argmax is applied over <math> a_{softmax} </math> to obtain a single “pure” action, but the continuous approximation is used in the backward pass using the estimation: <math> \nabla_\theta\; a_{argmax} \approx \nabla_\theta\; a_{softmax} </math>.<br />
<br />
== Backpropagating through a Forward model ==<br />
The above subsections presented the means for extracting the partial derivative <math> \nabla_aD </math>. The main contribution of this paper is incorporating the use of <math> \nabla_sD </math>. In a model-free approach the state <math> s </math> is treated as a fixed input, therefore <math> \nabla_sD </math> is discarded. This is illustrated in Figure 1. This work uses a model-based approach which makes incorporating <math> \nabla_sD </math> more involved. In the model-based approach, a state <math> s_t </math> can be written as a function of the previous state action pair: <math> s_t = f(s_{t-1}, a_{t-1}) </math>, where <math> f </math> represents the forward model. Using the forward model and the law of total derivatives we get:<br />
<br />
\begin{align}<br />
\nabla_\theta D(s_t,a_t)\Bigr|_{\substack{s=s_t, a=a_t}} &= \frac{\partial D}{\partial a}\frac{\partial a}{\partial \theta}\Bigr|_{\substack{a=a_t}} + \frac{\partial D}{\partial s}\frac{\partial s}{\partial \theta}\Bigr|_{\substack{s=s_t}} \\<br />
&= \frac{\partial D}{\partial a}\frac{\partial a}{\partial \theta}\Bigr|_{\substack{a=a_t}} + \frac{\partial D}{\partial s}\left (\frac{\partial f}{\partial s}\frac{\partial s}{\partial \theta}\Bigr|_{\substack{s=s_{t-1}}} + \frac{\partial f}{\partial a}\frac{\partial a}{\partial \theta}\Bigr|_{\substack{a=a_{t-1}}} \right )<br />
\end{align}<br />
<br />
<br />
Using this formula, the error regarding deviations of future states <math> (\psi_s) </math> propagate back in time and influence the actions of policies in earlier times. This is summarized in Figure 2.<br />
<br />
[[File:modelFree_blockDiagram.PNG|300px]]<br />
<br />
Figure 1: Block-diagram of the model-free approach: given a state <math> s </math>, the policy outputs <math> \mu </math> which is fed to a stochastic sampling unit. An action <math> a </math> is sampled, and together with <math> s </math> are presented to the discriminator network. In the backward phase, the error message <math> \delta_a </math> is blocked at the stochastic sampling unit. From there, a high-variance gradient estimation is used (<math> \delta_{HV} </math>). Meanwhile, the error message <math> \delta_s </math> is flushed.<br />
<br />
[[File:modelBased_blockDiagram.PNG|1000px]]<br />
<br />
Figure 2: Block diagram of model-based adversarial imitation learning. This diagram describes the computation graph for training the policy (i.e. G). The discriminator network D is fixed at this stage and is trained separately. At time <math> t </math> of the forward pass, <math> \pi </math> outputs a distribution over actions: <math> \mu_t = \pi(s_t) </math>, from which an action at is sampled. For example, in the continuous case, this is done using the re-parametrization trick: <math> a_t = \mu_t + \xi \cdot \sigma </math>, where <math> \xi \sim N(0,1) </math>. The next state <math> s_{t+1} = f(s_t, a_t) </math> is computed using the forward model (which is also trained separately), and the entire process repeats for time <math> t+1 </math>. In the backward pass, the gradient of <math> \pi </math> is comprised of a.) the error message <math> \delta_a </math> (Green) that propagates fluently through the differentiable approximation of the sampling process. And b.) the error message <math> \delta_s </math> (Blue) of future time-steps, that propagate back through the differentiable forward model.<br />
<br />
== MGAIL Algorithm ==<br />
Shalev- Shwartz et al. (2016) and Heess et al. (2015) built a multi-step computation graph for describing the familiar policy gradient objective; in this case it is given by:<br />
<br />
\begin{align}<br />
J(\theta) = \mathbb{E}\left [ \sum_{t=0}^{T} \gamma ^t D(s_t,a_t)|\theta\right ]<br />
\end{align}<br />
<br />
<br />
Using the results from Heess et al. (2015) this paper demonstrates how to differentiate <math> J(\theta) </math> over a trajectory of <math>(s,a,s’) </math> transitions:<br />
<br />
\begin{align}<br />
J_s &= \mathbb{E}_{p(a|s)}\mathbb{E}_{p(s'|s,a)}\left [ D_s + D_a \pi_s + \gamma J'_{s'}(f_s + f_a \pi_s) \right] \\<br />
J_\theta &= \mathbb{E}_{p(a|s)}\mathbb{E}_{p(s'|s,a)}\left [ D_a \pi_\theta + \gamma (J'_{s'} f_a \pi_\theta + J'_\theta) \right]<br />
\end{align}<br />
<br />
The policy gradient <math> \nabla_\theta J </math> is calculated by applying equations 12 and 13 recursively for <math> T </math> iterations. The MGAIL algorithm is presented below.<br />
<br />
[[File:MGAIL_alg.PNG]]<br />
<br />
== Forward Model Structure ==<br />
The stability of the learning process depends on the prediction accuracy of the forward model, but learning an accurate forward model is challenging by itself. The authors propose methods for improving the performance of the forward model based on two aspects of its functionality. First, the forward model should learn to use the action as an operator over the state space. To accomplish this, the actions and states, which are sampled form different distributions need to be first represented in a shared space. This is done by encoding the state and action using two separate neural networks and combining their outputs to form a single vector. Additionally, multiple previous states are used to predict the next state by representing the environment as an <math> n^{th} </math> order MDP. A gated recurrent units (GRU, a simpler variant on the LSTM model) layer is incorporated into the state encoder to enable recurrent connections from previous states. Using these modifications, the model is able to achieve better, and more stable results compared to the standard forward model based on a feed forward neural network. The comparison is presented in Figure 3.<br />
<br />
[[File:performance_comparison.PNG]]<br />
<br />
Figure 3: Performance comparison between a basic forward model (Blue), and the advanced forward model (Green).<br />
<br />
= Experiments =<br />
The proposed algorithm is evaluated on three discrete control tasks (Cartpole, Mountain-Car, Acrobot), and five continuous control tasks (Hopper, Walker, Half-Cheetah, Ant, and Humanoid), which are modeled by the MuJoCo physics simulator (Todorov et al., 2012). Expert policies are trained using the Trust Region Policy Optimization (TRPO) algorithm (Schulman et al., 2015). Different number of trajectories are used to train the expert for each task, but all trajectories are of length 1000.<br />
The discriminator and generator (policy) networks contains two hidden layers with ReLU non-linearity and are trained using the ADAM optimizer. The total reward received over a period of <math> N </math> steps using BC, GAIL and MGAIL is presented in Table 1. The proposed algorithm achieved the highest reward for most environments while exhibiting performance comparable to the expert over all of them.<br />
<br />
[[File:mgail_test_results.PNG]]<br />
<br />
Table 1. Policy performance, boldface indicates better results, <math> \pm </math> represents one standard deviation.<br />
<br />
= Discussion =<br />
This paper presented a model-free algorithm for imitation learning. It demonstrated how a forward model can be used to train policies using the exact gradient of the discriminator network. A downside of this approach is the need to learn a forward model, since this could be difficult in certain domains. Learning the system dynamics directly from raw images is considered as one line of future work. Another future work is to address the violation of the fundamental assumption made by all supervised learning algorithms, which requires the data to be i.i.d. This problem arises because the discriminator and forward models are trained in a supervised learning fashion using data sampled from a dynamic distribution.<br />
<br />
= Source =<br />
# Baram, Nir, et al. "End-to-end differentiable adversarial imitation learning." International Conference on Machine Learning. 2017.<br />
# Ho, Jonathan, and Stefano Ermon. "Generative adversarial imitation learning." Advances in Neural Information Processing Systems. 2016.<br />
# Shalev-Shwartz, Shai, et al. "Long-term planning by short-term prediction." arXiv preprint arXiv:1602.01580 (2016).<br />
# Heess, Nicolas, et al. "Learning continuous control policies by stochastic value gradients." Advances in Neural Information Processing Systems. 2015.<br />
# Schulman, John, et al. "Trust region policy optimization." International Conference on Machine Learning. 2015.</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=End-to-End_Differentiable_Adversarial_Imitation_Learning&diff=35028End-to-End Differentiable Adversarial Imitation Learning2018-03-21T19:39:20Z<p>X249wang: /* The discriminator network */</p>
<hr />
<div>= Introduction =<br />
The ability to imitate an expert policy is very beneficial in the case of automating human demonstrated tasks. Assuming that a sequence of state action pairs (trajectories) of an expert policy are available, a new policy can be trained that imitates the expert without having access to the original reward signal used by the expert. There are two main approaches to solve the problem of imitating a policy; they are Behavioural Cloning (BC) and Inverse Reinforcement Learning (IRL). BC directly learns the conditional distribution of actions over states in a supervised fashion by training on single time-step state-action pairs. The disadvantage of BC is that the training requires large amounts of expert data, which is hard to obtain. In addition, an agent trained using BC is unaware of how its action can affect future state distribution. The second method using IRL involves recovering a reward signal under which the expert is uniquely optimal; the main disadvantage is that it’s an ill-posed problem.<br />
<br />
To address the problem of imitating an expert policy, techniques based on Generative Adversarial Networks (GANs) have been proposed in recent years. GANs use a discriminator to guide the generative model towards producing patterns like those of the expert. This idea was used by (Ho & Ermon, 2016) in their work titled Generative Adversarial Imitation Learning (GAIL) to imitate an expert policy in a model-free setup. A model free setup is the one where the agent cannot make predictions about what the next state and reward will be before it takes each action since the transition function to move from state A to state B is not learned. The disadvantage of GAIL’s model-free approach is that backpropagation required gradient estimation which tends to suffer from high variance, which results in the need for large sample sizes and variance reduction methods. This paper proposed a model-based method (MGAIL) to address these issues by training a policy which is represented by the information propagated from the discriminator to the generator..<br />
<br />
= Background =<br />
== Markov Decision Process ==<br />
Consider an infinite-horizon discounted Markov decision process (MDP), defined by the tuple <math>(S, A, P, r, \rho_0, \gamma)</math> where <math>S</math> is the set of states, <math>A</math> is a set of actions, <math>P :<br />
S × A × S → [0, 1]</math> is the transition probability distribution, <math>r : (S × A) → R</math> is the reward function, <math>\rho_0 : S → [0, 1]</math> is the distribution over initial states, and <math>γ ∈ (0, 1)</math> is the discount factor. Let <math>π</math> denote a stochastic policy <math>π : S × A → [0, 1]</math>, <math>R(π)</math> denote its expected discounted reward: <math>E_πR = E_π [\sum_{t=0}^T \gamma^t r_t]</math> and <math>τ</math> denote a trajectory of states and actions <math>τ = {s_0, a_0, s_1, a_1, ...}</math>.<br />
<br />
== Imitation Learning ==<br />
A common technique for performing imitation learning is to train a policy <math> \pi </math> that minimizes some loss function <math> l(s, \pi(s)) </math> with respect to a discounted state distribution encountered by the expert: <math> d_\pi(s) = (1-\gamma)\sum_{t=0}^{\infty}\gamma^t p(s_t) </math>. This can be obtained using any supervised learning (SL) algorithm, but the policy's prediction affects future state distributions; this violates the independent and identically distributed (i.i.d) assumption made my most SL algorithms. This process is susceptible to compounding errors since a slight deviation in the learner's behavior can lead to different state distributions not encountered by the expert policy. <br />
<br />
This issue was overcome through the use of the Forward Training (FT) algorithm which trains a non-stationary policy iteratively overtime. At each time step a new policy is trained on the state distribution induced by the previously trained policies. This is continued till the end of the time horizon to obtain a policy that can mimic the expert policy. This requirement to train a policy at each time step till the end makes the FT algorithm impractical for cases where the time horizon is very large or undefined. This short coming is resolved using the Stochastic Mixing Iterative Learning (SMILe) algorithm. SMILe trains a stochastic stationary policy over several iterations under the trajectory distribution induced by the previously trained policy.<br />
<br />
== Generative Adversarial Networks ==<br />
GANs learn a generative model that can fool the discriminator by using a two-player zero-sum game:<br />
<br />
\begin{align} <br />
\underset{G}{\operatorname{argmin}}\; \underset{D\in (0,1)}{\operatorname{argmax}} = \mathbb{E}_{x\sim p_E}[log(D(x)]\ +\ \mathbb{E}_{z\sim p_z}[log(1 - D(G(z)))]<br />
\end{align}<br />
<br />
In the above equation, <math> p_E </math> represents the expert distribution and <math> p_z </math> represents the input noise distribution from which the input to the generator is sampled. The generator produces patterns and the discriminator judges if the pattern was generated or from the expert data. When the discriminator cannot distinguish between the two distributions the game ends and the generator has learned to mimic the expert. GANs rely on basic ideas such as binary classification and algorithms such as backpropagation in order to learn the expert distribution.<br />
<br />
GAIL applies GANs to the task of imitating an expert policy in a model-free approach. GAIL uses similar objective functions like GANs, but the expert distribution in GAIL represents the joint distribution over state action tuples:<br />
<br />
\begin{align} <br />
\underset{\pi}{\operatorname{argmin}}\; \underset{D\in (0,1)}{\operatorname{argmax}} = \mathbb{E}_{\pi}[log(D(s,a)]\ +\ \mathbb{E}_{\pi_E}[log(1 - D(s,a))] - \lambda H(\pi))<br />
\end{align}<br />
<br />
where <math> H(\pi) \triangleq \mathbb{E}_{\pi}[-log\: \pi(a|s)]</math> is the entropy.<br />
<br />
This problem cannot be solved using the standard methods described for GANs because the generator in GAIL represents a stochastic policy. The exact form of the first term in the above equation is given by: <math> \mathbb{E}_{s\sim \rho_\pi(s)}\mathbb{E}_{a\sim \pi(\cdot |s)} [log(D(s,a)] </math>.<br />
<br />
The two-player game now depends on the stochastic properties (<math> \theta </math>) of the policy, and it is unclear how to differentiate the above equation with respect to <math> \theta </math>. This problem can be overcome using score functions such as REINFORCE to obtain an unbiased gradient estimation:<br />
<br />
\begin{align}<br />
\nabla_\theta\mathbb{E}_{\pi} [log\; D(s,a)] \cong \hat{\mathbb{E}}_{\tau_i}[\nabla_\theta\; log\; \pi_\theta(a|s)Q(s,a)]<br />
\end{align}<br />
<br />
where <math> Q(\hat{s},\hat{a}) </math> is the score function of the gradient:<br />
<br />
\begin{align}<br />
Q(\hat{s},\hat{a}) = \hat{\mathbb{E}}_{\tau_i}[log\; D(s,a) | s_0 = \hat{s}, a_0 = \hat{a}]<br />
\end{align}<br />
<br />
<br />
REINFORCE gradients suffer from high variance which makes them difficult to work with even after applying variance reduction techniques. In order to better understand the changes required to fool the discriminator we need access to the gradients of the discriminator network, which can be obtained from the Jacobian of the discriminator. This paper demonstrates the use of a forward model along with the Jacobian of the discriminator to train a policy, without using high-variance gradient estimations.<br />
<br />
= Algorithm =<br />
This section first analyzes the characteristics of the discriminator network, then describes how a forward model can enable policy imitation through GANs. Lastly, the model based adversarial imitation learning algorithm is presented.<br />
<br />
== The discriminator network ==<br />
The discriminator network is trained to predict the conditional distribution: <math> D(s,a) = p(y|s,a) </math> where <math> y \in (\pi_E, \pi) </math>.<br />
<br />
The discriminator is trained on an even distribution of expert and generated examples; hence <math> p(\pi) = p(\pi_E) = \frac{1}{2} </math>. Given this and applying Bayes' theorem, we can rearrange and factor <math> D(s,a) </math> to obtain:<br />
<br />
\begin{aligned}<br />
D(s,a) &= p(\pi|s,a) \\<br />
& = \frac{p(s,a|\pi)p(\pi)}{p(s,a|\pi)p(\pi) + p(s,a|\pi_E)p(\pi_E)} \\<br />
& = \frac{p(s,a|\pi)}{p(s,a|\pi) + p(s,a|\pi_E)} \\<br />
& = \frac{1}{1 + \frac{p(s,a|\pi_E)}{p(s,a|\pi)}} \\<br />
& = \frac{1}{1 + \frac{p(a|s,\pi_E)}{p(a|s,\pi)} \cdot \frac{p(s|\pi_E)}{p(s|\pi)}} \\<br />
\end{aligned}<br />
<br />
Define <math> \varphi(s,a) </math> and <math> \psi(s) </math> to be:<br />
<br />
\begin{aligned}<br />
\varphi(s,a) = \frac{p(a|s,\pi_E)}{p(a|s,\pi)}, \psi(s) = \frac{p(s|\pi_E)}{p(s|\pi)}<br />
\end{aligned}<br />
<br />
to get the final expression for <math> D(s,a) </math>:<br />
\begin{aligned}<br />
D(s,a) = \frac{1}{1 + \varphi(s,a)\cdot \psi(s)}<br />
\end{aligned}<br />
<br />
<math> \varphi(s,a) </math> represents a policy likelihood ratio, and <math> \psi(s) </math> represents a state distribution likelihood ratio. Based on these expressions, the paper states that the discriminator makes its decisions by answering two questions. The first question relates to state distribution: what is the likelihood of encountering state <math> s </math> under the distribution induces by <math> \pi_E </math> vs <math> \pi </math>? The second question is about behavior: given a state <math> s </math>, how likely is action a under <math> \pi_E </math> vs <math> \pi </math>? The desired change in state is given by <math> \psi_s \equiv \partial \psi / \partial s </math>; this information can by obtained from the partial derivatives of <math> D(s,a) </math>, which is why these derivatives are proposed to be used for training policies (see following sections):<br />
<br />
\begin{aligned}<br />
\nabla_aD &= - \frac{\varphi_a(s,a)\psi(s)}{(1 + \varphi(s,a)\psi(s))^2} \\<br />
\nabla_sD &= - \frac{\varphi_s(s,a)\psi(s) + \varphi(s,a)\psi_s(s)}{(1 + \varphi(s,a)\psi(s))^2} \\<br />
\end{aligned}<br />
<br />
== Backpropagating through stochastic units ==<br />
There is interest in training stochastic policies because stochasticity encourages exploration for Policy Gradient methods. This is a problem for algorithms that build differentiable computation graphs where the gradients flow from one component to another since it is unclear how to backpropagate through stochastic units. The following subsections show how to estimate the gradients of continuous and categorical stochastic elements for continuous and discrete action domains respectively.<br />
<br />
=== Continuous Action Distributions ===<br />
In the case of continuous action policies, re-parameterization was used to enable computing the derivatives of stochastic models. Assuming that the stochastic policy has a Gaussian distribution <math> \mathcal{N}(\mu_{\theta} (s), \sigma_{\theta}^2 (s))</math>, where the mean and variance are given by some deterministic functions <math>\mu_{\theta}</math> and <math>\sigma_{\theta}</math>, then the policy <math> \pi </math> can be written as <math> \pi_\theta(a|s) = \mu_\theta(s) + \xi \sigma_\theta(s) </math>, where <math> \xi \sim N(0,1) </math>. This way, the authors are able to get a Monte-Carlo estimator of the derivative of the expected value of <math> D(s, a) </math> with respect to <math> \theta </math>:<br />
<br />
\begin{align}<br />
\nabla_\theta\mathbb{E}_{\pi(a|s)}D(s,a) = \mathbb{E}_{\rho (\xi )}\nabla_a D(a,s) \nabla_\theta \pi_\theta(a|s) \cong \frac{1}{M}\sum_{i=1}^{M} \nabla_a D(s,a) \nabla_\theta \pi_\theta(a|s)\Bigr|_{\substack{\xi=\xi_i}}<br />
\end{align}<br />
<br />
=== Categorical Action Distributions ===<br />
In the case of discrete action domains, the paper uses categorical re-parameterization with Gumbel-Softmax. This method relies on the Gumbel-Max trick which is a method for drawing samples from a categorical distribution with class probabilities <math> \pi(a_1|s),\pi(a_2|s),...,\pi(a_N|s) </math>:<br />
<br />
\begin{align}<br />
a_{argmax} = \underset{i}{argmax}[g_i + log\ \pi(a_i|s)]\textrm{, where } g_i \sim Gumbel(0, 1).<br />
\end{align}<br />
<br />
Gumbel-Softmax provides a differentiable approximation of the samples obtained using the Gumbel-Max trick:<br />
<br />
\begin{align}<br />
a_{softmax} = \frac{exp[\frac{1}{\tau}(g_i + log\ \pi(a_i|s))]}{\sum_{j=1}^{k}exp[\frac{1}{\tau}(g_j + log\ \pi(a_i|s))]}<br />
\end{align}<br />
<br />
<br />
In the above equation, the hyper-parameter <math> \tau </math> (temperature) trades bias for variance. When <math> \tau </math> gets closer to zero, the softmax operator acts like argmax resulting in a low bias, but high variance; vice versa when the <math> \tau </math> is large.<br />
<br />
The authors use <math> a_{softmax} </math> to interact with the environment; argmax is applied over <math> a_{softmax} </math> to obtain a single “pure” action, but the continuous approximation is used in the backward pass using the estimation: <math> \nabla_\theta\; a_{argmax} \approx \nabla_\theta\; a_{softmax} </math>.<br />
<br />
== Backpropagating through a Forward model ==<br />
The above subsections presented the means for extracting the partial derivative <math> \nabla_aD </math>. The main contribution of this paper is incorporating the use of <math> \nabla_sD </math>. In a model-free approach the state <math> s </math> is treated as a fixed input, therefore <math> \nabla_sD </math> is discarded. This is illustrated in Figure 1. This work uses a model-based approach which makes incorporating <math> \nabla_sD </math> more involved. In the model-based approach, a state <math> s_t </math> can be written as a function of the previous state action pair: <math> s_t = f(s_{t-1}, a_{t-1}) </math>, where <math> f </math> represents the forward model. Using the forward model and the law of total derivatives we get:<br />
<br />
\begin{align}<br />
\nabla_\theta D(s_t,a_t)\Bigr|_{\substack{s=s_t, a=a_t}} &= \frac{\partial D}{\partial a}\frac{\partial a}{\partial \theta}\Bigr|_{\substack{a=a_t}} + \frac{\partial D}{\partial s}\frac{\partial s}{\partial \theta}\Bigr|_{\substack{s=s_t}} \\<br />
&= \frac{\partial D}{\partial a}\frac{\partial a}{\partial \theta}\Bigr|_{\substack{a=a_t}} + \frac{\partial D}{\partial s}\left (\frac{\partial f}{\partial s}\frac{\partial s}{\partial \theta}\Bigr|_{\substack{s=s_{t-1}}} + \frac{\partial f}{\partial a}\frac{\partial a}{\partial \theta}\Bigr|_{\substack{a=a_{t-1}}} \right )<br />
\end{align}<br />
<br />
<br />
Using this formula, the error regarding deviations of future states <math> (\psi_s) </math> propagate back in time and influence the actions of policies in earlier times. This is summarized in Figure 2.<br />
<br />
[[File:modelFree_blockDiagram.PNG|300px]]<br />
<br />
Figure 1: Block-diagram of the model-free approach: given a state <math> s </math>, the policy outputs <math> \mu </math> which is fed to a stochastic sampling unit. An action <math> a </math> is sampled, and together with <math> s </math> are presented to the discriminator network. In the backward phase, the error message <math> \delta_a </math> is blocked at the stochastic sampling unit. From there, a high-variance gradient estimation is used (<math> \delta_{HV} </math>). Meanwhile, the error message <math> \delta_s </math> is flushed.<br />
<br />
[[File:modelBased_blockDiagram.PNG|1000px]]<br />
<br />
Figure 2: Block diagram of model-based adversarial imitation learning. This diagram describes the computation graph for training the policy (i.e. G). The discriminator network D is fixed at this stage and is trained separately. At time <math> t </math> of the forward pass, <math> \pi </math> outputs a distribution over actions: <math> \mu_t = \pi(s_t) </math>, from which an action at is sampled. For example, in the continuous case, this is done using the re-parametrization trick: <math> a_t = \mu_t + \xi \cdot \sigma </math>, where <math> \xi \sim N(0,1) </math>. The next state <math> s_{t+1} = f(s_t, a_t) </math> is computed using the forward model (which is also trained separately), and the entire process repeats for time <math> t+1 </math>. In the backward pass, the gradient of <math> \pi </math> is comprised of a.) the error message <math> \delta_a </math> (Green) that propagates fluently through the differentiable approximation of the sampling process. And b.) the error message <math> \delta_s </math> (Blue) of future time-steps, that propagate back through the differentiable forward model.<br />
<br />
== MGAIL Algorithm ==<br />
Shalev- Shwartz et al. (2016) and Heess et al. (2015) built a multi-step computation graph for describing the familiar policy gradient objective; in this case it is given by:<br />
<br />
\begin{align}<br />
J(\theta) = \mathbb{E}\left [ \sum_{t=0}^{T} \gamma ^t D(s_t,a_t)|\theta\right ]<br />
\end{align}<br />
<br />
<br />
Using the results from Heess et al. (2015) this paper demonstrates how to differentiate <math> J(\theta) </math> over a trajectory of <math>(s,a,s’) </math> transitions:<br />
<br />
\begin{align}<br />
J_s &= \mathbb{E}_{p(a|s)}\mathbb{E}_{p(s'|s,a)}\left [ D_s + D_a \pi_s + \gamma J'_{s'}(f_s + f_a \pi_s) \right] \\<br />
J_\theta &= \mathbb{E}_{p(a|s)}\mathbb{E}_{p(s'|s,a)}\left [ D_a \pi_\theta + \gamma (J'_{s'} f_a \pi_\theta + J'_\theta) \right]<br />
\end{align}<br />
<br />
The policy gradient <math> \nabla_\theta J </math> is calculated by applying equations 12 and 13 recursively for <math> T </math> iterations. The MGAIL algorithm is presented below.<br />
<br />
[[File:MGAIL_alg.PNG]]<br />
<br />
== Forward Model Structure ==<br />
The stability of the learning process depends on the prediction accuracy of the forward model, but learning an accurate forward model is challenging by itself. The authors propose methods for improving the performance of the forward model based on two aspects of its functionality. First, the forward model should learn to use the action as an operator over the state space. To accomplish this, the actions and states, which are sampled form different distributions need to be first represented in a shared space. This is done by encoding the state and action using two separate neural networks and combining their outputs to form a single vector. Additionally, multiple previous states are used to predict the next state by representing the environment as an <math> n^{th} </math> order MDP. A gated recurrent units (GRU, a simpler variant on the LSTM model) layer is incorporated into the state encoder to enable recurrent connections from previous states. Using these modifications, the model is able to achieve better, and more stable results compared to the standard forward model based on a feed forward neural network. The comparison is presented in Figure 3.<br />
<br />
[[File:performance_comparison.PNG]]<br />
<br />
Figure 3: Performance comparison between a basic forward model (Blue), and the advanced forward model (Green).<br />
<br />
= Experiments =<br />
The proposed algorithm is evaluated on three discrete control tasks (Cartpole, Mountain-Car, Acrobot), and five continuous control tasks (Hopper, Walker, Half-Cheetah, Ant, and Humanoid), which are modeled by the MuJoCo physics simulator (Todorov et al., 2012). Expert policies are trained using the Trust Region Policy Optimization (TRPO) algorithm (Schulman et al., 2015). Different number of trajectories are used to train the expert for each task, but all trajectories are of length 1000.<br />
The discriminator and generator (policy) networks contains two hidden layers with ReLU non-linearity and are trained using the ADAM optimizer. The total reward received over a period of <math> N </math> steps using BC, GAIL and MGAIL is presented in Table 1. The proposed algorithm achieved the highest reward for most environments while exhibiting performance comparable to the expert over all of them.<br />
<br />
[[File:mgail_test_results.PNG]]<br />
<br />
Table 1. Policy performance, boldface indicates better results, <math> \pm </math> represents one standard deviation.<br />
<br />
= Discussion =<br />
This paper presented a model-free algorithm for imitation learning. It demonstrated how a forward model can be used to train policies using the exact gradient of the discriminator network. A downside of this approach is the need to learn a forward model, since this could be difficult in certain domains. Learning the system dynamics directly from raw images is considered as one line of future work. Another future work is to address the violation of the fundamental assumption made by all supervised learning algorithms, which requires the data to be i.i.d. This problem arises because the discriminator and forward models are trained in a supervised learning fashion using data sampled from a dynamic distribution.<br />
<br />
= Source =<br />
# Baram, Nir, et al. "End-to-end differentiable adversarial imitation learning." International Conference on Machine Learning. 2017.<br />
# Ho, Jonathan, and Stefano Ermon. "Generative adversarial imitation learning." Advances in Neural Information Processing Systems. 2016.<br />
# Shalev-Shwartz, Shai, et al. "Long-term planning by short-term prediction." arXiv preprint arXiv:1602.01580 (2016).<br />
# Heess, Nicolas, et al. "Learning continuous control policies by stochastic value gradients." Advances in Neural Information Processing Systems. 2015.<br />
# Schulman, John, et al. "Trust region policy optimization." International Conference on Machine Learning. 2015.</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=End-to-End_Differentiable_Adversarial_Imitation_Learning&diff=35027End-to-End Differentiable Adversarial Imitation Learning2018-03-21T19:24:49Z<p>X249wang: /* Continuous Action Distributions */</p>
<hr />
<div>= Introduction =<br />
The ability to imitate an expert policy is very beneficial in the case of automating human demonstrated tasks. Assuming that a sequence of state action pairs (trajectories) of an expert policy are available, a new policy can be trained that imitates the expert without having access to the original reward signal used by the expert. There are two main approaches to solve the problem of imitating a policy; they are Behavioural Cloning (BC) and Inverse Reinforcement Learning (IRL). BC directly learns the conditional distribution of actions over states in a supervised fashion by training on single time-step state-action pairs. The disadvantage of BC is that the training requires large amounts of expert data, which is hard to obtain. In addition, an agent trained using BC is unaware of how its action can affect future state distribution. The second method using IRL involves recovering a reward signal under which the expert is uniquely optimal; the main disadvantage is that it’s an ill-posed problem.<br />
<br />
To address the problem of imitating an expert policy, techniques based on Generative Adversarial Networks (GANs) have been proposed in recent years. GANs use a discriminator to guide the generative model towards producing patterns like those of the expert. This idea was used by (Ho & Ermon, 2016) in their work titled Generative Adversarial Imitation Learning (GAIL) to imitate an expert policy in a model-free setup. A model free setup is the one where the agent cannot make predictions about what the next state and reward will be before it takes each action since the transition function to move from state A to state B is not learned. The disadvantage of GAIL’s model-free approach is that backpropagation required gradient estimation which tends to suffer from high variance, which results in the need for large sample sizes and variance reduction methods. This paper proposed a model-based method (MGAIL) to address these issues by training a policy which is represented by the information propagated from the discriminator to the generator..<br />
<br />
= Background =<br />
== Markov Decision Process ==<br />
Consider an infinite-horizon discounted Markov decision process (MDP), defined by the tuple <math>(S, A, P, r, \rho_0, \gamma)</math> where <math>S</math> is the set of states, <math>A</math> is a set of actions, <math>P :<br />
S × A × S → [0, 1]</math> is the transition probability distribution, <math>r : (S × A) → R</math> is the reward function, <math>\rho_0 : S → [0, 1]</math> is the distribution over initial states, and <math>γ ∈ (0, 1)</math> is the discount factor. Let <math>π</math> denote a stochastic policy <math>π : S × A → [0, 1]</math>, <math>R(π)</math> denote its expected discounted reward: <math>E_πR = E_π [\sum_{t=0}^T \gamma^t r_t]</math> and <math>τ</math> denote a trajectory of states and actions <math>τ = {s_0, a_0, s_1, a_1, ...}</math>.<br />
<br />
== Imitation Learning ==<br />
A common technique for performing imitation learning is to train a policy <math> \pi </math> that minimizes some loss function <math> l(s, \pi(s)) </math> with respect to a discounted state distribution encountered by the expert: <math> d_\pi(s) = (1-\gamma)\sum_{t=0}^{\infty}\gamma^t p(s_t) </math>. This can be obtained using any supervised learning (SL) algorithm, but the policy's prediction affects future state distributions; this violates the independent and identically distributed (i.i.d) assumption made my most SL algorithms. This process is susceptible to compounding errors since a slight deviation in the learner's behavior can lead to different state distributions not encountered by the expert policy. <br />
<br />
This issue was overcome through the use of the Forward Training (FT) algorithm which trains a non-stationary policy iteratively overtime. At each time step a new policy is trained on the state distribution induced by the previously trained policies. This is continued till the end of the time horizon to obtain a policy that can mimic the expert policy. This requirement to train a policy at each time step till the end makes the FT algorithm impractical for cases where the time horizon is very large or undefined. This short coming is resolved using the Stochastic Mixing Iterative Learning (SMILe) algorithm. SMILe trains a stochastic stationary policy over several iterations under the trajectory distribution induced by the previously trained policy.<br />
<br />
== Generative Adversarial Networks ==<br />
GANs learn a generative model that can fool the discriminator by using a two-player zero-sum game:<br />
<br />
\begin{align} <br />
\underset{G}{\operatorname{argmin}}\; \underset{D\in (0,1)}{\operatorname{argmax}} = \mathbb{E}_{x\sim p_E}[log(D(x)]\ +\ \mathbb{E}_{z\sim p_z}[log(1 - D(G(z)))]<br />
\end{align}<br />
<br />
In the above equation, <math> p_E </math> represents the expert distribution and <math> p_z </math> represents the input noise distribution from which the input to the generator is sampled. The generator produces patterns and the discriminator judges if the pattern was generated or from the expert data. When the discriminator cannot distinguish between the two distributions the game ends and the generator has learned to mimic the expert. GANs rely on basic ideas such as binary classification and algorithms such as backpropagation in order to learn the expert distribution.<br />
<br />
GAIL applies GANs to the task of imitating an expert policy in a model-free approach. GAIL uses similar objective functions like GANs, but the expert distribution in GAIL represents the joint distribution over state action tuples:<br />
<br />
\begin{align} <br />
\underset{\pi}{\operatorname{argmin}}\; \underset{D\in (0,1)}{\operatorname{argmax}} = \mathbb{E}_{\pi}[log(D(s,a)]\ +\ \mathbb{E}_{\pi_E}[log(1 - D(s,a))] - \lambda H(\pi))<br />
\end{align}<br />
<br />
where <math> H(\pi) \triangleq \mathbb{E}_{\pi}[-log\: \pi(a|s)]</math> is the entropy.<br />
<br />
This problem cannot be solved using the standard methods described for GANs because the generator in GAIL represents a stochastic policy. The exact form of the first term in the above equation is given by: <math> \mathbb{E}_{s\sim \rho_\pi(s)}\mathbb{E}_{a\sim \pi(\cdot |s)} [log(D(s,a)] </math>.<br />
<br />
The two-player game now depends on the stochastic properties (<math> \theta </math>) of the policy, and it is unclear how to differentiate the above equation with respect to <math> \theta </math>. This problem can be overcome using score functions such as REINFORCE to obtain an unbiased gradient estimation:<br />
<br />
\begin{align}<br />
\nabla_\theta\mathbb{E}_{\pi} [log\; D(s,a)] \cong \hat{\mathbb{E}}_{\tau_i}[\nabla_\theta\; log\; \pi_\theta(a|s)Q(s,a)]<br />
\end{align}<br />
<br />
where <math> Q(\hat{s},\hat{a}) </math> is the score function of the gradient:<br />
<br />
\begin{align}<br />
Q(\hat{s},\hat{a}) = \hat{\mathbb{E}}_{\tau_i}[log\; D(s,a) | s_0 = \hat{s}, a_0 = \hat{a}]<br />
\end{align}<br />
<br />
<br />
REINFORCE gradients suffer from high variance which makes them difficult to work with even after applying variance reduction techniques. In order to better understand the changes required to fool the discriminator we need access to the gradients of the discriminator network, which can be obtained from the Jacobian of the discriminator. This paper demonstrates the use of a forward model along with the Jacobian of the discriminator to train a policy, without using high-variance gradient estimations.<br />
<br />
= Algorithm =<br />
This section first analyzes the characteristics of the discriminator network, then describes how a forward model can enable policy imitation through GANs. Lastly, the model based adversarial imitation learning algorithm is presented.<br />
<br />
== The discriminator network ==<br />
The discriminator network is trained to predict the conditional distribution: <math> D(s,a) = p(y|s,a) </math> where <math> y \in (\pi_E, \pi) </math>.<br />
<br />
The discriminator is trained on an even distribution of expert and generated examples; hence <math> p(\pi) = p(\pi_E) = \frac{1}{2} </math>. Given this, we can rearrange and factor <math> D(s,a) </math> to obtain:<br />
<br />
\begin{aligned}<br />
D(s,a) &= p(\pi|s,a) \\<br />
& = \frac{p(s,a|\pi)p(\pi)}{p(s,a|\pi)p(\pi) + p(s,a|\pi_E)p(\pi_E)} \\<br />
& = \frac{p(s,a|\pi)}{p(s,a|\pi) + p(s,a|\pi_E)} \\<br />
& = \frac{1}{1 + \frac{p(s,a|\pi_E)}{p(s,a|\pi)}} \\<br />
& = \frac{1}{1 + \frac{p(a|s,\pi_E)}{p(a|s,\pi)} \cdot \frac{p(s|\pi_E)}{p(s|\pi)}} \\<br />
\end{aligned}<br />
<br />
Define <math> \varphi(s,a) </math> and <math> \psi(s) </math> to be:<br />
<br />
\begin{aligned}<br />
\varphi(s,a) = \frac{p(a|s,\pi_E)}{p(a|s,\pi)}, \psi(s) = \frac{p(s|\pi_E)}{p(s|\pi)}<br />
\end{aligned}<br />
<br />
to get the final expression for <math> D(s,a) </math>:<br />
\begin{aligned}<br />
D(s,a) = \frac{1}{1 + \varphi(s,a)\cdot \psi(s)}<br />
\end{aligned}<br />
<br />
<math> \varphi(s,a) </math> represents a policy likelihood ratio, and <math> \psi(s) </math> represents a state distribution likelihood ratio. Based on these expressions, the paper states that the discriminator makes its decisions by answering two questions. The first question relates to state distribution: what is the likelihood of encountering state <math> s </math> under the distribution induces by <math> \pi_E </math> vs <math> \pi </math>? The second question is about behavior: given a state <math> s </math>, how likely is action a under <math> \pi_E </math> vs <math> \pi </math>? The desired change in state is given by <math> \psi_s \equiv \partial \psi / \partial s </math>; this information can by obtained from the partial derivatives of <math> D(s,a) </math>:<br />
<br />
\begin{aligned}<br />
\nabla_aD &= - \frac{\varphi_a(s,a)\psi(s)}{(1 + \varphi(s,a)\psi(s))^2} \\<br />
\nabla_sD &= - \frac{\varphi_s(s,a)\psi(s) + \varphi(s,a)\psi_s(s)}{(1 + \varphi(s,a)\psi(s))^2} \\<br />
\end{aligned}<br />
<br />
<br />
== Backpropagating through stochastic units ==<br />
There is interest in training stochastic policies because stochasticity encourages exploration for Policy Gradient methods. This is a problem for algorithms that build differentiable computation graphs where the gradients flow from one component to another since it is unclear how to backpropagate through stochastic units. The following subsections show how to estimate the gradients of continuous and categorical stochastic elements for continuous and discrete action domains respectively.<br />
<br />
=== Continuous Action Distributions ===<br />
In the case of continuous action policies, re-parameterization was used to enable computing the derivatives of stochastic models. Assuming that the stochastic policy has a Gaussian distribution <math> \mathcal{N}(\mu_{\theta} (s), \sigma_{\theta}^2 (s))</math>, where the mean and variance are given by some deterministic functions <math>\mu_{\theta}</math> and <math>\sigma_{\theta}</math>, then the policy <math> \pi </math> can be written as <math> \pi_\theta(a|s) = \mu_\theta(s) + \xi \sigma_\theta(s) </math>, where <math> \xi \sim N(0,1) </math>. This way, the authors are able to get a Monte-Carlo estimator of the derivative of the expected value of <math> D(s, a) </math> with respect to <math> \theta </math>:<br />
<br />
\begin{align}<br />
\nabla_\theta\mathbb{E}_{\pi(a|s)}D(s,a) = \mathbb{E}_{\rho (\xi )}\nabla_a D(a,s) \nabla_\theta \pi_\theta(a|s) \cong \frac{1}{M}\sum_{i=1}^{M} \nabla_a D(s,a) \nabla_\theta \pi_\theta(a|s)\Bigr|_{\substack{\xi=\xi_i}}<br />
\end{align}<br />
<br />
=== Categorical Action Distributions ===<br />
In the case of discrete action domains, the paper uses categorical re-parameterization with Gumbel-Softmax. This method relies on the Gumbel-Max trick which is a method for drawing samples from a categorical distribution with class probabilities <math> \pi(a_1|s),\pi(a_2|s),...,\pi(a_N|s) </math>:<br />
<br />
\begin{align}<br />
a_{argmax} = \underset{i}{argmax}[g_i + log\ \pi(a_i|s)]\textrm{, where } g_i \sim Gumbel(0, 1).<br />
\end{align}<br />
<br />
Gumbel-Softmax provides a differentiable approximation of the samples obtained using the Gumbel-Max trick:<br />
<br />
\begin{align}<br />
a_{softmax} = \frac{exp[\frac{1}{\tau}(g_i + log\ \pi(a_i|s))]}{\sum_{j=1}^{k}exp[\frac{1}{\tau}(g_j + log\ \pi(a_i|s))]}<br />
\end{align}<br />
<br />
<br />
In the above equation, the hyper-parameter <math> \tau </math> (temperature) trades bias for variance. When <math> \tau </math> gets closer to zero, the softmax operator acts like argmax resulting in a low bias, but high variance; vice versa when the <math> \tau </math> is large.<br />
<br />
The authors use <math> a_{softmax} </math> to interact with the environment; argmax is applied over <math> a_{softmax} </math> to obtain a single “pure” action, but the continuous approximation is used in the backward pass using the estimation: <math> \nabla_\theta\; a_{argmax} \approx \nabla_\theta\; a_{softmax} </math>.<br />
<br />
== Backpropagating through a Forward model ==<br />
The above subsections presented the means for extracting the partial derivative <math> \nabla_aD </math>. The main contribution of this paper is incorporating the use of <math> \nabla_sD </math>. In a model-free approach the state <math> s </math> is treated as a fixed input, therefore <math> \nabla_sD </math> is discarded. This is illustrated in Figure 1. This work uses a model-based approach which makes incorporating <math> \nabla_sD </math> more involved. In the model-based approach, a state <math> s_t </math> can be written as a function of the previous state action pair: <math> s_t = f(s_{t-1}, a_{t-1}) </math>, where <math> f </math> represents the forward model. Using the forward model and the law of total derivatives we get:<br />
<br />
\begin{align}<br />
\nabla_\theta D(s_t,a_t)\Bigr|_{\substack{s=s_t, a=a_t}} &= \frac{\partial D}{\partial a}\frac{\partial a}{\partial \theta}\Bigr|_{\substack{a=a_t}} + \frac{\partial D}{\partial s}\frac{\partial s}{\partial \theta}\Bigr|_{\substack{s=s_t}} \\<br />
&= \frac{\partial D}{\partial a}\frac{\partial a}{\partial \theta}\Bigr|_{\substack{a=a_t}} + \frac{\partial D}{\partial s}\left (\frac{\partial f}{\partial s}\frac{\partial s}{\partial \theta}\Bigr|_{\substack{s=s_{t-1}}} + \frac{\partial f}{\partial a}\frac{\partial a}{\partial \theta}\Bigr|_{\substack{a=a_{t-1}}} \right )<br />
\end{align}<br />
<br />
<br />
Using this formula, the error regarding deviations of future states <math> (\psi_s) </math> propagate back in time and influence the actions of policies in earlier times. This is summarized in Figure 2.<br />
<br />
[[File:modelFree_blockDiagram.PNG|300px]]<br />
<br />
Figure 1: Block-diagram of the model-free approach: given a state <math> s </math>, the policy outputs <math> \mu </math> which is fed to a stochastic sampling unit. An action <math> a </math> is sampled, and together with <math> s </math> are presented to the discriminator network. In the backward phase, the error message <math> \delta_a </math> is blocked at the stochastic sampling unit. From there, a high-variance gradient estimation is used (<math> \delta_{HV} </math>). Meanwhile, the error message <math> \delta_s </math> is flushed.<br />
<br />
[[File:modelBased_blockDiagram.PNG|1000px]]<br />
<br />
Figure 2: Block diagram of model-based adversarial imitation learning. This diagram describes the computation graph for training the policy (i.e. G). The discriminator network D is fixed at this stage and is trained separately. At time <math> t </math> of the forward pass, <math> \pi </math> outputs a distribution over actions: <math> \mu_t = \pi(s_t) </math>, from which an action at is sampled. For example, in the continuous case, this is done using the re-parametrization trick: <math> a_t = \mu_t + \xi \cdot \sigma </math>, where <math> \xi \sim N(0,1) </math>. The next state <math> s_{t+1} = f(s_t, a_t) </math> is computed using the forward model (which is also trained separately), and the entire process repeats for time <math> t+1 </math>. In the backward pass, the gradient of <math> \pi </math> is comprised of a.) the error message <math> \delta_a </math> (Green) that propagates fluently through the differentiable approximation of the sampling process. And b.) the error message <math> \delta_s </math> (Blue) of future time-steps, that propagate back through the differentiable forward model.<br />
<br />
== MGAIL Algorithm ==<br />
Shalev- Shwartz et al. (2016) and Heess et al. (2015) built a multi-step computation graph for describing the familiar policy gradient objective; in this case it is given by:<br />
<br />
\begin{align}<br />
J(\theta) = \mathbb{E}\left [ \sum_{t=0}^{T} \gamma ^t D(s_t,a_t)|\theta\right ]<br />
\end{align}<br />
<br />
<br />
Using the results from Heess et al. (2015) this paper demonstrates how to differentiate <math> J(\theta) </math> over a trajectory of <math>(s,a,s’) </math> transitions:<br />
<br />
\begin{align}<br />
J_s &= \mathbb{E}_{p(a|s)}\mathbb{E}_{p(s'|s,a)}\left [ D_s + D_a \pi_s + \gamma J'_{s'}(f_s + f_a \pi_s) \right] \\<br />
J_\theta &= \mathbb{E}_{p(a|s)}\mathbb{E}_{p(s'|s,a)}\left [ D_a \pi_\theta + \gamma (J'_{s'} f_a \pi_\theta + J'_\theta) \right]<br />
\end{align}<br />
<br />
The policy gradient <math> \nabla_\theta J </math> is calculated by applying equations 12 and 13 recursively for <math> T </math> iterations. The MGAIL algorithm is presented below.<br />
<br />
[[File:MGAIL_alg.PNG]]<br />
<br />
== Forward Model Structure ==<br />
The stability of the learning process depends on the prediction accuracy of the forward model, but learning an accurate forward model is challenging by itself. The authors propose methods for improving the performance of the forward model based on two aspects of its functionality. First, the forward model should learn to use the action as an operator over the state space. To accomplish this, the actions and states, which are sampled form different distributions need to be first represented in a shared space. This is done by encoding the state and action using two separate neural networks and combining their outputs to form a single vector. Additionally, multiple previous states are used to predict the next state by representing the environment as an <math> n^{th} </math> order MDP. A gated recurrent units (GRU, a simpler variant on the LSTM model) layer is incorporated into the state encoder to enable recurrent connections from previous states. Using these modifications, the model is able to achieve better, and more stable results compared to the standard forward model based on a feed forward neural network. The comparison is presented in Figure 3.<br />
<br />
[[File:performance_comparison.PNG]]<br />
<br />
Figure 3: Performance comparison between a basic forward model (Blue), and the advanced forward model (Green).<br />
<br />
= Experiments =<br />
The proposed algorithm is evaluated on three discrete control tasks (Cartpole, Mountain-Car, Acrobot), and five continuous control tasks (Hopper, Walker, Half-Cheetah, Ant, and Humanoid), which are modeled by the MuJoCo physics simulator (Todorov et al., 2012). Expert policies are trained using the Trust Region Policy Optimization (TRPO) algorithm (Schulman et al., 2015). Different number of trajectories are used to train the expert for each task, but all trajectories are of length 1000.<br />
The discriminator and generator (policy) networks contains two hidden layers with ReLU non-linearity and are trained using the ADAM optimizer. The total reward received over a period of <math> N </math> steps using BC, GAIL and MGAIL is presented in Table 1. The proposed algorithm achieved the highest reward for most environments while exhibiting performance comparable to the expert over all of them.<br />
<br />
[[File:mgail_test_results.PNG]]<br />
<br />
Table 1. Policy performance, boldface indicates better results, <math> \pm </math> represents one standard deviation.<br />
<br />
= Discussion =<br />
This paper presented a model-free algorithm for imitation learning. It demonstrated how a forward model can be used to train policies using the exact gradient of the discriminator network. A downside of this approach is the need to learn a forward model, since this could be difficult in certain domains. Learning the system dynamics directly from raw images is considered as one line of future work. Another future work is to address the violation of the fundamental assumption made by all supervised learning algorithms, which requires the data to be i.i.d. This problem arises because the discriminator and forward models are trained in a supervised learning fashion using data sampled from a dynamic distribution.<br />
<br />
= Source =<br />
# Baram, Nir, et al. "End-to-end differentiable adversarial imitation learning." International Conference on Machine Learning. 2017.<br />
# Ho, Jonathan, and Stefano Ermon. "Generative adversarial imitation learning." Advances in Neural Information Processing Systems. 2016.<br />
# Shalev-Shwartz, Shai, et al. "Long-term planning by short-term prediction." arXiv preprint arXiv:1602.01580 (2016).<br />
# Heess, Nicolas, et al. "Learning continuous control policies by stochastic value gradients." Advances in Neural Information Processing Systems. 2015.<br />
# Schulman, John, et al. "Trust region policy optimization." International Conference on Machine Learning. 2015.</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=End-to-End_Differentiable_Adversarial_Imitation_Learning&diff=35026End-to-End Differentiable Adversarial Imitation Learning2018-03-21T19:17:39Z<p>X249wang: /* Categorical Action Distributions */</p>
<hr />
<div>= Introduction =<br />
The ability to imitate an expert policy is very beneficial in the case of automating human demonstrated tasks. Assuming that a sequence of state action pairs (trajectories) of an expert policy are available, a new policy can be trained that imitates the expert without having access to the original reward signal used by the expert. There are two main approaches to solve the problem of imitating a policy; they are Behavioural Cloning (BC) and Inverse Reinforcement Learning (IRL). BC directly learns the conditional distribution of actions over states in a supervised fashion by training on single time-step state-action pairs. The disadvantage of BC is that the training requires large amounts of expert data, which is hard to obtain. In addition, an agent trained using BC is unaware of how its action can affect future state distribution. The second method using IRL involves recovering a reward signal under which the expert is uniquely optimal; the main disadvantage is that it’s an ill-posed problem.<br />
<br />
To address the problem of imitating an expert policy, techniques based on Generative Adversarial Networks (GANs) have been proposed in recent years. GANs use a discriminator to guide the generative model towards producing patterns like those of the expert. This idea was used by (Ho & Ermon, 2016) in their work titled Generative Adversarial Imitation Learning (GAIL) to imitate an expert policy in a model-free setup. A model free setup is the one where the agent cannot make predictions about what the next state and reward will be before it takes each action since the transition function to move from state A to state B is not learned. The disadvantage of GAIL’s model-free approach is that backpropagation required gradient estimation which tends to suffer from high variance, which results in the need for large sample sizes and variance reduction methods. This paper proposed a model-based method (MGAIL) to address these issues by training a policy which is represented by the information propagated from the discriminator to the generator..<br />
<br />
= Background =<br />
== Markov Decision Process ==<br />
Consider an infinite-horizon discounted Markov decision process (MDP), defined by the tuple <math>(S, A, P, r, \rho_0, \gamma)</math> where <math>S</math> is the set of states, <math>A</math> is a set of actions, <math>P :<br />
S × A × S → [0, 1]</math> is the transition probability distribution, <math>r : (S × A) → R</math> is the reward function, <math>\rho_0 : S → [0, 1]</math> is the distribution over initial states, and <math>γ ∈ (0, 1)</math> is the discount factor. Let <math>π</math> denote a stochastic policy <math>π : S × A → [0, 1]</math>, <math>R(π)</math> denote its expected discounted reward: <math>E_πR = E_π [\sum_{t=0}^T \gamma^t r_t]</math> and <math>τ</math> denote a trajectory of states and actions <math>τ = {s_0, a_0, s_1, a_1, ...}</math>.<br />
<br />
== Imitation Learning ==<br />
A common technique for performing imitation learning is to train a policy <math> \pi </math> that minimizes some loss function <math> l(s, \pi(s)) </math> with respect to a discounted state distribution encountered by the expert: <math> d_\pi(s) = (1-\gamma)\sum_{t=0}^{\infty}\gamma^t p(s_t) </math>. This can be obtained using any supervised learning (SL) algorithm, but the policy's prediction affects future state distributions; this violates the independent and identically distributed (i.i.d) assumption made my most SL algorithms. This process is susceptible to compounding errors since a slight deviation in the learner's behavior can lead to different state distributions not encountered by the expert policy. <br />
<br />
This issue was overcome through the use of the Forward Training (FT) algorithm which trains a non-stationary policy iteratively overtime. At each time step a new policy is trained on the state distribution induced by the previously trained policies. This is continued till the end of the time horizon to obtain a policy that can mimic the expert policy. This requirement to train a policy at each time step till the end makes the FT algorithm impractical for cases where the time horizon is very large or undefined. This short coming is resolved using the Stochastic Mixing Iterative Learning (SMILe) algorithm. SMILe trains a stochastic stationary policy over several iterations under the trajectory distribution induced by the previously trained policy.<br />
<br />
== Generative Adversarial Networks ==<br />
GANs learn a generative model that can fool the discriminator by using a two-player zero-sum game:<br />
<br />
\begin{align} <br />
\underset{G}{\operatorname{argmin}}\; \underset{D\in (0,1)}{\operatorname{argmax}} = \mathbb{E}_{x\sim p_E}[log(D(x)]\ +\ \mathbb{E}_{z\sim p_z}[log(1 - D(G(z)))]<br />
\end{align}<br />
<br />
In the above equation, <math> p_E </math> represents the expert distribution and <math> p_z </math> represents the input noise distribution from which the input to the generator is sampled. The generator produces patterns and the discriminator judges if the pattern was generated or from the expert data. When the discriminator cannot distinguish between the two distributions the game ends and the generator has learned to mimic the expert. GANs rely on basic ideas such as binary classification and algorithms such as backpropagation in order to learn the expert distribution.<br />
<br />
GAIL applies GANs to the task of imitating an expert policy in a model-free approach. GAIL uses similar objective functions like GANs, but the expert distribution in GAIL represents the joint distribution over state action tuples:<br />
<br />
\begin{align} <br />
\underset{\pi}{\operatorname{argmin}}\; \underset{D\in (0,1)}{\operatorname{argmax}} = \mathbb{E}_{\pi}[log(D(s,a)]\ +\ \mathbb{E}_{\pi_E}[log(1 - D(s,a))] - \lambda H(\pi))<br />
\end{align}<br />
<br />
where <math> H(\pi) \triangleq \mathbb{E}_{\pi}[-log\: \pi(a|s)]</math> is the entropy.<br />
<br />
This problem cannot be solved using the standard methods described for GANs because the generator in GAIL represents a stochastic policy. The exact form of the first term in the above equation is given by: <math> \mathbb{E}_{s\sim \rho_\pi(s)}\mathbb{E}_{a\sim \pi(\cdot |s)} [log(D(s,a)] </math>.<br />
<br />
The two-player game now depends on the stochastic properties (<math> \theta </math>) of the policy, and it is unclear how to differentiate the above equation with respect to <math> \theta </math>. This problem can be overcome using score functions such as REINFORCE to obtain an unbiased gradient estimation:<br />
<br />
\begin{align}<br />
\nabla_\theta\mathbb{E}_{\pi} [log\; D(s,a)] \cong \hat{\mathbb{E}}_{\tau_i}[\nabla_\theta\; log\; \pi_\theta(a|s)Q(s,a)]<br />
\end{align}<br />
<br />
where <math> Q(\hat{s},\hat{a}) </math> is the score function of the gradient:<br />
<br />
\begin{align}<br />
Q(\hat{s},\hat{a}) = \hat{\mathbb{E}}_{\tau_i}[log\; D(s,a) | s_0 = \hat{s}, a_0 = \hat{a}]<br />
\end{align}<br />
<br />
<br />
REINFORCE gradients suffer from high variance which makes them difficult to work with even after applying variance reduction techniques. In order to better understand the changes required to fool the discriminator we need access to the gradients of the discriminator network, which can be obtained from the Jacobian of the discriminator. This paper demonstrates the use of a forward model along with the Jacobian of the discriminator to train a policy, without using high-variance gradient estimations.<br />
<br />
= Algorithm =<br />
This section first analyzes the characteristics of the discriminator network, then describes how a forward model can enable policy imitation through GANs. Lastly, the model based adversarial imitation learning algorithm is presented.<br />
<br />
== The discriminator network ==<br />
The discriminator network is trained to predict the conditional distribution: <math> D(s,a) = p(y|s,a) </math> where <math> y \in (\pi_E, \pi) </math>.<br />
<br />
The discriminator is trained on an even distribution of expert and generated examples; hence <math> p(\pi) = p(\pi_E) = \frac{1}{2} </math>. Given this, we can rearrange and factor <math> D(s,a) </math> to obtain:<br />
<br />
\begin{aligned}<br />
D(s,a) &= p(\pi|s,a) \\<br />
& = \frac{p(s,a|\pi)p(\pi)}{p(s,a|\pi)p(\pi) + p(s,a|\pi_E)p(\pi_E)} \\<br />
& = \frac{p(s,a|\pi)}{p(s,a|\pi) + p(s,a|\pi_E)} \\<br />
& = \frac{1}{1 + \frac{p(s,a|\pi_E)}{p(s,a|\pi)}} \\<br />
& = \frac{1}{1 + \frac{p(a|s,\pi_E)}{p(a|s,\pi)} \cdot \frac{p(s|\pi_E)}{p(s|\pi)}} \\<br />
\end{aligned}<br />
<br />
Define <math> \varphi(s,a) </math> and <math> \psi(s) </math> to be:<br />
<br />
\begin{aligned}<br />
\varphi(s,a) = \frac{p(a|s,\pi_E)}{p(a|s,\pi)}, \psi(s) = \frac{p(s|\pi_E)}{p(s|\pi)}<br />
\end{aligned}<br />
<br />
to get the final expression for <math> D(s,a) </math>:<br />
\begin{aligned}<br />
D(s,a) = \frac{1}{1 + \varphi(s,a)\cdot \psi(s)}<br />
\end{aligned}<br />
<br />
<math> \varphi(s,a) </math> represents a policy likelihood ratio, and <math> \psi(s) </math> represents a state distribution likelihood ratio. Based on these expressions, the paper states that the discriminator makes its decisions by answering two questions. The first question relates to state distribution: what is the likelihood of encountering state <math> s </math> under the distribution induces by <math> \pi_E </math> vs <math> \pi </math>? The second question is about behavior: given a state <math> s </math>, how likely is action a under <math> \pi_E </math> vs <math> \pi </math>? The desired change in state is given by <math> \psi_s \equiv \partial \psi / \partial s </math>; this information can by obtained from the partial derivatives of <math> D(s,a) </math>:<br />
<br />
\begin{aligned}<br />
\nabla_aD &= - \frac{\varphi_a(s,a)\psi(s)}{(1 + \varphi(s,a)\psi(s))^2} \\<br />
\nabla_sD &= - \frac{\varphi_s(s,a)\psi(s) + \varphi(s,a)\psi_s(s)}{(1 + \varphi(s,a)\psi(s))^2} \\<br />
\end{aligned}<br />
<br />
<br />
== Backpropagating through stochastic units ==<br />
There is interest in training stochastic policies because stochasticity encourages exploration for Policy Gradient methods. This is a problem for algorithms that build differentiable computation graphs where the gradients flow from one component to another since it is unclear how to backpropagate through stochastic units. The following subsections show how to estimate the gradients of continuous and categorical stochastic elements for continuous and discrete action domains respectively.<br />
<br />
=== Continuous Action Distributions ===<br />
In the case of continuous action policies, re-parameterization was used to enable computing the derivatives of stochastic models. Assuming that the stochastic policy has a Gaussian distribution the policy <math> \pi </math> can be written as <math> \pi_\theta(a|s) = \mu_\theta(s) + \xi \sigma_\theta(s) </math>, where <math> \xi \sim N(0,1) </math>. This way, the authors are able to get a Monte-Carlo estimator of the derivative of the expected value of <math> D(s, a) </math> with respect to <math> \theta </math>:<br />
<br />
\begin{align}<br />
\nabla_\theta\mathbb{E}_{\pi(a|s)}D(s,a) = \mathbb{E}_{\rho (\xi )}\nabla_a D(a,s) \nabla_\theta \pi_\theta(a|s) \cong \frac{1}{M}\sum_{i=1}^{M} \nabla_a D(s,a) \nabla_\theta \pi_\theta(a|s)\Bigr|_{\substack{\xi=\xi_i}}<br />
\end{align}<br />
<br />
<br />
=== Categorical Action Distributions ===<br />
In the case of discrete action domains, the paper uses categorical re-parameterization with Gumbel-Softmax. This method relies on the Gumbel-Max trick which is a method for drawing samples from a categorical distribution with class probabilities <math> \pi(a_1|s),\pi(a_2|s),...,\pi(a_N|s) </math>:<br />
<br />
\begin{align}<br />
a_{argmax} = \underset{i}{argmax}[g_i + log\ \pi(a_i|s)]\textrm{, where } g_i \sim Gumbel(0, 1).<br />
\end{align}<br />
<br />
Gumbel-Softmax provides a differentiable approximation of the samples obtained using the Gumbel-Max trick:<br />
<br />
\begin{align}<br />
a_{softmax} = \frac{exp[\frac{1}{\tau}(g_i + log\ \pi(a_i|s))]}{\sum_{j=1}^{k}exp[\frac{1}{\tau}(g_j + log\ \pi(a_i|s))]}<br />
\end{align}<br />
<br />
<br />
In the above equation, the hyper-parameter <math> \tau </math> (temperature) trades bias for variance. When <math> \tau </math> gets closer to zero, the softmax operator acts like argmax resulting in a low bias, but high variance; vice versa when the <math> \tau </math> is large.<br />
<br />
The authors use <math> a_{softmax} </math> to interact with the environment; argmax is applied over <math> a_{softmax} </math> to obtain a single “pure” action, but the continuous approximation is used in the backward pass using the estimation: <math> \nabla_\theta\; a_{argmax} \approx \nabla_\theta\; a_{softmax} </math>.<br />
<br />
== Backpropagating through a Forward model ==<br />
The above subsections presented the means for extracting the partial derivative <math> \nabla_aD </math>. The main contribution of this paper is incorporating the use of <math> \nabla_sD </math>. In a model-free approach the state <math> s </math> is treated as a fixed input, therefore <math> \nabla_sD </math> is discarded. This is illustrated in Figure 1. This work uses a model-based approach which makes incorporating <math> \nabla_sD </math> more involved. In the model-based approach, a state <math> s_t </math> can be written as a function of the previous state action pair: <math> s_t = f(s_{t-1}, a_{t-1}) </math>, where <math> f </math> represents the forward model. Using the forward model and the law of total derivatives we get:<br />
<br />
\begin{align}<br />
\nabla_\theta D(s_t,a_t)\Bigr|_{\substack{s=s_t, a=a_t}} &= \frac{\partial D}{\partial a}\frac{\partial a}{\partial \theta}\Bigr|_{\substack{a=a_t}} + \frac{\partial D}{\partial s}\frac{\partial s}{\partial \theta}\Bigr|_{\substack{s=s_t}} \\<br />
&= \frac{\partial D}{\partial a}\frac{\partial a}{\partial \theta}\Bigr|_{\substack{a=a_t}} + \frac{\partial D}{\partial s}\left (\frac{\partial f}{\partial s}\frac{\partial s}{\partial \theta}\Bigr|_{\substack{s=s_{t-1}}} + \frac{\partial f}{\partial a}\frac{\partial a}{\partial \theta}\Bigr|_{\substack{a=a_{t-1}}} \right )<br />
\end{align}<br />
<br />
<br />
Using this formula, the error regarding deviations of future states <math> (\psi_s) </math> propagate back in time and influence the actions of policies in earlier times. This is summarized in Figure 2.<br />
<br />
[[File:modelFree_blockDiagram.PNG|300px]]<br />
<br />
Figure 1: Block-diagram of the model-free approach: given a state <math> s </math>, the policy outputs <math> \mu </math> which is fed to a stochastic sampling unit. An action <math> a </math> is sampled, and together with <math> s </math> are presented to the discriminator network. In the backward phase, the error message <math> \delta_a </math> is blocked at the stochastic sampling unit. From there, a high-variance gradient estimation is used (<math> \delta_{HV} </math>). Meanwhile, the error message <math> \delta_s </math> is flushed.<br />
<br />
[[File:modelBased_blockDiagram.PNG|1000px]]<br />
<br />
Figure 2: Block diagram of model-based adversarial imitation learning. This diagram describes the computation graph for training the policy (i.e. G). The discriminator network D is fixed at this stage and is trained separately. At time <math> t </math> of the forward pass, <math> \pi </math> outputs a distribution over actions: <math> \mu_t = \pi(s_t) </math>, from which an action at is sampled. For example, in the continuous case, this is done using the re-parametrization trick: <math> a_t = \mu_t + \xi \cdot \sigma </math>, where <math> \xi \sim N(0,1) </math>. The next state <math> s_{t+1} = f(s_t, a_t) </math> is computed using the forward model (which is also trained separately), and the entire process repeats for time <math> t+1 </math>. In the backward pass, the gradient of <math> \pi </math> is comprised of a.) the error message <math> \delta_a </math> (Green) that propagates fluently through the differentiable approximation of the sampling process. And b.) the error message <math> \delta_s </math> (Blue) of future time-steps, that propagate back through the differentiable forward model.<br />
<br />
== MGAIL Algorithm ==<br />
Shalev- Shwartz et al. (2016) and Heess et al. (2015) built a multi-step computation graph for describing the familiar policy gradient objective; in this case it is given by:<br />
<br />
\begin{align}<br />
J(\theta) = \mathbb{E}\left [ \sum_{t=0}^{T} \gamma ^t D(s_t,a_t)|\theta\right ]<br />
\end{align}<br />
<br />
<br />
Using the results from Heess et al. (2015) this paper demonstrates how to differentiate <math> J(\theta) </math> over a trajectory of <math>(s,a,s’) </math> transitions:<br />
<br />
\begin{align}<br />
J_s &= \mathbb{E}_{p(a|s)}\mathbb{E}_{p(s'|s,a)}\left [ D_s + D_a \pi_s + \gamma J'_{s'}(f_s + f_a \pi_s) \right] \\<br />
J_\theta &= \mathbb{E}_{p(a|s)}\mathbb{E}_{p(s'|s,a)}\left [ D_a \pi_\theta + \gamma (J'_{s'} f_a \pi_\theta + J'_\theta) \right]<br />
\end{align}<br />
<br />
The policy gradient <math> \nabla_\theta J </math> is calculated by applying equations 12 and 13 recursively for <math> T </math> iterations. The MGAIL algorithm is presented below.<br />
<br />
[[File:MGAIL_alg.PNG]]<br />
<br />
== Forward Model Structure ==<br />
The stability of the learning process depends on the prediction accuracy of the forward model, but learning an accurate forward model is challenging by itself. The authors propose methods for improving the performance of the forward model based on two aspects of its functionality. First, the forward model should learn to use the action as an operator over the state space. To accomplish this, the actions and states, which are sampled form different distributions need to be first represented in a shared space. This is done by encoding the state and action using two separate neural networks and combining their outputs to form a single vector. Additionally, multiple previous states are used to predict the next state by representing the environment as an <math> n^{th} </math> order MDP. A gated recurrent units (GRU, a simpler variant on the LSTM model) layer is incorporated into the state encoder to enable recurrent connections from previous states. Using these modifications, the model is able to achieve better, and more stable results compared to the standard forward model based on a feed forward neural network. The comparison is presented in Figure 3.<br />
<br />
[[File:performance_comparison.PNG]]<br />
<br />
Figure 3: Performance comparison between a basic forward model (Blue), and the advanced forward model (Green).<br />
<br />
= Experiments =<br />
The proposed algorithm is evaluated on three discrete control tasks (Cartpole, Mountain-Car, Acrobot), and five continuous control tasks (Hopper, Walker, Half-Cheetah, Ant, and Humanoid), which are modeled by the MuJoCo physics simulator (Todorov et al., 2012). Expert policies are trained using the Trust Region Policy Optimization (TRPO) algorithm (Schulman et al., 2015). Different number of trajectories are used to train the expert for each task, but all trajectories are of length 1000.<br />
The discriminator and generator (policy) networks contains two hidden layers with ReLU non-linearity and are trained using the ADAM optimizer. The total reward received over a period of <math> N </math> steps using BC, GAIL and MGAIL is presented in Table 1. The proposed algorithm achieved the highest reward for most environments while exhibiting performance comparable to the expert over all of them.<br />
<br />
[[File:mgail_test_results.PNG]]<br />
<br />
Table 1. Policy performance, boldface indicates better results, <math> \pm </math> represents one standard deviation.<br />
<br />
= Discussion =<br />
This paper presented a model-free algorithm for imitation learning. It demonstrated how a forward model can be used to train policies using the exact gradient of the discriminator network. A downside of this approach is the need to learn a forward model, since this could be difficult in certain domains. Learning the system dynamics directly from raw images is considered as one line of future work. Another future work is to address the violation of the fundamental assumption made by all supervised learning algorithms, which requires the data to be i.i.d. This problem arises because the discriminator and forward models are trained in a supervised learning fashion using data sampled from a dynamic distribution.<br />
<br />
= Source =<br />
# Baram, Nir, et al. "End-to-end differentiable adversarial imitation learning." International Conference on Machine Learning. 2017.<br />
# Ho, Jonathan, and Stefano Ermon. "Generative adversarial imitation learning." Advances in Neural Information Processing Systems. 2016.<br />
# Shalev-Shwartz, Shai, et al. "Long-term planning by short-term prediction." arXiv preprint arXiv:1602.01580 (2016).<br />
# Heess, Nicolas, et al. "Learning continuous control policies by stochastic value gradients." Advances in Neural Information Processing Systems. 2015.<br />
# Schulman, John, et al. "Trust region policy optimization." International Conference on Machine Learning. 2015.</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=End-to-End_Differentiable_Adversarial_Imitation_Learning&diff=35025End-to-End Differentiable Adversarial Imitation Learning2018-03-21T19:10:43Z<p>X249wang: /* Forward Model Structure */</p>
<hr />
<div>= Introduction =<br />
The ability to imitate an expert policy is very beneficial in the case of automating human demonstrated tasks. Assuming that a sequence of state action pairs (trajectories) of an expert policy are available, a new policy can be trained that imitates the expert without having access to the original reward signal used by the expert. There are two main approaches to solve the problem of imitating a policy; they are Behavioural Cloning (BC) and Inverse Reinforcement Learning (IRL). BC directly learns the conditional distribution of actions over states in a supervised fashion by training on single time-step state-action pairs. The disadvantage of BC is that the training requires large amounts of expert data, which is hard to obtain. In addition, an agent trained using BC is unaware of how its action can affect future state distribution. The second method using IRL involves recovering a reward signal under which the expert is uniquely optimal; the main disadvantage is that it’s an ill-posed problem.<br />
<br />
To address the problem of imitating an expert policy, techniques based on Generative Adversarial Networks (GANs) have been proposed in recent years. GANs use a discriminator to guide the generative model towards producing patterns like those of the expert. This idea was used by (Ho & Ermon, 2016) in their work titled Generative Adversarial Imitation Learning (GAIL) to imitate an expert policy in a model-free setup. A model free setup is the one where the agent cannot make predictions about what the next state and reward will be before it takes each action since the transition function to move from state A to state B is not learned. The disadvantage of GAIL’s model-free approach is that backpropagation required gradient estimation which tends to suffer from high variance, which results in the need for large sample sizes and variance reduction methods. This paper proposed a model-based method (MGAIL) to address these issues by training a policy which is represented by the information propagated from the discriminator to the generator..<br />
<br />
= Background =<br />
== Markov Decision Process ==<br />
Consider an infinite-horizon discounted Markov decision process (MDP), defined by the tuple <math>(S, A, P, r, \rho_0, \gamma)</math> where <math>S</math> is the set of states, <math>A</math> is a set of actions, <math>P :<br />
S × A × S → [0, 1]</math> is the transition probability distribution, <math>r : (S × A) → R</math> is the reward function, <math>\rho_0 : S → [0, 1]</math> is the distribution over initial states, and <math>γ ∈ (0, 1)</math> is the discount factor. Let <math>π</math> denote a stochastic policy <math>π : S × A → [0, 1]</math>, <math>R(π)</math> denote its expected discounted reward: <math>E_πR = E_π [\sum_{t=0}^T \gamma^t r_t]</math> and <math>τ</math> denote a trajectory of states and actions <math>τ = {s_0, a_0, s_1, a_1, ...}</math>.<br />
<br />
== Imitation Learning ==<br />
A common technique for performing imitation learning is to train a policy <math> \pi </math> that minimizes some loss function <math> l(s, \pi(s)) </math> with respect to a discounted state distribution encountered by the expert: <math> d_\pi(s) = (1-\gamma)\sum_{t=0}^{\infty}\gamma^t p(s_t) </math>. This can be obtained using any supervised learning (SL) algorithm, but the policy's prediction affects future state distributions; this violates the independent and identically distributed (i.i.d) assumption made my most SL algorithms. This process is susceptible to compounding errors since a slight deviation in the learner's behavior can lead to different state distributions not encountered by the expert policy. <br />
<br />
This issue was overcome through the use of the Forward Training (FT) algorithm which trains a non-stationary policy iteratively overtime. At each time step a new policy is trained on the state distribution induced by the previously trained policies. This is continued till the end of the time horizon to obtain a policy that can mimic the expert policy. This requirement to train a policy at each time step till the end makes the FT algorithm impractical for cases where the time horizon is very large or undefined. This short coming is resolved using the Stochastic Mixing Iterative Learning (SMILe) algorithm. SMILe trains a stochastic stationary policy over several iterations under the trajectory distribution induced by the previously trained policy.<br />
<br />
== Generative Adversarial Networks ==<br />
GANs learn a generative model that can fool the discriminator by using a two-player zero-sum game:<br />
<br />
\begin{align} <br />
\underset{G}{\operatorname{argmin}}\; \underset{D\in (0,1)}{\operatorname{argmax}} = \mathbb{E}_{x\sim p_E}[log(D(x)]\ +\ \mathbb{E}_{z\sim p_z}[log(1 - D(G(z)))]<br />
\end{align}<br />
<br />
In the above equation, <math> p_E </math> represents the expert distribution and <math> p_z </math> represents the input noise distribution from which the input to the generator is sampled. The generator produces patterns and the discriminator judges if the pattern was generated or from the expert data. When the discriminator cannot distinguish between the two distributions the game ends and the generator has learned to mimic the expert. GANs rely on basic ideas such as binary classification and algorithms such as backpropagation in order to learn the expert distribution.<br />
<br />
GAIL applies GANs to the task of imitating an expert policy in a model-free approach. GAIL uses similar objective functions like GANs, but the expert distribution in GAIL represents the joint distribution over state action tuples:<br />
<br />
\begin{align} <br />
\underset{\pi}{\operatorname{argmin}}\; \underset{D\in (0,1)}{\operatorname{argmax}} = \mathbb{E}_{\pi}[log(D(s,a)]\ +\ \mathbb{E}_{\pi_E}[log(1 - D(s,a))] - \lambda H(\pi))<br />
\end{align}<br />
<br />
where <math> H(\pi) \triangleq \mathbb{E}_{\pi}[-log\: \pi(a|s)]</math> is the entropy.<br />
<br />
This problem cannot be solved using the standard methods described for GANs because the generator in GAIL represents a stochastic policy. The exact form of the first term in the above equation is given by: <math> \mathbb{E}_{s\sim \rho_\pi(s)}\mathbb{E}_{a\sim \pi(\cdot |s)} [log(D(s,a)] </math>.<br />
<br />
The two-player game now depends on the stochastic properties (<math> \theta </math>) of the policy, and it is unclear how to differentiate the above equation with respect to <math> \theta </math>. This problem can be overcome using score functions such as REINFORCE to obtain an unbiased gradient estimation:<br />
<br />
\begin{align}<br />
\nabla_\theta\mathbb{E}_{\pi} [log\; D(s,a)] \cong \hat{\mathbb{E}}_{\tau_i}[\nabla_\theta\; log\; \pi_\theta(a|s)Q(s,a)]<br />
\end{align}<br />
<br />
where <math> Q(\hat{s},\hat{a}) </math> is the score function of the gradient:<br />
<br />
\begin{align}<br />
Q(\hat{s},\hat{a}) = \hat{\mathbb{E}}_{\tau_i}[log\; D(s,a) | s_0 = \hat{s}, a_0 = \hat{a}]<br />
\end{align}<br />
<br />
<br />
REINFORCE gradients suffer from high variance which makes them difficult to work with even after applying variance reduction techniques. In order to better understand the changes required to fool the discriminator we need access to the gradients of the discriminator network, which can be obtained from the Jacobian of the discriminator. This paper demonstrates the use of a forward model along with the Jacobian of the discriminator to train a policy, without using high-variance gradient estimations.<br />
<br />
= Algorithm =<br />
This section first analyzes the characteristics of the discriminator network, then describes how a forward model can enable policy imitation through GANs. Lastly, the model based adversarial imitation learning algorithm is presented.<br />
<br />
== The discriminator network ==<br />
The discriminator network is trained to predict the conditional distribution: <math> D(s,a) = p(y|s,a) </math> where <math> y \in (\pi_E, \pi) </math>.<br />
<br />
The discriminator is trained on an even distribution of expert and generated examples; hence <math> p(\pi) = p(\pi_E) = \frac{1}{2} </math>. Given this, we can rearrange and factor <math> D(s,a) </math> to obtain:<br />
<br />
\begin{aligned}<br />
D(s,a) &= p(\pi|s,a) \\<br />
& = \frac{p(s,a|\pi)p(\pi)}{p(s,a|\pi)p(\pi) + p(s,a|\pi_E)p(\pi_E)} \\<br />
& = \frac{p(s,a|\pi)}{p(s,a|\pi) + p(s,a|\pi_E)} \\<br />
& = \frac{1}{1 + \frac{p(s,a|\pi_E)}{p(s,a|\pi)}} \\<br />
& = \frac{1}{1 + \frac{p(a|s,\pi_E)}{p(a|s,\pi)} \cdot \frac{p(s|\pi_E)}{p(s|\pi)}} \\<br />
\end{aligned}<br />
<br />
Define <math> \varphi(s,a) </math> and <math> \psi(s) </math> to be:<br />
<br />
\begin{aligned}<br />
\varphi(s,a) = \frac{p(a|s,\pi_E)}{p(a|s,\pi)}, \psi(s) = \frac{p(s|\pi_E)}{p(s|\pi)}<br />
\end{aligned}<br />
<br />
to get the final expression for <math> D(s,a) </math>:<br />
\begin{aligned}<br />
D(s,a) = \frac{1}{1 + \varphi(s,a)\cdot \psi(s)}<br />
\end{aligned}<br />
<br />
<math> \varphi(s,a) </math> represents a policy likelihood ratio, and <math> \psi(s) </math> represents a state distribution likelihood ratio. Based on these expressions, the paper states that the discriminator makes its decisions by answering two questions. The first question relates to state distribution: what is the likelihood of encountering state <math> s </math> under the distribution induces by <math> \pi_E </math> vs <math> \pi </math>? The second question is about behavior: given a state <math> s </math>, how likely is action a under <math> \pi_E </math> vs <math> \pi </math>? The desired change in state is given by <math> \psi_s \equiv \partial \psi / \partial s </math>; this information can by obtained from the partial derivatives of <math> D(s,a) </math>:<br />
<br />
\begin{aligned}<br />
\nabla_aD &= - \frac{\varphi_a(s,a)\psi(s)}{(1 + \varphi(s,a)\psi(s))^2} \\<br />
\nabla_sD &= - \frac{\varphi_s(s,a)\psi(s) + \varphi(s,a)\psi_s(s)}{(1 + \varphi(s,a)\psi(s))^2} \\<br />
\end{aligned}<br />
<br />
<br />
== Backpropagating through stochastic units ==<br />
There is interest in training stochastic policies because stochasticity encourages exploration for Policy Gradient methods. This is a problem for algorithms that build differentiable computation graphs where the gradients flow from one component to another since it is unclear how to backpropagate through stochastic units. The following subsections show how to estimate the gradients of continuous and categorical stochastic elements for continuous and discrete action domains respectively.<br />
<br />
=== Continuous Action Distributions ===<br />
In the case of continuous action policies, re-parameterization was used to enable computing the derivatives of stochastic models. Assuming that the stochastic policy has a Gaussian distribution the policy <math> \pi </math> can be written as <math> \pi_\theta(a|s) = \mu_\theta(s) + \xi \sigma_\theta(s) </math>, where <math> \xi \sim N(0,1) </math>. This way, the authors are able to get a Monte-Carlo estimator of the derivative of the expected value of <math> D(s, a) </math> with respect to <math> \theta </math>:<br />
<br />
\begin{align}<br />
\nabla_\theta\mathbb{E}_{\pi(a|s)}D(s,a) = \mathbb{E}_{\rho (\xi )}\nabla_a D(a,s) \nabla_\theta \pi_\theta(a|s) \cong \frac{1}{M}\sum_{i=1}^{M} \nabla_a D(s,a) \nabla_\theta \pi_\theta(a|s)\Bigr|_{\substack{\xi=\xi_i}}<br />
\end{align}<br />
<br />
<br />
=== Categorical Action Distributions ===<br />
In the case of discrete action domains, the paper uses categorical re-parameterization with Gumbel-Softmax. This method relies on the Gumble-Max trick which is a method for drawing samples from a categorical distribution with class probabilities <math> \pi(a_1|s),\pi(a_2|s),...,\pi(a_N|s) </math>:<br />
<br />
\begin{align}<br />
a_{argmax} = \underset{i}{argmax}[g_i + log\ \pi(a_i|s)]<br />
\end{align}<br />
<br />
<br />
Gumbel-Softmax provides a differentiable approximation of the samples obtained using the Gumble-Max trick:<br />
<br />
\begin{align}<br />
a_{softmax} = \frac{exp[\frac{1}{\tau}(g_i + log\ \pi(a_i|s))]}{\sum_{j=1}^{k}exp[\frac{1}{\tau}(g_j + log\ \pi(a_i|s))]}<br />
\end{align}<br />
<br />
<br />
In the above equation, the hyper-parameter <math> \tau </math> (temperature) trades bias for variance. When <math> \tau </math> gets closer to zero, the softmax operator acts like argmax resulting in a low bias, but high variance; vice versa when the <math> \tau </math> is large.<br />
<br />
The authors use <math> a_{softmax} </math> to interact with the environment; argmax is applied over <math> a_{softmax} </math> to obtain a single “pure” action, but the continuous approximation is used in the backward pass using the estimation: <math> \nabla_\theta\; a_{argmax} \approx \nabla_\theta\; a_{softmax} </math>.<br />
<br />
== Backpropagating through a Forward model ==<br />
The above subsections presented the means for extracting the partial derivative <math> \nabla_aD </math>. The main contribution of this paper is incorporating the use of <math> \nabla_sD </math>. In a model-free approach the state <math> s </math> is treated as a fixed input, therefore <math> \nabla_sD </math> is discarded. This is illustrated in Figure 1. This work uses a model-based approach which makes incorporating <math> \nabla_sD </math> more involved. In the model-based approach, a state <math> s_t </math> can be written as a function of the previous state action pair: <math> s_t = f(s_{t-1}, a_{t-1}) </math>, where <math> f </math> represents the forward model. Using the forward model and the law of total derivatives we get:<br />
<br />
\begin{align}<br />
\nabla_\theta D(s_t,a_t)\Bigr|_{\substack{s=s_t, a=a_t}} &= \frac{\partial D}{\partial a}\frac{\partial a}{\partial \theta}\Bigr|_{\substack{a=a_t}} + \frac{\partial D}{\partial s}\frac{\partial s}{\partial \theta}\Bigr|_{\substack{s=s_t}} \\<br />
&= \frac{\partial D}{\partial a}\frac{\partial a}{\partial \theta}\Bigr|_{\substack{a=a_t}} + \frac{\partial D}{\partial s}\left (\frac{\partial f}{\partial s}\frac{\partial s}{\partial \theta}\Bigr|_{\substack{s=s_{t-1}}} + \frac{\partial f}{\partial a}\frac{\partial a}{\partial \theta}\Bigr|_{\substack{a=a_{t-1}}} \right )<br />
\end{align}<br />
<br />
<br />
Using this formula, the error regarding deviations of future states <math> (\psi_s) </math> propagate back in time and influence the actions of policies in earlier times. This is summarized in Figure 2.<br />
<br />
[[File:modelFree_blockDiagram.PNG|300px]]<br />
<br />
Figure 1: Block-diagram of the model-free approach: given a state <math> s </math>, the policy outputs <math> \mu </math> which is fed to a stochastic sampling unit. An action <math> a </math> is sampled, and together with <math> s </math> are presented to the discriminator network. In the backward phase, the error message <math> \delta_a </math> is blocked at the stochastic sampling unit. From there, a high-variance gradient estimation is used (<math> \delta_{HV} </math>). Meanwhile, the error message <math> \delta_s </math> is flushed.<br />
<br />
[[File:modelBased_blockDiagram.PNG|1000px]]<br />
<br />
Figure 2: Block diagram of model-based adversarial imitation learning. This diagram describes the computation graph for training the policy (i.e. G). The discriminator network D is fixed at this stage and is trained separately. At time <math> t </math> of the forward pass, <math> \pi </math> outputs a distribution over actions: <math> \mu_t = \pi(s_t) </math>, from which an action at is sampled. For example, in the continuous case, this is done using the re-parametrization trick: <math> a_t = \mu_t + \xi \cdot \sigma </math>, where <math> \xi \sim N(0,1) </math>. The next state <math> s_{t+1} = f(s_t, a_t) </math> is computed using the forward model (which is also trained separately), and the entire process repeats for time <math> t+1 </math>. In the backward pass, the gradient of <math> \pi </math> is comprised of a.) the error message <math> \delta_a </math> (Green) that propagates fluently through the differentiable approximation of the sampling process. And b.) the error message <math> \delta_s </math> (Blue) of future time-steps, that propagate back through the differentiable forward model.<br />
<br />
== MGAIL Algorithm ==<br />
Shalev- Shwartz et al. (2016) and Heess et al. (2015) built a multi-step computation graph for describing the familiar policy gradient objective; in this case it is given by:<br />
<br />
\begin{align}<br />
J(\theta) = \mathbb{E}\left [ \sum_{t=0}^{T} \gamma ^t D(s_t,a_t)|\theta\right ]<br />
\end{align}<br />
<br />
<br />
Using the results from Heess et al. (2015) this paper demonstrates how to differentiate <math> J(\theta) </math> over a trajectory of <math>(s,a,s’) </math> transitions:<br />
<br />
\begin{align}<br />
J_s &= \mathbb{E}_{p(a|s)}\mathbb{E}_{p(s'|s,a)}\left [ D_s + D_a \pi_s + \gamma J'_{s'}(f_s + f_a \pi_s) \right] \\<br />
J_\theta &= \mathbb{E}_{p(a|s)}\mathbb{E}_{p(s'|s,a)}\left [ D_a \pi_\theta + \gamma (J'_{s'} f_a \pi_\theta + J'_\theta) \right]<br />
\end{align}<br />
<br />
The policy gradient <math> \nabla_\theta J </math> is calculated by applying equations 12 and 13 recursively for <math> T </math> iterations. The MGAIL algorithm is presented below.<br />
<br />
[[File:MGAIL_alg.PNG]]<br />
<br />
== Forward Model Structure ==<br />
The stability of the learning process depends on the prediction accuracy of the forward model, but learning an accurate forward model is challenging by itself. The authors propose methods for improving the performance of the forward model based on two aspects of its functionality. First, the forward model should learn to use the action as an operator over the state space. To accomplish this, the actions and states, which are sampled form different distributions need to be first represented in a shared space. This is done by encoding the state and action using two separate neural networks and combining their outputs to form a single vector. Additionally, multiple previous states are used to predict the next state by representing the environment as an <math> n^{th} </math> order MDP. A gated recurrent units (GRU, a simpler variant on the LSTM model) layer is incorporated into the state encoder to enable recurrent connections from previous states. Using these modifications, the model is able to achieve better, and more stable results compared to the standard forward model based on a feed forward neural network. The comparison is presented in Figure 3.<br />
<br />
[[File:performance_comparison.PNG]]<br />
<br />
Figure 3: Performance comparison between a basic forward model (Blue), and the advanced forward model (Green).<br />
<br />
= Experiments =<br />
The proposed algorithm is evaluated on three discrete control tasks (Cartpole, Mountain-Car, Acrobot), and five continuous control tasks (Hopper, Walker, Half-Cheetah, Ant, and Humanoid), which are modeled by the MuJoCo physics simulator (Todorov et al., 2012). Expert policies are trained using the Trust Region Policy Optimization (TRPO) algorithm (Schulman et al., 2015). Different number of trajectories are used to train the expert for each task, but all trajectories are of length 1000.<br />
The discriminator and generator (policy) networks contains two hidden layers with ReLU non-linearity and are trained using the ADAM optimizer. The total reward received over a period of <math> N </math> steps using BC, GAIL and MGAIL is presented in Table 1. The proposed algorithm achieved the highest reward for most environments while exhibiting performance comparable to the expert over all of them.<br />
<br />
[[File:mgail_test_results.PNG]]<br />
<br />
Table 1. Policy performance, boldface indicates better results, <math> \pm </math> represents one standard deviation.<br />
<br />
= Discussion =<br />
This paper presented a model-free algorithm for imitation learning. It demonstrated how a forward model can be used to train policies using the exact gradient of the discriminator network. A downside of this approach is the need to learn a forward model, since this could be difficult in certain domains. Learning the system dynamics directly from raw images is considered as one line of future work. Another future work is to address the violation of the fundamental assumption made by all supervised learning algorithms, which requires the data to be i.i.d. This problem arises because the discriminator and forward models are trained in a supervised learning fashion using data sampled from a dynamic distribution.<br />
<br />
= Source =<br />
# Baram, Nir, et al. "End-to-end differentiable adversarial imitation learning." International Conference on Machine Learning. 2017.<br />
# Ho, Jonathan, and Stefano Ermon. "Generative adversarial imitation learning." Advances in Neural Information Processing Systems. 2016.<br />
# Shalev-Shwartz, Shai, et al. "Long-term planning by short-term prediction." arXiv preprint arXiv:1602.01580 (2016).<br />
# Heess, Nicolas, et al. "Learning continuous control policies by stochastic value gradients." Advances in Neural Information Processing Systems. 2015.<br />
# Schulman, John, et al. "Trust region policy optimization." International Conference on Machine Learning. 2015.</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=stat946w18/IMPROVING_GANS_USING_OPTIMAL_TRANSPORT&diff=35019stat946w18/IMPROVING GANS USING OPTIMAL TRANSPORT2018-03-21T16:58:03Z<p>X249wang: /* Mini-Batch Energy Distance */</p>
<hr />
<div>== Introduction ==<br />
Generative Adversarial Networks (GANs) are powerful generative models. A GAN model consists of a generator and a discriminator or critic. The generator is a neural network which is trained to generate data having a distribution matched with the distribution of the real data. The critic is also a neural network, which is trained to separate the generated data from the real data. A loss function that measures the distribution distance between the generated data and the real one is important to train the generator.<br />
<br />
Optimal transport theory evaluates the distribution distance between the generated data and the training data based on a metric, which provides another method for generator training. The main advantage of optimal transport theory over the distance measurement in GAN is its closed form solution for having a tractable training process. But the theory might also result in inconsistency in statistical estimation due to the given biased gradients if the mini-batches method is applied (Bellemare et al.,<br />
2017).<br />
<br />
This paper presents a variant GANs named OT-GAN, which incorporates a discriminative metric called 'MIni-batch Energy Distance' into its critic in order to overcome the issue of biased gradients.<br />
<br />
== GANs and Optimal Transport ==<br />
<br />
===Generative Adversarial Nets===<br />
Original GAN was firstly reviewed. The objective function of the GAN: <br />
<br />
[[File:equation1.png|700px]]<br />
<br />
The goal of GANs is to train the generator g and the discriminator d finding a pair of (g,d) to achieve Nash equilibrium(such that either of them cannot reduce their cost without changing the others' parameters). However, it could cause failure of converging since the generator and the discriminator are trained based on gradient descent techniques.<br />
<br />
===Wasserstein Distance (Earth-Mover Distance)===<br />
<br />
In order to solve the problem of convergence failure, Arjovsky et. al. (2017) suggested Wasserstein distance (Earth-Mover distance) based on the optimal transport theory.<br />
<br />
[[File:equation2.png|600px]]<br />
<br />
where <math> \prod (p,g) </math> is the set of all joint distributions <math> \gamma (x,y) </math> with marginals <math> p(x) </math> (real data), <math> g(y) </math> (generated data). <math> c(x,y) </math> is a cost function and the Euclidean distance was used by Arjovsky et. al. in the paper. <br />
<br />
The Wasserstein distance can be considered as moving the minimum amount of points between distribution <math> g(y) </math> and <math> p(x) </math> such that the generator distribution <math> g(y) </math> is similar to the real data distribution <math> p(x) </math>.<br />
<br />
Computing the Wasserstein distance is intractable. The proposed Wasserstein GAN (W-GAN) provides an estimated solution by switching the optimal transport problem into Kantorovich-Rubinstein dual formulation using a set of 1-Lipschitz functions. A neural network can then be used to obtain an estimation.<br />
<br />
[[File:equation3.png|600px]]<br />
<br />
W-GAN helps to solve the unstable training process of original GAN and it can solve the optimal transport problem approximately, but it is still intractable.<br />
<br />
===Sinkhorn Distance===<br />
Genevay et al. (2017) proposed to use the primal formulation of optimal transport instead of the dual formulation to generative modeling. They introduced Sinkhorn distance which is a smoothed generalization of the Wasserstein distance.<br />
[[File: equation4.png|600px]]<br />
<br />
It introduced entropy restriction (<math> \beta </math>) to the joint distribution <math> \prod_{\beta} (p,g) </math>. This distance could be generalized to approximate the mini-batches of data <math> X ,Y</math> with <math> K </math> vectors of <math> x, y</math>. The <math> i, j </math> th entry of the cost matrix <math> C </math> can be interpreted as the cost it needs to transport the <math> x_i </math> in mini-batch X to the <math> y_i </math> in mini-batch <math>Y </math>. The resulting distance will be:<br />
<br />
[[File: equation5.png|550px]]<br />
<br />
where <math> M </math> is a <math> K \times K </math> matrix, each row of <math> M </math> is a joint distribution of <math> \gamma (x,y) </math> with positive entries. The summmation of rows or columns of <math> M </math> is always equal to 1. <br />
<br />
This mini-batch Sinkhorn distance is not only fully tractable but also capable of solving the instability problem of GANs. However, it is not a valid metric over probability distribution when taking the expectation of <math> \mathcal{W}_{c} </math> and the gradients are biased when the mini-batch size is fixed.<br />
<br />
===Energy Distance (Cramer Distance)===<br />
In order to solve the above problem, Bellemare et al. proposed Energy distance:<br />
<br />
[[File: equation6.png|700px]]<br />
<br />
where <math> x, x' </math> and <math> y, y'</math> are independent samples from data distribution <math> p </math> and generator distribution <math> g </math>, respectively. Based on the Energy distance, Cramer GAN is to minimize the ED distance metric when training the generator.<br />
<br />
==Mini-Batch Energy Distance==<br />
Salimans et al. (2016) mentioned that comparing to use distributions over individual images, mini-batch GAN is more powerful when using the distributions over mini-batches <math> g(X), p(X) </math>. The distance measure is generated for mini-batches.<br />
<br />
===Generalized Energy Distance===<br />
The generalized energy distance allowed to use non-Euclidean distance functions d. It is also valid for mini-batches and is considered better than working with individual data batch.<br />
<br />
[[File: equation7.png|670px]]<br />
<br />
Similarly as defined in the Energy distance, <math> X, X' </math> and <math> Y, Y'</math> can be the independent samples from data distribution <math> p </math> and the generator distribution <math> g </math>, respectively. While in Generalized engergy distance, <math> X, X' </math> and <math> Y, Y'</math> can also be valid for mini-batches. The <math> D_{GED}(p,g) </math> is a metric when having <math> d </math> as a metric. Thus, taking the triangle inequality of <math> d </math> into account, <math> D(p,g) \geq 0,</math> and <math> D(p,g)=0 </math> when <math> p=g </math>.<br />
<br />
===Mini-Batch Energy Distance===<br />
As <math> d </math> is free to choose, authors proposed Mini-batch Energy Distance by using entropy-regularized Wasserstein distnace as <math> d </math>. <br />
<br />
[[File: equation8.png|650px]]<br />
<br />
where <math> X, X' </math> and <math> Y, Y'</math> are independent sampled mini-batches from the data distribution <math> p </math> and the generator distribution <math> g </math>, respectively. This distance metric combines the energy distance with primal form of optimal transport over mini-batch distributions <math> g(Y) </math> and <math> p(X) </math>. Inside the generalized energy distance, the Sinkhorn distance is a valid metric between each mini-batches. By adding the <math> - \mathcal{W}_c (Y,Y')</math> and <math> \mathcal{W}_c (X,Y)</math> to equation (5) and using energy distance, the objective becomes statistically consistent (meaning the objective converges to the true parameter value for large sample sizes) and mini-batch gradients are unbiased.<br />
<br />
==Optimal Transport GAN (OT-GAN)==<br />
<br />
In order to secure the statistical efficiency (i.e. being able to tell <math>p</math> and <math>g</math> apart without requiring an enormous sample size when their distance is close to zero), authors suggested using cosine distance between vectors <math> v_\eta (x) </math> and <math> v_\eta (y) </math> based on the deep neural network that maps the mini-batch data to a learned latent space. The reason for not using Euclidean distance is because of its poor performance in the high dimensional space. Here is the transportation cost:<br />
<br />
[[File: euqation9.png|370px]]<br />
<br />
where the <math> v_\eta </math> is chosen to maximize the resulting minibatch energy distance.<br />
<br />
Unlike the practice when using the original GANs, the generator was trained more often than the critic, which keep the cost function from degeneration. The resulting generator in OT-GAN has a well defined and statistically consistent objective through the training process.<br />
<br />
The algorithm is defined below. The backpropagation is not used in the algorithm due to the envelope theorem. Stochastic gradient descent is used as the optimization method in algorithm 1 below, although other optimizers are also possible. In fact, Adam was used in experiments. <br />
<br />
[[File: al.png|600px]]<br />
<br />
<br />
[[File: al_figure.png|600px]]<br />
<br />
==Experiments==<br />
<br />
In order to demonstrate the supermum performance of the OT-GAN, authors compared it with the original GAN and other popular models based on four experiments: Dataset recovery; CIFAR-10 test; ImageNet test; and the conditional image synthesis test.<br />
<br />
===Mixture of Gaussian Dataset===<br />
OT-GAN has a statistically consistent objective when it is compared with the original GAN (DC-GAN), such that the generator would not update to a wrong direction even if the signal provided by the cost function to the generator is not good. In order to prove this advantage, authors compared the OT-GAN with the original GAN loss (DAN-S) based on a simple task. The task was set to recover all of the 8 modes from 8 Gaussian mixers in which the means were arranged in a circle. MLP with RLU activation functions were used in this task. The critic was only updated for 15K iterations. The generator distribution was tracked for another 25K iteration. The results showed that the original GAN experiences the model collapse after fixing the discriminator while the OT-GAN recovered all the 8 modes from the mixed Gaussian data.<br />
<br />
[[File: 5_1.png|600px]]<br />
<br />
===CIFAR-10===<br />
<br />
The dataset CIFAR-10 was then used for inspecting the effect of batch-size to the model training process and the image quality. OT-GAN and four other methods were compared using "inception score" as the criteria for comparison. Figure 3 shows the change of inceptions scores (y-axis) by the increased of the iteration number. Scores of four different batch sizes (200, 800, 3200 and 8000) were compared. The results show that a larger batch size, which would more likely cover more modes in the distribution of data, lead to a more stable model showing a larger value in inception score. However, a large batch size would also require a high-performance computational environment. The sample quality across all 5 methods, ran using a batch size of 8000, are compared in Table 1 where the OT_GAN has the best score.<br />
<br />
[[File: 5_2.png|600px]]<br />
<br />
===ImageNet Dogs===<br />
<br />
In order to investigate the performance of OT-GAN when dealing with the high-quality images, the dog subset of ImageNet (128*128) was used to train the model. Figure 6 shows that OT-GAN produces less nonsensical images and it has a higher inception score compare to the DC-GAN. <br />
<br />
[[FIle: 5_3.png|600px]]<br />
<br />
===Conditional Generation of Birds===<br />
<br />
The last experiment was to compare OT-GAN with three popular GAN models for processing the text-to-image generation demonstrating the performance on conditional image synthesis. As can be found from Table 2, OT-GAN received the highest inception score than the scores of the other three models. <br />
<br />
[[File: 5_4.png|600px]]<br />
<br />
The algorithm used to obtain the results above is conditional generation generalized from '''Algorithm 1''' to include conditional information <math>s</math> such as some text description of an image. The modified algorithm is outlined in '''Algorithm 2'''.<br />
<br />
[[File: paper23_alg2.png|600px]]<br />
<br />
==Conclusion==<br />
<br />
In this paper, an OT-GAN method was proposed based on the optimal transport theory. A distance metric that combines the primal form of the optimal transport and the energy distance was given was presented for realizing the OT-GAN. One of the advantages of OT-GAN over other GAN models is that OT-GAN can stay on the correct track with an unbiased gradient even if the training on critic is stopped or presents a weak cost signal. The performance of the OT-GAN can be maintained when the batch size is increasing, though the computational cost has to be taken into consideration.<br />
<br />
==Critique==<br />
<br />
The paper presents a variant of GANs by defining a new distance metric based on the primal form of optimal transport and the mini-batch energy distance. The stability was demonstrated based on the four experiments that comparing OP-GAN with other popular methods. However, limitations in computational efficiency were not discussed much. Furthermore, in section 2, the paper is lack of explanation on using mini-batches instead of a vector as input when applying Sinkhorn distance. It is also confusing when explaining the algorithm in section 4 about choosing M for minimizing <math> \mathcal{W}_c </math>. Lastly, it is found that it is lack of parallel comparison with existing GAN variants in this paper. Readers may feel jumping from one algorithm to another without necessary explanations.<br />
<br />
==Reference==<br />
Salimans, Tim, Han Zhang, Alec Radford, and Dimitris Metaxas. "Improving GANs using optimal transport." (2018).</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=stat946w18/IMPROVING_GANS_USING_OPTIMAL_TRANSPORT&diff=35018stat946w18/IMPROVING GANS USING OPTIMAL TRANSPORT2018-03-21T16:50:50Z<p>X249wang: /* Optimal Transport GAN (OT-GAN) */</p>
<hr />
<div>== Introduction ==<br />
Generative Adversarial Networks (GANs) are powerful generative models. A GAN model consists of a generator and a discriminator or critic. The generator is a neural network which is trained to generate data having a distribution matched with the distribution of the real data. The critic is also a neural network, which is trained to separate the generated data from the real data. A loss function that measures the distribution distance between the generated data and the real one is important to train the generator.<br />
<br />
Optimal transport theory evaluates the distribution distance between the generated data and the training data based on a metric, which provides another method for generator training. The main advantage of optimal transport theory over the distance measurement in GAN is its closed form solution for having a tractable training process. But the theory might also result in inconsistency in statistical estimation due to the given biased gradients if the mini-batches method is applied (Bellemare et al.,<br />
2017).<br />
<br />
This paper presents a variant GANs named OT-GAN, which incorporates a discriminative metric called 'MIni-batch Energy Distance' into its critic in order to overcome the issue of biased gradients.<br />
<br />
== GANs and Optimal Transport ==<br />
<br />
===Generative Adversarial Nets===<br />
Original GAN was firstly reviewed. The objective function of the GAN: <br />
<br />
[[File:equation1.png|700px]]<br />
<br />
The goal of GANs is to train the generator g and the discriminator d finding a pair of (g,d) to achieve Nash equilibrium(such that either of them cannot reduce their cost without changing the others' parameters). However, it could cause failure of converging since the generator and the discriminator are trained based on gradient descent techniques.<br />
<br />
===Wasserstein Distance (Earth-Mover Distance)===<br />
<br />
In order to solve the problem of convergence failure, Arjovsky et. al. (2017) suggested Wasserstein distance (Earth-Mover distance) based on the optimal transport theory.<br />
<br />
[[File:equation2.png|600px]]<br />
<br />
where <math> \prod (p,g) </math> is the set of all joint distributions <math> \gamma (x,y) </math> with marginals <math> p(x) </math> (real data), <math> g(y) </math> (generated data). <math> c(x,y) </math> is a cost function and the Euclidean distance was used by Arjovsky et. al. in the paper. <br />
<br />
The Wasserstein distance can be considered as moving the minimum amount of points between distribution <math> g(y) </math> and <math> p(x) </math> such that the generator distribution <math> g(y) </math> is similar to the real data distribution <math> p(x) </math>.<br />
<br />
Computing the Wasserstein distance is intractable. The proposed Wasserstein GAN (W-GAN) provides an estimated solution by switching the optimal transport problem into Kantorovich-Rubinstein dual formulation using a set of 1-Lipschitz functions. A neural network can then be used to obtain an estimation.<br />
<br />
[[File:equation3.png|600px]]<br />
<br />
W-GAN helps to solve the unstable training process of original GAN and it can solve the optimal transport problem approximately, but it is still intractable.<br />
<br />
===Sinkhorn Distance===<br />
Genevay et al. (2017) proposed to use the primal formulation of optimal transport instead of the dual formulation to generative modeling. They introduced Sinkhorn distance which is a smoothed generalization of the Wasserstein distance.<br />
[[File: equation4.png|600px]]<br />
<br />
It introduced entropy restriction (<math> \beta </math>) to the joint distribution <math> \prod_{\beta} (p,g) </math>. This distance could be generalized to approximate the mini-batches of data <math> X ,Y</math> with <math> K </math> vectors of <math> x, y</math>. The <math> i, j </math> th entry of the cost matrix <math> C </math> can be interpreted as the cost it needs to transport the <math> x_i </math> in mini-batch X to the <math> y_i </math> in mini-batch <math>Y </math>. The resulting distance will be:<br />
<br />
[[File: equation5.png|550px]]<br />
<br />
where <math> M </math> is a <math> K \times K </math> matrix, each row of <math> M </math> is a joint distribution of <math> \gamma (x,y) </math> with positive entries. The summmation of rows or columns of <math> M </math> is always equal to 1. <br />
<br />
This mini-batch Sinkhorn distance is not only fully tractable but also capable of solving the instability problem of GANs. However, it is not a valid metric over probability distribution when taking the expectation of <math> \mathcal{W}_{c} </math> and the gradients are biased when the mini-batch size is fixed.<br />
<br />
===Energy Distance (Cramer Distance)===<br />
In order to solve the above problem, Bellemare et al. proposed Energy distance:<br />
<br />
[[File: equation6.png|700px]]<br />
<br />
where <math> x, x' </math> and <math> y, y'</math> are independent samples from data distribution <math> p </math> and generator distribution <math> g </math>, respectively. Based on the Energy distance, Cramer GAN is to minimize the ED distance metric when training the generator.<br />
<br />
==Mini-Batch Energy Distance==<br />
Salimans et al. (2016) mentioned that comparing to use distributions over individual images, mini-batch GAN is more powerful when using the distributions over mini-batches <math> g(X), p(X) </math>. The distance measure is generated for mini-batches.<br />
<br />
===Generalized Energy Distance===<br />
The generalized energy distance allowed to use non-Euclidean distance functions d. It is also valid for mini-batches and is considered better than working with individual data batch.<br />
<br />
[[File: equation7.png|670px]]<br />
<br />
Similarly as defined in the Energy distance, <math> X, X' </math> and <math> Y, Y'</math> can be the independent samples from data distribution <math> p </math> and the generator distribution <math> g </math>, respectively. While in Generalized engergy distance, <math> X, X' </math> and <math> Y, Y'</math> can also be valid for mini-batches. The <math> D_{GED}(p,g) </math> is a metric when having <math> d </math> as a metric. Thus, taking the triangle inequality of <math> d </math> into account, <math> D(p,g) \geq 0,</math> and <math> D(p,g)=0 </math> when <math> p=g </math>.<br />
<br />
===Mini-Batch Energy Distance===<br />
As <math> d </math> is free to choose, authors proposed Mini-batch Energy Distance by using entropy-regularized Wasserstein distnace as <math> d </math>. <br />
<br />
[[File: equation8.png|650px]]<br />
<br />
where <math> X, X' </math> and <math> Y, Y'</math> are independent sampled mini-batches from the data distribution <math> p </math> and the generator distribution <math> g </math>, respectively. This distance metric combines the energy distance with primal form of optimal tranport over mini-batch distributions <math> g(Y) </math> and <math> p(X) </math>. Inside the generalized energy distance, the Sinkhorn distance is a valid metric between each mini-batches. By adding the <math> - \mathcal{W}_c (Y,Y')</math> and <math> \mathcal{W}_c (X,Y)</math> to equation (5) and using enregy distance, the objective becomes statistically consistent and mini-batch gradients are unbiased.<br />
<br />
==Optimal Transport GAN (OT-GAN)==<br />
<br />
In order to secure the statistical efficiency (i.e. being able to tell <math>p</math> and <math>g</math> apart without requiring an enormous sample size when their distance is close to zero), authors suggested using cosine distance between vectors <math> v_\eta (x) </math> and <math> v_\eta (y) </math> based on the deep neural network that maps the mini-batch data to a learned latent space. The reason for not using Euclidean distance is because of its poor performance in the high dimensional space. Here is the transportation cost:<br />
<br />
[[File: euqation9.png|370px]]<br />
<br />
where the <math> v_\eta </math> is chosen to maximize the resulting minibatch energy distance.<br />
<br />
Unlike the practice when using the original GANs, the generator was trained more often than the critic, which keep the cost function from degeneration. The resulting generator in OT-GAN has a well defined and statistically consistent objective through the training process.<br />
<br />
The algorithm is defined below. The backpropagation is not used in the algorithm due to the envelope theorem. Stochastic gradient descent is used as the optimization method in algorithm 1 below, although other optimizers are also possible. In fact, Adam was used in experiments. <br />
<br />
[[File: al.png|600px]]<br />
<br />
<br />
[[File: al_figure.png|600px]]<br />
<br />
==Experiments==<br />
<br />
In order to demonstrate the supermum performance of the OT-GAN, authors compared it with the original GAN and other popular models based on four experiments: Dataset recovery; CIFAR-10 test; ImageNet test; and the conditional image synthesis test.<br />
<br />
===Mixture of Gaussian Dataset===<br />
OT-GAN has a statistically consistent objective when it is compared with the original GAN (DC-GAN), such that the generator would not update to a wrong direction even if the signal provided by the cost function to the generator is not good. In order to prove this advantage, authors compared the OT-GAN with the original GAN loss (DAN-S) based on a simple task. The task was set to recover all of the 8 modes from 8 Gaussian mixers in which the means were arranged in a circle. MLP with RLU activation functions were used in this task. The critic was only updated for 15K iterations. The generator distribution was tracked for another 25K iteration. The results showed that the original GAN experiences the model collapse after fixing the discriminator while the OT-GAN recovered all the 8 modes from the mixed Gaussian data.<br />
<br />
[[File: 5_1.png|600px]]<br />
<br />
===CIFAR-10===<br />
<br />
The dataset CIFAR-10 was then used for inspecting the effect of batch-size to the model training process and the image quality. OT-GAN and four other methods were compared using "inception score" as the criteria for comparison. Figure 3 shows the change of inceptions scores (y-axis) by the increased of the iteration number. Scores of four different batch sizes (200, 800, 3200 and 8000) were compared. The results show that a larger batch size, which would more likely cover more modes in the distribution of data, lead to a more stable model showing a larger value in inception score. However, a large batch size would also require a high-performance computational environment. The sample quality across all 5 methods, ran using a batch size of 8000, are compared in Table 1 where the OT_GAN has the best score.<br />
<br />
[[File: 5_2.png|600px]]<br />
<br />
===ImageNet Dogs===<br />
<br />
In order to investigate the performance of OT-GAN when dealing with the high-quality images, the dog subset of ImageNet (128*128) was used to train the model. Figure 6 shows that OT-GAN produces less nonsensical images and it has a higher inception score compare to the DC-GAN. <br />
<br />
[[FIle: 5_3.png|600px]]<br />
<br />
===Conditional Generation of Birds===<br />
<br />
The last experiment was to compare OT-GAN with three popular GAN models for processing the text-to-image generation demonstrating the performance on conditional image synthesis. As can be found from Table 2, OT-GAN received the highest inception score than the scores of the other three models. <br />
<br />
[[File: 5_4.png|600px]]<br />
<br />
The algorithm used to obtain the results above is conditional generation generalized from '''Algorithm 1''' to include conditional information <math>s</math> such as some text description of an image. The modified algorithm is outlined in '''Algorithm 2'''.<br />
<br />
[[File: paper23_alg2.png|600px]]<br />
<br />
==Conclusion==<br />
<br />
In this paper, an OT-GAN method was proposed based on the optimal transport theory. A distance metric that combines the primal form of the optimal transport and the energy distance was given was presented for realizing the OT-GAN. One of the advantages of OT-GAN over other GAN models is that OT-GAN can stay on the correct track with an unbiased gradient even if the training on critic is stopped or presents a weak cost signal. The performance of the OT-GAN can be maintained when the batch size is increasing, though the computational cost has to be taken into consideration.<br />
<br />
==Critique==<br />
<br />
The paper presents a variant of GANs by defining a new distance metric based on the primal form of optimal transport and the mini-batch energy distance. The stability was demonstrated based on the four experiments that comparing OP-GAN with other popular methods. However, limitations in computational efficiency were not discussed much. Furthermore, in section 2, the paper is lack of explanation on using mini-batches instead of a vector as input when applying Sinkhorn distance. It is also confusing when explaining the algorithm in section 4 about choosing M for minimizing <math> \mathcal{W}_c </math>. Lastly, it is found that it is lack of parallel comparison with existing GAN variants in this paper. Readers may feel jumping from one algorithm to another without necessary explanations.<br />
<br />
==Reference==<br />
Salimans, Tim, Han Zhang, Alec Radford, and Dimitris Metaxas. "Improving GANs using optimal transport." (2018).</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=stat946w18/IMPROVING_GANS_USING_OPTIMAL_TRANSPORT&diff=35017stat946w18/IMPROVING GANS USING OPTIMAL TRANSPORT2018-03-21T16:05:43Z<p>X249wang: /* CIFAR-10 */</p>
<hr />
<div>== Introduction ==<br />
Generative Adversarial Networks (GANs) are powerful generative models. A GAN model consists of a generator and a discriminator or critic. The generator is a neural network which is trained to generate data having a distribution matched with the distribution of the real data. The critic is also a neural network, which is trained to separate the generated data from the real data. A loss function that measures the distribution distance between the generated data and the real one is important to train the generator.<br />
<br />
Optimal transport theory evaluates the distribution distance between the generated data and the training data based on a metric, which provides another method for generator training. The main advantage of optimal transport theory over the distance measurement in GAN is its closed form solution for having a tractable training process. But the theory might also result in inconsistency in statistical estimation due to the given biased gradients if the mini-batches method is applied (Bellemare et al.,<br />
2017).<br />
<br />
This paper presents a variant GANs named OT-GAN, which incorporates a discriminative metric called 'MIni-batch Energy Distance' into its critic in order to overcome the issue of biased gradients.<br />
<br />
== GANs and Optimal Transport ==<br />
<br />
===Generative Adversarial Nets===<br />
Original GAN was firstly reviewed. The objective function of the GAN: <br />
<br />
[[File:equation1.png|700px]]<br />
<br />
The goal of GANs is to train the generator g and the discriminator d finding a pair of (g,d) to achieve Nash equilibrium(such that either of them cannot reduce their cost without changing the others' parameters). However, it could cause failure of converging since the generator and the discriminator are trained based on gradient descent techniques.<br />
<br />
===Wasserstein Distance (Earth-Mover Distance)===<br />
<br />
In order to solve the problem of convergence failure, Arjovsky et. al. (2017) suggested Wasserstein distance (Earth-Mover distance) based on the optimal transport theory.<br />
<br />
[[File:equation2.png|600px]]<br />
<br />
where <math> \prod (p,g) </math> is the set of all joint distributions <math> \gamma (x,y) </math> with marginals <math> p(x) </math> (real data), <math> g(y) </math> (generated data). <math> c(x,y) </math> is a cost function and the Euclidean distance was used by Arjovsky et. al. in the paper. <br />
<br />
The Wasserstein distance can be considered as moving the minimum amount of points between distribution <math> g(y) </math> and <math> p(x) </math> such that the generator distribution <math> g(y) </math> is similar to the real data distribution <math> p(x) </math>.<br />
<br />
Computing the Wasserstein distance is intractable. The proposed Wasserstein GAN (W-GAN) provides an estimated solution by switching the optimal transport problem into Kantorovich-Rubinstein dual formulation using a set of 1-Lipschitz functions. A neural network can then be used to obtain an estimation.<br />
<br />
[[File:equation3.png|600px]]<br />
<br />
W-GAN helps to solve the unstable training process of original GAN and it can solve the optimal transport problem approximately, but it is still intractable.<br />
<br />
===Sinkhorn Distance===<br />
Genevay et al. (2017) proposed to use the primal formulation of optimal transport instead of the dual formulation to generative modeling. They introduced Sinkhorn distance which is a smoothed generalization of the Wasserstein distance.<br />
[[File: equation4.png|600px]]<br />
<br />
It introduced entropy restriction (<math> \beta </math>) to the joint distribution <math> \prod_{\beta} (p,g) </math>. This distance could be generalized to approximate the mini-batches of data <math> X ,Y</math> with <math> K </math> vectors of <math> x, y</math>. The <math> i, j </math> th entry of the cost matrix <math> C </math> can be interpreted as the cost it needs to transport the <math> x_i </math> in mini-batch X to the <math> y_i </math> in mini-batch <math>Y </math>. The resulting distance will be:<br />
<br />
[[File: equation5.png|550px]]<br />
<br />
where <math> M </math> is a <math> K \times K </math> matrix, each row of <math> M </math> is a joint distribution of <math> \gamma (x,y) </math> with positive entries. The summmation of rows or columns of <math> M </math> is always equal to 1. <br />
<br />
This mini-batch Sinkhorn distance is not only fully tractable but also capable of solving the instability problem of GANs. However, it is not a valid metric over probability distribution when taking the expectation of <math> \mathcal{W}_{c} </math> and the gradients are biased when the mini-batch size is fixed.<br />
<br />
===Energy Distance (Cramer Distance)===<br />
In order to solve the above problem, Bellemare et al. proposed Energy distance:<br />
<br />
[[File: equation6.png|700px]]<br />
<br />
where <math> x, x' </math> and <math> y, y'</math> are independent samples from data distribution <math> p </math> and generator distribution <math> g </math>, respectively. Based on the Energy distance, Cramer GAN is to minimize the ED distance metric when training the generator.<br />
<br />
==Mini-Batch Energy Distance==<br />
Salimans et al. (2016) mentioned that comparing to use distributions over individual images, mini-batch GAN is more powerful when using the distributions over mini-batches <math> g(X), p(X) </math>. The distance measure is generated for mini-batches.<br />
<br />
===Generalized Energy Distance===<br />
The generalized energy distance allowed to use non-Euclidean distance functions d. It is also valid for mini-batches and is considered better than working with individual data batch.<br />
<br />
[[File: equation7.png|670px]]<br />
<br />
Similarly as defined in the Energy distance, <math> X, X' </math> and <math> Y, Y'</math> can be the independent samples from data distribution <math> p </math> and the generator distribution <math> g </math>, respectively. While in Generalized engergy distance, <math> X, X' </math> and <math> Y, Y'</math> can also be valid for mini-batches. The <math> D_{GED}(p,g) </math> is a metric when having <math> d </math> as a metric. Thus, taking the triangle inequality of <math> d </math> into account, <math> D(p,g) \geq 0,</math> and <math> D(p,g)=0 </math> when <math> p=g </math>.<br />
<br />
===Mini-Batch Energy Distance===<br />
As <math> d </math> is free to choose, authors proposed Mini-batch Energy Distance by using entropy-regularized Wasserstein distnace as <math> d </math>. <br />
<br />
[[File: equation8.png|650px]]<br />
<br />
where <math> X, X' </math> and <math> Y, Y'</math> are independent sampled mini-batches from the data distribution <math> p </math> and the generator distribution <math> g </math>, respectively. This distance metric combines the energy distance with primal form of optimal tranport over mini-batch distributions <math> g(Y) </math> and <math> p(X) </math>. Inside the generalized energy distance, the Sinkhorn distance is a valid metric between each mini-batches. By adding the <math> - \mathcal{W}_c (Y,Y')</math> and <math> \mathcal{W}_c (X,Y)</math> to equation (5) and using enregy distance, the objective becomes statistically consistent and mini-batch gradients are unbiased.<br />
<br />
==Optimal Transport GAN (OT-GAN)==<br />
<br />
In order to secure the statistical efficiency (i.e. being able to tell <math>p</math> and <math>g</math> apart without requiring an enormous sample size when their distance is close to zero), authors suggested using cosine distance between vectors <math> v_\eta (x) </math> and <math> v_\eta (y) </math> based on the deep neural network that maps the mini-batch data to a learned latent space. The reason for not using Euclidean distance is because of its poor performance in the high dimensional space. Here is the transportation cost:<br />
<br />
[[File: euqation9.png|370px]]<br />
<br />
where the <math> v_\eta </math> is chosen to maximize the resulting minibatch energy distance.<br />
<br />
Unlike the practice when using the original GANs, the generator was trained more often than the critic, which keep the cost function from degeneration. The resulting generator in OT-GAN has a well defined and statistically consistent objective through the training process.<br />
<br />
The algorithm is defined below. The backpropagation is not used in the algorithm due to the envelope theorem. Stochastic gradient descent is used as the optimization method. <br />
<br />
[[File: al.png|600px]]<br />
<br />
<br />
[[File: al_figure.png|600px]]<br />
<br />
==Experiments==<br />
<br />
In order to demonstrate the supermum performance of the OT-GAN, authors compared it with the original GAN and other popular models based on four experiments: Dataset recovery; CIFAR-10 test; ImageNet test; and the conditional image synthesis test.<br />
<br />
===Mixture of Gaussian Dataset===<br />
OT-GAN has a statistically consistent objective when it is compared with the original GAN (DC-GAN), such that the generator would not update to a wrong direction even if the signal provided by the cost function to the generator is not good. In order to prove this advantage, authors compared the OT-GAN with the original GAN loss (DAN-S) based on a simple task. The task was set to recover all of the 8 modes from 8 Gaussian mixers in which the means were arranged in a circle. MLP with RLU activation functions were used in this task. The critic was only updated for 15K iterations. The generator distribution was tracked for another 25K iteration. The results showed that the original GAN experiences the model collapse after fixing the discriminator while the OT-GAN recovered all the 8 modes from the mixed Gaussian data.<br />
<br />
[[File: 5_1.png|600px]]<br />
<br />
===CIFAR-10===<br />
<br />
The dataset CIFAR-10 was then used for inspecting the effect of batch-size to the model training process and the image quality. OT-GAN and four other methods were compared using "inception score" as the criteria for comparison. Figure 3 shows the change of inceptions scores (y-axis) by the increased of the iteration number. Scores of four different batch sizes (200, 800, 3200 and 8000) were compared. The results show that a larger batch size, which would more likely cover more modes in the distribution of data, lead to a more stable model showing a larger value in inception score. However, a large batch size would also require a high-performance computational environment. The sample quality across all 5 methods, ran using a batch size of 8000, are compared in Table 1 where the OT_GAN has the best score.<br />
<br />
[[File: 5_2.png|600px]]<br />
<br />
===ImageNet Dogs===<br />
<br />
In order to investigate the performance of OT-GAN when dealing with the high-quality images, the dog subset of ImageNet (128*128) was used to train the model. Figure 6 shows that OT-GAN produces less nonsensical images and it has a higher inception score compare to the DC-GAN. <br />
<br />
[[FIle: 5_3.png|600px]]<br />
<br />
===Conditional Generation of Birds===<br />
<br />
The last experiment was to compare OT-GAN with three popular GAN models for processing the text-to-image generation demonstrating the performance on conditional image synthesis. As can be found from Table 2, OT-GAN received the highest inception score than the scores of the other three models. <br />
<br />
[[File: 5_4.png|600px]]<br />
<br />
The algorithm used to obtain the results above is conditional generation generalized from '''Algorithm 1''' to include conditional information <math>s</math> such as some text description of an image. The modified algorithm is outlined in '''Algorithm 2'''.<br />
<br />
[[File: paper23_alg2.png|600px]]<br />
<br />
==Conclusion==<br />
<br />
In this paper, an OT-GAN method was proposed based on the optimal transport theory. A distance metric that combines the primal form of the optimal transport and the energy distance was given was presented for realizing the OT-GAN. One of the advantages of OT-GAN over other GAN models is that OT-GAN can stay on the correct track with an unbiased gradient even if the training on critic is stopped or presents a weak cost signal. The performance of the OT-GAN can be maintained when the batch size is increasing, though the computational cost has to be taken into consideration.<br />
<br />
==Critique==<br />
<br />
The paper presents a variant of GANs by defining a new distance metric based on the primal form of optimal transport and the mini-batch energy distance. The stability was demonstrated based on the four experiments that comparing OP-GAN with other popular methods. However, limitations in computational efficiency were not discussed much. Furthermore, in section 2, the paper is lack of explanation on using mini-batches instead of a vector as input when applying Sinkhorn distance. It is also confusing when explaining the algorithm in section 4 about choosing M for minimizing <math> \mathcal{W}_c </math>. Lastly, it is found that it is lack of parallel comparison with existing GAN variants in this paper. Readers may feel jumping from one algorithm to another without necessary explanations.<br />
<br />
==Reference==<br />
Salimans, Tim, Han Zhang, Alec Radford, and Dimitris Metaxas. "Improving GANs using optimal transport." (2018).</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=stat946w18/IMPROVING_GANS_USING_OPTIMAL_TRANSPORT&diff=35011stat946w18/IMPROVING GANS USING OPTIMAL TRANSPORT2018-03-21T15:45:07Z<p>X249wang: /* Optimal Transport GAN (OT-GAN) */</p>
<hr />
<div>== Introduction ==<br />
Generative Adversarial Networks (GANs) are powerful generative models. A GAN model consists of a generator and a discriminator or critic. The generator is a neural network which is trained to generate data having a distribution matched with the distribution of the real data. The critic is also a neural network, which is trained to separate the generated data from the real data. A loss function that measures the distribution distance between the generated data and the real one is important to train the generator.<br />
<br />
Optimal transport theory evaluates the distribution distance between the generated data and the training data based on a metric, which provides another method for generator training. The main advantage of optimal transport theory over the distance measurement in GAN is its closed form solution for having a tractable training process. But the theory might also result in inconsistency in statistical estimation due to the given biased gradients if the mini-batches method is applied (Bellemare et al.,<br />
2017).<br />
<br />
This paper presents a variant GANs named OT-GAN, which incorporates a discriminative metric called 'MIni-batch Energy Distance' into its critic in order to overcome the issue of biased gradients.<br />
<br />
== GANs and Optimal Transport ==<br />
<br />
===Generative Adversarial Nets===<br />
Original GAN was firstly reviewed. The objective function of the GAN: <br />
<br />
[[File:equation1.png|700px]]<br />
<br />
The goal of GANs is to train the generator g and the discriminator d finding a pair of (g,d) to achieve Nash equilibrium(such that either of them cannot reduce their cost without changing the others' parameters). However, it could cause failure of converging since the generator and the discriminator are trained based on gradient descent techniques.<br />
<br />
===Wasserstein Distance (Earth-Mover Distance)===<br />
<br />
In order to solve the problem of convergence failure, Arjovsky et. al. (2017) suggested Wasserstein distance (Earth-Mover distance) based on the optimal transport theory.<br />
<br />
[[File:equation2.png|600px]]<br />
<br />
where <math> \prod (p,g) </math> is the set of all joint distributions <math> \gamma (x,y) </math> with marginals <math> p(x) </math> (real data), <math> g(y) </math> (generated data). <math> c(x,y) </math> is a cost function and the Euclidean distance was used by Arjovsky et. al. in the paper. <br />
<br />
The Wasserstein distance can be considered as moving the minimum amount of points between distribution <math> g(y) </math> and <math> p(x) </math> such that the generator distribution <math> g(y) </math> is similar to the real data distribution <math> p(x) </math>.<br />
<br />
Computing the Wasserstein distance is intractable. The proposed Wasserstein GAN (W-GAN) provides an estimated solution by switching the optimal transport problem into Kantorovich-Rubinstein dual formulation using a set of 1-Lipschitz functions. A neural network can then be used to obtain an estimation.<br />
<br />
[[File:equation3.png|600px]]<br />
<br />
W-GAN helps to solve the unstable training process of original GAN and it can solve the optimal transport problem approximately, but it is still intractable.<br />
<br />
===Sinkhorn Distance===<br />
Genevay et al. (2017) proposed to use the primal formulation of optimal transport instead of the dual formulation to generative modeling. They introduced Sinkhorn distance which is a smoothed generalization of the Wasserstein distance.<br />
[[File: equation4.png|600px]]<br />
<br />
It introduced entropy restriction (<math> \beta </math>) to the joint distribution <math> \prod_{\beta} (p,g) </math>. This distance could be generalized to approximate the mini-batches of data <math> X ,Y</math> with <math> K </math> vectors of <math> x, y</math>. The <math> i, j </math> th entry of the cost matrix <math> C </math> can be interpreted as the cost it needs to transport the <math> x_i </math> in mini-batch X to the <math> y_i </math> in mini-batch <math>Y </math>. The resulting distance will be:<br />
<br />
[[File: equation5.png|550px]]<br />
<br />
where <math> M </math> is a <math> K \times K </math> matrix, each row of <math> M </math> is a joint distribution of <math> \gamma (x,y) </math> with positive entries. The summmation of rows or columns of <math> M </math> is always equal to 1. <br />
<br />
This mini-batch Sinkhorn distance is not only fully tractable but also capable of solving the instability problem of GANs. However, it is not a valid metric over probability distribution when taking the expectation of <math> \mathcal{W}_{c} </math> and the gradients are biased when the mini-batch size is fixed.<br />
<br />
===Energy Distance (Cramer Distance)===<br />
In order to solve the above problem, Bellemare et al. proposed Energy distance:<br />
<br />
[[File: equation6.png|700px]]<br />
<br />
where <math> x, x' </math> and <math> y, y'</math> are independent samples from data distribution <math> p </math> and generator distribution <math> g </math>, respectively. Based on the Energy distance, Cramer GAN is to minimize the ED distance metric when training the generator.<br />
<br />
==Mini-Batch Energy Distance==<br />
Salimans et al. (2016) mentioned that comparing to use distributions over individual images, mini-batch GAN is more powerful when using the distributions over mini-batches <math> g(X), p(X) </math>. The distance measure is generated for mini-batches.<br />
<br />
===Generalized Energy Distance===<br />
The generalized energy distance allowed to use non-Euclidean distance functions d. It is also valid for mini-batches and is considered better than working with individual data batch.<br />
<br />
[[File: equation7.png|670px]]<br />
<br />
Similarly as defined in the Energy distance, <math> X, X' </math> and <math> Y, Y'</math> can be the independent samples from data distribution <math> p </math> and the generator distribution <math> g </math>, respectively. While in Generalized engergy distance, <math> X, X' </math> and <math> Y, Y'</math> can also be valid for mini-batches. The <math> D_{GED}(p,g) </math> is a metric when having <math> d </math> as a metric. Thus, taking the triangle inequality of <math> d </math> into account, <math> D(p,g) \geq 0,</math> and <math> D(p,g)=0 </math> when <math> p=g </math>.<br />
<br />
===Mini-Batch Energy Distance===<br />
As <math> d </math> is free to choose, authors proposed Mini-batch Energy Distance by using entropy-regularized Wasserstein distnace as <math> d </math>. <br />
<br />
[[File: equation8.png|650px]]<br />
<br />
where <math> X, X' </math> and <math> Y, Y'</math> are independent sampled mini-batches from the data distribution <math> p </math> and the generator distribution <math> g </math>, respectively. This distance metric combines the energy distance with primal form of optimal tranport over mini-batch distributions <math> g(Y) </math> and <math> p(X) </math>. Inside the generalized energy distance, the Sinkhorn distance is a valid metric between each mini-batches. By adding the <math> - \mathcal{W}_c (Y,Y')</math> and <math> \mathcal{W}_c (X,Y)</math> to equation (5) and using enregy distance, the objective becomes statistically consistent and mini-batch gradients are unbiased.<br />
<br />
==Optimal Transport GAN (OT-GAN)==<br />
<br />
In order to secure the statistical efficiency (i.e. being able to tell <math>p</math> and <math>g</math> apart without requiring an enormous sample size when their distance is close to zero), authors suggested using cosine distance between vectors <math> v_\eta (x) </math> and <math> v_\eta (y) </math> based on the deep neural network that maps the mini-batch data to a learned latent space. The reason for not using Euclidean distance is because of its poor performance in the high dimensional space. Here is the transportation cost:<br />
<br />
[[File: euqation9.png|370px]]<br />
<br />
where the <math> v_\eta </math> is chosen to maximize the resulting minibatch energy distance.<br />
<br />
Unlike the practice when using the original GANs, the generator was trained more often than the critic, which keep the cost function from degeneration. The resulting generator in OT-GAN has a well defined and statistically consistent objective through the training process.<br />
<br />
The algorithm is defined below. The backpropagation is not used in the algorithm due to the envelope theorem. Stochastic gradient descent is used as the optimization method. <br />
<br />
[[File: al.png|600px]]<br />
<br />
<br />
[[File: al_figure.png|600px]]<br />
<br />
==Experiments==<br />
<br />
In order to demonstrate the supermum performance of the OT-GAN, authors compared it with the original GAN and other popular models based on four experiments: Dataset recovery; CIFAR-10 test; ImageNet test; and the conditional image synthesis test.<br />
<br />
===Mixture of Gaussian Dataset===<br />
OT-GAN has a statistically consistent objective when it is compared with the original GAN (DC-GAN), such that the generator would not update to a wrong direction even if the signal provided by the cost function to the generator is not good. In order to prove this advantage, authors compared the OT-GAN with the original GAN loss (DAN-S) based on a simple task. The task was set to recover all of the 8 modes from 8 Gaussian mixers in which the means were arranged in a circle. MLP with RLU activation functions were used in this task. The critic was only updated for 15K iterations. The generator distribution was tracked for another 25K iteration. The results showed that the original GAN experiences the model collapse after fixing the discriminator while the OT-GAN recovered all the 8 modes from the mixed Gaussian data.<br />
<br />
[[File: 5_1.png|600px]]<br />
<br />
===CIFAR-10===<br />
<br />
The dataset CIFAR-10 was then used for inspecting the effect of batch-size to the model training process and the image quality. OT-GAN and four other methods were compared using "inception score" as the criteria for comparison. Figure 3 shows the change of inceptions scores (y-axis) by the increased of the iteration number. Scores of four different batch sizes (200, 800, 3200 and 8000) were compared. The results show that a larger batch size would lead to a more stable model showing a larger value in inception score. However, a large batch size would also require a high-performance computational environment. The sample quality across all 5 methods are compared in Table 1 where the OT_GAN has the best score.<br />
<br />
[[File: 5_2.png|600px]]<br />
<br />
===ImageNet Dogs===<br />
<br />
In order to investigate the performance of OT-GAN when dealing with the high-quality images, the dog subset of ImageNet (128*128) was used to train the model. Figure 6 shows that OT-GAN produces less nonsensical images and it has a higher inception score compare to the DC-GAN. <br />
<br />
[[FIle: 5_3.png|600px]]<br />
<br />
===Conditional Generation of Birds===<br />
<br />
The last experiment was to compare OT-GAN with three popular GAN models for processing the text-to-image generation demonstrating the performance on conditional image synthesis. As can be found from Table 2, OT-GAN received the highest inception score than the scores of the other three models. <br />
<br />
[[File: 5_4.png|600px]]<br />
<br />
The algorithm used to obtain the results above is conditional generation generalized from '''Algorithm 1''' to include conditional information <math>s</math> such as some text description of an image. The modified algorithm is outlined in '''Algorithm 2'''.<br />
<br />
[[File: paper23_alg2.png|600px]]<br />
<br />
==Conclusion==<br />
<br />
In this paper, an OT-GAN method was proposed based on the optimal transport theory. A distance metric that combines the primal form of the optimal transport and the energy distance was given was presented for realizing the OT-GAN. One of the advantages of OT-GAN over other GAN models is that OT-GAN can stay on the correct track with an unbiased gradient even if the training on critic is stopped or presents a weak cost signal. The performance of the OT-GAN can be maintained when the batch size is increasing, though the computational cost has to be taken into consideration.<br />
<br />
==Critique==<br />
<br />
The paper presents a variant of GANs by defining a new distance metric based on the primal form of optimal transport and the mini-batch energy distance. The stability was demonstrated based on the four experiments that comparing OP-GAN with other popular methods. However, limitations in computational efficiency were not discussed much. Furthermore, in section 2, the paper is lack of explanation on using mini-batches instead of a vector as input when applying Sinkhorn distance. It is also confusing when explaining the algorithm in section 4 about choosing M for minimizing <math> \mathcal{W}_c </math>. Lastly, it is found that it is lack of parallel comparison with existing GAN variants in this paper. Readers may feel jumping from one algorithm to another without necessary explanations.<br />
<br />
==Reference==<br />
Salimans, Tim, Han Zhang, Alec Radford, and Dimitris Metaxas. "Improving GANs using optimal transport." (2018).</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=stat946w18/Implicit_Causal_Models_for_Genome-wide_Association_Studies&diff=34790stat946w18/Implicit Causal Models for Genome-wide Association Studies2018-03-20T17:59:31Z<p>X249wang: /* Real-data Analysis */</p>
<hr />
<div>==Introduction and Motivation==<br />
There is progression in probabilistic models which could develop rich generative models. The models have been expanded with neural network, implicit densities, and with scalable algorithms to very large data for their Bayesian inference. However, most of the models are focus on capturing statistical relationships rather than causal relationships. Causal models give us a sense on how manipulate the generative process could change the final results. <br />
<br />
Genome-wide association studies (GWAS) are examples of causal relationship. Specifically, GWAS is about figuring out how genetic factors cause disease among humans. Here the genetic factors we are referring to is single nucleotide polymorphisms (SNPs), and getting a particular disease is treated as a trait, i.e., the outcome. In order to know about the reason of developing a disease and to cure it, the causation between SNPs and diseases is interested: first, predict which one or multiple SNPs cause the disease; second, target the selected SNPs to cure the disease. <br />
<br />
The figure below depicts an example Manhattan plot for a GWAS. Each dot represents a SNP. x-axis is the chromosome location, and y-axis is the negative log of the association p-value between the SNP and the disease, so points with the largest values represent strongly associated risk loci.<br />
<br />
[[File:gwas-example.jpg|500px|center]]<br />
<br />
This paper dealt with two questions. The first one is how to build rich causal models with specific needs by GWAS. In general, probabilistic causal models involve a function <math>f</math> and a noise <math>n</math>. For the working simplicity, we usually assume <math>f</math> as a linear model with a Gaussian noise. However, proof has shown that in GWAS, it is necessary to accommodate non-linearity and interactions between multiple genes into the models.<br />
<br />
The second accomplishment of this paper is that it addressed the problem caused by latent confounders. Latent confounders are issues when we apply the causal models since we cannot observe them nor knowing the underlying structure. In this paper, they developed implicit causal models which can adjust for confounders.<br />
<br />
There has been growing works on causal models which focus on causal discovery and typically have strong assumptions such as Gaussian processes on noise variable or nonlinearities for the main function.<br />
<br />
==Implicit Causal Models==<br />
Implicit causal models are an extension of probabilistic causal models. Probabilistic causal models will be introduced first.<br />
<br />
=== Probabilistic Causal Models ===<br />
Probabilistic causal models have two parts: deterministic functions of noise and other variables. Consider a global variable <math>\beta</math> and noise <math>\epsilon</math>, where<br />
<br />
[[File: eq1.1.png|800px|center]]<br />
<br />
Each <math>\beta</math> and <math>x</math> is a function of noise; <math>y</math> is a function of noise and <math>x</math>，<br />
<br />
[[File: eqt1.png|800px|center]]<br />
<br />
The target is the causal mechanism <math>f_y</math> so that the causal effect <math>p(y|do(X=x),\beta)</math> can be calculated. <math>do(X=x)</math> means that we specify a value of <math>X</math> under the fixed structure <math>\beta</math>. By other paper’s work, it is assumed that <math>p(y|do(x),\beta) = p(y|x, \beta)</math>.<br />
<br />
[[File: f_1.png|650px|center|]]<br />
<br />
<br />
An example of probabilistic causal models is additive noise model. <br />
<br />
[[File: eq2.1.png|800px|center]]<br />
<br />
<math>f(.)</math> is usually a linear function or spline functions for nonlinearities. <math>\epsilon</math> is assumed to be standard normal, as well as <math>y</math>. Thus the posterior <math>p(\theta | x, y, \beta)</math> can be represented as <br />
<br />
[[File: eqt2.png|800px|center]]<br />
<br />
where <math>p(\theta)</math> is the prior which is known. Then, variational inference or MCMC can be applied to calculate the posterior distribution. <br />
<br />
<br />
===Implicit Causal Models===<br />
The difference between implicit causal models and probabilistic causal models is the noise variable. Instead of an additive noise term, implicit causal models directly take noise <math>\epsilon</math> into a neural network and output <math>x</math>.<br />
<br />
The causal diagram has changed to:<br />
<br />
[[File: f_2.png|650px|center|]]<br />
<br />
<br />
They used fully connected neural network with a fair amount of hidden units to approximate each causal mechanism. Below is the formal description: <br />
<br />
[[File: theorem.png|650px|center|]]<br />
<br />
<br />
==Implicit Causal Models with Latent Confounders==<br />
Previously, they assumed the global structure is observed. Next, the unobserved scenario is being considered.<br />
<br />
===Causal Inference with a Latent Confounder===<br />
Same as before, the interest is the causal effect <math>p(y|do(x_m), x_{-m})</math>. Here, the SNPs other than <math>x_m</math> is also under consideration. However, it is confounded by the unobserved confounder <math>z_n</math>. As a result, the standard inference method cannot be used in this case.<br />
<br />
The paper proposed a new method which include the latent confounders. For each subject <math>n=1,…,N</math> and each SNP <math>m=1,…,M</math>,<br />
<br />
[[File: eqt4.png|800px|center]]<br />
<br />
<br />
The mechanism for latent confounder <math>z_n</math> is assumed to be known. SNPs depend on the confounders and the trait depends on all the SNPs and the confounders as well. <br />
<br />
The posterior of <math>\theta</math> is needed to be calculate in order to estimate the mechanism <math>g_y</math> as well as the causal effect <math>p(y|do(x_m), x_{-m})</math>, so that it can be explained how changes to each SNP <math>X_m</math> cause changes to the trait <math>Y</math>.<br />
<br />
[[File: eqt5.png|800px|center]]<br />
<br />
Note that the latent structure <math>p(z|x, y)</math> is assumed known.<br />
<br />
<br />
===Implicit Causal Model with a Latent Confounder===<br />
This section is the algorithm and functions to implementing an implicit causal model for GWAS.<br />
<br />
====Generative Process of Confounders <math>z_n</math>.====<br />
The distribution of confounders is set as standard normal. <math>z_n \in R^K</math> , where <math>K</math> is the dimension of <math>z_n</math> and <math>K</math> should make the latent space as close as possible to the true population structural. <br />
<br />
====Generative Process of SNPs <math>x_{nm}</math>.====<br />
Given SNP is coded for,<br />
<br />
[[File: SNP.png|300px|center]]<br />
<br />
The authors defined a <math>Binomial(2,\pi_{nm})</math> distribution on <math>x_{nm}</math>. And used logistic factor analysis to design the SNP matrix.<br />
<br />
[[File: gpx.png|800px|center]]<br />
<br />
A SNP matrix looks like this:<br />
[[File: SNP_matrix.png|200px|center]]<br />
<br />
<br />
Since logistic factor analysis makes strong assumptions, this paper suggests using a neural network to relax these assumptions,<br />
<br />
[[File: gpxnn.png|800px|center]]<br />
<br />
This renders the outputs to be a full <math>N*M</math> matrix due the the variables <math>w_m</math>, which act as principal component in PCA. Here, <math>\phi</math> has a standard normal prior distribution. The weights <math>w</math> and biases <math>\phi</math> are shared over the <math>m</math> SNPs and <math>n</math> individuals, which makes it possible to learn nonlinear interactions between <math>z_n</math> and <math>w_m</math>.<br />
<br />
====Generative Process of Traits <math>y_n</math>.====<br />
Previously, each trait is modeled by a linear regression,<br />
<br />
[[File: gpy.png|800px|center]]<br />
<br />
This also has very strong assumptions on SNPs, interactions, and additive noise. It can also be replaced by a neural network which only outputs a scalar,<br />
<br />
[[File: gpynn.png|800px|center]]<br />
<br />
<br />
==Likelihood-free Variational Inference==<br />
Calculating the posterior of <math>\theta</math> is the key of applying the implicit causal model with latent confounders.<br />
<br />
[[File: eqt5.png|800px|center]]<br />
<br />
could be reduces to <br />
<br />
[[File: lfvi1.png|800px|center]]<br />
<br />
However, with implicit models, integrating over a nonlinear function could be suffered. The authors applied likelihood-free variational inference (LFVI). LFVI proposes a family of distribution over the latent variables. Here the variables <math>w_m</math> and <math>z_n</math> are all assumed to be Normal,<br />
<br />
[[File: lfvi2.png|700px|center]]<br />
<br />
For LFVI applied to GWAS, the algorithm which similar to the EM algorithm has been used:<br />
[[File: em.png|800px|center]]<br />
<br />
==Empirical Study==<br />
The authors performed simulation on 100,000 SNPs, 940 to 5,000 individuals, and across 100 replications of 11 settings. <br />
Four methods were compared: <br />
<br />
* implicit causal model (ICM);<br />
* PCA with linear regression (PCA); <br />
* a linear mixed model (LMM); <br />
* logistic factor analysis with inverse regression (GCAT).<br />
<br />
The feedforward neural networks for traits and SNPs are fully connected with two hidden layers using ReLU activation function, and batch normalization. <br />
<br />
===Simulation Study===<br />
Based on real genomic data, a true model is applied to generate the SNPs and traits for each configuration. <br />
There are four datasets used in this simulation study: <br />
<br />
# HapMap [Balding-Nichols model]<br />
# 1000 Genomes Project (TGP) [PCA]<br />
#* Human Genome Diversity project (HGDP) [PCA]<br />
#* HGDP [Pritchard-Stephens-Donelly model] <br />
# A latent spatial position of individuals for population structure [spatial]<br />
<br />
<br />
The table shows the prediction accuracy. The accuracy is calculated by the rate of the number of true positives divide the number of true positives plus false positives. True positives measure the proportion of positives that are correctly identified as such (e.g. the percentage of SNPs which are correctly identified as having the causal relation with the trait). In contrast, false positives state the SNPs has the causal relation with the trait when they don’t. The closer the rate to 1, the better the model is since false positives are considered as the wrong prediction.<br />
<br />
[[File: table_1.png|650px|center|]]<br />
<br />
The result represented above shows that the implicit causal model has the best performance among these four models in every situation. Especially, other models tend to do poorly on PSD and Spatial when <math>a</math> is small, but the ICM achieved a significantly high rate. The only comparable method to ICM is GCAT, when applying to simpler configurations.<br />
<br />
<br />
===Real-data Analysis===<br />
They also applied ICM to a real-world GWAS of Northern Finland Birth Cohorts which contain 324,160 SNPs and 5,027 individuals. Ten implicit causal models were fitted and the 2 neural networks both with two hidden layers were used for SNP and trait. The dimension of confounders (<math>K</math>) was set to be six, same as what was used in the paper by Song et al. for comparable models in Table 2.<br />
<br />
[[File: table_2.png|650px|center|]]<br />
<br />
The numbers in the above table are the number of significant loci for each of the 10 traits. The number for other methods, such as GCAT, LMM, PCA, and "uncorrected" (association tests without accounting for hidden relatedness of study samples) are obtained from other papers. By comparison, the ICM reached the level of the best previous model for each trait.<br />
<br />
==Conclusion==<br />
This paper introduced implicit causal models in order to account for nonlinear complex causal relationships, and applied the method to GWAS. It can not only capture important interactions between genes within an individual and among population level, but also can adjust for latent confounders by taking account of the latent variables into the model.<br />
<br />
By the simulation study, the authors proved that the implicit causal model could beat other methods by 15-45.3% on a variety of datasets with variations on parameters.<br />
<br />
The authors also believed this GWAS application is only a start of the usage of implicit causal models. It might also be used in physics or economics. <br />
<br />
==Critique==<br />
I think this paper is an interesting and novel work. The main contribution of this paper is to connect the statistical genetics and the machine learning methodology. The method is technically sound and does indeed generalize techniques currently used in statistical genetics.<br />
<br />
The neural network used in this paper is a very simple feedforward 2 hidden layers neural network, but the idea of where to use the neural network is crucial and might be significant in GWAS.<br />
<br />
It has limitations as well. The empirical example in this paper is too easy, and far away from the realistic situation. Despite the simulation study showed some competing results, the Northern Finland Birth Cohort Data application did not demonstrate the advantage of using implicit causal model whether are better than the previous methods, such as GCAT or LMM.<br />
<br />
Another limitation is about linkage disequilibrium as the authors stated as well. SNPs are not completely independent of each other; usually they have correlations when the alleles at close locus. They did not consider this complex case, rather they only considered the simplest case where they assumed all the SNPs are independent.<br />
<br />
Furthermore, one SNP maybe does not have enough power to explain the causal relationship. Recent papers indicate that causation to a trait may involve multiple SNPs.<br />
This could be a future work as well.<br />
<br />
==References==<br />
Tran D, Blei D M. Implicit Causal Models for Genome-wide Association Studies[J]. arXiv preprint arXiv:1710.10742, 2017.<br />
<br />
Patrik O Hoyer, Dominik Janzing, Joris M Mooij, Jonas Peters, and Prof Bernhard Schölkopf. Non- linear causal discovery with additive noise models. In Neural Information Processing Systems, 2009.<br />
<br />
Alkes L Price, Nick J Patterson, Robert M Plenge, Michael E Weinblatt, Nancy A Shadick, and David Reich. Principal components analysis corrects for stratification in genome-wide association studies. Nature Genetics, 38(8):904–909, 2006.<br />
<br />
Minsun Song, Wei Hao, and John D Storey. Testing for genetic associations in arbitrarily structured populations. Nature, 47(5):550–554, 2015.<br />
<br />
Dustin Tran, Rajesh Ranganath, and David M Blei. Hierarchical implicit models and likelihood-free variational inference. In Neural Information Processing Systems, 2017.</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=stat946w18/Implicit_Causal_Models_for_Genome-wide_Association_Studies&diff=34781stat946w18/Implicit Causal Models for Genome-wide Association Studies2018-03-20T17:30:59Z<p>X249wang: /* Generative Process of SNPs x_{nm}. */</p>
<hr />
<div>==Introduction and Motivation==<br />
There is progression in probabilistic models which could develop rich generative models. The models have been expanded with neural network, implicit densities, and with scalable algorithms to very large data for their Bayesian inference. However, most of the models are focus on capturing statistical relationships rather than causal relationships. Causal models give us a sense on how manipulate the generative process could change the final results. <br />
<br />
Genome-wide association studies (GWAS) are examples of causal relationship. Specifically, GWAS is about figuring out how genetic factors cause disease among humans. Here the genetic factors we are referring to is single nucleotide polymorphisms (SNPs), and getting a particular disease is treated as a trait, i.e., the outcome. In order to know about the reason of developing a disease and to cure it, the causation between SNPs and diseases is interested: first, predict which one or multiple SNPs cause the disease; second, target the selected SNPs to cure the disease. <br />
<br />
The figure below depicts an example Manhattan plot for a GWAS. Each dot represents a SNP. x-axis is the chromosome location, and y-axis is the negative log of the association p-value between the SNP and the disease, so points with the largest values represent strongly associated risk loci.<br />
<br />
[[File:gwas-example.jpg|500px|center]]<br />
<br />
This paper dealt with two questions. The first one is how to build rich causal models with specific needs by GWAS. In general, probabilistic causal models involve a function <math>f</math> and a noise <math>n</math>. For the working simplicity, we usually assume <math>f</math> as a linear model with a Gaussian noise. However, proof has shown that in GWAS, it is necessary to accommodate non-linearity and interactions between multiple genes into the models.<br />
<br />
The second accomplishment of this paper is that it addressed the problem caused by latent confounders. Latent confounders are issues when we apply the causal models since we cannot observe them nor knowing the underlying structure. In this paper, they developed implicit causal models which can adjust for confounders.<br />
<br />
There has been growing works on causal models which focus on causal discovery and typically have strong assumptions such as Gaussian processes on noise variable or nonlinearities for the main function.<br />
<br />
==Implicit Causal Models==<br />
Implicit causal models are an extension of probabilistic causal models. Probabilistic causal models will be introduced first.<br />
<br />
=== Probabilistic Causal Models ===<br />
Probabilistic causal models have two parts: deterministic functions of noise and other variables. Consider a global variable <math>\beta</math> and noise <math>\epsilon</math>, where<br />
<br />
[[File: eq1.1.png|800px|center]]<br />
<br />
Each <math>\beta</math> and <math>x</math> is a function of noise; <math>y</math> is a function of noise and <math>x</math>，<br />
<br />
[[File: eqt1.png|800px|center]]<br />
<br />
The target is the causal mechanism <math>f_y</math> so that the causal effect <math>p(y|do(X=x),\beta)</math> can be calculated. <math>do(X=x)</math> means that we specify a value of <math>X</math> under the fixed structure <math>\beta</math>. By other paper’s work, it is assumed that <math>p(y|do(x),\beta) = p(y|x, \beta)</math>.<br />
<br />
[[File: f_1.png|650px|center|]]<br />
<br />
<br />
An example of probabilistic causal models is additive noise model. <br />
<br />
[[File: eq2.1.png|800px|center]]<br />
<br />
<math>f(.)</math> is usually a linear function or spline functions for nonlinearities. <math>\epsilon</math> is assumed to be standard normal, as well as <math>y</math>. Thus the posterior <math>p(\theta | x, y, \beta)</math> can be represented as <br />
<br />
[[File: eqt2.png|800px|center]]<br />
<br />
where <math>p(\theta)</math> is the prior which is known. Then, variational inference or MCMC can be applied to calculate the posterior distribution. <br />
<br />
<br />
===Implicit Causal Models===<br />
The difference between implicit causal models and probabilistic causal models is the noise variable. Instead of an additive noise term, implicit causal models directly take noise <math>\epsilon</math> into a neural network and output <math>x</math>.<br />
<br />
The causal diagram has changed to:<br />
<br />
[[File: f_2.png|650px|center|]]<br />
<br />
<br />
They used fully connected neural network with a fair amount of hidden units to approximate each causal mechanism. Below is the formal description: <br />
<br />
[[File: theorem.png|650px|center|]]<br />
<br />
<br />
==Implicit Causal Models with Latent Confounders==<br />
Previously, they assumed the global structure is observed. Next, the unobserved scenario is being considered.<br />
<br />
===Causal Inference with a Latent Confounder===<br />
Same as before, the interest is the causal effect <math>p(y|do(x_m), x_{-m})</math>. Here, the SNPs other than <math>x_m</math> is also under consideration. However, it is confounded by the unobserved confounder <math>z_n</math>. As a result, the standard inference method cannot be used in this case.<br />
<br />
The paper proposed a new method which include the latent confounders. For each subject <math>n=1,…,N</math> and each SNP <math>m=1,…,M</math>,<br />
<br />
[[File: eqt4.png|800px|center]]<br />
<br />
<br />
The mechanism for latent confounder <math>z_n</math> is assumed to be known. SNPs depend on the confounders and the trait depends on all the SNPs and the confounders as well. <br />
<br />
The posterior of <math>\theta</math> is needed to be calculate in order to estimate the mechanism <math>g_y</math> as well as the causal effect <math>p(y|do(x_m), x_{-m})</math>, so that it can be explained how changes to each SNP <math>X_m</math> cause changes to the trait <math>Y</math>.<br />
<br />
[[File: eqt5.png|800px|center]]<br />
<br />
Note that the latent structure <math>p(z|x, y)</math> is assumed known.<br />
<br />
<br />
===Implicit Causal Model with a Latent Confounder===<br />
This section is the algorithm and functions to implementing an implicit causal model for GWAS.<br />
<br />
====Generative Process of Confounders <math>z_n</math>.====<br />
The distribution of confounders is set as standard normal. <math>z_n \in R^K</math> , where <math>K</math> is the dimension of <math>z_n</math> and <math>K</math> should make the latent space as close as possible to the true population structural. <br />
<br />
====Generative Process of SNPs <math>x_{nm}</math>.====<br />
Given SNP is coded for,<br />
<br />
[[File: SNP.png|300px|center]]<br />
<br />
The authors defined a <math>Binomial(2,\pi_{nm})</math> distribution on <math>x_{nm}</math>. And used logistic factor analysis to design the SNP matrix.<br />
<br />
[[File: gpx.png|800px|center]]<br />
<br />
A SNP matrix looks like this:<br />
[[File: SNP_matrix.png|200px|center]]<br />
<br />
<br />
Since logistic factor analysis makes strong assumptions, this paper suggests using a neural network to relax these assumptions,<br />
<br />
[[File: gpxnn.png|800px|center]]<br />
<br />
This renders the outputs to be a full <math>N*M</math> matrix due the the variables <math>w_m</math>, which act as principal component in PCA. Here, <math>\phi</math> has a standard normal prior distribution. The weights <math>w</math> and biases <math>\phi</math> are shared over the <math>m</math> SNPs and <math>n</math> individuals, which makes it possible to learn nonlinear interactions between <math>z_n</math> and <math>w_m</math>.<br />
<br />
====Generative Process of Traits <math>y_n</math>.====<br />
Previously, each trait is modeled by a linear regression,<br />
<br />
[[File: gpy.png|800px|center]]<br />
<br />
This also has very strong assumptions on SNPs, interactions, and additive noise. It can also be replaced by a neural network which only outputs a scalar,<br />
<br />
[[File: gpynn.png|800px|center]]<br />
<br />
<br />
==Likelihood-free Variational Inference==<br />
Calculating the posterior of <math>\theta</math> is the key of applying the implicit causal model with latent confounders.<br />
<br />
[[File: eqt5.png|800px|center]]<br />
<br />
could be reduces to <br />
<br />
[[File: lfvi1.png|800px|center]]<br />
<br />
However, with implicit models, integrating over a nonlinear function could be suffered. The authors applied likelihood-free variational inference (LFVI). LFVI proposes a family of distribution over the latent variables. Here the variables <math>w_m</math> and <math>z_n</math> are all assumed to be Normal,<br />
<br />
[[File: lfvi2.png|700px|center]]<br />
<br />
For LFVI applied to GWAS, the algorithm which similar to the EM algorithm has been used:<br />
[[File: em.png|800px|center]]<br />
<br />
==Empirical Study==<br />
The authors performed simulation on 100,000 SNPs, 940 to 5,000 individuals, and across 100 replications of 11 settings. <br />
Four methods were compared: <br />
<br />
* implicit causal model (ICM);<br />
* PCA with linear regression (PCA); <br />
* a linear mixed model (LMM); <br />
* logistic factor analysis with inverse regression (GCAT).<br />
<br />
The feedforward neural networks for traits and SNPs are fully connected with two hidden layers using ReLU activation function, and batch normalization. <br />
<br />
===Simulation Study===<br />
Based on real genomic data, a true model is applied to generate the SNPs and traits for each configuration. <br />
There are four datasets used in this simulation study: <br />
<br />
# HapMap [Balding-Nichols model]<br />
# 1000 Genomes Project (TGP) [PCA]<br />
#* Human Genome Diversity project (HGDP) [PCA]<br />
#* HGDP [Pritchard-Stephens-Donelly model] <br />
# A latent spatial position of individuals for population structure [spatial]<br />
<br />
<br />
The table shows the prediction accuracy. The accuracy is calculated by the rate of the number of true positives divide the number of true positives plus false positives. True positives measure the proportion of positives that are correctly identified as such (e.g. the percentage of SNPs which are correctly identified as having the causal relation with the trait). In contrast, false positives state the SNPs has the causal relation with the trait when they don’t. The closer the rate to 1, the better the model is since false positives are considered as the wrong prediction.<br />
<br />
[[File: table_1.png|650px|center|]]<br />
<br />
The result represented above shows that the implicit causal model has the best performance among these four models in every situation. Especially, other models tend to do poorly on PSD and Spatial when <math>a</math> is small, but the ICM achieved a significantly high rate. The only comparable method to ICM is GCAT, when applying to simpler configurations.<br />
<br />
<br />
===Real-data Analysis===<br />
They also applied ICM to a real-world GWAS of Northern Finland Birth Cohorts which contain 324,160 SNPs and 5,027 individuals. Ten implicit causal models were fitted and the 2 neural networks both with two hidden layers were used for SNP and trait. <br />
<br />
[[File: table_2.png|650px|center|]]<br />
<br />
The numbers in the above table are the number of significant loci for each of the 10 traits. The number for other methods, such as GCAT, LMM, PCA, and "uncorrected" are obtained from other papers. By comparison, the ICM reached the level of the best previous model for each trait. <br />
<br />
==Conclusion==<br />
This paper introduced implicit causal models in order to account for nonlinear complex causal relationships, and applied the method to GWAS. It can not only capture important interactions between genes within an individual and among population level, but also can adjust for latent confounders by taking account of the latent variables into the model.<br />
<br />
By the simulation study, the authors proved that the implicit causal model could beat other methods by 15-45.3% on a variety of datasets with variations on parameters.<br />
<br />
The authors also believed this GWAS application is only a start of the usage of implicit causal models. It might also be used in physics or economics. <br />
<br />
==Critique==<br />
I think this paper is an interesting and novel work. The main contribution of this paper is to connect the statistical genetics and the machine learning methodology. The method is technically sound and does indeed generalize techniques currently used in statistical genetics.<br />
<br />
The neural network used in this paper is a very simple feedforward 2 hidden layers neural network, but the idea of where to use the neural network is crucial and might be significant in GWAS.<br />
<br />
It has limitations as well. The empirical example in this paper is too easy, and far away from the realistic situation. Despite the simulation study showed some competing results, the Northern Finland Birth Cohort Data application did not demonstrate the advantage of using implicit causal model whether are better than the previous methods, such as GCAT or LMM.<br />
<br />
Another limitation is about linkage disequilibrium as the authors stated as well. SNPs are not completely independent of each other; usually they have correlations when the alleles at close locus. They did not consider this complex case, rather they only considered the simplest case where they assumed all the SNPs are independent.<br />
<br />
Furthermore, one SNP maybe does not have enough power to explain the causal relationship. Recent papers indicate that causation to a trait may involve multiple SNPs.<br />
This could be a future work as well.<br />
<br />
==References==<br />
Tran D, Blei D M. Implicit Causal Models for Genome-wide Association Studies[J]. arXiv preprint arXiv:1710.10742, 2017.<br />
<br />
Patrik O Hoyer, Dominik Janzing, Joris M Mooij, Jonas Peters, and Prof Bernhard Schölkopf. Non- linear causal discovery with additive noise models. In Neural Information Processing Systems, 2009.<br />
<br />
Alkes L Price, Nick J Patterson, Robert M Plenge, Michael E Weinblatt, Nancy A Shadick, and David Reich. Principal components analysis corrects for stratification in genome-wide association studies. Nature Genetics, 38(8):904–909, 2006.<br />
<br />
Minsun Song, Wei Hao, and John D Storey. Testing for genetic associations in arbitrarily structured populations. Nature, 47(5):550–554, 2015.<br />
<br />
Dustin Tran, Rajesh Ranganath, and David M Blei. Hierarchical implicit models and likelihood-free variational inference. In Neural Information Processing Systems, 2017.</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=stat946w18/Implicit_Causal_Models_for_Genome-wide_Association_Studies&diff=34774stat946w18/Implicit Causal Models for Genome-wide Association Studies2018-03-20T16:09:35Z<p>X249wang: /* Explain figure */</p>
<hr />
<div>==Introduction and Motivation==<br />
There is progression in probabilistic models which could develop rich generative models. The models have been expanded with neural network, implicit densities, and with scalable algorithms to very large data for their Bayesian inference. However, most of the models are focus on capturing statistical relationships rather than causal relationships. Causal models give us a sense on how manipulate the generative process could change the final results. <br />
<br />
Genome-wide association studies (GWAS) are examples of causal relationship. Specifically, GWAS is about figuring out how genetic factors cause disease among humans. Here the genetic factors we are referring to is single nucleotide polymorphisms (SNPs), and getting a particular disease is treated as a trait, i.e., the outcome. In order to know about the reason of developing a disease and to cure it, the causation between SNPs and diseases is interested: first, predict which one or multiple SNPs cause the disease; second, target the selected SNPs to cure the disease. <br />
<br />
The figure below depicts an example Manhattan plot for a GWAS. Each dot represents a SNP. x-axis is the chromosome location, and y-axis is the negative log of the association p-value between the SNP and the disease, so points with the largest values represent strongly associated risk loci.<br />
<br />
[[File:gwas-example.jpg|500px|center]]<br />
<br />
This paper dealt with two questions. The first one is how to build rich causal models with specific needs by GWAS. In general, probabilistic causal models involve a function <math>f</math> and a noise <math>n</math>. For the working simplicity, we usually assume <math>f</math> as a linear model with a Gaussian noise. However, proof has shown that in GWAS, it is necessary to accommodate non-linearity and interactions between multiple genes into the models.<br />
<br />
The second accomplishment of this paper is that it addressed the problem caused by latent confounders. Latent confounders are issues when we apply the causal models since we cannot observe them nor knowing the underlying structure. In this paper, they developed implicit causal models which can adjust for confounders.<br />
<br />
There has been growing works on causal models which focus on causal discovery and typically have strong assumptions such as Gaussian processes on noise variable or nonlinearities for the main function.<br />
<br />
==Implicit Causal Models==<br />
Implicit causal models are an extension of probabilistic causal models. Probabilistic causal models will be introduced first.<br />
<br />
=== Probabilistic Causal Models ===<br />
Probabilistic causal models have two parts: deterministic functions of noise and other variables. Consider a global variable <math>\beta</math> and noise <math>\epsilon</math>, where<br />
<br />
[[File: eq1.1.png|800px|center]]<br />
<br />
Each <math>\beta</math> and <math>x</math> is a function of noise; <math>y</math> is a function of noise and <math>x</math>，<br />
<br />
[[File: eqt1.png|800px|center]]<br />
<br />
The target is the causal mechanism <math>f_y</math> so that the causal effect <math>p(y|do(X=x),\beta)</math> can be calculated. <math>do(X=x)</math> means that we specify a value of <math>X</math> under the fixed structure <math>\beta</math>. By other paper’s work, it is assumed that <math>p(y|do(x),\beta) = p(y|x, \beta)</math>.<br />
<br />
[[File: f_1.png|650px|center|]]<br />
<br />
<br />
An example of probabilistic causal models is additive noise model. <br />
<br />
[[File: eq2.1.png|800px|center]]<br />
<br />
<math>f(.)</math> is usually a linear function or spline functions for nonlinearities. <math>\epsilon</math> is assumed to be standard normal, as well as <math>y</math>. Thus the posterior <math>p(\theta | x, y, \beta)</math> can be represented as <br />
<br />
[[File: eqt2.png|800px|center]]<br />
<br />
where <math>p(\theta)</math> is the prior which is known. Then, variational inference or MCMC can be applied to calculate the posterior distribution. <br />
<br />
<br />
===Implicit Causal Models===<br />
The difference between implicit causal models and probabilistic causal models is the noise variable. Instead of an additive noise term, implicit causal models directly take noise <math>\epsilon</math> into a neural network and output <math>x</math>.<br />
<br />
The causal diagram has changed to:<br />
<br />
[[File: f_2.png|650px|center|]]<br />
<br />
<br />
They used fully connected neural network with a fair amount of hidden units to approximate each causal mechanism. Below is the formal description: <br />
<br />
[[File: theorem.png|650px|center|]]<br />
<br />
<br />
==Implicit Causal Models with Latent Confounders==<br />
Previously, they assumed the global structure is observed. Next, the unobserved scenario is being considered.<br />
<br />
===Causal Inference with a Latent Confounder===<br />
Same as before, the interest is the causal effect <math>p(y|do(x_m), x_{-m})</math>. Here, the SNPs other than <math>x_m</math> is also under consideration. However, it is confounded by the unobserved confounder <math>z_n</math>. As a result, the standard inference method cannot be used in this case.<br />
<br />
The paper proposed a new method which include the latent confounders. For each subject <math>n=1,…,N</math> and each SNP <math>m=1,…,M</math>,<br />
<br />
[[File: eqt4.png|800px|center]]<br />
<br />
<br />
The mechanism for latent confounder <math>z_n</math> is assumed to be known. SNPs depend on the confounders and the trait depends on all the SNPs and the confounders as well. <br />
<br />
The posterior of <math>\theta</math> is needed to be calculate in order to estimate the mechanism <math>g_y</math> as well as the causal effect <math>p(y|do(x_m), x_{-m})</math>, so that it can be explained how changes to each SNP <math>X_m</math> cause changes to the trait <math>Y</math>.<br />
<br />
[[File: eqt5.png|800px|center]]<br />
<br />
Note that the latent structure <math>p(z|x, y)</math> is assumed known.<br />
<br />
<br />
===Implicit Causal Model with a Latent Confounder===<br />
This section is the algorithm and functions to implementing an implicit causal model for GWAS.<br />
<br />
====Generative Process of Confounders <math>z_n</math>.====<br />
The distribution of confounders is set as standard normal. <math>z_n \in R^K</math> , where <math>K</math> is the dimension of <math>z_n</math> and <math>K</math> should make the latent space as close as possible to the true population structural. <br />
<br />
====Generative Process of SNPs <math>x_{nm}</math>.====<br />
Given SNP is coded for,<br />
<br />
[[File: SNP.png|300px|center]]<br />
<br />
The authors defined a <math>Binomial(2,\pi_{nm})</math> distribution on <math>x_{nm}</math>. And used logistic factor analysis to design the SNP matrix.<br />
<br />
[[File: gpx.png|800px|center]]<br />
<br />
A SNP matrix looks like this:<br />
[[File: SNP_matrix.png|200px|center]]<br />
<br />
<br />
Since logistic factor analysis makes strong assumptions, this paper suggests using a neural network to relax these assumptions,<br />
<br />
[[File: gpxnn.png|800px|center]]<br />
<br />
This renders the outputs to be a full <math>N*M</math> matrix due the the variables <math>w_m</math>, which act as principal component in PCA. <br />
<br />
====Generative Process of Traits <math>y_n</math>.====<br />
Previously, each trait is modeled by a linear regression,<br />
<br />
[[File: gpy.png|800px|center]]<br />
<br />
This also has very strong assumptions on SNPs, interactions, and additive noise. It can also be replaced by a neural network which only outputs a scalar,<br />
<br />
[[File: gpynn.png|800px|center]]<br />
<br />
<br />
==Likelihood-free Variational Inference==<br />
Calculating the posterior of <math>\theta</math> is the key of applying the implicit causal model with latent confounders.<br />
<br />
[[File: eqt5.png|800px|center]]<br />
<br />
could be reduces to <br />
<br />
[[File: lfvi1.png|800px|center]]<br />
<br />
However, with implicit models, integrating over a nonlinear function could be suffered. The authors applied likelihood-free variational inference (LFVI). LFVI proposes a family of distribution over the latent variables. Here the variables <math>w_m</math> and <math>z_n</math> are all assumed to be Normal,<br />
<br />
[[File: lfvi2.png|700px|center]]<br />
<br />
For LFVI applied to GWAS, the algorithm which similar to the EM algorithm has been used:<br />
[[File: em.png|800px|center]]<br />
<br />
==Empirical Study==<br />
The authors performed simulation on 100,000 SNPs, 940 to 5,000 individuals, and across 100 replications of 11 settings. <br />
Four methods were compared: <br />
<br />
* implicit causal model (ICM);<br />
* PCA with linear regression (PCA); <br />
* a linear mixed model (LMM); <br />
* logistic factor analysis with inverse regression (GCAT).<br />
<br />
The feedforward neural networks for traits and SNPs are fully connected with two hidden layers using ReLU activation function, and batch normalization. <br />
<br />
===Simulation Study===<br />
Based on real genomic data, a true model is applied to generate the SNPs and traits for each configuration. <br />
There are four datasets used in this simulation study: <br />
<br />
# HapMap [Balding-Nichols model]<br />
# 1000 Genomes Project (TGP) [PCA]<br />
#* Human Genome Diversity project (HGDP) [PCA]<br />
#* HGDP [Pritchard-Stephens-Donelly model] <br />
# A latent spatial position of individuals for population structure [spatial]<br />
<br />
<br />
The table shows the prediction accuracy. The accuracy is calculated by the rate of the number of true positives divide the number of true positives plus false positives. True positives measure the proportion of positives that are correctly identified as such (e.g. the percentage of SNPs which are correctly identified as having the causal relation with the trait). In contrast, false positives state the SNPs has the causal relation with the trait when they don’t. The closer the rate to 1, the better the model is since false positives are considered as the wrong prediction.<br />
<br />
[[File: table_1.png|650px|center|]]<br />
<br />
The result represented above shows that the implicit causal model has the best performance among these four models in every situation. Especially, other models tend to do poorly on PSD and Spatial when <math>a</math> is small, but the ICM achieved a significantly high rate. The only comparable method to ICM is GCAT, when applying to simpler configurations.<br />
<br />
<br />
===Real-data Analysis===<br />
They also applied ICM to a real-world GWAS of Northern Finland Birth Cohorts which contain 324,160 SNPs and 5,027 individuals. Ten implicit causal models were fitted and the 2 neural networks both with two hidden layers were used for SNP and trait. <br />
<br />
[[File: table_2.png|650px|center|]]<br />
<br />
The numbers in the above table are the number of significant loci for each of the 10 traits. The number for other methods, such as GCAT, LMM, PCA, and "uncorrected" are obtained from other papers. By comparison, the ICM reached the level of the best previous model for each trait. <br />
<br />
==Conclusion==<br />
This paper introduced implicit causal models in order to account for nonlinear complex causal relationships, and applied the method to GWAS. It can not only capture important interactions between genes within an individual and among population level, but also can adjust for latent confounders by taking account of the latent variables into the model.<br />
<br />
By the simulation study, the authors proved that the implicit causal model could beat other methods by 15-45.3% on a variety of datasets with variations on parameters.<br />
<br />
The authors also believed this GWAS application is only a start of the usage of implicit causal models. It might also be used in physics or economics. <br />
<br />
==Critique==<br />
I think this paper is an interesting and novel work. The main contribution of this paper is to connect the statistical genetics and the machine learning methodology. The method is technically sound and does indeed generalize techniques currently used in statistical genetics.<br />
<br />
The neural network used in this paper is a very simple feedforward 2 hidden layers neural network, but the idea of where to use the neural network is crucial and might be significant in GWAS.<br />
<br />
It has limitations as well. The empirical example in this paper is too easy, and far away from the realistic situation. Despite the simulation study showed some competing results, the Northern Finland Birth Cohort Data application did not demonstrate the advantage of using implicit causal model whether are better than the previous methods, such as GCAT or LMM.<br />
<br />
Another limitation is about linkage disequilibrium as the authors stated as well. SNPs are not completely independent of each other; usually they have correlations when the alleles at close locus. They did not consider this complex case, rather they only considered the simplest case where they assumed all the SNPs are independent.<br />
<br />
Furthermore, one SNP maybe does not have enough power to explain the causal relationship. Recent papers indicate that causation to a trait may involve multiple SNPs.<br />
This could be a future work as well.<br />
<br />
==References==<br />
Tran D, Blei D M. Implicit Causal Models for Genome-wide Association Studies[J]. arXiv preprint arXiv:1710.10742, 2017.<br />
<br />
Patrik O Hoyer, Dominik Janzing, Joris M Mooij, Jonas Peters, and Prof Bernhard Schölkopf. Non- linear causal discovery with additive noise models. In Neural Information Processing Systems, 2009.<br />
<br />
Alkes L Price, Nick J Patterson, Robert M Plenge, Michael E Weinblatt, Nancy A Shadick, and David Reich. Principal components analysis corrects for stratification in genome-wide association studies. Nature Genetics, 38(8):904–909, 2006.<br />
<br />
Minsun Song, Wei Hao, and John D Storey. Testing for genetic associations in arbitrarily structured populations. Nature, 47(5):550–554, 2015.<br />
<br />
Dustin Tran, Rajesh Ranganath, and David M Blei. Hierarchical implicit models and likelihood-free variational inference. In Neural Information Processing Systems, 2017.</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Continuous_Adaptation_via_Meta-Learning_in_Nonstationary_and_Competitive_Environments&diff=34600Continuous Adaptation via Meta-Learning in Nonstationary and Competitive Environments2018-03-18T03:55:33Z<p>X249wang: /* Elaborated on future work */</p>
<hr />
<div>= Introduction =<br />
<br />
Typically, the basic goal of machine learning is to train a model to perform a task. In Meta-learning, the goal is to train a model to perform the task of training a model to perform a task. Hence in this case the term "Meta-Learning" has the exact meaning you would expect; the word "Meta" has the precise function of introducing a layer of abstraction.<br />
<br />
The meta-learning task can be made more concrete by a simple example. Consider the CIFAR-100 classification task that we used for our data competition. We can alter this task from being a 100-class classification problem to a collection of 100 binary classification problems. The goal of Meta-Learning here is to design and train and single binary classifier that will perform well on a randomly sampled task given a limited amount of training data for that specific task. In other words, we would like to train a model to perform the following procedure:<br />
<br />
# A task is sampled. The task is "Is X a dog?".<br />
# A small set of labeled training data is provided to the model. The labels represent whether or not the image is a picture of a dog.<br />
# The model uses the training data to adjust itself to the specific task of checking whether or not an image is a picture of a dog.<br />
<br />
This example also highlights the intuition that the skill of sight is distinct and separable from the skill of knowing what a dog looks like.<br />
<br />
In this paper, a probabilistic framework for meta-learning is derived, then applied to tasks involving simulated robotic spiders. This framework generalizes the typical machine learning set up using Markov Decision Processes. This paper focuses on a multi-agent non-stationary environment which requires reinforcement learning (RL) agents to do continuous adaptation in such an environment. Non-stationarity breaks the standard assumptions and requires agents to continuously adapt, both at training and execution time, in order to earn more rewards, hence the approach is to break this into a sequence of stationary tasks and present it as a multi-task learning problem.<br />
<br />
[[File:paper19_fig1.png|600px|frame|none|alt=Alt text| '''Figure 1'''. a) Illustrates a probabilistic model for Model Agnostic Meta-Learning (MAML) in a multi-task RL setting, where the tasks <math>T</math>, policies <math>\pi</math>, and trajectories <math>\tau</math> are all random variables with dependencies encoded in the edges of a given graph. b) The proposed extension to MAML by the authors suitable for continuous adaptation to a task changing dynamically due to non-stationarity of the environment. The distribution of tasks is represented by a Markov chain, whereby policies from a previous step are used to construct a new policy for the current step. c) The computation graph for the meta-update from <math>\phi_i</math> to <math>\phi_{i+1}</math>. Boxes represent replicas of the policy graphs with the specified parameters. The model is optimized using truncated backpropagation through time starting from <math>L_{T_{i+1}}</math>.]]<br />
<br />
= Model Agnostic Meta-Learning =<br />
<br />
An initial framework for meta-learning is given in "Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks" (Finn et al, 2017):<br />
<br />
"In our approach, the parameters of<br />
the model are explicitly trained such that a small<br />
number of gradient steps with a small amount<br />
of training data from a new task will produce<br />
good generalization performance on that task" (Finn et al, 2017).<br />
<br />
[[File:MAML.png | 500px]]<br />
<br />
In this training algorithm, the parameter vector <math>\theta</math> belonging to the model <math>f_{\theta}</math> is trained such that the meta-objective function <math>\mathcal{L} (\theta) = \sum_{\tau_i \sim P(\tau)} \mathcal{L}_{\tau_i} (f_{\theta_i' }) </math> is minimized. The sum in the objective function is over a sampled batch of training tasks. <math>\mathcal{L}_{\tau_i} (f_{\theta_i'})</math> is the training loss function corresponding to the <math>i^{th}</math> task in the batch evaluated at the model <math>f_{\theta_i'}</math>. The parameter vector <math>\theta_i'</math> is obtained by updating the general parameter <math>\theta</math> using the loss function <math>\mathcal{L}_{\tau_i}</math> and set of K training examples specific to the <math>i^{th}</math> task. Note that in alternate versions of this algorithm, additional testing sets are sampled from <math>\tau_i</math> and used to update <math>\theta</math> using testing loss functions instead of training loss functions.<br />
<br />
One of the important difference between this algorithm and more typical fine-tuning methods is that <math>\theta</math> is explicitly trained to be easily adjusted to perform well on different tasks rather than perform well on any specific tasks then fine tuned as the environment changes. (Sutton et al., 2007)<br />
<br />
= Probabilistic Framework for Meta-Learning =<br />
<br />
This paper puts the meta-learning problem into a Markov Decision Process (MDP) framework common to RL, see Figure 1a. Instead of training examples <math>\{(x, y)\}</math>, we have trajectories <math>\tau = (x_0, a_1, x_1, R_1, x_2, ... a_H, x_H, R_H)</math>. A trajectory is sequence of states/observations <math>x_t</math>, actions <math>a_t</math> and rewards <math>R_t</math> that is sampled from a task <math> T </math> according to a policy <math>\pi_{\theta}</math>. Included with said task is a method for assigning loss values to trajectories <math>L_T(\tau)</math> which is typically the negative cumulative reward. A policy is a deterministic function that takes in a state and returns an action. Our goal here is to train a policy <math>\pi_{\theta}</math> with parameter vector <math>\theta</math>. This is analougous to training a function <math>f_{\theta}</math> that assigns labels <math>y</math> to feature vectors <math>x</math>. More precisely we have the following definitions:<br />
<br />
* <math>T :=(L_T, P_T(x), P_T(x_t | x_{t-1}, a_{t-1}), H )</math> (A Task)<br />
* <math>D(T)</math> : A distribution over tasks.<br />
* <math>L_T</math>: A loss function for the task T that assigns numeric loss values to trajectories.<br />
* <math>P_T(x), P_T(x_t | x_{t-1}, a_{t-1})</math>: Probability measures specifying the markovian dynamics of the observations <math>x_t</math><br />
* <math>H</math>: The horizon of the MDP. This is a fixed natural number specifying the lengths of the tasks trajectories.<br />
<br />
The papers goes further to define a Markov dynamic for sequences of tasks as shown in Figure 1b. Thus the policy that we would like to meta learn <math>\pi_{\theta}</math>, after being exposed to a sample of K trajectories <math>\tau_\theta^{1:K}</math> from the task <math>T_i</math>, should produce a new policy <math>\pi_{\phi}</math> that will perform well on the next task <math>T_{i+1}</math>. Thus we seek to minimize the following expectation:<br />
<br />
<math>\mathrm{E}_{P(T_0), P(T_{i+1} | T_i)}\bigg(\sum_{i=1}^{l} \mathcal{L}_{T_i, T_{i+1}}(\theta)\bigg)</math>, <br />
<br />
where <math>\mathcal{L}_{T_i, T_{i + 1}}(\theta) := \mathrm{E}_{\tau_{i, \theta}^{1:K} } \bigg( \mathrm{E}_{\tau_{i+1, \phi}}\Big( L_{T_{i+1}}(\tau_{i+1, \phi} | \tau_{i, \theta}^{1:K}, \theta) \Big) \bigg) </math> and <math>l</math> is the number of tasks.<br />
<br />
The meta-policy <math>\pi_{\theta}</math> is trained and then adapted at test time using the following procedures. The computational graph is given in Figure 1c.<br />
<br />
[[File:MAML2.png | 800px]]<br />
<br />
The mathematics of calculating loss gradients is omitted.<br />
<br />
= Training Spiders to Run with Dynamic Handicaps (Robotic Locomotion in Non-Stationary Environments) =<br />
<br />
The authors used the MuJoCo physics simulator to create a simulated environment where robotic spiders with 6 legs are faced with the task of running due east as quickly as possible. The robotic spider observes the location and velocity of its body, and the angles and velocities of its legs. It interacts with the environment by exerting torque on the joints of its legs. Each leg has two joints, the joint closer to the body rotates horizontally while the joint farther from the body rotates vertically. The environment is made non-stationary by gradually paralyzing two legs of the spider across training and testing episodes.<br />
Putting this example into the above probabilistic framework yields:<br />
<br />
* <math>T_i</math>: The task of walking east with the torques of two legs scaled by <math> (i-1)/6 </math><br />
* <math>\{T_i\}_{i=1}^{7}</math>: A sequence of tasks with the same two legs handicapped in each task. Note there are 15 different ways to choose such legs resulting in 15 sequences of tasks. 12 are used for training and 3 for testing.<br />
* A Markov Descision process composed of<br />
** Observations <math> x_t </math> containing information about the state of the spider.<br />
** Actions <math> a_t </math> containing information about the torques to apply to the spiders legs.<br />
** Rewards <math> R_t </math> corresponding to the speed at which the spider is moving east.<br />
<br />
Three differently structured policy neural networks are trained in this set up using both meta-learning and three different previously developed adaption methods.<br />
<br />
At testing time, the spiders following meta learned policies initially perform worse than the spiders using non-adaptive policies. However, by the third episode (<math> i=3 </math>), the meta-learners perform on par. And by the sixth episode, when the selected legs are mostly immobile, the meta-learners significantly out perform. These results can be seen in the graphs below.<br />
<br />
[[File:locomotion_results.png | 800px]]<br />
<br />
= Training Spiders to Fight Each Other (Adversarial Meta-Learning) =<br />
<br />
The authors created an adversarial environment called RoboSumo where pairs of agents with 4 (named Ants), 6 (named Bugs),or 8 legs (named spiders) sumo wrestle. The agents observe the location and velocity of their bodies and the bodies of their opponent, the angles and velocities of their legs, and the forces being exerted on them by their opponent (equivalent of tactile sense). The game is organized into episodes and rounds. Episodes are single wrestling matches with 500 time steps and win/lose/draw outcomes. Agents win by pushing their opponent out of the ring or making their opponent's body touch the ground. Rounds are batches of episodes. An episode results in a draw when neither of these things happen after 500 time steps. Rounds have possible outcomes win, lose, and draw that are decided based on majority of episodes won. K rounds will be fought. Both agents may update their policies between rounds. The agent that wins the majority of rounds is deemed the winner of the game.<br />
<br />
== Setup ==<br />
Similar to the Robotic locomotion example, this game can be phrased in terms of the RL MDP framework.<br />
<br />
* <math>T_i</math>: The task of fighting a round.<br />
* <math>\{T_i\}_{i=1}^{K}</math>: A sequence of rounds against the same opponent. Note that the opponent may update their policy between rounds but the anatomy of both wrestlers will be constant across rounds.<br />
* A Markov Descision process composed of<br />
** A horizon <math>H = 500*n</math> where <math>n</math> is the number of episodes per round.<br />
** Observations <math> x_t </math> containing information about the state of the agent and its opponent.<br />
** Actions <math> a_t </math> containing information about the torques to apply to the agents legs.<br />
** Rewards <math> R_t </math> rewards given to the agent based on its wrestling performance. <math>R_{500*n} = </math> +2000 if win episode, -2000 if lose, and -1000 if draw.<br />
<br />
Note that the above reward set up is quite sparse, therefore in order to encourage fast training, rewards are introduced at every time step for the following.<br />
* For staying close to the center of the ring.<br />
* For exerting force on the opponents body.<br />
* For moving towards the opponent.<br />
* For the distance of the opponent to the center of the ring.<br />
<br />
This makes sense intuitively as these are reasonable goals for agents to explore when they are learning to wrestle.<br />
<br />
== Training ==<br />
The same combinations of policy networks and adaptation methods that were used in the locomotion example are trained and tested here. A family of non-adaptive policies are first trained via self-play and saved at all stages. Self-play simply means the two agents in the training environment use the same policy. All policy versions are saved so that agents of various skill levels can be sampled when training meta-learners. The weights of the different insects were calibrated such that the test win rate between two insects of differing anatomy, who have been trained for the same number of epochs via self-play, is close to 50%.<br />
<br />
[[File:weight_cal.png | 800px]]<br />
<br />
We can see in the above figure that the weight of the spider had to be increased by almost four times in order for the agents to be evenly matched.<br />
<br />
[[File:robosumo_results.png | 800px]]<br />
<br />
The above figure shows testing results for various adaptation strategies. The agent and opponent both start with the self-trained policies. The opponent uses all of its testing experience to continue training. The agent uses only the last 75 episodes to adapt its policy network. This shows that metal learners need only a limited amount of experience in order to hold their own against a constantly improving opponent.<br />
<br />
= Future Work =<br />
The authors noted that the meta-learning adaptation rule they proposed is similar to backpropagation through time with a unit time lag, so a potential area for future research would be to introduce fully-recurrent meta-updates based on the full interaction history with the environment. Secondly, the algorithm proposed involves computing second-order derivatives at training time (see Figure 1c), which resulted in much slower training processes compared to baseline models during experiments, so they suggested finding a method to utilize information from the loss function without explicit backpropagation to speed up computations. The authors also mention that their approach likely will not work well with sparse rewards. This is because the meta-updates, which use policy gradients, are very dependent on the reward signal. They mention that this is an issue they would like to address in the future. A potential solution they have outlined for this is to introduce auxiliary dense rewards which could enable meta-learning.<br />
<br />
= Sources =<br />
# Chelsea Finn, Pieter Abbeel, Sergey Levine. "Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks." arXiv preprint arXiv:1703.03400v3 (2017).<br />
# Richard S Sutton, Anna Koop, and David Silver. On the role of tracking in stationary environments. In Proceedings of the 24th international conference on Machine learning, pp. 871–878. ACM, 2007.</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Continuous_Adaptation_via_Meta-Learning_in_Nonstationary_and_Competitive_Environments&diff=34599Continuous Adaptation via Meta-Learning in Nonstationary and Competitive Environments2018-03-18T03:36:52Z<p>X249wang: /* Editorial */</p>
<hr />
<div>= Introduction =<br />
<br />
Typically, the basic goal of machine learning is to train a model to perform a task. In Meta-learning, the goal is to train a model to perform the task of training a model to perform a task. Hence in this case the term "Meta-Learning" has the exact meaning you would expect; the word "Meta" has the precise function of introducing a layer of abstraction.<br />
<br />
The meta-learning task can be made more concrete by a simple example. Consider the CIFAR-100 classification task that we used for our data competition. We can alter this task from being a 100-class classification problem to a collection of 100 binary classification problems. The goal of Meta-Learning here is to design and train and single binary classifier that will perform well on a randomly sampled task given a limited amount of training data for that specific task. In other words, we would like to train a model to perform the following procedure:<br />
<br />
# A task is sampled. The task is "Is X a dog?".<br />
# A small set of labeled training data is provided to the model. The labels represent whether or not the image is a picture of a dog.<br />
# The model uses the training data to adjust itself to the specific task of checking whether or not an image is a picture of a dog.<br />
<br />
This example also highlights the intuition that the skill of sight is distinct and separable from the skill of knowing what a dog looks like.<br />
<br />
In this paper, a probabilistic framework for meta-learning is derived, then applied to tasks involving simulated robotic spiders. This framework generalizes the typical machine learning set up using Markov Decision Processes. This paper focuses on a multi-agent non-stationary environment which requires reinforcement learning (RL) agents to do continuous adaptation in such an environment. Non-stationarity breaks the standard assumptions and requires agents to continuously adapt, both at training and execution time, in order to earn more rewards, hence the approach is to break this into a sequence of stationary tasks and present it as a multi-task learning problem.<br />
<br />
[[File:paper19_fig1.png|600px|frame|none|alt=Alt text| '''Figure 1'''. a) Illustrates a probabilistic model for Model Agnostic Meta-Learning (MAML) in a multi-task RL setting, where the tasks <math>T</math>, policies <math>\pi</math>, and trajectories <math>\tau</math> are all random variables with dependencies encoded in the edges of a given graph. b) The proposed extension to MAML by the authors suitable for continuous adaptation to a task changing dynamically due to non-stationarity of the environment. The distribution of tasks is represented by a Markov chain, whereby policies from a previous step are used to construct a new policy for the current step. c) The computation graph for the meta-update from <math>\phi_i</math> to <math>\phi_{i+1}</math>. Boxes represent replicas of the policy graphs with the specified parameters. The model is optimized using truncated backpropagation through time starting from <math>L_{T_{i+1}}</math>.]]<br />
<br />
= Model Agnostic Meta-Learning =<br />
<br />
An initial framework for meta-learning is given in "Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks" (Finn et al, 2017):<br />
<br />
"In our approach, the parameters of<br />
the model are explicitly trained such that a small<br />
number of gradient steps with a small amount<br />
of training data from a new task will produce<br />
good generalization performance on that task" (Finn et al, 2017).<br />
<br />
[[File:MAML.png | 500px]]<br />
<br />
In this training algorithm, the parameter vector <math>\theta</math> belonging to the model <math>f_{\theta}</math> is trained such that the meta-objective function <math>\mathcal{L} (\theta) = \sum_{\tau_i \sim P(\tau)} \mathcal{L}_{\tau_i} (f_{\theta_i' }) </math> is minimized. The sum in the objective function is over a sampled batch of training tasks. <math>\mathcal{L}_{\tau_i} (f_{\theta_i'})</math> is the training loss function corresponding to the <math>i^{th}</math> task in the batch evaluated at the model <math>f_{\theta_i'}</math>. The parameter vector <math>\theta_i'</math> is obtained by updating the general parameter <math>\theta</math> using the loss function <math>\mathcal{L}_{\tau_i}</math> and set of K training examples specific to the <math>i^{th}</math> task. Note that in alternate versions of this algorithm, additional testing sets are sampled from <math>\tau_i</math> and used to update <math>\theta</math> using testing loss functions instead of training loss functions.<br />
<br />
One of the important difference between this algorithm and more typical fine-tuning methods is that <math>\theta</math> is explicitly trained to be easily adjusted to perform well on different tasks rather than perform well on any specific tasks then fine tuned as the environment changes. (Sutton et al., 2007)<br />
<br />
= Probabilistic Framework for Meta-Learning =<br />
<br />
This paper puts the meta-learning problem into a Markov Decision Process (MDP) framework common to RL, see Figure 1a. Instead of training examples <math>\{(x, y)\}</math>, we have trajectories <math>\tau = (x_0, a_1, x_1, R_1, x_2, ... a_H, x_H, R_H)</math>. A trajectory is sequence of states/observations <math>x_t</math>, actions <math>a_t</math> and rewards <math>R_t</math> that is sampled from a task <math> T </math> according to a policy <math>\pi_{\theta}</math>. Included with said task is a method for assigning loss values to trajectories <math>L_T(\tau)</math> which is typically the negative cumulative reward. A policy is a deterministic function that takes in a state and returns an action. Our goal here is to train a policy <math>\pi_{\theta}</math> with parameter vector <math>\theta</math>. This is analougous to training a function <math>f_{\theta}</math> that assigns labels <math>y</math> to feature vectors <math>x</math>. More precisely we have the following definitions:<br />
<br />
* <math>T :=(L_T, P_T(x), P_T(x_t | x_{t-1}, a_{t-1}), H )</math> (A Task)<br />
* <math>D(T)</math> : A distribution over tasks.<br />
* <math>L_T</math>: A loss function for the task T that assigns numeric loss values to trajectories.<br />
* <math>P_T(x), P_T(x_t | x_{t-1}, a_{t-1})</math>: Probability measures specifying the markovian dynamics of the observations <math>x_t</math><br />
* <math>H</math>: The horizon of the MDP. This is a fixed natural number specifying the lengths of the tasks trajectories.<br />
<br />
The papers goes further to define a Markov dynamic for sequences of tasks as shown in Figure 1b. Thus the policy that we would like to meta learn <math>\pi_{\theta}</math>, after being exposed to a sample of K trajectories <math>\tau_\theta^{1:K}</math> from the task <math>T_i</math>, should produce a new policy <math>\pi_{\phi}</math> that will perform well on the next task <math>T_{i+1}</math>. Thus we seek to minimize the following expectation:<br />
<br />
<math>\mathrm{E}_{P(T_0), P(T_{i+1} | T_i)}\bigg(\sum_{i=1}^{l} \mathcal{L}_{T_i, T_{i+1}}(\theta)\bigg)</math>, <br />
<br />
where <math>\mathcal{L}_{T_i, T_{i + 1}}(\theta) := \mathrm{E}_{\tau_{i, \theta}^{1:K} } \bigg( \mathrm{E}_{\tau_{i+1, \phi}}\Big( L_{T_{i+1}}(\tau_{i+1, \phi} | \tau_{i, \theta}^{1:K}, \theta) \Big) \bigg) </math> and <math>l</math> is the number of tasks.<br />
<br />
The meta-policy <math>\pi_{\theta}</math> is trained and then adapted at test time using the following procedures. The computational graph is given in Figure 1c.<br />
<br />
[[File:MAML2.png | 800px]]<br />
<br />
The mathematics of calculating loss gradients is omitted.<br />
<br />
= Training Spiders to Run with Dynamic Handicaps (Robotic Locomotion in Non-Stationary Environments) =<br />
<br />
The authors used the MuJoCo physics simulator to create a simulated environment where robotic spiders with 6 legs are faced with the task of running due east as quickly as possible. The robotic spider observes the location and velocity of its body, and the angles and velocities of its legs. It interacts with the environment by exerting torque on the joints of its legs. Each leg has two joints, the joint closer to the body rotates horizontally while the joint farther from the body rotates vertically. The environment is made non-stationary by gradually paralyzing two legs of the spider across training and testing episodes.<br />
Putting this example into the above probabilistic framework yields:<br />
<br />
* <math>T_i</math>: The task of walking east with the torques of two legs scaled by <math> (i-1)/6 </math><br />
* <math>\{T_i\}_{i=1}^{7}</math>: A sequence of tasks with the same two legs handicapped in each task. Note there are 15 different ways to choose such legs resulting in 15 sequences of tasks. 12 are used for training and 3 for testing.<br />
* A Markov Descision process composed of<br />
** Observations <math> x_t </math> containing information about the state of the spider.<br />
** Actions <math> a_t </math> containing information about the torques to apply to the spiders legs.<br />
** Rewards <math> R_t </math> corresponding to the speed at which the spider is moving east.<br />
<br />
Three differently structured policy neural networks are trained in this set up using both meta-learning and three different previously developed adaption methods.<br />
<br />
At testing time, the spiders following meta learned policies initially perform worse than the spiders using non-adaptive policies. However, by the third episode (<math> i=3 </math>), the meta-learners perform on par. And by the sixth episode, when the selected legs are mostly immobile, the meta-learners significantly out perform. These results can be seen in the graphs below.<br />
<br />
[[File:locomotion_results.png | 800px]]<br />
<br />
= Training Spiders to Fight Each Other (Adversarial Meta-Learning) =<br />
<br />
The authors created an adversarial environment called RoboSumo where pairs of agents with 4 (named Ants), 6 (named Bugs),or 8 legs (named spiders) sumo wrestle. The agents observe the location and velocity of their bodies and the bodies of their opponent, the angles and velocities of their legs, and the forces being exerted on them by their opponent (equivalent of tactile sense). The game is organized into episodes and rounds. Episodes are single wrestling matches with 500 time steps and win/lose/draw outcomes. Agents win by pushing their opponent out of the ring or making their opponent's body touch the ground. Rounds are batches of episodes. An episode results in a draw when neither of these things happen after 500 time steps. Rounds have possible outcomes win, lose, and draw that are decided based on majority of episodes won. K rounds will be fought. Both agents may update their policies between rounds. The agent that wins the majority of rounds is deemed the winner of the game.<br />
<br />
== Setup ==<br />
Similar to the Robotic locomotion example, this game can be phrased in terms of the RL MDP framework.<br />
<br />
* <math>T_i</math>: The task of fighting a round.<br />
* <math>\{T_i\}_{i=1}^{K}</math>: A sequence of rounds against the same opponent. Note that the opponent may update their policy between rounds but the anatomy of both wrestlers will be constant across rounds.<br />
* A Markov Descision process composed of<br />
** A horizon <math>H = 500*n</math> where <math>n</math> is the number of episodes per round.<br />
** Observations <math> x_t </math> containing information about the state of the agent and its opponent.<br />
** Actions <math> a_t </math> containing information about the torques to apply to the agents legs.<br />
** Rewards <math> R_t </math> rewards given to the agent based on its wrestling performance. <math>R_{500*n} = </math> +2000 if win episode, -2000 if lose, and -1000 if draw.<br />
<br />
Note that the above reward set up is quite sparse, therefore in order to encourage fast training, rewards are introduced at every time step for the following.<br />
* For staying close to the center of the ring.<br />
* For exerting force on the opponents body.<br />
* For moving towards the opponent.<br />
* For the distance of the opponent to the center of the ring.<br />
<br />
This makes sense intuitively as these are reasonable goals for agents to explore when they are learning to wrestle.<br />
<br />
== Training ==<br />
The same combinations of policy networks and adaptation methods that were used in the locomotion example are trained and tested here. A family of non-adaptive policies are first trained via self-play and saved at all stages. Self-play simply means the two agents in the training environment use the same policy. All policy versions are saved so that agents of various skill levels can be sampled when training meta-learners. The weights of the different insects were calibrated such that the test win rate between two insects of differing anatomy, who have been trained for the same number of epochs via self-play, is close to 50%.<br />
<br />
[[File:weight_cal.png | 800px]]<br />
<br />
We can see in the above figure that the weight of the spider had to be increased by almost four times in order for the agents to be evenly matched.<br />
<br />
[[File:robosumo_results.png | 800px]]<br />
<br />
The above figure shows testing results for various adaptation strategies. The agent and opponent both start with the self-trained policies. The opponent uses all of its testing experience to continue training. The agent uses only the last 75 episodes to adapt its policy network. This shows that metal learners need only a limited amount of experience in order to hold their own against a constantly improving opponent.<br />
<br />
= Future Work =<br />
The authors mention that their approach will likely not work well with sparse rewards. This is because the meta-updates, which use policy gradients, are very dependent on the reward signal. They mention that this is an issue they would like to address in the future. A potential solution they have outlined for this is to introduce auxiliary dense rewards which could enable meta-learning.<br />
<br />
= Sources =<br />
# Chelsea Finn, Pieter Abbeel, Sergey Levine. "Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks." arXiv preprint arXiv:1703.03400v3 (2017).<br />
# Richard S Sutton, Anna Koop, and David Silver. On the role of tracking in stationary environments. In Proceedings of the 24th international conference on Machine learning, pp. 871–878. ACM, 2007.</div>X249wanghttp://wiki.math.uwaterloo.ca/statwiki/index.php?title=Continuous_Adaptation_via_Meta-Learning_in_Nonstationary_and_Competitive_Environments&diff=34598Continuous Adaptation via Meta-Learning in Nonstationary and Competitive Environments2018-03-18T03:08:51Z<p>X249wang: /* Editorial */</p>
<hr />
<div>= Introduction =<br />
<br />
Typically, the basic goal of machine learning is to train a model to perform a task. In Meta-learning, the goal is to train a model to perform the task of training a model to perform a task. Hence in this case the term "Meta-Learning" has the exact meaning you would expect; the word "Meta" has the precise function of introducing a layer of abstraction.<br />
<br />
The meta-learning task can be made more concrete by a simple example. Consider the CIFAR-100 classification task that we used for our data competition. We can alter this task from being a 100-class classification problem to a collection of 100 binary classification problems. The goal of Meta-Learning here is to design and train and single binary classifier that will perform well on a randomly sampled task given a limited amount of training data for that specific task. In other words, we would like to train a model to perform the following procedure:<br />
<br />
# A task is sampled. The task is "Is X a dog?".<br />
# A small set of labeled training data is provided to the model. The labels represent whether or not the image is a picture of a dog.<br />
# The model uses the training data to adjust itself to the specific task of checking whether or not an image is a picture of a dog.<br />
<br />
This example also highlights the intuition that the skill of sight is distinct and separable from the skill of knowing what a dog looks like.<br />
<br />
In this paper, a probabilistic framework for meta-learning is derived, then applied to tasks involving simulated robotic spiders. This framework generalizes the typical machine learning set up using Markov Decision Processes. This paper focuses on a multi-agent non-stationary environment which requires reinforcement learning (RL) agents to do continuous adaptation in such an environment. Non-stationarity breaks the standard assumptions and requires agents to continuously adapt, both at training and execution time, in order to earn more rewards, hence the approach is to break this into a sequence of stationary tasks and present it as a multi-task learning problem.<br />
<br />
[[File:paper19_fig1.png|600px|frame|none|alt=Alt text| '''Figure 1'''. a) Illustrates a probabilistic model for Model Agnostic Meta-Learning (MAML) in a multi-task RL setting, where the tasks <math>T</math>, policies <math>\pi</math>, and trajectories <math>\tau</math> are all random variables with dependencies encoded in the edges of a given graph. b) The proposed extension to MAML by the authors suitable for continuous adaptation to a task changing dynamically due to non-stationarity of the environment. The distribution of tasks is represented by a Markov chain, whereby policies from a previous step are used to construct a new policy for the current step. c) The computation graph for the meta-update from <math>\phi_i</math> to <math>\phi_{i+1}</math>. Boxes represent replicas of the policy graphs with the specified parameters. The model is optimized using truncated backpropagation through time starting from <math>L_{T_{i+1}}</math>.]]<br />
<br />
= Model Agnostic Meta-Learning =<br />
<br />
An initial framework for meta-learning is given in "Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks" (Finn et al, 2017):<br />
<br />
"In our approach, the parameters of<br />
the model are explicitly trained such that a small<br />
number of gradient steps with a small amount<br />
of training data from a new task will produce<br />
good generalization performance on that task" (Finn et al, 2017).<br />
<br />
[[File:MAML.png | 500px]]<br />
<br />
In this training algorithm, the parameter vector <math>\theta</math> belonging to the model <math>f_{\theta}</math> is trained such that the meta-objective function <math>\mathcal{L} (\theta) = \sum_{\tau_i \sim P(\tau)} \mathcal{L}_{\tau_i} (f_{\theta_i' }) </math> is minimized. The sum in the objective function is over a sampled batch of training tasks. <math>\mathcal{L}_{\tau_i} (f_{\theta_i'})</math> is the training loss function corresponding to the <math>i^{th}</math> task in the batch evaluated at the model <math>f_{\theta_i'}</math>. The parameter vector <math>\theta_i'</math> is obtained by updating the general parameter <math>\theta</math> using the loss function <math>\mathcal{L}_{\tau_i}</math> and set of K training examples specific to the <math>i^{th}</math> task. Note that in alternate versions of this algorithm, additional testing sets are sampled from <math>\tau_i</math> and used to update <math>\theta</math> using testing loss functions instead of training loss functions.<br />
<br />
One of the important difference between this algorithm and more typical fine-tuning methods is that <math>\theta</math> is explicitly trained to be easily adjusted to perform well on different tasks rather than perform well on any specific tasks then fine tuned as the environment changes. (Sutton et al., 2007)<br />
<br />
= Probabilistic Framework for Meta-Learning =<br />
<br />
This paper puts the meta-learning problem into a Markov Decision Process (MDP) framework common to RL, see Figure 1a. Instead of training examples <math>\{(x, y)\}</math>, we have trajectories <math>\tau = (x_0, a_1, x_1, R_1, x_2, ... a_H, x_H, R_H)</math>. A trajectory is sequence of states/observations <math>x_t</math>, actions <math>a_t</math> and rewards <math>R_t</math> that is sampled from a task <math> T </math> according to a policy <math>\pi_{\theta}</math>. Included with said task is a method for assigning loss values to trajectories <math>L_T(\tau)</math> which is typically the negative cumulative reward. A policy is a deterministic function that takes in a state and returns an action. Our goal here is to train a policy <math>\pi_{\theta}</math> with parameter vector <math>\theta</math>. This is analougous to training a function <math>f_{\theta}</math> that assigns labels <math>y</math> to feature vectors <math>x</math>. More precisely we have the following definitions:<br />
<br />
* <math>T :=(L_T, P_T(x), P_T(x_t | x_{t-1}, a_{t-1}), H )</math> (A Task)<br />
* <math>D(T)</math> : A distribution over tasks.<br />
* <math>L_T</math>: A loss function for the task T that assigns numeric loss values to trajectories.<br />
* <math>P_T(x), P_T(x_t | x_{t-1}, a_{t-1})</math>: Probability measures specifying the markovian dynamics of the observations <math>x_t</math><br />
* <math>H</math>: The horizon of the MDP. This is a fixed natural number specifying the lengths of the tasks trajectories.<br />
<br />
The papers goes further to define a Markov dynamic for sequences of tasks as shown in Figure 1b. Thus the policy that we would like to meta learn <math>\pi_{\theta}</math>, after being exposed to a sample of K trajectories <math>\tau_\theta^{1:K}</math> from the task <math>T_i</math>, should produce a new policy <math>\pi_{\phi}</math> that will perform well on the next task <math>T_{i+1}</math>. Thus we seek to minimize the following expectation:<br />
<br />
<math>\mathrm{E}_{P(T_0), P(T_{i+1} | T_i)}\bigg(\sum_{i=1}^{l} \mathcal{L}_{T_i, T_{i+1}}(\theta)\bigg)</math>, <br />
<br />
where <math>\mathcal{L}_{T_i, T_{i + 1}}(\theta) := \mathrm{E}_{\tau_{i, \theta}^{1:K} } \bigg( \mathrm{E}_{\tau_{i+1, \phi}}\Big( L_{T_{i+1}}(\tau_{i+1, \phi} | \tau_{i, \theta}^{1:K}, \theta) \Big) \bigg) </math> and <math>l</math> is the number of tasks.<br />
<br />
The meta-policy <math>\pi_{\theta}</math> is trained and then adapted at test time using the following procedures. The computational graph is given in Figure 1c.<br />
<br />
[[File:MAML2.png | 800px]]<br />
<br />
The mathematics of calculating loss gradients is omitted.<br />
<br />
= Training Spiders to Run with Dynamic Handicaps (Robotic Locomotion in Non-Stationary Environments) =<br />
<br />
The authors used the MuJoCo physics simulator to create a simulated environment where robotic spiders with 6 legs are faced with the task of running due east as quickly as possible. The robotic spider observes the location and velocity of it's body, and the angles and velocities of its legs. It interacts with the environment by exerting torque on the joints of its legs. Each leg has two joints, the joint closer to the body rotates horizontally while the joint farther from the body rotates vertically. The environment is made non-stationary by gradually paralyzing two legs of the spider across training and testing episodes.<br />
Putting this example into the above probabilistic framework yields:<br />
<br />
* <math>T_i</math>: The task of walking east with the torques of two legs scaled by <math> (i-1)/6 </math><br />
* <math>\{T_i\}_{i=1}^{7}</math>: A sequence of tasks with the same two legs handicapped in each task. Note there are 15 different ways to choose such legs resulting in 15 sequences of tasks. 12 are used for training and 3 for testing.<br />
* A Markov Descision process composed of<br />
** Observations <math> x_t </math> containing information about the state of the spider.<br />
** Actions <math> a_t </math> containing information about the torques to apply to the spiders legs.<br />
** Rewards <math> R_t </math> corresponding to the speed at which the spider is moving east.<br />
<br />
Three differently structured policy neural networks are trained in this set up using both meta-learning and three different previously developed adaption methods.<br />
<br />
At testing time, the spiders following meta learned policies initially perform worse than the spiders using non-adaptive policies. However, by the third episode <math> i=3 </math> the meta learners perform on par. And by the sixth episode, when the selected legs are mostly immobile, the meta learners significantly out perform. These results can be seen in the graphs below.<br />
<br />
[[File:locomotion_results.png | 800px]]<br />
<br />
= Training Spiders to Fight Each Other (Adversarial Meta-Learning) =<br />
<br />
The authors created an adversarial environment called RoboSumo where pairs of agents with 4 (named Ants), 6 (named Bugs),or 8 legs (named spiders) sumo wrestle. The agents observe the location and velocity of their bodies and the bodies of their opponent, the angles and velocities of their legs, and the forces being exerted on them by their opponent (equivalent of tactile sense). The game is organized into episodes and rounds. Episodes are single wrestling matches with 500 time steps and win/lose/draw outcomes. Agents win by pushing their opponent out of the ring or making their opponent's body touch the ground. Rounds are batches of episodes. An episode results in a draw when neither of these things happen after 500 time steps. Rounds have possible outcomes win, lose, and draw that are decided based on majority of episodes won. K rounds will be fought. Both agents may update their policies between rounds. The agent that wins the majority of rounds is deemed the winner of the game.<br />
<br />
== Setup ==<br />
Similar to the Robotic locomotion example, this game can be phrased in terms of the RL MDP framework.<br />
<br />
* <math>T_i</math>: The task of fighting a round.<br />
* <math>\{T_i\}_{i=1}^{K}</math>: A sequence of rounds against the same opponent. Note that the opponent may update their policy between rounds but the anatomy of both wrestlers will be constant across rounds.<br />
* A Markov Descision process composed of<br />
** A horizon <math>H = 500*n</math> where <math>n</math> is the number of episodes per round.<br />
** Observations <math> x_t </math> containing information about the state of the agent and its opponent.<br />
** Actions <math> a_t </math> containing information about the torques to apply to the agents legs.<br />
** Rewards <math> R_t </math> rewards given to the agent based on its wrestling performance. <math>R_{500*n} = </math> +2000 if win episode, -2000 if lose, and -1000 if draw.<br />
<br />
Note that the above reward set up is quite sparse, therefore in order to encourage fast training, rewards are introduced at every time step for the following.<br />
* For staying close to the center of the ring.<br />
* For exerting force on the opponents body.<br />
* For moving towards the opponent.<br />
* For the distance of the opponent to the center of the ring.<br />
<br />
This makes sense intuitively as these are reasonable goals for agents to explore when they are learning to wrestle.<br />
<br />
== Training ==<br />
The same combinations of policy networks and adaptation methods that were used in the locomotion example are trained and tested here. A family of non-adaptive policies are first trained via self-play and saved at all stages. Self-play simply means the two agents in the training environment use the same policy. All policy versions are saved so that agents of various skill levels can be sampled when training meta-learners. The weights of the different insects were calibrated such that the test win rate between two insects of differing anatomy, who have been trained for the same number of epochs via self-play, is close to 50%.<br />
<br />
[[File:weight_cal.png | 800px]]<br />
<br />
We can see in the above figure that the weight of the spider had to be increased by almost four times in order for the agents to be evenly matched.<br />
<br />
[[File:robosumo_results.png | 800px]]<br />
<br />
The above figure shows testing results for various adaptation strategies. The agent and opponent both start with the self-trained policies. The opponent uses all of its testing experience to continue training. The agent uses only the last 75 episodes to adapt its policy network. This shows that metal learners need only a limited amount of experience in order to hold their own against a constantly improving opponent.<br />
<br />
= Future Work =<br />
The authors mention that their approach will likely not work well with sparse rewards. This is because the meta-updates, which use policy gradients, are very dependent on the reward signal. They mention that this is an issue they would like to address in the future. A potential solution they have outlined for this is to introduce auxiliary dense rewards which could enable meta-learning.<br />
<br />
= Sources =<br />
# Chelsea Finn, Pieter Abbeel, Sergey Levine. "Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks." arXiv preprint arXiv:1703.03400v3 (2017).<br />
# Richard S Sutton, Anna Koop, and David Silver. On the role of tracking in stationary environments. In Proceedings of the 24th international conference on Machine learning, pp. 871–878. ACM, 2007.</div>X249wang