Natural language models typically have to solve two tough problems: mapping sentence prefixes to fixed-sized representations and using the representations to predict the next word in the text. In a recent paper, researchers at Facebook AI Research assert that the first problem — the mapping problem — might be easier than the prediction problem, a hypothesis they build upon to augment language models with a “nearest neighbors” retrieval mechanism. They say it allows rare patterns to be memorized and that it achieves a state-of-the-art complexity score (a measure of vocabulary and syntax variety) with no additional training.
As the researchers explain, language models assign probabilities to sequences of words, such that from a context sequence of tokens (e.g., words) they estimate the distribution (the probabilities of occurrence of different possible outcomes) over target tokens. The proposed approach — kNN-LM — maps a context to a fixed-length mathematical representation computed by the pre-trained language model. Given a training example, a key-value pair is defined, where the key is the mathematical representation of the context and the value is the target word.
At test time, kNN-LM takes an input context and generates an output distribution over next words and the context representation. It retrieves its nearest neighbors according to a distance function, at which point it computes a distribution over neighbors while aggregating probabilities for each vocabulary item across all its occurrences in the retrieved targets.
The researchers note that kNN-LM is compatible with any language model that produces fixed-size context representations. In the study, this enabled the training of a Transformer-based model on a data set consisting of 103 million tokens from Wikipedia articles, 250,000 of which were reserved for development and testing.
In experiments, the kNN-LM “significantly” outperformed the baselines at test time, which the team attributes to its propensity for learning a representation function for contexts with an implicit notion of similarity. The kNN-LM added some computational overhead — it took roughly two hours on a single processor to build a cache for 103 million entries, and running the validation set took approximately 25 minutes. But the team points out that it’s “trivial” to parallelize the model and that it requires no GPU-based training.
“In general, we find that examples where kNN-LM is most helpful typically contain rare patterns,” the coauthors of the study wrote. “Examples include factual knowledge, names, and near-duplicate sentences from the training set. In these cases, assigning train and test instances similar representations … appears to be an easier problem than implicitly memorizing the next word in model parameters.”