stat946W25: Difference between revisions

From statwiki
Jump to navigation Jump to search
Line 503: Line 503:
The computation complexity of this method is <math>O(n^2d)</math> for sequence length <math>n</math> and representation dimension <math>d</math> and the main reason behind that is the multiplication of <math>Q</math> and <math>K^T</math>, produces a large <math>n \times n</math> attention matrix . If we remove the softmax function from this equation, then we can reduce the computational intensity because the matrix multiplication can be reordered to <math>Q(K^T V)</math>. Multiplying <math>K^T V</math> first results in smaller matrix multiplications. However, this linear mapping is insufficient to capture the complex relationships in the data. To reintroduce more complexity, we can replace the softmax function with some kernel with approximates it. Recall that for any kernel <math>K, K(x, y)= \Phi(x)^T \Phi(y)</math> for a matrix <math>\Phi</math>. If we apply this kernel to Q and K, then we can approximate vanilla attention with a much more efficient mechanism. [[File:retnet_impossible_triangle.png|thumb|200px|Figure 1: Impossible Triangle]]  
The computation complexity of this method is <math>O(n^2d)</math> for sequence length <math>n</math> and representation dimension <math>d</math> and the main reason behind that is the multiplication of <math>Q</math> and <math>K^T</math>, produces a large <math>n \times n</math> attention matrix . If we remove the softmax function from this equation, then we can reduce the computational intensity because the matrix multiplication can be reordered to <math>Q(K^T V)</math>. Multiplying <math>K^T V</math> first results in smaller matrix multiplications. However, this linear mapping is insufficient to capture the complex relationships in the data. To reintroduce more complexity, we can replace the softmax function with some kernel with approximates it. Recall that for any kernel <math>K, K(x, y)= \Phi(x)^T \Phi(y)</math> for a matrix <math>\Phi</math>. If we apply this kernel to Q and K, then we can approximate vanilla attention with a much more efficient mechanism. [[File:retnet_impossible_triangle.png|thumb|200px|Figure 1: Impossible Triangle]]  


===Retentive Network: A Successor to Transformer for Large Language Models===
== Key Approaches to Linear Attention ==
===Retentive Network (RetNet): A Successor to Transformer for Large Language Models===
The authors of this paper introduce a new architecture that can achieve parallel training, fast inference, without compromising performance solving the inherent complexity in vanilla attention mechanism. They claim that the new architecture makes the impossible triangle possible. RetNet combines the advantages of transformer performance and ability for parallel training with the lower-cost inference of RNNs by introducing a multi-scale retention mechanism to replace multi-head attention that incorporates three main representations:  
The authors of this paper introduce a new architecture that can achieve parallel training, fast inference, without compromising performance solving the inherent complexity in vanilla attention mechanism. They claim that the new architecture makes the impossible triangle possible. RetNet combines the advantages of transformer performance and ability for parallel training with the lower-cost inference of RNNs by introducing a multi-scale retention mechanism to replace multi-head attention that incorporates three main representations:  
# Parallel: allows GPU utilization during training.
# Parallel: allows GPU utilization during training.
# Recurrent: enables low-cost inference with <math>O(1)</math> complexity in terms of computation and memory.
# Recurrent: enables low-cost inference with <math>O(1)</math> complexity in terms of computation and memory.
# Chunkwise recurrent: models long-sequences efficiently.
# Chunkwise recurrent: models long-sequences efficiently.
We can show the retention mechanism has a dual form of recurrence and parallelism. First let's show RetNet in its recurrent form:
<math>S_n = \gamma S_{n-1} + K_n^T V_n</math>
<math>\text{Retention}(X_n) = Q_n S_n</math>
In this formula, <math>\gamma</math> is a decay factor, and <math>Q</math>, <math>K</math>, <math>V</math> are learned projections of the input. During training, this can be computed in parallel form. Now, we can show the parallel representation of retention:
<math>\text{Retention}(X) = (QK^T \odot D)V</math>
Where <math>D</math> is a matrix that combines causal masking and exponential decay along relative distance. For better understanding the role of matrix <math>D</math>, let's take a look at its formula:
<math display="block">
D_{nm} =
\begin{cases}
\gamma^{n-m}, & n \geq m \\
0, & n < m
\end{cases}
</math>
This formulation ensures that tokens can only attend to previous tokens (which they call it, causal masking) and the attention strength decays exponentially with distance (controlled by parameter <math>\gamma</math>)
.


RetNet consists of stacked identical L blocks similar to Transformer. Each block consists of two modules: multi-scale retention (MSR) module and a feed-forward network (FFN) module. It takes the embeddings of input sequence <math>X^0= [x_1, \dots, x_{|x|}] \in \mathbb R^{|x| \times d_{model}}</math> and compute the output <math>X^L</math>:
RetNet consists of stacked identical L blocks similar to Transformer. Each block consists of two modules: multi-scale retention (MSR) module and a feed-forward network (FFN) module. It takes the embeddings of input sequence <math>X^0= [x_1, \dots, x_{|x|}] \in \mathbb R^{|x| \times d_{model}}</math> and compute the output <math>X^L</math>:

Revision as of 21:12, 10 March 2025

Record your contributions here [1]

Use the following notations:

C: You have written a summary/critique on the paper.



Your feedback on presentations

Topic 12: State Space Models

Introduction

State Space Models (SSMs) are introduced as powerful alternatives to traditional sequence modeling approaches. These models demonstrate good performance in various modalities, including time series analysis, audio generation, and image processing and they can capture long-range dependencies more efficiently. SSMs initially struggled to match the performance of Transformers in language modeling tasks and there were some gaps between them. To address their challenges, recent advances in their architecture such as the Structured State Space Model (S4) have been introduced, which succeeded in long-range reasoning tasks and allowed for more efficient computation while preserving theoretical strengths. However, its implementation remains complex and computationally demanding. So further research led to simplified variants such as the Diagonal State Space Model (DSS), which achieves comparable performance with a more straightforward formulation. In parallel, hybrid approaches, like the H3 model, that integrate SSMs with attention mechanisms try to bridge the mentioned gaps. To understand better what I mean from the hybrid word, for example in H3 the authors try replacing almost all the attention layers in transformers with SSMs. More recently, models like Mamba have pushed the boundaries of SSMs by selectively parameterizing state matrices as functions of the input and allowing more flexible and adaptive information propagation. Research in SSMs continues to resolve the remaining challenges and the potential to substitute attention-based architectures with SSMs grows stronger. They will likely play a crucial role in the next generation of sequence modeling frameworks.

Core concepts

To understand State Space Models better, let's first take a look at the problem with Transformers and Recurrent Neural Networks (RNNs) and the relation between them and SSMs. In transformers during training, we create a matrix comparing each token with every token that came before and the weights in each cell of this matrix show how similar they are to each other. For calculating the weights, we don't do a sequential operation and each two tokens' relation can be computed in parallel. However, during inference, when we try to generate the next token, we need to re-calculate the attention for the entire sequence, even if we already generated some tokens. This means, for a sequence of length [math]\displaystyle{ L }[/math], the computation cost would be [math]\displaystyle{ L^2 }[/math] and this is a bottleneck for long sequences in Transformer. RNN takes two inputs at each time step to predict the output [math]\displaystyle{ y_t }[/math] and generate current hidden state ([math]\displaystyle{ h_t }[/math]). These inputs are the hidden state of previous time step ([math]\displaystyle{ h_{t-1} }[/math]) and the input of current time step ([math]\displaystyle{ x_t }[/math]). This structure helps us to do the inference with linear computation because unlike transformers, it doesn't recalculate all previous hidden states. On the other hand, the known problem with them is that they forget the initial tokens as they go forward. A simple RNN updates its hidden state using below formula:

Figure 1: Simple RNN architecture


[math]\displaystyle{ h(t) = \sigma(W_h h(t-1) + W_x x(t)) }[/math]

[math]\displaystyle{ y(t) = W_y h(t) }[/math]


Where:

  • [math]\displaystyle{ h(t) }[/math] is the hidden state at time t
  • [math]\displaystyle{ x(t) }[/math] is the input
  • [math]\displaystyle{ y(t) }[/math] is the output
  • [math]\displaystyle{ W_h, W_x, }[/math] and [math]\displaystyle{ W_y }[/math] are weight matrices
  • [math]\displaystyle{ \sigma }[/math] is a non-linear activation function



State Space Models (SSM) come from control theory for mathematical representation of a system and describing its possible states. They define a linear mapping from an input signal x(t) to an output signal y(t) through a latent state representation h(t). There are two main components to SSM:


1. State equation: Describes how the system's hidden state changes based on the current state and input

[math]\displaystyle{ h'(t) = \mathbf{A} h(t) + \mathbf{B} x(t) }[/math]

2. Output equation: Defines how the system produces output based on the current state and input

[math]\displaystyle{ y(t) = \mathbf{C} h(t) + \mathbf{D} x(t) }[/math]


Where:

  • [math]\displaystyle{ h(t) }[/math] represents the hidden state
  • [math]\displaystyle{ x(t) }[/math] is the input
  • [math]\displaystyle{ y(t) }[/math] is the output
  • [math]\displaystyle{ \mathbf{A}, \mathbf{B}, \mathbf{C}, }[/math] and [math]\displaystyle{ \mathbf{D} }[/math] are learnable parameter matrices

We assume that by solving these equations we can predict the output sequence corresponding to the given input sequence and previous state. So our goal here is to find [math]\displaystyle{ h(t) }[/math] in a way that we can go from input sequence to the correct output sequence. However, note that in this definition we use continuous functions and finding h(t) analytically is challenging. Also, we usually have discrete input sequences in our problems. So we should discretize the equations, which details are explained in the next section, and we can finally reach these equations:


[math]\displaystyle{ h_t = \bar{\mathbf{A}} h_{t-1} + \bar{\mathbf{B}} x_t }[/math]

[math]\displaystyle{ y_t = \mathbf{C} h_t }[/math]

Where:

[math]\displaystyle{ \mathbf{\bar A} }[/math] and [math]\displaystyle{ \mathbf{\bar B} }[/math] are discrete equivalents of the [math]\displaystyle{ \mathbf{A} }[/math] and [math]\displaystyle{ \mathbf{B} }[/math] transforms compounded over timestep [math]\displaystyle{ \Delta }[/math] using some discretization method. For example using a trapezoidal rule

  • [math]\displaystyle{ \mathbf{\bar A} = (I - \frac{\Delta}{2}A)^{-1}(I + \frac{\Delta}{2}A) }[/math]
  • [math]\displaystyle{ \mathbf{\bar B} = (I - \frac{\Delta}{2}A)^{-1}\Delta B }[/math]

Look at #Discretization for further details.

Looking at these formulations and those we defined for RNNs shows us they are similar. We can see that an RNN is essentially a non-linear extension of a state space model. The main differences are:

  • SSMs are linear transformations between states, while RNNs apply non-linearity through the activation function
  • SSMs come from control theory and in control systems, the matrices are typically derived from physics equations, while in machine learning we learn these matrices from data
  • In SSMs, we have D u(t) in the second equation which is commonly left out in control problems

Discretization

As mentioned State Space models come from ordinary differential equations, so the way we discretize those continuous equations so they work with finite sequences is crucial. In fact we aim to learn [math]\displaystyle{ \mathbf{\bar A} }[/math] and [math]\displaystyle{ \mathbf{\bar B} }[/math] and not A and B directly, so the discretization step is baked into our model in practice. Below we show a quick discretization example based on a trapizoidal rule.

Trapizoidal rule assumes: [math]\displaystyle{ x_{n+1} - x_{n} = \frac{\Delta}{2} (f(t_{n+1}) + f(t_{n})) }[/math]

We start from the ordinary differential equation. [math]\displaystyle{ h'(t) = \mathbf{A} h(t) + \mathbf{B} x(t) }[/math]

By using trapizoidal rule on [math]\displaystyle{ h' }[/math]

[math]\displaystyle{ h_{n+1} - h_{n} = \frac{\Delta}{2} (\mathbf{A}h_{n+1} + \mathbf{B}x_{n+1} + \mathbf{A}h_{n} + \mathbf{B}x_{n}) }[/math]

[math]\displaystyle{ h_{n+1} - \frac{\Delta}{2} \mathbf{A}h_{n+1} = h_n + \frac{\Delta}{2}\mathbf{A}h_{n} + \frac{\Delta}{2} \mathbf{B}(x_{n+1} + x_{n}) }[/math]

[math]\displaystyle{ (\mathbf{I} - \frac{\Delta}{2} \mathbf{A})h_{n+1} = (\mathbf{I} + \frac{\Delta}{2}\mathbf{A})h_{n} + \frac{\Delta}{2} \mathbf{B}(x_{n+1} + x_{n}) }[/math]

It is assumed that the control sequence does not change over small enough [math]\displaystyle{ \Delta }[/math]. e.i. [math]\displaystyle{ x_{n+1}\approx x_n }[/math]

[math]\displaystyle{ (\mathbf{I} - \frac{\Delta}{2} \mathbf{A})h_{n+1} = (\mathbf{I} + \frac{\Delta}{2}\mathbf{A})h_{n} + \Delta \mathbf{B}(x_{n+1}) }[/math]


[math]\displaystyle{ h_{n+1} = (\mathbf{I} - \frac{\Delta}{2} \mathbf{A})^{-1}(\mathbf{I} + \frac{\Delta}{2}\mathbf{A})h_{n} + (\mathbf{I} - \frac{\Delta}{2} \mathbf{A})^{-1}\Delta \mathbf{B}(x_{n+1}) }[/math]

Indeed [math]\displaystyle{ \mathbf{\bar A}=(\mathbf{I} - \frac{\Delta}{2} \mathbf{A})^{-1}(\mathbf{I} + \frac{\Delta}{2}\mathbf{A}) }[/math] and [math]\displaystyle{ \mathbf{\bar B}=(\mathbf{I} - \frac{\Delta}{2} \mathbf{A})^{-1}\Delta \mathbf{B} }[/math]

Structured State Space (S4)

The structured state space (S4) model is specialized for conducting efficient sequence modelling, particularly in situations when handling long-term dependencies is needed. The goal of S4 is to improve computational efficiency, reduce the number of parameters, and enhance the ability of learning by introducing constraints on the state transition matrices.

To efficiently handle long sequences, the Structured State Space (S4) model employs a structured parameterization of its state transition matrices. Specifically, it uses a Diagonal Plus Low-Rank (DPLR) parameterization, which combines the simplicity of diagonal matrices with the expressiveness of low-rank updates. This approach allows S4 to learn complex state spaces while maintaining computational efficiency, making it particularly effective for tasks requiring long-range dependencies.

However, a significant critique of the S4 model is its complexity. The model relies on multiple reduction steps and advanced linear algebraic techniques to compute the state space efficiently. While these optimizations are crucial for performance, they also make S4 challenging to understand, implement, and analyze. This complexity can act as a barrier for researchers and practitioners seeking to adapt or build upon the model, highlighting a trade-off between efficiency and accessibility.

Some of the applications of S4 are language and audio processing (text generation, speech recognition, music and voice synthesis), forecasting (predicting stock prices in financial markets, modelling climate information), and scientific computing (simulations in physics, chemistry, astronomy).

Diagonal State Space Model (DSS)

As mentioned above, S4 relies on the Diagonal Plus Low Rank (DPLR) structure of state matrix [math]\displaystyle{ A }[/math], which improves efficiency but introduces significant complexity in both understanding and implementation. To simplify S4 while maintaining its performance, Diagonal State Space Model (DSS) is proposed. It simplifies the state matrix by replacing it with a purely diagonal form, reducing computational overhead and improving interpretability. It also maintains the ability to capture long-range dependencies.

[math]\displaystyle{ \mathbf{Improvements\ Compared\ with\ S4} }[/math]

  • Computational Efficiency (and Model Complexity)
    • The DSS model offers a significant computational advantage by simplifying the NPLR state matrix to a diagonal form. This diagonalization enables more efficient recurrence computations and reduces the number of parameters to be estimated, making it computationally efficient. This simplicity allows DSS to scale well to large datasets and real-time systems.
  • Interpretability
    • The diagonal structure makes it clear how previous hidden states and inputs affect the current state, which helps in understanding the relationships between states. For example, it can be shown that under certain situations, DSS can capture information from extreme distant position in the sequence. This simplicity enhances the transparency of the model, providing clear insights into system behavior and making it easier to interpret the underlying dynamics. By comparison, S4 lacks interpretability due to its more complex architecture.

[math]\displaystyle{ \mathbf{Applications} }[/math]

Despite its simplifications, DSS retains strong modeling capabilities for a variety of tasks, including time-series forecasting, raw speech classification, and natural language processing (NLP). Its ability to model temporal dependencies with minimal complexity makes it a highly efficient choice for applications where computational efficiency and interpretability are critical. While S4 excels in capturing intricate, long-range dependencies, DSS's performance is nearly as strong, and in certain cases, it even outperforms S4.

[math]\displaystyle{ \mathbf{Limitations} }[/math]

The initialization of parameters can significantly influence the performance of DSS, as improper initialization may lead to slower convergence or suboptimal solutions. Additionally, DSS tends to be less effective at modeling information-dense data, such as text, where complex patterns and intricate relationships between words or phrases are crucial. In these cases, more sophisticated models may be helpful. However, DSS still provides a viable alternative in scenarios where simplicity and efficiency are prioritized over capturing deep contextual relationships.

Hungry Hungry Hippos (H3)

SSMs vs Attention

Research have shown that SSMs demonstrated state-of-the-art performance in domains like speech recognition, audio generation, etc. However, it underperforms attention in language modelling due to two main reasons:

  • Expressivity gap
  • Poor hardware utilization

Expressivity Gap

To understand the gap between SSMs and attention on language modelling, the authors examined two synthetic language modelling tasks:

The Induction Head task tests how well a model can recall content after a special token. At the end of the sequence, the model must recall the token that appeared immediately after the special token earlier in the sequence. In the table above, we are trying to recall the first token after the ⊢ symbol, which is f.

Associative Recall is similar to the induction head task, but requires the model to remember multiple key-value pairs. At the end of the sequence, the model must recall a specific value belonging to a specific key. In the table above, we are trying to recall the value associated with character a, which is 2.


The table above shows performance of S4D (specialized S4 for language modelling), Gated State Spaces (GSS) and Attention. We can see that Attention can complete both synthetic tasks perfectly, achieving an accuracy of 100%, significantly outperforming S4D and GSS. Failure of SSMs can be attributed to two missing capabilities: (i) the ability to remember tokens that appear after a particular event (e.g., the special token in the induction head task), and (ii) the ability to compare tokens across the sequence (e.g., comparing keys to decide which value to recall). Attention has both these capabilities: it can compare tokens by constructing the quadratic attention matrix [math]\displaystyle{ \mathbf{QK^T} }[/math], and it can recall tokens by direct copying (multiplying [math]\displaystyle{ softmax(\mathbf{QK^T}) }[/math] with [math]\displaystyle{ \mathbf{V} }[/math]).


H3 Design

H3 is designed to enable these capabilities in SSMs, its structure is given as follows:

The design for H3 layer consists of two SSMs stacked together.

The first SSM is the shift SSM that you can think of as performing a local lookup across the sequence. This will become clear shortly in the next few slides.

The diagonal SSM serve as a kind of global memory that keeps track of important information.

Multiplicative interactions, which is represented by the teal elementwise multiplication symbol, it offers H3 the ability to compare tokens across the sequence.


In the shift SSM, we constrain [math]\displaystyle{ \mathbf{A} ∈ R^{m×m} }[/math] to be a shift matrix as defined by its entries: [math]\displaystyle{ \mathbf{A}_{i,j} = \begin{cases} 1 & \text{for } i - 1 = j,\\ 0 & \text{otherwise}. \end{cases} }[/math]. If you draw it out, it would be a square matrix where the 1s appear directly below the main diagonal with all other entries 0. The action of this matrix on the hidden state [math]\displaystyle{ x_i }[/math] is to shift each coordinate down by one—thereby creating a “memory” of the previous states. If [math]\displaystyle{ \mathbf{B} = \textit{e}_1 }[/math], the first basis vector, then [math]\displaystyle{ x_i = [u_i, u_{i-1}, . . . , u_{i-m+1}] }[/math] contains the inputs from the previous m time steps. Both [math]\displaystyle{ \mathbf{B} }[/math] and [math]\displaystyle{ \mathbf{C} }[/math] are learnable matrices, but [math]\displaystyle{ \mathbf{B} }[/math] is usually fixed to [math]\displaystyle{ \textit{e}_1 }[/math] for simplicity, in which case the output is a 1D convolution with kernel size m.


The diagonal SSM constrains A to be diagonal and initializes it from the diagonal version of HiPPO. This parameterization allows the model to remember state over the entire sequence. The shift SSM can detect when a particular event occurs, and the diagonal SSM can remember a token afterwards for the rest of the sequence.

Mamba

Mamba is a technique which builds on S4 models. It was introduced to increase efficiency for long sequences by leveraging selective attention mechanisms, allowing them to save memory and computational cost by not focusing on irrelevant information. It does so by combining the H3 block traditionally used in SSM models with a gated multilayered perceptron (MLP), as shown in the figure below.

Figure 3: Mamba architecture

Recall that SSM models calculate the output using the following formulae: [math]\displaystyle{ h'(t) = \mathbf{A} h(t) + \mathbf{B} x(t) }[/math]

[math]\displaystyle{ y(t) = \mathbf{C} h(t) + \mathbf{D} x(t) }[/math]

Where these equations are discretized using an additional parameter [math]\displaystyle{ \Delta }[/math] as shown above, resulting in the updated formulae:

[math]\displaystyle{ h(t) = \bar{\mathbf{A}} h(t-1)+\bar{\mathbf{B}}x(t) }[/math]

[math]\displaystyle{ y(t) = \mathbf{C} h(t) }[/math]

The matrices may be understood as follows:

  • [math]\displaystyle{ \bar{\mathbf{A}} }[/math] represents the transition matrix between discrete hidden states
  • [math]\displaystyle{ \bar{\mathbf{B}} }[/math] processes the input data in order to update the hidden state
  • [math]\displaystyle{ \mathbf{C} }[/math] interprets the hidden state in order to generate the output.

Mamba introduces a gating function [math]\displaystyle{ \Delta }[/math] which is learnable, and the hidden states are updated according to [math]\displaystyle{ h(t) = (1-\Delta)h(t-1) + \Delta x(t) }[/math]


Selective State Space Models in Mamba

Mamba builds on the success of structured state space models (SSMs) — a family of models that, by design, efficiently capture long-range dependencies through a recurrent or convolutional mechanism. In traditional SSMs, a key benefit is their linear-time scaling, achieved by compressing the entire sequence context into a fixed hidden state. However, this very efficiency can be a double-edged sword: by forcing the state to be constant over time (a property known as linear time invariance or LTI), the model may miss crucial content-specific cues in the data.

Mamba addresses this limitation with a clever twist—it introduces a selection mechanism that makes certain SSM parameters depend on the input. In simple terms, while standard SSMs update their hidden state with the same fixed rules at every time step, the selective variant modulates these update rules based on the current input. This allows the model to "decide" when to ignore irrelevant information and when to focus on key inputs, similar to how gating works in recurrent neural networks.

The mechanism works by turning parameters that usually remain constant (like the discretization parameter [math]\displaystyle{ \Delta }[/math] and others controlling the transition between hidden states) into functions of the input. This extra flexibility means that rather than applying a one-size-fits-all transformation to every token, the model tailors its processing dynamically across the sequence. The result is a network that not only maintains the efficiency of linear scaling but also improves accuracy by being content-aware—even for very long sequences.

An important innovation in Mamba is that it leverages a hardware-aware algorithm to keep the additional computational cost low. Instead of materializing huge intermediate states, the algorithm uses optimized memory strategies. In practice, this means that Mamba can run several times faster than comparable Transformer-based models, especially on longer inputs.

In summary, Mamba's selective state space models marry the efficiency of classical SSMs with the flexibility of input-dependent gating. This combination allows the architecture to handle extremely long sequences—up to millions of tokens—while still capturing the nuanced patterns essential for tasks ranging from language modeling to audio generation.


Mamba-2

Semiseparable Matrices

Semiseparable Matrices are matrices that have this form:

[math]\displaystyle{ \begin{bmatrix} C_0^\top B_0 & \\ C_1^\top A_1 B_0 & C_1^\top B_1 & \\ C_2^\top A_2A_1 B_0 & C_2^\top A_2 B_1 & C_2^\top B_2 \\ \vdots & \vdots & \ddots & \ddots \\ C_\mathtt{T}^\top A_{\mathtt{T}-1}\dots A_1 B_0 & C_\mathtt{T}^\top A_{\mathtt{T}-1}\dots A_2 B_1 & \dots & C_\mathtt{T}^\top A_{\mathtt{T}-1} B_{\mathtt{T}-2} & C_\mathtt{T}^\top B_{\mathtt{T}-1} \\ \end{bmatrix} }[/math]

By observation, we can immediately spot that all algorithms for computing state space models can be viewed as structured matrix multiplication algorithms on semiseparable matrices. Such matrix have a property that every submatrix contained in the lower-triangular portion is low rank, so we can write then as:

[math]\displaystyle{ \begin{bmatrix} C_j^T A_{j,i}^X B_{i'} & \cdots & C_j^T A_{j,i-1}^X B_{i-1} \\ \vdots & \ddots & \vdots \\ C_{j'-1}^T A_{j'-1,i}^X B_{i'} & \cdots & C_{j'-1}^T A_{j'-1,i-1}^X B_{i-1} \end{bmatrix} = \begin{bmatrix} C_j^T A_{j,j}^X \\ \vdots \\ C_{j'-1}^T A_{j'-1,j}^X \end{bmatrix} A_{j,i-1}^X \begin{bmatrix} A_{i-1,i'}^X B_{i'} & \cdots & A_{i-1,i-1}^X B_{i-1} \end{bmatrix}. }[/math]

This will leads to an efficient algorithm for calculating semiseparable matrices.

The matrix multiplication has an interpretation as a state space model, where blocks represent chunking the input and output sequence. Diagonal blocks represent intra-chunk compu tations and the off-diagonal blocks represent inter-chunk computations, factored through the SSM’s hidden state.

SSD Algorithm: "Block Matrix Decomposition" and "Chunking and State Passing"

  1. We first partition the semiseparable matrix into blocks of size [math]\displaystyle{ \mathtt{Q} \times \mathtt{Q} }[/math].
  2. Each diagonal block can be computed using the quadratically, but since [math]\displaystyle{ \mathtt{Q} }[/math] is small, it's hardware efficient. (Orange on the figure)
  3. Then, we use the properties of semiseparable matrices to factorize each off-diagonal block, which is low rank (Blue, Greem and Yellow on the figure).
  4. Following a work flow shown, we can have an efficient algorithm that could even incorporate system optimization by running them on different GPUs.

Structured Masked Attention

In RetNet, we used a decay factor which can be seen as a mask on [math]\displaystyle{ \mathtt{Q}^T \mathtt{K} }[/math]. Inspired by Linear Attention and RetNet, we can generalize the idea of using a mask and introduces Structured Masked Attention(SMA). Definition: Structured masked attention (SMA) (or structured attention for short) is defined as a function on queries/keys/values Q,K,V as well as any structured matrix L, through the 4-way tensor contraction.

[math]\displaystyle{ Y = (L \circ Q K^\top) V }[/math]

We can use a special mask called 1-Semiseparable Matrices (1SS) defined as the following.

[math]\displaystyle{ L = \begin{bmatrix} 1 & \\ a_1 & 1 & \\ a_2a_1 & a_2 & 1 \\ \vdots & \vdots & \ddots & \ddots \\ a_{\mathtt{T}-1}\dots a_1 & a_{\mathtt{T}-1}\dots a_2 & \dots & a_{\mathtt{T}-1} & 1 \\ \end{bmatrix} . }[/math]

SSMs and SMA intersect at a large class of state space dual models (SSD) which capture many sequence models as special cases.

State Space Duality (SSD)

We can recognize that if we use 1SS as the mask for SMA, and if we restrict our SSM'S [math]\displaystyle{ A }[/math] to be a scalar times identity structure (that means [math]\displaystyle{ A }[/math] is a scalar at a given time, so [math]\displaystyle{ A_t = a_t }[/math] at time [math]\displaystyle{ t }[/math]), the are the same thing. This leads to the State Space Duality.

We can view the SSD as:

1. A state space models through structured matrix

2. A generalized linear attention through a masking

Overall, it is an instance that has dual quadratic and linear forms that can be derived from either representation.

Having all this dual form allowed us to use some extensions from linear attention such as multi-head attention, grouped-value attention, kernel attention approximations to softmax attention and so on.

Mamba-2 architecture (Right) compared with Mamba (Left).

General SSMs and SMA both possess linear and quadratic forms, with direct analogs in notation.

Mamba-2 architecture

Using SSD and the a small changes to Mamba’s neural network architecture, we have Mamba-2 architecture.

One change is that we have Parallel Parameter Projections. In Mamba-2, the SSD layer is viewed as a map. It therefore makes sense to produce [math]\displaystyle{ A, B, C, X }[/math] in parallel with a single projection at the beginning of the block.

Another change is that we have an extra Normalization at the end, it could be any normalization, LayerNorm, GroupNorm, or RMSNorm.


Key Takeaway

State Space Models (SSMs) offer a scalable alternative to transformers in language modeling by using state space equations to manage and update hidden states for generating outputs from inputs. Since its inception, there has been several important model enhancements:

SSM Model Evolution over time
Year Model Description
2022 Structured State Space (S4) Leveraged Diagonal Plus Low-Rank parameterization to efficiently handle long sequences but challenging to understand
2022 Diagonal State Spaces (DSS) Simplified S4 by using diagonal matrices to achieve comparable performance
2023 Hungry Hungry Hippos (H3) Introduced new SSM layer consisting of two SSMs stacked together (a shift and diagonal SSM) to enable token recall and comparison across sequences
2024 Mamba Integrated selective state spaces to enable dynamic parameter adjustment based on input with linear-time complexity
2024 Mamba-2 Used state space duality (SSD) to enable larger state dimensions and faster training

Topic 8: Sparse Attention

Introduction

Vanilla Attention is very computationally expensive due to the multiplication of very large matrices required. This results in a complexity of [math]\displaystyle{ O(n^2) }[/math] where n may be very large, potentially even limiting the scalability of this method for long sequences. Sparse attention tries to address this problem by being more selective about what attention is computed. Intuitively, not all tokens need to attend to each other because not all words provide semantic information about all other words in a sentence. How can we determine which token-pairs are important? The goal of Sparse Attention is to answer this question in order to improve the efficiency and scalability of vanilla attention by not computing attention between all tokens without sacrificing performance.

Four methods were presented:

  • Sparse Sinkhorn Attention
  • Big Bird Sparse Attention
  • Attention with Linear Biases (ALiBi)
  • SpAtten

Sparse Sinkhorn Attention

To address the computationally intensive of Vanilla Attention, especially for the long input sequence, Sparse Sinkhorn Attention proposes three core ideas:

  • [math]\displaystyle{ \mathbf{Block} }[/math] [math]\displaystyle{ \mathbf{Partitioning} }[/math]
    • The input sequence with length [math]\displaystyle{ l }[/math] is split into [math]\displaystyle{ N_b }[/math] blocks, each containing [math]\displaystyle{ b }[/math] tokens.
  • [math]\displaystyle{ \mathbf{Block} }[/math] [math]\displaystyle{ \mathbf{Sorting} }[/math]
    • Input blocks are sorted based on their similarity to the query.
    • The Sinkhorn algorithm is then applied to generate a permutation matrix which ensures that relevant blocks are adjacent, and this algorithm also normalizes the sorting matrix to prevent overemphasis on certain input blocks.
  • [math]\displaystyle{ \mathbf{Sparse\ (local)} }[/math] [math]\displaystyle{ \mathbf{Attention} }[/math]
    • Instead of the full sequence, each token only computes the attention with token within its block (and tokens within its neighbour blocks).
Overview of Sparse Sinkhorn Attention.

Sorting Network

The original idea of block-based local attention is to allow tokens to only attend to tokens within the same block. However, this restricts the global attention field and limits the ability for local attention models to model long range dependencies, so we need to mitigates this problem by neural sorting of blocks, we learn this sorting through a neural network and gives a permutation matrix. We use a a Meta Sorting Network to learn to sort sequences to enable efficient quasi-global local attention. The equation of the Sorting Network is:

[math]\displaystyle{ R = P(x) = \sigma(W_B \sigma(W_pX+b_p) +b_B) }[/math]

Here, [math]\displaystyle{ R }[/math] becomes the permutation matrix, [math]\displaystyle{ \sigma }[/math] is an activation function and [math]\displaystyle{ W }[/math], [math]\displaystyle{ b }[/math] are the weights and biases. We can see that this is just a two-layer fully-connected neural network.

Sinkhorn normalization with Gumbel noise

We normalize the rows and columns of the sorting matrix using formula:

[math]\displaystyle{ S^0(R) = \frac{\exp(R)+\epsilon}{\tau} }[/math]

[math]\displaystyle{ S^k(R) = F_c(F_r(S^{k-1}(R))) }[/math]

where [math]\displaystyle{ F_r }[/math], [math]\displaystyle{ F_c }[/math] are the row and column wise normalization function, [math]\displaystyle{ \epsilon }[/math] is the injected standard i.i.d Gumbel noise and [math]\displaystyle{ \tau }[/math] is the temperature hyper-parameter.

Advantages over Vanilla Attention

  • Reduced memory
    • Sparse Sinkhorn Attention only needs [math]\displaystyle{ O\left(\left(\frac{l}{N_b}\right)^2 + N_b^2\right) }[/math], while Vanilla Attention requires [math]\displaystyle{ O(l^2) }[/math].
  • Reduced Computational Complexity
    • SORTCUT (a Sparse Sinkhorn Attention variant) only requires [math]\displaystyle{ O(lN_k) }[/math], where [math]\displaystyle{ N_k }[/math] is a hyperparameter which is much smaller than [math]\displaystyle{ l }[/math]. In contrast, Vanilla Attention requires [math]\displaystyle{ O(l^2) }[/math] time.
  • Ability to Capture Long Range Dependency
    • Through block sorting, relevant tokens are grouped adjacently, even if they are originally far apart. Then, through local attention, long range dependencies are captured. By comparison, Vanilla Attention struggles with this since it calculates attention for every tokens with all other tokens, and Softmax entropy collapse naturally reduces the importance of distant tokens, making it difficult to model long-range relationships both computationally and statistically.

Big Bird Sparse Attention

Intuitively, we picture the tokens as nodes in a directed graph where attention is calculated between two tokens if an edge exists to connect their nodes. Then, using the adjacency matrix of this graph, attention is not calculated between disconnected nodes.

In order to decide which nodes to 'connect,' the authors combined several methods:

  • Random Attention
    • Nodes are connected randomly (whether attention is calculated for a particular token pair is determined randomly)
  • Window Attention
    • Intuitively, tokens that are closer together probably provide more semantic information about each other. Therefore, attention is calculated between tokens that occur within a particular distance from each other.
  • Global Attention
    • Some tokens are 'celebrities' so attention between them and all other tokens is calculated. This results in a maximum distance of 2 between any two tokens (nodes in the graph)

Theoretically, the authors prove that any star-graph (graph which represents global attention) provides a lower bound for performance. Empirically, combining all three of the above methods into what the authors call 'BIGBIRD' results in higher model performance.

We may intepret the naming as:

  • "Big": The ability of handling long sequences and overcome complexity of higher orders
  • "Bird": A metaphor for bird's eye-view, combining global tokens (ability to observe the whole sequence), local sliding windows (ability to focus on local context), and random attention (ability to account for details that are random left over)


Limitations

Although BigBird achieves linear complexity, it does not come for free.

  • Lower Bounds
    • A natural task is displayed by the author that can be evaluated by O(1)-layer full attention mechanism. However, any sparse attention mechanism that has polynomial inner product evaluations and logarithmic length of the shortest path between any two nodes of the mechanism's digraph (thus more than just BigBird) needs to solve the task in at least polynomial layers.

Attention with Linear Biases (ALiBi)

Currently, models struggle to extrapolate. Extrapolation is the ability to produce sequence lengths at inference time that are longer than the sequences they were trained on. For example, if a model is trained on a dataset where the longest sequence is 1024 tokens, then it will perform poorly when asked to generate a sequence of 2048 tokens. If we can solve this problem then, in theory, training can become much more efficient as we can train on shorter sequences without sacrificing performance.

The authors hypothesized that the position representation method used in these models influenced extrapolation ability. The authors first tested the extrapolation performance of transformers that use absolute position encodings. This includes sinusoidal position embeddings (as seen in vanilla transformer) and rotary position embeddings (seen in open source GPT-3). The authors found that these transformer models did not extrapolate well.

Then, they tested the extrapolation performance of transformer that employed T5 bias as the position method. This method is a relative position method; it does not add positional information to word embeddings. Instead, after computing the query-key score in attention and before executing softmax, it adds a learned, shared bias that is dependent on the distance between the query and key. This model produced an improve performance. The authors concluded that a relative position method improves extrapolation. However, due to the high computation cost, T5 bias is infeasible.

ALiBi replaces the vanilla transformer [math]\displaystyle{ softmax(QK^T) }[/math] with [math]\displaystyle{ softmax(QK^T + bias) }[/math], adding a lower triangular matrix of biases in order to encode the relative positioning of tokens within the attention calculation.

The bias matrix is of the form [math]\displaystyle{ \begin{bmatrix} 0 & 0 & 0 & 0\\ -1 & 0 & 0 & 0\\ -2 & -1 & 0 & 0\\ -3 & -2 & -1 & 0\\ \end{bmatrix} }[/math]

The addition of this bias leads to a recency bias. A query-key pair that are far apart will have a large bias. The increasing penalty results in less information learned from distant query-key pairs.

The authors hypothesized that this could replace the positional encoding in transformers. For model architectures with multiple attention heads, the weighting of these biases can vary per head (defaulting to a geometric sequence of [math]\displaystyle{ \frac{1}{2^m} }[/math] where m is the number of heads.)

So even through ALiBi is considered to be a spare attention method, it didn't specify sparsity explicitly. As tokens gets far from one another, the bias becomes small and implicitly, we have sparsity.

SpAtten

SpAtten is a method for pruning both tokens and attention heads in transformer models to improve efficiency. By removing tokens and heads with pruning, SpAtten reduces computational complexity and memory usage.

There are four key components to SpAtten:

1. Cascade Token Pruning

2. Cascade Head Pruning

3. Progressive Quantization

4. Specialized High Parallelism top-k Engine.

The first two are motivated by the inherent redundancy in human languages. Cascade token pruning helps to identify and remove unimportant tokens in a sentence. Similarly, cascade head pruning eliminates unnecessary attention heads. Unlike traditional weight pruning techniques, cascade token and cascade head pruning operates dynamically and selects tokens and heads to prune on the fly during inference. As an example, given an initial sentence of "As a visual treat, the film is almost perfect." with 11 tokens and 12 heads, cascade pruning results in a final output of "film perfect" with only 2 tokens and 8 heads. The final output retains the essence of the initial sentence but would use less memory and be less computationally expensive to train.

The latter two components play a pivotal role in ensuring the efficient implementation of SpAtten. Progressive quantization is designed to strike a balance between computational efficiency and accuracy by dynamically adjusting the precision of computations. The process begins by handling only the most significant bits (MSBs) of the input data, enabling faster computations and reduced memory access. If the confidence in the result is insufficient, the system progressively incorporates more bits of data to recompute the output with higher precision. This approach is further refined by assigning different bitwidths to different layers, optimizing performance across the model.

Similarly, SpAtten employs a specially designed top-k engine to enable real-time decision-making for pruning tokens and attention heads. This engine efficiently ranks token and head importance scores in linear time, avoiding the inefficiencies of traditional sorting methods. Instead of sorting an entire array, the top-k engine uses a quick-select module to identify the kth largest element and filters the array based on this threshold. This technique is also implemented in a highly parallelized manner, with multiple comparators working simultaneously to achieve linear-time ranking. This streamlined process ensures rapid and accurate pruning decisions, contributing to SpAtten's overall efficiency.

Topic 10: Linear Attention

Linear attention tries to address the efficiency limitations of traditional softmax attention. Standard attention (also called vanilla attention) is calculated as:

[math]\displaystyle{ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V }[/math]

The computation complexity of this method is [math]\displaystyle{ O(n^2d) }[/math] for sequence length [math]\displaystyle{ n }[/math] and representation dimension [math]\displaystyle{ d }[/math] and the main reason behind that is the multiplication of [math]\displaystyle{ Q }[/math] and [math]\displaystyle{ K^T }[/math], produces a large [math]\displaystyle{ n \times n }[/math] attention matrix . If we remove the softmax function from this equation, then we can reduce the computational intensity because the matrix multiplication can be reordered to [math]\displaystyle{ Q(K^T V) }[/math]. Multiplying [math]\displaystyle{ K^T V }[/math] first results in smaller matrix multiplications. However, this linear mapping is insufficient to capture the complex relationships in the data. To reintroduce more complexity, we can replace the softmax function with some kernel with approximates it. Recall that for any kernel [math]\displaystyle{ K, K(x, y)= \Phi(x)^T \Phi(y) }[/math] for a matrix [math]\displaystyle{ \Phi }[/math]. If we apply this kernel to Q and K, then we can approximate vanilla attention with a much more efficient mechanism.

Figure 1: Impossible Triangle

Key Approaches to Linear Attention

Retentive Network (RetNet): A Successor to Transformer for Large Language Models

The authors of this paper introduce a new architecture that can achieve parallel training, fast inference, without compromising performance solving the inherent complexity in vanilla attention mechanism. They claim that the new architecture makes the impossible triangle possible. RetNet combines the advantages of transformer performance and ability for parallel training with the lower-cost inference of RNNs by introducing a multi-scale retention mechanism to replace multi-head attention that incorporates three main representations:

  1. Parallel: allows GPU utilization during training.
  2. Recurrent: enables low-cost inference with [math]\displaystyle{ O(1) }[/math] complexity in terms of computation and memory.
  3. Chunkwise recurrent: models long-sequences efficiently.

We can show the retention mechanism has a dual form of recurrence and parallelism. First let's show RetNet in its recurrent form:


[math]\displaystyle{ S_n = \gamma S_{n-1} + K_n^T V_n }[/math]

[math]\displaystyle{ \text{Retention}(X_n) = Q_n S_n }[/math]


In this formula, [math]\displaystyle{ \gamma }[/math] is a decay factor, and [math]\displaystyle{ Q }[/math], [math]\displaystyle{ K }[/math], [math]\displaystyle{ V }[/math] are learned projections of the input. During training, this can be computed in parallel form. Now, we can show the parallel representation of retention:


[math]\displaystyle{ \text{Retention}(X) = (QK^T \odot D)V }[/math]


Where [math]\displaystyle{ D }[/math] is a matrix that combines causal masking and exponential decay along relative distance. For better understanding the role of matrix [math]\displaystyle{ D }[/math], let's take a look at its formula:


[math]\displaystyle{ D_{nm} = \begin{cases} \gamma^{n-m}, & n \geq m \\ 0, & n \lt m \end{cases} }[/math]


This formulation ensures that tokens can only attend to previous tokens (which they call it, causal masking) and the attention strength decays exponentially with distance (controlled by parameter [math]\displaystyle{ \gamma }[/math]) .

RetNet consists of stacked identical L blocks similar to Transformer. Each block consists of two modules: multi-scale retention (MSR) module and a feed-forward network (FFN) module. It takes the embeddings of input sequence [math]\displaystyle{ X^0= [x_1, \dots, x_{|x|}] \in \mathbb R^{|x| \times d_{model}} }[/math] and compute the output [math]\displaystyle{ X^L }[/math]:

[math]\displaystyle{ Y^l = MSR(LN(X^l)) + X^l }[/math], [math]\displaystyle{ X^{l+1} = FFN(LN(Y^l)) + Y^l }[/math] where [math]\displaystyle{ LN(.) }[/math] is LayerNorm.

Simple linear attention language models balance the recall-throughput tradeoff

This paper focuses on the task of associative recall and discusses the memory-recall tradeoff. The memory-recall tradeoff basically states that higher recall requires higher memory consumption, as shown in Figure 2.

Figure 2: Empirical performance showing the memory-recall tradeoff

The authors goal was to develop a method to improve this tradeoff. They do so by stacking linear attention blocks with sliding window attention blocks. Recall that one of the limitations of linear attention is relatively poor performance compared to vanilla attention. By selecting this type of architecture, the authors compromise on global performance by allowing linear attention to map global dependencies while retaining high performance within a local window by using the sliding window attention to map local dependencies. As seen in figure 4, they achieve this goal.