stat946W25: Difference between revisions
(102 intermediate revisions by 20 users not shown) | |||
Line 10: | Line 10: | ||
[https://forms.gle/SVzJpUXxQka11q83A Your feedback on presentations] | [https://forms.gle/SVzJpUXxQka11q83A Your feedback on presentations] | ||
= Topic 12: State Space Models= | = Topic 12: State Space Models = | ||
=== Introduction === | === Introduction === | ||
State Space Models (SSMs) are introduced as powerful alternatives to traditional sequence modeling approaches. These models demonstrate good performance in various modalities, including time series analysis, audio generation, and image processing and they can capture long-range dependencies more efficiently. | State Space Models (SSMs) are introduced as powerful alternatives to traditional sequence modeling approaches. These models demonstrate good performance in various modalities, including time series analysis, audio generation, and image processing and they can capture long-range dependencies more efficiently. | ||
SSMs initially struggled to match the performance of Transformers in language modeling tasks and there were some gaps between them. To address their challenges, recent advances in their architecture such as the Structured State Space Model (S4) have been introduced, which succeeded in long-range reasoning tasks and allowed for more efficient computation while preserving theoretical strengths. However, its implementation remains complex and computationally demanding. So further research led to simplified variants such as the Diagonal State Space Model (DSS), which achieves comparable performance with a more straightforward formulation. In parallel, hybrid approaches, like the H3 model, that integrate SSMs with attention mechanisms try to bridge the mentioned gaps. To understand better what I mean from the hybrid word, for example in H3 the authors try replacing almost all the attention layers in transformers with SSMs. More recently, models like Mamba have pushed the boundaries of SSMs by selectively parameterizing state matrices as functions of the input and allowing more flexible and adaptive information propagation. | SSMs initially struggled to match the performance of Transformers in language modeling tasks and there were some gaps between them. To address their challenges, recent advances in their architecture such as the Structured State Space Model (S4) have been introduced, which succeeded in long-range reasoning tasks and allowed for more efficient computation while preserving theoretical strengths. However, its implementation remains complex and computationally demanding. So further research led to simplified variants such as the Diagonal State Space Model (DSS), which achieves comparable performance with a more straightforward formulation. In parallel, hybrid approaches, like the H3 model, that integrate SSMs with attention mechanisms try to bridge the mentioned gaps. To understand better what I mean from the hybrid word, for example in H3 the authors try replacing almost all the attention layers in transformers with SSMs. More recently, models like Mamba have pushed the boundaries of SSMs by selectively parameterizing state matrices as functions of the input and allowing more flexible and adaptive information propagation. | ||
Research in SSMs continues to resolve the remaining challenges and the potential to substitute attention-based architectures with SSMs grows stronger. They will likely play a crucial role in the next generation of sequence modeling frameworks. | Research in SSMs continues to resolve the remaining challenges and the potential to substitute attention-based architectures with SSMs grows stronger. They will likely play a crucial role in the next generation of sequence modeling frameworks. | ||
=== Advantages of SSMs === | |||
* Efficient Long-Range Dependency Handling: Unlike Transformers, which require <math>O(n^2)</math> complexity for self-attention, SSMs can process sequences in <math>O(n \log n)</math> time complexity using efficient matrix-vector multiplications and the Fast Fourier Transform (FFT) or <math>O(n)</math> in case of Mamba architectures. | |||
* Effective Long-Range Dependency Handling: By leveraging a state update mechanism (parametrized by λ) that preserves and accumulates signals over arbitrarily large time spans, SSMs effectively capture and retain information from distant points in a sequence, enabling robust long-range dependency handling. | |||
* Lower memory requirements: Unlike transformers, SSMs don’t require storing the entire input in memory, leading to lower memory consumption. | |||
* Parallelization: SSMs allow efficient parallelization while maintaining the benefits of recurrent computation, making them a powerful alternative to RNNs. | |||
* Robustness to Irregularly Sampled Data: Unlike Transformers, which inherently assume regularly spaced inputs, State Space Models naturally handle irregularly sampled data, which is prevalent in real-world applications such as sensor data processing and clinical time-series. By explicitly modeling the continuous-time evolution of latent states, SSMs provide robustness against missing or unevenly spaced observations, leading to improved performance in scenarios where input data is incomplete or irregularly sampled. | |||
* Interpretability and Diagnostic Insight: State Space Models offer a distinct interpretability advantage by representing system behavior through their state-transition dynamics explicitly. Unlike black-box models, the learned parameters in SSMs (matrices <math>\mathbf{A}</math>, <math>\mathbf{B}</math>, <math>\mathbf{C}</math>, and <math>\mathbf{D}</math>) can be analyzed to infer how inputs influence future states and outputs over time. This interpretability is especially valuable in fields where understanding model behavior is critical, such as financial risk modeling or biological systems analysis. | |||
=== Core concepts === | === Core concepts === | ||
Line 101: | Line 111: | ||
* <math>\mathbf{\bar A} = (I - \frac{\Delta}{2}A)^{-1}(I + \frac{\Delta}{2}A)</math> | * <math>\mathbf{\bar A} = (I - \frac{\Delta}{2}A)^{-1}(I + \frac{\Delta}{2}A)</math> | ||
* <math>\mathbf{\bar B} = (I - \frac{\Delta}{2}A)^{-1}\Delta B</math> | * <math>\mathbf{\bar B} = (I - \frac{\Delta}{2}A)^{-1}\Delta B</math> | ||
<math>\overline{A}</math> : Discretized state transfer matrix, which is an approximate representation of the continuous state matrix A. | |||
<math>\overline{B}</math> : Discrete version of the input matrix. | |||
Look at [[#Discretization]] for further details. | Look at [[#Discretization]] for further details. | ||
Line 133: | Line 148: | ||
===Discretization=== | ===Discretization=== | ||
As mentioned State Space models come from ordinary differential equations, so the way we discretize those continuous equations so they work with finite sequences is crucial. In fact we aim to learn <math>\mathbf{\bar A}</math> and <math>\mathbf{\bar B}</math> | As mentioned, State Space models come from ordinary differential equations, so the way we discretize those continuous equations so they work with finite sequences is crucial. In fact, we aim to learn <math>\mathbf{\bar A}</math> and <math>\mathbf{\bar B}</math> instead of <math>A</math> and <math>B</math> directly, so the discretization step is baked into our model in practice. Below we show a quick discretization example based on a trapezoidal rule. | ||
Trapezoidal rule assumes: | |||
<math>x_{n+1} - x_{n} = \frac{\Delta}{2} (f(t_{n+1}) + f(t_{n}))</math> | |||
We start from the ordinary differential equation | We start from the ordinary differential equation: | ||
<math> | <math> | ||
h'(t) = \mathbf{A} h(t) + \mathbf{B} x(t) | h'(t) = \mathbf{A} h(t) + \mathbf{B} x(t) | ||
</math> | </math> | ||
By | By applying the trapezoidal rule: | ||
<math> | <math> | ||
Line 156: | Line 172: | ||
</math> | </math> | ||
It is assumed that the control sequence does not change over small | It is assumed that the control sequence does not change significantly over small <math>\Delta</math>, i.e., <math>x_{n+1} \approx x_n</math>, leading to: | ||
<math> | <math> | ||
Line 162: | Line 178: | ||
</math> | </math> | ||
Thus, solving for <math>h_{n+1}</math>: | |||
<math>h_{n+1} = | <math>h_{n+1} = | ||
Line 167: | Line 184: | ||
</math> | </math> | ||
From this, the discretized state matrices are: | |||
<math>\mathbf{\bar B}=(\ | |||
<math> | |||
\mathbf{\bar A}=(\mathbf{I} - \frac{\Delta}{2} \mathbf{A})^{-1}(\mathbf{I} + \frac{\Delta}{2}\mathbf{A}), \quad | |||
\mathbf{\bar B}=(\mathbf{I} - \frac{\Delta}{2} \mathbf{A})^{-1}\Delta \mathbf{B}. | |||
</math> | |||
An alternative and often-used discretization is based on the matrix exponential: | |||
<math> | |||
\mathbf{\bar{A}} = e^{A\Delta}, \quad | |||
\mathbf{\bar{B}} = (A^{-1}(e^{A\Delta} - I)) B. | |||
</math> | |||
This formulation naturally arises in continuous-time state-space models and ensures numerical stability for stiff systems. | |||
The matrix index <math>{A\Delta}</math> comes directly from the analytic solution of the continuous-time system: | |||
<math>x(t) = e^{At} x(0) + \int_0^t e^{A(t-s)} B u(s) ds</math> | |||
For discretization, we choose <math> t = k\Delta </math> as the sampling point to obtain: | |||
<math>x_k = e^{A\Delta} x_{k-1} + (e^{A\Delta} - I) A^{-1} B u_k</math> | |||
which is the source of the matrix index method. | |||
To express this in terms of a convolution kernel, the system can be reformulated as: | |||
<math> | |||
\bar{K} = ( C \bar{B}, C \bar{A} \bar{B}, ..., C \bar{A}^{L-1} \bar{B} ). | |||
</math> | |||
This can also be rewritten using the exponential formulation: | |||
<math> | |||
K = \left( C e^{A k \Delta} (e^{A\Delta} - I) A^{-1} B \right)_{0 \leq k < L}. | |||
</math> | |||
This convolution kernel representation helps efficiently compute long-range dependencies in sequences. | |||
Numerical Stability Considerations: | |||
* If <math>\|A\Delta\|</math> is large, naive exponentiation can introduce significant numerical errors. | |||
* The trapezoidal rule provides a more stable alternative by constraining eigenvalues. | |||
* In practical DSS implementations, row-wise normalization is used to further stabilize large-sequence dynamics. | |||
These methods ensure that state-space models remain numerically robust and efficient when handling long sequences. | |||
===Structured State Space (S4)=== | ===Structured State Space (S4)=== | ||
Line 178: | Line 239: | ||
====Intuition Behind S4==== | ====Intuition Behind S4==== | ||
The main innovation of S4 is its ability to track long-range dependencies efficiently. Traditional models struggle with long sequences either because they require too much memory (like Transformers) or forget early information (like RNNs). S4 overcomes these limitations by using a mathematical reformulation of state space models, allowing it to process sequences much more efficiently. | The main innovation of S4 is its ability to track long-range dependencies efficiently. Traditional models struggle with long sequences either because they require too much memory (like Transformers) or forget early information (like RNNs). S4 overcomes these limitations by using a mathematical reformulation of state space models, allowing it to process sequences much more efficiently. The authors argue that the different layers of the S4 model contribute to learning long-range dependencies efficiently. They hypothesized that the deep layers of S4 are predominantly responsible for learning local information (short-range dependencies). On the other hand, higher layers aggregate global information, allowing the model to capture the bigger picture. | ||
At its core, S4 is a hybrid between recurrent models and convolutions. Instead of updating each time step sequentially like an RNN, it uses a special structure called the Cauchy kernel, which allows it to compute all steps in parallel. This enables it to process sequences of tens of thousands of steps with a fraction of the computational cost of previous models. | At its core, S4 is a hybrid between recurrent models and convolutions. Instead of updating each time step sequentially like an RNN, it uses a special structure called the Cauchy kernel, which allows it to compute all steps in parallel. This enables it to process sequences of tens of thousands of steps with a fraction of the computational cost of previous models. Thus, S4 is very advantageous for inference. | ||
S4 has been particularly successful in domains that require continuous, structured data, such as time-series processing. However, its reliance on structured state matrices makes it less adaptable to unstructured data like natural language, where attention-based models still hold an advantage. | S4 has been particularly successful in domains that require continuous, structured data, such as time-series processing. However, its reliance on structured state matrices makes it less adaptable to unstructured data like natural language, where attention-based models still hold an advantage. | ||
Line 192: | Line 253: | ||
From Theorem 1, we can find that NPLR matrices can be conjugated into diagonal plus low-rank (DPLR) form. | From Theorem 1, we can find that NPLR matrices can be conjugated into diagonal plus low-rank (DPLR) form. | ||
The core idea of this form is to diagonalize the matrix A using the unitary-ary transformation V and use a low-rank decomposition to further compress the computation, which is important in efficient computation and storage optimization. | |||
=====Theorem 2 ===== | =====Theorem 2 ===== | ||
Given any step size <math>∆</math>, computing one step of the [[ #Recurrent Representation of SSM|recurrence]] can be done in <math>O(N)</math> operations where <math>N</math> is the state size. | Given any step size <math>∆</math>, computing one step of the [[ #Recurrent Representation of SSM|recurrence]] can be done in <math>O(N)</math> operations where <math>N</math> is the state size. | ||
Line 231: | Line 295: | ||
* Interpretability | * Interpretability | ||
** The diagonal structure makes it clear how previous hidden states and inputs affect the current state, which helps in understanding the relationships between states. For example, it can be shown that under certain situations, DSS can capture information from extreme distant position in the sequence. This simplicity enhances the transparency of the model, providing clear insights into system behavior and making it easier to interpret the underlying dynamics. By comparison, S4 lacks interpretability due to its more complex architecture. | ** The diagonal structure makes it clear how previous hidden states and inputs affect the current state, which helps in understanding the relationships between states. For example, it can be shown that under certain situations, DSS can capture information from extreme distant position in the sequence. | ||
** To illustrate how DSS captures long-range dependencies, consider the diagonalization of the state space representation. Given a diagonalizable state matrix <math>A</math> with eigenvalues <math>\lambda_1, \dots, \lambda_N</math> and the diagonal matrix <math>\Lambda</math>, DSS reformulates the recurrence relation as: <math> K = \bar{K}_{\Delta,L} \left( \Lambda, \left(e^{L \lambda_i \Delta} - 1 \right)^{-1} \right) </math>, where <math>K</math> encodes the sequence propagation in the transformed space. The diagonal form simplifies the recurrence relationship, making it clear how past states influence the current state. Under specific conditions, DSS captures long-range dependencies: | |||
*** If <math>|\lambda_i| \Delta \approx 0</math>, the state evolution reduces to <math>x_{i,k} \approx x_{i,k-1}</math>, leading to persistent memory effects. | |||
*** If <math>\text{Re}(\lambda_i) \Delta \ll 0</math>, information from past states is effectively forgotten, prioritizing local dependencies. | |||
*** If <math>\text{Re}(\lambda_i) > 0</math>, DSS can propagate information from distant states due to the exponential weighting term <math>e^{\lambda \Delta}</math>. This property enables DSS to retain long-range dependencies more effectively than conventional models. | |||
** A key advantage of DSS is that it provides interpretable mechanisms to regulate information retention and forgetting. The row-wise normalization in <math>\text{DSS}_{\text{softmax}}</math> further stabilizes the dynamics, mitigating numerical instabilities when handling large sequence lengths <math>L</math>. This simplicity enhances the transparency of the model, providing clear insights into system behavior and making it easier to interpret the underlying dynamics. By comparison, S4 lacks interpretability due to its more complex architecture. | |||
<math>\mathbf{Two\ Variants\ of\ DSS} </math> | |||
The DSS model mentioned above is one of the variants of DSS called <math>DSS_{softmax}</math>. Softmax normalization ensures that the contribution of different eigenvalues remains bounded, avoiding numerical instability. Another variant of DSS is called <math>DSS_{exp}</math>. In this variant, the recurrence relation becomes <math> K = \bar{K}_{\Delta,L} \left( \Lambda, \left(1\right)\right) </math>. Instead of softmax normalization, this variant forces the real part of eigenvalues to be negative, ensuring that the recurrence does not grow uncontrollably. Note that the restriction may fail in some tasks and not as expressive as general state spaces. | |||
<math>\mathbf{Applications} </math> | <math>\mathbf{Applications} </math> | ||
Line 316: | Line 389: | ||
</math> | </math> | ||
, where these equations are discretized using an additional parameter <math>\Delta</math> as shown above, resulting in the updated formulae: | |||
<math> | <math> | ||
Line 324: | Line 397: | ||
<math> | <math> | ||
y(t) = \mathbf{C} h(t) | y(t) = \mathbf{C} h(t) | ||
</math> | |||
, where the matrices <math>{\mathbf{A}}</math> and <math>{\mathbf{B}}</math> are discretized as follows: | |||
<math display="block"> | |||
\begin{aligned} | |||
& \overline{\mathbf{A}}=\operatorname{diag}\left(e^{-\Delta_t \omega}\right) \\ | |||
& \overline{\mathbf{B}}=\left(1-e^{-\Delta_t \omega}\right) | |||
\end{aligned} | |||
</math> | </math> | ||
Line 332: | Line 414: | ||
* <math>\mathbf{C}</math> interprets the hidden state in order to generate the output. | * <math>\mathbf{C}</math> interprets the hidden state in order to generate the output. | ||
Mamba introduces | Mamba introduces an input-dependent gating mechanism <math>\Delta_t</math>, defined as: | ||
<math> | |||
\Delta_t=\tau_{\Delta}\left(W_{\Delta} x_t+b_{\Delta}\right) | |||
</math> | |||
where <math>\tau_{\Delta}</math> is typically a softplus or sigmoid activation function, and <math>W_{\Delta}, b_{\Delta}</math> are learned parameters. The discrete state update rule incorporating selectivity is: | |||
<math> | <math> | ||
h_t=\left(\mathbf{I}-\Delta_t\right) h_{t-1}+\Delta_t x_t | |||
</math> | |||
This gating function dynamically adjusts the timescale of information propagation, allowing the model to selectively retain or discard information based on the context provided by input data. | |||
==== Convolutional Kernel Representation ==== | |||
The Mamba model's dynamics can also be expressed through a convolutional kernel representation for computational efficiency, especially during training. Given an input sequence <math>x_{1: T}</math>, the output <math>y_t</math> can equivalently be represented as: | |||
<math display="block"> | |||
y_t=\sum_{k=1}^t \mathbf{C} \overline{\mathbf{A}}^{t-k} \overline{\mathbf{B}} x_k | |||
</math> | </math> | ||
This representation highlights that Mamba's hidden states effectively implement a convolution over past inputs with a kernel shaped by the learned gating parameters and state transition matrices. This approach provides efficient parallel training similar to convolutional architectures while maintaining the ability to handle extremely long sequences with linear complexity during inference. | |||
'''Advantages :''' | |||
* '''Dynamic Timescales:''' Unlike traditional SSMs with fixed transition matrices, Mamba adjusts transition dynamics based on inputs. | |||
* '''Improved Long-range Dependencies:''' Dynamic gating allows Mamba to selectively propagate crucial long-range information while ignoring irrelevant short-term fluctuations. | |||
* '''Computational Efficiency:''' The convolutional representation significantly reduces computational overhead during parallel training phases. | |||
====Selective State Space Models in Mamba==== | ====Selective State Space Models in Mamba==== | ||
Line 349: | Line 454: | ||
In summary, Mamba's selective state space models marry the efficiency of classical SSMs with the flexibility of input-dependent gating. This combination allows the architecture to handle extremely long sequences—up to millions of tokens—while still capturing the nuanced patterns essential for tasks ranging from language modeling to audio generation. | In summary, Mamba's selective state space models marry the efficiency of classical SSMs with the flexibility of input-dependent gating. This combination allows the architecture to handle extremely long sequences—up to millions of tokens—while still capturing the nuanced patterns essential for tasks ranging from language modeling to audio generation. | ||
===Mamba-2=== | ===Mamba-2=== | ||
Line 430: | Line 534: | ||
[[File:Mamba2-SSD.png|400px|General SSMs and SMA both possess linear and quadratic forms, with direct analogs in notation.]] | [[File:Mamba2-SSD.png|400px|General SSMs and SMA both possess linear and quadratic forms, with direct analogs in notation.]] | ||
Other ways to explain: | |||
* SSM can be regarded as a controlled dynamical system, where at each time step t, the state xt is influenced by A, while B is responsible for introducing inputs and C is responsible for computing outputs. | |||
* SMA is using a constrained attention model to do information dissemination, but because of the structure of Mask, SMA can only do information flow within a constrained range. | |||
* SSD reveals that they are the same thing, just in a different representation - SSM passes information recursively, whereas SMA propagates through a structured Mask, but the M they compute is structurally equivalent. | |||
====Mamba-2 architecture==== | ====Mamba-2 architecture==== | ||
Line 453: | Line 562: | ||
6. Zigzag Mamba (ZigMa): A Mamba-powered diffusion model for image generation. Replaces attention in Diffusion Transformers (DiT), improving speed and memory efficiency for high-resolution image synthesis. | 6. Zigzag Mamba (ZigMa): A Mamba-powered diffusion model for image generation. Replaces attention in Diffusion Transformers (DiT), improving speed and memory efficiency for high-resolution image synthesis. | ||
=== Key | ==== State Space Duality: State Space Models and Transformers ==== | ||
State Space Models (SSMs) offer a scalable alternative to transformers in language modeling by using state space equations to manage and update hidden states for generating outputs from inputs. Since its inception, there has been several important model enhancements: | |||
'''State Space Duality (SSD)''' gives a mathematization of State Space Models (SSMs)/Transformer duality such that the two models are proven to be definable in terms of structured semiseparable matrices. Duality offers the potential for the development of computationally efficient hybrid models such as '''Mamba-2''' which preserve the expressiveness of Transformers with the benefit of computational efficiency of SSMs. | |||
A discrete-time State Space Model (SSM) can be written as: | |||
<math> h_{t+1} = A h_t + B x_t </math> | |||
<math> y_t = C * h_t + D * x_t </math> | |||
where <math> h_t </math> is the hidden state at time <math> t </math>, <math> A, B, C, D </math> are learnable state-space matrices, <math> x_t </math> is the input, and <math> y_t </math> is the output. | |||
<math> K = (CB, CAB, CA^2B, ..., CA^{L-1}B) </math> | |||
where <math> K </math> is some kernel that encodes input dependencies. | |||
SSD establishes that the Transformers and the SSMs share the same structure with respect to semiseparable matrices. A semiseparable matrix is | |||
<math> M = D + U^T L </math> | |||
where <math> D </math> is the diagonal matrix for local dependencies, <math> L </math> is the lower triangular matrix for sequential state transitions, and <math> U </math> is the upper triangular matrix for long-range interactions. | |||
The self-attention mechanism of Transformers computes | |||
<math> M_{att} = Q K^T </math> | |||
where <math> Q, K </math> are the query and key matrices that yield explicit pairwise token interactions. In contrast, transition matrices of SSM make use of: | |||
<math> M_{ssm} = C (I - A)^{-1} B </math> | |||
where <math> A, B, C </math> define an implicit recurrence. | |||
SSD offers an optimized block decomposition method for computational efficiency. It splits the sequence into tiny blocks within which intra-chunk dependencies are handled by structured SSMs and inter-chunk information is conducted by structured masked attention (SMA). The structured mechanism hugely reduces the computational complexity so training complexity is <math> O(NT) </math> for SSMs compared to <math> O(N^2T) </math> for Transformers and inference complexity is also reduced from <math> O(N^2T) </math> to <math> O(NT) </math>. | |||
A semiseparable matrix <math> M </math> of order <math> N </math> is defined as | |||
<math> M_{ji} = C_j^T A_{j:i} B_i </math> | |||
where <math> A </math> is a structured transition matrix that facilitates effective sequence transformations. In the case of 1-semiseparable structured attention, the recurrence is | |||
<math> M_{ji} = a_{j:i} (C_j^T B_i) </math> | |||
where <math> a_{j:i} </math> is a generalized positional dependency factor in structured decay. | |||
Structured Masked Attention (SMA) generalizes linear attention with the inclusion of a structured masking matrix <math> L </math> as follows: | |||
<math> L_{ij} = a_{i:j} </math> | |||
where <math> a_{i:j} </math> are the learned decay factors that control the way information flows through the sequence. This makes SSD attain flexibility like attention but with efficiency like SSM. | |||
One of the primary conclusions of SSD is that quadratic (attention-like) as well as linear (SSM-like) computations are enabled for semiseparable matrices. | |||
<math> y = M x </math> | |||
It can either be computed by a naive quadratic method where <math> M </math> is used as an explicit attention-like matrix, or an SSM recurrence method that computes <math> y </math> in <math> O(NT) </math> time. The dual formulation enables hardware-efficient implementations where structured decompositions are used to enhance computational speed with maintained long-range dependencies. | |||
Mamba-2 exploits SSD through the use of parallel parameter projections to dynamically calculate the parameters of the SSM. There is selective gating of information by a gating mechanism and SSD-based calculation of state transitions facilitates optimal state updates. Expressivity is boosted through structured masked attention (SMA) and the gap with state-space models is bridged. | |||
<math> y = \sum_{j=0}^{L-1} W_j (A^j B x_{t-j}) </math> | |||
where <math> W_j </math> are learnable weights derived from structured attention masks. SSD permits hybrid architectures that retain the efficiency of SSMs but with Transformer expressiveness. The future can be anticipated from GPU/TPU acceleration, memory-efficient representations, and retrieval-augmented generation (RAG) for long-context applications. SSD has far-reaching implications for the future of large-scale models that are computationally tractable with 100M-token context windows. | |||
===Comparative Analysis of SSM Variants=== | |||
To provide a clearer understanding of the strengths and weaknesses of State Space Model (SSM) variants—S4, DSS, H3, Mamba, and Mamba-2—this section presents a comparative analysis based on their performance in language modeling tasks, computational efficiency, and scalability. This comparison draws from their architectural designs and empirical results, offering insights into their suitability for various applications. | |||
====Performance Metrics==== | |||
* Perplexity on Language Modeling Tasks | |||
** '''S4''': Achieves competitive perplexity on datasets like WikiText-103 but struggles with extremely long sequences due to its structured parameterization. | |||
** '''DSS''': Matches S4’s perplexity with a simpler diagonal structure, performing well on large-scale datasets with reduced overhead. | |||
** '''H3''': Outperforms S4 and DSS on synthetic tasks like Induction Heads and Associative Recall, closing the expressivity gap with Transformers. | |||
** '''Mamba''': Excels with very low perplexity on long-context tasks, leveraging selective state spaces to rival Transformer performance. | |||
** '''Mamba-2''': Achieves the lowest perplexity among SSMs by enhancing state dimensions and optimization, surpassing Mamba on benchmarks like OpenWebText. | |||
* Computational Efficiency | |||
** '''S4''': Uses Fast Fourier Transform (FFT) for <math>O(N \log N)</math> training complexity, though its implementation is intricate. | |||
** '''DSS''': Simplifies to <math>O(N)</math> training and <math>O(1)</math> inference with a diagonal state matrix, enhancing efficiency. | |||
** '''H3''': Combines two SSMs, resulting in <math>O(d^2 N + d N \log N)</math> complexity—more expressive but less efficient than others. | |||
** '''Mamba''': Achieves <math>O(N)</math> training and <math>O(1)</math> inference with hardware-aware selective mechanisms, optimizing for long sequences. | |||
** '''Mamba-2''': Maintains Mamba’s complexity while improving performance through state space duality and parallelism. | |||
* Scalability | |||
** '''S4''': Scales effectively for moderate sequence lengths but faces challenges with very long contexts. | |||
** '''DSS''': Excels in scalability for large datasets and real-time systems due to its simplicity. | |||
** '''H3''': Limited by higher computational costs, making it less scalable for large models. | |||
** '''Mamba''': Designed for long sequences, scales efficiently to millions of tokens, ideal for extensive context tasks. | |||
** '''Mamba-2''': Enhances scalability with tensor and sequence parallelism, enabling faster training and inference. | |||
====Where to Use Each Model==== | |||
* '''S4''': Best for continuous data tasks (e.g., time series, audio) where moderate sequence lengths are sufficient. | |||
* '''DSS''': Ideal for large-scale, real-time applications needing efficiency and interpretability (e.g., time-series forecasting). | |||
* '''H3''': Suited for language modeling tasks requiring Transformer-like expressivity with moderate sequence lengths. | |||
* '''Mamba''': Optimal for long-sequence tasks (e.g., document modeling, audio generation) needing high efficiency. | |||
* '''Mamba-2''': The go-to choice for large-scale, long-context modeling with enhanced training speed and performance. | |||
=== Empirical Scaling and Performance Trends === | |||
Recent experiments have demonstrated that state space models scale remarkably well as their capacity increases. For instance, when scaled to billions of parameters, models such as Mamba and Mamba-2 achieve language modeling perplexities that rival or even surpass those of similarly sized Transformers. Empirical studies on benchmarks like OpenWebText and the Long Range Arena indicate that as the state dimension grows and longer contexts are leveraged, SSM-based architectures not only maintain their linear computational scaling but also benefit from improved performance. These scaling trends suggest that, with proper architectural tuning and parameter initialization, SSMs are emerging as a viable alternative for large-scale, long-context sequence modeling. | |||
=== Summary & Key Takeaways === | |||
State Space Models (SSMs) offer a scalable alternative to transformers in language modeling by using state space equations to manage and update hidden states for generating outputs from inputs. Since its inception, there has been several important model enhancements. The following table summarizes their development, strengths, weaknesses, and time complexity: | |||
{| class="wikitable" | {| class="wikitable" | ||
|+ SSM | |+ SSM Models Summary | ||
|- | |- | ||
! Year !! Model !! | ! Year !! Model !! Contribution !! Strengths !! Weaknesses !! Complexity | ||
|- | |- | ||
| 2022|| Structured State Space (S4)|| Leveraged Diagonal Plus Low-Rank parameterization to efficiently | | 2022|| Structured State Space (S4)|| Leveraged Diagonal Plus Low-Rank parameterization to improve computational efficiency || Captures long-range dependencies efficiently || Complex architecture and implementation || <math>O(N log N)</math> with <math>FFT</math> | ||
|- | |- | ||
| 2022|| Diagonal State Spaces (DSS)|| Simplified S4 by using diagonal matrices to achieve comparable performance | | 2022|| Diagonal State Spaces (DSS)|| Simplified S4 by using diagonal matrices to achieve comparable performance || Big data and real-time systems scalability || Performance depends on initialization and lacks capacity to handle information-dense data || For batch size <math>B</math>, sequence length <math>L</math>, and hidden size <math>H</math>: DSS layer requires <math>O(NHL)</math> time to compute kernels, <math>O(BHL logL)</math> time for discrete convolution and <math>O(BH^2L)</math> time for output projection. | ||
|- | |- | ||
| 2023|| Hungry Hungry Hippos (H3)|| Introduced new SSM layer consisting of two SSMs stacked together (a shift and diagonal SSM) to enable token recall and comparison across sequences | | 2023|| Hungry Hungry Hippos (H3)|| Introduced new SSM layer consisting of two SSMs stacked together (a shift and diagonal SSM) to enable token recall and comparison across sequences || Transformer competitive performance on language modeling tasks || Less computationally efficient than other SSMs|| For sequence length <math>N</math>, and <math>d</math> hidden dimension, then <math>H3</math> layer takes <math>O(d^2N + dN logN)</math> | ||
|- | |- | ||
| 2024|| Mamba|| Integrated selective state spaces to | | 2024|| Mamba|| Integrated selective state spaces to filter irrelevant information while being hardware-aware algorithm || Handles long sequences at faster inference and less memory requirements than Transformer || Scaling to compete with LLM requires complex engineering || Training is <math>O(N)</math>, while inference is <math>O(1)</math> | ||
|- | |- | ||
| 2024|| Mamba-2|| Used state space duality (SSD) to enable larger state dimensions and faster training | | 2024|| Mamba-2|| Used state space duality (SSD) to enable larger state dimensions || Tensor and sequence parallelism which allows for faster training and inference || Lack of inference optimization techniques like quantization and speculative decoding || Training is <math>O(N)</math>, while inference is <math>O(1)</math> | ||
|} | |} | ||
=== Future of SSMs === | |||
Ongoing research is focused on improving efficiency, adaptability, and scalability to position State Space Models (SSMs) as a viable alternative to Transformers. While SSMs excel at modeling long-range dependencies, they face hardware inefficiencies and an expressivity gap in tasks like language modeling. Key areas of research include: | |||
1. Better hardware utilization: Optimizing SSM implementations to leverage GPU/TPU acceleration as efficiently as Transformers. | |||
* Enhancing GPU/TPU acceleration through optimized memory access and tensor core utilization. | |||
* Applying kernel fusion techniques to minimize redundant computations. | |||
2. Adaptive SSMs: SSMs currently lack the expressivity of Transformers, particularly for tasks requiring complex reasoning. Developing architectures that can dynamically switch between SSM and attention-based processing depending on the task. | |||
* Structured Masked Attention (SMA) for improved context retention and expressivity. | |||
* Selective Copying & Induction Heads to enhance memory retention. | |||
* State-passing algorithms to better preserve sequence dependencies. | |||
3. Scaling laws for SSMs: SSMs scale efficiently for long-sequence tasks but require further study to match Transformers at large parameter counts. Understanding how these models perform at increasing levels of parameterization compared to standard deep learning architectures. | |||
* Understanding scaling laws to optimize model depth and expressivity. | |||
* Evaluating trade-offs between efficiency and generalization across tasks. | |||
* Expanding SSM applications, particularly in speech, DNA, and document retrieval. | |||
4. Hybrid Architectures with Dynamic Routing: | |||
* Exploring deeper integrations between SSMs and Transformers, potentially using dynamic gating mechanisms or context-based routing, allowing models to dynamically switch processing modes depending on input characteristics or computational constraints. | |||
5. Automated Hyperparameter and Architecture Search: | |||
* Given the sensitivity of SSM performance to hyperparameter choices, research into automated hyperparameter optimization (such as Bayesian optimization or neural architecture search) could greatly reduce manual tuning, making SSMs more broadly applicable in practice. | |||
6. Improved Visualization and Interpretability Techniques: | |||
* | * Although inherently interpretable, visualization and diagnostic tools tailored specifically to SSM hidden states and transitions remain limited. Developing intuitive visualization tools would significantly improve the practical interpretability of SSM models, especially in high-dimensional or complex data scenarios. | ||
By addressing these key areas, SSMs hold substantial promise for becoming the foundation of next-generation sequence modeling, especially where computational efficiency, adaptability, and interpretability are crucial, such as speech recognition, biological sequence modeling, and real-time inference systems. | |||
= Topic 8: Sparse Attention= | = Topic 8: Sparse Attention= | ||
Line 498: | Line 736: | ||
===Sparse Sinkhorn Attention=== | ===Sparse Sinkhorn Attention=== | ||
==== Simplified Intuition ==== | |||
Imagine a crowded room where everyone is trying to talk at once—confusing and inefficient, right? That’s what happens in traditional attention: every token interacts with every other token, like a chaotic party where no one can really focus on the conversation. Sparse Sinkhorn Attention takes a different approach. It first acts like a clever host who quickly sorts the guests into smaller groups based on who’s most likely to have something in common. This sorting, done by a neat algorithm (the Sinkhorn normalization), gently reorders the tokens so that similar ideas end up together, allowing meaningful conversations in manageable clusters. | |||
Once the guests are grouped, each token only needs to chat with its immediate neighbors, much like small, focused discussion groups. Even though tokens now converse locally, the smart grouping ensures that distant yet related ideas are brought close together. This not only cuts down on the chaos and computational cost but also preserves the essential flow of information—much like having productive, focused conversations instead of a noisy, all-out babble. | |||
=== Overview === | |||
To address the computationally intensive of Vanilla Attention, especially for the long input sequence, Sparse Sinkhorn Attention proposes three core ideas: | To address the computationally intensive of Vanilla Attention, especially for the long input sequence, Sparse Sinkhorn Attention proposes three core ideas: | ||
Line 665: | Line 910: | ||
== Key Approaches to Linear Attention == | == Key Approaches to Linear Attention == | ||
We | ===Retentive Network (RetNet): A Successor to Transformer for Large Language Models=== | ||
[[File:retnet_impossible_triangle.png|thumb|200px|Figure 1: Impossible Triangle]] | |||
The authors of this paper introduce a new architecture that can achieve parallel training, fast inference, without compromising performance solving the inherent complexity in vanilla attention mechanism. They claim that the new architecture makes the impossible triangle, Figure 1, possible. RetNet combines the advantages of transformer performance and ability for parallel training with the lower-cost inference of RNNs by introducing a multi-scale retention mechanism to replace multi-head attention that incorporates three main representations: | |||
# Parallel: allows GPU utilization during training. | |||
# Recurrent: enables low-cost inference with <math>O(1)</math> complexity in terms of computation and memory. | |||
# Chunkwise recurrent: models long-sequences efficiently. | |||
Essentially, RetNet replaces the <math>softmax</math> by adopting a linear representation of the attention mechanism in the form of retention. | |||
====Recurrent Representation==== | |||
We can show that the retention mechanism has a dual form of recurrence and parallelism (Figure2). First let's show RetNet in its recurrent form: | |||
[[File:RetNet_dual.png|thumb|500px|Figure 2: Dual form of RetNet]] | |||
<math>S_n = \gamma S_{n-1} + K_n^T V_n</math> | |||
<math>\text{Retention}(X_n) = Q_n S_n</math> | |||
In this formula, <math>\gamma</math> is a decay factor, and <math>Q</math>, <math>K</math>, <math>V</math> are learned projections of the input. During training, this can be computed in parallel form. | |||
====Parallel Representation==== | |||
Now, we can show the parallel representation of retention: | |||
<math>\text{Retention}(X) = (QK^T \odot D)V</math> | |||
Where <math>D</math> is a matrix that combines causal masking and exponential decay along relative distance. For better understanding the role of matrix <math>D</math>, let's take a look at its formula: | |||
<math display="block"> | |||
D_{nm} = | |||
\begin{cases} | |||
\gamma^{n-m}, & n \geq m \ | |||
0, & n < m | |||
\end{cases} | |||
</math> | |||
This formulation ensures that tokens can only attend to previous tokens (which they call it, causal masking) and the attention strength decays exponentially with distance (controlled by parameter <math>\gamma</math>). In other words, as we move forward, less attention is paid to earlier tokens. | |||
====Chunkwise Recurrent Representation==== | |||
The third component, chunkwise recurrent representation, is a hybrid of recurrent and parallel representations aimed at training acceleration for long sequences. Simply, the input sequence is divided into chunks, where in each chunk a parallel representation is utilized to process all the tokens simultaneously. To pass information across chunks, the recurrent representation is used to create a summary of that chunk and pass it to the next. The combination of inner-chunk (parallel) and cross-chunk (recurrent) information produce the retention of the chunk which captures the details within the chunk and the context from previous chunks. To compute the retention of the i-th chunk, let <math>B</math> be the chunk length, then: | |||
<math>Q_{[i]} = Q_{Bi:B(i+1)}, K_{[i]} = K_{Bi:B(i+1)}, V_{[i]} = V_{Bi:B(i+1)}</math> | |||
<math>R_i = K_{[i]}^T (V_{[i]} \odot \zeta) + \gamma^B R_{i-1}, \quad \zeta_{ij} = \gamma^{B-i-1}</math> | |||
<math> | |||
{Retention}(X_{[i]}) = | |||
\underbrace{(Q_{[i]} K_{[i]}^T \odot D) V_{[i]}}_{Inner-Chunk} + | |||
\underbrace{(Q_{[i]} R_{i-1}) \odot \xi}_{Cross-Chunk}, \quad \xi_{i,j} = \gamma^{i+1} | |||
</math> | |||
where <math>R_i</math> is the state of the current chunk to be passed to the next. | |||
In summary, the overall architecture of RetNet consists of stacked identical L blocks similar to Transformer. Each block consists of two modules: multi-scale retention (MSR) module and a feed-forward network (FFN) module. It takes the embeddings of input sequence <math>X^0= [x_1, \dots, x_{|x|}] \in \mathbb R^{|x| \times d_{model}}</math> and compute the output <math>X^L</math>: | |||
<math>Y^l = MSR(LN(X^l)) + X^l</math> | |||
<math>X^{l+1} = FFN(LN(Y^l)) + Y^l</math> | |||
where <math>LN(.)</math> is LayerNorm. | |||
[[File:retnet_comparison.png|thumb|800px|Table 1: Model Comparison]] | |||
====Comparison, Limitations & Future Work==== | |||
Finally, comparing RetNet to other solutions, we can see that it provides a solid alternative to the Transformer achieving parallel training, fast inference, linear long-sequence memory complexity, and good performance. | |||
The limitations of this architecture can be summarized in the following points: | |||
* Limited Cross-Modal Exploration | |||
*The method is primarily validated on language modeling tasks. Its applicability to other modalities (e.g., vision, audio) remains unexplored, requiring further research for broader adoption. | |||
=== Memory-Recall Tradeoff === | |||
A fundamental tradeoff exists between the size of a model's recurrent state and its ability to recall past tokens accurately. Some architectures, like BASED, combine linear attention with a sliding window of exact softmax attention to navigate this tradeoff. By adjusting hyperparameters such as the window size and feature dimensions, these models can traverse the Pareto frontier—achieving high recall with a small memory footprint while still benefiting from high throughput. | |||
This paper focuses on the task of associative recall and discusses the memory-recall tradeoff. The memory-recall tradeoff basically states that higher recall requires higher memory consumption, as shown in Figure 3. This tradeoff manifests in different ways across model architectures: | |||
* Vanilla attention models: achieve excellent recall but at the cost of quadratic computational complexity | |||
* Linear attention models: offer better efficiency but often struggle with accurate information retrieval | |||
* State-space models like Mamba: show that increasing recurrent state size generally improves recall accuracy but introduces computational overhead | |||
[[File:memory-recall-tradeoff.png|thumb|300px|Figure 3: Empirical performance showing the memory-recall tradeoff]] | |||
=== Implementations === | |||
The authors introduced Based architecture to improve this tradeoff. They do so by stacking linear attention blocks with sliding window attention blocks. Recall that one of the limitations of linear attention is relatively poor performance compared to vanilla attention. By selecting this type of architecture, the authors compromise on global performance by allowing linear attention to map global dependencies while retaining high performance within a local window by using the sliding window attention to map local dependencies. As seen in figure 4, they achieve this goal. | |||
==== Chunking Strategy ==== | |||
Chunking involves spliting the input sequence <math>X</math> into chunks of length <math>B</math>. For each chunk <math>i</math>, let | |||
<math>Q_{[i]} = Q_{B i : B(i+1)}, \quad K_{[i]} = K_{B i : B(i+1)}, \quad V_{[i]} = V_{B i : B(i+1)}</math> | <math>Q_{[i]} = Q_{B i : B(i+1)}, \quad K_{[i]} = K_{B i : B(i+1)}, \quad V_{[i]} = V_{B i : B(i+1)}</math> | ||
Line 676: | Line 995: | ||
be the query, key, and value vectors (respectively) for that chunk. We maintain a recurrent “global” state <math>R_i</math> that summarizes information from all previous chunks and combine it with a “local” sliding-window attention over the current chunk. | be the query, key, and value vectors (respectively) for that chunk. We maintain a recurrent “global” state <math>R_i</math> that summarizes information from all previous chunks and combine it with a “local” sliding-window attention over the current chunk. | ||
=== Global Linear Attention (Recurrent “State”) === | ==== Global Linear Attention (Recurrent “State”) ==== | ||
We use a feature map <math>\Phi(\cdot)</math> (e.g., a kernel or ELU-based mapping) to make attention linear in the sequence length. We update a global state <math>R_i</math> for each chunk: | We use a feature map <math>\Phi(\cdot)</math> (e.g., a kernel or ELU-based mapping) to make attention linear in the sequence length. We update a global state <math>R_i</math> for each chunk: | ||
Line 687: | Line 1,006: | ||
Intuitively, <math>R_{i-1}</math> aggregates (in linear time) all key-value information from past chunks. | Intuitively, <math>R_{i-1}</math> aggregates (in linear time) all key-value information from past chunks. | ||
=== Local Sliding-Window Attention (Exact) === | ==== Taylor Linear Attention ==== | ||
The Based architecture specifically employs a second-order Taylor series expansion as the feature map for linear attention. This approximation offers a good balance between computational efficiency and performance: | |||
<math>\Phi(x) = \begin{bmatrix} 1 \\ x \\ \frac{x^2}{2} \end{bmatrix}</math> | |||
This feature map provides a reasonable approximation of the softmax function while maintaining the linear complexity benefits. | |||
==== Local Sliding-Window Attention (Exact) ==== | |||
Within the current chunk, we also apply standard (exact) attention to a small window of recent tokens. Denote the set of token indices in this window by <math>\mathcal{W}_i</math>. Then: | Within the current chunk, we also apply standard (exact) attention to a small window of recent tokens. Denote the set of token indices in this window by <math>\mathcal{W}_i</math>. Then: | ||
Line 694: | Line 1,019: | ||
This term captures fine-grained, short-range dependencies at full (exact) attention fidelity, but only over a small neighborhood <math>\mathcal{W}_i</math>. | This term captures fine-grained, short-range dependencies at full (exact) attention fidelity, but only over a small neighborhood <math>\mathcal{W}_i</math>. | ||
=== Final Output per Chunk === | ==== Final Output per Chunk ==== | ||
The total representation for chunk <math>i</math> is a sum of global (linear) and local (windowed) attention: | The total representation for chunk <math>i</math> is a sum of global (linear) and local (windowed) attention: | ||
Line 701: | Line 1,026: | ||
By combining a linear-time global update (<math>R_i</math>) with high-resolution local attention, the model balances throughput (via the efficient global linear component) and recall (via exact local attention). This addresses the memory-recall tradeoff: large contexts are captured without quadratically scaling memory usage, while local windows preserve high accuracy on short-range dependencies. | By combining a linear-time global update (<math>R_i</math>) with high-resolution local attention, the model balances throughput (via the efficient global linear component) and recall (via exact local attention). This addresses the memory-recall tradeoff: large contexts are captured without quadratically scaling memory usage, while local windows preserve high accuracy on short-range dependencies. | ||
=== Gated Linear Attention (GLA) | === Hardware Optimizations === | ||
Recent work has shown that linear attention can be made even more practical when its implementation is tailored to the underlying hardware. For example, methods like '''FLASHLINEARATTENTION''' integrate I/O-aware techniques to minimize data movement between high-bandwidth memory and faster on-chip memories, resulting in real-world speedups that can even outperform optimized softmax attention implementations on moderate sequence lengths. | |||
The Based Architecture includes several hardware-aware optimizations: | |||
# Memory-efficient linear attention: The implementation fuses the feature map and causal dot product computation in fast memory, reducing high-latency memory operations. | |||
# Optimized sliding window: The window size is carefully selected to align with hardware constraints (typically 64×64 tiles), balancing computation and memory bandwidth. | |||
# Register-level computation: Critical calculations are performed in registers whenever possible, minimizing data movement between different memory hierarchies. | |||
==== FLASHLINEARATTENTION: Hardware-Efficient Linear Attention for Fast Training and Inference ==== | |||
FLASHLINEARATTENTION is an I/O-aware, hardware-efficient linear attention mechanism for efficient data movement between shared memory (SRAM) and high-bandwidth memory (HBM). The goals are to alleviate memory bottlenecks, maximize GPU parallelism, and accelerate training and inference. The method significantly outperforms even FLASHATTENTION-2 at moderate sequence lengths (~1K tokens). Softmax-based self-attention, which is standard attention mechanism, is quadratic in both computation and memory complexity, making it inefficient for long sequences and scalability-constrained. Linear attention attempts to reduce this complexity, but most of them are not GPU-optimized for modern GPUs and do not provide real-world speed improvements. FLASHLINEARATTENTION solves this by splitting the input sequence into more manageable pieces where computation can independently be performed on each piece before global state updating. This alleviates redundant memory access, decreased latency GPU operations, and keeps tensor cores in effective use. | |||
Mathematically, FLASHLINEARATTENTION builds upon linear attention, which rewrites standard softmax attention as: | |||
<math>\text{Attention}(Q, K, V) = \text{softmax} \left(\frac{QK^T}{\sqrt{d}} \right) V</math> | |||
instead of explicitly computing the full <math>n \times n</math> attention matrix, linear attention approximates it using a kernel function <math>\phi(x)</math> such that: | |||
<math>\text{Attention}(Q, K, V) \approx \frac{\phi(Q) (\phi(K)^T V)}{\phi(Q) (\phi(K)^T)}</math> | |||
where <math>\phi(x)</math> is a feature map transformation ensuring that the inner product approximates the softmax function. In standard parallel linear attention, the output is computed as: | |||
<math>O = (Q K^T) V</math> | |||
which still has quadratic complexity <math>O(n^2 d)</math>. FLASHLINEARATTENTION, on the other hand, splits the input sequence into pieces and processes them separately while having a hidden state <math>S</math>. The rule to update the hidden state is in a recurrent form: | |||
<math>S[i+1] = S[i] + \sum_{j=1}^{C} K_j^T V_j</math> | |||
<math>O[i+1] = Q[i+1] S[i] + (Q[i+1] K[i+1]^T \odot M) V[i+1]</math> | |||
Here, <math>M</math> is a causal mask, ensuring attention is calculated only for tokens in the past. Chunkwise computation, minimizing memory overhead by splitting the sequence into smaller chunks to process, is the most crucial optimization in FLASHLINEARATTENTION to enhance efficiency. Chunks can be processed independently, allowing parallel execution. HBM I/O cost minimization is another key optimization, avoiding unnecessary data transfer between HBM and SRAM by reusing on-chip loaded tensors. When <math>Q[n]</math> is loaded into SRAM, both <math>Q[n]S</math> and <math>(Q[n]K[n]^T \odot M)V[n]</math> are computed without reloading <math>Q[n]</math>. FLASHLINEARATTENTION has two implementations. In the non-materialization version, hidden states <math>S[i]</math> are stored in SRAM to enable memory-efficient computation. In the materialization version, all <math>S[i]</math> are stored in HBM to enable full sequence-level parallelism. The materialization version is slightly slower but boosts training throughput by 10-20%. The calculation is parallel for chunks but sequential between chunks. This is for efficiency and handling of memory to render the algorithm suitable for processing long sequences. | |||
The FLASHLINEARATTENTION forward pass algorithm goes as follows. For input matrices <math>Q</math>, <math>K</math>, and <math>V</math> of size <math>L \times d</math> and chunk size <math>C</math>, the sequence is divided into <math>N = L / C</math> blocks as: | |||
<math>Q = \{ Q[1], Q[2], ..., Q[N] \}, \quad K = \{ K[1], K[2], ..., K[N] \}</math> | |||
The hidden state is initialized as <math>S = 0</math> in SRAM. For each chunk, <math>S</math> is stored in HBM if materialization is enabled, and the corresponding <math>K[n], V[n]</math> values are loaded into SRAM. The hidden state update follows: | |||
<math>S = S + K[n]^T V[n]</math> | |||
and the output for each chunk is computed in parallel as: | |||
<math>O'[n] = Q[n] S + (Q[n] K[n]^T \odot M) V[n]</math> | |||
The outputs are then stored in HBM and fed as outputs. The algorithm maintains a trade-off between memory usage and parallelization such that training speeds are higher than previous linear attention versions. FLASHLINEARATTENTION offers significant performance gains compared to several other attention models. Speedup gains involve less than FLASHATTENTION-2 on sequences shorter than 4K tokens and doubling of training speeds compared to standard linear attention. Memory efficiency is improved through reducing HBM I/O expense and removing redundant data movement, with 4x less memory usage compared to baseline softmax attention. Scalability is also a benefit, allowing for processing of sequences longer than 20K tokens without quadratic memory growth. These improvements make FLASHLINEARATTENTION a solid tool for large-scale transformers, improving efficiency in training and inference. By integrating FLASHLINEARATTENTION into traditional transformer designs, large language models can be significantly accelerated, and large-scale deployment made more viable. Future work can explore deeper kernel optimizations, CUDA-specific workloads, and hardware-specific transformations to further enhance efficiency. FLASHLINEARATTENTION is a significant advancement in hardware-efficient deep learning, supporting memory-efficient training at higher speeds for large-scale transformers. By optimizing memory access, chunking input sequences, and parallel intra-chunk computation, it achieves significant speedup with strong recall ability, making it a landmark achievement in the efficient processing of long sequences. | |||
=== Limitations === | |||
While the Based Architecture presents a promising direction, it still faces some challenges: | |||
# Implementation complexity: The dual-attention mechanism and hardware optimizations add complexity to the implementation. | |||
# Hyperparameter sensitivity: The optimal balance between global and local attention components may vary across different tasks and datasets. | |||
# Performance gap: Despite improvements, there remains a gap between Based models and state-of-the-art vanilla attention models on some specific tasks. | |||
== Gated Linear Attention (GLA) == | |||
RetNet shows us we can achieve parallel training, efficient inference, and more optimized performance. However, despite its innovations, RetNet still has significant limitations. The biggest problem with RetNet was how it handled sequential data. RetNet uses a constant decay factor (<math>\gamma</math>) that doesn't adapt to the content being processed. Let's imagine a conversation where you treat every fifth word equally regardless of how significant. That is what RetNet does with its constant exponential decay factor. What that means is that RetNet treats all sequential dependencies the same way, regardless of whether the current content requires more emphasis on recent tokens or on tokens much farther back. This process could work, but not necessarily optimally. | RetNet shows us we can achieve parallel training, efficient inference, and more optimized performance. However, despite its innovations, RetNet still has significant limitations. The biggest problem with RetNet was how it handled sequential data. RetNet uses a constant decay factor (<math>\gamma</math>) that doesn't adapt to the content being processed. Let's imagine a conversation where you treat every fifth word equally regardless of how significant. That is what RetNet does with its constant exponential decay factor. What that means is that RetNet treats all sequential dependencies the same way, regardless of whether the current content requires more emphasis on recent tokens or on tokens much farther back. This process could work, but not necessarily optimally. | ||
Line 724: | Line 1,102: | ||
===Limitations & Future Work=== | |||
* Lack of Large-Scale Validation | * Lack of Large-Scale Validation | ||
**Although the authors anticipate that the training efficiency of GLA become even more favorable compared to model like Mamba at larger scales, it is unclear how GLA would scale to even larger models/datasets. | **Although the authors anticipate that the training efficiency of GLA become even more favorable compared to model like Mamba at larger scales, it is unclear how GLA would scale to even larger models/datasets. | ||
Line 731: | Line 1,109: | ||
** While GLA outperforms other subquadratic models in recall-intensive tasks, it still lags behind standard quadratic attention Transformers, indicating room for improvement in tasks requiring precise memory retrieval. | ** While GLA outperforms other subquadratic models in recall-intensive tasks, it still lags behind standard quadratic attention Transformers, indicating room for improvement in tasks requiring precise memory retrieval. | ||
==TransNormerLLM: A Faster and Better Large Language Model with Improved TransNormer== | |||
The objective of the paper is to address the quadratic time complexity and scalability limitations of conventional transformer-based LLMs by introducing a linear attention-based architecture for improved efficiency and performance. | The objective of the paper is to address the quadratic time complexity and scalability limitations of conventional transformer-based LLMs by introducing a linear attention-based architecture for improved efficiency and performance. Modern linear attention models like '''TransNormerLLM''' push the envelope by integrating several modifications—including improved positional encodings (using LRPE with exponential decay), gating mechanisms, and tensor normalization—to not only match but even outperform conventional Transformer architectures in both accuracy and efficiency. These innovations help address the limitations of earlier linear attention approaches and demonstrate the potential for linear methods in large-scale language modeling. | ||
===Key Contributions=== | |||
<ul> | <ul> | ||
Line 796: | Line 1,174: | ||
</ul> | </ul> | ||
==Simple linear attention language models balance the recall-throughput tradeoff== | |||
Large language models (LLMs) using Transformers are fantastic at "recalling" details from their input—think of it as their ability to dig up a specific fact buried in a long conversation. However, attention-based language models suffer from high memory consumption during inference due to the growing key-value (KV) cache, which scales with sequence length. This reduces throughput and limits efficiency for long sequences, despite their strong recall ability. | |||
===Key Contributions=== | |||
* '''BASED''': A hybrid architecture blending linear attention and sliding window attention to optimize the recall-memory tradeoff. | |||
* '''Theoretical and empirical evidence''' of how memory usage impacts recall. | |||
* '''Optimized implementation''' achieving up to 24× higher throughput than FlashAttention-2 for long-sequence generation. | |||
===BASED=== | |||
BASED combines two tricks to strike a balance: | |||
# '''Linear Attention''': This approximates the softmax with a simpler operation. Instead of computing <math>\exp(q_i^\top k_j / \sqrt{d})</math>, it uses a feature map <math>\phi</math> so that <math>\phi(q_i)^\top \phi(k_j) \approx \exp(q_i^\top k_j / \sqrt{d})</math>. The paper opts for a 2nd-order Taylor series: <math>\phi(q_i)^\top \phi(k_j) = 1 + q_i^\top k_j + \frac{(q_i^\top k_j)^2}{2}</math>. This lets them rewrite attention as a recurrent process with a fixed-size state, say <math>s_i = \sum_{j=1}^i \phi(k_j)^\top v_j</math>, avoiding the KV-cache explosion. The state size depends on the feature dimension <math>d'</math>, not <math>N</math>, making memory predictable. | |||
# '''Sliding Window Attention''': This adds precision by letting each token attend to a small, local window of past tokens (e.g., 64 tokens) using exact softmax attention. It’s like giving the model a magnifying glass for nearby details, with a KV-cache capped at the window size <math>w</math>. | |||
Together, linear attention handles long-range context with a fixed memory footprint, while sliding window attention sharpens local recall. By tweaking <math>w</math> and <math>d'</math>, BASED can slide along the recall-memory tradeoff curve—crank up the window for better recall or shrink it for efficiency. | |||
===Limitations and Future Work=== | |||
* '''Scaling''': May struggle with very large models or extremely long sequences. | |||
* '''Future Work''': Could explore improved approximations for linear attention or enhanced hardware optimizations. | |||
= Topic 11: Flash Attention= | = Topic 11: Flash Attention= | ||
Line 801: | Line 1,198: | ||
Flash Attention tries to address the resource intensity of vanilla attention by utilizing hardware more effectively. Recent advances in computational speed (particularly in parallelization) have lessened computational intensity as a bottleneck. Consequently, memory usage became a larger bottleneck. Specifically, during training, attention requires data to be read and written multiple times and computations are done in HBM (which is not the fastest memory in the computer). As a result, training takes a lot of time and resource. The goal of Flash Attention is thus to address these inefficiencies in order to reduce training time for LLMs. | Flash Attention tries to address the resource intensity of vanilla attention by utilizing hardware more effectively. Recent advances in computational speed (particularly in parallelization) have lessened computational intensity as a bottleneck. Consequently, memory usage became a larger bottleneck. Specifically, during training, attention requires data to be read and written multiple times and computations are done in HBM (which is not the fastest memory in the computer). As a result, training takes a lot of time and resource. The goal of Flash Attention is thus to address these inefficiencies in order to reduce training time for LLMs. | ||
== Flash Attention V1 == | |||
Flash Attention V1 rethinks the standard attention mechanism by minimizing expensive memory transfers. Modern GPUs have high-bandwidth memory (HBM) that is large but relatively slow compared to on‑chip SRAM, which is over 10× faster but limited to about 20 MB. By restructuring the attention computation to operate in small blocks within SRAM, Flash Attention V1 dramatically reduces the number of slow memory reads and writes, resulting in faster training for large language models. | |||
=== Overview === | |||
Traditional attention computes the full <math>QK^\top</math> matrix in HBM, causing a quadratic memory traffic bottleneck. Flash Attention V1 addresses this by processing the input matrices in smaller, manageable blocks (tiles) that fit in fast SRAM. In doing so, both the matrix multiplication and the softmax normalization are performed piecewise, with local results later merged to produce the exact, globally normalized output. | |||
==== Block‑Wise (Tiled) Computation ==== | |||
Imagine trying to paint a giant mural using a small canvas. Rather than handling the entire image at once, you divide it into sections that fit on your workbench. In Flash Attention V1, the large <math>Q</math>, <math>K</math>, and <math>V</math> matrices are partitioned into tiles. For example, with 1024 tokens, the matrices might be split into blocks of 256 tokens each. A block of <math>Q</math> (e.g. <math>256 \times d</math>) is loaded from HBM into SRAM along with the corresponding block of <math>K</math> (transposed as needed). The multiplication <math>QK^\top</math> is then computed within SRAM to yield a <math>256 \times 256</math> block of attention scores, eliminating the need to store or compute the entire matrix in slow memory. | |||
==== Softmax Normalization: Handling Global Dependencies in Small Pieces ==== | |||
The softmax function, defined as | |||
<math display="block">\text{softmax}(x_i) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}</math> | |||
requires a global normalization across each row of scores. This poses a challenge when the data is processed in tiles. | |||
Within each tile, a local maximum (<math>m_{\text{local}}</math>) is computed to stabilize the exponentiation, and a local sum is calculated: | |||
<math display="block">S_{\text{local}} = \sum_{i \in \text{tile}} \exp(x_i - m_{\text{local}})</math> | |||
Simultaneously, the tile computes a partial numerator by weighting each <math>\exp(x_i - m_{\text{local}})</math> with the corresponding <math>V</math> values. These local computations ensure that the softmax can be performed on a small subset of the data without needing the entire row at once. | |||
==== Algebraic Aggregation: Merging Tile Results for Global Softmax ==== | |||
Since each tile’s softmax is computed relative to its own local maximum, the results must be re‑aligned to a common global reference. Suppose one tile computes a local maximum <math>m_1</math> and sum | |||
<math>S_1 = \sum_{i \in \text{tile 1}} \exp(x_i - m_1)</math> | |||
and another tile computes <math>m_2</math> and | |||
<math>S_2 = \sum_{j \in \text{tile 2}} \exp(x_j - m_2)</math>. The global maximum is given by | |||
<math display="block">m = \max(m_1, m_2)</math> | |||
Each local sum is then adjusted to this common reference: | |||
<math display="block">S_{\text{global}} = \exp(m_1 - m) S_1 + \exp(m_2 - m) S_2</math> | |||
A similar adjustment is applied to the local numerators (the weighted sums of <math>V</math>). The final attention output is obtained by dividing the aggregated numerator by <math>S_{\text{global}}</math>. This algebraic aggregation process guarantees numerical stability and correctness while never requiring the full matrix to be processed at once. | |||
==== Memory vs. Compute Trade-offs and Workflow Recap ==== | |||
Flash Attention V1 trades additional computations for a dramatic reduction in memory traffic. Although extra arithmetic is performed—such as recomputing local softmax values and adjusting them—these operations occur in fast SRAM. The overall workflow is as follows: | |||
Load small tiles of <math>Q</math>, <math>K</math>, and <math>V</math> from HBM into SRAM. Perform block‑wise matrix multiplication and softmax computations within SRAM, including computing local maximums, local sums, and partial numerators. Merge these local results through algebraic aggregation to produce the correct global softmax normalization. Finally, write the computed attention outputs back to HBM. | |||
==== Optional: Block‑Sparse Extension ==== | |||
In some cases, not every tile contributes equally. Flash Attention V1 can be extended to a block‑sparse variant where tiles that fall below a predetermined importance threshold are skipped entirely. This further reduces computation and memory transfers, echoing ideas from sparse attention methods. | |||
=== Flash Attention V2 === | === Flash Attention V2 === | ||
Line 819: | Line 1,246: | ||
* Attention bottleneck: quadratic time complexity, inefficient hardware utilization | * Attention bottleneck: quadratic time complexity, inefficient hardware utilization | ||
* Flash Attention V1: still not nearly as fast as optimized matrix-multiply (GEMM) operations, reaching only 25-40% of the theoretical maximum FLOPs/s | * Flash Attention V1: still not nearly as fast as optimized matrix-multiply (GEMM) operations, reaching only 25-40% of the theoretical maximum FLOPs/s | ||
* Suitability to Different Devices: Further research is needed to make Flash Attention V2 available to different devices, such as H100 GPUs, AMD GPUs. So far, the discussion has been focusing on the HBM and SRAM (NVIDIA GPUs with specific memory hierarchies). | |||
==== Idea ==== | ==== Main Idea ==== | ||
In V1, its idea is to use SRAM, compute more, reduce reads/writes. However, it's mentioned previously that only 25%-40% of theoretical maximum FLOPs is reached. This inefficiency is due to suboptimal work partitioning between different thread blocks and warps on the GPU, causing either low-occupancy or unnecessary shared memory reads/writes. In V2, the main idea is to have a better work partitioning algorithm. The core improvements are: | In V1, its idea is to use SRAM, compute more, reduce reads/writes. However, it's mentioned previously that only 25%-40% of theoretical maximum FLOPs is reached. This inefficiency is due to suboptimal work partitioning between different thread blocks and warps on the GPU, causing either low-occupancy or unnecessary shared memory reads/writes. In V2, the main idea is to have a better work partitioning algorithm. The core improvements are: | ||
Line 828: | Line 1,256: | ||
3. Efficient memory utilization: Optimize intra-thread communication, reducing memory overhead. | 3. Efficient memory utilization: Optimize intra-thread communication, reducing memory overhead. | ||
==== Empirical Results ==== | |||
[[File:Flash_Attention_V2_Attention_forward_+_backward_speed_on_A100_GPU.png|600px]] | |||
Diagrams above shows how Flash Attention V2 performs across different sequence lengths and compare it to a standard attention implementation in PyTorch, FlashAttention, and FlashAttention in Triton. We can see an improvement in efficiency across all sequence lengths. We can see that FlashAttention-2 is 1.7-3.0× faster than FlashAttention, 1.3-2.5× faster than FlashAttention in Triton, and 3-10× faster than a standard attention implementation. The TFLOPs reached in V2 is about 70% of the theoretical TFLOPs achievable on A100 GPUs. | |||
=== Flash Attention V3 === | === Flash Attention V3 === | ||
Line 847: | Line 1,280: | ||
* Up to 75% GPU utilization with FP16 precision | * Up to 75% GPU utilization with FP16 precision | ||
=== Flash Fast Fourier Transform Convolution (FlashFFT Conv) === | |||
==== Introduction ==== | |||
Efficient sequence processing is always a fundamental challenge in deep learning, especially for natural language processing (NLP). Although convolutional models provide the state-of-the-art performance across various tasks, including NLP and time-series forecasting, their efficiency is often limited due to suboptimal hardware utilization. A potential solution to this problem is the Fast Fourier Transform (FFT), which theoretically operates in <math> O(N\text{log}N)</math> time complexity. However, conventional FFT implementations fail to fully leverage modern GPU architectures, leading to suboptimal efficiency. To address this, a highly optimized FFT-based convolution algorithm, namely Flash Fast Fourier Transform Convolution (Flash FFT Conv), is introduced. | |||
==== Innovation ==== | |||
A key innovation in Flash FFT Conv is the Monarch FFT Decomposition, which significantly improves the efficiency of Fast Fourier Transform (FFT)-based convolutions. Unlike traditional FFT algorithms, such as the Cooley-Tukey FFT, Monarch FFT Decomposition reformulates the FFT computation into a structured matrix representation that can be efficiently executed using matrix multiplications on tensor cores. | |||
===== Importance of Monarch FFT Decomposition ===== | |||
The objective of Monarch FFT decomposition is to express FFT as a series of matrix multiplication. Doing so, it enables efficient GPU computation on specialized compute units (e.g., tensor cores of NVIDIA GPUS or matrix multiply units of TPUs). An order-<math>p</math> Monatch decomposition rewrites the FFT into <math>p</math> matrix-matrix multiply operations (which implies that we can map them onto these fast compute units efficiently). The value of <math>p</math> has a tradeoff. A higher <math>p</math> value imply smaller matrices to multiply, and thus, a lower number of FLOPs. But there is greater I/O communication overhead due to greater number of intermediate results. | |||
[[File:Screenshot 2025-03-14 182945.png|600px|thumb|right|Example illustration of Monarch FFT decomposition <span id="monarch"></span>]] | |||
An illusrtation of Monarch FFT decomposition can be seen to the [[#monarch|right]]. To give an example of Monarch FFT decomposition, consider a matrix <math>N=N_1N_2</math>. An order-2 Monarch FFT decomposition expresses the FFT of <math>N</math> as <math>\mathcal{F}_N=\mathbf{P}(\mathbf{I}_{N_2}\otimes\mathcal{F}_{N_1})\mathbf{D}\mathbf{P}^{-1}(\mathbf{I}_{N_1}\otimes\mathcal{F}_{N_2})\mathbf{P}</math>, where <math>\oslash</math> is the Kronecker product, <math>\mathbf{P}</math> is the permutate matrix that reshapes the input to <math>N_1\times N_2</math>, transposes the intermediate matrix,and then reshape it back to <math>N</math>m, and <math>\mathbf{D}\in\mathbb{C}^{N\times N}</math> is a dinagonal matrix containing correctional values Twiddle factors. Twiddle factors are roots of unity <math>W^k_N=\exp{\left(-j\frac{2\pi k}{N}\right)}</math> where <math>N</math> is the number of points in the FFT, that are used to combine the results of smaller DFTs to generate larger DFTs (recall that FFT is a divide-and-conquer algorithm to compute DFT efficiently). | |||
To execute higher-order Monarch FFT decompositions, one can recursively apply the order-2 decomposition to <math>\mathcal{F}_{N_1}</math> and <math>\mathcal{F}_{N_2}</math>. | |||
==== Benefits ==== | |||
By leveraging this structured decomposition, FlashFFTConv achieves highly parallelized execution across the input sequence. This maximizes hardware utilization, resulting in faster computation. Furthermore, this approach minimizes high-latency global memory (HBM) access by performing most computations within on-chip memory (SRAM or shared memory). This reduces the I/O bottleneck that often limits FFT performance on modern GPUs. This leads to significant speedups over conventional FFT-based convolution methods. Additionally, the use of sparse and low-rank matrix structures further enhances computational efficiency, eliminating unnecessary operations and improving scalability for long-sequence processing. | |||
Through efficient memory access and tensor core optimization, Flash FFT Conv becomes a highly effective solution for accelerating FFT-based convolutions tasks, such as NLP and audio generation. | |||
= Topic 5: KD / Pruning / Sharing = | |||
Knowledge Distillation (KD), Pruning, and Sharing are three strategies to make language models smaller, faster, and more efficient without sacrificing their performance. KD works by having a large, powerful "teacher" model that teachs a smaller "student" model how to behave similarly so the student model can do almost as well as the teacher but with less effort. Pruning can be thought of trimming a tree by which it removes unnecessary or less important parts of the model, making it more compact and faster while keeping its core abilities intact. Sharing is about reusing parts of the model across different tasks or layers, so there is no need to build everything from scratch, saving time and resources. | |||
While KD focuses on teaching a smaller model to imitate a larger one, pruning cuts away the unnecessary parts to make the model cleaner, and sharing reuses parts to avoid having unnecessary extra work. These techniques contribute together to create language models that are not only powerful but also efficient enough to run on devices with limited resources, like smartphones or laptops. | |||
===Knowledge Distillation (KD)=== | |||
KD transfers knowledge from a large, well-trained teacher model to a smaller student model. The student learns not only from the hard labels but also by mimicking the teacher’s soft output distributions or internal representations. In the work by Muralidharan et al. , KD is a critical component in recovering the performance lost during the compression of a 15B-parameter model. By distilling knowledge during retraining, the authors were able to produce smaller models (MINITRON 8B and 4B) that not only match but, in some cases, outperform models trained from scratch—even while using a fraction of the training data. | |||
[[File:KD.png|500px|thumb|right|Example illustration of Sequence-level knowledge distillation <span id="KD"></span>]] | |||
===Improvements to Knowledge Distillation=== | |||
'''1. Born-Again Networks (BANs): ''' | |||
Repeated training iterations to simplify the dataset. | |||
Improves performance for lower-capacity NAT models. | |||
'''2. Mixture of Experts (MoE): ''' | |||
Reduces data diversity by training on simpler expert translations. | |||
'''3. Sequence-Level Interpolation: ''' | |||
Selects the best candidate translation based on BLEU score, improving high-capacity NAT models. | |||
===Metrics for Distilled Data=== | |||
* '''Data Complexity''' measures how much variability is present in the translations, which is calculated by fitting an alignment model and compute the average of token-level conditional entropy. | |||
* '''Data faithfulness''' evaluates how closely the distilled data aligns with the original real-world data. | |||
<math> | |||
F(d)=\frac{1}{\left|\mathcal{V}_x\right|} \sum_{x \in \mathcal{V}_x} \sum_{y \in \mathcal{V}_y} p_r(y \mid x) \log \frac{p_r(y \mid x)}{p_d(y \mid x)} | |||
</math> | |||
==Attention Is All You Need But You Don’t Need All Of It For Inference of Large Language Models== | |||
The computational costs of LLMs, particularly during inference, is a major bottleneck. This paper explores a key optimization: selectively dropping layers, specifically, deeper attention layers, during inference to speed up computations while maintaining performance. | |||
===The Problem: Inference Complexity=== | |||
Inference in LLMs is costly because of the self-attention mechanism, which scales quadratically with input length. This creates a challenge, especially for applications requiring real-time responses. Researchers have previously explored various methods to optimize inference, including pruning, quantization, and speculative decoding, but this paper focuses on a more direct approach: skipping unnecessary layers. | |||
===Key Idea: Skipping Layers Without Significant Performance Loss=== | |||
The core hypothesis of the paper is that not all layers contribute equally to the final output of an LLM. Specifically, the deeper layers (closer to the output) exhibit higher similarity with their preceding layers, meaning that dropping some of them may have minimal impact on performance. To test this, the authors conduct experiments on Llama-v2 (7B and 13B models) and compare different strategies for skipping layers. | |||
===Method: Selective Layer Skipping=== | |||
Consider a Transformer model with <math>L</math> layers, where each layer consists of an attention sub-layer and a multi-layer perceptron (MLP) sub-layer. We define each layer as: | |||
<math display="block">\text{Layer}_i = (\text{Attention}_i, \text{MLP}_i)</math> | |||
for <math>i \in \{1, 2, \dots, L\}</math>. | |||
The paper evaluates three techniques for skipping components of Transformer layers: | |||
1. Skipping MLP Layers: The feedforward (MLP) sub-layers from the last few layers are removed. | |||
If we skip the last <math>k</math> MLP layers while keeping the attention layers, the modified model is: | |||
<math display="block">M_{\text{skip MLP}} = \left\{ (\text{Attention}_i, \text{MLP}_i) \mid i \in [1, L-k] \right\} \cup \left\{ (\text{Attention}_i, \emptyset) \mid i \in [L-k+1, L] \right\}</math> | |||
2. Skipping Attention Layers: The self-attention sub-layers from the last few layers are removed. | |||
If we skip the last <math>k</math> attention layers while keeping the MLP layers, the modified model is represented as: | |||
<math display="block">M_{\text{skip attention}} = \left\{ (\text{Attention}_i, \text{MLP}_i) \mid i \in [1, L-k] \right\} \cup \left\{ (\emptyset, \text{MLP}_i) \mid i \in [L-k+1, L] \right\}</math> | |||
3. Skipping Entire Transformer Blocks: Both attention and MLP components are removed from some of the last layers. | |||
If we skip both the attention and MLP sub-layers in the last <math>k</math> layers, we obtain: | |||
<math display="block">M_{\text{skip block}} = \left\{ (\text{Attention}_i, \text{MLP}_i) \mid i \in [1, L-k] \right\} \cup \left\{ (\emptyset, \emptyset) \mid i \in [L-k+1, L] \right\}</math> | |||
The models were tested across four benchmarks: ARC, HellaSwag, TruthfulQA, and MMLU, which measure reasoning, common sense, truthfulness, and general knowledge. | |||
===Findings: Attention Layers Are Less Crucial Than MLP Layers=== | |||
[[File:Screenshot 2025-03-15.png|400px|thumb|Cosine similarity of Llama-v2 layers with the previous layer]] | |||
The results demonstrate a clear pattern: | |||
* Dropping entire Transformer blocks leads to a significant performance drop. | |||
* Dropping MLP layers leads to larger performance degradation compared to dropping attention layers. | |||
* Dropping attention layers results in the best trade-off between speed and performance, with only a 1.8% drop in accuracy when 33% of attention layers are removed. Removing 33% of attention layers led to an 18% speedup in Llama-2-13B. | |||
Interestingly, the TruthfulQA benchmark showed an increase in accuracy when some layers were skipped, suggesting that reducing model complexity might reduce hallucinations. | |||
The paper provides empirical evidence that deeper layers contribute less unique information compared to earlier layers. The figure below shows cosine similarity between successive layers, highlighting that deeper layers are more redundant. |
Latest revision as of 00:31, 15 March 2025
Record your contributions here [1]
Use the following notations:
C: You have written a summary/critique on the paper.
Your feedback on presentations
Topic 12: State Space Models
Introduction
State Space Models (SSMs) are introduced as powerful alternatives to traditional sequence modeling approaches. These models demonstrate good performance in various modalities, including time series analysis, audio generation, and image processing and they can capture long-range dependencies more efficiently. SSMs initially struggled to match the performance of Transformers in language modeling tasks and there were some gaps between them. To address their challenges, recent advances in their architecture such as the Structured State Space Model (S4) have been introduced, which succeeded in long-range reasoning tasks and allowed for more efficient computation while preserving theoretical strengths. However, its implementation remains complex and computationally demanding. So further research led to simplified variants such as the Diagonal State Space Model (DSS), which achieves comparable performance with a more straightforward formulation. In parallel, hybrid approaches, like the H3 model, that integrate SSMs with attention mechanisms try to bridge the mentioned gaps. To understand better what I mean from the hybrid word, for example in H3 the authors try replacing almost all the attention layers in transformers with SSMs. More recently, models like Mamba have pushed the boundaries of SSMs by selectively parameterizing state matrices as functions of the input and allowing more flexible and adaptive information propagation. Research in SSMs continues to resolve the remaining challenges and the potential to substitute attention-based architectures with SSMs grows stronger. They will likely play a crucial role in the next generation of sequence modeling frameworks.
Advantages of SSMs
- Efficient Long-Range Dependency Handling: Unlike Transformers, which require [math]\displaystyle{ O(n^2) }[/math] complexity for self-attention, SSMs can process sequences in [math]\displaystyle{ O(n \log n) }[/math] time complexity using efficient matrix-vector multiplications and the Fast Fourier Transform (FFT) or [math]\displaystyle{ O(n) }[/math] in case of Mamba architectures.
- Effective Long-Range Dependency Handling: By leveraging a state update mechanism (parametrized by λ) that preserves and accumulates signals over arbitrarily large time spans, SSMs effectively capture and retain information from distant points in a sequence, enabling robust long-range dependency handling.
- Lower memory requirements: Unlike transformers, SSMs don’t require storing the entire input in memory, leading to lower memory consumption.
- Parallelization: SSMs allow efficient parallelization while maintaining the benefits of recurrent computation, making them a powerful alternative to RNNs.
- Robustness to Irregularly Sampled Data: Unlike Transformers, which inherently assume regularly spaced inputs, State Space Models naturally handle irregularly sampled data, which is prevalent in real-world applications such as sensor data processing and clinical time-series. By explicitly modeling the continuous-time evolution of latent states, SSMs provide robustness against missing or unevenly spaced observations, leading to improved performance in scenarios where input data is incomplete or irregularly sampled.
- Interpretability and Diagnostic Insight: State Space Models offer a distinct interpretability advantage by representing system behavior through their state-transition dynamics explicitly. Unlike black-box models, the learned parameters in SSMs (matrices [math]\displaystyle{ \mathbf{A} }[/math], [math]\displaystyle{ \mathbf{B} }[/math], [math]\displaystyle{ \mathbf{C} }[/math], and [math]\displaystyle{ \mathbf{D} }[/math]) can be analyzed to infer how inputs influence future states and outputs over time. This interpretability is especially valuable in fields where understanding model behavior is critical, such as financial risk modeling or biological systems analysis.
Core concepts
To understand State Space Models better, let's first take a look at the problem with Transformers and Recurrent Neural Networks (RNNs) and the relation between them and SSMs. In transformers during training, we create a matrix comparing each token with every token that came before and the weights in each cell of this matrix show how similar they are to each other. For calculating the weights, we don't do a sequential operation and each two tokens' relation can be computed in parallel. However, during inference, when we try to generate the next token, we need to re-calculate the attention for the entire sequence, even if we already generated some tokens. This means, for a sequence of length [math]\displaystyle{ L }[/math], the computation cost would be [math]\displaystyle{ L^2 }[/math] and this is a bottleneck for long sequences in Transformer. RNN takes two inputs at each time step to predict the output [math]\displaystyle{ y_t }[/math] and generate current hidden state ([math]\displaystyle{ h_t }[/math]). These inputs are the hidden state of previous time step ([math]\displaystyle{ h_{t-1} }[/math]) and the input of current time step ([math]\displaystyle{ x_t }[/math]). This structure helps us to do the inference with linear computation because unlike transformers, it doesn't recalculate all previous hidden states. On the other hand, the known problem with them is that they forget the initial tokens as they go forward. A simple RNN updates its hidden state using below formula:

[math]\displaystyle{
h(t) = \sigma(W_h h(t-1) + W_x x(t))
}[/math]
[math]\displaystyle{ y(t) = W_y h(t) }[/math]
Where:
- [math]\displaystyle{ h(t) }[/math] is the hidden state at time t
- [math]\displaystyle{ x(t) }[/math] is the input
- [math]\displaystyle{ y(t) }[/math] is the output
- [math]\displaystyle{ W_h, W_x, }[/math] and [math]\displaystyle{ W_y }[/math] are weight matrices
- [math]\displaystyle{ \sigma }[/math] is a non-linear activation function
State Space Models (SSM) come from control theory for mathematical representation of a system and describing its possible states. They define a linear mapping from an input signal x(t) to an output signal y(t) through a latent state representation h(t). There are two main components to SSM:
1. State equation: Describes how the system's hidden state changes based on the current state and input
[math]\displaystyle{ h'(t) = \mathbf{A} h(t) + \mathbf{B} x(t) }[/math]
2. Output equation: Defines how the system produces output based on the current state and input
[math]\displaystyle{ y(t) = \mathbf{C} h(t) + \mathbf{D} x(t) }[/math]
Where:
- [math]\displaystyle{ h(t) }[/math] represents the hidden state
- [math]\displaystyle{ x(t) }[/math] is the input
- [math]\displaystyle{ y(t) }[/math] is the output
- [math]\displaystyle{ \mathbf{A}, \mathbf{B}, \mathbf{C}, }[/math] and [math]\displaystyle{ \mathbf{D} }[/math] are learnable parameter matrices
The intuition behind the neural network architecture stems from the state equation, which essentially says the change in latent variable is given by the sum of linear transformations of current latent state and the linear transformation of the input variable. Solving the system of differential equations will then be equivalent to train the neural network to find linear transforamtions [math]\displaystyle{ \mathbf{A},\mathbf{B},\mathbf{C},\mathbf{D} }[/math]. However, note that in this definition we use continuous functions and is not applicable to many sequential input data that are discrete, especially sentences. Therefore, we discretize the differential equation so it adapts to taking discrete input sequences in our problems. This yields
[math]\displaystyle{ h_t = \bar{\mathbf{A}} h_{t-1} + \bar{\mathbf{B}} x_t }[/math]
[math]\displaystyle{ y_t = \mathbf{\bar{C}} h_t }[/math]
where [math]\displaystyle{ h_t }[/math] is the discrete latent state at step t and [math]\displaystyle{ x_t }[/math] is the input at step t.
Note that we dropped [math]\displaystyle{ \mathbf{D}x_t }[/math] for computational simplicity and hope [math]\displaystyle{ \mathbf{\bar{C}} }[/math] can capture the effect of input [math]\displaystyle{ x_t }[/math], as [math]\displaystyle{ \mathbf{\bar{C}} h_t = \mathbf{\bar{C}} (\mathbf{\bar{A}}h_{t-1} + \mathbf{\bar{B}}x_t) }[/math] so [math]\displaystyle{ \mathbf{\bar{C}} }[/math] operates on [math]\displaystyle{ x_t }[/math] as well.
So how does SSM differ from the RNN? In fact, the difference lies in the convolutive view of SSM. To understand this, consider an sequence of inputs [math]\displaystyle{ x_{1:n} }[/math], then we have
[math]\displaystyle{ h_0 = \mathbf{\bar{B}} x_0 }[/math]
[math]\displaystyle{ h_1 = \mathbf{\bar{A}}h_0 + \mathbf{\bar{B}} x_1 = \mathbf{\bar{A}}\mathbf{\bar{B}} x_0 + \mathbf{\bar{B}} x_1 }[/math]
[math]\displaystyle{ h_2 = \mathbf{\bar{A}}h_1 + \mathbf{\bar{B}} x_2 = \mathbf{\bar{A}}^2\mathbf{\bar{B}} x_0 + \mathbf{\bar{A}}\mathbf{\bar{B}} x_1 + \mathbf{\bar{B}} x_2 }[/math]
Note that since we have linear transformation, we can now compute all states in parallel and this could greatly improve training speed.
Note that we also put a bar on the matrices, this is a result of discritization where [math]\displaystyle{ \mathbf{\bar A} }[/math] and [math]\displaystyle{ \mathbf{\bar B} }[/math] are discrete equivalents of the [math]\displaystyle{ \mathbf{A} }[/math] and [math]\displaystyle{ \mathbf{B} }[/math] transforms compounded over timestep [math]\displaystyle{ \Delta }[/math] using some discretization method. For example using a trapezoidal rule
- [math]\displaystyle{ \mathbf{\bar A} = (I - \frac{\Delta}{2}A)^{-1}(I + \frac{\Delta}{2}A) }[/math]
- [math]\displaystyle{ \mathbf{\bar B} = (I - \frac{\Delta}{2}A)^{-1}\Delta B }[/math]
[math]\displaystyle{ \overline{A} }[/math] : Discretized state transfer matrix, which is an approximate representation of the continuous state matrix A.
[math]\displaystyle{ \overline{B} }[/math] : Discrete version of the input matrix.
Look at #Discretization for further details.
Looking at these formulations and those we defined for RNNs shows us they are similar. We can see that an RNN is essentially a non-linear extension of a state space model. The main differences are:
- SSMs are linear transformations between states, while RNNs apply non-linearity through the activation function
- SSMs come from control theory and in control systems, the matrices are typically derived from physics equations, while in machine learning we learn these matrices from data
- In SSMs, we have D u(t) in the second equation which is commonly left out in control problems
HiPPO Matrix
Higher-Order-Polynomial Projection Operator (HiPPO) Matrix is a class of specially designed state transfer matrix A that enables the state space model to capture long-term dependencies in a continuous time setting. The most important matrix in this class is defined as:
[math]\displaystyle{ \left.\left(\textbf{HiPPO Matrix}\right)\right.\quad A_{nk}=-\left\{ \begin{aligned} (2n+1)^{1/2}(2k+1)^{1/2} & & \mathrm{if~}n\gt k \\ n+1 & & \mathrm{if~}n=k \\ 0 & & \mathrm{if~}n\lt k \end{aligned}\right.. }[/math]
This matrix is designed to compress past history into a state that contains enough information to roughly reconstruct history. The interpretation of this matrix is that it produces a hidden state that remembers its history.
Recurrent Representation of SSM
The input is a discrete sequence [math]\displaystyle{ (u_0,u_1,\ldots) }[/math] instead of continuous function. To discretize the continuous-time SSM, we follow prior work in using the bilinear method [2], which converts the state matrix [math]\displaystyle{ A }[/math] into an approximation [math]\displaystyle{ \overline{A} }[/math] . The discrete SSM is:
[math]\displaystyle{ \begin{aligned} x_k & =\overline{\boldsymbol{A}}x_{k-1}+\overline{\boldsymbol{B}}u_k & \overline{\boldsymbol{A}} & =(\boldsymbol{I}-\Delta/2\cdot\boldsymbol{A})^{-1}(\boldsymbol{I}+\Delta/2\cdot\boldsymbol{A}) \\ y_k & =\overline{\boldsymbol{C}}x_k & \overline{\boldsymbol{B}} & =(\boldsymbol{I}-\Delta/2\cdot\boldsymbol{A})^{-1}\Delta\boldsymbol{B} & \overline{\boldsymbol{C}}=\boldsymbol{C}. \end{aligned} }[/math]
This equation is now a sequence-to-sequence map [math]\displaystyle{ u_k\mapsto y_k }[/math] instead of function-to-function.
Discretization
As mentioned, State Space models come from ordinary differential equations, so the way we discretize those continuous equations so they work with finite sequences is crucial. In fact, we aim to learn [math]\displaystyle{ \mathbf{\bar A} }[/math] and [math]\displaystyle{ \mathbf{\bar B} }[/math] instead of [math]\displaystyle{ A }[/math] and [math]\displaystyle{ B }[/math] directly, so the discretization step is baked into our model in practice. Below we show a quick discretization example based on a trapezoidal rule.
Trapezoidal rule assumes: [math]\displaystyle{ x_{n+1} - x_{n} = \frac{\Delta}{2} (f(t_{n+1}) + f(t_{n})) }[/math]
We start from the ordinary differential equation: [math]\displaystyle{ h'(t) = \mathbf{A} h(t) + \mathbf{B} x(t) }[/math]
By applying the trapezoidal rule:
[math]\displaystyle{ h_{n+1} - h_{n} = \frac{\Delta}{2} (\mathbf{A}h_{n+1} + \mathbf{B}x_{n+1} + \mathbf{A}h_{n} + \mathbf{B}x_{n}) }[/math]
[math]\displaystyle{ h_{n+1} - \frac{\Delta}{2} \mathbf{A}h_{n+1} = h_n + \frac{\Delta}{2}\mathbf{A}h_{n} + \frac{\Delta}{2} \mathbf{B}(x_{n+1} + x_{n}) }[/math]
[math]\displaystyle{ (\mathbf{I} - \frac{\Delta}{2} \mathbf{A})h_{n+1} = (\mathbf{I} + \frac{\Delta}{2}\mathbf{A})h_{n} + \frac{\Delta}{2} \mathbf{B}(x_{n+1} + x_{n}) }[/math]
It is assumed that the control sequence does not change significantly over small [math]\displaystyle{ \Delta }[/math], i.e., [math]\displaystyle{ x_{n+1} \approx x_n }[/math], leading to:
[math]\displaystyle{ (\mathbf{I} - \frac{\Delta}{2} \mathbf{A})h_{n+1} = (\mathbf{I} + \frac{\Delta}{2}\mathbf{A})h_{n} + \Delta \mathbf{B}(x_{n+1}) }[/math]
Thus, solving for [math]\displaystyle{ h_{n+1} }[/math]:
[math]\displaystyle{ h_{n+1} = (\mathbf{I} - \frac{\Delta}{2} \mathbf{A})^{-1}(\mathbf{I} + \frac{\Delta}{2}\mathbf{A})h_{n} + (\mathbf{I} - \frac{\Delta}{2} \mathbf{A})^{-1}\Delta \mathbf{B}(x_{n+1}) }[/math]
From this, the discretized state matrices are:
[math]\displaystyle{ \mathbf{\bar A}=(\mathbf{I} - \frac{\Delta}{2} \mathbf{A})^{-1}(\mathbf{I} + \frac{\Delta}{2}\mathbf{A}), \quad \mathbf{\bar B}=(\mathbf{I} - \frac{\Delta}{2} \mathbf{A})^{-1}\Delta \mathbf{B}. }[/math]
An alternative and often-used discretization is based on the matrix exponential:
[math]\displaystyle{ \mathbf{\bar{A}} = e^{A\Delta}, \quad \mathbf{\bar{B}} = (A^{-1}(e^{A\Delta} - I)) B. }[/math]
This formulation naturally arises in continuous-time state-space models and ensures numerical stability for stiff systems.
The matrix index [math]\displaystyle{ {A\Delta} }[/math] comes directly from the analytic solution of the continuous-time system: [math]\displaystyle{ x(t) = e^{At} x(0) + \int_0^t e^{A(t-s)} B u(s) ds }[/math]
For discretization, we choose [math]\displaystyle{ t = k\Delta }[/math] as the sampling point to obtain:
[math]\displaystyle{ x_k = e^{A\Delta} x_{k-1} + (e^{A\Delta} - I) A^{-1} B u_k }[/math]
which is the source of the matrix index method.
To express this in terms of a convolution kernel, the system can be reformulated as:
[math]\displaystyle{ \bar{K} = ( C \bar{B}, C \bar{A} \bar{B}, ..., C \bar{A}^{L-1} \bar{B} ). }[/math]
This can also be rewritten using the exponential formulation:
[math]\displaystyle{ K = \left( C e^{A k \Delta} (e^{A\Delta} - I) A^{-1} B \right)_{0 \leq k \lt L}. }[/math]
This convolution kernel representation helps efficiently compute long-range dependencies in sequences.
Numerical Stability Considerations:
- If [math]\displaystyle{ \|A\Delta\| }[/math] is large, naive exponentiation can introduce significant numerical errors.
- The trapezoidal rule provides a more stable alternative by constraining eigenvalues.
- In practical DSS implementations, row-wise normalization is used to further stabilize large-sequence dynamics.
These methods ensure that state-space models remain numerically robust and efficient when handling long sequences.
Structured State Space (S4)
Objective
The structured state space (S4) model is specialized for conducting efficient sequence modelling, particularly in situations when handling long-term dependencies is needed. The goal of S4 is to improve computational efficiency, reduce the number of parameters, and enhance the ability of learning by introducing constraints on the state transition matrices.
Key Contributions
To efficiently handle long sequences, the Structured State Space (S4) model employs a structured parameterization of its state transition matrices. Specifically, it uses a Diagonal Plus Low-Rank (DPLR) parameterization, which combines the simplicity of diagonal matrices with the expressiveness of low-rank updates. This approach allows S4 to learn complex state spaces while maintaining computational efficiency, making it particularly effective for tasks requiring long-range dependencies.
Intuition Behind S4
The main innovation of S4 is its ability to track long-range dependencies efficiently. Traditional models struggle with long sequences either because they require too much memory (like Transformers) or forget early information (like RNNs). S4 overcomes these limitations by using a mathematical reformulation of state space models, allowing it to process sequences much more efficiently. The authors argue that the different layers of the S4 model contribute to learning long-range dependencies efficiently. They hypothesized that the deep layers of S4 are predominantly responsible for learning local information (short-range dependencies). On the other hand, higher layers aggregate global information, allowing the model to capture the bigger picture.
At its core, S4 is a hybrid between recurrent models and convolutions. Instead of updating each time step sequentially like an RNN, it uses a special structure called the Cauchy kernel, which allows it to compute all steps in parallel. This enables it to process sequences of tens of thousands of steps with a fraction of the computational cost of previous models. Thus, S4 is very advantageous for inference.
S4 has been particularly successful in domains that require continuous, structured data, such as time-series processing. However, its reliance on structured state matrices makes it less adaptable to unstructured data like natural language, where attention-based models still hold an advantage.
Theorem 1
All HiPPO matrices in the paper[3] have a NPLR representation:
[math]\displaystyle{ A=V\Lambda V^*-PQ^\top=V\left(\boldsymbol{\Lambda}-\left(\boldsymbol{V}^*\boldsymbol{P}\right)(\boldsymbol{V}^*\boldsymbol{Q})^*\right)\boldsymbol{V}^* }[/math]
for unitary [math]\displaystyle{ \boldsymbol{V}\in\mathbb{C}^{N\times N} }[/math], diagonal [math]\displaystyle{ Λ }[/math], and low-rank factorization [math]\displaystyle{ \boldsymbol{P},\boldsymbol{Q}\in\mathbb{R}^{N\times r} }[/math].
From Theorem 1, we can find that NPLR matrices can be conjugated into diagonal plus low-rank (DPLR) form.
The core idea of this form is to diagonalize the matrix A using the unitary-ary transformation V and use a low-rank decomposition to further compress the computation, which is important in efficient computation and storage optimization.
Theorem 2
Given any step size [math]\displaystyle{ ∆ }[/math], computing one step of the recurrence can be done in [math]\displaystyle{ O(N) }[/math] operations where [math]\displaystyle{ N }[/math] is the state size.
Theorem 3
Given any step size [math]\displaystyle{ ∆ }[/math], computing the SSM convolution filter [math]\displaystyle{ \overline{\boldsymbol{K}} }[/math] can be reduced to 4 Cauchy multiplies, requiring only [math]\displaystyle{ {\widetilde{O}}(N+L) }[/math] operations and [math]\displaystyle{ O(N + L)) }[/math] space.
Theorems 2 and 3 describe the complexities of SSMs where A is in DPLR form.
Convolution | Recurrence | Attention | S4 | |
---|---|---|---|---|
Parameters | [math]\displaystyle{ LH }[/math] | [math]\displaystyle{ H^2 }[/math] | [math]\displaystyle{ H^2 }[/math] | [math]\displaystyle{ H^2 }[/math] |
Training | [math]\displaystyle{ \tilde{L}H(B + H) }[/math] | [math]\displaystyle{ BLH^2 }[/math] | [math]\displaystyle{ B(L^2H + LH^2) }[/math] | [math]\displaystyle{ BH(\tilde{H} + \tilde{L}) + B\tilde{L}H }[/math] |
Space | [math]\displaystyle{ BLH }[/math] | [math]\displaystyle{ BLH }[/math] | [math]\displaystyle{ B(L^2 + HL) }[/math] | [math]\displaystyle{ BLH }[/math] |
Parallel | [math]\displaystyle{ Yes }[/math] | [math]\displaystyle{ No }[/math] | [math]\displaystyle{ Yes }[/math] | [math]\displaystyle{ No }[/math] |
Inference | [math]\displaystyle{ LH^2 }[/math] | [math]\displaystyle{ H^2 }[/math] | [math]\displaystyle{ L^2H + H^2L }[/math] | [math]\displaystyle{ H^2 }[/math] |
Limitations
However, a significant critique of the S4 model is its complexity. The model relies on multiple reduction steps and advanced linear algebraic techniques to compute the state space efficiently. While these optimizations are crucial for performance, they also make S4 challenging to understand, implement, and analyze. This complexity can act as a barrier for researchers and practitioners seeking to adapt or build upon the model, highlighting a trade-off between efficiency and accessibility.
Some of the applications of S4 are language and audio processing (text generation, speech recognition, music and voice synthesis), forecasting (predicting stock prices in financial markets, modelling climate information), and scientific computing (simulations in physics, chemistry, astronomy).
Diagonal State Space Model (DSS)
As mentioned above, S4 relies on the Diagonal Plus Low Rank (DPLR) structure of state matrix [math]\displaystyle{ A }[/math], which improves efficiency but introduces significant complexity in both understanding and implementation. To simplify S4 while maintaining its performance, Diagonal State Space Model (DSS) is proposed. It simplifies the state matrix by replacing it with a purely diagonal form, reducing computational overhead and improving interpretability. It also maintains the ability to capture long-range dependencies.
[math]\displaystyle{ \mathbf{Improvements\ Compared\ with\ S4} }[/math]
- Computational Efficiency (and Model Complexity)
- The DSS model offers a significant computational advantage by simplifying the NPLR state matrix to a diagonal form. This diagonalization enables more efficient recurrence computations and reduces the number of parameters to be estimated, making it computationally efficient. This simplicity allows DSS to scale well to large datasets and real-time systems.
- Interpretability
- The diagonal structure makes it clear how previous hidden states and inputs affect the current state, which helps in understanding the relationships between states. For example, it can be shown that under certain situations, DSS can capture information from extreme distant position in the sequence.
- To illustrate how DSS captures long-range dependencies, consider the diagonalization of the state space representation. Given a diagonalizable state matrix [math]\displaystyle{ A }[/math] with eigenvalues [math]\displaystyle{ \lambda_1, \dots, \lambda_N }[/math] and the diagonal matrix [math]\displaystyle{ \Lambda }[/math], DSS reformulates the recurrence relation as: [math]\displaystyle{ K = \bar{K}_{\Delta,L} \left( \Lambda, \left(e^{L \lambda_i \Delta} - 1 \right)^{-1} \right) }[/math], where [math]\displaystyle{ K }[/math] encodes the sequence propagation in the transformed space. The diagonal form simplifies the recurrence relationship, making it clear how past states influence the current state. Under specific conditions, DSS captures long-range dependencies:
- If [math]\displaystyle{ |\lambda_i| \Delta \approx 0 }[/math], the state evolution reduces to [math]\displaystyle{ x_{i,k} \approx x_{i,k-1} }[/math], leading to persistent memory effects.
- If [math]\displaystyle{ \text{Re}(\lambda_i) \Delta \ll 0 }[/math], information from past states is effectively forgotten, prioritizing local dependencies.
- If [math]\displaystyle{ \text{Re}(\lambda_i) \gt 0 }[/math], DSS can propagate information from distant states due to the exponential weighting term [math]\displaystyle{ e^{\lambda \Delta} }[/math]. This property enables DSS to retain long-range dependencies more effectively than conventional models.
- A key advantage of DSS is that it provides interpretable mechanisms to regulate information retention and forgetting. The row-wise normalization in [math]\displaystyle{ \text{DSS}_{\text{softmax}} }[/math] further stabilizes the dynamics, mitigating numerical instabilities when handling large sequence lengths [math]\displaystyle{ L }[/math]. This simplicity enhances the transparency of the model, providing clear insights into system behavior and making it easier to interpret the underlying dynamics. By comparison, S4 lacks interpretability due to its more complex architecture.
[math]\displaystyle{ \mathbf{Two\ Variants\ of\ DSS} }[/math]
The DSS model mentioned above is one of the variants of DSS called [math]\displaystyle{ DSS_{softmax} }[/math]. Softmax normalization ensures that the contribution of different eigenvalues remains bounded, avoiding numerical instability. Another variant of DSS is called [math]\displaystyle{ DSS_{exp} }[/math]. In this variant, the recurrence relation becomes [math]\displaystyle{ K = \bar{K}_{\Delta,L} \left( \Lambda, \left(1\right)\right) }[/math]. Instead of softmax normalization, this variant forces the real part of eigenvalues to be negative, ensuring that the recurrence does not grow uncontrollably. Note that the restriction may fail in some tasks and not as expressive as general state spaces.
[math]\displaystyle{ \mathbf{Applications} }[/math]
Despite its simplifications, DSS retains strong modeling capabilities for a variety of tasks, including time-series forecasting, raw speech classification, and natural language processing (NLP). Its ability to model temporal dependencies with minimal complexity makes it a highly efficient choice for applications where computational efficiency and interpretability are critical. While S4 excels in capturing intricate, long-range dependencies, DSS's performance is nearly as strong, and in certain cases, it even outperforms S4.
[math]\displaystyle{ \mathbf{Limitations} }[/math]
The initialization of parameters can significantly influence the performance of DSS, as improper initialization may lead to slower convergence or suboptimal solutions. Additionally, DSS tends to be less effective at modeling information-dense data, such as text, where complex patterns and intricate relationships between words or phrases are crucial. In these cases, more sophisticated models may be helpful. However, DSS still provides a viable alternative in scenarios where simplicity and efficiency are prioritized over capturing deep contextual relationships.
Hungry Hungry Hippos (H3)
SSMs vs Attention
Research have shown that SSMs demonstrated state-of-the-art performance in domains like speech recognition, audio generation, etc. However, it underperforms attention in language modelling due to two main reasons:
- Expressivity gap
- Poor hardware utilization
Expressivity Gap
To understand the gap between SSMs and attention on language modelling, the authors examined two synthetic language modelling tasks:
The Induction Head task tests how well a model can recall content after a special token. At the end of the sequence, the model must recall the token that appeared immediately after the special token earlier in the sequence. In the table above, we are trying to recall the first token after the ⊢ symbol, which is f.
Associative Recall is similar to the induction head task, but requires the model to remember multiple key-value pairs. At the end of the sequence, the model must recall a specific value belonging to a specific key. In the table above, we are trying to recall the value associated with character a, which is 2.
The table above shows performance of S4D (specialized S4 for language modelling), Gated State Spaces (GSS) and Attention. We can see that Attention can complete both synthetic tasks perfectly, achieving an accuracy of 100%, significantly outperforming S4D and GSS. Failure of SSMs can be attributed to two missing capabilities: (i) the ability to remember tokens that appear after a particular event (e.g., the special token in the induction head task), and (ii) the ability to compare tokens across the sequence (e.g., comparing keys to decide which value to recall). Attention has both these capabilities: it can compare tokens by constructing the quadratic attention matrix [math]\displaystyle{ \mathbf{QK^T} }[/math], and it can recall tokens by direct copying (multiplying [math]\displaystyle{ softmax(\mathbf{QK^T}) }[/math] with [math]\displaystyle{ \mathbf{V} }[/math]).
H3 Design
H3 is designed to enable these capabilities in SSMs, its structure is given as follows:
The design for H3 layer consists of two SSMs stacked together. The idea of stacking multiple SSM layers is to selectively incorporate attention mechanisms for key operations like recalling past tokens or comparing sequence elements. This allows H3 to perform well on tasks where pure SSMs historically underperformed, making it competitive with Transformers in language modeling while maintaining high computational efficiency. Initially, we convert our input sequence into query, key, and value matrices, as seen in regular attention. We then apply two SSM transformations to these input matrices to generate the final output.
Shift SSM
The first SSM is the shift SSM that you can think of as performing a local lookup across the sequence. In a sliding window fashion, the Shift SSM shifts the state vector by one position at each time to store the most recent inputs. The purpose is to detect specific events and remember the surrounding tokens, As an analogy, this is similar to how we remember the most recent words we have read. Our brain captures those last few words (tokens, for the model). For specific words or events in the book, we associate them with the surrounding context (words) around the specific entity.
In the shift SSM, we constrain [math]\displaystyle{ \mathbf{A} ∈ R^{m×m} }[/math] to be a shift matrix as defined by its entries: [math]\displaystyle{ \mathbf{A}_{i,j} = \begin{cases} 1 & \text{for } i - 1 = j,\\ 0 & \text{otherwise}. \end{cases} }[/math].
If you draw it out, it would be a square matrix where the 1s appear directly below the main diagonal with all other entries 0. The action of this matrix on the hidden state [math]\displaystyle{ x_i }[/math] is to shift each coordinate down by one—thereby creating a “memory” of the previous states. If [math]\displaystyle{ \mathbf{B} = \textit{e}_1 }[/math], the first basis vector, then [math]\displaystyle{ x_i = [u_i, u_{i-1}, . . . , u_{i-m+1}] }[/math] contains the inputs from the previous m time steps. Both [math]\displaystyle{ \mathbf{B} }[/math] and [math]\displaystyle{ \mathbf{C} }[/math] are learnable matrices, but [math]\displaystyle{ \mathbf{B} }[/math] is usually fixed to [math]\displaystyle{ \textit{e}_1 }[/math] for simplicity, in which case the output is a 1D convolution with kernel size m.
Diag SSM
The diagonal SSM serve as a kind of global memory that keeps track of important information over the entire sequence. A diagonal matrix is used to summarize the information we have seen. As an analogy, the diagonal SSM acts like the notes that we take when we are in class. It is super difficult to recall every single detail our professor states during a lecture. We will take notes during lecture to retain important concepts and information. We can then refer to these notes when we review. The diagonal SSM serves as our notes that the input sequence can refer to when recalling the summary of past information.
The diagonal SSM constrains A to be diagonal and initializes it from the diagonal version of HiPPO. This parameterization allows the model to remember state over the entire sequence. The shift SSM can detect when a particular event occurs, and the diagonal SSM can remember a token afterwards for the rest of the sequence.
Multiplicative Interaction
Multiplicative interactions, which is represented by the teal elementwise multiplication symbol, it offers H3 the ability to compare tokens across the sequence. Specifically, we can use this to compare and obtain information between the Shift SSM and Diagonal SSM. This is similar to the similarity score (dot product) seen in vanilla attention.
Holistically, we can kind of view the outputs from Shift SSM as local or short-term information. The outputs from diagonal SSM servers as long-term information. The multiplicative interaction enables us to use both types of information to extract important information. This is critical for tasks like associative recall. For example, if the model needs to recall a value associated with a key, it can compare the current key with all previously seen keys using these interactions.
Mamba
Mamba is a technique which builds on S4 models. It was introduced to increase efficiency for long sequences by leveraging selective attention mechanisms, allowing them to save memory and computational cost by not focusing on irrelevant information. It does so by combining the H3 block traditionally used in SSM models with a gated multilayered perceptron (MLP), as shown in the figure below.

Recall that SSM models calculate the output using the following formulae: [math]\displaystyle{ h'(t) = \mathbf{A} h(t) + \mathbf{B} x(t) }[/math]
[math]\displaystyle{ y(t) = \mathbf{C} h(t) + \mathbf{D} x(t) }[/math]
, where these equations are discretized using an additional parameter [math]\displaystyle{ \Delta }[/math] as shown above, resulting in the updated formulae:
[math]\displaystyle{ h(t) = \bar{\mathbf{A}} h(t-1)+\bar{\mathbf{B}}x(t) }[/math]
[math]\displaystyle{ y(t) = \mathbf{C} h(t) }[/math]
, where the matrices [math]\displaystyle{ {\mathbf{A}} }[/math] and [math]\displaystyle{ {\mathbf{B}} }[/math] are discretized as follows:
[math]\displaystyle{ \begin{aligned} & \overline{\mathbf{A}}=\operatorname{diag}\left(e^{-\Delta_t \omega}\right) \\ & \overline{\mathbf{B}}=\left(1-e^{-\Delta_t \omega}\right) \end{aligned} }[/math]
The matrices may be understood as follows:
- [math]\displaystyle{ \bar{\mathbf{A}} }[/math] represents the transition matrix between discrete hidden states
- [math]\displaystyle{ \bar{\mathbf{B}} }[/math] processes the input data in order to update the hidden state
- [math]\displaystyle{ \mathbf{C} }[/math] interprets the hidden state in order to generate the output.
Mamba introduces an input-dependent gating mechanism [math]\displaystyle{ \Delta_t }[/math], defined as:
[math]\displaystyle{ \Delta_t=\tau_{\Delta}\left(W_{\Delta} x_t+b_{\Delta}\right) }[/math]
where [math]\displaystyle{ \tau_{\Delta} }[/math] is typically a softplus or sigmoid activation function, and [math]\displaystyle{ W_{\Delta}, b_{\Delta} }[/math] are learned parameters. The discrete state update rule incorporating selectivity is:
[math]\displaystyle{ h_t=\left(\mathbf{I}-\Delta_t\right) h_{t-1}+\Delta_t x_t }[/math]
This gating function dynamically adjusts the timescale of information propagation, allowing the model to selectively retain or discard information based on the context provided by input data.
Convolutional Kernel Representation
The Mamba model's dynamics can also be expressed through a convolutional kernel representation for computational efficiency, especially during training. Given an input sequence [math]\displaystyle{ x_{1: T} }[/math], the output [math]\displaystyle{ y_t }[/math] can equivalently be represented as:
[math]\displaystyle{ y_t=\sum_{k=1}^t \mathbf{C} \overline{\mathbf{A}}^{t-k} \overline{\mathbf{B}} x_k }[/math]
This representation highlights that Mamba's hidden states effectively implement a convolution over past inputs with a kernel shaped by the learned gating parameters and state transition matrices. This approach provides efficient parallel training similar to convolutional architectures while maintaining the ability to handle extremely long sequences with linear complexity during inference.
Advantages :
- Dynamic Timescales: Unlike traditional SSMs with fixed transition matrices, Mamba adjusts transition dynamics based on inputs.
- Improved Long-range Dependencies: Dynamic gating allows Mamba to selectively propagate crucial long-range information while ignoring irrelevant short-term fluctuations.
- Computational Efficiency: The convolutional representation significantly reduces computational overhead during parallel training phases.
Selective State Space Models in Mamba
Mamba builds on the success of structured state space models (SSMs) — a family of models that, by design, efficiently capture long-range dependencies through a recurrent or convolutional mechanism. In traditional SSMs, a key benefit is their linear-time scaling, achieved by compressing the entire sequence context into a fixed hidden state. However, this very efficiency can be a double-edged sword: by forcing the state to be constant over time (a property known as linear time invariance or LTI), the model may miss crucial content-specific cues in the data.
Mamba addresses this limitation with a clever twist—it introduces a selection mechanism that makes certain SSM parameters depend on the input. In simple terms, while standard SSMs update their hidden state with the same fixed rules at every time step, the selective variant modulates these update rules based on the current input. This allows the model to "decide" when to ignore irrelevant information and when to focus on key inputs, similar to how gating works in recurrent neural networks.
The mechanism works by turning parameters that usually remain constant (like the discretization parameter [math]\displaystyle{ \Delta }[/math] and others controlling the transition between hidden states) into functions of the input. This extra flexibility means that rather than applying a one-size-fits-all transformation to every token, the model tailors its processing dynamically across the sequence. The result is a network that not only maintains the efficiency of linear scaling but also improves accuracy by being content-aware—even for very long sequences.
An important innovation in Mamba is that it leverages a hardware-aware algorithm to keep the additional computational cost low. Instead of materializing huge intermediate states, the algorithm uses optimized memory strategies. In practice, this means that Mamba can run several times faster than comparable Transformer-based models, especially on longer inputs.
In summary, Mamba's selective state space models marry the efficiency of classical SSMs with the flexibility of input-dependent gating. This combination allows the architecture to handle extremely long sequences—up to millions of tokens—while still capturing the nuanced patterns essential for tasks ranging from language modeling to audio generation.
Mamba-2
Semiseparable Matrices
Semiseparable Matrices are matrices that have this form:
[math]\displaystyle{ \begin{bmatrix} C_0^\top B_0 & \\ C_1^\top A_1 B_0 & C_1^\top B_1 & \\ C_2^\top A_2A_1 B_0 & C_2^\top A_2 B_1 & C_2^\top B_2 \\ \vdots & \vdots & \ddots & \ddots \\ C_\mathtt{T}^\top A_{\mathtt{T}-1}\dots A_1 B_0 & C_\mathtt{T}^\top A_{\mathtt{T}-1}\dots A_2 B_1 & \dots & C_\mathtt{T}^\top A_{\mathtt{T}-1} B_{\mathtt{T}-2} & C_\mathtt{T}^\top B_{\mathtt{T}-1} \\ \end{bmatrix} }[/math]
By observation, we can immediately spot that all algorithms for computing state space models can be viewed as structured matrix multiplication algorithms on semiseparable matrices. Such matrix have a property that every submatrix contained in the lower-triangular portion is low rank, so we can write then as:
[math]\displaystyle{ \begin{bmatrix} C_j^T A_{j,i}^X B_{i'} & \cdots & C_j^T A_{j,i-1}^X B_{i-1} \\ \vdots & \ddots & \vdots \\ C_{j'-1}^T A_{j'-1,i}^X B_{i'} & \cdots & C_{j'-1}^T A_{j'-1,i-1}^X B_{i-1} \end{bmatrix} = \begin{bmatrix} C_j^T A_{j,j}^X \\ \vdots \\ C_{j'-1}^T A_{j'-1,j}^X \end{bmatrix} A_{j,i-1}^X \begin{bmatrix} A_{i-1,i'}^X B_{i'} & \cdots & A_{i-1,i-1}^X B_{i-1} \end{bmatrix}. }[/math]
This will leads to an efficient algorithm for calculating semiseparable matrices.

SSD Algorithm: "Block Matrix Decomposition" and "Chunking and State Passing"
- We first partition the semiseparable matrix into blocks of size [math]\displaystyle{ \mathtt{Q} \times \mathtt{Q} }[/math].
- Each diagonal block can be computed using the quadratically, but since [math]\displaystyle{ \mathtt{Q} }[/math] is small, it's hardware efficient. (Orange on the figure)
- Then, we use the properties of semiseparable matrices to factorize each off-diagonal block, which is low rank (Blue, Greem and Yellow on the figure).
- Following a work flow shown, we can have an efficient algorithm that could even incorporate system optimization by running them on different GPUs.
Structured Masked Attention
In RetNet, we used a decay factor which can be seen as a mask on [math]\displaystyle{ \mathtt{Q}^T \mathtt{K} }[/math]. Inspired by Linear Attention and RetNet, we can generalize the idea of using a mask and introduces Structured Masked Attention(SMA). Definition: Structured masked attention (SMA) (or structured attention for short) is defined as a function on queries/keys/values Q,K,V as well as any structured matrix L, through the 4-way tensor contraction.
[math]\displaystyle{ Y = (L \circ Q K^\top) V }[/math]
We can use a special mask called 1-Semiseparable Matrices (1SS) defined as the following.
[math]\displaystyle{ L = \begin{bmatrix} 1 & \\ a_1 & 1 & \\ a_2a_1 & a_2 & 1 \\ \vdots & \vdots & \ddots & \ddots \\ a_{\mathtt{T}-1}\dots a_1 & a_{\mathtt{T}-1}\dots a_2 & \dots & a_{\mathtt{T}-1} & 1 \\ \end{bmatrix} . }[/math]

State Space Duality (SSD)
We can recognize that if we use 1SS as the mask for SMA, and if we restrict our SSM'S [math]\displaystyle{ A }[/math] to be a scalar times identity structure (that means [math]\displaystyle{ A }[/math] is a scalar at a given time, so [math]\displaystyle{ A_t = a_t }[/math] at time [math]\displaystyle{ t }[/math]), the are the same thing. This leads to the State Space Duality. It reveals a surprising connection between SSMs and Transformer-based attention mechanisms. The State Space Duality (SSD) framework suggests that certain structured matrices used in SSMs can be reformulated in ways that resemble attention computations. By leveraging this relationship, researchers have proposed efficient algorithms that blend the best of both worlds—retaining the efficiency of SSMs while adopting optimization techniques from Transformers.
We can view the SSD as:
1. A state space models through structured matrix
2. A generalized linear attention through a masking (The duality between SSMs and structured masked attention (SMA) allows us to reinterpret SSMs as a form of linear attention with specific masking strategies.)
Overall, it is an instance that has dual quadratic and linear forms that can be derived from either representation.
Having all this dual form allows us to use some extensions from linear attention such as multi-head attention, grouped-value attention, kernel attention approximations to softmax attention and so on. To name a few specific examples:
- Efficient algorithmic improvements: SSD provides novel ways to compute state transitions, allowing SSMs to leverage techniques such as grouped-value attention and kernelized approximations from the Transformer literature.
- Scalability benefits: Since SSD offers both quadratic and linear formulations, it enables selective adaptation depending on sequence length and computational constraints.
- Bridging architectures: Insights from SSD can guide the development of hybrid architectures that merge SSM-style recurrence with attention-like expressivity.

Other ways to explain:
- SSM can be regarded as a controlled dynamical system, where at each time step t, the state xt is influenced by A, while B is responsible for introducing inputs and C is responsible for computing outputs.
- SMA is using a constrained attention model to do information dissemination, but because of the structure of Mask, SMA can only do information flow within a constrained range.
- SSD reveals that they are the same thing, just in a different representation - SSM passes information recursively, whereas SMA propagates through a structured Mask, but the M they compute is structurally equivalent.
Mamba-2 architecture
Using SSD and the a small changes to Mamba’s neural network architecture, we have Mamba-2 architecture.
One change is that we have Parallel Parameter Projections. In Mamba-2, the SSD layer is viewed as a map. It therefore makes sense to produce [math]\displaystyle{ A, B, C, X }[/math] in parallel with a single projection at the beginning of the block.
Another change is that we have an extra Normalization at the end, it could be any normalization, LayerNorm, GroupNorm, or RMSNorm.
Mamba-2 usage and implementation
some examples of Mamba-2 usage includes:
1. Codestral Mamba (Mistral AI): Mistral AI’s open-source code generation model, built on Mamba-2. Competes with Code Llama and Codex, offering longer context (256k tokens) and lower latency.
2. NeMo LLM (NVIDIA): NVIDIA’s large-language model framework now supports Mamba-2 layers. They specifically highlight the use of a hybrid architecture with both Mamba-2 and self-attention layers.
4. Mamba-ND: A multi-dimensional version of Mamba-2 for handling images, video, and climate data. Used in ImageNet-1K classification, HMDB-51 action recognition, and ERA5 weather forecasting.
5. Mambaformer: A hybrid Transformer-Mamba model for time-series forecasting. Achieves better accuracy than both standalone Transformers and Mamba on long-range prediction tasks like stock markets and weather data.
6. Zigzag Mamba (ZigMa): A Mamba-powered diffusion model for image generation. Replaces attention in Diffusion Transformers (DiT), improving speed and memory efficiency for high-resolution image synthesis.
State Space Duality: State Space Models and Transformers
State Space Duality (SSD) gives a mathematization of State Space Models (SSMs)/Transformer duality such that the two models are proven to be definable in terms of structured semiseparable matrices. Duality offers the potential for the development of computationally efficient hybrid models such as Mamba-2 which preserve the expressiveness of Transformers with the benefit of computational efficiency of SSMs.
A discrete-time State Space Model (SSM) can be written as:
[math]\displaystyle{ h_{t+1} = A h_t + B x_t }[/math]
[math]\displaystyle{ y_t = C * h_t + D * x_t }[/math]
where [math]\displaystyle{ h_t }[/math] is the hidden state at time [math]\displaystyle{ t }[/math], [math]\displaystyle{ A, B, C, D }[/math] are learnable state-space matrices, [math]\displaystyle{ x_t }[/math] is the input, and [math]\displaystyle{ y_t }[/math] is the output.
[math]\displaystyle{ K = (CB, CAB, CA^2B, ..., CA^{L-1}B) }[/math]
where [math]\displaystyle{ K }[/math] is some kernel that encodes input dependencies.
SSD establishes that the Transformers and the SSMs share the same structure with respect to semiseparable matrices. A semiseparable matrix is
[math]\displaystyle{ M = D + U^T L }[/math]
where [math]\displaystyle{ D }[/math] is the diagonal matrix for local dependencies, [math]\displaystyle{ L }[/math] is the lower triangular matrix for sequential state transitions, and [math]\displaystyle{ U }[/math] is the upper triangular matrix for long-range interactions.
The self-attention mechanism of Transformers computes
[math]\displaystyle{ M_{att} = Q K^T }[/math]
where [math]\displaystyle{ Q, K }[/math] are the query and key matrices that yield explicit pairwise token interactions. In contrast, transition matrices of SSM make use of:
[math]\displaystyle{ M_{ssm} = C (I - A)^{-1} B }[/math]
where [math]\displaystyle{ A, B, C }[/math] define an implicit recurrence.
SSD offers an optimized block decomposition method for computational efficiency. It splits the sequence into tiny blocks within which intra-chunk dependencies are handled by structured SSMs and inter-chunk information is conducted by structured masked attention (SMA). The structured mechanism hugely reduces the computational complexity so training complexity is [math]\displaystyle{ O(NT) }[/math] for SSMs compared to [math]\displaystyle{ O(N^2T) }[/math] for Transformers and inference complexity is also reduced from [math]\displaystyle{ O(N^2T) }[/math] to [math]\displaystyle{ O(NT) }[/math].
A semiseparable matrix [math]\displaystyle{ M }[/math] of order [math]\displaystyle{ N }[/math] is defined as
[math]\displaystyle{ M_{ji} = C_j^T A_{j:i} B_i }[/math]
where [math]\displaystyle{ A }[/math] is a structured transition matrix that facilitates effective sequence transformations. In the case of 1-semiseparable structured attention, the recurrence is
[math]\displaystyle{ M_{ji} = a_{j:i} (C_j^T B_i) }[/math]
where [math]\displaystyle{ a_{j:i} }[/math] is a generalized positional dependency factor in structured decay.
Structured Masked Attention (SMA) generalizes linear attention with the inclusion of a structured masking matrix [math]\displaystyle{ L }[/math] as follows:
[math]\displaystyle{ L_{ij} = a_{i:j} }[/math]
where [math]\displaystyle{ a_{i:j} }[/math] are the learned decay factors that control the way information flows through the sequence. This makes SSD attain flexibility like attention but with efficiency like SSM.
One of the primary conclusions of SSD is that quadratic (attention-like) as well as linear (SSM-like) computations are enabled for semiseparable matrices.
[math]\displaystyle{ y = M x }[/math]
It can either be computed by a naive quadratic method where [math]\displaystyle{ M }[/math] is used as an explicit attention-like matrix, or an SSM recurrence method that computes [math]\displaystyle{ y }[/math] in [math]\displaystyle{ O(NT) }[/math] time. The dual formulation enables hardware-efficient implementations where structured decompositions are used to enhance computational speed with maintained long-range dependencies.
Mamba-2 exploits SSD through the use of parallel parameter projections to dynamically calculate the parameters of the SSM. There is selective gating of information by a gating mechanism and SSD-based calculation of state transitions facilitates optimal state updates. Expressivity is boosted through structured masked attention (SMA) and the gap with state-space models is bridged.
[math]\displaystyle{ y = \sum_{j=0}^{L-1} W_j (A^j B x_{t-j}) }[/math]
where [math]\displaystyle{ W_j }[/math] are learnable weights derived from structured attention masks. SSD permits hybrid architectures that retain the efficiency of SSMs but with Transformer expressiveness. The future can be anticipated from GPU/TPU acceleration, memory-efficient representations, and retrieval-augmented generation (RAG) for long-context applications. SSD has far-reaching implications for the future of large-scale models that are computationally tractable with 100M-token context windows.
Comparative Analysis of SSM Variants
To provide a clearer understanding of the strengths and weaknesses of State Space Model (SSM) variants—S4, DSS, H3, Mamba, and Mamba-2—this section presents a comparative analysis based on their performance in language modeling tasks, computational efficiency, and scalability. This comparison draws from their architectural designs and empirical results, offering insights into their suitability for various applications.
Performance Metrics
- Perplexity on Language Modeling Tasks
- S4: Achieves competitive perplexity on datasets like WikiText-103 but struggles with extremely long sequences due to its structured parameterization.
- DSS: Matches S4’s perplexity with a simpler diagonal structure, performing well on large-scale datasets with reduced overhead.
- H3: Outperforms S4 and DSS on synthetic tasks like Induction Heads and Associative Recall, closing the expressivity gap with Transformers.
- Mamba: Excels with very low perplexity on long-context tasks, leveraging selective state spaces to rival Transformer performance.
- Mamba-2: Achieves the lowest perplexity among SSMs by enhancing state dimensions and optimization, surpassing Mamba on benchmarks like OpenWebText.
- Computational Efficiency
- S4: Uses Fast Fourier Transform (FFT) for [math]\displaystyle{ O(N \log N) }[/math] training complexity, though its implementation is intricate.
- DSS: Simplifies to [math]\displaystyle{ O(N) }[/math] training and [math]\displaystyle{ O(1) }[/math] inference with a diagonal state matrix, enhancing efficiency.
- H3: Combines two SSMs, resulting in [math]\displaystyle{ O(d^2 N + d N \log N) }[/math] complexity—more expressive but less efficient than others.
- Mamba: Achieves [math]\displaystyle{ O(N) }[/math] training and [math]\displaystyle{ O(1) }[/math] inference with hardware-aware selective mechanisms, optimizing for long sequences.
- Mamba-2: Maintains Mamba’s complexity while improving performance through state space duality and parallelism.
- Scalability
- S4: Scales effectively for moderate sequence lengths but faces challenges with very long contexts.
- DSS: Excels in scalability for large datasets and real-time systems due to its simplicity.
- H3: Limited by higher computational costs, making it less scalable for large models.
- Mamba: Designed for long sequences, scales efficiently to millions of tokens, ideal for extensive context tasks.
- Mamba-2: Enhances scalability with tensor and sequence parallelism, enabling faster training and inference.
Where to Use Each Model
- S4: Best for continuous data tasks (e.g., time series, audio) where moderate sequence lengths are sufficient.
- DSS: Ideal for large-scale, real-time applications needing efficiency and interpretability (e.g., time-series forecasting).
- H3: Suited for language modeling tasks requiring Transformer-like expressivity with moderate sequence lengths.
- Mamba: Optimal for long-sequence tasks (e.g., document modeling, audio generation) needing high efficiency.
- Mamba-2: The go-to choice for large-scale, long-context modeling with enhanced training speed and performance.
Empirical Scaling and Performance Trends
Recent experiments have demonstrated that state space models scale remarkably well as their capacity increases. For instance, when scaled to billions of parameters, models such as Mamba and Mamba-2 achieve language modeling perplexities that rival or even surpass those of similarly sized Transformers. Empirical studies on benchmarks like OpenWebText and the Long Range Arena indicate that as the state dimension grows and longer contexts are leveraged, SSM-based architectures not only maintain their linear computational scaling but also benefit from improved performance. These scaling trends suggest that, with proper architectural tuning and parameter initialization, SSMs are emerging as a viable alternative for large-scale, long-context sequence modeling.
Summary & Key Takeaways
State Space Models (SSMs) offer a scalable alternative to transformers in language modeling by using state space equations to manage and update hidden states for generating outputs from inputs. Since its inception, there has been several important model enhancements. The following table summarizes their development, strengths, weaknesses, and time complexity:
Year | Model | Contribution | Strengths | Weaknesses | Complexity |
---|---|---|---|---|---|
2022 | Structured State Space (S4) | Leveraged Diagonal Plus Low-Rank parameterization to improve computational efficiency | Captures long-range dependencies efficiently | Complex architecture and implementation | [math]\displaystyle{ O(N log N) }[/math] with [math]\displaystyle{ FFT }[/math] |
2022 | Diagonal State Spaces (DSS) | Simplified S4 by using diagonal matrices to achieve comparable performance | Big data and real-time systems scalability | Performance depends on initialization and lacks capacity to handle information-dense data | For batch size [math]\displaystyle{ B }[/math], sequence length [math]\displaystyle{ L }[/math], and hidden size [math]\displaystyle{ H }[/math]: DSS layer requires [math]\displaystyle{ O(NHL) }[/math] time to compute kernels, [math]\displaystyle{ O(BHL logL) }[/math] time for discrete convolution and [math]\displaystyle{ O(BH^2L) }[/math] time for output projection. |
2023 | Hungry Hungry Hippos (H3) | Introduced new SSM layer consisting of two SSMs stacked together (a shift and diagonal SSM) to enable token recall and comparison across sequences | Transformer competitive performance on language modeling tasks | Less computationally efficient than other SSMs | For sequence length [math]\displaystyle{ N }[/math], and [math]\displaystyle{ d }[/math] hidden dimension, then [math]\displaystyle{ H3 }[/math] layer takes [math]\displaystyle{ O(d^2N + dN logN) }[/math] |
2024 | Mamba | Integrated selective state spaces to filter irrelevant information while being hardware-aware algorithm | Handles long sequences at faster inference and less memory requirements than Transformer | Scaling to compete with LLM requires complex engineering | Training is [math]\displaystyle{ O(N) }[/math], while inference is [math]\displaystyle{ O(1) }[/math] |
2024 | Mamba-2 | Used state space duality (SSD) to enable larger state dimensions | Tensor and sequence parallelism which allows for faster training and inference | Lack of inference optimization techniques like quantization and speculative decoding | Training is [math]\displaystyle{ O(N) }[/math], while inference is [math]\displaystyle{ O(1) }[/math] |
Future of SSMs
Ongoing research is focused on improving efficiency, adaptability, and scalability to position State Space Models (SSMs) as a viable alternative to Transformers. While SSMs excel at modeling long-range dependencies, they face hardware inefficiencies and an expressivity gap in tasks like language modeling. Key areas of research include:
1. Better hardware utilization: Optimizing SSM implementations to leverage GPU/TPU acceleration as efficiently as Transformers.
- Enhancing GPU/TPU acceleration through optimized memory access and tensor core utilization.
- Applying kernel fusion techniques to minimize redundant computations.
2. Adaptive SSMs: SSMs currently lack the expressivity of Transformers, particularly for tasks requiring complex reasoning. Developing architectures that can dynamically switch between SSM and attention-based processing depending on the task.
- Structured Masked Attention (SMA) for improved context retention and expressivity.
- Selective Copying & Induction Heads to enhance memory retention.
- State-passing algorithms to better preserve sequence dependencies.
3. Scaling laws for SSMs: SSMs scale efficiently for long-sequence tasks but require further study to match Transformers at large parameter counts. Understanding how these models perform at increasing levels of parameterization compared to standard deep learning architectures.
- Understanding scaling laws to optimize model depth and expressivity.
- Evaluating trade-offs between efficiency and generalization across tasks.
- Expanding SSM applications, particularly in speech, DNA, and document retrieval.
4. Hybrid Architectures with Dynamic Routing:
- Exploring deeper integrations between SSMs and Transformers, potentially using dynamic gating mechanisms or context-based routing, allowing models to dynamically switch processing modes depending on input characteristics or computational constraints.
5. Automated Hyperparameter and Architecture Search:
- Given the sensitivity of SSM performance to hyperparameter choices, research into automated hyperparameter optimization (such as Bayesian optimization or neural architecture search) could greatly reduce manual tuning, making SSMs more broadly applicable in practice.
6. Improved Visualization and Interpretability Techniques:
- Although inherently interpretable, visualization and diagnostic tools tailored specifically to SSM hidden states and transitions remain limited. Developing intuitive visualization tools would significantly improve the practical interpretability of SSM models, especially in high-dimensional or complex data scenarios.
By addressing these key areas, SSMs hold substantial promise for becoming the foundation of next-generation sequence modeling, especially where computational efficiency, adaptability, and interpretability are crucial, such as speech recognition, biological sequence modeling, and real-time inference systems.
Topic 8: Sparse Attention
Introduction
Vanilla Attention is very computationally expensive due to the multiplication of very large matrices required. This results in a complexity of [math]\displaystyle{ O(n^2) }[/math] where n may be very large, potentially even limiting the scalability of this method for long sequences. Sparse attention tries to address this problem by being more selective about what attention is computed. Intuitively, not all tokens need to attend to each other because not all words provide semantic information about all other words in a sentence. How can we determine which token-pairs are important? The goal of Sparse Attention is to answer this question in order to improve the efficiency and scalability of vanilla attention by not computing attention between all tokens without sacrificing performance.
Four methods were presented:
- Sparse Sinkhorn Attention
- Big Bird Sparse Attention
- Attention with Linear Biases (ALiBi)
- SpAtten
Sparse Sinkhorn Attention
Simplified Intuition
Imagine a crowded room where everyone is trying to talk at once—confusing and inefficient, right? That’s what happens in traditional attention: every token interacts with every other token, like a chaotic party where no one can really focus on the conversation. Sparse Sinkhorn Attention takes a different approach. It first acts like a clever host who quickly sorts the guests into smaller groups based on who’s most likely to have something in common. This sorting, done by a neat algorithm (the Sinkhorn normalization), gently reorders the tokens so that similar ideas end up together, allowing meaningful conversations in manageable clusters.
Once the guests are grouped, each token only needs to chat with its immediate neighbors, much like small, focused discussion groups. Even though tokens now converse locally, the smart grouping ensures that distant yet related ideas are brought close together. This not only cuts down on the chaos and computational cost but also preserves the essential flow of information—much like having productive, focused conversations instead of a noisy, all-out babble.
Overview
To address the computationally intensive of Vanilla Attention, especially for the long input sequence, Sparse Sinkhorn Attention proposes three core ideas:
- [math]\displaystyle{ \mathbf{Block} }[/math] [math]\displaystyle{ \mathbf{Partitioning} }[/math]
- The input sequence with length [math]\displaystyle{ l }[/math] is split into [math]\displaystyle{ N_b }[/math] blocks, each containing [math]\displaystyle{ b }[/math] tokens.
- [math]\displaystyle{ \mathbf{Block} }[/math] [math]\displaystyle{ \mathbf{Sorting} }[/math]
- Input blocks are sorted based on their similarity to the query.
- The Sinkhorn algorithm is then applied to generate a permutation matrix which ensures that relevant blocks are adjacent, and this algorithm also normalizes the sorting matrix to prevent overemphasis on certain input blocks.
- [math]\displaystyle{ \mathbf{Sparse\ (local)} }[/math] [math]\displaystyle{ \mathbf{Attention} }[/math]
- Instead of the full sequence, each token only computes the attention with token within its block (and tokens within its neighbour blocks).

Sorting Network
The original idea of block-based local attention is to allow tokens to only attend to tokens within the same block. However, this restricts the global attention field and limits the ability for local attention models to model long range dependencies, so we need to mitigates this problem by neural sorting of blocks, we learn this sorting through a neural network and gives a permutation matrix. We use a a Meta Sorting Network to learn to sort sequences to enable efficient quasi-global local attention. The equation of the Sorting Network is:
[math]\displaystyle{ R = P(x) = \sigma(W_B \sigma(W_pX+b_p) +b_B) }[/math]
Here, [math]\displaystyle{ R }[/math] becomes the permutation matrix, [math]\displaystyle{ \sigma }[/math] is an activation function and [math]\displaystyle{ W }[/math], [math]\displaystyle{ b }[/math] are the weights and biases. We can see that this is just a two-layer fully-connected neural network.
Sinkhorn normalization with Gumbel noise
We normalize the rows and columns of the sorting matrix using formula:
[math]\displaystyle{ S^0(R) = \frac{\exp(R)+\epsilon}{\tau} }[/math]
[math]\displaystyle{ S^k(R) = F_c(F_r(S^{k-1}(R))) }[/math]
where [math]\displaystyle{ F_r }[/math], [math]\displaystyle{ F_c }[/math] are the row and column wise normalization function, [math]\displaystyle{ \epsilon }[/math] is the injected standard i.i.d Gumbel noise and [math]\displaystyle{ \tau }[/math] is the temperature hyper-parameter.
Advantages over Vanilla Attention
- Reduced memory
- Sparse Sinkhorn Attention only needs [math]\displaystyle{ O\left(\left(\frac{l}{N_b}\right)^2 + N_b^2\right) }[/math], while Vanilla Attention requires [math]\displaystyle{ O(l^2) }[/math].
- Reduced Computational Complexity
- SORTCUT (a Sparse Sinkhorn Attention variant) only requires [math]\displaystyle{ O(lN_k) }[/math], where [math]\displaystyle{ N_k }[/math] is a hyperparameter which is much smaller than [math]\displaystyle{ l }[/math]. In contrast, Vanilla Attention requires [math]\displaystyle{ O(l^2) }[/math] time.
- Ability to Capture Long Range Dependency
- Through block sorting, relevant tokens are grouped adjacently, even if they are originally far apart. Then, through local attention, long range dependencies are captured. By comparison, Vanilla Attention struggles with this since it calculates attention for every tokens with all other tokens, and Softmax entropy collapse naturally reduces the importance of distant tokens, making it difficult to model long-range relationships both computationally and statistically.
Big Bird Sparse Attention
Intuitively, we picture the tokens as nodes in a directed graph where attention is calculated between two tokens if an edge exists to connect their nodes. Then, using the adjacency matrix of this graph, attention is not calculated between disconnected nodes.
In order to decide which nodes to 'connect,' the authors combined several methods:
- Random Attention
- Nodes are connected randomly (whether attention is calculated for a particular token pair is determined randomly)
- Window Attention
- Intuitively, tokens that are closer together probably provide more semantic information about each other. Therefore, attention is calculated between tokens that occur within a particular distance from each other.
- Global Attention
- Some tokens are 'celebrities' so attention between them and all other tokens is calculated. This results in a maximum distance of 2 between any two tokens (nodes in the graph)
Theoretically, the authors prove that any star-graph (graph which represents global attention) provides a lower bound for performance. Empirically, combining all three of the above methods into what the authors call 'BIGBIRD' results in higher model performance.
We may intepret the naming as:
- "Big": The ability of handling long sequences and overcome complexity of higher orders
- "Bird": A metaphor for bird's eye-view, combining global tokens (ability to observe the whole sequence), local sliding windows (ability to focus on local context), and random attention (ability to account for details that are random left over)
Limitations
Although BigBird achieves linear complexity, it does not come for free.
- Lower Bounds
- A natural task is displayed by the author that can be evaluated by O(1)-layer full attention mechanism. However, any sparse attention mechanism that has polynomial inner product evaluations and logarithmic length of the shortest path between any two nodes of the mechanism's digraph (thus more than just BigBird) needs to solve the task in at least polynomial layers.
Attention with Linear Biases (ALiBi)
Currently, models struggle to extrapolate. Extrapolation is the ability to produce sequence lengths at inference time that are longer than the sequences they were trained on. For example, if a model is trained on a dataset where the longest sequence is 1024 tokens, then it will perform poorly when asked to generate a sequence of 2048 tokens. If we can solve this problem then, in theory, training can become much more efficient as we can train on shorter sequences without sacrificing performance.
The authors hypothesized that the position representation method used in these models influenced extrapolation ability. The authors first tested the extrapolation performance of transformers that use absolute position encodings. This includes sinusoidal position embeddings (as seen in vanilla transformer) and rotary position embeddings (seen in open source GPT-3). The authors found that these transformer models did not extrapolate well.
Then, they tested the extrapolation performance of transformer that employed T5 bias as the position method. This method is a relative position method; it does not add positional information to word embeddings. Instead, after computing the query-key score in attention and before executing softmax, it adds a learned, shared bias that is dependent on the distance between the query and key. This model produced an improve performance. The authors concluded that a relative position method improves extrapolation. However, due to the high computation cost, T5 bias is infeasible.
ALiBi replaces the vanilla transformer [math]\displaystyle{ softmax(QK^T) }[/math] with [math]\displaystyle{ softmax(QK^T + bias) }[/math], adding a lower triangular matrix of biases in order to encode the relative positioning of tokens within the attention calculation.
The bias matrix is of the form [math]\displaystyle{ \begin{bmatrix} 0 & 0 & 0 & 0\\ -1 & 0 & 0 & 0\\ -2 & -1 & 0 & 0\\ -3 & -2 & -1 & 0\\ \end{bmatrix} }[/math]
The addition of this bias leads to a recency bias. A query-key pair that are far apart will have a large bias. The increasing penalty results in less information learned from distant query-key pairs.
The authors hypothesized that this could replace the positional encoding in transformers. For model architectures with multiple attention heads, the weighting of these biases can vary per head (defaulting to a geometric sequence of [math]\displaystyle{ \frac{1}{2^m} }[/math] where m is the number of heads.)
So even through ALiBi is considered to be a spare attention method, it didn't specify sparsity explicitly. As tokens gets far from one another, the bias becomes small and implicitly, we have sparsity.
SpAtten
Introduction
One of the major limitations of LLMs is the computational cost. One of the major inefficiencies is the long input sequence length and large KV cache. The authors have shown that much of the overall latency in attention is computing the [math]\displaystyle{ Q\times K }[/math] matrix. This matrix is [math]\displaystyle{ n\times n }[/math], where [math]\displaystyle{ n }[/math] is the sequence length. The author proposes to reduce the computational cost by shrinking the sequence length [math]\displaystyle{ n }[/math]. The objective is to remove tokens (or words in a sequence) during attention to reduce computation.
Sparse Attention
SpAtten is a method for pruning both tokens and attention heads in transformer models to improve efficiency. By removing tokens and heads with pruning, SpAtten reduces computational complexity and memory usage.
There are four key components to SpAtten:
1. Cascade Token Pruning
2. Cascade Head Pruning
3. Progressive Quantization
4. Specialized High Parallelism top-k Engine.
The first two are motivated by the inherent redundancy in human languages. Cascade token pruning helps to identify and remove unimportant tokens in a sentence; semantically-important tokens remain. For example, in the English language, syntactic words like prepositions, article words, etc... would be considered unimportant, and thus pruned. The reduction in token sequence length reduces the computation cost when executing atttention.
Similarly, cascade head pruning eliminates unnecessary attention heads. Recall that in multi-head attention, one executes attention with multiple heads to capture various token dependency relationships. However, some heads may be redundant to the output. These heads would be pruned.
Unlike traditional weight pruning techniques, cascade token and cascade head pruning operates dynamically and selects tokens and heads to prune on the fly during inference. As an example, given an initial sentence of "As a visual treat, the film is almost perfect." with 11 tokens and 12 heads, cascade pruning results in a final output of "film perfect" with only 2 tokens and 8 heads. The final output retains the essence of the initial sentence but would use less memory and be less computationally expensive to train.
The latter two components play a pivotal role in ensuring the efficient implementation of SpAtten. Progressive quantization is designed to strike a balance between computational efficiency and accuracy by dynamically adjusting the precision of computations. The process begins by handling only the most significant bits (MSBs) of the input data, enabling faster computations and reduced memory access. If the confidence in the result is insufficient, the system progressively incorporates more bits of data to recompute the output with higher precision. This approach is further refined by assigning different bitwidths to different layers, optimizing performance across the model.
Similarly, SpAtten employs a specially designed top-k engine to enable real-time decision-making for pruning tokens and attention heads. This engine efficiently ranks token and head importance scores in linear time, avoiding the inefficiencies of traditional sorting methods. Instead of sorting an entire array, the top-k engine uses a quick-select module to identify the kth largest element and filters the array based on this threshold. This technique is also implemented in a highly parallelized manner, with multiple comparators working simultaneously to achieve linear-time ranking. This streamlined process ensures rapid and accurate pruning decisions, contributing to SpAtten's overall efficiency.
Comparison
Method | Key Focus | How It Works | Primary Benefit | Primary Trade-off |
---|---|---|---|---|
Sparse Sinkhorn | Global interactions via learned sorting | Rearranges tokens to place important ones together in local blocks for efficient attention | Retains global attention while reducing compute | Sorting overhead adds computation; complex implementation |
BigBird | Local context with selective global/random links | Uses a fixed sparse pattern (local window + some global/random tokens) to approximate full attention | Scales to long sequences without needing per-sequence tuning | Fixed sparsity means some interactions might be missed |
ALiBi | Recency bias via linear attention penalty, enables extrapolation to longer sequence lengths | Applies a penalty to distant tokens in attention scores, encouraging focus on closer tokens | Extrapolates well to longer sequences without increasing training costs | Does not reduce compute/memory at inference; only helps training |
SpAtten | Dynamic token and head pruning for efficiency | Dynamically removes unimportant tokens and attention heads at inference for efficiency | Maximizes efficiency while preserving accuracy | Requires custom logic/hardware for full benefits |
Method | Memory Efficiency | Computational Speed | Accuracy vs Full Attention | Real-World Adoption |
---|---|---|---|---|
Sparse Sinkhorn | High (learned local attention reduces need for full attention) | Moderate (sorting overhead but reduced attention computation) | Near-parity (sometimes better, as learned sparsity is adaptive) | Limited (mainly research, complex to implement) |
BigBird | Very High (reduces O(n²) to O(n) via sparse pattern) | Fast (sparse pattern significantly reduces compute) | Near-parity (captures long-range dependencies well) | High (used in NLP for long-text tasks, available in HuggingFace) |
ALiBi | Same as full attention (no sparsity, just biasing) but can use smaller context length in training | Same as full attention (no change in complexity) but can use smaller context length in training | Same (performs as well as full attention with added extrapolation) | Very High (adopted in major LLMs like BLOOM, MPT) |
SpAtten | Extremely High (prunes tokens/heads dynamically) | Extremely Fast (up to 162× faster on specialized hardware) | Same (no accuracy loss, just efficiency gain) | Limited (research and specialized hardware, not common in open-source models) |
Topic 10: Linear Attention
Introduction
Linear attention tries to address the efficiency limitations of traditional softmax attention. Standard attention (also called vanilla attention) is calculated as:
[math]\displaystyle{ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V }[/math]
The computation complexity of this method is [math]\displaystyle{ O(n^2d) }[/math] for sequence length [math]\displaystyle{ n }[/math] and representation dimension [math]\displaystyle{ d }[/math] and the main reason behind that is the multiplication of [math]\displaystyle{ Q }[/math] and [math]\displaystyle{ K^T }[/math], produces a large [math]\displaystyle{ n \times n }[/math] attention matrix . If we remove the softmax function from this equation, then we can reduce the computational intensity because the matrix multiplication can be reordered to [math]\displaystyle{ Q(K^T V) }[/math]. Multiplying [math]\displaystyle{ K^T V }[/math] first results in smaller matrix multiplications. However, this linear mapping is insufficient to capture the complex relationships in the data. To reintroduce more complexity, we can replace the softmax function with some kernel with approximates it. Recall that for any kernel [math]\displaystyle{ K, K(x, y)= \Phi(x)^T \Phi(y) }[/math] for a matrix [math]\displaystyle{ \Phi }[/math]. If we apply this kernel to Q and K, then we can approximate vanilla attention with a much more efficient mechanism.
Linear attention is particularly useful for large-scale models because it allows them to handle long sequences efficiently. Unlike traditional Transformers, which require large amounts of memory for key-value caching, linear attention methods enable faster inference with lower resource usage. This makes them well-suited for applications where memory and compute power are limited, such as real-time processing.
Key Approaches to Linear Attention
Retentive Network (RetNet): A Successor to Transformer for Large Language Models

The authors of this paper introduce a new architecture that can achieve parallel training, fast inference, without compromising performance solving the inherent complexity in vanilla attention mechanism. They claim that the new architecture makes the impossible triangle, Figure 1, possible. RetNet combines the advantages of transformer performance and ability for parallel training with the lower-cost inference of RNNs by introducing a multi-scale retention mechanism to replace multi-head attention that incorporates three main representations:
- Parallel: allows GPU utilization during training.
- Recurrent: enables low-cost inference with [math]\displaystyle{ O(1) }[/math] complexity in terms of computation and memory.
- Chunkwise recurrent: models long-sequences efficiently.
Essentially, RetNet replaces the [math]\displaystyle{ softmax }[/math] by adopting a linear representation of the attention mechanism in the form of retention.
Recurrent Representation
We can show that the retention mechanism has a dual form of recurrence and parallelism (Figure2). First let's show RetNet in its recurrent form:

[math]\displaystyle{ S_n = \gamma S_{n-1} + K_n^T V_n }[/math]
[math]\displaystyle{ \text{Retention}(X_n) = Q_n S_n }[/math]
In this formula, [math]\displaystyle{ \gamma }[/math] is a decay factor, and [math]\displaystyle{ Q }[/math], [math]\displaystyle{ K }[/math], [math]\displaystyle{ V }[/math] are learned projections of the input. During training, this can be computed in parallel form.
Parallel Representation
Now, we can show the parallel representation of retention:
[math]\displaystyle{ \text{Retention}(X) = (QK^T \odot D)V }[/math]
Where [math]\displaystyle{ D }[/math] is a matrix that combines causal masking and exponential decay along relative distance. For better understanding the role of matrix [math]\displaystyle{ D }[/math], let's take a look at its formula:
[math]\displaystyle{ D_{nm} = \begin{cases} \gamma^{n-m}, & n \geq m \ 0, & n \lt m \end{cases} }[/math]
This formulation ensures that tokens can only attend to previous tokens (which they call it, causal masking) and the attention strength decays exponentially with distance (controlled by parameter [math]\displaystyle{ \gamma }[/math]). In other words, as we move forward, less attention is paid to earlier tokens.
Chunkwise Recurrent Representation
The third component, chunkwise recurrent representation, is a hybrid of recurrent and parallel representations aimed at training acceleration for long sequences. Simply, the input sequence is divided into chunks, where in each chunk a parallel representation is utilized to process all the tokens simultaneously. To pass information across chunks, the recurrent representation is used to create a summary of that chunk and pass it to the next. The combination of inner-chunk (parallel) and cross-chunk (recurrent) information produce the retention of the chunk which captures the details within the chunk and the context from previous chunks. To compute the retention of the i-th chunk, let [math]\displaystyle{ B }[/math] be the chunk length, then:
[math]\displaystyle{ Q_{[i]} = Q_{Bi:B(i+1)}, K_{[i]} = K_{Bi:B(i+1)}, V_{[i]} = V_{Bi:B(i+1)} }[/math]
[math]\displaystyle{ R_i = K_{[i]}^T (V_{[i]} \odot \zeta) + \gamma^B R_{i-1}, \quad \zeta_{ij} = \gamma^{B-i-1} }[/math]
[math]\displaystyle{ {Retention}(X_{[i]}) = \underbrace{(Q_{[i]} K_{[i]}^T \odot D) V_{[i]}}_{Inner-Chunk} + \underbrace{(Q_{[i]} R_{i-1}) \odot \xi}_{Cross-Chunk}, \quad \xi_{i,j} = \gamma^{i+1} }[/math]
where [math]\displaystyle{ R_i }[/math] is the state of the current chunk to be passed to the next.
In summary, the overall architecture of RetNet consists of stacked identical L blocks similar to Transformer. Each block consists of two modules: multi-scale retention (MSR) module and a feed-forward network (FFN) module. It takes the embeddings of input sequence [math]\displaystyle{ X^0= [x_1, \dots, x_{|x|}] \in \mathbb R^{|x| \times d_{model}} }[/math] and compute the output [math]\displaystyle{ X^L }[/math]:
[math]\displaystyle{ Y^l = MSR(LN(X^l)) + X^l }[/math]
[math]\displaystyle{ X^{l+1} = FFN(LN(Y^l)) + Y^l }[/math]
where [math]\displaystyle{ LN(.) }[/math] is LayerNorm.

Comparison, Limitations & Future Work
Finally, comparing RetNet to other solutions, we can see that it provides a solid alternative to the Transformer achieving parallel training, fast inference, linear long-sequence memory complexity, and good performance. The limitations of this architecture can be summarized in the following points:
- Limited Cross-Modal Exploration
- The method is primarily validated on language modeling tasks. Its applicability to other modalities (e.g., vision, audio) remains unexplored, requiring further research for broader adoption.
Memory-Recall Tradeoff
A fundamental tradeoff exists between the size of a model's recurrent state and its ability to recall past tokens accurately. Some architectures, like BASED, combine linear attention with a sliding window of exact softmax attention to navigate this tradeoff. By adjusting hyperparameters such as the window size and feature dimensions, these models can traverse the Pareto frontier—achieving high recall with a small memory footprint while still benefiting from high throughput.
This paper focuses on the task of associative recall and discusses the memory-recall tradeoff. The memory-recall tradeoff basically states that higher recall requires higher memory consumption, as shown in Figure 3. This tradeoff manifests in different ways across model architectures:
- Vanilla attention models: achieve excellent recall but at the cost of quadratic computational complexity
- Linear attention models: offer better efficiency but often struggle with accurate information retrieval
- State-space models like Mamba: show that increasing recurrent state size generally improves recall accuracy but introduces computational overhead

Implementations
The authors introduced Based architecture to improve this tradeoff. They do so by stacking linear attention blocks with sliding window attention blocks. Recall that one of the limitations of linear attention is relatively poor performance compared to vanilla attention. By selecting this type of architecture, the authors compromise on global performance by allowing linear attention to map global dependencies while retaining high performance within a local window by using the sliding window attention to map local dependencies. As seen in figure 4, they achieve this goal.
Chunking Strategy
Chunking involves spliting the input sequence [math]\displaystyle{ X }[/math] into chunks of length [math]\displaystyle{ B }[/math]. For each chunk [math]\displaystyle{ i }[/math], let
[math]\displaystyle{ Q_{[i]} = Q_{B i : B(i+1)}, \quad K_{[i]} = K_{B i : B(i+1)}, \quad V_{[i]} = V_{B i : B(i+1)} }[/math]
be the query, key, and value vectors (respectively) for that chunk. We maintain a recurrent “global” state [math]\displaystyle{ R_i }[/math] that summarizes information from all previous chunks and combine it with a “local” sliding-window attention over the current chunk.
Global Linear Attention (Recurrent “State”)
We use a feature map [math]\displaystyle{ \Phi(\cdot) }[/math] (e.g., a kernel or ELU-based mapping) to make attention linear in the sequence length. We update a global state [math]\displaystyle{ R_i }[/math] for each chunk:
[math]\displaystyle{ R_i = R_{i-1} + \Phi\bigl(K_{[i]}\bigr)^\top , V_{[i]}, }[/math]
and compute the global attention contribution for the [math]\displaystyle{ i }[/math]-th chunk:
[math]\displaystyle{ \mathrm{GlobalAttn}(X_{[i]}) = \Phi\bigl(Q_{[i]}\bigr),R_{i-1}. }[/math]
Intuitively, [math]\displaystyle{ R_{i-1} }[/math] aggregates (in linear time) all key-value information from past chunks.
Taylor Linear Attention
The Based architecture specifically employs a second-order Taylor series expansion as the feature map for linear attention. This approximation offers a good balance between computational efficiency and performance: [math]\displaystyle{ \Phi(x) = \begin{bmatrix} 1 \\ x \\ \frac{x^2}{2} \end{bmatrix} }[/math]
This feature map provides a reasonable approximation of the softmax function while maintaining the linear complexity benefits.
Local Sliding-Window Attention (Exact)
Within the current chunk, we also apply standard (exact) attention to a small window of recent tokens. Denote the set of token indices in this window by [math]\displaystyle{ \mathcal{W}_i }[/math]. Then:
[math]\displaystyle{ \mathrm{LocalAttn}(X_{[i]}) ;=; \sum_{j ,\in, \mathcal{W}i} \mathrm{softmax}!\Bigl(\frac{Q{[i]},K_j^\top}{\sqrt{d}}\Bigr),V_j.! }[/math]
This term captures fine-grained, short-range dependencies at full (exact) attention fidelity, but only over a small neighborhood [math]\displaystyle{ \mathcal{W}_i }[/math].
Final Output per Chunk
The total representation for chunk [math]\displaystyle{ i }[/math] is a sum of global (linear) and local (windowed) attention:
[math]\displaystyle{ h_{[i]} ;=; \mathrm{GlobalAttn}(X_{[i]}) ;+; \mathrm{LocalAttn}(X_{[i]}).! }[/math]
By combining a linear-time global update ([math]\displaystyle{ R_i }[/math]) with high-resolution local attention, the model balances throughput (via the efficient global linear component) and recall (via exact local attention). This addresses the memory-recall tradeoff: large contexts are captured without quadratically scaling memory usage, while local windows preserve high accuracy on short-range dependencies.
Hardware Optimizations
Recent work has shown that linear attention can be made even more practical when its implementation is tailored to the underlying hardware. For example, methods like FLASHLINEARATTENTION integrate I/O-aware techniques to minimize data movement between high-bandwidth memory and faster on-chip memories, resulting in real-world speedups that can even outperform optimized softmax attention implementations on moderate sequence lengths.
The Based Architecture includes several hardware-aware optimizations:
- Memory-efficient linear attention: The implementation fuses the feature map and causal dot product computation in fast memory, reducing high-latency memory operations.
- Optimized sliding window: The window size is carefully selected to align with hardware constraints (typically 64×64 tiles), balancing computation and memory bandwidth.
- Register-level computation: Critical calculations are performed in registers whenever possible, minimizing data movement between different memory hierarchies.
FLASHLINEARATTENTION: Hardware-Efficient Linear Attention for Fast Training and Inference
FLASHLINEARATTENTION is an I/O-aware, hardware-efficient linear attention mechanism for efficient data movement between shared memory (SRAM) and high-bandwidth memory (HBM). The goals are to alleviate memory bottlenecks, maximize GPU parallelism, and accelerate training and inference. The method significantly outperforms even FLASHATTENTION-2 at moderate sequence lengths (~1K tokens). Softmax-based self-attention, which is standard attention mechanism, is quadratic in both computation and memory complexity, making it inefficient for long sequences and scalability-constrained. Linear attention attempts to reduce this complexity, but most of them are not GPU-optimized for modern GPUs and do not provide real-world speed improvements. FLASHLINEARATTENTION solves this by splitting the input sequence into more manageable pieces where computation can independently be performed on each piece before global state updating. This alleviates redundant memory access, decreased latency GPU operations, and keeps tensor cores in effective use.
Mathematically, FLASHLINEARATTENTION builds upon linear attention, which rewrites standard softmax attention as:
[math]\displaystyle{ \text{Attention}(Q, K, V) = \text{softmax} \left(\frac{QK^T}{\sqrt{d}} \right) V }[/math]
instead of explicitly computing the full [math]\displaystyle{ n \times n }[/math] attention matrix, linear attention approximates it using a kernel function [math]\displaystyle{ \phi(x) }[/math] such that:
[math]\displaystyle{ \text{Attention}(Q, K, V) \approx \frac{\phi(Q) (\phi(K)^T V)}{\phi(Q) (\phi(K)^T)} }[/math]
where [math]\displaystyle{ \phi(x) }[/math] is a feature map transformation ensuring that the inner product approximates the softmax function. In standard parallel linear attention, the output is computed as:
[math]\displaystyle{ O = (Q K^T) V }[/math]
which still has quadratic complexity [math]\displaystyle{ O(n^2 d) }[/math]. FLASHLINEARATTENTION, on the other hand, splits the input sequence into pieces and processes them separately while having a hidden state [math]\displaystyle{ S }[/math]. The rule to update the hidden state is in a recurrent form:
[math]\displaystyle{ S[i+1] = S[i] + \sum_{j=1}^{C} K_j^T V_j }[/math]
[math]\displaystyle{ O[i+1] = Q[i+1] S[i] + (Q[i+1] K[i+1]^T \odot M) V[i+1] }[/math]
Here, [math]\displaystyle{ M }[/math] is a causal mask, ensuring attention is calculated only for tokens in the past. Chunkwise computation, minimizing memory overhead by splitting the sequence into smaller chunks to process, is the most crucial optimization in FLASHLINEARATTENTION to enhance efficiency. Chunks can be processed independently, allowing parallel execution. HBM I/O cost minimization is another key optimization, avoiding unnecessary data transfer between HBM and SRAM by reusing on-chip loaded tensors. When [math]\displaystyle{ Q[n] }[/math] is loaded into SRAM, both [math]\displaystyle{ Q[n]S }[/math] and [math]\displaystyle{ (Q[n]K[n]^T \odot M)V[n] }[/math] are computed without reloading [math]\displaystyle{ Q[n] }[/math]. FLASHLINEARATTENTION has two implementations. In the non-materialization version, hidden states [math]\displaystyle{ S[i] }[/math] are stored in SRAM to enable memory-efficient computation. In the materialization version, all [math]\displaystyle{ S[i] }[/math] are stored in HBM to enable full sequence-level parallelism. The materialization version is slightly slower but boosts training throughput by 10-20%. The calculation is parallel for chunks but sequential between chunks. This is for efficiency and handling of memory to render the algorithm suitable for processing long sequences.
The FLASHLINEARATTENTION forward pass algorithm goes as follows. For input matrices [math]\displaystyle{ Q }[/math], [math]\displaystyle{ K }[/math], and [math]\displaystyle{ V }[/math] of size [math]\displaystyle{ L \times d }[/math] and chunk size [math]\displaystyle{ C }[/math], the sequence is divided into [math]\displaystyle{ N = L / C }[/math] blocks as: [math]\displaystyle{ Q = \{ Q[1], Q[2], ..., Q[N] \}, \quad K = \{ K[1], K[2], ..., K[N] \} }[/math]
The hidden state is initialized as [math]\displaystyle{ S = 0 }[/math] in SRAM. For each chunk, [math]\displaystyle{ S }[/math] is stored in HBM if materialization is enabled, and the corresponding [math]\displaystyle{ K[n], V[n] }[/math] values are loaded into SRAM. The hidden state update follows:
[math]\displaystyle{ S = S + K[n]^T V[n] }[/math]
and the output for each chunk is computed in parallel as:
[math]\displaystyle{ O'[n] = Q[n] S + (Q[n] K[n]^T \odot M) V[n] }[/math]
The outputs are then stored in HBM and fed as outputs. The algorithm maintains a trade-off between memory usage and parallelization such that training speeds are higher than previous linear attention versions. FLASHLINEARATTENTION offers significant performance gains compared to several other attention models. Speedup gains involve less than FLASHATTENTION-2 on sequences shorter than 4K tokens and doubling of training speeds compared to standard linear attention. Memory efficiency is improved through reducing HBM I/O expense and removing redundant data movement, with 4x less memory usage compared to baseline softmax attention. Scalability is also a benefit, allowing for processing of sequences longer than 20K tokens without quadratic memory growth. These improvements make FLASHLINEARATTENTION a solid tool for large-scale transformers, improving efficiency in training and inference. By integrating FLASHLINEARATTENTION into traditional transformer designs, large language models can be significantly accelerated, and large-scale deployment made more viable. Future work can explore deeper kernel optimizations, CUDA-specific workloads, and hardware-specific transformations to further enhance efficiency. FLASHLINEARATTENTION is a significant advancement in hardware-efficient deep learning, supporting memory-efficient training at higher speeds for large-scale transformers. By optimizing memory access, chunking input sequences, and parallel intra-chunk computation, it achieves significant speedup with strong recall ability, making it a landmark achievement in the efficient processing of long sequences.
Limitations
While the Based Architecture presents a promising direction, it still faces some challenges:
- Implementation complexity: The dual-attention mechanism and hardware optimizations add complexity to the implementation.
- Hyperparameter sensitivity: The optimal balance between global and local attention components may vary across different tasks and datasets.
- Performance gap: Despite improvements, there remains a gap between Based models and state-of-the-art vanilla attention models on some specific tasks.
Gated Linear Attention (GLA)
RetNet shows us we can achieve parallel training, efficient inference, and more optimized performance. However, despite its innovations, RetNet still has significant limitations. The biggest problem with RetNet was how it handled sequential data. RetNet uses a constant decay factor ([math]\displaystyle{ \gamma }[/math]) that doesn't adapt to the content being processed. Let's imagine a conversation where you treat every fifth word equally regardless of how significant. That is what RetNet does with its constant exponential decay factor. What that means is that RetNet treats all sequential dependencies the same way, regardless of whether the current content requires more emphasis on recent tokens or on tokens much farther back. This process could work, but not necessarily optimally.
Furthermore, while RetNet is theoretically efficient, it is not so effective in practice on real hardware. The algorithms are not I/O-aware, and therefore they do not account for how data is being moved between levels of memory in modern GPUs and this leads to suboptimal performance. Gated Linear Attention (GLA) advances linear attention by solving such key limitations.
The new idea of GLA is the use of data-dependent gates. Unlike RetNet which uses fixed decay factors, these gates dynamically adjust based on the content being processed and this makes the model more powerful. It likes an intelligent filter that decides, at each position in a sequence, how much information should pass through depending on its relevance. For example, when processing the phrase "Prime Minister" in "The Prime Minister of Canada," the model might focus more on "Canada" than on "The". This is something RetNet with fixed decay cannot achieve. GLA introduces a gating mechanism where the output is determined by the below formula:
[math]\displaystyle{ \text{Output} = \text{LinearAttention} \odot \text{ContentGate} }[/math]
The ContentGate is derived directly from the input itself and this modification enhances the model’s ability to capture complex dependencies in text.
The other important feature of GLA is its FLASHLINEARATTENTION algorithm which designed specifically for modern GPU architectures. Unlike theoretical improvements, this optimization gives us real-world performance and outperforming even highly optimized methods like FLASHATTENTION-2. What makes FLASHLINEARATTENTION different is that it is I/O-aware, meaning it efficiently manages data movement between:
- High-bandwidth memory (HBM): Large but slower GPU memory.
- Shared memory (SRAM): Smaller but significantly faster memory.
By minimizing unnecessary data transfers and using specialized compute units (like tensor cores), this method achieves remarkable speed and efficiency. In benchmarks, it even surpasses FLASHATTENTION-2 on relatively short sequences of just 1,000 tokens.
FLASHLINEARATTENTION improves efficiency by using a method called chunkwise processing. Instead of processing the entire sequence at once, it divides it into smaller chunks. Within each chunk, computations are done in parallel to maximize GPU usage. Then, information is passed between chunks in a recurrent way to maintain dependencies. This balances speed and memory efficiency, making it faster than traditional softmax attention and even some optimized methods like FLASHATTENTION-2.
Limitations & Future Work
- Lack of Large-Scale Validation
- Although the authors anticipate that the training efficiency of GLA become even more favorable compared to model like Mamba at larger scales, it is unclear how GLA would scale to even larger models/datasets.
- The experiments were conducted on moderate-scale language models (up to 1.3B parameters). The performance and efficiency of GLA at larger scales (e.g., >7B parameters) remain uncertain, particularly regarding tensor parallelism and memory constraints.
- Performance Gap in Recall-Intensive Tasks
- While GLA outperforms other subquadratic models in recall-intensive tasks, it still lags behind standard quadratic attention Transformers, indicating room for improvement in tasks requiring precise memory retrieval.
TransNormerLLM: A Faster and Better Large Language Model with Improved TransNormer
The objective of the paper is to address the quadratic time complexity and scalability limitations of conventional transformer-based LLMs by introducing a linear attention-based architecture for improved efficiency and performance. Modern linear attention models like TransNormerLLM push the envelope by integrating several modifications—including improved positional encodings (using LRPE with exponential decay), gating mechanisms, and tensor normalization—to not only match but even outperform conventional Transformer architectures in both accuracy and efficiency. These innovations help address the limitations of earlier linear attention approaches and demonstrate the potential for linear methods in large-scale language modeling.
Key Contributions
- Architectural Improvements
- Positional Encoding:
Combines LRPE (Linearized Relative Positional Encoding) with exponential decay to balance global interactions and avoid attention dilution. The expression of positional coding is:
[math]\displaystyle{ a_{st}=\mathbf{q}_s^\top\mathbf{k}_t\lambda^{s-t}\exp(i\theta(s-t)), }[/math]
where [math]\displaystyle{ a_{st} }[/math] is the attention score between token s and t, [math]\displaystyle{ \lambda^{s-t} }[/math] is the exponential decay factor and [math]\displaystyle{ \exp(i\theta(s-t)) }[/math] is the encoding to capture relative positions. - Gating Mechanisms:
Uses Gated Linear Attention (GLA) and Simplified Gated Linear Units (SGLU) to stabilize training and enhance performance.
Gate can enhance the performance of the model and smooth the training process. The structure of Gated LinearAttention (GLA) is:
[math]\displaystyle{ O=\mathrm{Norm}(QK^{\top}V)\odot U, }[/math]
where:[math]\displaystyle{ \quad Q=\phi(XW_q),\quad K=\phi(XW_k),\quad V=XW_v,\quad U=XW_u. }[/math]
To further accelerate the model, the author propose Simple GLU (SGLU), which removes the activation function from the original GLU structure as the gate itself can introduce non-linearity. Therefore, the channel mixing becomes:
[math]\displaystyle{ O=[V\odot U]W_o, }[/math]
where:[math]\displaystyle{ V=XW_v,\quad U=XW_u. }[/math]
- Tensor Normalization:
Replaces RMSNorm with SRMSNorm, a simpler normalization method that accelerates training without performance loss. The new simple normalization function called SimpleRMSNorm, abbreviated as SRMSNorm:
[math]\displaystyle{ \mathrm{SRMSNorm}(x)=\frac{x}{\|x\|_2/\sqrt{d}} }[/math] - Training Optimization Model Parallelism:
- Robust Inference A robust inference algorithm proposed in the paper ensures numerical stability and constant inference speed regardless of sequence length via a recurrent formulation with decay factors.
Adapts Megatron-LM model parallelism for SGLU and GLA, enabling efficient scaling from 7B to 175B parameters.
Model Parallelism on SGLU:
[math]\displaystyle{ O=\left((XW_v)\odot(XW_u)\right)W_o }[/math]
[math]\displaystyle{ \begin{bmatrix} O_1^{\prime},O_2^{\prime} \end{bmatrix}=X \begin{bmatrix} W_v^1,W_v^2 \end{bmatrix}\odot X \begin{bmatrix} W_u^1,W_u^2 \end{bmatrix}= \begin{bmatrix} XW_v^1,XW_v^2 \end{bmatrix}\odot \begin{bmatrix} XW_u^1,XW_u^2 \end{bmatrix} }[/math]
[math]\displaystyle{ O= \begin{bmatrix} O_1^{\prime},O_2^{\prime} \end{bmatrix} \begin{bmatrix} W_o^1,W_o^2 \end{bmatrix}^\top=O_1^{\prime}W_o^1+O_2^{\prime}W_o^2 }[/math]
Model Parallelism on GLA:
[math]\displaystyle{ [O_1,O_2]=\mathrm{SRMSNorm}(QK^\top V)\odot U, }[/math]
where: [math]\displaystyle{ Q = [\phi(X W_q^1), \phi(X W_q^2)],\ K = [\phi(X W_q^1), \phi(X W_q^2)],\ V = X[W_v^1, W_v^2],\ U = X[W_u^1, W_u^2]. }[/math]

Simple linear attention language models balance the recall-throughput tradeoff
Large language models (LLMs) using Transformers are fantastic at "recalling" details from their input—think of it as their ability to dig up a specific fact buried in a long conversation. However, attention-based language models suffer from high memory consumption during inference due to the growing key-value (KV) cache, which scales with sequence length. This reduces throughput and limits efficiency for long sequences, despite their strong recall ability.
Key Contributions
- BASED: A hybrid architecture blending linear attention and sliding window attention to optimize the recall-memory tradeoff.
- Theoretical and empirical evidence of how memory usage impacts recall.
- Optimized implementation achieving up to 24× higher throughput than FlashAttention-2 for long-sequence generation.
BASED
BASED combines two tricks to strike a balance:
- Linear Attention: This approximates the softmax with a simpler operation. Instead of computing [math]\displaystyle{ \exp(q_i^\top k_j / \sqrt{d}) }[/math], it uses a feature map [math]\displaystyle{ \phi }[/math] so that [math]\displaystyle{ \phi(q_i)^\top \phi(k_j) \approx \exp(q_i^\top k_j / \sqrt{d}) }[/math]. The paper opts for a 2nd-order Taylor series: [math]\displaystyle{ \phi(q_i)^\top \phi(k_j) = 1 + q_i^\top k_j + \frac{(q_i^\top k_j)^2}{2} }[/math]. This lets them rewrite attention as a recurrent process with a fixed-size state, say [math]\displaystyle{ s_i = \sum_{j=1}^i \phi(k_j)^\top v_j }[/math], avoiding the KV-cache explosion. The state size depends on the feature dimension [math]\displaystyle{ d' }[/math], not [math]\displaystyle{ N }[/math], making memory predictable.
- Sliding Window Attention: This adds precision by letting each token attend to a small, local window of past tokens (e.g., 64 tokens) using exact softmax attention. It’s like giving the model a magnifying glass for nearby details, with a KV-cache capped at the window size [math]\displaystyle{ w }[/math].
Together, linear attention handles long-range context with a fixed memory footprint, while sliding window attention sharpens local recall. By tweaking [math]\displaystyle{ w }[/math] and [math]\displaystyle{ d' }[/math], BASED can slide along the recall-memory tradeoff curve—crank up the window for better recall or shrink it for efficiency.
Limitations and Future Work
- Scaling: May struggle with very large models or extremely long sequences.
- Future Work: Could explore improved approximations for linear attention or enhanced hardware optimizations.
Topic 11: Flash Attention
Flash Attention tries to address the resource intensity of vanilla attention by utilizing hardware more effectively. Recent advances in computational speed (particularly in parallelization) have lessened computational intensity as a bottleneck. Consequently, memory usage became a larger bottleneck. Specifically, during training, attention requires data to be read and written multiple times and computations are done in HBM (which is not the fastest memory in the computer). As a result, training takes a lot of time and resource. The goal of Flash Attention is thus to address these inefficiencies in order to reduce training time for LLMs.
Flash Attention V1
Flash Attention V1 rethinks the standard attention mechanism by minimizing expensive memory transfers. Modern GPUs have high-bandwidth memory (HBM) that is large but relatively slow compared to on‑chip SRAM, which is over 10× faster but limited to about 20 MB. By restructuring the attention computation to operate in small blocks within SRAM, Flash Attention V1 dramatically reduces the number of slow memory reads and writes, resulting in faster training for large language models.
Overview
Traditional attention computes the full [math]\displaystyle{ QK^\top }[/math] matrix in HBM, causing a quadratic memory traffic bottleneck. Flash Attention V1 addresses this by processing the input matrices in smaller, manageable blocks (tiles) that fit in fast SRAM. In doing so, both the matrix multiplication and the softmax normalization are performed piecewise, with local results later merged to produce the exact, globally normalized output.
Block‑Wise (Tiled) Computation
Imagine trying to paint a giant mural using a small canvas. Rather than handling the entire image at once, you divide it into sections that fit on your workbench. In Flash Attention V1, the large [math]\displaystyle{ Q }[/math], [math]\displaystyle{ K }[/math], and [math]\displaystyle{ V }[/math] matrices are partitioned into tiles. For example, with 1024 tokens, the matrices might be split into blocks of 256 tokens each. A block of [math]\displaystyle{ Q }[/math] (e.g. [math]\displaystyle{ 256 \times d }[/math]) is loaded from HBM into SRAM along with the corresponding block of [math]\displaystyle{ K }[/math] (transposed as needed). The multiplication [math]\displaystyle{ QK^\top }[/math] is then computed within SRAM to yield a [math]\displaystyle{ 256 \times 256 }[/math] block of attention scores, eliminating the need to store or compute the entire matrix in slow memory.
Softmax Normalization: Handling Global Dependencies in Small Pieces
The softmax function, defined as [math]\displaystyle{ \text{softmax}(x_i) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} }[/math] requires a global normalization across each row of scores. This poses a challenge when the data is processed in tiles.
Within each tile, a local maximum ([math]\displaystyle{ m_{\text{local}} }[/math]) is computed to stabilize the exponentiation, and a local sum is calculated: [math]\displaystyle{ S_{\text{local}} = \sum_{i \in \text{tile}} \exp(x_i - m_{\text{local}}) }[/math] Simultaneously, the tile computes a partial numerator by weighting each [math]\displaystyle{ \exp(x_i - m_{\text{local}}) }[/math] with the corresponding [math]\displaystyle{ V }[/math] values. These local computations ensure that the softmax can be performed on a small subset of the data without needing the entire row at once.
Algebraic Aggregation: Merging Tile Results for Global Softmax
Since each tile’s softmax is computed relative to its own local maximum, the results must be re‑aligned to a common global reference. Suppose one tile computes a local maximum [math]\displaystyle{ m_1 }[/math] and sum [math]\displaystyle{ S_1 = \sum_{i \in \text{tile 1}} \exp(x_i - m_1) }[/math] and another tile computes [math]\displaystyle{ m_2 }[/math] and [math]\displaystyle{ S_2 = \sum_{j \in \text{tile 2}} \exp(x_j - m_2) }[/math]. The global maximum is given by [math]\displaystyle{ m = \max(m_1, m_2) }[/math] Each local sum is then adjusted to this common reference: [math]\displaystyle{ S_{\text{global}} = \exp(m_1 - m) S_1 + \exp(m_2 - m) S_2 }[/math] A similar adjustment is applied to the local numerators (the weighted sums of [math]\displaystyle{ V }[/math]). The final attention output is obtained by dividing the aggregated numerator by [math]\displaystyle{ S_{\text{global}} }[/math]. This algebraic aggregation process guarantees numerical stability and correctness while never requiring the full matrix to be processed at once.
Memory vs. Compute Trade-offs and Workflow Recap
Flash Attention V1 trades additional computations for a dramatic reduction in memory traffic. Although extra arithmetic is performed—such as recomputing local softmax values and adjusting them—these operations occur in fast SRAM. The overall workflow is as follows:
Load small tiles of [math]\displaystyle{ Q }[/math], [math]\displaystyle{ K }[/math], and [math]\displaystyle{ V }[/math] from HBM into SRAM. Perform block‑wise matrix multiplication and softmax computations within SRAM, including computing local maximums, local sums, and partial numerators. Merge these local results through algebraic aggregation to produce the correct global softmax normalization. Finally, write the computed attention outputs back to HBM.
Optional: Block‑Sparse Extension
In some cases, not every tile contributes equally. Flash Attention V1 can be extended to a block‑sparse variant where tiles that fall below a predetermined importance threshold are skipped entirely. This further reduces computation and memory transfers, echoing ideas from sparse attention methods.
Flash Attention V2
Background & Motivation
Language models have much longer context than before. Scaling transformers to long sequences is crucial:
- language modelling (GPT-4, Clause, MPT)
- high-resolution image understanding
- code, audio, and video generation
Challenges
- Attention bottleneck: quadratic time complexity, inefficient hardware utilization
- Flash Attention V1: still not nearly as fast as optimized matrix-multiply (GEMM) operations, reaching only 25-40% of the theoretical maximum FLOPs/s
- Suitability to Different Devices: Further research is needed to make Flash Attention V2 available to different devices, such as H100 GPUs, AMD GPUs. So far, the discussion has been focusing on the HBM and SRAM (NVIDIA GPUs with specific memory hierarchies).
Main Idea
In V1, its idea is to use SRAM, compute more, reduce reads/writes. However, it's mentioned previously that only 25%-40% of theoretical maximum FLOPs is reached. This inefficiency is due to suboptimal work partitioning between different thread blocks and warps on the GPU, causing either low-occupancy or unnecessary shared memory reads/writes. In V2, the main idea is to have a better work partitioning algorithm. The core improvements are:
1. Reduce redundant operations: GPU is optimized for matmul operations, although non-matmul FLOPs only takes a fraction of total FLOPs, with optimized hardware, matmul operations are actually performed much faster than non-matmul operations. Thus it's important to reduce non-matmul FLOPs and spend as much time as possible doing matmul FLOPs.
2. Parallelism: Distribute attention computation efficiently across thread blocks and warps.
3. Efficient memory utilization: Optimize intra-thread communication, reducing memory overhead.
Empirical Results
Diagrams above shows how Flash Attention V2 performs across different sequence lengths and compare it to a standard attention implementation in PyTorch, FlashAttention, and FlashAttention in Triton. We can see an improvement in efficiency across all sequence lengths. We can see that FlashAttention-2 is 1.7-3.0× faster than FlashAttention, 1.3-2.5× faster than FlashAttention in Triton, and 3-10× faster than a standard attention implementation. The TFLOPs reached in V2 is about 70% of the theoretical TFLOPs achievable on A100 GPUs.
Flash Attention V3
FlashAttention-2 was a significant improvement in speeding up attention mechanisms, but it only utilized 35% of the computational capacity of H100 GPUs. This left a lot of untapped potential, especially given the advanced features of Hopper GPUs, such as asynchronous execution and support for low-precision math. FlashAttention-3 was developed to fully exploit these capabilities, making attention computations faster, more efficient, and scalable for extremely long sequences.
There are two main ideas in FlashAttention-3:
1. Overlapping Tasks with Asynchronous Execution:
Hopper GPUs feature Tensor Cores which are specialized units for performing matrix multiplications at lightning speed. FlashAttention-3 takes advantage of these Tensor Cores by using warp specialization, where some GPU threads handle data movement (producers) while others focus on computation (consumers). These roles swap and run simultaneously, leading to "pingpong" scheduling to ensure the GPU is always busy and no time is wasted.
2. Low-Precision Math (FP8):
FlashAttention-3 introduces support for FP8 precision, which reduces memory usage and speeds up calculations compared to FP16. To maintain accuracy, it uses techniques like block quantization (storing one scalar per block) and incoherent processing, which smooths out numerical outliers. This results in a 2.6x reduction in numerical error compared to baseline FP8 implementations.
Comparison to FlashAttention-2:
- 1.5–2x faster attention computations
- Up to 75% GPU utilization with FP16 precision
Flash Fast Fourier Transform Convolution (FlashFFT Conv)
Introduction
Efficient sequence processing is always a fundamental challenge in deep learning, especially for natural language processing (NLP). Although convolutional models provide the state-of-the-art performance across various tasks, including NLP and time-series forecasting, their efficiency is often limited due to suboptimal hardware utilization. A potential solution to this problem is the Fast Fourier Transform (FFT), which theoretically operates in [math]\displaystyle{ O(N\text{log}N) }[/math] time complexity. However, conventional FFT implementations fail to fully leverage modern GPU architectures, leading to suboptimal efficiency. To address this, a highly optimized FFT-based convolution algorithm, namely Flash Fast Fourier Transform Convolution (Flash FFT Conv), is introduced.
Innovation
A key innovation in Flash FFT Conv is the Monarch FFT Decomposition, which significantly improves the efficiency of Fast Fourier Transform (FFT)-based convolutions. Unlike traditional FFT algorithms, such as the Cooley-Tukey FFT, Monarch FFT Decomposition reformulates the FFT computation into a structured matrix representation that can be efficiently executed using matrix multiplications on tensor cores.
Importance of Monarch FFT Decomposition
The objective of Monarch FFT decomposition is to express FFT as a series of matrix multiplication. Doing so, it enables efficient GPU computation on specialized compute units (e.g., tensor cores of NVIDIA GPUS or matrix multiply units of TPUs). An order-[math]\displaystyle{ p }[/math] Monatch decomposition rewrites the FFT into [math]\displaystyle{ p }[/math] matrix-matrix multiply operations (which implies that we can map them onto these fast compute units efficiently). The value of [math]\displaystyle{ p }[/math] has a tradeoff. A higher [math]\displaystyle{ p }[/math] value imply smaller matrices to multiply, and thus, a lower number of FLOPs. But there is greater I/O communication overhead due to greater number of intermediate results.

An illusrtation of Monarch FFT decomposition can be seen to the right. To give an example of Monarch FFT decomposition, consider a matrix [math]\displaystyle{ N=N_1N_2 }[/math]. An order-2 Monarch FFT decomposition expresses the FFT of [math]\displaystyle{ N }[/math] as [math]\displaystyle{ \mathcal{F}_N=\mathbf{P}(\mathbf{I}_{N_2}\otimes\mathcal{F}_{N_1})\mathbf{D}\mathbf{P}^{-1}(\mathbf{I}_{N_1}\otimes\mathcal{F}_{N_2})\mathbf{P} }[/math], where [math]\displaystyle{ \oslash }[/math] is the Kronecker product, [math]\displaystyle{ \mathbf{P} }[/math] is the permutate matrix that reshapes the input to [math]\displaystyle{ N_1\times N_2 }[/math], transposes the intermediate matrix,and then reshape it back to [math]\displaystyle{ N }[/math]m, and [math]\displaystyle{ \mathbf{D}\in\mathbb{C}^{N\times N} }[/math] is a dinagonal matrix containing correctional values Twiddle factors. Twiddle factors are roots of unity [math]\displaystyle{ W^k_N=\exp{\left(-j\frac{2\pi k}{N}\right)} }[/math] where [math]\displaystyle{ N }[/math] is the number of points in the FFT, that are used to combine the results of smaller DFTs to generate larger DFTs (recall that FFT is a divide-and-conquer algorithm to compute DFT efficiently).
To execute higher-order Monarch FFT decompositions, one can recursively apply the order-2 decomposition to [math]\displaystyle{ \mathcal{F}_{N_1} }[/math] and [math]\displaystyle{ \mathcal{F}_{N_2} }[/math].
Benefits
By leveraging this structured decomposition, FlashFFTConv achieves highly parallelized execution across the input sequence. This maximizes hardware utilization, resulting in faster computation. Furthermore, this approach minimizes high-latency global memory (HBM) access by performing most computations within on-chip memory (SRAM or shared memory). This reduces the I/O bottleneck that often limits FFT performance on modern GPUs. This leads to significant speedups over conventional FFT-based convolution methods. Additionally, the use of sparse and low-rank matrix structures further enhances computational efficiency, eliminating unnecessary operations and improving scalability for long-sequence processing.
Through efficient memory access and tensor core optimization, Flash FFT Conv becomes a highly effective solution for accelerating FFT-based convolutions tasks, such as NLP and audio generation.
Topic 5: KD / Pruning / Sharing
Knowledge Distillation (KD), Pruning, and Sharing are three strategies to make language models smaller, faster, and more efficient without sacrificing their performance. KD works by having a large, powerful "teacher" model that teachs a smaller "student" model how to behave similarly so the student model can do almost as well as the teacher but with less effort. Pruning can be thought of trimming a tree by which it removes unnecessary or less important parts of the model, making it more compact and faster while keeping its core abilities intact. Sharing is about reusing parts of the model across different tasks or layers, so there is no need to build everything from scratch, saving time and resources.
While KD focuses on teaching a smaller model to imitate a larger one, pruning cuts away the unnecessary parts to make the model cleaner, and sharing reuses parts to avoid having unnecessary extra work. These techniques contribute together to create language models that are not only powerful but also efficient enough to run on devices with limited resources, like smartphones or laptops.
Knowledge Distillation (KD)
KD transfers knowledge from a large, well-trained teacher model to a smaller student model. The student learns not only from the hard labels but also by mimicking the teacher’s soft output distributions or internal representations. In the work by Muralidharan et al. , KD is a critical component in recovering the performance lost during the compression of a 15B-parameter model. By distilling knowledge during retraining, the authors were able to produce smaller models (MINITRON 8B and 4B) that not only match but, in some cases, outperform models trained from scratch—even while using a fraction of the training data.

Improvements to Knowledge Distillation
1. Born-Again Networks (BANs):
Repeated training iterations to simplify the dataset.
Improves performance for lower-capacity NAT models.
2. Mixture of Experts (MoE):
Reduces data diversity by training on simpler expert translations.
3. Sequence-Level Interpolation:
Selects the best candidate translation based on BLEU score, improving high-capacity NAT models.
Metrics for Distilled Data
- Data Complexity measures how much variability is present in the translations, which is calculated by fitting an alignment model and compute the average of token-level conditional entropy.
- Data faithfulness evaluates how closely the distilled data aligns with the original real-world data.
[math]\displaystyle{ F(d)=\frac{1}{\left|\mathcal{V}_x\right|} \sum_{x \in \mathcal{V}_x} \sum_{y \in \mathcal{V}_y} p_r(y \mid x) \log \frac{p_r(y \mid x)}{p_d(y \mid x)} }[/math]
Attention Is All You Need But You Don’t Need All Of It For Inference of Large Language Models
The computational costs of LLMs, particularly during inference, is a major bottleneck. This paper explores a key optimization: selectively dropping layers, specifically, deeper attention layers, during inference to speed up computations while maintaining performance.
The Problem: Inference Complexity
Inference in LLMs is costly because of the self-attention mechanism, which scales quadratically with input length. This creates a challenge, especially for applications requiring real-time responses. Researchers have previously explored various methods to optimize inference, including pruning, quantization, and speculative decoding, but this paper focuses on a more direct approach: skipping unnecessary layers.
Key Idea: Skipping Layers Without Significant Performance Loss
The core hypothesis of the paper is that not all layers contribute equally to the final output of an LLM. Specifically, the deeper layers (closer to the output) exhibit higher similarity with their preceding layers, meaning that dropping some of them may have minimal impact on performance. To test this, the authors conduct experiments on Llama-v2 (7B and 13B models) and compare different strategies for skipping layers.
Method: Selective Layer Skipping
Consider a Transformer model with [math]\displaystyle{ L }[/math] layers, where each layer consists of an attention sub-layer and a multi-layer perceptron (MLP) sub-layer. We define each layer as:
[math]\displaystyle{ \text{Layer}_i = (\text{Attention}_i, \text{MLP}_i) }[/math]
for [math]\displaystyle{ i \in \{1, 2, \dots, L\} }[/math].
The paper evaluates three techniques for skipping components of Transformer layers:
1. Skipping MLP Layers: The feedforward (MLP) sub-layers from the last few layers are removed.
If we skip the last [math]\displaystyle{ k }[/math] MLP layers while keeping the attention layers, the modified model is:
[math]\displaystyle{ M_{\text{skip MLP}} = \left\{ (\text{Attention}_i, \text{MLP}_i) \mid i \in [1, L-k] \right\} \cup \left\{ (\text{Attention}_i, \emptyset) \mid i \in [L-k+1, L] \right\} }[/math]
2. Skipping Attention Layers: The self-attention sub-layers from the last few layers are removed.
If we skip the last [math]\displaystyle{ k }[/math] attention layers while keeping the MLP layers, the modified model is represented as:
[math]\displaystyle{ M_{\text{skip attention}} = \left\{ (\text{Attention}_i, \text{MLP}_i) \mid i \in [1, L-k] \right\} \cup \left\{ (\emptyset, \text{MLP}_i) \mid i \in [L-k+1, L] \right\} }[/math]
3. Skipping Entire Transformer Blocks: Both attention and MLP components are removed from some of the last layers.
If we skip both the attention and MLP sub-layers in the last [math]\displaystyle{ k }[/math] layers, we obtain:
[math]\displaystyle{ M_{\text{skip block}} = \left\{ (\text{Attention}_i, \text{MLP}_i) \mid i \in [1, L-k] \right\} \cup \left\{ (\emptyset, \emptyset) \mid i \in [L-k+1, L] \right\} }[/math]
The models were tested across four benchmarks: ARC, HellaSwag, TruthfulQA, and MMLU, which measure reasoning, common sense, truthfulness, and general knowledge.
Findings: Attention Layers Are Less Crucial Than MLP Layers

The results demonstrate a clear pattern:
- Dropping entire Transformer blocks leads to a significant performance drop.
- Dropping MLP layers leads to larger performance degradation compared to dropping attention layers.
- Dropping attention layers results in the best trade-off between speed and performance, with only a 1.8% drop in accuracy when 33% of attention layers are removed. Removing 33% of attention layers led to an 18% speedup in Llama-2-13B.
Interestingly, the TruthfulQA benchmark showed an increase in accuracy when some layers were skipped, suggesting that reducing model complexity might reduce hallucinations.
The paper provides empirical evidence that deeper layers contribute less unique information compared to earlier layers. The figure below shows cosine similarity between successive layers, highlighting that deeper layers are more redundant.