Wasserstein Auto-Encoders

From statwiki
Jump to navigation Jump to search

Introduction

Recent years have seen a convergence of two previously distinct approaches: representation learning from high dimensional data, and unsupervised generative modeling. In the field that formed at their intersection, Variational Auto-Encoders (VAEs) and Generative Adversarial Networks (GANs) have emerged to become well-established. VAEs are theoretically elegant but with the drawback that they tend to generate blurry samples when applied to natural images. GANs on the other hand produce better visual quality of sampled images, but come without an encoder, are harder to train and suffer from the mode-collapse problem when the trained model is unable to capture all the variability in the true data distribution. There has been recent research in generating encoder-decoder GANs where an encoder is trained in parallel with the generator, based on the intuition that this will allow the GAN to learn meaningful mapping from the compressed representation to the original image; however, these models also suffer from mode-collapse and perform comparable to vanilla GANs. Thus there has been a push to come up with the best way to combine them together, but a principled unifying framework is yet to be discovered.

This work proposes a new family of regularized auto-encoders called the Wasserstein Auto-Encoder (WAE). The proposed method provides a novel theoretical insight into setting up an objective function for auto-encoders from the point of view of of optimal transport (OT). This theoretical formulation leads the authors to examine adversarial and maximum mean discrepancy based regularizers for matching a prior and the distribution of encoded data points in the latent space. An empirical evaluation is performed on MNIST and CelebA datasets, where WAE is found to generate samples of better quality than VAE while preserving training stability, encoder-decoder structure and nice latent manifold structure.

The main contribution of the proposed algorithm is to provide theoretical foundations for using optimal transport cost as the auto-encoder objective function, while blending auto-encoders and GANs in a principled way. It also theoretically and experimentally explores the interesting relationships between WAEs, VAEs and adversarial auto-encoders.

Proposed Approach

Theory of Optimal Transport and Wasserstein Distance

Wasserstein Distance is a measure of the distance between two probability distributions. It is also called Earth Mover’s distance, short for EM distance, because informally it can be interpreted as moving piles of dirt that follow one probability distribution at a minimum cost to follow the other distribution. The cost is quantified by the amount of dirt moved times the moving distance. A simple case where the probability domain is discrete is presented below.


Step-by-step plan of moving dirt between piles in P and Q to make them match (W = 5).


When dealing with the continuous probability domain, the EM distance or the minimum one among the costs of all dirt moving solutions becomes: \begin{align} \small W(p_r, p_g) = \underset{\gamma\sim\Pi(p_r, p_g)} {\inf}\pmb{\mathbb{E}}_{(x,y)\sim\gamma}[\parallel x-y\parallel] \end{align}

Where [math]\displaystyle{ \Pi(p_r, p_g) }[/math] is the set of all joint probability distributions with marginals [math]\displaystyle{ p_r }[/math] and [math]\displaystyle{ p_g }[/math]. Here the distribution [math]\displaystyle{ \gamma }[/math] is called a transport plan because its marginal structure gives some intuition that it represents the amount of probability mass to be moved from x to y. This intuition can be explained by looking at the following equation.

\begin{align} \int\gamma(x, y)dx = p_g(y) \end{align} Which means that the total amount of dirt moved to point [math]\displaystyle{ y }[/math] is [math]\displaystyle{ p_g(y) }[/math]. Similarly, we have:

\begin{align} \int\gamma(x, y)dy = p_r(x) \end{align} Which means that the total amount of dirt moved out of point [math]\displaystyle{ x }[/math] is [math]\displaystyle{ p_r(x)/math\gt The Wasserstein distance or the cost of Optimal Transport (OT) provides a much weaker topology, which informally means that it makes it easier for a sequence of distribution to converge as compared to other ''f''-divergences. This is particularly important in applications where data is supported on low dimensional manifolds in the input space. As a result, stronger notions of distances such as KL-divergence, often max out, providing no useful gradients for training. In contrast, OT has a much nicer linear behaviour even upon saturation. It can be shown that the Wasserstein distance has guarantees of continuity and differentiability (Arjovsky et al., 2017). Moreover, Arjovsky et al. show there is a nice relationship between the magnitude of the Wasserstein distance and the distance between distributions; a smaller distance nicely corresponds to a smaller distance between the two distributions, and vice versa. ==Problem Formulation and Notation== In this paper, calligraphic letters, i.e. \lt math\gt \small {\mathcal{X}} }[/math], are used for sets, capital letters, i.e. [math]\displaystyle{ \small X }[/math], are used for random variables and lower case letters, i.e. [math]\displaystyle{ \small x }[/math], for their values. Probability distributions are denoted with capital letters, i.e. [math]\displaystyle{ \small P(X) }[/math], and corresponding densities with lower case letters, i.e. [math]\displaystyle{ \small p(x) }[/math].

This work aims to minimize OT [math]\displaystyle{ \small W_c(P_X, P_G) }[/math] between the true (but unknown) data distribution [math]\displaystyle{ \small P_X }[/math] and a latent variable model [math]\displaystyle{ \small P_G }[/math] specified by the prior distribution [math]\displaystyle{ \small P_Z }[/math] of latent codes [math]\displaystyle{ \small Z \in \pmb{\mathbb{Z}} }[/math] and the generative model [math]\displaystyle{ \small P_G(X|Z) }[/math] of the data points [math]\displaystyle{ \small X \in \pmb{\mathbb{X}} }[/math] given [math]\displaystyle{ \small Z }[/math].

Kantorovich's formulation of the OT problem is given by: \begin{align} \small W_c(P_X, P_G) := \underset{\Gamma\sim {\mathcal{P}}(X \sim P_X, Y \sim P_G)}{\inf} {\pmb{\mathbb{E}}_{(X,Y)\sim\Gamma}[c(X,Y)]} \end{align} where [math]\displaystyle{ \small c(x,y) }[/math] is any measurable cost function and [math]\displaystyle{ \small {\mathcal{P}(X \sim P_X,Y \sim P_G)} }[/math] is a set of all joint distributions of [math]\displaystyle{ \small (X,Y) }[/math] with marginals [math]\displaystyle{ \small P_X }[/math] and [math]\displaystyle{ \small P_G }[/math]. When [math]\displaystyle{ \small c(x,y)=d(x,y) }[/math], the following Kantorovich-Rubinstein duality holds for the [math]\displaystyle{ \small 1^{st} }[/math] root of [math]\displaystyle{ \small W_c }[/math]: \begin{align} \small W_1(P_X, P_G) := \underset{f \in {\mathcal{F_L}}} {\sup} {\pmb{\mathbb{E}}_{X \sim P_X}[f(X)]} -{\pmb{\mathbb{E}}_{Y \sim P_G}[f(Y)]} \end{align} where [math]\displaystyle{ \small {\mathcal{F_L}} }[/math] is the class of all bounded Lipschitz continuous functions. A reference that provides an intuitive explanation for how the Kantorovich-Rubinstein duality was applied in this case is here.

Wasserstein Auto-Encoders

The proposed method focuses on latent variables [math]\displaystyle{ \small P_G }[/math] defined by a two step procedure, where first a code [math]\displaystyle{ \small Z }[/math] is sampled from a fixed prior distribution [math]\displaystyle{ \small P_Z }[/math] on a latent space [math]\displaystyle{ \small {\mathcal{Z}} }[/math] and then [math]\displaystyle{ \small Z }[/math] is mapped to the image [math]\displaystyle{ \small X \in {\mathcal{X}} }[/math] with a transformation. This results in a density of the form \begin{align} \small p_G(x) := \int_{{\mathcal{Z}}} p_G(x|z)p_z(z)dz, \forall x\in{\mathcal{X}} \end{align} assuming all the densities are properly defined. It turns out that if the focus is only on generative models deterministically mapping [math]\displaystyle{ \small Z }[/math] to [math]\displaystyle{ \small X = G(Z) }[/math], then the OT cost takes a much simpler form as stated below by Theorem 1.

Theorem 1 For any function [math]\displaystyle{ \small G:{\mathcal{Z}} \rightarrow {\mathcal{X}} }[/math], where [math]\displaystyle{ \small Q(Z) }[/math] is the marginal distribution of [math]\displaystyle{ \small Z }[/math] when [math]\displaystyle{ \small X \in P_X }[/math] and [math]\displaystyle{ \small Z \in Q(Z|X) }[/math], \begin{align} \small \underset{\Gamma\sim {\mathcal{P}}(X \sim P_X, Y \sim P_G)}{\inf} {\pmb{\mathbb{E}}_{(X,Y)\sim\Gamma}[c(X,Y)]} = \underset{Q : Q_z=P_z}{\inf} {{\pmb{\mathbb{E}}_{P_X}}{\pmb{\mathbb{E}}_{Q(Z|X)}}[c(X,G(Z))]} \end{align} This essentially means that instead of finding a coupling [math]\displaystyle{ \small \Gamma }[/math] between two random variables living in the [math]\displaystyle{ \small {\mathcal{X}} }[/math] space, one distributed according to [math]\displaystyle{ \small P_X }[/math] and the other one according to [math]\displaystyle{ \small P_G }[/math], it is sufficient to find a conditional distribution [math]\displaystyle{ \small Q(Z|X) }[/math] such that its [math]\displaystyle{ \small Z }[/math] marginal [math]\displaystyle{ \small Q_Z(Z) := {\pmb{\mathbb{E}}_{X \sim P_X}[Q(Z|X)]} }[/math] is identical to the prior distribution [math]\displaystyle{ \small P_Z }[/math]. In order to implement a numerical solution to Theorem 1, the constraints on [math]\displaystyle{ \small Q(Z|X) }[/math] and [math]\displaystyle{ \small P_Z }[/math] are relaxed and a penalty function is added to the objective leading to the WAE objective function given by:

\begin{align} \small D_{WAE}(P_X, P_G):= \underset{Q(Z|X) \in Q}{\inf} {{\pmb{\mathbb{E}}_{P_X}}{\pmb{\mathbb{E}}_{Q(Z|X)}}[c(X,G(Z))]} + {\lambda} {{\mathcal{D}}_Z(Q_Z,P_Z)} \end{align} where [math]\displaystyle{ \small Q }[/math] is any non-parametric set of probabilistic encoders, [math]\displaystyle{ \small {\mathcal{D}}_Z }[/math] is an arbitrary divergence between [math]\displaystyle{ \small Q_Z }[/math] and [math]\displaystyle{ \small P_Z }[/math], and [math]\displaystyle{ \small \lambda \gt 0 }[/math] is a hyperparameter. The authors propose two different penalties [math]\displaystyle{ \small {\mathcal{D}}_Z(Q_Z,P_Z) }[/math] based on adversarial training (GANs) and maximum mean discrepancy (MMD). The authors note that a numerical solution to the dual formulation of the problem has been tried by clipping the weights of the network (to satisfy the Lipschitz condition) and by penalizing the objective with [math]\displaystyle{ \small \lambda \mathbb{E}(\parallel \nabla f(X) \parallel - 1)^2 }[/math]

WAE-GAN: GAN-based

The first option is to choose [math]\displaystyle{ \small {\mathcal{D}}_Z(Q_Z,P_Z) = D_{JS}(Q_Z,P_Z) }[/math], where [math]\displaystyle{ \small D_{JS} }[/math] is the Jensen-Shannon divergence metric, and use adversarial training to estimate it. Specifically a discriminator is introduced in the latent space [math]\displaystyle{ \small {\mathcal{Z}} }[/math] trying to separate true points sampled from [math]\displaystyle{ \small P_Z }[/math] from fake ones sampled from [math]\displaystyle{ \small Q_Z }[/math]. This results in Algorithm 1. It is interesting that the min-max problem is moved from the input pixel space to the latent space.


WAE-MMD: MMD-based

For a positive definite kernel [math]\displaystyle{ \small k: {\mathcal{Z}} \times {\mathcal{Z}} \rightarrow {\mathcal{R}} }[/math], the following expression is called the maximum mean discrepancy: \begin{align} \small {MMD}_k(P_Z,Q_Z) = \parallel \int_{{\mathcal{Z}}} k(z,\cdot)dP_z(z) - \int_{{\mathcal{Z}}} k(z,\cdot)dQ_z(z) \parallel_{\mathcal{H}_k}, \end{align}

where [math]\displaystyle{ \mathcal{H}_k }[/math] is the reproducing kernel Hilbert space of real-valued functions mapping [math]\displaystyle{ \mathcal{Z} }[/math] to [math]\displaystyle{ \mathcal{R} }[/math]. This can be used as a divergence measure and the authors propose to use [math]\displaystyle{ \small {\mathcal{D}}_Z(Q_Z,P_Z) = MMD_k(P_Z,Q_Z) }[/math], which leads to Algorithm 2.


Comparison with Related Work

Auto-Encoders, VAEs and WAEs

Classical unregularized encoders only minimized the reconstruction cost, and resulted in training points being chaotically scattered across the latent space with holes in between, where the decoder had never been trained. They were hard to sample from and did not provide a useful representation. VAEs circumvented this problem by maximizing a variational lower-bound term comprising of a reconstruction cost and a KL-divergence measure which captures how distinct each training example is from the prior [math]\displaystyle{ \small P_Z }[/math]. This however does not guarantee that the overall encoded distribution [math]\displaystyle{ \small {{\pmb{\mathbb{E}}_{P_X}}}[Q(Z|X)] }[/math] matches [math]\displaystyle{ \small P_Z }[/math]. This is ensured by WAE however, is a direct consequence of our objective function derived from Theorem 1, and is visually represented in the figure below. It is also interesting to note that this also allows WAE to have deterministic encoder-decoder pairs.


WAE and VAE regularization


It is also shown that if [math]\displaystyle{ \small c(x,y)={\parallel x-y \parallel}_2^2 }[/math], WAE-GAN is equivalent to adversarial autoencoders (AAE). Thus the theory suggests that AAE minimize the 2-Wasserstein distance between [math]\displaystyle{ \small P_X }[/math] and [math]\displaystyle{ \small P_G }[/math].

OT, W-GAN and WAE

The Wasserstein GAN (W-GAN) minimizes the 1-Wasserstein distance [math]\displaystyle{ \small W_1(P_X,P_G) }[/math] for generative modeling. The W-GAN formulation is approached from the dual form and thus it cannot be applied to another other cost [math]\displaystyle{ \small W_c }[/math] as the neat form of the Kantorovich-Rubinstein duality holds only for [math]\displaystyle{ \small W_1 }[/math]. WAE approaches the same problem from the primal form, can be applied to any cost function [math]\displaystyle{ \small c }[/math] and comes naturally with an encoder. The constraint on OT in Theorem 1, is relaxed in line with theory on unbalanced optimal transport by adding a penalty or additional divergences to the objective.

GANs and WAEs

Many of the GAN variations including f-GAN and W-GAN come without an encoder. Often it may be desirable to reconstruct the latent codes and use the learned manifold in which case they won't be applicable. For works which try to blend adversarial auto-encoder structures, encoders and decoders do not have incentive to be reciprocal. WAE does not necessarily lead to a min-max game and has a clear theoretical foundation for using penalties for regularization.

Experimental Results

The authors empirically evaluate the proposed WAE generative model by specifically testing if data points are accurately reconstructed, if the latent manifold has reasonable geometry, and if random samples of good visual quality are generated.

Experimental setup: Gaussian prior distribution [math]\displaystyle{ \small P_Z }[/math] and squared cost function [math]\displaystyle{ \small c(x,y) }[/math] are used for data-points. The encoder-decoder pairs were deterministic. The convolutional deep neural network for encoder mapping and decoder mapping are similar to DC-GAN with batch normalization. Real world datasets, MNIST with 70k images and CelebA with 203k images were used for training and testing. For interpolations a pair of of held out images, [math]\displaystyle{ (x,y) }[/math] from the test set are Auto-encoded (separately), to produce [math]\displaystyle{ (z_x, z_y) }[/math] in the latent space. The elements of the latent space are linearly interpolated and decoded to produce the images below.

WAE-GAN and WAE-MMD: In WAE-GAN, the discriminator [math]\displaystyle{ \small D }[/math] composed of several fully connected layers with ReLu activations. For WAE-MMD, the RBF kernel failed to penalize outliers and thus the authors resorted to using inverse multiquadratics kernel [math]\displaystyle{ \small k(x,y)=C/(C+\parallel{x-y}_2^2\parallel) }[/math]. Trained models are presented in the figure below. As far as random sampled results are concerned, WAE-GAN seems to be highly unstable but do lead to better matching scores among WAE-GAN, WAE-MMD and VAE. WAE-MMD on the other hand has much more stable training and fairly good quality of sampled results.

Qualitative assessment: In order to quantitatively assess the quality of the generated images, they use the Fréchet Inception Distance and report the results on CelebA (The Fréchet Inception Distance measures the similarity between two sets of images, by comparing the Fréchet distance of multivariate Gaussian distributions fitted to their feature representations. In more detail, let [math]\displaystyle{ (m,C) }[/math] denote the mean vector and covariance matrix of the features of the inception network (Szegedy et al. 2017) applied to model samples. Let [math]\displaystyle{ (m_w,C_w) }[/math] denote the mean vector and covariance matrix of the features of the inception network applied to real data. Then the Fréchet Inception Distance between the model samples and the real data is [math]\displaystyle{ ||m-m_w||^2 +\mathrm{tr}(C+C_w-2(CC_w)^{\frac{1}{2}} )\, }[/math] (Heusel et al. 2017). ) These results confirm that the sampled images from WAE are of better quality than from VAE (score: 82), and WAE-GAN gets a slightly better score (score:42) than WAE-MMD (score:55), which correlates with visual inspection of the images.

Results on MNIST and Celeb-A dataset. In "test reconstructions" (middle row of images), odd rows correspond to the real test points.


The authors also heuristically evaluate the sharpness of generated samples using the Laplace filter. The numbers, summarized in Table1, show that WAE-MMD has samples of slightly better quality than VAE, while WAE-GAN achieves the best results overall.

Qualitative Assessment of Images

Network structures:

The Encoder, Decoder, and Adversary architectures used for the MNIST and CelebA datasets are as sown in the following two images:

Network architectures used to evaluate on the MNIST dataset.
Network architectures used to evaluate on the CelebA dataset.

Commentary and Conclusion

This paper presents an interesting theoretical justification for a new family of auto-encoders called Wasserstein Auto-Encoders (WAE). The objective function minimizes the optimal transport cost in the form of the Wasserstein distance, but relaxes theoretical constraints to separate it into a reconstruction cost and a regularization penalty. The regularization penalizes divergences between a prior and the distribution of encoded latent space training data, and is estimated by means of adversarial training (WAE-GAN), or kernel-based techniques (WAE-MMD). They show that they achieve samples of better visual quality than VAEs, while achieving stable training at the same time. They also theoretically show that WAEs are a generalization of adversarial auto-encoders (AAEs).

Although the paper mentions that encoder-decoder pairs can be deterministic, they do not show the geometry of the latent space that is obtained. It is necessary to study the effect of randomness of encoders on the quality of obtained samples. While this method is evaluated on MNIST and CelebA datasets, it is also important to see their performance on other real world data distributions. The authors do not provide a comprehensive evaluation of WAE-GAN regularization, thus making it hard to comment on whether moving an adversarial problem to the latent space results in less instability. Reasons for better sample quality of WAE-GAN over WAE-MMD also need to be inspected. In the future it would be interesting to investigate different ways to compute the divergences between the encoded distribution and the prior distribution.

Open Source Code

1. https://github.com/tolstikhin/wae

2. https://github.com/maitek/waae-pytorch

Sources

1. M. Arjovsky, S. Chintala, and L. Bottou. Wasserstein GAN, 2017

2. Martin Heusel et al. "Gans trained by a two time-scale update rule converge to a local nash equilibrium." Advances in Neural Information Processing Systems. 2017.

3. Christian Szegedy et al. "Inception-v4, inception-resnet and the impact of residual connections on learning." AAAI. Vol. 4. 2017.

4. Ilya Tolstikhin, Olivier Bousquet, Sylvain Gelly, Bernhard Scholkopf. Wasserstein Auto-Encoders, 2017

5. https://lilianweng.github.io/lil-log/2017/08/20/from-GAN-to-WGAN.html