Large Transformer Model Inference Optimization
Large transformer models are mainstream nowadays, creating SoTA results for a variety of tasks. They are powerful but very expensive to train and use. The extremely high inference cost, in both time and memory, is a big bottleneck for adopting a powerful transformer for solving real-world tasks at scale.
Why is it hard to run inference for large transformer models? Besides the increasing size of SoTA models, there are two main factors contributing to the inference challenge (Pope et al. 2022):
- Large memory footprint . Both model parameters and intermediate states are needed in memory at inference time. For example,
- The KV cache should be stored in memory during decoding time; E.g. For a batch size of 512 and context length of 2048, the KV cache totals 3TB, that is 3x the model size (!).
- Inference cost from the attention mechanism scales quadratically with input sequence length.
- Low parallelizability. Inference generation is executed in an autoregressive fashion, making the decoding process hard to parallel.
In this post, we will look into several approaches for making transformer inference more efficient. Some are general network compression methods, while others are specific to transformer architecture.
Methods Overview
We in general consider the following as goals for model inference optimization:
- Reduce the memory footprint of the model by using fewer GPU devices and less GPU memory;
- Reduce the desired computation complexity by lowering the number of FLOPs needed;
- Reduce the inference latency and make things run faster.
Several methods can be used to make inference cheaper in memory or/and faster in time.
- Apply various parallelism to scale up the model across a large number of GPUs. Smart parallelism of model components and data makes it possible to run a model of trillions of parameters.
- Memory offloading to offload temporarily unused data to the CPU and read them back when needed later. This helps with memory usage but causes higher latency.
- Smart batching strategy; E.g. EffectiveTransformer packs consecutive sequences together to remove padding within one batch.
- Network compression techniques, such as pruning, quantization, distillation. A model of smaller size, in terms of parameter count or bitwidth, should demand less memory and run faster.
- Improvement specific to a target model architecture. Many architectural changes, especially those for attention layers, help with transformer decoding speed.
Check the previous post on large model training on different types of training parallelism and memory saving designs including CPU memory offloading. This post focuses on network compression techniques and architecture-specific improvement for transformer models.
Distillation
Knowledge Distillation (KD ; Hinton et al. 2015, Gou et al. 2020) is a straightforward way to build a smaller, cheaper model ("student model" ) to speed up inference by transferring skills from a pre-trained expensive model ("teacher model") into the student. There is no much restriction on how the student architecture should be constructed, except for a matched output space with the teacher in order to construct a proper learning objective.
Fig. 1. The generic framework of teacher-student knowledge distillation training. (Image source: Gou et al. 2020)
Given a dataset, a student model is trained to mimic outputs of a teacher via distillation loss. Usually a neural network has a softmax layer; For example, a LLM outputs a probability distribution over tokens. Let's denote the logits layer right before softmax as 𝑧𝑡 and 𝑧𝑠 for teacher and student models, respectively. The distillation loss minimizes the difference between two softmax outputs with a high temperature 𝑇. When ground truth labels 𝑦 are known, we can combine it with a supervised learning objective between ground truth and the student's soft logits using e.g. cross-entropy.
𝐿KD=𝐿distll(softmax(𝑧𝑡,𝑇),softmax(𝑧𝑠,𝑇))+𝜆𝐿CE(𝑦,𝑧𝑠)
where 𝜆 is a hyperparameter to balance between soft and hard learning objectives. A common choice for 𝐿distll is KL divergence / cross entropy.
A successful early trial is DistilBERT (Sanh et al. 2019) that is able to reduce the parameters of a BERT by 40% while maintaining 97% performance of BERT on fine-tuned downstream tasks and running 71% faster. The loss of pre-training DistilBERT is a combination of soft distillation loss, supervised training loss (i.e. Masked language modeling loss 𝐿MLM in the case of BERT) and a special cosine embedding loss to align the hidden state vectors between teacher and student.
Distillation can be easily combined with quantization, pruning or sparsification techniques, where the teacher model is the original full-precision, dense model and the student is quantized, pruned, or trimmed to have higher sparsity level.
Quantization
There are two common approaches for applying quantization on a deep neural network:
- Post-Training Quantization (PTQ): A model is first trained to convergence and then we convert its weights to lower precision without more training. It is usually quite cheap to implement, in comparison to training.
- Quantization-Aware Training (QAT): Quantization is applied during pre-training or further fine-tuning. QAT is able to attain better performance but requires extra computation resources and access to representative training data.
We should be aware of the gap between theoretical optimal quantization strategy and the hardware kernel support. Due to the lack of GPU kernel support for certain types of matrix multiplication (e.g. INT4 x FP16), not all the methods below result in speedup for the actual inference.
Challenges for Transformer Quantization
Many studies on Transformer model quantization have the same observation: A simple low-precision (e.g. 8-bit) post-training quantization leads to significant performance drop mainly due to the high dynamic ranges of activation and a naive activation quantization strategy fails to maintain the capacity.
Fig. 2. Only quantizing model weights to 8-bit while keeping activation at full precision (`W8A32`) achieves much better results when activations are quantized to 8-bit irrespective of whether weights are in lower precision (`W8A8` and `W32A8`). (Image source: Bondarenko et al. 2021)
Bondarenko et al. (2021) observed in a small BERT model that FFN's input and output have very different dynamic ranges due to strong outliers in the output tensor. Therefore per-tensor quantization for the FFN's residual sum is likely to cause a notable error.
As the model size continues to grow to billions of parameters, outlier features of high magnitude start to emerge in all transformer layers, causing failure of simple low-bit quantization. Dettmers et al. (2022) observed such a phenomenon for OPT models larger than 6.7B parameters. Larger models have more layers with extreme outliers and these outlier features have a significant impact on the model performance. The scale of activation outliers in a few dimensions can be ~100× larger than most of the other values.
Fig. 3. The mean zero-shot accuracy over a set of language tasks (WinoGrande, HellaSwag, PIQA, LAMBADA) of OPT models of increasing sizes. (Image source: Dettmers et al. 2022)
Post-training quantization (PTQ)
Mixed-precision quantization
The most straightforward approach for resolving the above quantization challenge is to implement quantization at different precision for weights vs activation.
GOBO (Zadeh et al. 2020) is one of the first models to apply post-training quantization on transformers (i.e. a small BERT model). It assumes that model weights of each layer follow a Gaussian distribution and therefore detects outliers by tracking mean and standard deviation per layer. Outlier features remain in original form, while other values are split into multiple bins and only corresponding bin indices of weights and the centroid values are stored.
Based on the observation that only certain activation layers (e.g. residual connections after FFN) in BERT cause big performance drop, Bondarenko et al. (2021) adopted mixed-precision quantization by using 16-bit quantization on problematic activations but 8-bit on others.
Mixed-precision quantization in LLM.int8()
(Dettmers et al. 2022) is implemented via two mixed-precision decompositions:
- Because matrix multiplication contains a set of independent inner products between row and column vectors, we can impose independent quantization per inner product: Each row and column are scaled by the absolution maximum values and then quantized to INT8.
- Outlier activation features (e.g. 20x larger than other dimensions) remain in FP16 but they represent only a tiny fraction of total weights. How to identify outliers is empirical.
Fig. 4. Two mixed-precision decompositions of `LLM.int8()`. (Image source: Dettmers et al. 2022)
Quantization at fine-grained granularity
Fig. 5. Comparison of quantization at different granularity. 𝑑 is the model size / hidden state dimension and ℎ is the number of heads in one MHSA (multi-head self-attention) component.
Naively quantizing the entire weight matrix in one layer ("per-tensor" or "per-layer" quantization) is easiest to implement but does not lead to good granularity of quantization.
Q-BERT (Shen, Dong & Ye, et al. 2020) applied group-wise quantization to a fine-tuned BERT model, treating an individual matrix 𝑊 with respect to each head in MHSA (multi-head self-attention) as one group and then applies Hessian based mixed precision quantization.
Per-embedding group (PEG) activation quantization was motivated by the observation that outlier values only appear in a few out of 𝑑 (hidden state / model size) dimensions (Bondarenko et al. 2021). Per-embedding is pretty computationally expensive. In comparison, PEG quantization splits the activation tensor into several evenly sized groups along the embedding dimension where elements in the same group share quantization parameters. To ensure all outliers are grouped together, they apply a deterministic range-based permutation of embedding dimensions, where dimensions are sorted by their value ranges.
ZeroQuant (Yao et al. 2022) uses group-wise quantization for weights, same as in Q-BERT, and token-wise quantization for activation. To avoid expensive quantization and de-quantization computation, ZeroQuant built customized kernel to fuse quantization operation with its previous operator.
Second order information for quantization
Q-BERT (Shen, Dong & Ye, et al. 2020) developed Hessian AWare Quantization (HAWQ) for its mixed-precision quantization. The motivation is that parameters with higher Hessian spectrum (i.e., larger top eigenvalues) are more sensitive to quantization and thus require higher precision. It is essentially a way to identify outliers.
In another viewpoint, the problem of quantization is an optimization problem. Given a weight matrix 𝑊 and an input matrix 𝑋 , we want to find a quantized weight matrix 𝑊^ to minimize the MSE:
𝑊^∗=argmin𝑊^|𝑊𝑋−𝑊^𝑋|
GPTQ (Frantar et al. 2022) treats the weight matrix 𝑊 as a collection of row vectors 𝑤 and applies quantization to each row independently. GPTQ iteratively quantizes more weights that are selected greedily to minimize the quantization error. The update on selected weights has a closed-form formula, utilizing Hessian matrices. Read more details in the paper and the OBQ (Optimal Brain Quantization; Frantar & Alistarh 2022) method if interested. GPTQ can reduce the bitwidth of weights in OPT-175B down to 3 or 4 bits without much performance loss, but it only applies to model weights not activation.
Outlier smoothing
It is known that activations are harder to quantize than weights in transformer models. SmoothQuant (Xiao & Lin 2022) proposed a smart solution to smooth outlier features from activations to weights via mathematically equivalent transformation and then enable quantization on both weights and activations (W8A8
). Because of this, SmoothQuant has better hardware efficiency than mixed-precision quantization.
Fig. 6. SmoothQuant migrates the scale variance from activations to weights offline to reduce the difficulty of activation quantization. Both the resulting new weight and activation matrices are easy to quantize. (Image source: Xiao & Lin 2022)
Considering a per-channel smooth factor 𝑠, SmoothQuant scales the weights according to:
𝑌=(𝑋diag(𝑠)−1)⋅(diag(𝑠)𝑊)=𝑋^𝑊^
The smoothing factor can be easily fused into previous layers' parameters offline. A hyperparameter 𝛼 controls how much we migrate the quantization difficulty from activations to weights: 𝑠=max(|𝑋𝑗|)𝛼/max(|𝑊𝑗|)1−𝛼. The paper found that 𝛼=0.5 is a sweet spot for many LLMs in the experiments. For models with more significant outliers in activation, 𝛼 can be adjusted to be larger.
Quantization-aware training (QAT)
Quantization-aware training fuses the quantization operation into the pre-training or fine-tuning process. It learns model weights in low-bit representation directly and leads to better performance at the cost of additional training time and computation.
The most straightforward approach is to fine-tune the model after quantization on a training dataset that is the same as or representative of the pre-training dataset. The training objective can be the same as the one for pre-training (e.g. NLL/MLM in general language model training) or specific to a downstream task that we care about (e.g. Cross entropy for classification).
Another approach is to consider the full-precision model as the teacher and the lower-precision model as the student, and then optimize the low-precision model with distillation loss. Distillation usually doesn't need to use the original dataset; E.g. Wikipedia dataset is a good choice and even random tokens can give decent performance gain. The Layer-by-layer Knowledge Distillation (LKD ; Yao et al. 2022) method quantizes the network layer by layer and uses its original, unquantized version as the teacher model. Given the same inputs, LKD minimizes the MSE between the multiplication with layer weights and the multiplication of quantized layer weights.
Pruning
Network pruning is to reduce the model size by trimming unimportant model weights or connections while the model capacity remains. It may or may not require re-training. Pruning can be unstructured or structured.
- Unstructured pruning is allowed to drop any weight or connection, so it does not retain the original network architecture. Unstructured pruning often does not work well with modern hardware and doesn't lead to actual inference speedup.
- Structured pruning aims to maintain the dense matrix multiplication form where some elements are zeros. They may need to follow certain pattern restrictions to work with what hardware kernel supports. Here we focus on structured pruning to achieve high sparsity in transformer models.
A routine workflow to construct a pruned network has three steps:
- Train a dense network until convergence;
- Prune the network to remove unwanted structure;
- Optionally retrain the network to recover the performance with new weights.
The idea of discovering a sparse structure within a dense model via network pruning while the sparse network can still maintain similar performance is motivated by Lottery Ticket Hypothesis (LTH ): A randomly initialized, dense, feed-forward network contains a pool of subnetworks and among them only a subset (a sparse network) are "winning tickets" which can achieve the optimal performance when trained in isolation.
How to prune?
Magnitude pruning is simplest yet quite effective pruning method - weights with smallest absolute values are trimmed. In fact, some studies (Gale et al. 2019) found that simple magnitude pruning approaches can achieve comparable or better results than complicated pruning methods , such as variational dropout (Molchanov et al. 2017) and 𝑙0 regularization (Louizos et al. 2017). Magnitude pruning is simple to apply to large models and achieves reasonably consistent performance across a wide range of hyperparameters.
Zhu & Gupta (2017) found that large sparse models were able to achieve better performance than their small but dense counterparts . They proposed Gradual Magnitude Pruning (GMP) algorithm that increases the sparsity of a network gradually over the course of training. At each training step, weights with smallest absolute values are masked to be zeros to achieve a desired sparsity level 𝑠 and masked weights do not get gradient update during back-propagation. The desired sparsity level 𝑠 goes up with more training steps. The process of GMP is sensitive to the learning rate schedule, which should be higher than what's used in dense network training, but not too high to prevent convergence.
Iterative pruning (Renda et al. 2020) iterates step 2 (prune) & step 3 (retrain) multiple times: Only a small fraction of weights are pruned and the model is retrained in each iteration. The process repeats until a desired sparsity level is reached.
How to retrain?
The retraining step can be simple fine-tuning using the same pre-training data or other task-specific datasets.
Lottery Ticket Hypothesis proposed a weight rewinding retraining technique: After pruning, the unpruned weights are reinitialized back to original values earlier in the training and then retrain with the same learning rate schedule.
Learning rate rewinding (Renda et al. 2020) only resets the learning rate back to its early value, while the unpruned weights stay unchanged since the end of the last train stage. They observed that (1) retraining with weight rewinding outperforms retraining with fine-tuning across networks and datasets and (2) learning rate rewinding matches or outperforms weight rewinding in all tested scenarios.
Sparsity
Sparsity is an effective way to scale up model capacity while keeping model inference computationally efficient. Here we consider two types of sparsity for transformers:
- Sparsified dense layers, including both self-attention and FFN layers.
- Sparse model architecture; i.e. via incorporating the Mixture-of-Experts (MoE) component.
N:M Sparsity via Pruning
N:M sparsity is a structured sparsity pattern that works well with modern GPU hardware optimization, in which 𝑁 out of every 𝑀 consecutive elements are zeros. For example, the sparse tensor core of Nvidia A100 GPU has support for 2:4 sparsity for faster inference (Nvidia 2020).
Fig. 7. A matrix of 2:4 structured sparsity and its compressed representation. (Image source: Nvidia blog)
To sparsify a dense neural network to follow a N:M structured sparsity pattern, Nvidia (2020) suggested using the three-step routine workflow for training a pruned network: train --> prune to satisfy 2:4 sparsity --> retrain.
Permuting columns can provide more options in the pruning process to maintain parameters of large magnitude or to satisfy a special restriction like N:M sparsity (Pool & Yu 2021). As long as paired axes of two matrices are permuted in the same order, the results of matrix multiplication would not change. For example,
(1) Within the self-attention module, if the same permutation order is applied on the axis 1 of query embedding matrix 𝑄 and the axis 0 of key embedding matrix 𝐾⊤, the final result of matrix multiplication of 𝑄𝐾⊤ would stay the same.
Fig. 8. Illustration of same permutation on 𝑄 (axis 1) and 𝐾⊤ (axis 0) to keep the results of a self-attention module unchanged.
(2) Within the FFN layer that contains two MLP layers and one ReLU non-linear layer, we can permute the first linear weight matrix 𝑊1 along the axis 1 and the second linear weight matrix 𝑊2 along the axis 0 in the same order.
Fig. 9. Illustration of the same permutation on 𝑊1 (axis 1) and 𝑊2 (axis 0) to keep the FFN layer's output unchanged. For simplicity, the bias terms are skipped but the same permutation should be applied on them too.
To enforce N:M structured sparsity, let's split the columns of one matrix into multiple slides of 𝑀 columns (named "stripe") and we can easily observe that both the order of columns within each stripe and the order of stripes have no effect on the N:M sparsity restriction.
Pool & Yu (2021) proposed an iterative greedy algorithm to find optimal permutation that maximizes the weight magnitude for N:M sparsity. All pairs of channels are speculatively swapped and only the swap that leads to the greatest increase in magnitude is adopted, generating a new permutation and concluding a single iteration. Greedy algorithm may only find local minima, so they introduced two techniques to escape local minima:
- Bounded regressions: In practice two random channels are swapped, up to a fixed number of times. The solution search is limited to a depth of only one channel swap to keep the search space broad and shallow.
- Narrow, deep search: Choose multiple stripes and optimize them at the same time.
Fig. 10. Algorithm of finding the best permutation for N:M sparsity greedily and iteratively. (Image source: Pool & Yu 2021)
The network can achieve better performance if it was permuted before pruning, compared to pruning the network in its default channel order.
To train a model with N:M sparsity from scratch, Zhou & Ma, et al. (2021) extended STE (Straight-Through Estimator; Bengio et al. 2013), which is commonly used for back-propagation update in model quantization, to work for magnitude pruning and sparse parameter update.
STE computes the gradients of dense parameters wrt the pruned network 𝑊~, 𝜕𝐿/𝜕𝑊~, and applies that to the dense network 𝑊 as an approximation:
𝑊𝑡+1←𝑊𝑡−𝛾𝜕𝐿𝜕𝑊~
The extended version, SR-STE (Sparse-refined STE), updates the dense weights 𝑊 by:
𝑊𝑡+1←𝑊𝑡−𝛾𝜕𝐿𝜕𝑊~+𝜆𝑊(𝐸¯⊙𝑊𝑡)where 𝐸¯ is the mask matrix for 𝑊~ and ⊙ is element-wise multiplication. SR-STE is proposed to prevent large change in the binary mask by (1) restricting the values of weights pruned in 𝑊~𝑡, and (2) promoting the non-pruned weights in 𝑊~𝑡.
Fig. 11. Comparison of STE and SR-STE. ⊙ is element-wise product; ⊗ is matrix multiplication. (Image source: Zhou & Ma, et al. 2021)
Different from STE or SR-STE, the Top-KAST (Jayakumar et al. 2021) method can preserve constant sparsity throughout training in both the forward and backward-passes but does not require forward passes with dense parameters or dense gradients.
At one training step 𝑡, Top-KAST processes as follows:
- Sparse forward pass: Select a subset of parameters 𝐴𝑡⊂Θ, containing top-𝐾 parameters by magnitude by each layer, restricted to top 𝐷-proportion of weights. The parameterization 𝛼𝑡 at time 𝑡 has parameters zeroed out if it is not in 𝐴𝑡 (active weights).
𝛼𝑖𝑡={𝜃𝑖𝑡 if 𝑖∈𝐴𝑡={𝑖∣𝜃𝑖𝑡∈TopK(𝜃𝑡,𝐷)}0 otherwise
where TopK(𝜃,𝑥) selected top 𝑥 proportion of weights from 𝜃 based on magnitude.
- Sparse backward pass: Then apply gradients to a larger parameter subset 𝐵⊂Θ where 𝐵 contains (𝐷+𝑀)-proportion of weights and 𝐴⊂𝐵. Updating a larger proportion of weights enables more effective exploration of different pruning masks, making it more likely to cause permutations in the top 𝐷-proportion active weights.
Δ𝜃𝑖𝑡={−𝜂∇𝛼𝑡𝐿(𝑦,𝑥,𝛼𝑡)𝑖 if 𝑖∈𝐵𝑡={𝑖∣𝜃𝑖𝑡∈TopK(𝜃𝑡,𝐷+𝑀)}0 otherwise
Training is split into two stages and the additional coordinates in the set 𝐵∖𝐴 controls how much exploration is brought in. The amount of exploration is expected to diminish gradually through the training process and the mask eventually stabilizes.
Fig. 12. The pruning mask of Top-KAST stabilizes in time. (Image source: Jayakumar et al. 2021)
To prevent rich-get-richer phenomenon, Top-KAST penalizes the magnitude of active weights via a L2 regularization loss to encourage more exploration of new items. Parameters in 𝐵∖𝐴 are penalized more than 𝐴 for a higher selection bar during updates to stabilize the mask.
𝐿penalty(𝛼𝑖𝑡)={|𝜃𝑖𝑡| if 𝑖∈𝐴𝑡|𝜃𝑖𝑡|/𝐷 if 𝑖∈𝐵𝑡∖𝐴𝑡0 otherwise
Sparsified Transformer
Scaling Transformer (Jaszczur et al. 2021) sparsifies both self-attention and FFN layers in transformer architecture, achieving 37x speedup for single-example inference.
Fig. 13. The speed of decoding a single token (unbatched inference) by a transformer model when sparsification is applied on different layers. (Image source: Jaszczur et al. 2021)
Sparse FFN layer: Each FFN layer contains 2 MLP and one ReLU in-between. Because ReLU will introduce a lot of zeros, they implement a fixed structure on activations to enforce only 1 non-zero value in one block of 𝑁 elements. The sparsity pattern is dynamic, different for each token.
𝑌sparse=max(0,𝑥𝑊1+𝑏1)⊙Controller(𝑥)SparseFFN(𝑥)=𝑌sparse𝑊2+𝑏2Controller(𝑥)=argmax(Reshape(𝑥𝐶1𝐶2,(−1,𝑁)))
where each activation in 𝑌sparse corresponds to one column in 𝑊1 and one row in 𝑊2. The controller is implemented as a low-rank bottleneck dense layer, 𝐶1∈𝑅𝑑model×𝑑lowrank,𝐶2∈𝑅𝑑lowrank×𝑑ff and 𝑑lowrank=𝑑model/𝑁. It uses argmax for inference to select which columns should be non-zero and Gumbel-softmax trick (Jang et al. 2016) during training. Because we can compute Controller(𝑥) before loading FFN weight matrices, we know which columns will be zeroed out and thus choose not to load them into memory for inference speedup.
Fig. 14. (a) Sparse FFN layer; columns in red are not loaded in memory for faster inference. (b) Sparse FFN controller for 1:4 sparsity. (Image source: Jaszczur et al. 2021) *Lilian's side note*: Fig (a) in the illustration from the paper is actually 𝑌sparse=max(0,(𝑥𝑊1+𝑏1)⊙Controller(𝑥)), but it doesn't change the results.
Sparse QKV (attention) layer: In the attention layer, the dimensionality 𝑑model is divided into 𝑆 modules, each of size 𝑀=𝑑model/𝑆. To make sure each subdivision can access any part of the embedding, Scaling Transformer introduces a multiplicative layer (i.e., a multiplication layer multiplies inputs from multiple neural network layers element-wise) which can represent arbitrary permutation but contains fewer parameters than a dense layer.
Given an input vector 𝑥∈𝑅𝑑model, the multiplicative layer outputs 𝑦∈𝑅𝑆×𝑀:
𝑦𝑠,𝑚=∑𝑖𝑥𝑖𝐷𝑖,𝑠𝐸𝑖,𝑚where 𝐷∈𝑅𝑑model×𝑆,𝐷∈𝑅𝑑model×𝑀
The output of the multiplicative layer is a tensor of size ∈𝑅batch size×length×𝑆×𝑀. It then gets processed by a two-dimensional convolutional layer, where length and 𝑆 are treated as the height and width of an image. Such a convolution layer further reduces the parameter count and computation time of attention layer.
Fig. 15. (a) A multiplicative layer is introduced to enable partitions to access any part of an embedding. (b) Combination of multiplicative dense layer and 2-D convolutional layer reduces the number of parameters and computation time of the attention layer. (Image source: Jaszczur et al. 2021)
To better work with long sequences, Scaling Transformer is further equipped with LSH (locality-sensitive hashing) attention from Reformer (Kitaev, et al. 2020) and FFN block recurrence, resulting in Terraformer.
Mixture-of-Experts
Mixture-of-experts (MoE) models depend on a collection of "expert" networks and each example only activates a subset of networks to get predictions. The idea originated back to the 1990s (Jacobs et al. 1991) and is strongly related to ensemble methods. For details on how to incorporate MoE module into transformer, please check my previous post on large model training techniques and a survey paper on MoE by Fedus et al. 2022.
With MoE architecture, only partial parameters are utilized at decoding time and therefore it saves inference cost. The capacity of each expert can be adjusted with a hyperparameter, capacity factor 𝐶, and the expert capacity is defined as:
Expert capacity=round(𝐶⋅𝑘⋅total # tokens in one batch# experts)
where top-𝑘 experts are selected per token. Larger 𝐶 leads to higher expert capacity and improved performance but more expensive computationally. When 𝐶>1, a slack capacity is added; otherwise, when 𝐶<1, the routing network needs to ignore some tokens.
Routing Strategy Improvement
MoE layer has a routing network to assign a subset of experts for each input token. The routing strategy in vanilla MoE models is to route each token toward preferred experts differently as they come up in the natural order. If a token is routed to experts that have reached their capacity, the token would be marked "overflowed" and skipped.
V-MoE (Vision MoE; Riquelme et al. 2021) adds MoE layers into ViT (Vision Transformer). It matches the performance of previous SoTA but only requires half of inference compute. V-MoE can be scaled up to 15B parameters. Their experiments used 𝑘=2, 32 experts and every-2 expert placement (meaning that MoEs are placed in every other layer).
Since each expert has a limited capacity, some important and informative tokens may have to be discarded if they come up too late in the predefined sequence order (e.g. the order of words in a sentence, or the order of image patches). To avoid such a drawback in the vanilla routing scheme, V-MoE adopts BPR (Batch Priority Routing) to assign experts to tokens with a high priority score first. BPR computes a priority score (max or sum of top-𝑘 router scores) per token before expert assignment and alters the order of tokens accordingly. This guarantees that the expert capacity buffer would be fulfilled with key tokens first.
Fig. 16. How image patches are discarded according to priority scores when 𝐶<1. (Image source: Riquelme et al. 2021)
BPR works much better than vanilla routing when 𝐶≤0.5, where the model starts dropping a significant amount of tokens. It capacitates the model to be competitive with the dense network even at quite low capacities.
When looking into how to interpret image class-expert association, they observed that early MoE layers are more general, while later MoE layers could be specialized for a few image classes.
Task MoE (Task-level Mixture-of-Experts; Kudugunta et al. 2021 ) takes the task information into consideration and routes tokens at the task level instead of the word or token level for machine translation. They used MNMT (multilingual neural machine translation) as an example and group translation tasks based on the target language or language pairs.
Token level routing is dynamic and the routing decision for each token is made disjointly. Hence, at inference time, the server needs to preload all the experts. In comparison, task level routing is static given a fixed task, so the inference server for one task only needs to preload 𝑘 experts (assuming top-𝑘 routing). According to their experiments, Task MoE can achieve similar performance gain as token MoE compared to dense model baseline with 2.6x higher peak throughput and 1.6% of the decoder size.
Task level MoE is essentially to categorize a distribution of tasks according to predefined heuristics and incorporate such human knowledge into the router. When such heuristics do not exist (e.g. consider a general sentence continuation task), it would not be straightforward how to utilize Task MoE.
PR-MoE (Pyramid residual MoE; Rajbhandari et al. 2022) has each token pass one fixed MLP and one chosen expert. Due to the observation that MoE at later layers is more beneficial, PR-MoE adopts more exports at later layers. DeepSpeed library implements a flexible multi-expert, multi-data parallelism to enable training PR-MoE with different numbers of experts across layers.
Fig. 17. Illustration of PR-MoE architecture in comparison with a standard MoE. (Image source: Rajbhandari et al. 2022)
Kernel Improvement
Expert networks can be hosted on different devices. However, when the number of GPUs increases, the number of experts per GPU decreases and the communication between experts ("All-to-all") grows to be more expensive. All-to-all communication between experts across a number of GPUs relies on P2P APIs of NCCL, which cannot saturate the bandwidth of high-speed links (e.g. NVLink, HDR InfiniBand) at a large scale, as individual chunk gets smaller with more nodes used. The existing all-to-all algorithm performs poorly at large scale with a small workload. There are a variety of kernel improvements to enable more efficient MoE computation, such as making all-to-all communication cheaper/faster.
Both the DeepSpeed library (Rajbhandari et al. 2022) and TUTEL (Hwang et al. 2022) implemented a tree-based hierarchical all-to-all algorithm, which runs an intra-node all-to-all followed by an inter-node all-to-all. It reduces the communication hops from 𝑂(𝐺) to 𝑂(𝐺node+𝐺/𝐺node), where 𝐺 is the total number of GPU nodes and 𝐺node is the number of GPU cores per node. Although the communication volume is doubled in such implementation, it enables better scaling with small batches at large scale as the bottleneck is on latency instead of communication bandwidth when the batch size is small.
DynaMoE (Kossmann et al. 2022) uses dynamic recompilation to adapt the computational resources to dynamic workloads among experts. The RECOMPILE
mechanism compiles the computation graph from scratch and only reallocates resources when needed. It measures how many samples are assigned to each expert and adjusts their capacity factors 𝐶 dynamically, in order to reduce the memory and computation requirements at run time. Based on the observation that sample-expert assignments converge early in training, sample assignment caching is introduced after convergence and then RECOMPILE
is used to eliminate the dependency between the gating network and experts.
Architectural Optimization
The survey paper on Efficient Transformers (Tay et al. 2020) reviewed a collection of new transformer architectures with improvement for better computational and memory efficiency . Strongly recommend a read. You can also check out my post "The Transformer Family Version 2.0" for introduction to a diverse set of transformer archiecture improvements in depth, including changes to make the model cheaper to run.
Fig. 18. Categorization of efficient transformer models.
(Image source: Tay et al. 2020)
Since the self-attention mechanism has quadratic time and memory complexity and that is the main bottleneck for better transformer decoding efficiency, all the efficient transformer models have applied some form of sparsity to the otherwise dense attention layer. Here only lists a high-level overview, several derived from Tay et al. 2020.
Sparse Attention Patterns
-
Fixed Patterns limit the field of view for the attention matrix, using predefined, fixed patterns.
- Chunk input sequences into fixed blocks, such as Blockwise Attention;
- Image Transformer uses local attention;
- Sparse Transformer uses strided attention patterns.
-
Combined Patterns learn to sort/cluster the input tokens - enabling a more optimal global view of the sequence while maintaining the efficiency benefits of fixed patterns.
- Sparse Transformer combines strided and local attention;
- Given a high dimensional input tensor, instead of applying attention to the flattened version of the input, Axial Transformer applies multiple attentions, each along a single axis of the input tensor.
- ETC, Longformer and Big Bird combines local and global context, as well as strided or random attention.
-
Learnable Patterns identify the optimal attention pattern via learning.
- Reformer clusters tokens into clusters based on hash-based similarity (LSH);
- Routing Transformer runs 𝑘-means clustering on tokens;
- Sinkhorn Sorting Network learns to sort blocks of input sequence.
Recurrence
Recurrence mechanism connects multiple blocks/segments via recurrence.
- Transformer-XL makes use of longer context by reusing hidden states between segments.
- Universal Transformer combines self-attention with the recurrent mechanism in RNN.
- Compressive Transformer is an extension of Transformer-XL with additional memory, containing a set of memory slots for past activiations and compressive memory slots for compressed activations. Whenever the model accepts a new input segment, the oldest activations in the primary memory are moved to the compressed memory where a compression function is applied.
Memory Saving Designs
Memory saving designs refer to changes of the architecture to use less memory.
- Linformer projects the length dimension of keys and values to a lower-dimensional representation (𝑁→𝑘) and thus the memory complexity is reduced from 𝑁×𝑁 to 𝑁×𝑘.
- Shazeer (2019) proposed multi-query attention which has the keys and values shared across different attention "heads", greatly reducing the size of these tensors and the memory cost.
- Random feature attention and Performer use kernel methods to achieve a cheaper mathematical format of the self-attention mechanism.
Adaptive Attention
Adaptive attention enables the model to learn the optimal attention span or decide on when to do early exiting for different input tokens.
- Adaptive Attention Span trains the model to learn the optimal attention span per token per head via a soft mask between the token and other keys.
- Universal Transformer incorporates recurrent mechanism and uses ACT (Adaptive computation time) to dynamically decide the number of recurrent steps.
- Depth-Adaptive Transformer and CALM learns when to early exit the computation layers per token using some confidence measures to achieve good performance-efficiency tradeoffs.
Citation
Cited as:
Weng, Lilian. (Jan 2023). Large Transformer Model Inference Optimization. Lil'Log. https://lilianweng.github.io/posts/2023-01-10-inference-optimization/.
Or
@article{weng2023inference,
title = "Large Transformer Model Inference Optimization",
author = "Weng, Lilian",
journal = "Lil'Log",
year = "2023",
month = "Jan",
url = "https://lilianweng.github.io/posts/2023-01-10-inference-optimization/"
}
References
[1] Bondarenko et al. "Understanding and overcoming the challenges of efficient transformer quantization" ACL 2021.
[2] Dettmers et al. "LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale" NeuriPS 2022
[3] Zadeh et al. "Gobo: Quantizing attention-based NLP models for low latency and energy efficient inference." MICRO 2020
[4] Shen, Dong & Ye, et al. "Q-BERT: Hessian based ultra low precision quantization of BERT" AAAI 2020.
[5] Yao et al. "ZeroQuant: Efficient and affordable post-training quantization for large-scale transformers" arXiv preprint arXiv:2206.01861 (2022).
[6] Frantar et al. "GPTQ: Accurate Quantization for Generative Pre-trained Transformers" arXiv preprint arXiv:2210.17323 (2022).
[7] Xiao & Lin "SmoothQuant: Accelerated sparse neural training: A provable and efficient method to find N:M transposable masks." arXiv preprint arXiv:2211.10438 (2022). | code
[8] Pool & Yu. "Channel Permutations for N:M Sparsity." NeuriPS 2021. | code
[9] Zhou & Ma, et al. "Learning N:M fine-grained structured sparse neural networks from scratch." arXiv preprint arXiv:2102.04010 (2021).
[10] Jayakumar et al. "Top-KAST: Top-K Always Sparse Training." NeuriPS 2020.
[11] Nvidia. "Nvidia A100 tensor core GPU architecture." 2020.
[12] Gale, Elsen & Hooker "The State of Sparsity in Deep Neural Networks." arXiv preprint arXiv:1902.09574 (2019).
[13] Zhu & Gupta. "To Prune, or Not to Prune: Exploring the Efficacy of Pruning for Model Compression." arXiv preprint arXiv:1710.01878 (2017).
[14] Renda et al. "Comparing rewinding and fine-tuning in neural network pruning." arXiv preprint arXiv:2003.02389 (2020).
[15] Zhou & Ma, et al. "Learning N:M fine-grained structured sparse neural networks from scratch." arXiv preprint arXiv:2102.04010 (2021).
[16] Pool & Yu. "Channel Permutations for N:M Sparsity." NeuriPS 2021. | code
[17] Jaszczur et al. "Sparse is Enough in Scaling Transformers." NeuriPS 2021.
[18] Mishra et al. "An Survey of Neural Network Compression." arXiv preprint arXiv:1710.09282 (2017).
[19] Fedus et al. "A Review of Sparse Expert Models in Deep Learning." arXiv preprint arXiv:2209.01667 (2022)..
[20] Riquelme et al. "Scaling vision with sparse mixture of experts." NeuriPS 2021.
[21] Kudugunta et al. "Beyond Distillation: Task-level Mixture-of-Experts for Efficient Inference." arXiv preprint arXiv:2110.03742 (2021).
[22] Rajbhandari et al. "DeepSpeed-MoE: Advancing mixture-of-experts inference and training to power next-generation ai scale." arXiv preprint arXiv:2201.05596 (2022).
[23] Kossmann et al. "Optimizing mixture of experts using dynamic recompilations." arXiv preprint arXiv:2205.01848 (2022).
[24] Hwang et al. "Tutel: Adaptive mixture-of-experts at scale." arXiv preprint arXiv:2206.03382 (2022). | code
[25] Noam Shazeer. "Fast Transformer Decoding: One Write-Head is All You Need." arXiv preprint arXiv:1911.02150 (2019).
[26] Tay et al. "Efficient Transformers: A Survey." ACM Computing Surveys 55.6 (2022): 1-28.
[27] Pope et al. "Efficiently Scaling Transformer Inference." arXiv preprint arXiv:2211.05102 (2022).
[28] Frankle & Carbin. "The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks" ICLR 2019.
[29] Elabyad et al. "Depth-Adaptive Transformer" ICLR 2020.
[30] Schuster et al. "Confident Adaptive Language Modeling" arXiv preprint arXiv:2207.07061 (2022).
[31] Gou et al. "https://arxiv.org/abs/2006.05525" arXiv preprint arXiv:2006.05525 (2020).
[32] Hinton et al. "Distilling the Knowledge in a Neural Network" NIPS 2014.
[33] Sanh et al. "DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter" Workshop on Energy Efficient Machine Learning and Cognitive Computing @ NeuriPS 2019.
中文翻译
大型 Transformer 模型如今已经成为主流,为各种任务创造了 SOTA 结果。诚然这些模型很强大,但训练和使用起来代价非常昂贵。在时间和内存方面存在有极高的推理成本。概括来说,使用大型 Transformer 模型进行推理的难点,除了模型的规模不断扩大外,还有两个不可忽略的地方:
- 内存消耗大:推理时,需要把模型参数和中间状态都保存到内存中。例如:KV 存储机制下的缓存中的内容在解码期间需要存储在内存中,举例来说,对于 batch size 为 512,上下文长度为 2048 的设置来说,KV 缓存里需要的空间规模为 3TB,这是模型大小的 3 倍;注意力机制的推理成本和输入序列的长度呈正相关;
- 低并行性:推理生成过程以自回归的方式执行,使解码过程难以并行。
在这篇文章中,领导 OpenAI 应用研究的 Lilian Weng 写了一篇博客,文中介绍了几种提高 transformer 推理效率的方法。一些是通用的网络压缩方法,而另一些则应用于特定的 transformer 架构。
Lilian Weng 现为 OpenAI 应用人工智能研究负责人,主要从事机器学习、深度学习等研究 。她本科毕业于香港大学,硕士毕业于北京大学信息系统与计算机科学系,之后前往印度安纳大学布鲁顿分校攻读博士。
模型综述
通常将以下内容视为模型推理优化的目标:
- 使用更少的 GPU 设备和更少的 GPU 内存,减少模型的内存占用;
- 减少所需的 FLOP,降低计算复杂度;
- 减少推理延迟,运行得更快。
可以使用几种方法来降低推理过程在内存中的成本,并且加快速度。
- 在多 GPU 上应用各种并行机制来实现对模型的扩展。模型组件和数据的智能并行使得运行具有万亿级参数的大模型成为可能;
- 将暂时未使用的数据卸载到 CPU,并在以后需要时读回。这样做对内存比较友好,但会导致更高的延迟;
- 智能批处理策略;例如 EffectiveTransformer 将连续的序列打包在一起,以删除单个批次中的 padding;
- 神经网络压缩技术,例如剪枝、量化、蒸馏。就参数数量或位宽而言,小尺寸的模型应该需要少量的内存,也就运行得更快;
- 特定于目标模型架构的改进。许多架构上的变化,尤其是注意力层的变化,有助于提高 transformer 的解码速度。
本篇文章的重点是网络压缩技术和 transformer 模型在特定体系结构下的改进。
量化策略
在深度神经网络上应用量化策略有两种常见的方法:
- 训练后量化(PTQ):首先需要模型训练至收敛,然后将其权重的精度降低。与训练过程相比,量化操作起来往往代价小得多;
- 量化感知训练 (QAT):在预训练或进一步微调期间应用量化。QAT 能够获得更好的性能,但需要额外的计算资源,还需要使用具有代表性的训练数据。
值得注意的是,理论上的最优量化策略与实际在硬件内核上的表现存在着客观的差距。由于 GPU 内核对某些类型的矩阵乘法(例如 INT4 x FP16)缺乏支持,并非下面所有的方法都会加速实际的推理过程。
Transformer 量化挑战
许多关于 Transformer 模型量化的研究都有相同的观察结果:训练后将参数简单地量化为低精度(例如 8 位)会导致性能显着下降,这主要是由于普通的激活函数量化策略无法覆盖全部的取值区间。
图 1. 只将模型权重量化为 8 位,激活函数使用完整的精度的时候能取得较好的效果(W8A32);激活函数量化为 8 位时,无论权重是否为低精度(W8A8 和 W32A8)效果都不如 W8A32。
Bondarenko 等人在一个小型 BERT 模型中观察到,由于输出张量中存在强异常值,FFN 的输入和输出具有非常不同的取值区间。因此,FFN 残差和的逐个张量的量化可能会导致显著的误差。
随着模型参数规模继续增长到数十亿的级别,**高量级的离群特征开始在所有 transformer 层中出现,导致简单的低位量化效果不佳。**Dettmers 等人观察到大于 6.7B 参数的 OPT 模型就会出现这种现象。模型大了,有极端离群值的网络层也会变多,这些离群值特征对模型的性能有很大的影响。在几个维度上的激活函数异常值的规模就可以比其他大部分数值大 100 倍左右。
图 2. 不同规模的 OPT 模型在四个语言任务(WinoGrande、HellaSwag、PIQA、LAMBADA)上的平均零样本准确率。
混合精度量化
解决上述量化挑战的最直接方法是以不同的精度对权重和激活函数进行量化。
GOBO 模型是首批将训练后量化应用于 transformer 的模型之一(即小型 BERT 模型)。GOBO 假设每一层的模型权重服从高斯分布,因此可以通过跟踪每层的均值和标准差来检测异常值。异常值特征保持原始形式,而其他值被分到多个 bin 中,并且仅存储相应的权重索引和质心值。
基于对 BERT 中只有某些激活层(例如 FFN 之后的残差连接)导致性能大幅下降现象的观察,Bondarenko 等人通过在有问题的激活函数上使用 16 位量化而在其他激活函数上使用 8 位来采用混合精度量化。
LLM.int8 () 中的混合精度量化是通过两个混合精度分解实现的:
- 因为矩阵乘法包含一组行和列向量之间的独立内积,所以可以对每个内积进行独立量化。每一行和每一列都按最大值进行缩放,然后量化为 INT8;
- 异常值激活特征(例如比其他维度大 20 倍)仍保留在 FP16 中,但它们只占总权重的极小部分,不过需要经验性地识别离群值。
图 3.LLM.int8()两种混合精度分解方法。
细粒度量化
图 4. 不同粒度量化对比。d 是模型大小 / 隐空间维度,h 是一个 MHSA(多头自注意力)组件中的头数。
简单地量化一层中的整个权重矩阵(逐个张量或逐个层量化)是最容易实现的,但量化粒度往往不尽如人意。
Q-BERT 将分组量化应用于微调的 BERT 模型,将 MHSA(多头自注意力)中每个头的单个矩阵 W 视为一个组,然后应用基于 Hessian 矩阵的混合精度量化。
Per-embedding group (PEG) 激活函数量化的设计动机是观察到离群值仅出现在少数几个维度中。对每个嵌入层都量化的代价非常昂贵,相比之下,PEG 量化将激活张量沿嵌入维度分成几个大小均匀的组,其中同一组中的元素共享量化参数。为确保所有异常值都分组在一起,PEG 应用了一种基于取值范围的嵌入维度排列算法,其中维度按其取值范围排序。
ZeroQuant 与 Q-BERT 一样都对权重使用分组量化,然后还对激活函数使用了 token-wise 量化策略。为了避免代价昂贵的量化和反量化计算,ZeroQuant 构建了独特的内核来将量化操作与其之前的运算符融合。
使用二阶信息量化
Q-BERT 针对混合精度量化开发了 Hessian AWare 量化 (HAWQ)。其动机是,具有更高 Hessian 谱的参数对量化更敏感,因此需要更高的精度。这种方法本质上是一种识别异常值的方法。
从另一个角度来看,量化问题是一个优化问题。给定一个权重矩阵 W 和一个输入矩阵 X ,想要找到一个量化的权重矩阵 W^ 来最小化如下所示的 MSE 损失:
GPTQ 将权重矩阵 W 视为行向量 w 的集合,并对每一行独立量化。GPTQ 使用贪心策略来选择需要量化的权重,并迭代地进行量化,来最小化量化误差。更新被选定的权重会生成 Hessian 矩阵形式的闭合解。GPTQ 可以将 OPT-175B 中的权重位宽减少到 3 或 4 位,还不会造成太大的性能损失,但它仅适用于模型权重而不适用于激活函数。
异常值平滑
众所周知,Transformer 模型中激活函数比权重更难量化。SmoothQuant 提出了一种智能解决方案,通过数学等效变换将异常值特征从激活函数平滑到权重,然后对权重和激活函数进行量化 (W8A8)。正因为如此,SmoothQuant 具有比混合精度量化更好的硬件效率。
图 5. SmoothQuant 将尺度方差从激活函数迁移到离线权重,以降低激活函数量化的难度。由此产生的新权重和激活矩阵都易于量化。
基于每个通道的平滑因子 s,SmoothQuant 根据以下公式缩放权重:
根据平滑因子
可以很容易地在离线状态下融合到前一层的参数中。超参数 α 控制从激活函数迁移到权重的程度。该研究发现 α=0.5 是实验中许多 LLM 的最佳取值。对于激活异常值较大的模型,可以将 α 调大。
量化感知训练 (QAT)
量化感知训练将量化操作融合到预训练或微调过程中。这种方法会直接学习低位表示的模型权重,并以额外的训练时间和计算为代价获得更好的性能。
最直接的方法是在与预训练数据集相同或代表预训练数据集的训练数据集上量化后微调模型。训练目标可以与预训练目标相同(例如通用语言模型训练中的 NLL/MLM)或特定于的下游任务(例如用于分类的交叉熵)。
另一种方法是将全精度模型视为教师模型,将低精度模型视为学生模型,然后使用蒸馏损失优化低精度模型。蒸馏通常不需要使用原始数据集。
剪枝
网络剪枝是在保留模型容量的情况下,通过修剪不重要的模型权重或连接来减小模型大小。剪枝可能需要也可能不需要重新训练。剪枝可以是非结构化的也可以是结构化的。
- 非结构化剪枝允许丢弃任何权重或连接,因此它不保留原始网络架构。非结构化剪枝通常对硬件要求比较苛刻,并且不会加速实际的推理过程;
- 结构化剪枝不改变权重矩阵本身的稀疏程度,可能需要遵循某些模式限制才能使用硬件内核支持的内容。本文专注于那些能实现 transformer 模型的高稀疏性的结构化剪枝。
构建剪枝网络的常规工作流程包含三个步骤:
-
训练密集型的神经网络直到收敛;
-
修剪网络以去除不需要的结构;
-
(可选择)重新训练网络,让新权重保持之前的训练效果。
通过剪枝在密集模型中发现稀疏结构,同时稀疏网络仍然可以保持相似性能的灵感是由彩票假设激发的:这是一个随机初始化的密集前馈网络,它包含一个子网络池。其中只有一个子集(稀疏网络)是中奖彩票(winning tickets),这个中奖彩票在独立训练时可以达到最佳性能。
如何剪枝
Magnitude pruning 是最简单但同时又非常有效的剪枝方法 - 只裁剪那些绝对值最小的权重。事实上,一些研究发现,简单的量级剪枝方法可以获得与复杂剪枝方法相当或更好的结果,例如变分 dropout 和 l_0 正则化。Magnitude pruning 很容易应用于大型模型,并在相当大的超参数范围内实现相当一致的性能。
Zhu & Gupta 发现,大型稀疏模型能够比小型但密集的模型获得更好的性能。他们提出了 Gradual Magnitude Pruning (GMP) 算法,该算法在训练过程中逐渐增加网络的稀疏性。在每个训练步骤中,具有最小绝对值的权重被屏蔽为零以达到所需的稀疏度并且屏蔽的权重在反向传播期间不会得到梯度更新。所需的稀疏度随着训练步骤的增加而增加。GMP 过程对学习率步长策略很敏感,学习率步长应高于密集网络训练中所使用的,但不能太高以防止收敛。
迭代剪枝多次迭代上述三个步骤中的第 2 步(剪枝)和第 3 步(重新训练),每次只有一小部分权重被剪枝,并且在每次迭代中重新训练模型。不断重复该过程,直到达到所需的稀疏度级别。
如何再训练
再训练可以通过使用相同的预训练数据或其他特定于任务的数据集进行简单的微调来实现。
Lottery Ticket Hypothesis 提出了一种权重 rewinding 再训练方法:剪枝后,将未剪枝的权重重新初始化回训练初期的原始值,然后以相同的学习率时间表进行再训练。
学习率 rewinding 仅将学习率重置回其早期值,而保持未剪枝的权重自最后一个训练阶段结束以来不变。研究者观察到 (1) 使用权重 rewinding 的再训练结果优于通过跨网络和数据集进行微调的再训练,以及 (2) 在所有测试场景中学习率 rewinding 与权重 rewinding 的效果持平甚至更优。
稀疏化
稀疏化是扩大模型容量同时保持模型推理计算效率的有效方法。本文考虑两种类型的 transformer 稀疏性:
- 稀疏化的全连接层,包括自注意力层和 FFN 层;
- 稀疏模型架构,即 MoE 组件的合并操作。
通过剪枝实现的 N:M 稀疏化
N:M 稀疏化是一种结构化的稀疏化模式,适用于现代 GPU 硬件优化,其中每 M 个连续元素中的 N 个元素为零。例如,英伟达 A100 GPU 的稀疏张量核心支持 2:4 稀疏度以加快推理速度。
图 6. 2:4 结构化稀疏矩阵及其压缩表示。
为了使密集型神经网络的稀疏化遵循 N:M 结构化稀疏模式,英伟达建议使用三步操作来训练剪枝后的网络:训练 --> 剪枝以满足 2:4 稀疏性 --> 重新训练。
(1) 对矩阵中的列进行排列可以在剪枝过程中提供更多可能,以保持参数的数量或满足特殊限制,如 N:M 稀疏性。只要两个矩阵对应的轴按相同的顺序排列,矩阵乘法的结果就不会改变。例如,(1) 在自注意力模块中,如果 query 的嵌入矩阵 Q 的轴 1 和 key 嵌入矩阵 K^⊤的轴 0 采用相同的排列顺序,则 QK^⊤的矩阵乘法最终结果保持不变。
图 7. Q(轴 1)和 K^⊤(轴 0)上相同排列,自注意力模块的结果不变。
(2) 在包含两个 MLP 层和一个 ReLU 非线性层的 FFN 层内,可以将第一个线性权重矩阵 W_1 沿轴 1 排列,然后第二个线性权重矩阵 W_2 沿轴 0 按相同顺序排列。
图 8. W_1(轴 1)和 W_2(轴 0)上有着相同的排列,可以保持 FFN 层的输出不变。为简单起见,图示省略了偏差项,但也应对它们应用相同的排列。
为了推动 N:M 结构稀疏化,需要将一个矩阵的列拆分为 M 列的多个 slide(也称为 stripe),这样可以很容易地观察到每个 stripe 中的列顺序和 stripe 的顺序对 N:M 稀疏化产生的限制。
Pool 和 Yu 提出了一种迭代式的贪心算法来寻找最优排列,使 N:M 稀疏化的权重幅度最大化。所有通道对都被推测性地交换,并且只采用幅度增加最大的交换,然后生成新的排列并结束单次迭代。贪心算法可能只会找到局部极小值,因此他们引入了两种技术来逃避局部极小值:
-
有界回归:在实践中,两个随机通道的最大交换次数是固定的。每次搜索只有一个通道可以进行交换,以保持搜索空间宽而浅;
-
窄且深的搜索:选择多个 stripe 并同时优化它们。
图 9. 贪心算法实现迭代地寻找 N:M 稀疏化最佳排列的算法。
与按默认通道顺序对网络进行剪枝相比,如果在剪枝之前对网络进行置换,可以获得更好的性能。
为了从头开始训练具有 N:M 稀疏化的模型,Zhou & Ma 扩展了常用于模型量化中的反向传播更新的 STE,用于幅度剪枝和稀疏参数更新。
STE 计算剪枝后的网络
的密集参数的梯度
,并将其作为近似值应用于稠密网络 W:
STE 的扩展版本 SR-STE(稀疏精化 STE)通过以下方式更新稠密权重 W:
其中
是
的掩码矩阵,⊙是元素对应位置相乘。SR-STE 通过(1)限制
中对权重的剪枝,以及(2)维持
中未被剪枝的权重,来防止二进制掩码剧烈变化。
图 10. STE 和 SR-STE 的对比。⊙的比较是元素乘积;⊗是矩阵乘法。
与 STE 或 SR-STE 不同,Top-KAST 方法可以在前向和反向传播的整个训练过程中保持恒定的稀疏性,还不需要使用具有稠密参数或梯度的前向传播。
在训练到第 t 步时,Top-KAST 过程如下:
稀疏前向传递:选择参数
的一个子集,包含每层按大小排列的前 K 个参数,限制为权重的前 D 比例。如果时间 t 的参数化 α^t 不在 A^t(活动权重)中,则参数化为零。
其中 TopK (θ,x) 是根据大小排序后从 θ 中的前 x 个权重。
稀疏向后传递:然后将梯度应用于更大的参数子集
, 其中 B 包含 (D+M), A⊂B。扩大需要更新的权重比例可以更有效地探索不同的剪枝掩码,从而更有可能将前 D% 的激活权重排列好。
训练分为两个阶段,集合 B∖A 中的附加坐标控制引入的探索量。探索量会在训练过程中逐渐减少,最终掩码会稳定下来。
图 11. Top-KAST 的剪枝掩码会随时间稳定下来。
为了防止马太效应,Top-KAST 通过 L2 正则化损失来惩罚激活权重,以鼓励产生更多新的探索。在更新期间,B∖A 中的参数比 A 受到更多的惩罚以稳定掩码。
稀疏 Transformer
稀疏 Transformer 将 Transformer 架构中的自注意力层和 FFN 层稀疏化,使单个样本推理的速度提高了 37 倍。
图 12. 当在不同网络层上应用稀疏化时,Transformer 模型解码单个 token(非批量推理)的速度。
稀疏 FFN 层:每个 FFN 层包含 2 个 MLP 和中间的一个 ReLU。因为 ReLU 会引入很多零值,所以该方法在激活函数上设计了一个固定结构,来强制要求在一个包含 N 个元素的块中只包含 1 个非零值。稀疏模式是动态的,每个 token 都不同。
其中 Y_(sparse ) 中的每个激活函数结果对应于 W_1 中的一列和 W_2 中的一行。控制器是一个低秩的 bottleneck 全连接层,其中
、
在训练期间使用 argmax 进行推理以选择哪些列应为非零和,以及 Gumbel-softmax 技巧 。因为可以在加载 FFN 权重矩阵之前计算 Controller (x),所以可以知道哪些列将被清零,因此选择不将它们加载到内存中以加快推理速度。
图 13. (a) 稀疏 FFN 层;红色列未加载到内存中以进行更快的推理。(b) 1:4 稀疏度的稀疏 FFN 控制器。
稀疏注意力层:在注意力层中,维度 d_(model) 被划分为 S 个模块,每个模块的大小为 M=d_(model)/S。为了确保每个细分都可以访问嵌入的任何部分,Scaling Transformer 引入了一个乘法层(即,一个乘法层将来自多个神经网络层的输入按元素相乘),它可以表示任意排列,但包含的参数少于全连接层。
给定输入向量
,乘法层输出
:
乘法层的输出是一个大小为
的张量。然后由二维卷积层对其进行处理,其中 length 和 S 被视为图像的高度和宽度。这样的卷积层进一步减少了注意力层的参数数量和计算时间。
图 14. (a) 引入乘法层以使分区能够访问嵌入的任何部分。(b) 乘法全连接层和二维卷积层的结合减少了注意力层的参数数量和计算时间。
为了更好地处理长序列数据,Scaling Transformer 进一步配备了来自 Reformer 的 LSH(局部敏感哈希)注意力和 FFN 块循环,从而产生了 Terraformer 模型。
混合专家系统 MoE
专家混合系统 (MoE) 模型是一种专家网络的集合,每个样本仅激活网络的一个子集来获得预测结果。这个想法起源于上世纪九十年代并且与集成方法密切相关。有关如何将 MoE 模块合并到 Transformer 的详细信息,可以查看本文作者之前写的关于大型模型训练技术的帖子和 Fedus 等人关于 MoE 的论文。
使用 MoE 架构,在解码时仅使用部分参数,因此节省了推理成本。每个专家的容量可以通过超参数容量因子 C 进行调整,专家容量定义为:
每个 token 需要选择前 k 个专家。较大的 C 会扩大专家容量,提高性能,但这样做计算成本更高。当 C>1 时,需要增加一个松弛容量;当 C<1 时,路由网络需要忽略一些 token。
路由策略改进
MoE 层有一个路由网络来为每个输入 token 分配一个专家子集。原生 MoE 模型中的路由策略是将每个 token 以不同的方式路由到按自然顺序出现的首选专家。如果路由到的专家已经没有多余的空间,token 将被标记为溢出并被跳过。
V-MoE 将 MoE 层添加到 ViT (Vision Transformer) 中。它与以前的 SOTA 模型的性能相匹配,但只需要一半的推理计算。V-MoE 可以扩展成一千五百万个参数。有研究者在实验中将 k=2、专家需要 32 位,每 2 位专家间放置一层 MoE。
由于每个专家的能力有限,如果某些重要且信息丰富的 token 在预定义的序列顺序(例如句子中的单词顺序或图像中 patch 的顺序)中出现得太晚,则可能不得不丢弃它们。为了避免原生路由方案中的这种缺陷,V-MoE 采用 BPR(批量优先路由)首先将专家分配给具有高优先级分数的 token。BPR 在专家分配之前计算每个 token 的优先级分数(前 k 名路由器得分的最大值或总和),并相应地更改 token 的顺序。这保证了核心的 token 能优先使用专家容量的缓冲区。
图 15. 当 C<1 时,根据优先级分数丢弃图像 patch 的方式。
当 C≤0.5 时,BPR 比普通路由效果更好,此时模型开始丢弃大量 token。这使模型即使在非常低的容量下也能与稠密网络一较高低。
在研究如何解释图像的类别与专家之间的关系时,研究者观察到早期的 MoE 层更通用,而后期的 MoE 层可以专门用于某类图像。
任务级 MoE 将任务信息考虑在内,并且将路由 token 在任务级的视角来处理。研究者以 MNMT(多语言神经机器翻译)为例,根据目标语言或语言对进行翻译任务分组。
Token 级路由是动态的,每个 token 的路由决策是不相交的。因此,在推理时,服务器需要预加载所有专家。相比之下,任务级路由是静态的,甚至是固定的任务,因此一个任务的推理服务器只需要预加载 k 个专家(假设 top-k 才有路由)。根据研究者的实验,与稠密模型的 baseline 相比,任务级 MoE 可以实现与 token MoE 类似的性能增益,峰值吞吐量高 2.6 倍,解码器小 1.6%。
任务级 MoE 本质上是根据预定义的启发式方法对任务分布进行分类,并将此类人类知识纳入路由器。当这种启发式不存在时,任务级 MoE 就难以使用了。
PR MoE 让每个 token 通过一个固定的 MLP 和一个选定的专家。由于靠后的 MoE 更有价值,PR MoE 在靠后的层上设计了更多的出口。DeepSpeed 库实现了灵活的多专家、多数据并行,以支持使用不同数量的专家来训练 PR MoE。
图 16。PR MoE 架构与标准 MoE 的对比图。
内核方面的改进措施
专家网络可以托管在不同的设备上。然而,当 GPU 数量增加时,每个 GPU 上的专家数量就会减少,专家之间的通信成本变得更加昂贵。跨多个 GPU 的专家之间的多对多通信依赖于 NCCL 的 P2P API,这个接口不能占据高速链路所有的带宽,这是因为使用的节点越多,单个 chunk 越小。现有的多对多算法在大规模问题上性能较差,单个 GPU 的工作量不能提升。针对这种情况,有多种内核改进措施来实现更高效的 MoE 计算,例如使多对多通信更便宜 / 更快。
DeepSpeed 库和 TUTEL 都实现了基于树的分层多对多算法,该算法在节点内使用多对多算法处理,然后再在节点间实现多对多。这种算法将通信跳数从 O(G)减少到
,其中 G 是 GPU 节点的总数,G_(node) 是每个节点的 GPU 内核数。尽管在这样的实现中通信量增加了一倍,但当批大小较小时 1×1 卷积层存在延迟,因此可以更好地扩展 batch 的规模。
DynaMoE 使用动态再编译使计算资源适应专家之间的动态工作负载。再编译机制需要从头开始编译计算图,并且只在需要时重新分配资源。它会琢磨分配给每个专家的样本数量,并动态调整其容量因子 C,以减少运行时的内存和计算需求。这种方法基于在训练早期对专家和样本的分配关系的观察,在模型收敛后引入样本分配缓存,然后使用再编译算法消除门控网络和专家之间的依赖性。
架构优化
论文《Efficient Transformers: A Survey》回顾了一系列新的 Transformer 架构,并针对提高计算和内存效率进行了一些改进,除此以外,大家还可以阅读这篇文章《The Transformer Family》,以深入了解几种类型的 Transformer 改进。
图 17. 高效 transformer 模型的分类
自注意力机制的二次时间复杂度和内存复杂性问题是提高 transformer 解码效率的主要瓶颈,因此所有高效 transformer 模型都对原本稠密的注意力层应用了某种形式的稀疏化措施。
- 固定模式:使用预定义的固定模式限制注意力矩阵的感受野:
- 可以将输入序列分成固定的块;
- 图像 transformer 使用了局部注意力;
- 稀疏 transformer 使用了跨线注意力模式;
- Longformer 使用了 dilated 注意力窗口;
- 可以使用 strided 卷积压缩注意力来减少序列长度。
- 组合模式:对输入的 token 进行排序 / 聚类以实现更优化的序列全局视图,同时保持固定模式的效率优势
- 稀疏 transformer 结合了跨步和局部注意力;
- 给定高维输入张量,axial transformer 不会将输入 flattened 后再使用注意力机制,而是使用多注意力机制,一个注意力对应着输入张量的一个轴;
- Big Bird 模型设计了一些关键组件,即(1)全局 token,(2)随机注意力(query 向量随机绑定 key 向量)和(3)固定模式(局部滑动窗口)。
- 可学习模式:通过学习确定最佳注意力模式:
- Reformer 使用局部敏感哈希将 token 聚类;
- 路由 transformer 用 k-means 将 token 聚类;
- Sinkhorn 排序网络会对输入序列块的排序算法进行学习。
- 递归:通过递归连接多个 block/segment:
- Transformer-XL 通过在 segment 之间重用隐藏状态来获取更长的上下文;
- 通用 transformer 将自注意力与 RNN 中的循环机制相结合;
- Compressive transformer 是 Transformer-XL 的扩展,具有额外的内存,具有 n_m 个内存槽和 n_(cm) 个压缩内存槽。每当有新的输入段被输入到模型当中时,主内存中最久未更新的前 n_s 个激活函数都会被转移到压缩内存中。
5.Side Memory:使用可以一次访问多个 token 的 Side Memory 模块
- Set Transformer 设计了一种受归纳点方法启发的新注意力;
- ETC(Extended transformer construction)是 Sparse Transformer 的变体,具有新的全局 - 局部注意力机制;
- Longformer 也是 Sparse Transformer 的变体,使用 dilated 滑动窗口。随着模型网络的深入,感受野也会逐渐增加。
- 节省内存:更改架构以使用更少的内存:
- Linformer 将 key 和 value 的代表长度的维度投影到低维表示(N→k),因此内存复杂度从 N×N 降低到 N×k;
- Shazeer 等人提出了多 query 注意力,在不同注意力头之间共享 key 和 value,大大减少了这些张量的大小和内存成本。
-
使用内核:使用内核可以让自注意力机制的公式书写起来更简单。需要注意的使,这里的内核是指内核方法中的内核,而不是 GPU 操作程序。
-
自适应注意力:让模型学习最佳注意力广度,或决定何时按每个 token 提前退出:
- 自适应注意力广度训练模型,通过 token 和其他 key 之间的 soft mask 机制,为每个 token、每个注意力头学习最佳的注意力广度;
- 通用 transformer 结合了循环机制,并使用 ACT(自适应计算时间)来动态决定循环几次;
- 深度自适应 transformer 和 CALM 使用一些置信度度量方法来学习何时提前退出每个 token 的计算层,这样可以在性能和效率之间找到一种平衡。