Linear Attention
Let’s start by looking at the standard attention formula:
$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$
Where:
- $V$ is the value matrix representing meanings in the sequence, dimensions $n \times d_v$
- $K$ is the key matrix representing indexes/locations of the values, dimensions $n \times d_k$
- $Q$ is the query matrix representing information we currently need, dimensions $n \times d_k$
- $d_k$ is the dimension of the key vectors
- The output is a weighted sum of the value vectors $V$, where the weights are determined by the attention scores.
This famously has a time complexity of $O(n^2)$ for $n$ tokens in the sequence due to the matrix multiplications of $QK^T$, which is a real bottleneck for longer context tasks. The main idea of linear attention is to remove the softmax operation to allow for linear time complexity with respect to the number of tokens using associativity. Consider the following (ignoring the constant $\sqrt{d_k}$ for now):
$$ \begin{aligned} V^\prime &= QK^TV \\ &= Q(K^TV) \end{aligned} $$
It might seem like we haven’t done anything, but just by using associativity of matrix multiplication, we can now compute $K^TV$ first, which is in $O(d_k \cdot d_v \cdot n)$, and then multiply the result by $Q$, which is in $O(n \cdot d_k \cdot d_v)$ again. Most importantly, this is linear in $n$, since $d_k$ and $d_v$ are constants.
Now, considering the auto-regressive goal of our model, we can write the output token at time $t$, which is the $t$-th row of the output matrix, as:
$$ \begin{aligned} V^\prime_t &= q_t(K^TV) \\ &= q_t \sum_{i=1}^t k_i^T \otimes v_i \\ \end{aligned} $$
Here we used the equivalence of matrix multiplication of $KV$ and the sum of outer products of columns of K and rows of V. For some reason unclear to me, this equation is usually written as $q_t \sum_{i=1}^t k_i^T v_i$ in the literature, without differentiating between the outer and inner products. Either way, we can see now that the sum $\sum_{i=1}^t k_i^T v_i$ includes all the $t-1$ previous outer products, which we calculated before. This means that the new token $t$ is simply:
$$ V^\prime_t = q_t (\sum_{i=1}^{t-1} k_i^T \otimes v_i + k_t^T \otimes v_t) $$
Let $S_t = \sum_{i=1}^{t} k_i^T \otimes v_i = S_{t-1} + k_t^T \otimes v_t$, and we can write:
$$ V^\prime_t = q_t S_t $$
$S_t$ can now be thought of as the hidden state of the model, which is updated at each time step, just like in RNNs. The only thing we need to save is the hidden state $S_t$, which means constant memory usage in comparison to the $O(n)$ memory of standard attention.
The current consensus is that the normalization term leads to worse performance and is therefore dropped in favor of different normalization techniques ( Citation: Qin, Han & al., 2022 Qin, Z., Han, X., Sun, W., Li, D., Kong, L., Barnes, N. & Zhong, Y. (2022). The Devil in Linear Transformer. Association for Computational Linguistics. https://doi.org/10.18653/v1/2022.emnlp-main.473 ; Citation: Yang, Wang & al., 2025 Yang, S., Wang, B., Zhang, Y., Shen, Y. & Kim, Y. (2025). Parallelizing Linear Transformers with the Delta Rule over Sequence Length. Retrieved from https://arxiv.org/abs/2406.06484 ) . Further, we can write any kernel similarity function as $sim(q, k) = \phi(q) \cdot \phi(k)$, where $\phi$ is a feature representation (map) ( Citation: Katharopoulos, Vyas & al., 2020 Katharopoulos, A., Vyas, A., Pappas, N. & Fleuret, F. (2020). Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention. Retrieved from https://arxiv.org/abs/2006.16236 ) . As for the choice of $\phi$, a number of works have shown that the identity works well enough ( Citation: Sun, Dong & al., 2023 Sun, Y., Dong, L., Huang, S., Ma, S., Xia, Y., Xue, J., Wang, J. & Wei, F. (2023). Retentive Network: A Successor to Transformer for Large Language Models. Retrieved from https://arxiv.org/abs/2307.08621 ) , ergo the formulation we had above.
Generalized (Linear) Attention
In fact, we can go further in generalizing sequence models. If we consider attention as an associative memorization mechanism, where we have a set of key and value pairs $(k_1, v_1), (k_2, v_2), \ldots, (k_T, v_T) \in \mathbb{R}^{d_k} \times \mathbb{R}^{d_v}$ up to sequence length $T$, the associative memory system is a function $m(k_t) \approx v_t$ for all $t \in [1, T]$. The learning goal of our attention mechanism becomes then minimizing the reconstruction error between $m(k_t)$ and $v_t$. To generate the outputs $o_t$ we apply our memory function on the query $q_t$: $o_t = m(q_t)$. Note that this formulation is not limited to language models and applies to both causal and non-causal settings ( Citation: Wang, Shi & al., 2025 Wang, K., Shi, J. & Fox, E. (2025). Test-time regression: a unifying framework for designing sequence models with associative memory. Retrieved from https://arxiv.org/abs/2501.12352 ) .
Further, we can assign weights to each association when minimizing $m$ depending on the problem we are modelling. For example, language tokens (or any time-series data) are not stationary, meaning that recent tokens are more important than older ones. This can be achieved by applying a set of decaying weights to the association error: $\gamma_i^{(t)} = \prod_{j=i+1}^{t} \gamma_j, \gamma_j \in [0, 1]$.
This reduces our attention mechanism to a choice of three components that we can use as a framework for understanding and working with linear attention model ( Citation: Wang, Shi & al., 2025 Wang, K., Shi, J. & Fox, E. (2025). Test-time regression: a unifying framework for designing sequence models with associative memory. Retrieved from https://arxiv.org/abs/2501.12352 ) :
- Memory function $m$.
- Weighting function $\gamma$.
- Optimization algorithm of $m$.
Comparing this to our previous formulation of vanilla linear attention, we see that the memory function is simply the linear matrix of sums $S_t = \sum_{i=1}^{t} k_i^T \otimes v_i$, the weighting function is constant, and the optimization algorithm is the incremental updates with the outer products. This gives us an interesting insight, however, into the limitation of linear attention. Let’s see what happens when we apply our $m$ to the key $k_t$ and assuming normalization:
$$ \begin{aligned} m(k_j) &= k_j \cdot \sum_{i=1}^{t} k_i^T \otimes v_i \\ &= \sum_{i=1}^{t} (k_j \cdot k_i^T) \otimes (k_j \cdot v_i) \\ &= v_j + \sum_{i=1, i \neq j}^{t} (k_j \cdot k_i^T) \otimes (k_j \cdot v_i) \end{aligned} $$
The last term is basically retrieval error, which is only zero if $k_j$ is orthogonal to all other keys. Orthogonality can be guaranteed only if the number of keys (i.e. rows of $S_t$) is greater than the sequence length $n$, but, since our memory state $S_t$ has constant dimensions unlike the expanding KV cache of transformers, this is practically impossible. This is a hard limit on the expressiveness of linear attention, and ongoing work has been more or less about pushing or circumventing this limitation.
Linear Attention Variants
DeltaNets
The Delta Rule in DeltaNet basically “derives its name from the core principle of updating weights based on the “delta” (difference) between the prediction $S_{t−1}k_t$ and the target $v_t$” ( Citation: Yang, Wang & al., 2025 Yang, S., Wang, B., Zhang, Y., Shen, Y. & Kim, Y. (2025). Parallelizing Linear Transformers with the Delta Rule over Sequence Length. Retrieved from https://arxiv.org/abs/2406.06484 ) . This is more or less online learning (SGD) on value tokens during inference (test-time) with squared error loss and learning rate $\beta_t$ as update weights:
$$ \ell_t(S) = \frac{1}{2} \left \| Sk_t − v_t \right \|^2, \\ S_t = S_{t−1} − \beta_t \nabla_{S_{t−1}}\ell_t(S_{t−1}) = S_{t−1} − \beta_t(S_{t−1}k_t − v_t)k_t^T $$
Vanilla linear attention could also be regarded to be doing online learning on the value tokens, with its loss target being $\ell_t = -\langle Sk_t,v_t \rangle$. This has been shown to underperform DeltaNets, however, and one intuitive way to understand the advantage of DeltaNet’s update rule is to consider that minimizing the vanilla objective might also mean increasing the magnitude of $Sk$, which leads to numerical instability, and is less accurate representation of what we actually want to achieve.


Gated DeltaNets
As time progresses and our model sees new tokens, we want to enable our states to dynamically learn new associations, or unlearn old ones. To unlearn old information, Mamba 2 ( Citation: Dao & Gu, 2024 Dao, T. & Gu, A. (2024). Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality. Retrieved from https://arxiv.org/abs/2405.21060 ) introduces a gated state update $S_t = \alpha_t S_{t-1} + v_tk_t^T, \alpha_t \in (0,1)$. This helps, but it’s a blanket forget gate that removes all old associations equally. If the model needs to forget only one key-value association, all of the associations are similarly forgotten, which is suboptimal. In contrast, DeltaNets update associations one at a time, meaning that the model lacks the ability to rapidly clear outdated or irrelevant information, especially during context switches where previous data needs to be erased.
Combining both update mechanisms we get the best of both worlds: a per association update and a general forget gate:
$$ \begin{align*} S_t &= S_{t-1} (\alpha_t (I - \beta_t k_t k_t^T)) + \beta_t v_t k_t^T \\ &= \alpha_t S_{t-1} - \alpha_t \beta_t S_{t-1} k_t k_t^T + \beta_t v_t k_t^T \\ &= \alpha_t S_{t-1} - \alpha_t \beta_t v_t^{old} k_t^T + \beta_t v_t k_t^T \end{align*} $$
Now, from the perspective of online learning and established SGD algorithms, all we did was incorporate a decay term $\alpha_t$. These gates are some of the most common and effective ways to improve attention models.
DeltaProduct
It is rather obvious that, since we are taking a gradient step using keys and values of each token, it makes sense that taking more than one step per token should help in achieving better performance. From a linear algebra perspective, our state has a hard limit on memorization based on its rank. Transition matrices could help improve expressiveness by being of higher ranks and incorporating more dimensions of the input token in our updates: First linear attention models had diagonal update matrices, then DeltaNets had diagonal+1 update matrices, and now we want to expand that to diagonal+$n_h$.
Starting from the update rule of DeltaNets we had above, to generalize to $n_h$ matrices on the same token, we rewrite as:
$$ S_t = S_{t-1}\prod_{j=1}^{n_h} (I-\beta_{t,j}k_{t,j}k_{t,j}^T) + \sum_{i=1}^{n_h}(\prod_{k=j+1}^{n_h} (I-\beta_{t,k}k_{t,k}k_{t,k}^T))\beta_{t,j}v_{t,j}k_{t,j}^T $$
Note that we are using different keys and values for each gradient step! This is important, because a product of the same DeltaNet transition matrix is, well, the “same” transition matrix, just with potentially different scaling ( Citation: Siems, Carstensen & al., 2025 Siems, J., Carstensen, T., Zela, A., Hutter, F., Pontil, M. & Grazzi, R. (2025). DeltaProduct: Improving State-Tracking in Linear RNNs via Householder Products. Retrieved from https://arxiv.org/abs/2502.10297 ) . The ability to generate multiple, possibly orthogonal, keys is very helpful in improving expressiveness of the model. Moreover, by using a different scaling factor $\beta_{t,j}$ for each key, the model can learn to “skip” certain steps for certain tokens if deemed unnecessary.
Test-time Training (TTT)
So far, we have only seen models where the state is simply a matrix, but in a context retrieval task, it might be useful to have a dynamic and powerful state. What if we made the state itself a neural network? Basically you would have the output of the network $o_t$ at time $t$ of token $x_t$ be the output of a neural network $f$ using network weights $W_t$:
$$ o_t = f(x_t; W_t) $$
Where $f$ is a linear layer, an MLP or any other neural network. During inference (i.e. test-time), we update the weights of the network using the gradient of the loss function with respect to the current token:
$$ W_t = W_{t-1} - \eta \nabla \ell(W_{t-1}; x_t) $$
Meaning that our state update rule is gradient descent. This depends on the loss function we choose. For example, if we want the weights to memorize the previous context, we define $\ell$ as:
$$ \ell(W_{t-1}; x_t) = \frac{1}{2} \left \| f(\tilde{x_t}; W_{t-1}) - x_t \right \|^2 $$
Where $\tilde{x_t}$ is a corrupted version of $x_t$ to make the task non-trivial. In the case of $f$ being a linear layer, this is not too dissimilar to the DeltaNet variants we had above. Things get interesting when we use a more complex $f$, such as an MLP. Another difference is in the initialization of the weights. In the case of TTT, every sequence starts with weights $W_0$ which are learned during training. Training the model itself is referred to as the outer loop, while adapting the weights during inference is the inner loop.
Titans
Titans takes the idea of TTT one step further by adding weight decay and momentum to the network update step. The memory network $M$ update rule now becomes:
$$ M_t = (1-\alpha_t)M_{t-1} + S_t \\ S_t = \eta_t S_{t-1} + \theta_t \nabla \ell(M_{t-1}; x_t) $$
In the paper, this is motivated in a neuroscientific way, where the gradient of the loss is called “surprise”, in the sense that surprising events are more likely to be remembered and catch our attention for a while afterwards. That being said, you can also say that this is more or less reinventing deep learning within attention updates. That is not to say that this is all what Titans introduced. Continuing with the neuroscientific motivation, Titans tries to model long-term or persistent memory by using a set of learned weights $P ={p_1, p_2, \ldots, p_T}$, which are fixed during inference and appended to the start of each sequence. Then they introduce three schemes of incorporating different types of memory, the most effective of which is the “Memory as a Context” scheme:
- chunk the sequence into $N/C$ fixed length segments $S^{(i)}$.
- run each segment through the memory network $M_{t-1}$ while fixing the weights of $M$: $h_t = M_{t-1}^*(S^{(t)}W_Q)$. In other words, no gradient steps on each token (no inner loop).
- concatenate the output of each segment $h_t$ with the memory weights $P$ at the start of the sequence: $\tilde{S}^{(t)} = [P; h_t; S^{(t)}]$ then run that through the attention module $y_t = \text{Attn}(\tilde{S}^{(t)})$.
- now use that to do the usual update/output steps:
$$ M_t = M_{t-1}(\tilde{S}^{(t)}) \\ o_t = y_t \otimes M_t^*(y_t) $$
The intuition here being that $P$ is task independent general knowledge, while $h_t$ is the running context of the current sequence. Both should help the model to better decide the importance of the new context $S^{(t)}$ appended afterwards.
MesaNet
As we established, the test time objective of our models is learning the state $S$ (or linear operator) that minimizes the difference between $Sq_t$ and $v_t$. We can formulate this then as a “simple” linear system:
$$ Sq_t \approx v_t, \forall{t} \in {1, …, T} $$
More concretely, we have the following optimization problem: $$ S = \argmin_S \sum_{t=0}^{T-1} \frac{1}{2} |Sq_t - v_t |^2 + \frac{\lambda}{2} |S|_F^2 $$
With the second term being a regularization term. Most of the approaches other models use are iterative approximations of the solution, but the Mesa layer ( Citation: von Oswald, Schlegel & al., 2024 von Oswald, J., Schlegel, M., Meulemans, A., Kobayashi, S., Niklasson, E., Zucchet, N., Scherrer, N., Miller, N., Sandler, M., Agüera y Arcas, B., Vladymyrov, M., Pascanu, R. & Sacramento, J. (2024). Uncovering mesa-optimization algorithms in Transformers. Retrieved from https://arxiv.org/abs/2309.05858 ) formulates the exact solution to this equation as:
$$ S = (\sum_{t=0}^{T-1}v_t k_t^T) \underbrace{(\sum_{t=0}^{T-1} k_t k_t^T + \lambda I)^{-1}}_{:=R_T} $$
Then, treating the $R_t$ part as a rank-1 recursive least square update of our state, i.e.:
$$ R_t = R_{t-1} - \frac{R_{t-1} k_t k_t^T R_{t-1}}{1+k_t^T R_{t-1}k_t}, \ \ \ \ \ R_0 = \lambda^{-1}I $$
We have our iterative state update at each timestep $t$ with the formula $S_t = (\sum_{i}^t v_i k_i^T)R_t$ and our output $o_t = S_t q_t = (\sum_{i}^t v_i k_i^T)R_t q_t$. This achieves an optimal state $S$ given the context of all previous tokens, instead of the approximative approaches of DeltaNets and the like. MesaNet is a parallelization of this insight, where the authors use an iterative linear system solver that is amendable to parallelization (Conjugate Gradient method).
Beyond Linear Attention
So far we have only talked about the attention mechanism itself and nothing else. Obviously, there’s a lot more that goes into creating a model than just that. The most important two aspects that weren’t discussed here are training and architectures. Due to the recurrent nature of linear attention, it is not possible to parallelize the training of a linear attention model if we use the naive formulation, and a model that can’t be trained efficiently is as good as useless. The most common solution to this is the chunkwise parallel formulation, which is explained in this and this posts, probably better than I ever could.
As for architectures, there is a wild variety of choices in the literature, from relatively simple LLaMa like models to MoE to even “Mixture of Memories”. This makes it hard to compare attention variations, since a big part of the model’s performance is due to the architecture and not the attention mechanism itself. A popular choice for evaluating attention models is the MAD benchmark, just to name one.
Bibliography
- Dao & Gu (2024)
- Dao, T. & Gu, A. (2024). Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality. Retrieved from https://arxiv.org/abs/2405.21060
- Katharopoulos, Vyas, Pappas & Fleuret (2020)
- Katharopoulos, A., Vyas, A., Pappas, N. & Fleuret, F. (2020). Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention. Retrieved from https://arxiv.org/abs/2006.16236
- Qin, Han, Sun, Li, Kong, Barnes & Zhong (2022)
- Qin, Z., Han, X., Sun, W., Li, D., Kong, L., Barnes, N. & Zhong, Y. (2022). The Devil in Linear Transformer. Association for Computational Linguistics. https://doi.org/10.18653/v1/2022.emnlp-main.473
- Siems, Carstensen, Zela, Hutter, Pontil & Grazzi (2025)
- Siems, J., Carstensen, T., Zela, A., Hutter, F., Pontil, M. & Grazzi, R. (2025). DeltaProduct: Improving State-Tracking in Linear RNNs via Householder Products. Retrieved from https://arxiv.org/abs/2502.10297
- Sun, Dong, Huang, Ma, Xia, Xue, Wang & Wei (2023)
- Sun, Y., Dong, L., Huang, S., Ma, S., Xia, Y., Xue, J., Wang, J. & Wei, F. (2023). Retentive Network: A Successor to Transformer for Large Language Models. Retrieved from https://arxiv.org/abs/2307.08621
- von Oswald, Schlegel, Meulemans, Kobayashi, Niklasson, Zucchet, Scherrer, Miller, Sandler, Agüera y Arcas, Vladymyrov, Pascanu & Sacramento (2024)
- von Oswald, J., Schlegel, M., Meulemans, A., Kobayashi, S., Niklasson, E., Zucchet, N., Scherrer, N., Miller, N., Sandler, M., Agüera y Arcas, B., Vladymyrov, M., Pascanu, R. & Sacramento, J. (2024). Uncovering mesa-optimization algorithms in Transformers. Retrieved from https://arxiv.org/abs/2309.05858
- Wang, Shi & Fox (2025)
- Wang, K., Shi, J. & Fox, E. (2025). Test-time regression: a unifying framework for designing sequence models with associative memory. Retrieved from https://arxiv.org/abs/2501.12352
- Yang, Wang, Zhang, Shen & Kim (2025)
- Yang, S., Wang, B., Zhang, Y., Shen, Y. & Kim, Y. (2025). Parallelizing Linear Transformers with the Delta Rule over Sequence Length. Retrieved from https://arxiv.org/abs/2406.06484