orthogonal gradient descent for continual learning
Mehrdad Farajtabar, Navid Azizan, Alex Mott, Ang Li
Neural Networks suffer from catastrophic forgetting: forgetting previously learned tasks when trained to do new ones. Most neural networks can’t learn tasks sequentially despite having capacity to learn them simultaneously. For example, training a CNN to look at only one label of CIFAR10 at a time results in poor performance for the initially trained labels (catastrophic forgetting). But that same CNN will perform really well if all the labels are trained simultaneously (as is standard). The ability to learn tasks sequentially is called continual learning, and it is crucially important for real world applications of machine learning. For example, a medical imaging classifier might be able to classify a set of base diseases very well, but its utility is limited if it cannot be adapted to learn novel diseases - like local/rare/or new diseases (like Covid-19).
This work introduces a new learning algorithm called Orthogonal Gradient Descent (OGD) that replaces Stochastic Gradient Descent (SGD). In standard SGD, the optimization takes no care to retain performance on any previously learned tasks, which works well when the task is presented all at once and iid. However, in a continual learning setting, when tasks/labels are presented sequentially, SGD does not perform well - as will be shown in the results. OGD considers previously learned tasks by maintaining a space of previous gradients, such that incoming gradients can be projected onto an orthogonal basis of that space - minimally impacting previously attained performance.
Previous work in continual learning can be summarized into three broad categories. There are expansion based techniques, which add neurons/modules to an existing model to accommodate incoming tasks while leveraging previously learned representations. One of the downsides of this method is the growing size of the model with increasing number of tasks. There are also regularization based methods, which constraints weight updates according to some importance measure for previous tasks. Finally, there are the repetition based methods. These models attempt to artificially interlace data from previous tasks into the training scheme of incoming tasks, mimicking traditional simultaneous learning. This can be done by training memory modules or generative networks.
 First Reference