conditional neural process

From statwiki
Revision as of 13:02, 19 November 2018 by S366chen (talk | contribs) (Introduction)
Jump to: navigation, search


To train a model effectively, deep neural networks require large datasets. To mitigate this data efficiency problem, learning in two phases is one approach : the first phase learns the statistics of a generic domain without committing to a specific learning task; the second phase learns a function for a specific task, but does so using only a small number of data points by exploiting the domain-wide statistics already learned.

In their work, they proposed a family of models that represent solutions to the supervised problem, and ab end-to-end training approach to learning them, that combine neural networks with features reminiscent if Gaussian Process. They call this family of models Conditional Neural Processes.


Let training set be [math] O = \{x_i, y_i\}_{i = 0} ^ n-1[/math], and test set be [math] T = \{x_i, y_i\}_{i = n} ^ {n + m - 1}[/math].

P be a probability distribution over functions [math] F : X \to Y[/math], formally known as a stochastic process. Thus, P defines a joint distribution over the random variables [math] {f(x_i)}_{i = 0} ^{n + m - 1}[/math]. Therefore, for [math] P(f(x)|O, T)[/math], our task is to predict the output values [math]f(x_i)[/math] for [math] x_i \in T[/math], given [math] O[/math],

Conditional Neural Process

Conditional Neural Process models directly parametrize conditional stochastic processes without imposing consistency with respect to some prior process. CNP parametrize distributions over [math]f(T)[/math] given a distributed representation of [math]O[/math] of fixed dimensionality. Thus, the mathematical guarantees associated with stochastic processes is traded off for functional flexibility and scalability.

CNP is a conditional stochastic process [math]Q_\theta[/math] defines distributions over [math]f(x_i)[/math] for [math]x_i \in T[/math]. For stochastic processs, we assume [math]Q_theta[/math] is invariant to permutations, and in this work, we generally enforce permutation invariance with respect to [math]T[/math] be assuming a factored structure. That is, [math]Q_theta(f(T) | O, T) = \prod _{x \in T} Q_\theta(f(x) | O, x)[/math]

In detail, we use the following archiecture

[math]r_i = h_\theta(x_i, y_i)[/math] for any [math](x_i, y_i) \in O[/math], where [math]h_\theta : X \times Y \to \mathbb{R} ^ d[/math]

[math]r = r_i * r_2 * ... * r_n[/math], where [math]*[/math] is a commutative operation that takes elements in [math]\mathbb{R}^d[/math] and maps them into a single element of [math]\mathbb{R} ^ d[/math]

[math]\Phi_i = g_\theta[/math] for any [math]x_i \in T[/math], where [math]g_\theta : X \times \mathbb{R} ^ d \to \mathbb{R} ^ e[/math] and [math]\Phi_i[/math] are parameters for [math]Q_\theta[/math]

Note that this architecture ensures permutation invariance and [math]O(n + m)[/math] scaling for conditional prediction. Also, [math]r = r_i * r_2 * ... * r_n[/math] can be computed in [math]O(n)[/math], this architecture supports streaming observation with minimal overhead.

We train [math]Q_\theta[/math] by asking it to predict [math]O[/math] conditioned on a randomly chosen subset of [math]O[/math]. This gives the model a signal of the uncertainty over the space X inherent in the distribution P given a set of observations. Thus, the targets it scores [math]Q_\theta[/math] on include both the observed and unobserved values. In practice, we take Monte Carlo estimates of the gradient of this loss by sampling f and N. This approach shifts the burden of imposing prior knowledge

from an analytic prior to empirical data. This has the advantage of liberating a practitioner from having to specify an analytic form for the prior, which is ultimately intended to summarize their empirical experience. Still, we emphasize that the [math]Q_\theta[/math] are not necessarily a consistent set of conditionals for all observation sets, and the training routine does not guarantee that.

In summary,

1. A CNP is a conditional distribution over functions trained to model the empirical conditional distributions of functions f ∼ P.

2. A CNP is permutation invariant in O and T.

3. A CNP is scalable, achieving a running time complexity of O(n + m) for making m predictions with n observations.

Experimental Result I: Function Regression

Classical 1D regression task that used as a common baseline for GP is our first example. They generated two different datasets that consisted of functions generated from a GP with an exponential kernel. In the first dataset they used a kernel with fixed parameters, and in the second dataset the function switched at some random point. on the real line between two functions each sampled with different kernel parameters. At every training step they sampled a curve from the GP, select a subset of n points as observations, and a subset of t points as target points. Using the model, the observed points are encoded using a three layer MLP encoder h with a 128 dimensional output representation. The representations are aggregated into a single representation [math]r = \frac{1}{n} \sum r_i[/math] , which is concatenated to [math]x_t[/math] and passed to a decoder g consisting of a five layer MLP.

Two examples of the regression results obtained for each of the datasets are shown in Figure 2. They compared the model to the predictions generated by a GP with the correct hyperparameters, which constitutes an upper bound on our performance. Although the prediction generated by the GP is smoother than the CNP's prediction both for the mean and variance, the model is able to learn to regress from a few context points for both the fixed kernels and switching kernels. As the number of context points grows, the accuracy of the model improves and the approximated uncertainty of the model decreases. Crucially, we see the model learns to estimate its own uncertainty given the observations very accurately. Nonetheless it provides a good approximation that increases in accuracy as the number of context points increases. Furthermore the model achieves similarly good performance on the switching kernel task. This type of regression task is not trivial for GPs whereas in our case we only have to change the dataset used for training

Experimental Result II: Image Completion

Not finished yet. Sorry.....