Explanation of FlexAttention by the team that developed it.
This is a quick note about FlexAttention. I always thought Attention as a design pattern for neural networks. But I saw that this idea took on many forms as many different attention mechanisms were developed in the wake of the transformer revolution in Deep Learning.
One problem which I felt kept attention mechanisms from being a design pattern is that it is too coupled to the rest of the implementation. This is both a practical and conceptual problem. Design patterns work best when they are modular and can be combined in different ways. So far I could see just two ways - in an RNN or in a transformer.
I wanted the power of attention mechanisms in other models too like probabilistic programming and hierarchical bayesian models. And also as something that can be used in a more modular way.
We see a number of moving parts in the attention mechanism design.
- First is the choice of attention mechanism.
- Secondly we need to handle subtleties like masking and padding to get these to work.
- We may want to add a weighting mechanism to tokens based on the distance between context and current token.
- We may want to use capping introduced to limit the logits in Model like Grok-1 and Gemma 2 from exploding.
When FlashAttention came out I saw that there was a big push to make attention faster and aware of the hardware it runs on.
So we need design with all these moving parts and optimize for different hardware.
This is a level of complexity that should be relegated to libraries compilers and handled by AI Engineers not poor data scientists who are trying to make models work with the data.
Here at last is a flexible implementation of it for PyTorch. That allows data scientists to mix and match attention mechanisms to their needs.
Explanation of How Chat GPT 3 and 4 work.
Viewing Andrej Karpathy’s Videos e.g. Let’s build GPT: from scratch, in code, spelled out. I came to realize that many of the big name papers in Deep Learning and architectures amount to one or two lines of code in our favorite framework. And that even if we write it from scratch again we are down to adding in a small function or two.
This gave me hope that reading these famous papers with the goal of actually implementing them boils down to a manageable task.
One paper that caught my attention was the Mixture of Recursion. Reading about it, I saw that one of the improvements they hoped to implement was FlexAttention. And they also talked about document masking. I wan’t sure what was that? It turns out that we may be working with many documents and we want to restrict attention to look at just tokens from the same document.
I really spent a lot of time on attention mechanism. I think that in many ways these are the core part of the Transformer architecture and to some degree even for the earlier RNNs which had it.
In fact attention is something that would probably be useful in other ML and statistical models.
Flex attention aims to provide a drop in replacement for the attention mechanism in Transformers. It is based on the idea that we can use a mixture of different attention mechanisms, each with its own strengths and weaknesses, to improve the overall performance of the model.
We can use it to quickly try out different attention mechanisms and see which one works best for our specific task.
This qualifies it for the “bag of tricks” as it is a small addition to our code that can have a big impact on performance.
It also serves to summarize in practical terms many ideas about attention mechanisms in one place.
While I can’t deep dive into FlexAttention here. The first video does a great job of explaining it. We get an api to define the different moving parts and FlexAttention then combines them efficiently in a way that is optimized for the hardware we are using.
Here is a third
I do want to provide a quick overview of what I want to use it for so I can come back to it later.
- I came on an early paper where there was just the one attention head and all that it did was count so that the output sequence had coverage of the input sequence. (CN?)
- I always thought that we should be able to give a model a few examples of what we consider it should pay attention to and let it train a head on that.
- Stating this in from a more general bayesian view of thing I want the attention heads to learn certain features. To do this we should be able to equip them with some priors. This would bootstrap the learning process by defining a divide and conquer strategy.
- competitive heads Heads should be competitive. They should try to avoid features covered by other heads. This idea can be formalized using game theory and there are some similar ideas in papers by hinton from before there was attention. He talks about it in his course and the main thrust of it was that there are ideas to make experts compete and specialize or to work as a group by correcting each other. I think we are talking about bagging vs boosting in more modern terms. He suggests that that the second approach leads to problems - some agents can be poor and then the rest work extra hard to accommodate them. Trying to track the actual paper showed that he looked into this for a long time and there are many related papers on this topic. This work is of extra interest here as we can use these architectures (mixtures of experts) with attention heads and more generally using bayesian ideas to combine heads and exploit parallelism.
- We should have the option of having some heads to look at the “residual” i.e. find features not covered by other heads.
- We should also have the option to do this in an automated fashion and see what the heads are learning. This would not only be be a form of unsupervised learning but a very powerful EDA tool.
- N-level featues While we said the heads should be competitive we recognize that many problems are multilevel and that we may need to build features based on features learned by other heads.
- Here the ideas from Mixture of Recursion come in - ideally we want to control the hierarchy of heads and the features they learn.
Citation
@online{bochman2025,
author = {Bochman, Oren},
title = {FlexAttention: {A} {Flexible} {Approach} to {Attention}
{Mechanisms}},
date = {2025-09-20},
url = {https://orenbochman.github.io/posts/2025/2025-09-20_FlexAttention/},
langid = {en}
}