stat946W25
Record your contributions here [1]
Use the following notations:
C: You have written a summary/critique on the paper.
Your feedback on presentations
Topic 12: State Space Models
Introduction
State Space Models (SSMs) are introduced as powerful alternatives to traditional sequence modeling approaches. These models demonstrate good performance in various modalities, including time series analysis, audio generation, and image processing and they can capture long-range dependencies more efficiently. SSMs initially struggled to match the performance of Transformers in language modeling tasks and there were some gaps between them. To address their challenges, recent advances in their architecture such as the Structured State Space Model (S4) have been introduced, which succeeded in long-range reasoning tasks and allowed for more efficient computation while preserving theoretical strengths. However, its implementation remains complex and computationally demanding. So further research led to simplified variants such as the Diagonal State Space Model (DSS), which achieves comparable performance with a more straightforward formulation. In parallel, hybrid approaches, like the H3 model, that integrate SSMs with attention mechanisms try to bridge the mentioned gaps. To understand better what I mean from the hybrid word, for example in H3 the authors try replacing almost all the attention layers in transformers with SSMs. More recently, models like Mamba have pushed the boundaries of SSMs by selectively parameterizing state matrices as functions of the input and allowing more flexible and adaptive information propagation. Research in SSMs continues to resolve the remaining challenges and the potential to substitute attention-based architectures with SSMs grows stronger. They will likely play a crucial role in the next generation of sequence modeling frameworks.
Core concepts
To understand State Space Models better, let's first take a look at the problem with Transformers and Recurrent Neural Networks (RNNs) and the relation between them and SSMs. In transformers during training, we create a matrix comparing each token with every token that came before and the weights in each cell of this matrix show how similar they are to each other. For calculating the weights, we don't do a sequential operation and each two tokens' relation can be computed in parallel. However, during inference, when we try to generate the next token, we need to re-calculate the attention for the entire sequence, even if we already generated some tokens. This means, for a sequence of length [math]\displaystyle{ L }[/math], the computation cost would be [math]\displaystyle{ L^2 }[/math] and this is a bottleneck for long sequences in Transformer. RNN takes two inputs at each time step to predict the output [math]\displaystyle{ y_t }[/math] and generate current hidden state ([math]\displaystyle{ h_t }[/math]). These inputs are the hidden state of previous time step ([math]\displaystyle{ h_{t-1} }[/math]) and the input of current time step ([math]\displaystyle{ x_t }[/math]). This structure helps us to do the inference with linear computation because unlike transformers, it doesn't recalculate all previous hidden states. On the other hand, the known problem with them is that they forget the initial tokens as they go forward. A simple RNN updates its hidden state using below formula:

[math]\displaystyle{
h(t) = \sigma(W_h h(t-1) + W_x x(t))
}[/math]
[math]\displaystyle{ y(t) = W_y h(t) }[/math]
Where:
- [math]\displaystyle{ h(t) }[/math] is the hidden state at time t
- [math]\displaystyle{ x(t) }[/math] is the input
- [math]\displaystyle{ y(t) }[/math] is the output
- [math]\displaystyle{ W_h, W_x, and W_y }[/math] are weight matrices
- [math]\displaystyle{ \sigma }[/math] is a non-linear activation function
State Space Models come from control theory for mathematical representation of a system and describing its possible states. They define a linear mapping from an input signal x(t) to an output signal y(t) through a latent state representation h(t), and are formulated as:
[math]\displaystyle{
h'(t) = \mathbf{A} h(t) + \mathbf{B} x(t)
}[/math]
[math]\displaystyle{ y(t) = \mathbf{C} h(t) + \mathbf{D} x(t) }[/math]
Where:
- [math]\displaystyle{ h(t) }[/math] represents the hidden state
- [math]\displaystyle{ x(t) }[/math] is the input
- [math]\displaystyle{ y(t) }[/math] is the output
- [math]\displaystyle{ \mathbf{A}, \mathbf{B}, \mathbf{C}, }[/math] and [math]\displaystyle{ \mathbf{D} }[/math] are parameter matrices
We assume that by solving these equations we can predict the output sequence corresponding to the given input sequence and previous state. So our goal here is to find [math]\displaystyle{ h(t) }[/math] in a way that we can go from input sequence to the correct output sequence. However, note that in this definition we use continuous functions and finding h(t) analytically is challenging. Also, we usually have discrete input sequences in our problems. So we should discretize the equations, which details are explained in the next section, and we can finally reach these equations:
[math]\displaystyle{
h_t = \bar{\mathbf{A}} h_{t-1} + \bar{\mathbf{B}} x_t
}[/math]
[math]\displaystyle{ y_t = \mathbf{C} h_t }[/math]
Where:
[math]\displaystyle{ \mathbf{\bar A} }[/math] and [math]\displaystyle{ \mathbf{\bar B} }[/math] are discrete equivalents of the [math]\displaystyle{ \mathbf{A} }[/math] and [math]\displaystyle{ \mathbf{B} }[/math] transforms compounded over timestep [math]\displaystyle{ \Delta }[/math] using some discretization method. For example using a trapezoidal rule
- [math]\displaystyle{ \mathbf{\bar A} = (I - \frac{\Delta}{2}A)^{-1}(I + \frac{\Delta}{2}A) }[/math]
- [math]\displaystyle{ \mathbf{\bar B} = (I - \frac{\Delta}{2}A)^{-1}\Delta B }[/math]
Look at #Discretization for further details.
Looking at these formulations and those we defined for RNNs shows us they are similar. We can see that an RNN is essentially a non-linear extension of a state space model. The main differences are:
- SSMs are linear transformations between states, while RNNs apply non-linearity through the activation function
- SSMs come from control theory and in control systems, the matrices are typically derived from physics equations, while in machine learning we learn these matrices from data
- In SSMs, we have D u(t) in the second equation which is commonly left out in control problems
Discretization
As mentioned State Space models come from ordinary differential equations, so the way we discretize those continuous equations so they work with finite sequences is crucial. In fact we aim to learn [math]\displaystyle{ \mathbf{\bar A} }[/math] and [math]\displaystyle{ \mathbf{\bar B} }[/math] and not A and B directly, so the discretization step is baked into our model in practice. Below we show a quick discretization example based on a trapizoidal rule.
Trapizoidal rule assumes: [math]\displaystyle{ x_{n+1} - x_{n} = \frac{\Delta}{2} (f(t_{n+1}) + f(t_{n})) }[/math]
We start from the ordinary differential equation. [math]\displaystyle{ h'(t) = \mathbf{A} h(t) + \mathbf{B} x(t) }[/math]
By using trapizoidal rule on [math]\displaystyle{ h' }[/math]
[math]\displaystyle{ h_{n+1} - h_{n} = \frac{\Delta}{2} (\mathbf{A}h_{n+1} + \mathbf{B}x_{n+1} + \mathbf{A}h_{n} + \mathbf{B}x_{n}) }[/math]
[math]\displaystyle{ h_{n+1} - \frac{\Delta}{2} \mathbf{A}h_{n+1} = h_n + \frac{\Delta}{2}\mathbf{A}h_{n} + \frac{\Delta}{2} \mathbf{B}(x_{n+1} + x_{n}) }[/math]
[math]\displaystyle{ (\mathbf{I} - \frac{\Delta}{2} \mathbf{A})h_{n+1} = (\mathbf{I} + \frac{\Delta}{2}\mathbf{A})h_{n} + \frac{\Delta}{2} \mathbf{B}(x_{n+1} + x_{n}) }[/math]
It is assumed that the control sequence does not change over small enough [math]\displaystyle{ \Delta }[/math]. e.i. [math]\displaystyle{ x_{n+1}\approx x_n }[/math]
[math]\displaystyle{ (\mathbf{I} - \frac{\Delta}{2} \mathbf{A})h_{n+1} = (\mathbf{I} + \frac{\Delta}{2}\mathbf{A})h_{n} + \Delta \mathbf{B}(x_{n+1}) }[/math]
[math]\displaystyle{ h_{n+1} =
(\mathbf{I} - \frac{\Delta}{2} \mathbf{A})^{-1}(\mathbf{I} + \frac{\Delta}{2}\mathbf{A})h_{n} + (\mathbf{I} - \frac{\Delta}{2} \mathbf{A})^{-1}\Delta \mathbf{B}(x_{n+1})
}[/math]
Indeed [math]\displaystyle{ \mathbf{\bar A}=(\mathbf{I} - \frac{\Delta}{2} \mathbf{A})^{-1}(\mathbf{I} + \frac{\Delta}{2}\mathbf{A}) }[/math] and [math]\displaystyle{ \mathbf{\bar B}=(\mathbf{I} - \frac{\Delta}{2} \mathbf{A})^{-1}\Delta \mathbf{B} }[/math]
Topic 8: Sparse Attention
Introduction
Vanilla Attention is very computationally expensive due to the multiplication of very large matrices required. This results in a complexity of [math]\displaystyle{ O(n^2) }[/math] where n may be very large, potentially even limiting the scalability of this method for long sequences. Sparse attention tries to address this problem by being more selective about what attention is computed. Intuitively, not all tokens need to attend to each other because not all words provide semantic information about all other words in a sentence. How can we determine which token-pairs are important? The goal of Sparse Attention is to answer this question in order to improve the efficiency and scalability of vanilla attention by not computing attention between all tokens without sacrificing performance.
Four methods were presented:
- Sparse Sinkhorn Attention
- Big Bird Sparse Attention
- Attention with Linear Biases (ALiBi)
- SpAtten
Sparse Sinkhorn Attention
Firstly, to address the computationally intensive of Vanilla Attention, Sparse Sinkhorn Attention proposes to partition the input
Big Bird Sparse Attention
Intuitively, we picture the tokens as nodes in a directed graph where attention is calculated between two tokens if an edge exists to connect their nodes. Then, using the adjacency matrix of this graph, attention is not calculated between disconnected nodes.
In order to decide which nodes to 'connect,' the authors combined several methods:
- Random Attention
- Nodes are connected randomly (whether attention is calculated for a particular token pair is determined randomly)
- Window Attention
- Intuitively, tokens that are closer together probably provide more semantic information about each other. Therefore, attention is calculated between tokens that occur within a particular distance from each other.
- Global Attention
- Some tokens are 'celebrities' so attention between them and all other tokens is calculated. This results in a maximum distance of 2 between any two tokens (nodes in the graph)
Theoretically, the authors prove that any star-graph (graph which represents global attention) provides a lower bound for performance. Empirically, combining all three of the above methods into what the authors call 'BIGBIRD' results in higher model performance.
Attention with Linear Biases (ALiBi)
Currently, models struggle to produce sequence lengths at inference time which are longer than the sequences they were trained on. For example, if a model is trained on a dataset where the longest sequence is 1024 tokens, then it will perform poorly when asked to generate a sequence of 2048 tokens. If we can solve this problem then, in theory, training can become much more efficient as we can train on shorter sequences without sacrificing performance.
ALiBi replaces the vanilla transformer [math]\displaystyle{ softmax(QK^T) }[/math] with [math]\displaystyle{ softmax(QK^T + bias) }[/math], adding a lower triangular matrix of biases in order to encode the relative positioning of tokens within the attention calculation.
The bias matrix is of the form [math]\displaystyle{ \begin{bmatrix} 0 & 0 & 0 & 0\\ -1 & 0 & 0 & 0\\ -2 & -1 & 0 & 0\\ -3 & -2 & -1 & 0\\ \end{bmatrix} }[/math]
The authors hypothesized that this could replace the positional encoding in transformers. For model architectures with multiple attention heads, the weighting of these biases can vary per head (defaulting to a geometric sequence of [math]\displaystyle{ \frac{1}{2^m} }[/math] where m is the number of heads.)