Neural Speed Reading via Skim-RNN

From statwiki
Jump to navigation Jump to search

Group

Leyan Cheng, Mingyan Dai, Jerry Huang, Daniel Jiang

Introduction

In Natural Language Processing, recurrent neural networks (RNNs) are a common architecture used to sequentially ‘read’ input tokens and output a distributed representation for each token. By recurrently updating the hidden state of the neural network, a RNN can inherently require the same computational cost across time. However, when it comes to processing input tokens, it is usually the case that some tokens are less important to the overall representation of a piece of text or a query when compared to others. In particular, when considering question answering, many times the neural network will encounter parts of a passage that is irrelevant when it comes to answering a query that is being asked.

Model

In this paper, the authors introduce a model called 'skim-RNN', which takes advantage of ‘skimming’ less important tokens or pieces of text rather than ‘skipping’ them entirely. This models the human ability to skim through passages, or to spend less time reading parts do not affect the reader’s main objective. While this leads to a loss in the comprehension rate of the text [1], it greatly reduces the amount of time spent reading by not focusing on areas which will not significantly affect efficiency when it comes to the reader's objective.

'Skim-RNN' works by rapidly determining the significance of each input and spending less time processing unimportant input tokens by using a smaller RNN to update only a fraction of the hidden state. When the decision is to ‘fully read’, that is to not skim the text, Skim-RNN updates the entire hidden state with the default RNN cell. Since the hard decision function (‘skim’ or ‘read’) is non-differentiable, the authors use a gumbel-softmax [2] to estimate the gradient of the function, rather than traditional methods such as REINFORCE (policy gradient)[3]. The switching between the two different RNN cells enables Skim-RNN to reduce the total number of operations performed when the skimming rate is high, which often leads to faster inference on CPUs, which makes it very useful for large-scale products and small devices.

The Skim-RNN has the same input and output interfaces as standard RNNs, so it can be conveniently used to speed up RNNs in existing models. In addition, the speed of Skim-RNN can be dynamically controlled at inference time by adjusting a parameter for the threshold for the ‘skim’ decision.

Implementation

Experiment

The effectiveness of Skim-RNN was measured in terms of accuracy and float operation reduction on four classification tasks and a question answering task. These tasks were chosen because they do not require one’s full attention to every detail of the text, but rather ask for capturing the high-level information (classification) or focusing on specific portion (QA) of the text, which a common context for speed reading. The tasks themselves are listed in the table below.

Classification Tasks

In a language classification task, the input was a sequence of words and the output was the vector of categorical probabilities. Each word is embedded into a [math]\displaystyle{ d }[/math]-dimensional vector. We initialize the vector with GloVe [4] to form representations of the words and use those as the inputs for a long short-term memory (LSTM) architecture. A linear transformation on the last hidden state of the LSTM and then a softmax function was applied to obtain the classification probabilities. Adam [5] was used for optimization, with initial learning rate of 0.0001. For Skim-LSTM, [math]\displaystyle{ \tau = \max(0.5, exp(−rn)) }[/math] where [math]\displaystyle{ r = 1e-4 }[/math] and [math]\displaystyle{ n }[/math] is the global training step, following [2]. We experiment on different sizes of big LSTM ([math]\displaystyle{ d \in \{100, 200\} }[/math]) and small LSTM ([math]\displaystyle{ d_0 \in \{5, 10, 20\} }[/math]) and the ratio between the model loss and the skim loss ([math]\displaystyle{ \gamma\in \{0.01, 0.02\} }[/math]) for Skim-LSTM. The batch sizes used were 32 for SST and Rotten Tomatoes, and 128 for others. For all models, early stopping was used when the validation accuracy did not increase for 3000 global steps.

Results

Table 2 shows the accuracy and the computational cost of the Skim-RNN model compared with other standard models. It is evident that the Skim-RNN model produces a speed-up on the computational complexity of the task while maintaining a high degree of accuracy. Figure 2 meanwhile demonstrates the effect of varying the size of the small hidden state as well as the parameter [math]\displaystyle{ \gamma }[/math] on the accuracy and computational cost.

Table 3 shows an example of a classification task over a IMDb dataset, where Skim-RNN with [math]\displaystyle{ d = 200 }[/math], [math]\displaystyle{ d_0 = 10 }[/math], and [math]\displaystyle{ \gamma = 0.01 }[/math] correctly classifies it with high skimming rate (92%). The goal was to classify the review as either positive or negative. The black words are skimmed, and blue words are fully read. The skimmed words are clearly irrelevant and the model learns to only carefully read the important words, such as ‘liked’, ‘dreadful’, and ‘tiresome’.

Question Answering Task

In Stanford Question Answering Dataset, the task was to locate the answer span for a given question in a context paragraph. The effectiveness of Skim-RNN for SQuAD was evaluated using two different models: LSTM+Attention and BiDAF (Seo et al., 2017a). The first model was inspired by most then-present QA systems consisting of multiple LSTM layers and an attention mechanism. This type of model is complex enough to reach reasonable accuracy on the dataset, and simple enough to run well-controlled analyses for the Skim-RNN. The second model wan an open-source model designed for SQuAD, used primarily to show that Skim-RNN could replace RNN in existing complex systems.

Training

Adam was used with an initial learning rate of 0.0005. For stable training, the model was pretrained with a standard LSTM for the first 5k steps , and then fine-tuned with Skim-LSTM.

Results

Table 4 shows the accuracy (F1 and EM) of LSTM+Attention and Skim-LSTM+Attention models as well as VCRNN (Jernite et al., 2017). It can be observed from the table that the skimming models achieve higher or similar accuracy scores compared to the non-skimming models while also reducing the computational cost by more than 1.4 times. In addition, decreasing layers (1 layer) or hidden size ([math]\displaystyle{ d=5\lt \math\gt ) improved the computational cost but significantly decreases the accuracy compared to skimming. The table also shows that replacing LSTM with Skim-LSTM in an existing complex model (BiDAF) stably gives reduced computational cost without losing much accuracy (only 0.2% drop from 77.3% of BiDAF to 77.1% of Sk-BiDAF with \lt math\gt \gamma = 0.001\lt \math\gt ). An explanation for this trend that was given is that the model is more confident about which tokens are important at the second layer. Second, higher \lt math\gt \gamma }[/math] values lead to higher skimming rate, which agrees with its intended functionality.

Figure 4 shows the F1 score of LSTM+Attention model using standard LSTM and Skim LSTM, sorted in ascending order by Flop-R (computational cost). While models tend to perform better with larger computational cost, Skim LSTM (Red) outperforms standard LSTM (Blue) with comparable computational cost. It can also be seen that the computational cost of Skim-LSTM is more stable across different configurations and computational cost. Moreover, increasing the value of [math]\displaystyle{ \gamma }[/math] for Skim-LSTM gradually increases the skipping rate and Flop-R, while it also led to reduced accuracy.

Runtime Benchmark

Results

The results clearly indicate that the Skim-RNN model provides features that are suitable for general reading tasks, which include classification and question answering. While the tables indicate that minor losses in accuracy occasionally did result when parameters were set at specific values, they were not minor and were acceptable given the improvement in runtime.

Conclusion

A Skim-RNN can offer better latency results on a CPU compared to a standard RNN on a GPU, as demonstrated through the results of this study. Future work (as stated by the authors) involves using Skim-RNN for applications that require much higher hidden state size, such as video understanding, and using multiple small RNN cells for varying degrees of skimming.

References

[1] Patricia Anderson Carpenter Marcel Adam Just. The Psychology of Reading and Language Comprehension. 1987.

[2] Eric Jang, Shixiang Gu, and Ben Poole. Categorical reparameterization with gumbel-softmax. In ICLR, 2017.

[3] Ronald J Williams. Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine learning, 8(3-4):229–256, 1992.

[4] Jeffrey Pennington, Richard Socher, and Christopher D Manning. Glove: Global vectors for word representation. In EMNLP, 2014.

[5] Diederik Kingma and Jimmy Ba. Adam: A method for stochastic optimization. In ICLR, 2015.