Fast Translation with Non-autoregressive Generation
by Jungo Kasai
Why Non-autoregressive Machine Translation?
A flurry of recent work proposed approaches to non-autoregressive machine translation (NAT, Gu et al. 2018). NAT generates target words in parallel, contrasting with standard autoregressive translation (AT) that predicts every word conditioned on all previous ones. While AT generally performs better than NAT with a similar configuration, NAT speeds up inference by parallel computation. One very successful application of such non-autoregressive generation is Parallel WaveNet (Oord et al. 2017), which speeds up the original autoregressive Wavenet by more than 1000 times and is deployed in the Google Assistant. The gain of fast inference from NAT can also allow for deployment of bigger and deeper transformer models under a certain latency budget in production. In this blog post, I will give an overview on recent research on non-autoregressive translation and discuss what I believe is missing or important for further development.
Fundamental Problem and Possible Remedies
Multimodality in generation presents a fundamental challenge in NAT. We all know that language is highly multimodal. To give a minimal example to see this,
he is very good at Japanese and
he speaks Japanese very well are both valid translations of a Japanese sentence,
彼は日本語が上手です. However, it would not make sense to say something like
he speaks very good at Japanese or
he is very good at very well.
We need to know which of the two possible translations (or modes more generally) the model is commiting itself to, but this cannot be easily achieved in conditionally independent decoding; parallel decoding breaks conditional dependence and often leads to inconsistent outputs.
Several work in the literature has developed ways to address this problem of multimodality in NAT.
Here I roughly summarize and categorize proposed approaches.
One way to remedy the issue in parallel decoding is to refine model output iteratively (Lee et al., 2018; Ghazvininejad et al., 2019; Gu et al. 2019; Kasai et al. 2020). In this framework, we give up on completely parallelizable generation, and instead we refine previously generated words in each iteration. Since typically we need much fewer iterations than the number of words in the output sentence, an iterative method can still improve latency compared to autoregressive models. Each of these papers take a different method for refinement, but you might find conditional masked language models (CMLMs) from Ghazvininejad et al., 2019 easy to understand. CMLMs are trained with a BERT-style masked language modeling objective on the target side given the source text, and in inference, we mask output tokens with low confidence and update them in every iteration. In Kasai et al. 2020, we proposed the DisCo transformer that computes an efficient alternative to this masked language modeling. In particular, the DisCo transformer can be trained to predict every output token given an arbitrary subset of the other reference tokens. This can be thought of as simulating multiple maskings at a time. We showed that the DisCo transformer can reduce the number of required iterations (and thus decoding time) while keeping translation quality.
Training with Different (Augmented) Objectives than NLL
Several work proposed alternative loss functions to negative log likelihood. My intuition is training with the vanilla NLL loss fails to induce representations that can capture highly multimodal distributions. This is somewhat reminiscent of the adversarial loss in generative adversarial networks (GANs). In image generation, the vanilla L2 reconstruction loss would collapse modes and yield blurry images, and maybe a similar situation occurs when an NAT model is trained with the NLL loss. Proposed alternative loss functions include the distance of hidden states between an NAT model and autoregressive teacher (Li et al. 2019), Bag-of-Ngrams difference (Shao et al. 2020), and auxiliary regularization (Wang et al. 2019). This strand of approaches can achieve one-shot generation in parallel at the expense of much degraded performance compared to iteration-based methods.
Lite/Partial Autoregressive Decoding
Prior work also proposed methods to incorprate a lite or partial autoregressive module to an NAT model. Kaiser et al. 2018 generated a shorter sequence of latent variables autoregressively and ran parallel word predictions on top. Blockwise decoding and the Insertion Transformer generate a sentence in a partially autoregressive manner (Stern et al. 2018, 2019). Sun et al. 2019 introduced a factorized CRF layer on top of transformer output vectors and ran fast autoregressive decoding with beam approximation. Ran et al. 2019 introduced a lite autoregressive source reordering module to facilitate parallel target decoding. Note that they also presented results with a non-autoregressive reordering module, but the performance is much worse.
Modeling with Latent Variables
We can interpret many of the models in this framework. For example, all NAT models that condition on predicted lengths can be seen as modeling with latent variables. But in particular, Ma et al. 2019 used the technique of generative flow to model complex distributions of target sentences. Shu et al. 2020 developed a continuous latent-variable NAT model with deterministic inference.
Distillation from Autoregressive Models
As far as I know, almost all competitive NAT models are trained with sequence-level knowledge distillation (Kim & Rush 2016) from autoregressive models (e.g. Gu et al. 2018). While distillation from a larger transformer is helpful for autoregressive machine translation as well, especially in the case of greedy decoding, it facilitates NAT models to greater degrees in general (Kasai et al. 2020). Zhou et al. 2019 examined the relationship between the model capacity and distillation data, showing that there is correlation between the model capacity and distillation data complexity. This suggests that knowledge distillation can kill some of the modes in the raw data so that an NAT model can be trained better.
Open Questions and Future Goals
Here I highlight open questions in non-autoregressive machine translation that I am personally curious about.
- Do we need distillation? Distillation is surely one-time training cost, but it can be expensive if we have to do every time we change training data or language pairs. It would be ideal if we could achieve reasonable performance with raw data.
- Do we need to predict the target length? I still find target length prediction strange. Many of the current NAT methods require target length prediction and condition on the predicted length. While length prediction gives us an opportunity to search in the latent variable space (length beam), length prediction undermines flexibility in generation.
- Can NAT outperform AT? We have been seeing that AT generally outperforms NAT under the same configuration. However, can NAT do better? Or more practically, can NAT significantly outperform AT under the same latency budget? NAT can afford to use a bigger configuration.
- Pretraining and NAT. It might be easier to use a large-scale pretrained masked language model in non-autoregressive machine translation than in autoregressive translation. Decoders in NAT, such as conditional masked language models, look more like BERT.
- Bridge training and inference. A gap often emerges between training and inference in the iterative NAT framework. For example, a conditional language model (CMLM) is trained to predict masked tokens given the other gold observed tokens. One very recent successful attempt is SMART training for a CMLM (Ghazvininejad et al. 2020) where they train the model to recover from previous prediction errors. This kind of approach can be applicable to iteration-based NAT in general.
- Learn from Structured Prediction. Much effort has been put into structured prediction such as syntactic and semantic parsing in NLP. Can we learn from methods in structured prediction to better deal with conditional dependence in generation? The aforementioned gap between training and inference is a problem studied in syntactic parsing (e.g. dynamic oracle, Goldberg & Nivre 2012). I suspect there will be more lessons to be learned from structured prediction.