Neural Speed Reading via Skim-RNN

From statwiki
Revision as of 19:29, 21 October 2020 by J354huan (talk | contribs)
Jump to: navigation, search

Group

Leyan Cheng, Mingyan Dai, Jerry Huang, Daniel Jiang

Introduction

In Natural Language Processing, recurrent neural networks (RNNs) are a common architecture used to sequentially ‘read’ input tokens and output a distributed representation for each token. By recurrently updating the hidden state of the neural network, a RNN can inherently require the same computational cost across time. However, when it comes to processing input tokens, it is usually the case that some tokens are less important to the overall representation of a piece of text or a query when compared to others. In particular, when considering question answering, many times the neural network will encounter parts of a passage that is irrelevant when it comes to answering a query that is being asked.

Model

In this paper, the authors introduce a model called 'skim-RNN', which takes advantage of ‘skimming’ less important tokens or pieces of text rather than ‘skipping’ them entirely. This models the human ability to skim through passages, or to spend less time reading parts do not affect the reader’s main objective. While this leads to a loss in the comprehension rate of the text [1], it greatly reduces the amount of time spent reading by not focusing on areas which will not significantly affect efficiency when it comes to the reader's objective.

'Skim-RNN' works by rapidly determining the significance of each input and spending less time processing unimportant input tokens by using a smaller RNN to update only a fraction of the hidden state. When the decision is to ‘fully read’, that is to not skim the text, Skim-RNN updates the entire hidden state with the default RNN cell. Since the hard decision function (‘skim’ or ‘read’) is non-differentiable, the authors use a gumbel-softmax [2] to estimate the gradient of the function, rather than traditional methods such as REINFORCE (policy gradient)[3]. The switching between the two different RNN cells enables Skim-RNN to reduce the total number of operations performed when the skimming rate is high, which often leads to faster inference on CPUs, which makes it very useful for large-scale products and small devices.

The Skim-RNN has the same input and output interfaces as standard RNNs, so it can be conveniently used to speed up RNNs in existing models. In addition, the speed of Skim-RNN can be dynamically controlled at inference time by adjusting a parameter for the threshold for the ‘skim’ decision.

Implementation

Experiment

The effectiveness of Skim-RNN was measured in terms of accuracy and float operation reduction on four classification tasks and a question answering task. These tasks were chosen because they do not require one’s full attention to every detail of the text, but rather ask for capturing the high-level information (classification) or focusing on specific portion (QA) of the text, which a common context for speed reading. The tasks themselves are listed in the table below.

Classification Tasks

In a language classification task, the input was a sequence of words and the output was the vector of categorical probabilities. Each word is embedded into a [math]d[/math]-dimensional vector. We initialize the vector with GloVe [4] to form representations of the words and use those as the inputs for a long short-term memory (LSTM) architecture. A linear transformation on the last hidden state of the LSTM and then a softmax function was applied to obtain the classification probabilities. Adam [5] was used for optimization, with initial learning rate of 0.0001. For Skim-LSTM, [math]\tau = \max(0.5, exp(−rn))[/math] where [math]r = 1e-4[/math] and [math]n[/math] is the global training step, following [2]. We experiment on different sizes of big LSTM ([math]d \in \{100, 200\}[/math]) and small LSTM ([math]d_0 \in \{5, 10, 20\}[/math]) and the ratio between the model loss and the skim loss ([math]\gamma\in \{0.01, 0.02\}[/math]) for Skim-LSTM. The batch sizes used were 32 for SST and Rotten Tomatoes, and 128 for others. For all models, early stopping was used when the validation accuracy did not increase for 3000 global steps.

Results

Table 2 shows the accuracy and the computational cost of the Skim-RNN model compared with other standard models. It is evident that the Skim-RNN model produces a speed-up on the computational complexity of the task while maintaining a high degree of accuracy. Figure 2 meanwhile demonstrates the effect of varying the size of the small hidden state as well as the parameter [math]\gamma[/math] on the accuracy and computational cost.

Table 3 shows an example of a classification task over a IMDb dataset, where Skim-RNN with [math]d = 200[/math], [math]d_0 = 10[/math], and [math]\gamma = 0.01[/math] correctly classifies it with high skimming rate (92%). The goal was to classify the review as either positive or negative. The black words are skimmed, and blue words are fully read. The skimmed words are clearly irrelevant and the model learns to only carefully read the important words, such as ‘liked’, ‘dreadful’, and ‘tiresome’.

Question Answering Task

Runtime Benchmark

Results

The results clearly indicate that the Skim-RNN model provides features that are suitable for general reading tasks, which include classification and question answering. While the tables indicate that minor losses in accuracy occasionally did result when parameters were set at specific values, they were not minor and were acceptable given the improvement in runtime.

Conclusion

A Skim-RNN can offer better latency results on a CPU compared to a standard RNN on a GPU, as demonstrated through the results of this study. Future work (as stated by the authors) involves using Skim-RNN for applications that require much higher hidden state size, such as video understanding, and using multiple small RNN cells for varying degrees of skimming.

References

[1] Patricia Anderson Carpenter Marcel Adam Just. The Psychology of Reading and Language Comprehension. 1987.

[2] Eric Jang, Shixiang Gu, and Ben Poole. Categorical reparameterization with gumbel-softmax. In ICLR, 2017.

[3] Ronald J Williams. Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine learning, 8(3-4):229–256, 1992.

[4] Jeffrey Pennington, Richard Socher, and Christopher D Manning. Glove: Global vectors for word representation. In EMNLP, 2014.

[5] Diederik Kingma and Jimmy Ba. Adam: A method for stochastic optimization. In ICLR, 2015.