Pre-Training Tasks For Embedding-Based Large-Scale Retrieval
Contributors
Pierre McWhannel wrote this summary with editorial and technical contributions from STAT 946 fall 2020 classmates. The summary is based on the paper "Pre-Training Tasks for Embedding-Based Large-Scale Retrieval" which was presented at ICLR 2020. The authors of this paper are Wei-Cheng Chang, Felix X. Yu, Yin-Wen Chang, Yiming Yang, Sanjiv Kumar [1].
Introduction
The problem domain which the paper is positioned within is large-scale query-document retrieval. This problem is: given a query/question to collect relevant documents which contain the answer(s) and identify the span of words which contain the answer(s). This problem is often separated into two steps 1) a retrieval phase which reduces a large corpus to a much smaller subset and 2) a reader which is applied to read the shortened list of documents and score them based on their ability to answer the query and identify the span containing the answer, these spans can be ranked according to some scoring function. In the setting of open-domain question answering the query, [math]\displaystyle{ q }[/math] is a question and the documents [math]\displaystyle{ d }[/math] represent a passage which may contain the answer(s). The focus of this paper is on the retriever question answering (QA), in this setting a scoring function [math]\displaystyle{ f: \mathcal{X} \times \mathcal{Y} \rightarrow \mathbb{R} }[/math] is utilized to map [math]\displaystyle{ (q,d) \in \mathcal{X} \times \mathcal{Y} }[/math] to a real value which represents the score of how relevant the passage is to the query. We desire high scores for relevant passages and and low scores otherwise and this is the job of the retriever.
Certain characteristics desired of the retriever are high recall since the reader will only be applied to a subset meaning there is an an opportunity to identify false positive but not false negatives. The other characteristic desired of a retriever is low latency meaning it is computationally efficient. The popular approach to meet these heuristics has been the BM-25 [2] which relies on token-based matching between two high-dimensional and sparse vectors representing [math]\displaystyle{ q }[/math] and [math]\displaystyle{ d }[/math]. This approach performs well and retrieval can be performed in time sublinear to the number of passages with inverter indexing. The downfall of this type of algorithm is its inability to be optimized for a specific task. An alternate option is BERT or transformer based models with cross attention between query and passage pairs which can be optimized for a specific task. However, these models suffer from latency as the retriever needs to process all the pairwise combinations of a query with all passages. The next type of algorithm is an embedding-based model that jointly embeds the query and passage in the same embedding space and then uses measures such as cosine distance, inner product, or even the Euclidean distance in this space to get a score. The authors suggest the two-tower models can capture deeper semantic relationships in addition to being able to be optimized for specific tasks. This model is referred to as the "two-tower retrieval model", since it has two transformer based models where one embeds the query [math]\displaystyle{ \phi{(\cdot)} }[/math] and the other the passage [math]\displaystyle{ \psi{(\cdot)} }[/math], then the embeddings can be scored by [math]\displaystyle{ f(q,d) = \langle \phi{(q)},\psi{(d)} \rangle \in \mathbb{R} }[/math]. This model can avoid the latency problems of a cross-layer model by pre-computing [math]\displaystyle{ \psi{(d)} }[/math] before hand and then by utilizing efficient approximate nearest neighbor search algorithms in the embedding space to find the nearest documents.
These two-tower retrieval models often use BERT with pretrained weights, in BERT the pre-training tasks are masked-LM (MLM) and next sentence prediction (NSP). The authors note that the research of pre-training tasks that improve the performance of two-tower retrieval models is an unsolved research problem. This is exactly what the authors have done in this paper. That is, to develop pre-training tasks for the two-tower model based on heuristics and then validated by experimental results. The pre-training tasks suggested are Inverse Cloze Task (ICT), Body First Search (BFS), Wiki Link Prediction (WLP), and combining all three. The authors used the Retrieval Question-Answering (ReQA) benchmark and used two datasets SQuAD and Natural Questions for training and evaluating their models.
Contributions of this paper
- The two-tower transformer model with proper pre-training can significantly outperform the widely used BM-25 Algorithm.
- Paragraph-level pre-training tasks such as ICT, BFS, and WLP hugely improve the retrieval quality, whereas the most widely used pre-training task (the token-level masked-LM) gives only marginal gains.
- The two-tower models with deep transformer encoders benefit more from paragraph-level pre-training compared to its shallow bag-of-word counterpart (BoW-MLP)
Background on Two-Tower Retrieval Models
This section is to present the two-tower model in more detail. As seen previously [math]\displaystyle{ q \in \mathcal{X} }[/math], [math]\displaystyle{ d \in \mathcal{y} }[/math] are the query and document or passage respectively. The two-tower retrieval model consists of two encoders: [math]\displaystyle{ \phi: \mathcal{X} \rightarrow \mathbb{R}^k }[/math] and [math]\displaystyle{ \psi: \mathcal{Y} \rightarrow \mathbb{R}^k }[/math]. The tokens in [math]\displaystyle{ \mathcal{X},\mathcal{Y} }[/math] are a sequence of tokens representing the query and document respectively. The scoring function is [math]\displaystyle{ f(q,d) = \langle \phi{(q)},\psi{(d)} \rangle \in \mathbb{R} }[/math]. The cross-attention BERT-style model would have [math]\displaystyle{ f_{\theta,w}(q,d) = \psi_{\theta}{(q \oplus d)^{T}}w }[/math] as a scoring function and [math]\displaystyle{ \oplus }[/math] signifies the concatenation of the sequences of [math]\displaystyle{ q }[/math] and [math]\displaystyle{ d }[/math], [math]\displaystyle{ \theta }[/math] are the parameters of the cross-attention model. The architectures of these two models can be seen below in figure 1.
Comparisons
Inference
In terms of inference the two-tower architecture has a distinct advantage as being easily seen from the architecture. Since the query and document are embedded independently, [math]\displaystyle{ \psi{(d)} }[/math] can be pre computed. This is an important advantage as it makes the two-tower method feasible. As an example the authors state the inner product between 100 query embeddings and 1 million document embeddings only takes hundreds of milliseconds on CPUs, however, for the equivalent calculation for a cross-attention-model it takes hours on GPUs. In addition they remark the retrieval of the relevant documents can be done in sublinear time with respect [math]\displaystyle{ |\mathcal{Y}| }[/math] using a maximum inner product (MIPS) algorithm with little loss in recall.
Learning
In terms of learning the two-tower model compared to BM-25 has a unique advantage of being possible to be fine-tuned for downstream tasks. This problem is within the paradigm of metric learning as the scoring function can be passed as input to an objective function of the log likelihood [math]\displaystyle{ \max_{\theta} \space log(p_{\theta}(d|q)) }[/math]. The conditional probability is represented by a SoftMax and then can be rewritten as:
[math]\displaystyle{ p_{\theta}(d|q) = \frac{exp(f_{\theta}(q,d))}{\sum_{d' \in D}} exp(f_{\theta}(q,d')) }[/math]
Here [math]\displaystyle{ \mathcal{D} }[/math] is the set of all possible documents, the denominator is a summation over [math]\displaystyle{ \mathcal{D} }[/math]. As is customary they utilized a sampled SoftMax technique to approximate the full-SoftMax. The technique they have employed is by using a small subset of documents in the current training batch, while also using a proper correcting term to ensure the unbiasedness of the partition function.
Why do we need pre-training then?
Although the two-tower model can be fine-tuned for a downstream task getting sufficiently labelled data is not always possible. This is an opportunity to look for better pre-training tasks as the authors have done. The authors suggest getting a set of pre-training task [math]\displaystyle{ \mathcal{T} }[/math] with labelled positive pairs of queries of documents/passages and then afterward fine-tuning on the limited amount of data for the downstream task.
Pre-Training Tasks
The authors suggest three pre-training tasks for the encoders of the two-tower retriever, these are ICT, BFS, and WLP. Keep in mind that ICT is not a newly proposed task. These pre-training tasks are developed in accordance to two heuristics they state and our paragraph-level pre-training tasks in contrast to more common sentence-level or word-level tasks:
- The tasks should capture a different resolutions of semantics between query and document. The semantics can be the local context within a paragraph, global consistency within a document, and even semantic relation between two documents.
- Collecting the data for the pre-training tasks should require as little human intervention as possible and be cost-efficient.
All three of the following tasks can be constructed in a cost-efficient and with minimal human intervention by utilizing Wikipedia articles. In figure 2 below the dashed line surrounding the block of text is to be regarded as the document [math]\displaystyle{ d }[/math] and the rest of the circled portions of text make up the three queries [math]\displaystyle{ q_1, q_2, q_3 }[/math] selected for each task.
Inverse Cloze Task (ICT)
This task is given a passage [math]\displaystyle{ p }[/math] consisting of [math]\displaystyle{ n }[/math] sentences [math]\displaystyle{ s_{i}, i=1,..,n }[/math]. [math]\displaystyle{ s_{i} }[/math] is randomly sampled and used as the query [math]\displaystyle{ q }[/math] and then the document is constructed to be based on the remaining [math]\displaystyle{ n-1 }[/math] sentences. This is utilized to meet the local context within a paragraph of the first heuristic. In figure 2 [math]\displaystyle{ q_1 }[/math] is an example of a potential query for ICT based on the selected document [math]\displaystyle{ d }[/math].
Body First Search (BFS)
BFS is [math]\displaystyle{ (q,d) }[/math] pairs are created by randomly selecting a sentence [math]\displaystyle{ q_2 }[/math] and setting it as the query from the first section of a Wikipedia page which contains the selected document [math]\displaystyle{ d }[/math]. This is utilized to meet the global context consistency within a document as the first section in Wikipedia generally is an overview of the whole page and they anticipate it to contain information central to the topic. In figure 2 [math]\displaystyle{ q_2 }[/math] is an example of a potential query for BFS based on the selected document [math]\displaystyle{ d }[/math].
Wiki Link Prediction (WLP)
The third task is selecting a hyperlink that is within the selected document [math]\displaystyle{ d }[/math], as can be observed in figure 2 "machine learning" is a valid option. The query [math]\displaystyle{ q_3 }[/math] will then be a random sentence on the first section of the page where the hyperlinks redirects to. This is utilized to provide the semantic relation between documents.
Additionally Masked LM (MLM) is considered during the experimentations which again is one of the primary pre-training tasks used in BERT.
Experiments and Results
Training Specifications
- Each tower is constructed from the 12 layers BERT-base model.
- Embedding is achieved by applying a linear layer on the [CLS] token output to get an embedding dimension of 512.
- Sequence lengths of 64 and 288 respectively for the query and document encoders.
- Pre-trained on 32 TPU v3 chips for 100k step with Adam optimizer learning rate of [math]\displaystyle{ 1 \times 10^{-4} }[/math] with 0.1 warm-up ratio, followed by linear learning rate decay and batch size of 8192. (~2.5 days to complete)
- Fine tuning on downstream task [math]\displaystyle{ 5 \times 10^{-5} }[/math] with 2000 training steps and batch size 512.
- Tokenizer is WordPiece.
Pre-Training Tasks Setup
- MLM, ICT, BFS, and WLP are used.
- Various combinations of the aforementioned tasks are tested.
- Tasks define [math]\displaystyle{ (q,d) }[/math] pairings.
- ICT's document [math]\displaystyle{ d }[/math] has the article title and passage separated by a [SEP] symbol as input to the document encoder.
Below table 1 shows some statistics for the datasets constructed for the ICT, BFS, and WLP tasks. Token counts considered after the tokenizer is applied.
Downstream Tasks Setup
Retrieval Question-Answering (ReQA) benchmark used.
- Datasets: SQuAD and Natural Questions.
- Each entry of data from datasets is [math]\displaystyle{ (q,a,p) }[/math], where each element is, query [math]\displaystyle{ q }[/math], answer [math]\displaystyle{ a }[/math], and passage [math]\displaystyle{ p }[/math] containing the answer, respectively.
- Authors split the passage to sentences [math]\displaystyle{ s_i }[/math] where [math]\displaystyle{ i=1,...n }[/math] and [math]\displaystyle{ n }[/math] is the number of sentences in the passage.
- The problem is then recast for a given query to retrieve the correct sentence-passage [math]\displaystyle{ (s,p) }[/math] from all candidates for all passages split in this fashion.
- They remark their reformulation makes it a more difficult problem.
Note: Experiments done on ReQA benchmarks are not entirely open-domain QA retrieval as the candidate [math]\displaystyle{ (s,p) }[/math] only cover the training set of QA dataset instead of entire Wikipedia articles. There is some experiments they did with augmented the dataset as will be seen in table 6.
Below table 2 shows statistics of the datasets used for the downstream QA task.
Evaluation
- 3 train/test splits: 1%/99%, 5%/95%, and 80%/20%.
- 10% of training set is used as a validation set for hyper-parameter tuning.
- Training never has seen the test queries in the test pairs of [math]\displaystyle{ (q,d). }[/math]
- Evaluation metric is recall@k since the goal is to capture positives in the top-k results.
Results
Table 3 shows the results on the SQuAD dataset, as observed the performance using the combination of all 3 pre-training tasks has the best performance, with the exception of R@1 for 1%/99% train/test split. BoW-MLP is used to justify the use of a transformer as the encoder since it is has more complexity, BoW-MLP here looks up uni-grams from an embedding table, aggregates the embeddings with average pooling, and passes them though a shallow two-layer ML network with [math]\displaystyle{ tanh }[/math] activation to generate the final 512-dimensional query/document embeddings.
Table 4 shows the results on the Natural Questions dataset, as observed the performance using the combination of all 3 pre-training tasks has the best performance in all testing scenarios.
Ablation Study
The embedding-dimensions, # of layers in BERT and pre-training tasks are altered for this ablation study as seen in table 5.
Interestingly ICT+BFS+WLP is only an absolute of 1.5% better than just ICT in the low-data regime. The authors infer this could be representative of when there is insufficient downstream data for training that more global pre-training tasks become increasingly beneficial, since BFS and WLP offer this.
Evaluation of Open-Domain Retrieval
Augment ReQA benchmark with large-scale (sentence, evidence passage) pairs extracted from general Wikipedia. This is added as one million external candidate pairs into the existing retrieval candidate set of the ReQA benchmark. Results here are fairly consistent with previous results.
Conclusion
The authors performed a comprehensive study and have concluded that properly designing paragraph-level pre-training tasks including ICT, BFS, and WLP for a two-tower transformer model can significantly outperform a BM-25 algorithm, which is an unsupervised method using weighted token-matching as a scoring function. The findings of this paper suggest that a two-tower transformer model with proper pre-training tasks can replace the BM25 algorithm used in the retrieval stage to further improve end-to-end system performance. The authors plan to test the pre-training tasks with a larger variety of encoder architectures and trying other corpora than Wikipedia and additionally trying pre-training in contrast to different regularization methods.
Critique
The authors of this paper suggest from their experiments that the combination of the 3 pretraining tasks ICT, BFS, and WLP yields the best results in general. However, they do acknowledge why two of the three pre-training tasks may not produce equally significant results as the three. From the ablation study and the small margin (1.5%) in which the three tasks achieved outperformed only using ICT in combination with observing ICT outperforming the three as seen in table 6 in certain scenarios. It seems plausible that a subset of two tasks could also potentially outperform the three or reach a similar performance. I believe this should be addressed in order to conclude three pre-training tasks are needed over two, though they never address this comparison I think it would make the paper more complete.
References
[1] Wei-Cheng Chang, Felix X Yu, Yin-Wen Chang, Yiming Yang, and Sanjiv Kumar. Pre-training tasks for embedding-based large-scale retrieval. arXiv preprint arXiv:2002.03932, 2020.
[2] Stephen Robertson, Hugo Zaragoza, et al. The probabilistic relevance framework: BM25 and beyond. Foundations and Trends in Information Retrieval, 3(4):333–389, 2009.