Difference between revisions of "conditional neural process"

Introduction

To train a model effectively, deep neural networks require large datasets. To mitigate this data efficiency problem, learning in two phases is one approach : the first phase learns the statistics of a generic domain without committing to a specific learning task; the second phase learns a function for a specific task, but does so using only a small number of data points by exploiting the domain-wide statistics already learned.

For example, consider a data set $\{x_i, y_i\}$ with evaluations $y_i = f(x_i)$ for some unknown function $f$. Assume $g$ is an approximating function of f. The aim is yo minimize the loss between $f$ and $g$ on the entire space $X$. In practice, the routine is evaluated on a finite set of observations.

In this work, they proposed a family of models that represent solutions to the supervised problem, and ab end-to-end training approach to learning them, that combine neural networks with features reminiscent if Gaussian Process. They call this family of models Conditional Neural Processes.

Model

Let training set be $O = \{x_i, y_i\}_{i = 0} ^ n-1$, and test set be $T = \{x_i, y_i\}_{i = n} ^ {n + m - 1}$.

P be a probability distribution over functions $F : X \to Y$, formally known as a stochastic process. Thus, P defines a joint distribution over the random variables ${f(x_i)}_{i = 0} ^{n + m - 1}$. Therefore, for $P(f(x)|O, T)$, our task is to predict the output values $f(x_i)$ for $x_i \in T$, given $O$,

Conditional Neural Process

Conditional Neural Process models directly parametrize conditional stochastic processes without imposing consistency with respect to some prior process. CNP parametrize distributions over $f(T)$ given a distributed representation of $O$ of fixed dimensionality. Thus, the mathematical guarantees associated with stochastic processes is traded off for functional flexibility and scalability.

CNP is a conditional stochastic process $Q_\theta$ defines distributions over $f(x_i)$ for $x_i \in T$. For stochastic processs, we assume $Q_theta$ is invariant to permutations, and in this work, we generally enforce permutation invariance with respect to $T$ be assuming a factored structure. That is, $Q_theta(f(T) | O, T) = \prod _{x \in T} Q_\theta(f(x) | O, x)$

In detail, we use the following archiecture

$r_i = h_\theta(x_i, y_i)$ for any $(x_i, y_i) \in O$, where $h_\theta : X \times Y \to \mathbb{R} ^ d$

$r = r_i * r_2 * ... * r_n$, where $*$ is a commutative operation that takes elements in $\mathbb{R}^d$ and maps them into a single element of $\mathbb{R} ^ d$

$\Phi_i = g_\theta$ for any $x_i \in T$, where $g_\theta : X \times \mathbb{R} ^ d \to \mathbb{R} ^ e$ and $\Phi_i$ are parameters for $Q_\theta$

Note that this architecture ensures permutation invariance and $O(n + m)$ scaling for conditional prediction. Also, $r = r_i * r_2 * ... * r_n$ can be computed in $O(n)$, this architecture supports streaming observation with minimal overhead.

We train $Q_\theta$ by asking it to predict $O$ conditioned on a randomly chosen subset of $O$. This gives the model a signal of the uncertainty over the space X inherent in the distribution P given a set of observations. Thus, the targets it scores $Q_\theta$ on include both the observed and unobserved values. In practice, we take Monte Carlo estimates of the gradient of this loss by sampling f and N. This approach shifts the burden of imposing prior knowledge

from an analytic prior to empirical data. This has the advantage of liberating a practitioner from having to specify an analytic form for the prior, which is ultimately intended to summarize their empirical experience. Still, we emphasize that the $Q_\theta$ are not necessarily a consistent set of conditionals for all observation sets, and the training routine does not guarantee that.

In summary,

1. A CNP is a conditional distribution over functions trained to model the empirical conditional distributions of functions f ∼ P.

2. A CNP is permutation invariant in O and T.

3. A CNP is scalable, achieving a running time complexity of O(n + m) for making m predictions with n observations.