Tree Attention: Topology-Aware Decoding for Long-Context Attention on GPU Clusters

paper review

Published

Tuesday, September 17, 2024

TL;DR

in (Shyam et al. 2024) the authors propose a new algorithm for parallelizing attention computation across multiple GPUs. This enables cross-device decoding to be performed asymptotically faster (up to 8 x faster in our experiments) than alternative approaches such as Ring Attention, while also requiring significantly less communication volume and incurring 2 x less peak memory.

Abstract

Self-attention is the core mathematical operation of modern transformer architectures and is also a significant computational bottleneck due to its quadratic complexity in the sequence length. In this work, we derive the scalar energy function whose gradient computes the self-attention block, thus elucidating the theoretical underpinnings of self-attention, providing a Bayesian interpretation of the operation and linking it closely with energy-based models such as Hopfield Networks. Our formulation reveals that the reduction across the sequence axis can be efficiently computed in parallel through a tree reduction. Our algorithm, for parallelizing attention computation across multiple GPUs enables cross-device decoding to be performed asymptotically faster (up to 8× faster in our experiments) than alternative approaches such as Ring Attention, while also requiring significantly less communication volume and incurring 2 x less peak memory.

(Shyam et al. 2024)

Resources

code

Softmax

The softmax operation can be derived as the gradient of the following scalar function:

\delta z_i \sum_{a=1}^n exp(z_a) = \frac{e^{z_j}}{\sum^n_{a=1} e^{z_j}} = \text{softmax}(z_j)

References

Shyam, Vasudev, Jonathan Pilault, Emily Shepperd, Quentin Anthony, and Beren Millidge. 2024. “Tree Attention: Topology-Aware Decoding for Long-Context Attention on GPU Clusters.” https://arxiv.org/abs/2408.04093.