stat946W25

From statwiki
Jump to navigation Jump to search

Record your contributions here [1]

Use the following notations:

C: You have written a summary/critique on the paper.



Your feedback on presentations

Topic 12: State Space Models

Introduction

State Space Models (SSMs) are introduced as powerful alternatives to traditional sequence modeling approaches. These models demonstrate good performance in various modalities, including time series analysis, audio generation, and image processing and they can capture long-range dependencies more efficiently. SSMs initially struggled to match the performance of Transformers in language modeling tasks and there were some gaps between them. To address their challenges, recent advances in their architecture such as the Structured State Space Model (S4) have been introduced, which succeeded in long-range reasoning tasks and allowed for more efficient computation while preserving theoretical strengths. However, its implementation remains complex and computationally demanding. So further research led to simplified variants such as the Diagonal State Space Model (DSS), which achieves comparable performance with a more straightforward formulation. In parallel, hybrid approaches, like the H3 model, that integrate SSMs with attention mechanisms try to bridge the mentioned gaps. To understand better what I mean from the hybrid word, for example in H3 the authors try replacing almost all the attention layers in transformers with SSMs. More recently, models like Mamba have pushed the boundaries of SSMs by selectively parameterizing state matrices as functions of the input and allowing more flexible and adaptive information propagation. Research in SSMs continues to resolve the remaining challenges and the potential to substitute attention-based architectures with SSMs grows stronger. They will likely play a crucial role in the next generation of sequence modeling frameworks.

Core concepts

To understand State Space Models better, let's first take a look at the problem with Transformers and Recurrent Neural Networks (RNNs) and the relation between them and SSMs. In transformers during training, we create a matrix comparing each token with every token that came before and the weights in each cell of this matrix show how similar they are to each other. For calculating the weights, we don't do a sequential operation and each two tokens' relation can be computed in parallel. However, during inference, when we try to generate the next token, we need to re-calculate the attention for the entire sequence, even if we already generated some tokens. This means, for a sequence of length [math]\displaystyle{ L }[/math], the computation cost would be [math]\displaystyle{ L^2 }[/math] and this is a bottleneck for long sequences in Transformer. RNN takes two inputs at each time step to predict the output [math]\displaystyle{ y_t }[/math] and generate current hidden state ([math]\displaystyle{ h_t }[/math]). These inputs are the hidden state of previous time step ([math]\displaystyle{ h_{t-1} }[/math]) and the input of current time step ([math]\displaystyle{ x_t }[/math]). This structure helps us to do the inference with linear computation because unlike transformers, it doesn't recalculate all previous hidden states. On the other hand, the known problem with them is that they forget the initial tokens as they go forward. A simple RNN updates its hidden state using below formula:

Figure 1: Simple RNN architecture


[math]\displaystyle{ h(t) = \sigma(W_h h(t-1) + W_x x(t)) }[/math]

[math]\displaystyle{ y(t) = W_y h(t) }[/math]


Where:

  • [math]\displaystyle{ h(t) }[/math] is the hidden state at time t
  • [math]\displaystyle{ x(t) }[/math] is the input
  • [math]\displaystyle{ y(t) }[/math] is the output
  • [math]\displaystyle{ W_h, W_x, }[/math] and [math]\displaystyle{ W_y }[/math] are weight matrices
  • [math]\displaystyle{ \sigma }[/math] is a non-linear activation function



State Space Models come from control theory for mathematical representation of a system and describing its possible states. They define a linear mapping from an input signal x(t) to an output signal y(t) through a latent state representation h(t), and are formulated as:


[math]\displaystyle{ h'(t) = \mathbf{A} h(t) + \mathbf{B} x(t) }[/math]

[math]\displaystyle{ y(t) = \mathbf{C} h(t) + \mathbf{D} x(t) }[/math]


Where:

  • [math]\displaystyle{ h(t) }[/math] represents the hidden state
  • [math]\displaystyle{ x(t) }[/math] is the input
  • [math]\displaystyle{ y(t) }[/math] is the output
  • [math]\displaystyle{ \mathbf{A}, \mathbf{B}, \mathbf{C}, }[/math] and [math]\displaystyle{ \mathbf{D} }[/math] are parameter matrices

We assume that by solving these equations we can predict the output sequence corresponding to the given input sequence and previous state. So our goal here is to find [math]\displaystyle{ h(t) }[/math] in a way that we can go from input sequence to the correct output sequence. However, note that in this definition we use continuous functions and finding h(t) analytically is challenging. Also, we usually have discrete input sequences in our problems. So we should discretize the equations, which details are explained in the next section, and we can finally reach these equations:


[math]\displaystyle{ h_t = \bar{\mathbf{A}} h_{t-1} + \bar{\mathbf{B}} x_t }[/math]

[math]\displaystyle{ y_t = \mathbf{C} h_t }[/math]

Where:

[math]\displaystyle{ \mathbf{\bar A} }[/math] and [math]\displaystyle{ \mathbf{\bar B} }[/math] are discrete equivalents of the [math]\displaystyle{ \mathbf{A} }[/math] and [math]\displaystyle{ \mathbf{B} }[/math] transforms compounded over timestep [math]\displaystyle{ \Delta }[/math] using some discretization method. For example using a trapezoidal rule

  • [math]\displaystyle{ \mathbf{\bar A} = (I - \frac{\Delta}{2}A)^{-1}(I + \frac{\Delta}{2}A) }[/math]
  • [math]\displaystyle{ \mathbf{\bar B} = (I - \frac{\Delta}{2}A)^{-1}\Delta B }[/math]

Look at #Discretization for further details.

Looking at these formulations and those we defined for RNNs shows us they are similar. We can see that an RNN is essentially a non-linear extension of a state space model. The main differences are:

  • SSMs are linear transformations between states, while RNNs apply non-linearity through the activation function
  • SSMs come from control theory and in control systems, the matrices are typically derived from physics equations, while in machine learning we learn these matrices from data
  • In SSMs, we have D u(t) in the second equation which is commonly left out in control problems

Discretization

As mentioned State Space models come from ordinary differential equations, so the way we discretize those continuous equations so they work with finite sequences is crucial. In fact we aim to learn [math]\displaystyle{ \mathbf{\bar A} }[/math] and [math]\displaystyle{ \mathbf{\bar B} }[/math] and not A and B directly, so the discretization step is baked into our model in practice. Below we show a quick discretization example based on a trapizoidal rule.

Trapizoidal rule assumes: [math]\displaystyle{ x_{n+1} - x_{n} = \frac{\Delta}{2} (f(t_{n+1}) + f(t_{n})) }[/math]

We start from the ordinary differential equation. [math]\displaystyle{ h'(t) = \mathbf{A} h(t) + \mathbf{B} x(t) }[/math]

By using trapizoidal rule on [math]\displaystyle{ h' }[/math]

[math]\displaystyle{ h_{n+1} - h_{n} = \frac{\Delta}{2} (\mathbf{A}h_{n+1} + \mathbf{B}x_{n+1} + \mathbf{A}h_{n} + \mathbf{B}x_{n}) }[/math]

[math]\displaystyle{ h_{n+1} - \frac{\Delta}{2} \mathbf{A}h_{n+1} = h_n + \frac{\Delta}{2}\mathbf{A}h_{n} + \frac{\Delta}{2} \mathbf{B}(x_{n+1} + x_{n}) }[/math]

[math]\displaystyle{ (\mathbf{I} - \frac{\Delta}{2} \mathbf{A})h_{n+1} = (\mathbf{I} + \frac{\Delta}{2}\mathbf{A})h_{n} + \frac{\Delta}{2} \mathbf{B}(x_{n+1} + x_{n}) }[/math]

It is assumed that the control sequence does not change over small enough [math]\displaystyle{ \Delta }[/math]. e.i. [math]\displaystyle{ x_{n+1}\approx x_n }[/math]

[math]\displaystyle{ (\mathbf{I} - \frac{\Delta}{2} \mathbf{A})h_{n+1} = (\mathbf{I} + \frac{\Delta}{2}\mathbf{A})h_{n} + \Delta \mathbf{B}(x_{n+1}) }[/math]


[math]\displaystyle{ h_{n+1} = (\mathbf{I} - \frac{\Delta}{2} \mathbf{A})^{-1}(\mathbf{I} + \frac{\Delta}{2}\mathbf{A})h_{n} + (\mathbf{I} - \frac{\Delta}{2} \mathbf{A})^{-1}\Delta \mathbf{B}(x_{n+1}) }[/math]

Indeed [math]\displaystyle{ \mathbf{\bar A}=(\mathbf{I} - \frac{\Delta}{2} \mathbf{A})^{-1}(\mathbf{I} + \frac{\Delta}{2}\mathbf{A}) }[/math] and [math]\displaystyle{ \mathbf{\bar B}=(\mathbf{I} - \frac{\Delta}{2} \mathbf{A})^{-1}\Delta \mathbf{B} }[/math]

Structured State Space (S4)

A structured stats space (S4) model is a kind of model specialized for conducting effecient sequence modelling, particularly in situations when handling long-term dependencies in needed. The goal of S4 is to improve computational efficiency, reduce the number of parameters, and enhance the ability of learning by introducing constraints on the state transition matrices.

Some of the applications of S4 are language and audio processing (text generation, speeach recognition, music and voice synthesis), forcasting (predicting stock pricesin financial markets, modeling climate information), and scientific computing (simulations in physics, chemistry, astronomy).

Hungry Hungry Hippos (H3)

SSMs vs Attention

Research have shown that SSMs demonstrated state-of-the-art performance in domains like speech recognition, audio generation, etc. However, it underperforms attention in language modelling due to two main reasons:

  • Expressivity gap
  • Poor hardware utilization

Expressivity Gap

To understand the gap between SSMs and attention on language modelling, the authors examined two synthetic language modelling tasks:

The Induction Head task tests how well a model can recall content after a special token. At the end of the sequence, the model must recall the token that appeared immediately after the special token earlier in the sequence. In the table above, we are trying to recall the first token after the ⊢ symbol, which is f.

Associative Recall is similar to the induction head task, but requires the model to remember multiple key-value pairs. At the end of the sequence, the model must recall a specific value belonging to a specific key. In the table above, we are trying to recall the value associated with character a, which is 2.


The table above shows performance of S4D (specialized S4 for language modelling), Gated State Spaces (GSS) and Attention. We can see that Attention can complete both synthetic tasks perfectly, achieving an accuracy of 100%, significantly outperforming S4D and GSS. Failure of SSMs can be attributed to two missing capabilities: (i) the ability to remember tokens that appear after a particular event (e.g., the special token in the induction head task), and (ii) the ability to compare tokens across the sequence (e.g., comparing keys to decide which value to recall). Attention has both these capabilities: it can compare tokens by constructing the quadratic attention matrix [math]\displaystyle{ \mathbf{QK^T} }[/math], and it can recall tokens by direct copying (multiplying [math]\displaystyle{ softmax(\mathbf{QK^T}) }[/math] with [math]\displaystyle{ \mathbf{V} }[/math]).


H3 Design

H3 is designed to enable these capabilities in SSMs, its structure is given as follows:

The design for H3 layer consists of two SSMs stacked together.

The first SSM is the shift SSM that you can think of as performing a local lookup across the sequence. This will become clear shortly in the next few slides.

The diagonal SSM serve as a kind of global memory that keeps track of important information.

Multiplicative interactions, which is represented by the teal elementwise multiplication symbol, it offers H3 the ability to compare tokens across the sequence.


In the shift SSM, we constrain [math]\displaystyle{ \mathbf{A} ∈ R^{m×m} }[/math] to be a shift matrix as defined by its entries: [math]\displaystyle{ \mathbf{A}_{i,j} = \begin{cases} 1 & \text{for } i - 1 = j,\\ 0 & \text{otherwise}. \end{cases} }[/math]. If you draw it out, it would be a square matrix where the 1s appear directly below the main diagonal with all other entries 0. The action of this matrix on the hidden state [math]\displaystyle{ x_i }[/math] is to shift each coordinate down by one—thereby creating a “memory” of the previous states. If [math]\displaystyle{ \mathbf{B} = \textit{e}_1 }[/math], the first basis vector, then [math]\displaystyle{ x_i = [u_i, u_{i-1}, . . . , u_{i-m+1}] }[/math] contains the inputs from the previous m time steps. Both [math]\displaystyle{ \mathbf{B} }[/math] and [math]\displaystyle{ \mathbf{C} }[/math] are learnable matrices, but [math]\displaystyle{ \mathbf{B} }[/math] is usually fixed to [math]\displaystyle{ \textit{e}_1 }[/math] for simplicity, in which case the output is a 1D convolution with kernel size m.


The diagonal SSM constrains A to be diagonal and initializes it from the diagonal version of HiPPO. This parameterization allows the model to remember state over the entire sequence. The shift SSM can detect when a particular event occurs, and the diagonal SSM can remember a token afterwards for the rest of the sequence.

Mamba

Mamba is a technique which builds on S4 models. It was introduced to increase efficiency for long sequences by leveraging selective attention mechanisms, allowing them to save memory and computational cost by not focusing on irrelevant information. It does so by combining the H3 block traditionally used in SSM models with a gated multilayered perceptron (MLP), as shown in the figure below.

Topic 8: Sparse Attention

Introduction

Vanilla Attention is very computationally expensive due to the multiplication of very large matrices required. This results in a complexity of [math]\displaystyle{ O(n^2) }[/math] where n may be very large, potentially even limiting the scalability of this method for long sequences. Sparse attention tries to address this problem by being more selective about what attention is computed. Intuitively, not all tokens need to attend to each other because not all words provide semantic information about all other words in a sentence. How can we determine which token-pairs are important? The goal of Sparse Attention is to answer this question in order to improve the efficiency and scalability of vanilla attention by not computing attention between all tokens without sacrificing performance.

Four methods were presented:

  • Sparse Sinkhorn Attention
  • Big Bird Sparse Attention
  • Attention with Linear Biases (ALiBi)
  • SpAtten

Sparse Sinkhorn Attention

To address the computationally intensive of Vanilla Attention, especially for the long input sequence, Sparse Sinkhorn Attention proposes three core ideas:

  • [math]\displaystyle{ \mathbf{Block} }[/math] [math]\displaystyle{ \mathbf{Partitioning} }[/math]
    • The input sequence with length [math]\displaystyle{ l }[/math] is splitted into [math]\displaystyle{ N_b }[/math] blocks, each containing [math]\displaystyle{ b }[/math] tokens.
  • [math]\displaystyle{ \mathbf{Block} }[/math] [math]\displaystyle{ \mathbf{Sorting} }[/math]
    • Apply the Sinkhorn algoithm to generate a permutation matrix, ensuring that relavant blocks are adjacent.
  • [math]\displaystyle{ \mathbf{Sparse\ (local)} }[/math] [math]\displaystyle{ \mathbf{Attention} }[/math]
    • Instead of the full sequence, each token only computes the attention with token within its block (and tokens within its neighbour blocks).

[math]\displaystyle{ \mathbf{Advantages\ over\ Vanilla\ Attention} }[/math]

  • Reduced memory
    • Sparse Sinkhorn Attention only needs [math]\displaystyle{ O\left(\left(\frac{l}{N_b}\right)^2 + N_b^2\right) }[/math], while Vanilla Attention requires [math]\displaystyle{ O(l^2) }[/math].
  • Reduced Computational Complexity
    • SORTCUT (a Sparse Sinkhorn Attention variant) only requires [math]\displaystyle{ O(lN_k) }[/math], where [math]\displaystyle{ N_k }[/math] is a hyperparameter which is much smaller than [math]\displaystyle{ l }[/math]. In contrast, Vanilla Attention requires [math]\displaystyle{ O(l^2) }[/math] time.
  • Ability to Capture Long Range Dependency
    • Through block sorting, relevant tokens are grouped adjacently, even if they are originally far apart. Then, through local attention, long range dependencies are captured. By comparison, Vanilla Attention struggles with this since it calculates attention for every tokens with all other tokens, and Softmax entropy collapse naturally reduces the importance of distant tokens, making it difficult to model long-range relationships both computationally and statistically.

Big Bird Sparse Attention

Intuitively, we picture the tokens as nodes in a directed graph where attention is calculated between two tokens if an edge exists to connect their nodes. Then, using the adjacency matrix of this graph, attention is not calculated between disconnected nodes.

In order to decide which nodes to 'connect,' the authors combined several methods:

  • Random Attention
    • Nodes are connected randomly (whether attention is calculated for a particular token pair is determined randomly)
  • Window Attention
    • Intuitively, tokens that are closer together probably provide more semantic information about each other. Therefore, attention is calculated between tokens that occur within a particular distance from each other.
  • Global Attention
    • Some tokens are 'celebrities' so attention between them and all other tokens is calculated. This results in a maximum distance of 2 between any two tokens (nodes in the graph)

Theoretically, the authors prove that any star-graph (graph which represents global attention) provides a lower bound for performance. Empirically, combining all three of the above methods into what the authors call 'BIGBIRD' results in higher model performance.

Attention with Linear Biases (ALiBi)

Currently, models struggle to produce sequence lengths at inference time which are longer than the sequences they were trained on. For example, if a model is trained on a dataset where the longest sequence is 1024 tokens, then it will perform poorly when asked to generate a sequence of 2048 tokens. If we can solve this problem then, in theory, training can become much more efficient as we can train on shorter sequences without sacrificing performance.

ALiBi replaces the vanilla transformer [math]\displaystyle{ softmax(QK^T) }[/math] with [math]\displaystyle{ softmax(QK^T + bias) }[/math], adding a lower triangular matrix of biases in order to encode the relative positioning of tokens within the attention calculation.

The bias matrix is of the form [math]\displaystyle{ \begin{bmatrix} 0 & 0 & 0 & 0\\ -1 & 0 & 0 & 0\\ -2 & -1 & 0 & 0\\ -3 & -2 & -1 & 0\\ \end{bmatrix} }[/math]

The authors hypothesized that this could replace the positional encoding in transformers. For model architectures with multiple attention heads, the weighting of these biases can vary per head (defaulting to a geometric sequence of [math]\displaystyle{ \frac{1}{2^m} }[/math] where m is the number of heads.)

SpAtten

Topic 10: Linear Attention