Tree Attention: Topology-Aware Decoding for Long-Context Attention on GPU Clusters

paper review

Published

Friday, September 6, 2024

in (Shyam et al. 2024) the authors

Self-attention is the core mathematical operation of modern transformer architectures and is also a significant computational bottleneck due to its quadratic complexity in the sequence length. In this work, we derive the scalar energy function whose gradient computes the self-attention block, thus elucidating the theoretical underpinnings of self-attention, providing a Bayesian interpretation of the operation and linking it closely with energy-based models such as Hopfield Networks. Our formulation reveals that the reduction across the sequence axis can be efficiently computed in parallel through a tree reduction. Our algorithm, for parallelizing attention computation across multiple GPUs enables cross-device decoding to be performed asymptotically faster (up to 8× faster in our experiments) than alternative approaches such as Ring Attention, while also requiring significantly less communication volume and incurring 2× less peak memory.

code

  1. Introduction

The self-attention operation is the core computational building block of the transformer architecture [1, 2], which has become an ubiquitous and highly effective workhorse architecture currently applied at scale to language [3–7], vision [8], audio [9], and decision-making [10, 11]. Nonetheless, the quadratic time complexity of self-attention means that significant resources are required to train and generate from transformer-based Large Language Models (LLMs), especially for models with large context lengths.

During inference, the attention block largely determines the computational and memory requirements, which become more demanding as the input sequence length increases. Although LLMs generate one token at a time, the entire sequence of past tokens must still be stored in memory and used to compute attention scores during generation. Since attention performs a similarity matching of every token representation with every other, it incurs quadratic computational complexity in terms of flops.

There have been recent advances in training LLMs to handle extremely long contexts (up to 1M tokens)[12–14].Such models attain qualitatively new capabilities such as extremely large-scale in-context learning of entire small datasets held in the prompt [15–17]. They can also avoid putting multi-modal continuous data through a lossy tokenization scheme [15, 18] by directly operating at the byte level [19, 20]. The issue however is that performing inference on such long contexts is very expensive.

To speed up inference and alleviate memory requirements, recent works have attempted to alter the attention mechanism itself, either by linearizing it [21], or approximating it by a kernel map [22–24], which reduces the complexity to linear at the cost of reduced expressiveness.

Correspondence to: vasu@zyphra.com, jonathan@zyphra.com Others have invented alternative sequence mixing archi- tectures such as state-space models which are designed to be efficiently computable in linear time and constant memory [25–29]. Other approaches utilize efficient algo- rithms to reduce the computational burden of attention while keeping the core computation the same. These in- clude memory-efficient attention [30], Flash Attention [31] and Flash Decoding [32], which provide a set of IO- aware kernels to map the attention operation to the GPU hardware resources in an extremely efficient way, signifi- cantly reducing the memory overhead required. Further works [33–36] explore compressing or otherwise reduc- ing the KV cache required in generation. Finally, Ring Attention [37] proposes a way to parallelize the atten- tion computation across the sequence axis between GPUs, thus enabling significantly longer contexts than can be served on a single GPU. This is the regime of primary interest of this paper. By leveraging the exact energy function for the self-attention block, we develop a method to speed up inference for long context use-cases when keys and values are sharded across multiple GPUs along the sequence axis.

Our proposed algorithm for computing attention via the gradient of the energy function is built on top of an efficient parallel computation and tree reduction commu- nication strategy. In particular, this formulation lets us devise an asymptotically faster algorithm for performing decoding in which the number of communication steps scales logarithmically with the number of devices, instead of linearly in alternatives such as Ring Attention [37]. Our topology-aware approach illustrated in Fig. 1 sig- nificantly outperforms leading attention parallelization methods such as Ring Attention on multiple devices. In this work, we make three core contributions: •We provide a mathematical form for the energy function of self-attention.

•From this theory, we develop an algorithm for par- allelizing the attention computation across devices, leveraging tree-reduction topology. arXiv:2408.04093v3 [cs.LG] 14 Aug 2024

References

Shyam, Vasudev, Jonathan Pilault, Emily Shepperd, Quentin Anthony, and Beren Millidge. 2024. “Tree Attention: Topology-Aware Decoding for Long-Context Attention on GPU Clusters.” https://arxiv.org/abs/2408.04093.