conditional neural process

From statwiki
Revision as of 02:43, 19 November 2018 by S366chen (talk | contribs) (Conditional Neural Process)
Jump to: navigation, search

Introduction

To train a model effectively, deep neural networks require large datasets. To mitigate this data efficiency problem, learning in two phases is one approach : the first phase learns the statistics of a generic domain without committing to a specific learning task; the second phase learns a function for a specific task, but does so using only a small number of data points by exploiting the domain-wide statistics already learned.

For example, consider a data set [math] \{x_i, y_i\} [/math] with evaluations [math]y_i = f(x_i) [/math] for some unknown function [math]f[/math]. Assume [math]g[/math] is an approximating function of f. The aim is yo minimize the loss between [math]f[/math] and [math]g[/math] on the entire space [math]X[/math]. In practice, the routine is evaluated on a finite set of observations.

In this work, they proposed a family of models that represent solutions to the supervised problem, and ab end-to-end training approach to learning them, that combine neural networks with features reminiscent if Gaussian Process. They call this family of models Conditional Neural Processes.


Model

Let training set be [math] O = \{x_i, y_i\}_{i = 0} ^ n-1[/math], and test set be [math] T = \{x_i, y_i\}_{i = n} ^ {n + m - 1}[/math].

P be a probability distribution over functions [math] F : X \to Y[/math], formally known as a stochastic process. Thus, P defines a joint distribution over the random variables [math] {f(x_i)}_{i = 0} ^{n + m - 1}[/math]. Therefore, for [math] P(f(x)|O, T)[/math], our task is to predict the output values [math]f(x_i)[/math] for [math] x_i \in T[/math], given [math] O[/math],

Conditional Neural Process

Conditional Neural Process models directly parametrize conditional stochastic processes without imposing consistency with respect to some prior process. CNP parametrize distributions over [math]f(T)[/math] given a distributed representation of [math]O[/math] of fixed dimensionality. Thus, the mathematical guarantees associated with stochastic processes is traded off for functional flexibility and scalability.

CNP is a conditional stochastic process [math]Q_\theta[/math] defines distributions over [math]f(x_i)[/math] for [math]x_i \in T[/math]. For stochastic processs, we assume [math]Q_theta[/math] is invariant to permutations, and in this work, we generally enforce permutation invariance with respect to [math]T[/math] be assuming a factored structure. That is, [math]Q_theta(f(T) | O, T) = \prod _{x \in T} Q_\theta(f(x) | O, x)[/math]

In detail, we use the following archiecture

[math]r_i = h_\theta(x_i, y_i)[/math] for any [math](x_i, y_i) \in O[/math]

[math]r = r_i * r_2 * ... * r_n[/math], where [math]*[/math] is a commutative operation that takes elements in [math]\mathbb{R}^d[/math]



networks, ⊕ is a commutative operation that takes elements in R d and maps them into a single element of R d , and φi are parameters for Qθ(f(xi)| O, xi) = Q(f(xi)| φi). Depending on the task the model learns to parametrize a different output distribution. This architecture ensures permutation invariance and O(n + m) scaling for conditional prediction. We note that, since r1 ⊕ . . . ⊕ rn can be computed in O(1) from r1 ⊕ . . . ⊕ rn−1, this architecture supports streaming observations with minimal overhead. For regression tasks we use φi to parametrize the mean and variance φi = (µi , σ2 i ) of a Gaussian distribution N (µi , σ2 i ) for every xi ∈ T. For classification tasks φi parametrizes the logits of the class probabilities pc over the c classes of a categorical distribution. In most of our experiments we take a1 ⊕ . . . ⊕ an to be the mean operation (a1 + . . . + an)/n. 2.3. Training CNPs We train Qθ by asking it to predict O conditioned on a randomly chosen subset of O. This gives the model a signal of the uncertainty over the space X inherent in the distribution P given a set of observations. More precisely, let f ∼ P, O = {(xi , yi)} n−1 i=0 be a set of observations, N ∼ uniform[0, . . . , n − 1]. We condition on the subset ON = {(xi , yi)} N i=0 ⊂ O, the first N elements of O. We minimize the negative conditional log probability L(θ) = −Ef∼P h EN h log Qθ({yi} n−1 i=0 |ON , {xi} n−1 i=0 ) ii (4) Thus, the targets it scores Qθ on include both the observed and unobserved values. In practice, we take Monte Carlo estimates of the gradient of this loss by sampling f and N. This approach shifts the burden of imposing prior knowledge from an analytic prior to empirical data. This has the advantage of liberating a practitioner from having to specify an analytic form for the prior, which is ultimately intended to summarize their empirical experience. Still, we emphasize that the Qθ are not necessarily a consistent set of conditionals for all observation sets, and the training routine does not guarantee that. In summary, 1. A CNP is a conditional distribution over functions trained to model the empirical conditional distributions of functions f ∼ P. 2. A CNP is permutation invariant in O and T. 3. A CNP is scalable, achieving a running time complexity of O(n + m) for making m predictions with n observations. Within this specification of the model there are still some aspects that can be modified to suit specific requirements. The exact implementation of h, for example, can be adapted to the data type. For low dimensional data the encoder can be implemented as an MLP, whereas for inputs with larger dimensions and spatial correlations it can also include convolutions. Finally, in the setup described the model is not able to produce any coherent samples, as it learns to model only a factored prediction of the mean and the variances, disregarding the covariance between target points. This is a result of this particular implementation of the model. One way we can obtain coherent samples is by introducing a latent variable that we can sample from. We carry out some proof-of-concept experiments on such a model in section 4.2.3.