Home

Fast Translation with Non-autoregressive Generation

by Jungo Kasai

parallel wavenet

Why Non-autoregressive Machine Translation?

A flurry of recent work proposed approaches to non-autoregressive machine translation (NAT, Gu et al. 2018). NAT generates target words in parallel, contrasting with standard autoregressive translation (AT) that predicts every word conditioned on all previous ones. While AT generally performs better than NAT with a similar configuration, NAT speeds up inference by parallel computation. One very successful application of such non-autoregressive generation is Parallel WaveNet (Oord et al. 2017), which speeds up the original autoregressive Wavenet by more than 1000 times and is deployed in the Google Assistant. The gain of fast inference from NAT can also allow for deployment of bigger and deeper transformer models under a certain latency budget in production. In this blog post, I will give an overview on recent research on non-autoregressive translation and discuss what I believe is missing or important for further development.

Fundamental Problem and Possible Remedies

Multimodality in generation presents a fundamental challenge in NAT. We all know that language is highly multimodal. To give a minimal example to see this, he is very good at Japanese and he speaks Japanese very well are both valid translations of a Japanese sentence, 彼は日本語が上手です. However, it would not make sense to say something like he speaks very good at Japanese or he is very good at very well. We need to know which of the two possible translations (or modes more generally) the model is commiting itself to, but this cannot be easily achieved in conditionally independent decoding; parallel decoding breaks conditional dependence and often leads to inconsistent outputs. Several work in the literature has developed ways to address this problem of multimodality in NAT. Here I roughly summarize and categorize proposed approaches.

Iteration-based Methods

One way to remedy the issue in parallel decoding is to refine model output iteratively (Lee et al., 2018; Ghazvininejad et al., 2019; Gu et al. 2019; Kasai et al. 2020). In this framework, we give up on completely parallelizable generation, and instead we refine previously generated words in each iteration. Since typically we need much fewer iterations than the number of words in the output sentence, an iterative method can still improve latency compared to autoregressive models. Each of these papers take a different method for refinement, but you might find conditional masked language models (CMLMs) from Ghazvininejad et al., 2019 easy to understand. CMLMs are trained with a BERT-style masked language modeling objective on the target side given the source text, and in inference, we mask output tokens with low confidence and update them in every iteration. In Kasai et al. 2020, we proposed the DisCo transformer that computes an efficient alternative to this masked language modeling. In particular, the DisCo transformer can be trained to predict every output token given an arbitrary subset of the other reference tokens. This can be thought of as simulating multiple maskings at a time. We showed that the DisCo transformer can reduce the number of required iterations (and thus decoding time) while keeping translation quality.

Latency and quality comparison between the CMLM and DisCo transformer on the WMT'17 en→zh test data with respect to the standard autoregressive model. T denotes the max number of iterations.

Training with Different (Augmented) Objectives than NLL

Several work proposed alternative loss functions to negative log likelihood. My intuition is training with the vanilla NLL loss fails to induce representations that can capture highly multimodal distributions. This is somewhat reminiscent of the adversarial loss in generative adversarial networks (GANs). In image generation, the vanilla L2 reconstruction loss would collapse modes and yield blurry images, and maybe a similar situation occurs when an NAT model is trained with the NLL loss. Proposed alternative loss functions include the distance of hidden states between an NAT model and autoregressive teacher (Li et al. 2019), Bag-of-Ngrams difference (Shao et al. 2020), and auxiliary regularization (Wang et al. 2019). This strand of approaches can achieve one-shot generation in parallel at the expense of much degraded performance compared to iteration-based methods.

Lite/Partial Autoregressive Decoding

Prior work also proposed methods to incorprate a lite or partial autoregressive module to an NAT model. Kaiser et al. 2018 generated a shorter sequence of latent variables autoregressively and ran parallel word predictions on top. Blockwise decoding and the Insertion Transformer generate a sentence in a partially autoregressive manner (Stern et al. 2018, 2019). Sun et al. 2019 introduced a factorized CRF layer on top of transformer output vectors and ran fast autoregressive decoding with beam approximation. Ran et al. 2019 introduced a lite autoregressive source reordering module to facilitate parallel target decoding. Note that they also presented results with a non-autoregressive reordering module, but the performance is much worse.

Modeling with Latent Variables

We can interpret many of the models in this framework. For example, all NAT models that condition on predicted lengths can be seen as modeling with latent variables. But in particular, Ma et al. 2019 used the technique of generative flow to model complex distributions of target sentences. Shu et al. 2020 developed a continuous latent-variable NAT model with deterministic inference.

Distillation from Autoregressive Models

As far as I know, almost all competitive NAT models are trained with sequence-level knowledge distillation (Kim & Rush 2016) from autoregressive models (e.g. Gu et al. 2018). While distillation from a larger transformer is helpful for autoregressive machine translation as well, especially in the case of greedy decoding, it facilitates NAT models to greater degrees in general (Kasai et al. 2020). Zhou et al. 2019 examined the relationship between the model capacity and distillation data, showing that there is correlation between the model capacity and distillation data complexity. This suggests that knowledge distillation can kill some of the modes in the raw data so that an NAT model can be trained better.

Open Questions and Future Goals

Here I highlight open questions in non-autoregressive machine translation that I am personally curious about.

Further Resources