Gradient Episodic Memory for Continual Learning
Contents
Presented by
- Yu Xuan Lee
- Tsen Yee Heng
Background and Introduction
Supervised learning consist of a training set [math]D_{tx}=(x_i,y_i)^n_{i=1}[/math], where [math]x_i \in \mathcal{X} [/math] and [math]y_i \in \mathcal{Y} [/math]. Empirical Risk Minimization (ERM) is one of the common supervised learning method used to minimize a loss function by having multiple passes over the training set.
- [math] \frac{1}{|D_{tr}|}\textstyle \sum_{(x_i,y_i) \in D_{tr}} \ell (f(x_i),y_i) [/math]
where [math]\ell :\mathcal {Y} \times \mathcal {Y} \to [0, \infty)[/math]
Different to machine learning, datas are being observed sequentially, occurred recurrently, and stored limitedly for learning humans. Thus, the iid assumption is not applicable. One of the characteristics of ERM is "catastrophic forgetting", which is the problem of recalling past knowledge upon acquiring new ones. To overcome this problem, Gradient Episodic Memory (GEM) is introduced to alleviates forgetting on previous acquired knowledge, while solving new problems more efficiently.
Framework for Continual Learning
The feature vector [math]x_i \in \mathcal{X}_t[/math], task descriptor [math]t_i \in \mathcal{T} [/math], and target vector [math]y_i \in \mathcal{Y}_t [/math] are the three main components of a continuum of data. Note that the continuum is locally iid where for every [math](x_i, t_i, y_i)[/math]
- [math] (x_i,y_i) \overset{iid}{\sim} P_{t_i}(X,Y) [/math]
The main mathematical purpose of continual learning is to obtain [math]f: \mathcal{X} \times \mathcal{Y} [/math] where a target vector [math]y[/math] must be inquired using a test pair [math](x,t)[/math].
Task Descriptor
Task descriptors are structured objects, describing how to solve each [math]i[/math]-th task. They are integers [math]t_i=i \in \mathbb{Z} [/math] which occurs in a collection where [math]t_1,...,t_n \in \mathcal{T}[/math]. Most importantly, they could distinguish every same input [math]x_i[/math] that have different target. To conclude, task descriptors plays the part of carrying crucial information of the example and distinguishing different learning environment for similar examples.
Training Protocol
The target setting for continual learning are as follow:
- Large task quantity
- Small quantity of training examples for each task
- Examples for each tasks being observed only once
- Outcome of transfer and forgetting being concluded
To perform this, each example were only given once to the learner in one at a time in sequence. In this case, learner gets information in [math](x_i,t_i,y_i)[/math] form with no duplication.
Evaluation Metrics
The capability of transferring knowledge across tasks are very important in addition to results across each tasks. First of all, transferring knowledge are categorized as follow:
- Backward transfer (BWT) This is the difference of judgement towards previously encountered task [math]k[/math] after learning new task [math]t[/math], noted as [math] k \prec t [/math]. Within backward transfer, there are two categories, positive backward transfer and negative backward transfer. Positive backward transfer shows a better judgement towards previously encountered task [math]k[/math] after learning new task [math]t[/math]. Contrarily, negative backward transfer shows the opposite. Also, do note that catastrophic forgetting happens due to extensive negative backward transfer.
- Forward transfer (FWT) Opposite to BWT, FWT shows judgement towards new task [math]t[/math] after learning task [math]k[/math], noted as [math] k \succ t [/math]. Positive forward transfer is one way of forward transfer.
Given a test set of [math]T[/math], we would learn task [math]t_i[/math] and observe its performance towards all [math]T[/math] tasks. A matrix [math]R_{i,j}[/math] as test classification accuracy of the model on task [math]t_j[/math] after observing the last sample from task [math]t_j[/math] is constructed, where [math]R \in \mathbb{R} ^{T \times T} [/math]. Note that [math] \bar b\ [/math] is the vector of test accuracies for each task at random initialization. The function for Average Accuracy (ACC), Backward Transfer (BWT) and Forward Transfer (FWT) are shown below:
- [math] ACC = \frac{1}{T} \sum_{i=1}^T R_{T,i} [/math]
- [math] BWT = \frac{1}{T-1} \sum_{i=1}^{T-1} R_{T,i} - R_{i,i} [/math]
- [math] FWT = \frac{1}{T-1} \sum_{i=2}^{T} R_{i-1,i- \bar b\ _i} [/math]
Note that if ACC happens to be similar for both models, model with higher BWT and FWT values are more desired.
Gradient Episodic Memory (GEM)
Episodic memory [math]M_t[/math] is very important in GEM, it contains information on examples on task [math]t[/math] and it is indicated from the integer task descriptors. So practically, we would minimize catastrophic forgetting by using the episodic memory efficiently. Note that learner is assumed to have limited memory locations [math]M[/math]. Hence, the amount located for each task is calculated as [math]m=\frac{M}{T}[/math]. To calculate the loss of memories from the [math]k[/math]-th task, assuming predictors [math]f_ \theta [/math] parameterized by [math] \theta \in \mathbb{R} ^p [/math], we have the following equation:
- [math] \ell (f_\theta, \mathcal{M}_k)=\frac{1}{|\mathcal{M}_k|} \sum_{(x_i,k,y_i) \in \mathcal{M}_k} \ell(f_ \theta (x_i,k),y_i) [/math]
The above equation will be treated as inequality constraint and a decrease in the equation would be in favour instead of increase. So we would use [math](x,t,y)[/math] to minimize the following equation:
- [math] mimimize_\theta \space \space \ell(f_\theta(x,t),y) [/math]
- [math] subject\space to \space \space \ell (f_\theta,\mathcal{M}_k) \le \ell(f_\theta^{t-1},\mathcal{M}_k) \space\space for \space all \space k\lt t [/math]
where [math]f_\theta^{t-1}[/math] is the predictor state at the end of learning of task [math]t-1[/math].
To efficiently solve the above equation, three ideas are proposed:
- Unnecessary to store old predictors [math]f_\theta^{t-1}[/math].
- Functions are locally linear.
- Loss of previous tasks could be calculated using the angle between loss gradient vector and proposed update.
With the above ideas, the loss function is further improved as follow:
- [math] \langle g,g_k \rangle := \langle \frac{\partial \ell(f_\theta(x,t),y)}{\partial \theta}, \frac{\partial \ell(f_\theta,\mathcal{M}_k)}{\partial \theta} \rangle \ge 0, \space for \space all \space k\lt t. [/math]
However, if there is at least one violation in the equality constraint, we would overcome this by projecting the gradient [math]g[/math] to the closest gradient [math]\tilde{g}[/math] satisfying all the constraints. The optimization problem becomes
- [math] minimize_{ \tilde{g} } \space \space \frac{1}{2}\parallel g - \tilde{g} \parallel _2^2 [/math]
- [math] subject \space to \space \space \langle \tilde{g},g_k \rangle \ge 0 \space \space for \space all \space k\lt t [/math]
Therefore, the primal GEM Quadratic Program (QP) is
- [math] minimize_z \space \space \frac{1}{2}z^Tz - g^Tz+\frac{1}{2}g^Tg [/math]
- [math] subject \space to \space \space Gz \ge 0, [/math]
Dual of the GEM QP is
- [math] minimize_v \space \space \frac{1}{2}v^TGG^Tv + g^TG^Tv [/math]
- [math] subject \space to \space \space v\ge 0 [/math]
By solving [math]v^*[/math], we could obtain the projected gradient update [math]\tilde{g}=G^Tv^* + g[/math]. The algorithm is as follow:
Experiment
We perform a variety of experiments to assess the performance of GEM in continual learning.
Dataset
We consider the following dataset:
- MNIST Permutations. A group of datas consist of handwritten digits. It is transformed by pixels.
- MNIST Rotation A group of datas consist of handwritten digits rotated by a fixed angle between 0 and 180 degree.
- CIFAR100 A group of CIFAR dataset with 100 classes, which each task introduces a new set of classes. For a total of T tasks, each new task concerns examples from a disjoint subset of 100/T classes. Input distribution are similar but output distribution required different.
For all the datasets, we consider T = 20 tasks. On the MNIST dataset, there are 1000 samples for each task from 10 different classes and 2500 examples from 5 different classes for the CIFAR100 datasets. The tasks are observed in sequence and the example are going through by once only.
Architectures
A neural network with two hidden layers of 100 ReLU units is used on this model. We use a smaller version of ResNet18 in the CIFAR100, with three times less feature maps across all layers. This is one simple way to leverage the task descriptor, in order to adapt the output distribution to the subset of classes for each task. Plain SGD is being used on mini batches of 10 samples in order to train the network. All hyper-parameter are optimized through grid-search, and the best result for each model are being listed.
Methods
We compare GEM to 5 alternatives:
- A single predictor trained across all tasks.
- One independent predictor is being used for each task. Each of them has the same architecture as "single" but with T times less hidden units than "single".
- A multimodal predictor with a dedicated input layer per task (only for MNIST).
- EWC, where the loss is regularized to avoid catastrophic forgetting.
- iCARL, a class-incremented learner that classifies using a nearest exemplar algorithm, and using an episodic memory to prevent catastrophic forgetting. iCARL method only applied to our experiment on CIFAR100.
Notes : GEM, EWC, iCARL has the same architecture as "single", with episodic memory.
Results
Figure 1 (left) shows the average accuracy, BWT and FWT. Overall, GEM performs similarly or better than multimodal model. GEM minimized backward transfer, while exhibiting negligible or positive forward transfer.
Figure 1(right) shows the evolution of the test accuracy of the first task throughout the continuum of data. GEM exhibits minimal forgetting, and achieve positive BWT in CIFAR100.
Overall, GEM performs slightly better than other continuing method like EWC, while spending less computational. GEM efficiency comes from optimizing over a number of variables equal to the number of tasks, instead of optimizing over a number of variables equal to the number of parameters.
Table 2 shows the final ACC for both GEM and iCARL in the dataset CIFAR100, as a function of their episodic memory size. As we can see from table 2, GEM outperform iCARL for a wide range of memory sizes.
Table 3 shows that the importance of memory as we do more than one pass through the data on the MNIST Rotation experiment. As we can see from table 3, GEM and EWC (memory-based) has a higher ACC as the number of passes through the data increases, and memory-less method causes higher negative BWT, which has lower accuracy. By comparing the first and last row of table 3, GEM matches the condition of “oracle performance upper-bound” ACC in iid learning, and minimizes negative BWT, which has higher ACC.
Related Work
Continual learning is a method that keep the knowledge from the past task and apply them to acquire new skills through all the task. In this work, we use continuing learning to focus more on the realistic setting where examples are only seen once. We introduced the GEM, which outperforms every model in limiting forgetting.
There are some ways to avoid catastrophic forgetting in which one of those ways is to freeze the early layer in the neural network and duplicate the later layers on new tasks. However, it is hard to use these methods and there are too many modules and tasks. We want to use a single model and modify the learning objective to avoid catastrophic forgetting. There is some other method that use synaptic memory, which minimized the important parameters for the previous task and there is another method which use episodic memory stored and replayed the examples from previous tasks to prevent forgetting and allow positive backward transfer.
Conclusion
The scenario of continuing learning is being formalized. Firstly, the training and evaluation protocol are being defined to assess the quality of models in term of their accuracy, and the ability to transfer knowledge with BWT and FWT. Secondly, GEM is introduced as a model which leverage the episodic memory to get a positive BWT and avoid forgetting efficiently. We demonstrate the effectiveness of GEM by comparing GEM to other model such as EWC, multimodal, etc.
However, there is some improvement has to be done to GEM. Firstly, GEM might obtain positive FWT because it do not leverage structured task descriptor. Secondly, there are too many complicated structures in advance memory management which we did not investigate. Thirdly, the iteration of GEM increase the computational time.
Reference
- ^{[1]} David Lopez-Paz and Marc'Aurelio Ranzato. Gradient Episodic Memory for Continual Learning.