Article review: Overcoming Catastrophic Forgetting in Neural Networks

Catastrophic forgetting is the tendency of a neural network that is trained on a task to degrade in performance when applied to a slightly modified version of the task. A remedy is proposed in the paper
“Overcoming catastrophic forgetting in neural networks” , by Kirkpatrick et al. The authors begin by describing the synaptic consolidation phenomenon in mammalian learning: synapses involved in learning a new task are allowed to have varying degrees of plasticity, enabling long-term memory of previously learned tasks to persist, even as a new, related task is learned. Inspired, they developed an analogous machine learning technique called elastic weights consolidation (EWC), which effectively mitigates against catastrophic forgetting, as they demonstrate using perturbed MNIST datasets.

EWC works by assigning varying degrees of rigidity to weights learned on an initial task, when learning a new, similar task. When a network that is trained on an initial task A is asked to learn a modified version (task B), previously learned weights are allowed to exert an “elastic force” on the values of the new weights, constraining them in proportion to how “important” the original weights were for the task A.

EWC is an application of Bayes’ theorem, leading to an approximate representation of the loss function for task B in terms of prior information learned in task A. The main drawback of EWC is that the new loss function depends on a Laplace approximation, which is locally Gaussian with a diagonal covariance matrix, and so may not be applicable in all situations.

I think that EWC may be applicable to time series problems (such as the Kaggle Rossmann Store Sales competition) where we want to model time-varying trends. A time series can be broken into contiguous segments A, B, C, … A model is first trained on segment A, then updated on segment B making use of the EWC loss function, so that it can predict the trend in times subsequent to B. Continuing, the model can then be successively updated on segments C, D, … up to the present time, making use of the EWC loss function at each update. The end product is a predictive model that should capture the (possible time-varying) trends in the entire time series.

We can keep evolving the model by successively applying EWC to update the model at regular intervals, with each new batch of contiguous time series data as it comes in. In this way, we could develop an adaptive predictive model, capable of continual learning, with long-term memory as well as generalizability to new data.