stat946W25
Record your contributions here [1]
Use the following notations:
C: You have written a summary/critique on the paper.
Your feedback on presentations
Topic 12: State Space Models
Introduction
State Space Models (SSMs) are introduced as powerful alternatives to traditional sequence modeling approaches. These models demonstrate good performance in various modalities, including time series analysis, audio generation, and image processing and they can capture long-range dependencies more efficiently. SSMs initially struggled to match the performance of Transformers in language modeling tasks and there were some gaps between them. To address their challenges, recent advances in their architecture such as the Structured State Space Model (S4) have been introduced, which succeeded in long-range reasoning tasks and allowed for more efficient computation while preserving theoretical strengths. However, its implementation remains complex and computationally demanding. So further research led to simplified variants such as the Diagonal State Space Model (DSS), which achieves comparable performance with a more straightforward formulation. In parallel, hybrid approaches, like the H3 model, that integrate SSMs with attention mechanisms try to bridge the mentioned gaps. To understand better what I mean from the hybrid word, for example in H3 the authors try replacing almost all the attention layers in transformers with SSMs. More recently, models like Mamba have pushed the boundaries of SSMs by selectively parameterizing state matrices as functions of the input and allowing more flexible and adaptive information propagation. Research in SSMs continues to resolve the remaining challenges and the potential to substitute attention-based architectures with SSMs grows stronger. They will likely play a crucial role in the next generation of sequence modeling frameworks.
Core concepts
To understand State Space Models better, let's first take a look at the problem with Transformers and Recurrent Neural Networks (RNNs) and the relation between them and SSMs. In transformers during training, we create a matrix comparing each token with every token that came before and the weights in each cell of this matrix show how similar they are to each other. For calculating the weights, we don't do a sequential operation and each two tokens' relation can be computed in parallel. However, during inference, when we try to generate the next token, we need to re-calculate the attention for the entire sequence, even if we already generated some tokens. This means, for a sequence of length [math]\displaystyle{ L }[/math], the computation cost would be [math]\displaystyle{ L^2 }[/math] and this is a bottleneck for long sequences in Transformer. RNN takes two inputs at each time step to predict the output [math]\displaystyle{ y_t }[/math] and generate current hidden state ([math]\displaystyle{ h_t }[/math]). These inputs are the hidden state of previous time step ([math]\displaystyle{ h_{t-1} }[/math]) and the input of current time step ([math]\displaystyle{ x_t }[/math]). This structure helps us to do the inference with linear computation because unlike transformers, it doesn't recalculate all previous hidden states. On the other hand, the known problem with them is that they forget the initial tokens as they go forward. A simple RNN updates its hidden state using below formula:

[math]\displaystyle{
h(t) = \sigma(W_h h(t-1) + W_x x(t))
}[/math]
[math]\displaystyle{ y(t) = W_y h(t) }[/math]
Where:
- [math]\displaystyle{ h(t) }[/math] is the hidden state at time t
- [math]\displaystyle{ x(t) }[/math] is the input
- [math]\displaystyle{ y(t) }[/math] is the output
- [math]\displaystyle{ W_h, W_x, }[/math] and [math]\displaystyle{ W_y }[/math] are weight matrices
- [math]\displaystyle{ \sigma }[/math] is a non-linear activation function
State Space Models (SSM) come from control theory for mathematical representation of a system and describing its possible states. They define a linear mapping from an input signal x(t) to an output signal y(t) through a latent state representation h(t). There are two main components to SSM:
1. State equation: Describes how the system's hidden state changes based on the current state and input
[math]\displaystyle{ h'(t) = \mathbf{A} h(t) + \mathbf{B} x(t) }[/math]
2. Output equation: Defines how the system produces output based on the current state and input
[math]\displaystyle{ y(t) = \mathbf{C} h(t) + \mathbf{D} x(t) }[/math]
Where:
- [math]\displaystyle{ h(t) }[/math] represents the hidden state
- [math]\displaystyle{ x(t) }[/math] is the input
- [math]\displaystyle{ y(t) }[/math] is the output
- [math]\displaystyle{ \mathbf{A}, \mathbf{B}, \mathbf{C}, }[/math] and [math]\displaystyle{ \mathbf{D} }[/math] are learnable parameter matrices
The intuition behind the neural network architecture stems from the state equation, which essentially says the change in latent variable is given by the sum of linear transformations of current latent state and the linear transformation of the input variable. Solving the system of differential equations will then be equivalent to train the neural network to find linear transforamtions [math]\displaystyle{ \mathbf{A},\mathbf{B},\mathbf{C},\mathbf{D} }[/math]. However, note that in this definition we use continuous functions and is not applicable to many sequential input data that are discrete, especially sentences. Therefore, we discretize the differential equation so it adapts to taking discrete input sequences in our problems. This yields
[math]\displaystyle{ h_t = \bar{\mathbf{A}} h_{t-1} + \bar{\mathbf{B}} x_t }[/math]
[math]\displaystyle{ y_t = \mathbf{\bar{C}} h_t }[/math]
where [math]\displaystyle{ h_t }[/math] is the discrete latent state at step t and [math]\displaystyle{ x_t }[/math] is the input at step t.
Note that we dropped [math]\displaystyle{ \mathbf{D}x_t }[/math] for computational simplicity and hope [math]\displaystyle{ \mathbf{\bar{C}} }[/math] can capture the effect of input [math]\displaystyle{ x_t }[/math], as [math]\displaystyle{ \mathbf{\bar{C}} h_t = \mathbf{\bar{C}} (\mathbf{\bar{A}}h_{t-1} + \mathbf{\bar{B}}x_t) }[/math] so [math]\displaystyle{ \mathbf{\bar{C}} }[/math] operates on [math]\displaystyle{ x_t }[/math] as well.
So how does SSM differ from the RNN? In fact, the difference lies in the convolutive view of SSM. To understand this, consider an sequence of inputs [math]\displaystyle{ x_{1:n} }[/math], then we have
[math]\displaystyle{ h_0 = \mathbf{\bar{B}} x_0 }[/math]
[math]\displaystyle{ h_1 = \mathbf{\bar{A}}h_0 + \mathbf{\bar{B}} x_1 = \mathbf{\bar{A}}\mathbf{\bar{B}} x_0 + \mathbf{\bar{B}} x_1 }[/math]
[math]\displaystyle{ h_2 = \mathbf{\bar{A}}h_1 + \mathbf{\bar{B}} x_2 = \mathbf{\bar{A}}^2\mathbf{\bar{B}} x_0 + \mathbf{\bar{A}}\mathbf{\bar{B}} x_1 + \mathbf{\bar{B}} x_2 }[/math]
Note that since we have linear transformation, we can now compute all states in parallel and this could greatly improve training speed.
Note that we also put a bar on the matrices, this is a result of discritization where [math]\displaystyle{ \mathbf{\bar A} }[/math] and [math]\displaystyle{ \mathbf{\bar B} }[/math] are discrete equivalents of the [math]\displaystyle{ \mathbf{A} }[/math] and [math]\displaystyle{ \mathbf{B} }[/math] transforms compounded over timestep [math]\displaystyle{ \Delta }[/math] using some discretization method. For example using a trapezoidal rule
- [math]\displaystyle{ \mathbf{\bar A} = (I - \frac{\Delta}{2}A)^{-1}(I + \frac{\Delta}{2}A) }[/math]
- [math]\displaystyle{ \mathbf{\bar B} = (I - \frac{\Delta}{2}A)^{-1}\Delta B }[/math]
Look at #Discretization for further details.
Looking at these formulations and those we defined for RNNs shows us they are similar. We can see that an RNN is essentially a non-linear extension of a state space model. The main differences are:
- SSMs are linear transformations between states, while RNNs apply non-linearity through the activation function
- SSMs come from control theory and in control systems, the matrices are typically derived from physics equations, while in machine learning we learn these matrices from data
- In SSMs, we have D u(t) in the second equation which is commonly left out in control problems
Discretization
As mentioned State Space models come from ordinary differential equations, so the way we discretize those continuous equations so they work with finite sequences is crucial. In fact we aim to learn [math]\displaystyle{ \mathbf{\bar A} }[/math] and [math]\displaystyle{ \mathbf{\bar B} }[/math] and not A and B directly, so the discretization step is baked into our model in practice. Below we show a quick discretization example based on a trapizoidal rule.
Trapizoidal rule assumes: [math]\displaystyle{ x_{n+1} - x_{n} = \frac{\Delta}{2} (f(t_{n+1}) + f(t_{n})) }[/math]
We start from the ordinary differential equation. [math]\displaystyle{ h'(t) = \mathbf{A} h(t) + \mathbf{B} x(t) }[/math]
By using trapizoidal rule on [math]\displaystyle{ h' }[/math]
[math]\displaystyle{ h_{n+1} - h_{n} = \frac{\Delta}{2} (\mathbf{A}h_{n+1} + \mathbf{B}x_{n+1} + \mathbf{A}h_{n} + \mathbf{B}x_{n}) }[/math]
[math]\displaystyle{ h_{n+1} - \frac{\Delta}{2} \mathbf{A}h_{n+1} = h_n + \frac{\Delta}{2}\mathbf{A}h_{n} + \frac{\Delta}{2} \mathbf{B}(x_{n+1} + x_{n}) }[/math]
[math]\displaystyle{ (\mathbf{I} - \frac{\Delta}{2} \mathbf{A})h_{n+1} = (\mathbf{I} + \frac{\Delta}{2}\mathbf{A})h_{n} + \frac{\Delta}{2} \mathbf{B}(x_{n+1} + x_{n}) }[/math]
It is assumed that the control sequence does not change over small enough [math]\displaystyle{ \Delta }[/math]. e.i. [math]\displaystyle{ x_{n+1}\approx x_n }[/math]
[math]\displaystyle{ (\mathbf{I} - \frac{\Delta}{2} \mathbf{A})h_{n+1} = (\mathbf{I} + \frac{\Delta}{2}\mathbf{A})h_{n} + \Delta \mathbf{B}(x_{n+1}) }[/math]
[math]\displaystyle{ h_{n+1} =
(\mathbf{I} - \frac{\Delta}{2} \mathbf{A})^{-1}(\mathbf{I} + \frac{\Delta}{2}\mathbf{A})h_{n} + (\mathbf{I} - \frac{\Delta}{2} \mathbf{A})^{-1}\Delta \mathbf{B}(x_{n+1})
}[/math]
Indeed [math]\displaystyle{ \mathbf{\bar A}=(\mathbf{I} - \frac{\Delta}{2} \mathbf{A})^{-1}(\mathbf{I} + \frac{\Delta}{2}\mathbf{A}) }[/math] and [math]\displaystyle{ \mathbf{\bar B}=(\mathbf{I} - \frac{\Delta}{2} \mathbf{A})^{-1}\Delta \mathbf{B} }[/math]
Structured State Space (S4)
The structured state space (S4) model is specialized for conducting efficient sequence modelling, particularly in situations when handling long-term dependencies is needed. The goal of S4 is to improve computational efficiency, reduce the number of parameters, and enhance the ability of learning by introducing constraints on the state transition matrices.
To efficiently handle long sequences, the Structured State Space (S4) model employs a structured parameterization of its state transition matrices. Specifically, it uses a Diagonal Plus Low-Rank (DPLR) parameterization, which combines the simplicity of diagonal matrices with the expressiveness of low-rank updates. This approach allows S4 to learn complex state spaces while maintaining computational efficiency, making it particularly effective for tasks requiring long-range dependencies.
However, a significant critique of the S4 model is its complexity. The model relies on multiple reduction steps and advanced linear algebraic techniques to compute the state space efficiently. While these optimizations are crucial for performance, they also make S4 challenging to understand, implement, and analyze. This complexity can act as a barrier for researchers and practitioners seeking to adapt or build upon the model, highlighting a trade-off between efficiency and accessibility.
Some of the applications of S4 are language and audio processing (text generation, speech recognition, music and voice synthesis), forecasting (predicting stock prices in financial markets, modelling climate information), and scientific computing (simulations in physics, chemistry, astronomy).
Diagonal State Space Model (DSS)
As mentioned above, S4 relies on the Diagonal Plus Low Rank (DPLR) structure of state matrix [math]\displaystyle{ A }[/math], which improves efficiency but introduces significant complexity in both understanding and implementation. To simplify S4 while maintaining its performance, Diagonal State Space Model (DSS) is proposed. It simplifies the state matrix by replacing it with a purely diagonal form, reducing computational overhead and improving interpretability. It also maintains the ability to capture long-range dependencies.
[math]\displaystyle{ \mathbf{Improvements\ Compared\ with\ S4} }[/math]
- Computational Efficiency (and Model Complexity)
- The DSS model offers a significant computational advantage by simplifying the NPLR state matrix to a diagonal form. This diagonalization enables more efficient recurrence computations and reduces the number of parameters to be estimated, making it computationally efficient. This simplicity allows DSS to scale well to large datasets and real-time systems.
- Interpretability
- The diagonal structure makes it clear how previous hidden states and inputs affect the current state, which helps in understanding the relationships between states. For example, it can be shown that under certain situations, DSS can capture information from extreme distant position in the sequence. This simplicity enhances the transparency of the model, providing clear insights into system behavior and making it easier to interpret the underlying dynamics. By comparison, S4 lacks interpretability due to its more complex architecture.
[math]\displaystyle{ \mathbf{Applications} }[/math]
Despite its simplifications, DSS retains strong modeling capabilities for a variety of tasks, including time-series forecasting, raw speech classification, and natural language processing (NLP). Its ability to model temporal dependencies with minimal complexity makes it a highly efficient choice for applications where computational efficiency and interpretability are critical. While S4 excels in capturing intricate, long-range dependencies, DSS's performance is nearly as strong, and in certain cases, it even outperforms S4.
[math]\displaystyle{ \mathbf{Limitations} }[/math]
The initialization of parameters can significantly influence the performance of DSS, as improper initialization may lead to slower convergence or suboptimal solutions. Additionally, DSS tends to be less effective at modeling information-dense data, such as text, where complex patterns and intricate relationships between words or phrases are crucial. In these cases, more sophisticated models may be helpful. However, DSS still provides a viable alternative in scenarios where simplicity and efficiency are prioritized over capturing deep contextual relationships.
Hungry Hungry Hippos (H3)
SSMs vs Attention
Research have shown that SSMs demonstrated state-of-the-art performance in domains like speech recognition, audio generation, etc. However, it underperforms attention in language modelling due to two main reasons:
- Expressivity gap
- Poor hardware utilization
Expressivity Gap
To understand the gap between SSMs and attention on language modelling, the authors examined two synthetic language modelling tasks:
The Induction Head task tests how well a model can recall content after a special token. At the end of the sequence, the model must recall the token that appeared immediately after the special token earlier in the sequence. In the table above, we are trying to recall the first token after the ⊢ symbol, which is f.
Associative Recall is similar to the induction head task, but requires the model to remember multiple key-value pairs. At the end of the sequence, the model must recall a specific value belonging to a specific key. In the table above, we are trying to recall the value associated with character a, which is 2.
The table above shows performance of S4D (specialized S4 for language modelling), Gated State Spaces (GSS) and Attention. We can see that Attention can complete both synthetic tasks perfectly, achieving an accuracy of 100%, significantly outperforming S4D and GSS. Failure of SSMs can be attributed to two missing capabilities: (i) the ability to remember tokens that appear after a particular event (e.g., the special token in the induction head task), and (ii) the ability to compare tokens across the sequence (e.g., comparing keys to decide which value to recall). Attention has both these capabilities: it can compare tokens by constructing the quadratic attention matrix [math]\displaystyle{ \mathbf{QK^T} }[/math], and it can recall tokens by direct copying (multiplying [math]\displaystyle{ softmax(\mathbf{QK^T}) }[/math] with [math]\displaystyle{ \mathbf{V} }[/math]).
H3 Design
H3 is designed to enable these capabilities in SSMs, its structure is given as follows:
The design for H3 layer consists of two SSMs stacked together.
The first SSM is the shift SSM that you can think of as performing a local lookup across the sequence. This will become clear shortly in the next few slides.
The diagonal SSM serve as a kind of global memory that keeps track of important information.
Multiplicative interactions, which is represented by the teal elementwise multiplication symbol, it offers H3 the ability to compare tokens across the sequence.
In the shift SSM, we constrain [math]\displaystyle{ \mathbf{A} ∈ R^{m×m} }[/math] to be a shift matrix as defined by its entries: [math]\displaystyle{ \mathbf{A}_{i,j} = \begin{cases} 1 & \text{for } i - 1 = j,\\ 0 & \text{otherwise}. \end{cases} }[/math]. If you draw it out, it would be a square matrix where the 1s appear directly below the main diagonal with all other entries 0. The action of this matrix on the hidden state [math]\displaystyle{ x_i }[/math] is to shift each coordinate down by one—thereby creating a “memory” of the previous states. If [math]\displaystyle{ \mathbf{B} = \textit{e}_1 }[/math], the first basis vector, then [math]\displaystyle{ x_i = [u_i, u_{i-1}, . . . , u_{i-m+1}] }[/math] contains the inputs from the previous m time steps. Both [math]\displaystyle{ \mathbf{B} }[/math] and [math]\displaystyle{ \mathbf{C} }[/math] are learnable matrices, but [math]\displaystyle{ \mathbf{B} }[/math] is usually fixed to [math]\displaystyle{ \textit{e}_1 }[/math] for simplicity, in which case the output is a 1D convolution with kernel size m.
The diagonal SSM constrains A to be diagonal and initializes it from the diagonal version of HiPPO. This parameterization allows the model to remember state over the entire sequence. The shift SSM can detect when a particular event occurs, and the diagonal SSM can remember a token afterwards for the rest of the sequence.
Mamba
Mamba is a technique which builds on S4 models. It was introduced to increase efficiency for long sequences by leveraging selective attention mechanisms, allowing them to save memory and computational cost by not focusing on irrelevant information. It does so by combining the H3 block traditionally used in SSM models with a gated multilayered perceptron (MLP), as shown in the figure below.

Recall that SSM models calculate the output using the following formulae: [math]\displaystyle{ h'(t) = \mathbf{A} h(t) + \mathbf{B} x(t) }[/math]
[math]\displaystyle{ y(t) = \mathbf{C} h(t) + \mathbf{D} x(t) }[/math]
Where these equations are discretized using an additional parameter [math]\displaystyle{ \Delta }[/math] as shown above, resulting in the updated formulae:
[math]\displaystyle{ h(t) = \bar{\mathbf{A}} h(t-1)+\bar{\mathbf{B}}x(t) }[/math]
[math]\displaystyle{ y(t) = \mathbf{C} h(t) }[/math]
The matrices may be understood as follows:
- [math]\displaystyle{ \bar{\mathbf{A}} }[/math] represents the transition matrix between discrete hidden states
- [math]\displaystyle{ \bar{\mathbf{B}} }[/math] processes the input data in order to update the hidden state
- [math]\displaystyle{ \mathbf{C} }[/math] interprets the hidden state in order to generate the output.
Mamba introduces a gating function [math]\displaystyle{ \Delta }[/math] which is learnable, and the hidden states are updated according to [math]\displaystyle{ h(t) = (1-\Delta)h(t-1) + \Delta x(t) }[/math]
Selective State Space Models in Mamba
Mamba builds on the success of structured state space models (SSMs) — a family of models that, by design, efficiently capture long-range dependencies through a recurrent or convolutional mechanism. In traditional SSMs, a key benefit is their linear-time scaling, achieved by compressing the entire sequence context into a fixed hidden state. However, this very efficiency can be a double-edged sword: by forcing the state to be constant over time (a property known as linear time invariance or LTI), the model may miss crucial content-specific cues in the data.
Mamba addresses this limitation with a clever twist—it introduces a selection mechanism that makes certain SSM parameters depend on the input. In simple terms, while standard SSMs update their hidden state with the same fixed rules at every time step, the selective variant modulates these update rules based on the current input. This allows the model to "decide" when to ignore irrelevant information and when to focus on key inputs, similar to how gating works in recurrent neural networks.
The mechanism works by turning parameters that usually remain constant (like the discretization parameter [math]\displaystyle{ \Delta }[/math] and others controlling the transition between hidden states) into functions of the input. This extra flexibility means that rather than applying a one-size-fits-all transformation to every token, the model tailors its processing dynamically across the sequence. The result is a network that not only maintains the efficiency of linear scaling but also improves accuracy by being content-aware—even for very long sequences.
An important innovation in Mamba is that it leverages a hardware-aware algorithm to keep the additional computational cost low. Instead of materializing huge intermediate states, the algorithm uses optimized memory strategies. In practice, this means that Mamba can run several times faster than comparable Transformer-based models, especially on longer inputs.
In summary, Mamba's selective state space models marry the efficiency of classical SSMs with the flexibility of input-dependent gating. This combination allows the architecture to handle extremely long sequences—up to millions of tokens—while still capturing the nuanced patterns essential for tasks ranging from language modeling to audio generation.
Mamba-2
Semiseparable Matrices
Semiseparable Matrices are matrices that have this form:
[math]\displaystyle{ \begin{bmatrix} C_0^\top B_0 & \\ C_1^\top A_1 B_0 & C_1^\top B_1 & \\ C_2^\top A_2A_1 B_0 & C_2^\top A_2 B_1 & C_2^\top B_2 \\ \vdots & \vdots & \ddots & \ddots \\ C_\mathtt{T}^\top A_{\mathtt{T}-1}\dots A_1 B_0 & C_\mathtt{T}^\top A_{\mathtt{T}-1}\dots A_2 B_1 & \dots & C_\mathtt{T}^\top A_{\mathtt{T}-1} B_{\mathtt{T}-2} & C_\mathtt{T}^\top B_{\mathtt{T}-1} \\ \end{bmatrix} }[/math]
By observation, we can immediately spot that all algorithms for computing state space models can be viewed as structured matrix multiplication algorithms on semiseparable matrices. Such matrix have a property that every submatrix contained in the lower-triangular portion is low rank, so we can write then as:
[math]\displaystyle{ \begin{bmatrix} C_j^T A_{j,i}^X B_{i'} & \cdots & C_j^T A_{j,i-1}^X B_{i-1} \\ \vdots & \ddots & \vdots \\ C_{j'-1}^T A_{j'-1,i}^X B_{i'} & \cdots & C_{j'-1}^T A_{j'-1,i-1}^X B_{i-1} \end{bmatrix} = \begin{bmatrix} C_j^T A_{j,j}^X \\ \vdots \\ C_{j'-1}^T A_{j'-1,j}^X \end{bmatrix} A_{j,i-1}^X \begin{bmatrix} A_{i-1,i'}^X B_{i'} & \cdots & A_{i-1,i-1}^X B_{i-1} \end{bmatrix}. }[/math]
This will leads to an efficient algorithm for calculating semiseparable matrices.

SSD Algorithm: "Block Matrix Decomposition" and "Chunking and State Passing"
- We first partition the semiseparable matrix into blocks of size [math]\displaystyle{ \mathtt{Q} \times \mathtt{Q} }[/math].
- Each diagonal block can be computed using the quadratically, but since [math]\displaystyle{ \mathtt{Q} }[/math] is small, it's hardware efficient. (Orange on the figure)
- Then, we use the properties of semiseparable matrices to factorize each off-diagonal block, which is low rank (Blue, Greem and Yellow on the figure).
- Following a work flow shown, we can have an efficient algorithm that could even incorporate system optimization by running them on different GPUs.
Structured Masked Attention
In RetNet, we used a decay factor which can be seen as a mask on [math]\displaystyle{ \mathtt{Q}^T \mathtt{K} }[/math]. Inspired by Linear Attention and RetNet, we can generalize the idea of using a mask and introduces Structured Masked Attention(SMA). Definition: Structured masked attention (SMA) (or structured attention for short) is defined as a function on queries/keys/values Q,K,V as well as any structured matrix L, through the 4-way tensor contraction.
[math]\displaystyle{ Y = (L \circ Q K^\top) V }[/math]
We can use a special mask called 1-Semiseparable Matrices (1SS) defined as the following.
[math]\displaystyle{ L = \begin{bmatrix} 1 & \\ a_1 & 1 & \\ a_2a_1 & a_2 & 1 \\ \vdots & \vdots & \ddots & \ddots \\ a_{\mathtt{T}-1}\dots a_1 & a_{\mathtt{T}-1}\dots a_2 & \dots & a_{\mathtt{T}-1} & 1 \\ \end{bmatrix} . }[/math]

State Space Duality (SSD)
We can recognize that if we use 1SS as the mask for SMA, and if we restrict our SSM'S [math]\displaystyle{ A }[/math] to be a scalar times identity structure (that means [math]\displaystyle{ A }[/math] is a scalar at a given time, so [math]\displaystyle{ A_t = a_t }[/math] at time [math]\displaystyle{ t }[/math]), the are the same thing. This leads to the State Space Duality.
We can view the SSD as:
1. A state space models through structured matrix
2. A generalized linear attention through a masking
Overall, it is an instance that has dual quadratic and linear forms that can be derived from either representation.
Having all this dual form allowed us to use some extensions from linear attention such as multi-head attention, grouped-value attention, kernel attention approximations to softmax attention and so on.

Mamba-2 architecture
Using SSD and the a small changes to Mamba’s neural network architecture, we have Mamba-2 architecture.
One change is that we have Parallel Parameter Projections. In Mamba-2, the SSD layer is viewed as a map. It therefore makes sense to produce [math]\displaystyle{ A, B, C, X }[/math] in parallel with a single projection at the beginning of the block.
Another change is that we have an extra Normalization at the end, it could be any normalization, LayerNorm, GroupNorm, or RMSNorm.
Key Takeaway
State Space Models (SSMs) offer a scalable alternative to transformers in language modeling by using state space equations to manage and update hidden states for generating outputs from inputs. Since its inception, there has been several important model enhancements:
Year | Model | Description |
---|---|---|
2022 | Structured State Space (S4) | Leveraged Diagonal Plus Low-Rank parameterization to efficiently handle long sequences but challenging to understand |
2022 | Diagonal State Spaces (DSS) | Simplified S4 by using diagonal matrices to achieve comparable performance |
2023 | Hungry Hungry Hippos (H3) | Introduced new SSM layer consisting of two SSMs stacked together (a shift and diagonal SSM) to enable token recall and comparison across sequences |
2024 | Mamba | Integrated selective state spaces to enable dynamic parameter adjustment based on input with linear-time complexity |
2024 | Mamba-2 | Used state space duality (SSD) to enable larger state dimensions and faster training |
Topic 8: Sparse Attention
Introduction
Vanilla Attention is very computationally expensive due to the multiplication of very large matrices required. This results in a complexity of [math]\displaystyle{ O(n^2) }[/math] where n may be very large, potentially even limiting the scalability of this method for long sequences. Sparse attention tries to address this problem by being more selective about what attention is computed. Intuitively, not all tokens need to attend to each other because not all words provide semantic information about all other words in a sentence. How can we determine which token-pairs are important? The goal of Sparse Attention is to answer this question in order to improve the efficiency and scalability of vanilla attention by not computing attention between all tokens without sacrificing performance.
Four methods were presented:
- Sparse Sinkhorn Attention
- Big Bird Sparse Attention
- Attention with Linear Biases (ALiBi)
- SpAtten
Sparse Sinkhorn Attention
To address the computationally intensive of Vanilla Attention, especially for the long input sequence, Sparse Sinkhorn Attention proposes three core ideas:
- [math]\displaystyle{ \mathbf{Block} }[/math] [math]\displaystyle{ \mathbf{Partitioning} }[/math]
- The input sequence with length [math]\displaystyle{ l }[/math] is split into [math]\displaystyle{ N_b }[/math] blocks, each containing [math]\displaystyle{ b }[/math] tokens.
- [math]\displaystyle{ \mathbf{Block} }[/math] [math]\displaystyle{ \mathbf{Sorting} }[/math]
- Input blocks are sorted based on their similarity to the query.
- The Sinkhorn algorithm is then applied to generate a permutation matrix which ensures that relevant blocks are adjacent, and this algorithm also normalizes the sorting matrix to prevent overemphasis on certain input blocks.
- [math]\displaystyle{ \mathbf{Sparse\ (local)} }[/math] [math]\displaystyle{ \mathbf{Attention} }[/math]
- Instead of the full sequence, each token only computes the attention with token within its block (and tokens within its neighbour blocks).

Sorting Network
The original idea of block-based local attention is to allow tokens to only attend to tokens within the same block. However, this restricts the global attention field and limits the ability for local attention models to model long range dependencies, so we need to mitigates this problem by neural sorting of blocks, we learn this sorting through a neural network and gives a permutation matrix. We use a a Meta Sorting Network to learn to sort sequences to enable efficient quasi-global local attention. The equation of the Sorting Network is:
[math]\displaystyle{ R = P(x) = \sigma(W_B \sigma(W_pX+b_p) +b_B) }[/math]
Here, [math]\displaystyle{ R }[/math] becomes the permutation matrix, [math]\displaystyle{ \sigma }[/math] is an activation function and [math]\displaystyle{ W }[/math], [math]\displaystyle{ b }[/math] are the weights and biases. We can see that this is just a two-layer fully-connected neural network.
Sinkhorn normalization with Gumbel noise
We normalize the rows and columns of the sorting matrix using formula:
[math]\displaystyle{ S^0(R) = \frac{\exp(R)+\epsilon}{\tau} }[/math]
[math]\displaystyle{ S^k(R) = F_c(F_r(S^{k-1}(R))) }[/math]
where [math]\displaystyle{ F_r }[/math], [math]\displaystyle{ F_c }[/math] are the row and column wise normalization function, [math]\displaystyle{ \epsilon }[/math] is the injected standard i.i.d Gumbel noise and [math]\displaystyle{ \tau }[/math] is the temperature hyper-parameter.
Advantages over Vanilla Attention
- Reduced memory
- Sparse Sinkhorn Attention only needs [math]\displaystyle{ O\left(\left(\frac{l}{N_b}\right)^2 + N_b^2\right) }[/math], while Vanilla Attention requires [math]\displaystyle{ O(l^2) }[/math].
- Reduced Computational Complexity
- SORTCUT (a Sparse Sinkhorn Attention variant) only requires [math]\displaystyle{ O(lN_k) }[/math], where [math]\displaystyle{ N_k }[/math] is a hyperparameter which is much smaller than [math]\displaystyle{ l }[/math]. In contrast, Vanilla Attention requires [math]\displaystyle{ O(l^2) }[/math] time.
- Ability to Capture Long Range Dependency
- Through block sorting, relevant tokens are grouped adjacently, even if they are originally far apart. Then, through local attention, long range dependencies are captured. By comparison, Vanilla Attention struggles with this since it calculates attention for every tokens with all other tokens, and Softmax entropy collapse naturally reduces the importance of distant tokens, making it difficult to model long-range relationships both computationally and statistically.
Big Bird Sparse Attention
Intuitively, we picture the tokens as nodes in a directed graph where attention is calculated between two tokens if an edge exists to connect their nodes. Then, using the adjacency matrix of this graph, attention is not calculated between disconnected nodes.
In order to decide which nodes to 'connect,' the authors combined several methods:
- Random Attention
- Nodes are connected randomly (whether attention is calculated for a particular token pair is determined randomly)
- Window Attention
- Intuitively, tokens that are closer together probably provide more semantic information about each other. Therefore, attention is calculated between tokens that occur within a particular distance from each other.
- Global Attention
- Some tokens are 'celebrities' so attention between them and all other tokens is calculated. This results in a maximum distance of 2 between any two tokens (nodes in the graph)
Theoretically, the authors prove that any star-graph (graph which represents global attention) provides a lower bound for performance. Empirically, combining all three of the above methods into what the authors call 'BIGBIRD' results in higher model performance.
We may intepret the naming as:
- "Big": The ability of handling long sequences and overcome complexity of higher orders
- "Bird": A metaphor for bird's eye-view, combining global tokens (ability to observe the whole sequence), local sliding windows (ability to focus on local context), and random attention (ability to account for details that are random left over)
Limitations
Although BigBird achieves linear complexity, it does not come for free.
- Lower Bounds
- A natural task is displayed by the author that can be evaluated by O(1)-layer full attention mechanism. However, any sparse attention mechanism that has polynomial inner product evaluations and logarithmic length of the shortest path between any two nodes of the mechanism's digraph (thus more than just BigBird) needs to solve the task in at least polynomial layers.
Attention with Linear Biases (ALiBi)
Currently, models struggle to extrapolate. Extrapolation is the ability to produce sequence lengths at inference time that are longer than the sequences they were trained on. For example, if a model is trained on a dataset where the longest sequence is 1024 tokens, then it will perform poorly when asked to generate a sequence of 2048 tokens. If we can solve this problem then, in theory, training can become much more efficient as we can train on shorter sequences without sacrificing performance.
The authors hypothesized that the position representation method used in these models influenced extrapolation ability. The authors first tested the extrapolation performance of transformers that use absolute position encodings. This includes sinusoidal position embeddings (as seen in vanilla transformer) and rotary position embeddings (seen in open source GPT-3). The authors found that these transformer models did not extrapolate well.
Then, they tested the extrapolation performance of transformer that employed T5 bias as the position method. This method is a relative position method; it does not add positional information to word embeddings. Instead, after computing the query-key score in attention and before executing softmax, it adds a learned, shared bias that is dependent on the distance between the query and key. This model produced an improve performance. The authors concluded that a relative position method improves extrapolation. However, due to the high computation cost, T5 bias is infeasible.
ALiBi replaces the vanilla transformer [math]\displaystyle{ softmax(QK^T) }[/math] with [math]\displaystyle{ softmax(QK^T + bias) }[/math], adding a lower triangular matrix of biases in order to encode the relative positioning of tokens within the attention calculation.
The bias matrix is of the form [math]\displaystyle{ \begin{bmatrix} 0 & 0 & 0 & 0\\ -1 & 0 & 0 & 0\\ -2 & -1 & 0 & 0\\ -3 & -2 & -1 & 0\\ \end{bmatrix} }[/math]
The addition of this bias leads to a recency bias. A query-key pair that are far apart will have a large bias. The increasing penalty results in less information learned from distant query-key pairs.
The authors hypothesized that this could replace the positional encoding in transformers. For model architectures with multiple attention heads, the weighting of these biases can vary per head (defaulting to a geometric sequence of [math]\displaystyle{ \frac{1}{2^m} }[/math] where m is the number of heads.)
So even through ALiBi is considered to be a spare attention method, it didn't specify sparsity explicitly. As tokens gets far from one another, the bias becomes small and implicitly, we have sparsity.
SpAtten
SpAtten is a method for pruning both tokens and attention heads in transformer models to improve efficiency. By removing tokens and heads with pruning, SpAtten reduces computational complexity and memory usage.
There are four key components to SpAtten:
1. Cascade Token Pruning
2. Cascade Head Pruning
3. Progressive Quantization
4. Specialized High Parallelism top-k Engine.
The first two are motivated by the inherent redundancy in human languages. Cascade token pruning helps to identify and remove unimportant tokens in a sentence. Similarly, cascade head pruning eliminates unnecessary attention heads. Unlike traditional weight pruning techniques, cascade token and cascade head pruning operates dynamically and selects tokens and heads to prune on the fly during inference. As an example, given an initial sentence of "As a visual treat, the film is almost perfect." with 11 tokens and 12 heads, cascade pruning results in a final output of "film perfect" with only 2 tokens and 8 heads. The final output retains the essence of the initial sentence but would use less memory and be less computationally expensive to train.
The latter two components play a pivotal role in ensuring the efficient implementation of SpAtten. Progressive quantization is designed to strike a balance between computational efficiency and accuracy by dynamically adjusting the precision of computations. The process begins by handling only the most significant bits (MSBs) of the input data, enabling faster computations and reduced memory access. If the confidence in the result is insufficient, the system progressively incorporates more bits of data to recompute the output with higher precision. This approach is further refined by assigning different bitwidths to different layers, optimizing performance across the model.
Similarly, SpAtten employs a specially designed top-k engine to enable real-time decision-making for pruning tokens and attention heads. This engine efficiently ranks token and head importance scores in linear time, avoiding the inefficiencies of traditional sorting methods. Instead of sorting an entire array, the top-k engine uses a quick-select module to identify the kth largest element and filters the array based on this threshold. This technique is also implemented in a highly parallelized manner, with multiple comparators working simultaneously to achieve linear-time ranking. This streamlined process ensures rapid and accurate pruning decisions, contributing to SpAtten's overall efficiency.
Topic 10: Linear Attention
Introduction
Linear attention tries to address the efficiency limitations of traditional softmax attention. Standard attention (also called vanilla attention) is calculated as:
[math]\displaystyle{ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V }[/math]
The computation complexity of this method is [math]\displaystyle{ O(n^2d) }[/math] for sequence length [math]\displaystyle{ n }[/math] and representation dimension [math]\displaystyle{ d }[/math] and the main reason behind that is the multiplication of [math]\displaystyle{ Q }[/math] and [math]\displaystyle{ K^T }[/math], produces a large [math]\displaystyle{ n \times n }[/math] attention matrix . If we remove the softmax function from this equation, then we can reduce the computational intensity because the matrix multiplication can be reordered to [math]\displaystyle{ Q(K^T V) }[/math]. Multiplying [math]\displaystyle{ K^T V }[/math] first results in smaller matrix multiplications. However, this linear mapping is insufficient to capture the complex relationships in the data. To reintroduce more complexity, we can replace the softmax function with some kernel with approximates it. Recall that for any kernel [math]\displaystyle{ K, K(x, y)= \Phi(x)^T \Phi(y) }[/math] for a matrix [math]\displaystyle{ \Phi }[/math]. If we apply this kernel to Q and K, then we can approximate vanilla attention with a much more efficient mechanism.
Key Approaches to Linear Attention
Retentive Network (RetNet): A Successor to Transformer for Large Language Models

The authors of this paper introduce a new architecture that can achieve parallel training, fast inference, without compromising performance solving the inherent complexity in vanilla attention mechanism. They claim that the new architecture makes the impossible triangle possible. RetNet combines the advantages of transformer performance and ability for parallel training with the lower-cost inference of RNNs by introducing a multi-scale retention mechanism to replace multi-head attention that incorporates three main representations:
- Parallel: allows GPU utilization during training.
- Recurrent: enables low-cost inference with [math]\displaystyle{ O(1) }[/math] complexity in terms of computation and memory.
- Chunkwise recurrent: models long-sequences efficiently.
Essentially, RetNet replaces the [math]\displaystyle{ softmax }[/math] by adopting a linear representation of the attention mechanism in the form of retention. We can show that the retention mechanism has a dual form of recurrence and parallelism (Figure2). First let's show RetNet in its recurrent form:

[math]\displaystyle{ S_n = \gamma S_{n-1} + K_n^T V_n }[/math]
[math]\displaystyle{ \text{Retention}(X_n) = Q_n S_n }[/math]
In this formula, [math]\displaystyle{ \gamma }[/math] is a decay factor, and [math]\displaystyle{ Q }[/math], [math]\displaystyle{ K }[/math], [math]\displaystyle{ V }[/math] are learned projections of the input. During training, this can be computed in parallel form. Now, we can show the parallel representation of retention:
[math]\displaystyle{ \text{Retention}(X) = (QK^T \odot D)V }[/math]
Where [math]\displaystyle{ D }[/math] is a matrix that combines causal masking and exponential decay along relative distance. For better understanding the role of matrix [math]\displaystyle{ D }[/math], let's take a look at its formula:
[math]\displaystyle{ D_{nm} = \begin{cases} \gamma^{n-m}, & n \geq m \\ 0, & n \lt m \end{cases} }[/math]
This formulation ensures that tokens can only attend to previous tokens (which they call it, causal masking) and the attention strength decays exponentially with distance (controlled by parameter [math]\displaystyle{ \gamma }[/math]). In other words, as we move forward, less attention is paid to earlier tokens.
The third component, chunkwise recurrent representation, is a hybrid of recurrent and parallel representations aimed at training acceleration for long sequences. Simply, the input sequence is divided into chunks, where in each chunk a parallel representation is utilized to process all the tokens simultaneously. To pass information across chunks, the recurrent representation is used to create a summary of that chunk and pass it to the next. The combination of inner-chunk (parallel) and cross-chunk (recurrent) information produce the retention of the chunk which captures the details within the chunk and the context from previous chunks. To compute the retention of the i-th chunk, let [math]\displaystyle{ B }[/math] be the chunk length, then:
[math]\displaystyle{ Q_{[i]} = Q_{Bi:B(i+1)}, K_{[i]} = K_{Bi:B(i+1)}, V_{[i]} = V_{Bi:B(i+1)} }[/math]
[math]\displaystyle{ R_i = K_{[i]}^T (V_{[i]} \odot \zeta) + \gamma^B R_{i-1}, \quad \zeta_{ij} = \gamma^{B-i-1} }[/math]
[math]\displaystyle{ {Retention}(X_{[i]}) = \underbrace{(Q_{[i]} K_{[i]}^T \odot D) V_{[i]}}_{Inner-Chunk} + \underbrace{(Q_{[i]} R_{i-1}) \odot \xi}_{Cross-Chunk}, \quad \xi_{i,j} = \gamma^{i+1} }[/math]
where [math]\displaystyle{ R_i }[/math] is the state of the current chunk to be passed to the next.
In summary, the overall architecture of RetNet consists of stacked identical L blocks similar to Transformer. Each block consists of two modules: multi-scale retention (MSR) module and a feed-forward network (FFN) module. It takes the embeddings of input sequence [math]\displaystyle{ X^0= [x_1, \dots, x_{|x|}] \in \mathbb R^{|x| \times d_{model}} }[/math] and compute the output [math]\displaystyle{ X^L }[/math]:
[math]\displaystyle{ Y^l = MSR(LN(X^l)) + X^l }[/math]
[math]\displaystyle{ X^{l+1} = FFN(LN(Y^l)) + Y^l }[/math]
where [math]\displaystyle{ LN(.) }[/math] is LayerNorm.
Finally, comparing RetNet to other solutions, we can see that it provides a solid alternative to the Transformer achieving parallel training, fast inference, linear long-sequence memory complexity, and good performance.
Limitations & Future Work
- Limited Cross-Modal Exploration
- The method is primarily validated on language modeling tasks. Its applicability to other modalities (e.g., vision, audio) remains unexplored, requiring further research for broader adoption.
Simple linear attention language models balance the recall-throughput tradeoff
This paper focuses on the task of associative recall and discusses the memory-recall tradeoff. The memory-recall tradeoff basically states that higher recall requires higher memory consumption, as shown in Figure 2.

The authors goal was to develop a method to improve this tradeoff. They do so by stacking linear attention blocks with sliding window attention blocks. Recall that one of the limitations of linear attention is relatively poor performance compared to vanilla attention. By selecting this type of architecture, the authors compromise on global performance by allowing linear attention to map global dependencies while retaining high performance within a local window by using the sliding window attention to map local dependencies. As seen in figure 4, they achieve this goal.
Gated Linear Attention (GLA)
RetNet shows us we can achieve parallel training, efficient inference, and more optimized performance. However, despite its innovations, RetNet still has significant limitations. The biggest problem with RetNet was how it handled sequential data. RetNet uses a constant decay factor ([math]\displaystyle{ \gamma }[/math]) that doesn't adapt to the content being processed. Let's imagine a conversation where you treat every fifth word equally regardless of how significant. That is what RetNet does with its constant exponential decay factor. What that means is that RetNet treats all sequential dependencies the same way, regardless of whether the current content requires more emphasis on recent tokens or on tokens much farther back. This process could work, but not necessarily optimally.
Furthermore, while RetNet is theoretically efficient, it is not so effective in practice on real hardware. The algorithms are not I/O-aware, and therefore they do not account for how data is being moved between levels of memory in modern GPUs and this leads to suboptimal performance. Gated Linear Attention (GLA) advances linear attention by solving such key limitations.
The new idea of GLA is the use of data-dependent gates. Unlike RetNet which uses fixed decay factors, these gates dynamically adjust based on the content being processed and this makes the model more powerful. It likes an intelligent filter that decides, at each position in a sequence, how much information should pass through depending on its relevance. For example, when processing the phrase "Prime Minister" in "The Prime Minister of Canada," the model might focus more on "Canada" than on "The". This is something RetNet with fixed decay cannot achieve. GLA introduces a gating mechanism where the output is determined by the below formula:
[math]\displaystyle{ \text{Output} = \text{LinearAttention} \odot \text{ContentGate} }[/math]
The ContentGate is derived directly from the input itself and this modification enhances the model’s ability to capture complex dependencies in text.
The other important feature of GLA is its FLASHLINEARATTENTION algorithm which designed specifically for modern GPU architectures. Unlike theoretical improvements, this optimization gives us real-world performance and outperforming even highly optimized methods like FLASHATTENTION-2. What makes FLASHLINEARATTENTION different is that it is I/O-aware, meaning it efficiently manages data movement between:
- High-bandwidth memory (HBM): Large but slower GPU memory.
- Shared memory (SRAM): Smaller but significantly faster memory.
By minimizing unnecessary data transfers and using specialized compute units (like tensor cores), this method achieves remarkable speed and efficiency. In benchmarks, it even surpasses FLASHATTENTION-2 on relatively short sequences of just 1,000 tokens.
Limitations & Future Work
- Lack of Large-Scale Validation
- Although the authors anticipate that the training efficiency of GLA become even more favorable compared to model like Mamba at larger scales, it is unclear how GLA would scale to even larger models/datasets.
- The experiments were conducted on moderate-scale language models (up to 1.3B parameters). The performance and efficiency of GLA at larger scales (e.g., >7B parameters) remain uncertain, particularly regarding tensor parallelism and memory constraints.
- Performance Gap in Recall-Intensive Tasks
- While GLA outperforms other subquadratic models in recall-intensive tasks, it still lags behind standard quadratic attention Transformers, indicating room for improvement in tasks requiring precise memory retrieval.
TRANSNORMERLLM: A Faster and Better Large Language Model with Improved Transformer
The objective of the paper is to address the quadratic time complexity and scalability limitations of conventional transformer-based LLMs by introducing a linear attention-based architecture for improved efficiency and performance.
Key Contributions
- Linear Attention Architecture
- Replaces softmax attention with linear attention, reducing computational complexity from quadratic to linear in sequence length.
- Introduces Lightning Attention, a block-wise computation technique that accelerates training (2× faster runtime, 4× memory reduction) by minimizing I/O operations.
- Architectural Improvements
- Positional Encoding: Combines LRPE (Linearized Relative Positional Encoding) with exponential decay to balance global interactions and avoid attention dilution.
- Gating Mechanisms: Uses Gated Linear Attention (GLA) and Simplified Gated Linear Units (SGLU) to stabilize training and enhance performance.
- Tensor Normalization: Replaces RMSNorm with SRMSNorm, a simpler normalization method that accelerates training without performance loss.
- Training Optimization
Topic 11: Flash Attention
Flash Attention tries to address the resource intensity of vanilla attention by utilizing hardware more effectively. Recent advances in computational speed (particularly in parallelization) have lessened computational intensity as a bottleneck. Consequently, memory usage became a larger bottleneck. Specifically, during training, attention requires data to be read and written multiple times and computations are done in HBM (which is not the fastest memory in the computer). As a result, training takes a lot of time and resource. The goal of Flash Attention is thus to address these inefficiencies in order to reduce training time for LLMs.
Flash Attention 1
Flash Attention 1 tries to take advantage of SRAM, a form of memory which is over 10 times faster than HBM, but can only handle 20mb of data. Therefore, the authors propose block-wise computation, ie. dividing the matrices into smaller blocks and performing computations one block at a time in SRAM. This allows the computer to load all the necessary data for a single computation block onto SRAM, perform the computation, and write the output back into HBM. The combination of faster computation speed and less reads/writes speeds up training time significantly.
Additionally, the authors propose Block-Sparse Flash Attention where some blocks are fully dropped (thus not computed) which speeds up training even more, due to the reduction in amount of computation required. This method is reminiscent of Sparse Attention methods.