stat940W25-presentation: Difference between revisions
No edit summary |
|||
Line 48: | Line 48: | ||
* Boundary loss, if you were to provide boundary conditions with the problem. | * Boundary loss, if you were to provide boundary conditions with the problem. | ||
* PINN loss: ensure the model respects the differential conditions. | * PINN loss: ensure the model respects the differential conditions. | ||
=== Summaries highlighting key points of the paper === | |||
UPINN combines PINN and UDE, bridges the limitation of both approaches. UPINN consumes less computation power than PINN, but robost to noise and can perform decently in low data case. However, UPINN still requires notable computational resource and sensitive to the choice of hyperparameters. Moreover, UPINN has low interpretability. | |||
=== Experimental Validation === | === Experimental Validation === |
Revision as of 15:02, 1 April 2025
Notes on Presentations
Group 1 Presentation: Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data
Paper Citation
Podina, L., Eastman, B., & Kohandel, M. (2023). Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data. In Proceedings of the 40th International Conference on Machine Learning (Vol. 202). PMLR, Honolulu, Hawaii, USA.
Background
Differential equations
Examples of differential equations in physics include Newton's second law (which is an ordinary differential equation), the Navier-Stokes equations (which are partial differential equations), etc.
Existing methods of solving differential equations:
- Analytical methods, such as integration or separation of variables.
- Numerical methods, such as finite difference, finite volume, or finite elements.
- Data-driven approaches: these involve Universal Differential Equations (UDEs) and Physics-Informed Neural Networks (PINNs), which are the focus of this paper.
Introduction to PINNs
With (many) machine learning approaches, the goal is to approximate the solution to a DE using a feed-forward neural network, optimized with MSE loss. The key difference that makes it physics-informed is an extra term in the loss, which penalizes the model for deviating from the governing DE.
Introduction to UDEs
Here, the differential equation is expressed as a sum of two terms: the known physics-based model and an unknown neural network.
Paper Contributions
Universal Physics-Informed Neural Networks (UPINNs)
PINNs and UDEs are combined, addressing the limitations of the original methods, while sharing their benefits.
The model integrates three network components:
- Surrogate Solution Network U: links to the measurement loss
- Unknown Differential Operator Network F: with with U within the PINN loss
- Boundary Condition Network B: links to the boundary loss
The loss function contains three terms:
- MSE
- Boundary loss, if you were to provide boundary conditions with the problem.
- PINN loss: ensure the model respects the differential conditions.
Summaries highlighting key points of the paper
UPINN combines PINN and UDE, bridges the limitation of both approaches. UPINN consumes less computation power than PINN, but robost to noise and can perform decently in low data case. However, UPINN still requires notable computational resource and sensitive to the choice of hyperparameters. Moreover, UPINN has low interpretability.
Experimental Validation
1. Lotka-Volterra Model
They first experimented with the UPINN on the Lotka-Volterra system of differential equations, which are used to model predator-prey dynamics:
[math]\displaystyle{ \frac{dx}{dt} = \alpha x - \beta xy }[/math]
[math]\displaystyle{ \frac{dy}{dt} = -\delta y + \gamma xy }[/math]
The UDE and PINN were individually tested on two scenarios: sparse data (where there are very few input data points) and noisy data. Alone, each model did not do very well, especially when the data was very sparse or very noisy. When the UPINN was used, the solution was quite good, even with high sparsity or noise.
2. Viscous Burgers’ Equation
Their next experiment was used Burger's equation, a system in fluid dynamics.
[math]\displaystyle{ \frac{\partial u}{\partial t} = -u \frac{\partial u}{\partial x} + \nu \frac{\partial^2 u}{\partial x^2} }[/math]
3. Cell Apoptosis Model
Summaries of key points
This paper introduces Universal Physics-Informed Neural Networks (UPINNs) for discovering unknown terms in differential equations (ODEs/PDEs) from sparse and possibly noisy data. It combines the strengths of standard Physics-Informed Neural Networks (PINNs)—which incorporate prior knowledge of the governing equations—while still allowing parts of the underlying model to remain unknown and be learned from the data. Unlike previous methods such as Universal Differential Equations (UDEs), which can falter in noisy and small-data regimes, UPINNs maintain good accuracy by:
1. Leveraging collocation points in the loss function to incorporate the differential equation constraints ("physics").
2. Adding a neural network component to represent the unknown terms of the operator.
3. Applying symbolic regression (e.g., AI Feynman) to convert the neural approximation of the hidden terms into interpretable, closed-form expressions.
Extensive experiments on the Lotka–Volterra system, a viscous Burgers’ PDE, and a cell apoptosis ODE show that UPINNs outperform UDEs in handling higher noise and fewer data points, while still recovering the hidden differential-operator terms accurately.
Furthermore, symbolic regression improves interpretability by converting neural outputs into explicit equations. This interpretability, combined with robustness to sparse and noisy data, makes UPINNs especially promising for scientific discovery. Potential applications include systems biology, fluid dynamics, and environmental modeling. Future research directions could address scalability to higher-dimensional PDEs and uncertainty quantification.
Related work
Nonlocal Physics-Informed Neural Networks (nPINNs): nPINNs introduce a universal nonlocal Laplace operator that encompasses classical and fractional Laplacians. This framework is utilized for parameter identification in nonlocal models, demonstrating consistency and accuracy in capturing operator behaviours.
Group 1 Presentation: Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data
Paper Citation
Podina, L., Eastman, B., & Kohandel, M. (2023). Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data. In Proceedings of the 40th International Conference on Machine Learning (Vol. 202). PMLR, Honolulu, Hawaii, USA.
Background
In many scientific problems, we model systems using differential equations. But in practice, we often don’t know the full form of these equations, and we rarely have clean, enouhg data to work with. This makes it hard to apply standard data-driven approaches or even physics-informed models that assume the structure is already known. The goal of this paper is to develop a method that can discover the unknown parts of a differential equation directly from data, even when the data is sparse and noisy, and return an interpretable symbolic expression.
The authors introduce a method called Universal Physics-Informed Neural Networks (UPINNs). It combines the strengths of two existing approaches:
1. PINNs, which integrate known physical laws into neural network training by including differential equations in the loss function.
2. UDEs, which use neural networks to model unknown terms in a differential equation.
UPINNs aim to do both: they use physical constraints to guide the training process (like PINNs), but they also allow for the discovery of unknown components of the equation (like UDEs). Importantly, once a neural network learns those unknown components, the method uses symbolic regression (via the AI Feynman tool) to extract a readable formula—something scientists can actually interpret.
Main Idea
The model uses three neural networks: one approximates the solution to the differential equation, one learns the unknown part of the equation (i.e., the missing dynamics), and the other one (optional) models unknown boundary conditions if needed.
Training is guided by a loss function with three parts:
1. Fit to observed data,
2. Match the known physical dynamics,
3. Satisfy boundary conditions.
To help compensate for the limited data, they add “collocation points”, additional locations in the domain where the model must follow the known physics. These points don’t require real data and can be sampled freely, so they’re a cheap way to strengthen training.
Experimental & Result
The paper tests UPINNs on three systems:
(a) Lotka-Volterra Predator-Prey Model (ODE) The model successfully recovers the hidden interaction terms, even with very sparse data.
It outperforms UDEs especially when noise is present or data is limited.
(b) Viscous Burgers’ Equation (PDE) Even with data from only two time points, UPINNs can reconstruct the solution and recover the nonlinear transport term (−u ∂u/∂x) with reasonable accuracy.
(c) Apoptosis (Cell Death) Model (ODE) The method learns complex nonlinear terms involving protein concentrations.
It performs well despite flat dynamics late in the simulation, which normally makes learning harder.
In all three cases, symbolic regression is applied to the learned neural network and is often able to recover the correct functional form of the hidden terms. When comparing against UDEs, UPINNs are more robust to noise and return more accurate symbolic expressions.
UPINNs are useful when you:
1. Only have limited, noisy measurements of a system.
2. Know part of the physical model but not all of it.
3. Want interpretable results, not just predictions.
In short, it’s a flexible way to discover unknown dynamics from data, while still respecting the physical structure you already know. This is particularly helpful in scientific domains where experimentation is expensive or data is inherently sparse.
Group 2 Presentation: EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty
Presented by:
Kareena Bhalla and Chelsea Huffman
Paper Citation
Li, Y., Wei, F., Zhang, C., Zhang, H. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty. arXiv. doi.org/10.48550/arXiv.2401.15077
Background
In this paper, we are looking at Large Language Models (LLMs). Autoregressive decoding in LLMs refers to generating text one token at a time, the model basing its predictions on the tokens that came before it. It's inefficient and costly.
Speculative Sampling
Speculative sampling is a technique meant to reduce the computational cost and runtime of autoregressive decoding. The process consists of two main parts:
- Draft stage: A small, fast model suggests some tokens.
- Verification stage: In parallel, the large main LLM verifies these tokens and selects the best ones.
Eagle also has both a draft stage and a verification stage.
How do we choose a draft model that functions like the big LLM, but faster? One approach is to use a reduced version of the main LLM, but this doesn't work when there is no small version available, also using a model with fewer parameters often come with high overhead and reduced accuracy.
Technical Contributions
Extrapolation Algorithm for Greater Language Model Efficiency (EAGLE)
Before making a prediction, the model looks ahead by one token/word in the sequence.
One advantage of the EAGLE method is that only a single decoder layer needs to be trained rather than an entire draft model. This makes the training process extremely efficient in terms of runtime/computational cost.
The EAGLE method makes improvements to the process based on the following 2 observations:
1. Autoregression is simpler at the feature level than the token level.
2. The uncertainties from the sampling process negatively affect the performance.
Feature-Level Autoregression
Unlike traditional speculative sampling methods that operate at the token level, EAGLE performs autoregression at the feature level, specifically utilizing the second-to-top-layer features of the LLM. This approach simplifies the drafting process and enhances efficiency.
Experimental Results
- EAGLE is much faster than ordinary autoregressive decoding.
- EAGLE can be applied to various LLMs without reducing the model's generation quality. Furthermore, it is also compatible with other generation acceleration methods; for example, the paper goes through a quick case study of combining EAGLE with an optimization technique called gpt-fast.
- EAGLE can generate approximately 3.2–4.5 tokens per pass. Regular, vanilla decoding only generates one token per forward pass, so this is a notable speed-up in token generation.
- EAGLE provides the greatest computational speed-up in code generation tasks. This is likely because code contains fixed templates (e.g., repetitive structures, strict and specific rules, etc.), thereby making it easier to generate plausible and accurate drafts.
Summaries of key points
EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency) is a framework designed to accelerate the text generation of large language models (LLMs) without sacrificing the original distribution of generated tokens. Traditionally, LLMs rely on autoregressive decoding, generating text one token at a time and incurring heavy computational costs. Speculative sampling, introduced in earlier works, seeks to reduce these costs by splitting text generation into a faster *“drafting”* step and a single-step “verification” step using the main LLM, thereby validating multiple drafted tokens in parallel. However, existing methods often struggle to balance high draft accuracy with low overhead, limiting their speed improvements.
EAGLE addresses these issues by predicting hidden-layer features rather than tokens. Specifically, it focuses on the second-to-top-layer feature vector of the original LLM. Because these features exhibit more consistent structure than raw language tokens, they can be modeled with fewer errors and simpler modules. Additionally, EAGLE includes a small Autoregression Head that factors in one-step-ahead token information—this is critical for mitigating the inherent uncertainty in feature prediction. When generating text, the main LLM determines which token is sampled at a given step, but from the feature-only perspective, it is impossible to know the token that will appear. By feeding the token as well as the past feature states into a lightweight single-layer decoder, EAGLE’s draft model can produce a highly accurate prediction of the next feature. After obtaining the predicted features, EAGLE applies the LLM’s original output layer (the LM head) to sample candidate tokens in parallel.
Once a tree of drafted tokens is produced, the method performs a single verification pass in the main LLM. In one forward call, EAGLE obtains the LLM’s probabilities for all tokens in this tree and accepts or rejects each token based on a mathematically proven acceptance criterion. Crucially, this ensures the distribution of tokens matches what the LLM would have generated if it performed naive, step-by-step autoregressive decoding.
In evaluations across code (HumanEval), mathematical reasoning (GSM8K), instruction following (Alpaca), and multi-turn dialogue (MT-bench) datasets, EAGLE yields 2.7–3.5× speedups for models like LLaMA2-Chat 70B and 2.8–3.3× for smaller Vicuna variants on single-GPU setups. Even for more modest setups, or MoE-based models like Mixtral 8×7B Instruct, EAGLE still demonstrates strong acceleration. The approach does not require finetuning the main LLM itself; only a small draft network is trained on a fixed dataset (e.g., ShareGPT). Because the overhead of this training is low—typically 1–2 days even for 70B-parameter models—EAGLE proves highly practical.
Overall, it offers a simple but powerful way to generate tokens in parallel while retaining exact quality and distribution guarantees, making it a compelling choice for real-world systems seeking reduced latency in LLM services.
Related Works
There is now a new version of Eagle, Eagle 2. Eagle 2 improves from this version by introducing dynamically adjustable draft tree. It would adjust the draft tree based on the context and position, which is built based on speculative sampling. Upon some testing, Eagle can be 20% - 40% faster than EAGEL-1. There is also EAGLE-3 that is out Scaling up Inference Acceleration of Large Language Models via Training-Time Test which can be found here: https://arxiv.org/html/2503.01840v1.EAGLE-3 advances the acceleration of LLM inference by shifting from feature prediction to direct token prediction and employing multi-layer feature fusion through a technique called training-time test. These enhancements enable the draft model to fully benefit from scaled-up training data, resulting in notable speedup ratios.
In the draft stage, other related works have utilized different methods. Speculative sampling and Lookahead use tokens to predict tokens. Medusa uses features to independently predict tokens.
In the verifying stage, other related works have modified acceptance probabilities (DistillSpec (Zhou et al., 2023)), or the acceptance method (BiLD (Kim etal., 2023) and Medusa (Cai et al., 2023)).
Other related methods that reduce the latency per forward pass including distillation, quantization and pruning.
Group 2 Presentation: EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty
Presented by:
Kareena Bhalla and Chelsea Huffman
Paper Citation
Du, Y., Ram, D., Liu, X., Su, Y., Liu, S., Lee, J., Mohamed, A., & Ma, T. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty. arXiv preprint arXiv:2402.00842.
Summaries of key points
Speculative sampling speeds up language model inference by having a fast draft model guess multiple tokens, which are then verified by a slower, accurate model. While effective, this breaks the model’s training assumption of strictly sequential input, introducing feature uncertainty—the hidden states are based on possibly incorrect guesses. The paper proposes EAGLE, an energy-based method that learns to assess and modulate this uncertainty at the feature level. Instead of discarding uncertain tokens, EAGLE uses a learned energy score to decide how much to trust them during decoding. This leads to better performance, especially on tasks requiring reasoning or longer context, and is more robust than previous speculative decoding approaches Problem Motivation: Speculative decoding (used to speed up LLM inference by parallelizing token generation) is great for efficiency but introduces a mismatch between training and inference: training assumes sequential decoding, but speculative sampling adds a “guess-and-check” step that breaks that assumption.
Key Idea: EAGLE (which stands for Energy-based Adaptive Guidance with Latent Evidence) proposes a new method to handle uncertainty that arises in speculative decoding. It adjusts the model’s internal feature representations to reflect this uncertainty, rather than just masking or ignoring it.
How It Works: Instead of relying on the regular transformer’s last hidden state, EAGLE builds an energy-based auxiliary model that learns to estimate whether a token guess is valid, using both the main model’s predictions and the speculative draft. This energy score helps modulate the influence of uncertain features during decoding.
Results: EAGLE shows better performance on downstream tasks compared to vanilla speculative decoding, especially on tasks that require reasoning or handling uncertainty — e.g., question answering or coding benchmarks.
Explanations to aid understanding
Speculative decoding in simpler terms: Imagine trying to write the next word in a sentence, but instead of just writing one word and waiting, you guess a bunch in parallel and then double-check them. This saves time, but makes it harder for the model to know what it should trust. EAGLE essentially adds a smart layer that acts like a “confidence gauge” for the guesses, using a learned energy function to decide how much to trust each speculative token and its underlying features.
Group 2 Presentation: EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty
Presented by:
Kareena Bhalla and Chelsea Huffman
Paper Citation
Li, Y., Wei, F., Zhang, C., Zhang, H. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty. arXiv. doi.org/10.48550/arXiv.2401.15077
Background
Large Language Models (LLMs) like LLaMA and Vicuna are powerful but notoriously slow during inference, especially because they generate one token at a time using autoregressive decoding. This sequential nature makes real-time applications difficult and expensive. Speculative sampling has emerged as a solution: it uses a smaller model (a “draft model”) to propose multiple tokens ahead of time, which are then verified in parallel by the original, larger model. This can lead to big speedups—but only if the drafts are accurate. The problem is, for many models (especially smaller ones like 7B), finding a good draft model is hard or inefficient, and prediction at the token level is noisy.
The paper introduces EAGLE – a speculative sampling method that takes a different approach. Instead of generating tokens directly, it works at the feature level (i.e., the hidden state just before the final output layer). It also addresses a key challenge: uncertainty in the feature sequence, caused by the randomness in token sampling. To fix this, EAGLE feeds the token sampled at the next time step (i.e., a “shifted” token sequence) into the draft model, giving it a clearer signal of what to predict next.
This method is designed to:
Be fast — achieving 2x to 3.7x speedups over vanilla decoding.
Be accurate — preserving the original LLM's output distribution.
Be plug-and-play — requiring no fine-tuning of the LLM and using a lightweight add-on model.
Main Idea
EAGLE consists of two main parts:
a. Drafting Phase Instead of predicting the next token, EAGLE predicts the next feature vector (from the LLM’s penultimate layer).
Then, the actual token is generated using the original LLM’s output layer.
The key idea: use both the current features and the next token to reduce ambiguity in what feature to predict.
b. Verification Phase A standard speculative sampling step: the full model (LLM) runs a single forward pass to verify the draft.
If accepted, the tokens are kept. If rejected, the process restarts from the failed point.
EAGLE supports tree-structured drafts, where multiple possible sequences are explored in parallel, boosting acceptance rates and reducing the number of passes.
Experimental & Result
EAGLE is tested on Vicuna and LLaMA2-Chat models (7B, 13B, 33B, and 70B), plus Mixtral 8x7B, across tasks like: Dialogue (MT-bench), Code generation (HumanEval), Math problems (GSM8K), Instruction following (Alpaca)
Key numbers:
For LLaMA2-Chat 70B: speedup of 2.7x to 3.5x
For Vicuna 13B: up to 3.76x on code generation
Compared to Lookahead and Medusa, EAGLE is consistently faster by 1.5x–2.1x
EAGLE also works well with gpt-fast (a quantization and compilation tool), achieving up to 160 tokens/sec on a single RTX 3090 — a strong result for a 7B model.
Training is efficient: even for 70B models, the draft module (just under 1B parameters) can be trained in 1–2 days on 4×A100 GPUs.
It is a very useful approach because:
1. No need to fine-tune the full LLM – only the draft model is trained.
2. Preserves output distribution – unlike some other fast decoding methods, EAGLE guarantees the same output distribution as the original model.
3. Compatible with other speedup tools – works in combination with quantization or compilation.
Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces
Presented by:
Liang Wu, Jingcheng Yu, Candace Ng
Paper Citation
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752.
Background
Technical Contributions
- Introduce a selection mechanism for state space models that enables input-dependent state updates
- Develop a hardware-aware recurrent scan algorithm for efficient computation
- Propose Mamba, an attention-free, linear-time architecture that achieves state-of-the-art performance across multiple modalities
Architecture
Mamba integrates selective SSMs into a single homogeneous block. Each block comprises: a linear projection, a selective SSM layer, and an MLP block. This architecture is attention-free and scales linearly in sequence length.
Related Work
- Builds on structured SSMs (S4, S5) and architectures such as H3 and Hyena.
- Demonstrates theoretical and empirical connections to classical RNN gating.
Future Directions:
- Scale to larger models and refine training recipes
- Extend Mamba to multimodal tasks (e.g. video)
- Explore additional downstream affordances (fine-tuning, in-context learning, quantization)
Summaries of Key Points
1. Motivation: Faster and More Flexible Sequence Modeling
Large-scale "foundation models" (FMs) largely rely on Transformers (with attention) to handle long-range context. However, attention has quadratic complexity in sequence length and requires storing the entire key–value cache. This leads to high memory usage and slow generation. Meanwhile, Structured State-Space Models (SSMs) offer linear-time processing through either convolution or recurrence, achieving excellent results in domains like audio and vision. Yet they typically assume constant, time-invariant parameters, which hinders their ability to handle discrete or "content-based" tasks such as language.
2. Proposed Approach: Selective State-Space Models (S6)
The paper introduces a selection mechanism to inject content-based (input-dependent) reasoning into SSMs. Specifically:
- Parameterizing SSM transitions by inputs: Conventional SSMs have fixed transition matrices (A, B, C), but here, B and C—and crucially, the step-size parameter Δ—become input-dependent. This small twist allows the model to "select" which inputs matter and how strongly to update its hidden state, like gating in RNNs.
- Efficient "Selective Scan" Implementation: Because parameters vary with the input, one cannot use global convolution for training (which was key to SSM efficiency). Instead, the authors design a hardware-aware parallel scan on GPUs, fusing operators so that large intermediate states reside only in on-chip memory. This approach retains O(L) time scaling in sequence length while dramatically cutting overhead.
3. Mamba Architecture
To showcase selective SSMs, they build Mamba, a simple "attention-free" architecture. Each layer is effectively an MLP with a parallel SSM path included. By training only these blocks end-to-end (without multi-head attention), Mamba delivers:
- Linear-time training because it uses convolution-like or scan-based methods. - Constant-time inference per token (no caching entire past context), which is especially beneficial for extremely long sequences.
Additionally, Mamba integrates gating mechanisms within its SSM path, adaptively controlling information flow to better capture complex dependencies. Its streamlined layer design allows hardware-aware GPU optimizations, enhancing throughput and scalability compared to attention-based models.
4. Experiments Across Multiple Domains
1. Synthetic Tests (Selective Copying, Induction Heads): Mamba (or more precisely, its selective mechanism) solves tasks where ordinary time-invariant SSMs fail. It can ignore irrelevant tokens and preserve essential ones.
2. Language Modeling: On The Pile (multi-domain text), Mamba matches or exceeds standard Transformer perplexities when model size scales to ~1B parameters. It also outperforms other sub-quadratic methods (e.g., Hyena, RWKV). Notably, Mamba obtains 4–5× higher inference throughput than Transformers because it avoids large key–value caches.
3. Genomics: Mamba attains better perplexities than alternative long-sequence models (HyenaDNA) on a human genome corpus, especially at very long context (up to millions of tokens).
4. Audio Waveforms: The model outperforms prior state-of-the-art (e.g., SaShiMi) in generating raw speech signals, achieving lower FID scores and better sample quality.
5. Significance and Conclusions
Mamba demonstrates that:
1. Combining gating/selection with state-space recurrences, and
2. A carefully engineered GPU scan can outperform Transformers in certain tasks and match them in others, all with linear-time complexity. This opens a path for foundation models in domains needing extremely long context—like genomics, raw audio, and (possibly) next-generation language modeling—where Transformers become unwieldy. The authors highlight that future directions include scaling Mamba to billions of parameters (7B+), exploring interpretability, and combining it with other approaches (e.g., partial attention) for even broader performance gains.
Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces
Presented by:
Liang Wu, Jingcheng Yu, Candace Ng
Paper Citation
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752.
Summaries of key points
Nowadays, transformers are extremely popular because they are great at finding relationships in data. However, they struggle with long sequences since their calculations grow too fast as the input gets longer. In comparision, State Space Models (SSMs) are more efficient for long sequences, but they have a limitation: they can't easily decide what information to keep or forget. Mamba solves this problem by improving a standard SSM model so that it can selectively adjust its parameters based on the input.
Standard SSM (S4): How It Works
SSMs process sequences differently from transformers. Instead of checking all relationships in the data, they store past information in a hidden state and update it step by step.
The basic rule for an SSM is:
[math]\displaystyle{ h_t = A h_{t-1} + B x_t }[/math], [math]\displaystyle{ y_t = C h_t }[/math]
Where:
- [math]\displaystyle{ h_t }[/math] is the hidden state (memory of past information).
- [math]\displaystyle{ x_t }[/math] is the current input.
- [math]\displaystyle{ A, B, C }[/math] are numbers that the model learns to control how information flows.
- [math]\displaystyle{ y_t }[/math] is the output.
The problem? The issue is that A, B, and C stay fixed, meaning the model always updates its memory in the same way, no matter what input it sees. This makes it less adaptable, especially for complex tasks like language modeling.
Selective SSM (S6): Mamba’s Improvement
Mamba introduces Selective State Space Models (S6), where the key values A, B, and C change dynamically based on the input. This allows the model to decide what to store and what to forget in real-time.
Instead of using fixed values, Mamba updates them dynamically:
[math]\displaystyle{ h_t = A_t h_{t-1} + B_t x_t }[/math], [math]\displaystyle{ y_t = C_t h_t }[/math]
Now, A, B, and C depend on the input [math]\displaystyle{ x_t }[/math], making the model much more flexible and allowing it to adjust its memory as needed.
Also, Mamba uses a hardware-aware scan algorithm, which makes its computations 20-45x faster than older methods and 4-5x faster than transformers. This means it can handle longer sequences without slowing down.
Experiment Results
Mamba has been tested in different fields:
- Language modeling: Performs as well as transformers but is more efficient.
- DNA sequencing: Works well on long sequences, beating many existing models.
- Audio processing: Outperforms other SSM models for speech tasks.
Limitations of Mamba
Mamba is promising, but it has a few challenges:
1. Scaling up: It hasn’t been tested on extremely large models yet, so we don’t know if it can compete with transformers at that level.
2. Trade-offs in memory selection: Choosing what to remember and forget works well for things like text but might not be as useful for continuous data like speech.
3. Lack of a mature ecosystem: Transformers have been around longer, so they have more tools and techniques available. Mamba still has to catch up.
Even with these issues, Mamba is an exciting step forward for sequence modeling
Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces
Presented by:
Liang Wu, Jingcheng Yu, Candace Ng
Paper Citation
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752.
Background
Transformers are the backbone of most foundation models today, especially in language tasks. Their strength lies in the self-attention mechanism, which allows for flexible information routing across tokens. However, this comes with a computational cost: both time and memory scale quadratically with sequence length. This makes training and inference inefficient, especially for long sequences.
There’s been growing interest in finding efficient alternatives. Structured state space models (SSMs) offer a different route. Inspired by control theory, they compute with linear or near-linear complexity and have worked well in domains like audio. But they’ve consistently struggled with tasks involving discrete and information-dense inputs, like text. One major issue is that existing SSMs apply the same operations at each time step. This time-invariant design makes them fast but limits their ability to reason based on content.
Main Idea
Mamba introduces selective state space models, which allow the model to change how it processes each input based on the input itself. In earlier SSMs, parameters like B and C (which control how inputs are added to the state and how the state is turned into output) were fixed across time. In Mamba, these parameters vary with the input token.
This makes the model capable of:
1. Retaining relevant tokens while forgetting unimportant ones.
2. Filtering out noise or filler content.
3. Adapting its internal state in a context-aware manner.
Of course, this also breaks the efficient convolution trick that earlier SSMs used, since convolutions require fixed kernels. To deal with this, the authors implement a custom recurrent scan that is hardware-friendly and memory-efficient, especially on GPUs. This scan computes the model step-by-step, but in a way that avoids the usual memory bottlenecks of recurrent models.
Rather than using the traditional Transformer layout of attention followed by an MLP block, Mamba builds a simpler block:
a. It merges the sequence transformation (via the selective SSM) with the MLP into a single unit.
b. This block is stacked repeatedly, with residual connections and normalization in between.
c. The resulting model is easier to scale and implement than Transformer-based designs.
The paper also shows that the model works with real-valued parameters (as opposed to the complex-valued versions used in some previous SSMs), which improves compatibility with common deep learning hardware.
Experimental & Result
1. Synthetic Tasks Mamba is tested on synthetic problems like selective copying and induction heads, which are designed to measure a model’s ability to selectively remember and use earlier parts of a sequence. Mamba solves both tasks and generalizes well to sequences much longer than it saw during training—up to a million tokens.
2. Language Modeling The authors train Mamba on The Pile and evaluate it on standard zero-shot benchmarks (LAMBADA, HellaSwag, ARC, etc.). Key findings:
Mamba-1.4B outperforms Pythia models of similar and larger size.
It matches the performance of “Transformer++” variants while using fewer parameters.
It runs significantly faster at inference time because it does not rely on key-value caching.
3. Genomics (DNA modeling) On the HG38 genome dataset, Mamba shows better scaling than baselines as both model size and context length increase. Even on a challenging species classification task with closely related species (human, chimp, gorilla, etc.), Mamba performs well with long-context inputs.
4. Audio Modeling Mamba improves on S4-based baselines in waveform modeling. On the SC09 speech generation dataset, it beats previous models, including GANs and diffusion models on several automated metrics.
5. Efficiency The selective scan implementation is highly optimized:
Faster than FlashAttention-2 for long sequences.
4–5x faster inference throughput than Transformers of similar size.
Uses less memory, since it doesn’t need to cache key/value pairs during generation.
Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces
Presented by:
Liang Wu, Jingcheng Yu, Candace Ng
Paper Citation
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752.
Motivation and Problem Statement
(1) Attention bottleneck Transformers have demonstrated strong modeling capacity thanks to self-attention, but attention layers are known to scale quadratically with sequence length. This becomes problematic for very long sequences due to high compute and memory requirements.
(2) Subquadratic approaches Alternative models (e.g. linear attention, recurrent cells, and structured state space models, or SSMs) achieve subquadratic time complexity. However, in practice, they often lag behind Transformers—especially for discrete, “information-dense” domains like language.
(3) Key challenge Balancing the efficiency of subquadratic approaches with the “context-compression” power typical of full attention. In particular, standard linear time-invariant (LTI) SSMs struggle to handle tasks that require input-dependent selection or “content-based” reasoning.
Contribution
(1) Selective Mechanism The paper introduces a selective variant of state space models—“Selective SSM” or S6—whose parameters can dynamically depend on the current input token. This makes the model selectively propagate or ignore information, overcoming the rigidity of time-invariant recurrences.
(2) Hardware-Aware Recurrent Scan To handle the new time-varying SSM parameters, the authors propose a “scan” algorithm specialized for modern GPUs that leverages efficient memory management (fusing operations and reducing data movement). Despite the recurrent nature, it matches or exceeds the speed of FFT-based or other convolution-based methods for large sequence lengths.
(3) Mamba Architecture Built on top of this selective SSM layer, “Mamba” is a purely recurrent neural network that omits attention altogether, yet achieves competitive or better performance than Transformers across various domains (language, audio, genomics) while scaling linearly in sequence length.
Algorithm 1: Standard SSM (S4)
Structured State Space Models (S4) were initially designed to combine RNN-like recurrences with global convolutions:
(1) Core idea SSMs are defined by continuous-time parameters. These are discretized (using, for example, a Zero-Order Hold) so that the SSM can operate on discrete sequences.
(2) LTI property All prior S4-type models are time-invariant—the parameters stay constant for all time steps. This allows S4 to be computed as either (i) a single global convolution or (ii) a linear recurrence.
(3) Limitation Because the transition dynamics do not depend on the input, S4 cannot do content-based selection of the tokens it should store or ignore.
Computation of SSMs
(1) Convolutional mode Time-invariant S4 can exploit global convolutions (via the expanded kernel) to compute outputs in [math]\displaystyle{ O(LlogL) }[/math] or near-linear time. This approach avoids explicitly storing the large hidden state.
(2) Recurrent mode Alternatively, one can compute the same sequence mapping in a step-by-step fashion with [math]\displaystyle{ O(L) }[/math] steps but multiplied by the state dimension N. Typical parallel RNN implementations are memory-heavy because the hidden state dimension can be large.
(3) Trade-off Convolution mode is highly parallel (suitable for training) but struggles with time-varying parameters. Hence, prior S4-based models remain LTI to preserve convolution-based efficiency.
Limitations of Linear Time-Invariant SSMs
(1) Static transitions Standard SSMs (S4) cannot adapt or filter out unimportant inputs on the fly.
(2) Discrete data handling In discrete, information-dense tasks (like language), one often must selectively attend to critical tokens. Purely LTI models do not have a built-in mechanism for such content-based selection.
Algorithm 2: Selective SSMs (S6)
(1) Key idea Make parts of the SSM parameters become functions of the current input token, hence “selective.” At each time step t, the model can decide whether to store or forget the information from [math]\displaystyle{ x_t }[/math]
(2) Effect on recurrence The system is no longer linear time-invariant. However, it gains the ability to gate hidden states based on content—similar to an RNN gating mechanism but integrated within the SSM formulation.
Efficient Implementations of Selective SSMs
(1) Challenge With time-varying parameters, the global convolution trick no longer applies. A naive RNN-like approach would be slow (or memory-heavy) when N is large.
(2) Hardware-aware parallel scan The authors design a “selective scan” that operates recurrently but fuses memory reads and writes on GPUs, storing the full hidden state only in fast on-chip memory (SRAM). This avoids the usual overhead of a standard step-by-step approach.
(3) Performance Benchmarks indicate the proposed selective scan can be faster than attention beyond certain sequence lengths and avoids the large memory overhead of an attention KV-cache.
Mamba Architecture
(1) Simplified design Mamba blocks combine: A selective SSM layer (the new S6 variant). A Gated MLP pathway (or “Gated MLP”) in the same layer. This merges what used to be a multi-layer approach (SSM + MLP) into a single homogeneous block.
(2) Purely recurrent Mamba is completely attention-free. Each layer processes the input in linear time with the selective scan.
(3) Competitive performance Despite omitting self-attention, Mamba consistently achieves Transformer-quality or better results on diverse tasks, with far lower memory/time overhead at long context lengths.
Interpretations of Selection Mechanisms
(1) Variable spacing The selective update effectively allows the model to “jump over” irrelevant tokens, addressing tasks like “selective copying” where the positions of relevant tokens vary.
(2) Filtering context S6 can decide which tokens to integrate or forget. If the input at time t is unimportant, the update gate can suppress it, preventing noise accumulation in the hidden state.
(3) Boundary resetting When sequences are concatenated (e.g., different segments back-to-back), selective SSMs can “reset” the hidden state if the new segment is unrelated, mimicking the attention mask for different documents.
Overview of Experiments
(1) Synthetic tasks Selective Copying / Induction Heads: Demonstrates that selective SSMs learn to focus on relevant tokens. LTI SSMs fail these tasks, but S6-based Mamba solves them and even extrapolates correctly to much longer sequences.
(2) Language modeling
Scaling laws: Mamba shows strong scaling, matching or surpassing Transformers on the Pile dataset when model sizes go up to 1B+ parameters.
Zero-shot downstream tasks: Mamba outperforms or matches similarly sized Transformer baselines on tasks like LAMBADA, HellaSwag, and ARC.
(3) DNA sequences
Extremely long contexts: Mamba uses million-length context and still improves perplexity, while LTI SSMs degrade at such scales.
Classification tasks: Fine-tuning Mamba at lengths up to 1M tokens surpasses prior approaches on synthetic species classification.
(4) Audio generation
Long-range modeling: Mamba outperforms convolution-based S4 layers for autoregressive waveforms, especially beyond tens of thousands of time steps.
Speech quality: On a speech benchmark, Mamba cuts the previous state-of-the-art FID roughly in half, achieving more realistic outputs.
Speed and Memory Benchmarks
(1) Selective scan Achieves high training speed and memory efficiency on modern GPUs. Outperforms naive recurrent approaches by a large margin.
(2) Inference As a recurrent model with no need to store a growing KV-cache, Mamba obtains up to 5× higher inference throughput than Transformers of comparable size, especially at batch sizes beyond 1.
Related Work and Future Directions
(1) Transformer adaptations Many recent efforts approximate or modify attention (linear attention, kernel methods, etc.) to achieve subquadratic complexity—yet none consistently matched Transformers across modalities.
(2) Structured State Spaces Previous S4 variants excelled at continuous signals; discrete tasks were less successful due to the inability to filter input tokens selectively.
Future
(1) Scaled training: Investigating even larger Mamba models, or specialized domain tasks (vision, speech).
(2) Low-level optimization: The fused scan approach might be combined with novel GPU/TPU kernels.
(3) Formal interpretability: Mechanistically verifying how the model “chooses” tokens would improve transparency.
Limitations
(1) Discrete–Continuous tradeoff While the selective approach helps with discrete data, the authors note that certain initializations (real or complex) may still matter for stable training on continuous signals.
(2) Complex parameterization Tuning the selection parameters for each domain can be non-trivial, particularly if one wants to combine multiple forms of gating or advanced expansions.
(3) Residual data dependence Unlike attention, which explicitly attends to tokens by index, selective SSMs rely on gating from learned projections. Certain tasks might still favor explicit attention or local convolutions.
Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces
Presented by:
Liang Wu, Jingcheng Yu, Candace Ng
Paper Citation
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752.
Summaries of key points
Goal: With parameter-dependent input, content-aware selection is achieved while ensuring efficiency and flexibility.
Background: Although the traditional SSM model is linear and efficient, it is weak in dynamic content selection.
Methodology: By making SSM parameters vary with input, information can be selectively remembered. The parallel scan algorithm is used to preserve the linear time complexity. Completely attention-free architecture, each module (Selective SSM + MLP) is a stackable structural unified module.
Result: Performance on Pile data sets is like transformers and long text performance is more stable. Longer DNA sequence contexts can be used for better classification accuracy. Go beyond baselines like S4 on YouTube Mix and SC09.
Constructive critiques or reviews
The structure is clear, the transition is natural, and the explanation is full.
More images can be added from more intuitive descriptions.
Provide more detailed examples to help the audience understand better.
Clear explanations to aid understanding
Parallel Scan: avoids the problem of slow memory and high memory of traditional RNN inference, and the efficiency is almost equal to the attention mechanism but saves computing resources.
Selective SSM: Updates status when seeing keywords and skips irrelevant information.
S6 can be degenerated into classic RNN gating mechanism (generalization of RNN)
Group 4 Presentation: Learning spatiotemporal dynamics with a pretrained generative model
Presented by:
- Karolina Suszek
- Negin Amou
- Muhammad Azeem
Paper Citation
Z. Li et al., “Learning spatiotemporal dynamics with a pretrained generative model,” Nature Machine Intelligence, vol. 6, no. 12. Springer Science and Business Media LLC, pp. 1566–1579, Dec. 06, 2024. doi: 10.1038/s42256-024-00938-z.
Background
- Spatiotemporal dynamics: how the state of a physical system varies with space and time
- Real datasets often contain data with sparse measurements where there are a limited number of sensors available. There needs to be a way to convert the sparse measurement data into a full spatiotemporal field.
- Existing solutions learn to map the input to output and ignores missing data, but this reduces the models ability to generalize.
- Paper proposes the use of Sparse-Sensor-Assisted Score-Based Generative Model (S3GM) which uses unlabeled data durring training and can reconstruct incomplete data after training to make accurate predictions even when there isnt much information available.
- Key Idea: Learn the probabilith distribution of spatiotemporal data using score-based generative model and refine the samples via schochastic sampling
Technical Contributions
The main proposed model is the Sparse-Sensor-Assissted Score-Based Generative Model. It learns patterns from a large amount of data before hand. It also is unsupervised so it does not require any labels during training. It tries to learn the significant features of the data/natural patterns. After training, the model can be used to take incomplete data and reconstruct the missing parts to make predictions.
Core Components:
- Pre Training Stage: Learns the joint probability distribution of the data
- Generating Stage: Use a stochastic differential equation to refine and generate full field predictions
- Refinement Mechanism: Ensure Allignment with observations and enforce sequence consistency
Some of the common applications of this model are Turbulent flow modeling, climate forecasting, and physics-based simulations.
Summaries of key points
- Challenge Addressed: Traditional end-to-end learning models often struggle with generalization in reconstructing spatiotemporal dynamics, particularly when data is sparse—a common scenario in real-world applications.
- S³GM Methodology: Pretraining Phase: An unconditioned generative model is pretrained in a self-supervised manner on a comprehensive dataset, capturing the joint distribution of the system's dynamics. Generation Phase: The pretrained model is conditioned on new, sparse measurements to reconstruct and predict the full-field spatiotemporal dynamics.
- Validation and Performance: S³GM's efficacy was tested across multiple dynamical systems using synthetic, real-world, and laboratory datasets, including applications in turbulent flow modelling and weather forecasting. The results demonstrated that S³GM achieves high accuracy, generalizability, and robustness, even when faced with significant data sparsity and noise.
S³GM offers a promising approach for modeling and predicting complex spatiotemporal dynamics in situations where data is limited, leveraging the strengths of pretrained generative models to enhance performance in small data regimes.
Related Works
Some of the related works in this area are GPT-ST: Generative Pre-Training of Spatio-Temporal Graph Neural Networks. This framework employs a spatio-temporal masked autoencoder designed to capture both intra- and inter-cluster region semantic relationships, which are often overlooked in existing approaches. Another one is Spatio-Temporal Few-Shot Learning via Diffusive Neural Network Generation, where a generative pre-training framework (GPD) that addresses data scarcity in spatiotemporal modeling. By performing generative pre-training on neural network parameters optimized with data from source cities, the framework enables the generation of tailored neural networks guided by prompts.
Many other methods that map the the sparse measurements (input) to the full spatial temporal reconstructed field include the following:
- Using Fourier or Laplace transforms to learn mappings between function spaces. Fourier transform transforms the sparse input data into the frequency domain, where reconstructed techniques can be applied more easily.
- Using CNN's to learn latent representations of full spatial-temporal fields and reconstruct missing regions through an encoder and decoder
- Using PINN's to incorporate physics laws (differential equations) into the loss function. This can be useful when data is sparse or noisy as they enforce physical consistency in the absence of complete ground-truth data.
Group 4 Presentation: Learning spatiotemporal dynamics with a pretrained generative model
Presented by:
- Karolina Suszek
- Negin Amou
- Muhammad Azeem
Paper Citation
Z. Li et al., “Learning spatiotemporal dynamics with a pretrained generative model,” Nature Machine Intelligence, vol. 6, no. 12. Springer Science and Business Media LLC, pp. 1566–1579, Dec. 06, 2024. doi: 10.1038/s42256-024-00938-z.
Summaries of key points
Spatiotemporal Dynamics and S3GM
Spatiotemporal dynamics describe physical systems (e.g., climate forecasting, fluid dynamics) that evolve over space and time. However, in real-world situations, sensor data is usually incomplete, e.g., whether stations may cover only a few locations; Sensors might capture only the magnitude of velocity, missing direction. Standard deep learning models (FNO, PINN, U-Net, LNO, DeepONet) often struggle to adapt if the sensor setup changes or if the system behaves unexpectedly. This lack of generalization means models often need to be retrained for each new situation. To overcome this, S3GM is proposed.
How Does S3GM Work?
Instead of learning the full probability distribution [math]\displaystyle{ p(x) }[/math] (which is complex), S3GM learns the gradient of the data distribution, called the score function: [math]\displaystyle{ s(x) = \nabla_x \log p(x) }[/math]. This tells the model which direction the data is most likely to change, making learning more efficient.
It should also be noted that real-world data is often messy—noisy, incomplete, or missing. S3GM handles this using DSM:
- Adds a small amount of noise to the data and trains the model to remove it.
- Forces the model to focus on true underlying patterns rather than memorizing raw data.
By repeatedly removing noise, the model deeply understands the true data structure—even when parts are missing.
Once the model learns the data’s structure, it reconstructs missing information using stochastic differential equations (SDEs), which has 3 terms:
- Drift Term: guides reconstruction toward likely states.
- Diffusion Term: adds controlled randomness to explore multiple solutions.
- Correction Term: uses real sensor data to ensure consistency.
How Well Does S3GM Work?
S3GM is tested on four different systems to see how well it reconstructs and predicts missing data:
Experiment 1: Predicting Chaotic Behavior (Kuramoto-Sivashinsky Equation)
- Challenges Tested:
- Sparse Spatial Data (few sensor readings)
- Fourier Transform Domain (frequency-based measurements)
- Limited Initial Data (predict future states with few frames)
- Results:
- S3GM outperformed U-Net, FNO, and DeepONet, achieving lower errors.
- Stable even with limited input data.
Experiment 2: Reconstructing Turbulent Flow (Kolmogorov Flow)
- Challenges Tested:
- Trained on low-turbulence data.
- Tested on high-turbulence data.
- Results:
- Accurately reconstructed velocity fields and vorticity patterns.
Experiment 3: Climate Data Reconstruction (ERA5 Dataset)
- Challenges Tested:
- Extreme Data Sparsity (only 1% wind speed measurements available).
- Hidden Variables (missing temperature and pressure).
- Noisy Measurements (Gaussian noise added).
- Results:
- Successfully reconstructed missing climate variables.
- Performance improved with more sensor data.
Experiment 4: Flow Around a Cylinder
- Challenges Tested:
- Spatiotemporal Gaps (only specific cross-sections measured).
- Time-Averaged Data (some measurements were only available as averages).
- Results:
- Accurately reconstructed instantaneous and time-averaged flow fields.
- Outperformed physics-informed neural networks (PINNs).
Limitations of S3GM
While powerful, S3GM has limitations:
- Computational Cost: Pre-training is resource-intensive.
- Data Quality Dependence: Best performance with diverse, high-quality data.
- Generalization Issues: May struggle with entirely new dynamics.
- Processing Speed: Iterative reconstruction can be slower than traditional methods.
Despite these challenges, S3GM is a promising tool, and if these are improved, it could be even more powerful.
Group 4 Presentation: Learning spatiotemporal dynamics with a pretrained generative model
Presented by:
Karolina Suszekm, Negin Amou, and Muhammad Azeem
Paper Citation
Z. Li et al., “Learning spatiotemporal dynamics with a pretrained generative model,” Nature Machine Intelligence, vol. 6, no. 12. Springer Science and Business Media LLC, pp. 1566–1579, Dec. 06, 2024. doi: 10.1038/s42256-024-00938-z.
Background
This paper proposes Mamba, a new type of sequence model designed to match the modeling quality of Transformers while improving computational efficiency. The key innovation is a selective state space model (SSM) that can reason based on content and scale linearly with sequence length. While providing 4–5x quicker inference than Transformers of comparable size, Mamba shows strong performance across several domains—language, music, and genomics—positioning itself as a general-purpose backbone for foundation models.
Most large models today rely on Transformers, which are powerful but inefficient, especially for long sequences. Both training and inference are bottlenecked by the quadratic scaling with the sequence length of the self-attention mechanism. Efficient substitutes have been state space models (SSMs), which have drawn growing interest. Though current versions have fallen short on jobs like language modelling, these models are recurrent and scale linearly. A major drawback the writers point out is that conventional SSMs are time-invariant, applying the same dynamics at every time step regardless of the input. This limits their capacity to complete tasks needing content-based thinking.
Main Idea
The central idea of this paper is to improve state space models by making them selective. Traditional structured state space models (SSMs) apply the same linear operations at every time step, which works well for smooth or continuous data like audio, but not for discrete tasks like language modeling. The authors argue that this is because these models cannot adapt their behavior based on the content of the input.
Mamba addresses this by allowing some of the internal dynamics of the SSM to depend on the current input token. Specifically, the model modifies the SSM parameters (like Δ, B, and C) so that they are no longer fixed, but vary depending on what the model sees at each step. This makes it possible for the model to filter, retain, or discard information in a context-aware way.
This design sacrifices the ability to use fast convolutional implementations, but the authors introduce an alternative they call a selective scan—a custom, hardware-friendly way of computing the state updates efficiently on GPU. This allows the model to maintain linear computational complexity while being much more flexible than previous SSMs.
Mamba’s architecture is also deliberately kept simple. It does not rely on attention, nor does it use the usual Transformer-style MLP blocks. Instead, it stacks blocks based on this new selective SSM design, each combining sequence modeling and nonlinear transformation in one place.
Experimental & Result
The authors test Mamba on a wide range of tasks to show both its performance and its scalability.
On synthetic tasks like selective copying and induction heads, Mamba succeeds in learning long-range dependencies that other models fail to capture. It generalizes well even when the test sequences are far longer than the ones it was trained on, reaching up to a million tokens.
In language modeling, they train Mamba on The Pile and compare it to Transformer baselines like Pythia and RWKV. Despite being smaller in size, Mamba-1.4B performs better than Pythia-2.8B on several zero-shot benchmarks. It also matches the performance of more carefully tuned Transformer setups. One major advantage is that Mamba runs faster at inference time—achieving 4 to 5 times the throughput of Transformer models—because it avoids key-value caching.
For genomics, Mamba is trained on the human genome (HG38). Its perplexity improves as the sequence length increases, which is unusual—most models perform worse on longer contexts. On a classification task involving DNA from closely related species (humans, chimps, gorillas, etc.), Mamba significantly outperforms other models, especially at longer input lengths.
In audio modeling, Mamba is plugged into the SaShiMi framework and outperforms it on waveform prediction and speech generation. On the SC09 dataset, it scores better than WaveNet and DiffWave, despite having fewer parameters.
Finally, in terms of efficiency, the new scan implementation is fast. It’s faster than both a naive PyTorch loop and FlashAttention-2 for long sequences. Mamba’s speed and memory use scale linearly with sequence length, making it practical for real-world applications with long inputs or limited compute.
Group 5 Presentation: Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models
Presented by:
Guoqian Li, Xinlei Xu, Wenning Xu
Paper Citation
De, S., Smith, S., Fernando, A., Botev, A., Cristian-Muraru, G., Gu, A., Haroun, R., Berrada, L., Chen, Y., Srinivasan, S., Desjardins, G., Doucet, A., Budden, D., Teh, Y. W., Pascanu, R., De Freitas, N., Gulcehre, C. (2024). Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models. arXiv. arxiv.org/pdf/2402.19427
Background
RNNs and transformers
Recurrent neural networks (RNNs) are a basic architecture for handling sequential data; they are good for short sequences but can struggle with long ones.
Transformers generally perform better than RNNs and have become dominant in recent years. However, for large sequences, they become computationally expensive for a couple of reasons. Their global attention mechanism has a quadratic complexity. Furthermore, the key-value cache and the multi-query attention cache grows linearly with sequence length.
The paper proposes several innovations to improve upon existing attention mechanisms in neural networks:
- RG-LRU layer: a novel gated linear recurrent layer designed to replace Multi-Query Attention (MQA)
- Hawk: integrate multi-layer perceptrons (MLPs) with recurrent blocks
- Griffin: combine MLPs with a mix of recurrent blocks and local attention mechanisms
Technical Contributions
Model architecture
Their proposed models -- Hawk and Griffin -- use these following structures:
- Residual block: This is the main structure in the architecture. It starts with an RMSNorm layer being applied to the hidden state, followed by a temporal mixing block. There's a residual connection. Next, there is another RMSNorm layer followed by an MLP (multi-layer perceptron) block.
- Gated MLP block: This is a gated feedforward block. There are 2 branches: one is linear and one uses a GeLU activation. This part of the structure is used for feature selection.
- Temporal mixing block: They used 3 different kinds of temporal mixing blocks in their models: global multi-query attention, local sliding window attention, and recurrent blocks (what they proposed). The recurrent block contains 2 parallel branches. One has a GeLU activation, while the other has a temporal 1D convolutional layer followed by a RG-LRU layer (a proposal in this paper).
Performance and Efficiancy
- Hawk surpasses the reported performance of Mamba on downstream tasks.
- Griffin matches the performance of Llama-2, despite being trained on over six times fewer tokens.
- Both models demonstrate the ability to extrapolate to sequences significantly longer than those encountered during training.
Real-Gated Linear Recurrent Unit (RG-LRU)
RG-LRUs are inspired by regular Linear Recurrent Units (LRUs), but with a gating mechanism. They are more computational efficient, as they avoid the quadratic complexity that transformers have. Mathematically, the layer is defined as follows.
The recurrence gate:
[math]\displaystyle{ r_t = \sigma (W_a x_t + b_a) }[/math]
The input gate:
[math]\displaystyle{ i_t = \sigma (W_x x_t + b_x) }[/math]
[math]\displaystyle{ a_t = a^{cr_t} }[/math]
The output:
[math]\displaystyle{ h_t = a_t \odot h_{t-1} + \sqrt{1-a_t^2} \odot (i_t \odot x_t) }[/math]
Summaries of Key Points
Context and Motivation
Transformers have dominated large language modeling, but their attention mechanism can be expensive for very long sequences, particularly because of the quadratic complexity of attention and a Key-Value (KV) cache that grows linearly in sequence length. Recurrent neural networks (RNNs) compress entire contexts into a fixed-size hidden state and avoid storing a cache that grows with length, which can reduce memory and speed up inference. Historically, though, RNNs have been difficult to train at scale and often underperform Transformers on complex tasks. This paper proposes new recurrent-based architectures that can match or exceed Transformer performance while retaining the training efficiency of standard deep-learning pipelines and gaining a significant advantage during inference.
Proposed Architectures
Two models are introduced: Hawk, a purely recurrent architecture, and Griffin, which mixes gated linear recurrences with local attention. Both rely on a new Real-Gated Linear Recurrent Unit (RG-LRU) layer, which extends a previously studied linear recurrency (the LRU) by adding two gating mechanisms. One gate modulates how much of the new input to incorporate each step, and another governs whether the layer updates in a standard LRU style or simply retains the previous hidden state.
RG-LRU and Recurrent Blocks
RG-LRU’s diagonal recurrence matrix dramatically lowers the computational cost for large sequence lengths. The authors also include a lightweight convolution for near-token interactions. For Griffin, they interleave RG-LRU-based blocks with a local sliding-window attention, allowing short-range local attention plus a global recurrent state. This combination helps the model handle both local details and long-distance dependencies in a memory-efficient way.
Empirical Results
The authors scale Hawk and Griffin to billions of parameters and compare them with:
1. A baseline Transformer using Multi-Query Attention (MQA),
2. The Mamba recurrent model (from prior work),
3. Llama-2.
Their main findings include:
- Both Hawk and Griffin follow standard power-law scaling in validation loss versus training compute, matching or beating Transformers.
- Hawk matches or exceeds previous recurrent models like Mamba on language modeling benchmarks, despite training on fewer tokens.
- Griffin matches Llama-2’s accuracy on multiple downstream NLP tasks while training on only about one-seventh the token count.
- During training, these models achieve speeds comparable to Transformers (thanks to parallelized or fused kernels), and in some configurations are faster at long sequence lengths (2048 tokens or more).
- At inference time, recurrent and local-attention blocks avoid the large KV caches that hamper Transformers. Hence, Hawk and Griffin show lower latency and can decode tokens at significantly higher throughput, especially for long sequences.
- Hawk and Griffin extrapolate to far longer sequences than they were trained on. They also efficiently learn synthetic copying/retrieval tasks if explicitly trained on them. However, their out-of-the-box retrieval or copying for tasks not seen at training time is still below that of Transformers.
Significance
By mixing local attention with a novel gated recurrence, Griffin retains the best of both recurrent and attention-based approaches: it can model local context with a sliding window while maintaining a small, fixed-size global state for the entire sequence. This achieves strong performance on large-scale language modeling benchmarks while offering major advantages in memory footprint and decoding speed. The results position Hawk and Griffin as compelling alternatives to purely Transformer-based architectures for scalable language models.
Related work and Future direction
Griffin builds on several key areas in efficient language modelling, particularly in recurrent architectures and hybrid approaches combining recurrence and attention: State Space Models (SSMs), Hybrid Recurrence-Attention Models, Efficient Transformer Alternatives.
Scaling Griffin to Larger Models
The paper discusses Griffin models up to 14B parameters but suggests further exploration into scaling beyond this size to compete with GPT-4 or Gemini models. Investigating how Griffin handles long-context tasks in ultra-large models would be an interesting future study.
Memory and Efficiency Optimization
Griffin already improves efficiency compared to transformers, but further research into sparsification, quantization, and hardware-specific optimizations could enhance its real-world applicability. Exploring mixed-precision training and inference optimizations for mobile and edge devices.
Extending Griffin to Multimodal Learning
While Griffin is designed for language modeling, incorporating it into multimodal tasks (e.g., vision-language models, audio processing) could expand its impact. Combining Griffin’s recurrence mechanism with diffusion models or video understanding models might be a promising direction.
Group 5 Presentation: Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models
Presented by:
Guoqian Li, Xinlei Xu, Wenning Xu
Paper Citation
De, S., Smith, S., Fernando, A., Botev, A., Cristian-Muraru, G., Gu, A., Haroun, R., Berrada, L., Chen, Y., Srinivasan, S., Desjardins, G., Doucet, A., Budden, D., Teh, Y. W., Pascanu, R., De Freitas, N., Gulcehre, C. (2024). Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models. arXiv. arxiv.org/pdf/2402.19427
Summaries of Key Points
Motivation and background
Dominance of Transformers and their drawbacks: Transformers (with global attention) have driven most breakthroughs in large-scale language modeling. Despite their strengths, transformers struggle with quadratic complexity on long sequences; the large Key-Value (KV) cache grows linearly in the sequence length during inference, slowing down decoding for long inputs.
Potential of Recurrent Models: Recurrent neural networks (RNNs) offer an appealing alternative because they summarize the sequence with a fixed-size hidden state that is updated token by token (iteratively), thus avoiding a large KV cache. However, classical RNNs can be slow to train in practice and have struggled to match the performance of Transformers at large scales.
Goal: To develop recurrent architectures (and hybrids that combine local attention and recurrences) which (1) match or exceed Transformers’ performance on language modeling. (2) train efficiently at scale (both in FLOPs and hardware utilization). (3) offer improved inference efficiency (lower latency, higher throughput).
Architecture
Hawk (MLPs with recurrent blocks)
(1) Recurrent Block: Replaces self-attention entirely with a diagonal recurrent design. It processes tokens sequentially with fast iteration at inference time. (2) RG-LRU Layer: The crucial piece inside the Hawk model is the Real-Gated Linear Recurrent Unit (RG-LRU), a novel RNN layer that is stable, uses simple diagonal recurrences, and includes input and recurrence gates: Recurrence Gate [math]\displaystyle{ r_t }[/math] and input gate [math]\displaystyle{ i_t }[/math] each depend on the input [math]\displaystyle{ x_t }[/math] but not the recurrent state [math]\displaystyle{ h_(t-1) }[/math]. This yields stable, memory-friendly computations. The update equation mixes the previous hidden state [math]\displaystyle{ h_(t-1) }[/math] and a transformed input [math]\displaystyle{ x_t }[/math] using a diagonal recurrent weight [math]\displaystyle{ a }[/math]. This ensures the update is purely elementwise, minimizing overhead during inference. (3) Achieves strong performance on language modeling while avoiding the large KV cache burden of Transformers. Provides super-exponential memory (depending on the gating mechanism) and can, in practice, capture long-range dependencies.
Griffin (MLPs with a mixture of recurrent blocks and local attention)
(1) Motivation: Although recurrence alone scales well to long sequences, local patterns might still be easier to capture via (local) attention. Mixing the two yields a model that can leverage local attention for shorter-range patterns while relying on a fixed-size recurrent state for longer context. (2) Local Attention: Replaces global attention with a sliding-window approach (no quadratic cost in sequence length). The local window ensures that each token attends only to a fixed number of past tokens. (3) Interleaving Blocks. Griffin’s architecture interleaves the same RG-LRU-based recurrent block and a local attention block. This combination preserves efficiency (small recurrence state, bounded cache size), and matches or outperforms Transformers at the same scale.
Empirical Results and Main Findings
Power-Law Scaling
Hawk, Griffin, and a baseline Transformer (using multi-query attention, MQA) all exhibit a power-law relationship between validation loss and training FLOPs, resembling established Transformer scaling laws. Griffin consistently has slightly better loss–compute tradeoffs than the baseline Transformer at all scales tested.
Performance vs. Mamba and Llama-2
Hawk surpasses Mamba (a prior recurrent state-space model) on downstream tasks at 3B parameters despite training on half as many tokens (300B vs. 600B). Griffin (7B, 14B) matches or slightly outperforms Llama-2–7B/13B despite being trained on about seven times fewer tokens. Conclusion: The new architectures remain competitive on well-known benchmarks such as MMLU, HellaSwag, ARC, PIQA, and WinoGrande, even against popular open-source Transformers.
Implementation Challenges
RG-LRU’s diagonal recurrence is memory-bound on TPUs (and GPUs), as it performs elementwise updates rather than large matrix multiplications.The paper introduces a custom kernel (using Pallas on TPU-v3) that reads and writes hidden states in a linear scan and minimizes memory transfers, greatly speeding up training.
Speed Gains for Long Sequences
For short sequence lengths (e.g. 2K tokens), Griffin trains at a speed on par with the MQA baseline.As sequence lengths grow (4K, 8K), the MQA baseline slows down due to attention complexity, whereas the recurrent models maintain similar runtime. Hence, Hawk/Griffin gain a training-speed advantage at longer contexts.
Throughput and Latency Gains
Transformers (even MQA) rely on a KV cache growing linearly in sequence length. The overhead can become larger than the model parameters at long decode lengths. Hawk and Griffin keep a small, fixed-size recurrence state. This yields: (1) Lower latency: less reading/writing of large caches for extended contexts. (2) Higher throughput: bigger batch sizes fit into memory because the small recurrence state frees device memory otherwise used by attention KV caches.
Additional Observations (Long Context and Copy/Retrieve Tasks)
Extrapolation to Long Sequences
Hawk and Griffin can successfully exploit contexts longer than those used during training, showing better extrapolation than Transformers with typical positional embeddings (e.g. RoPE). Training them explicitly on 8K sequences (rather than 2K) further improves performance at longer contexts.
Copying/Retrieval Skills
Hawk alone may learn copy tasks but can sometimes lag behind Transformers in speed of learning. This matches prior observations about slow adaptation of purely recurrent state-space layers on synthetic tasks. Griffin (mixing local attention + recurrence) can learn copy tasks more quickly, similar to Transformers, and can extrapolate better to sequences beyond the training length.
Significance and Contributions
Challenging the Status Quo
This work demonstrates that RNN-based architectures can be trained at large scale as efficiently as Transformers, with the added benefit of small fixed-size hidden states.
High Performance at Fewer Tokens
Griffin achieves or exceeds Llama-2 performance while using substantially fewer training tokens, highlighting that it is not necessary to rely purely on full-sequence attention.
Practical Inference Advantage
Significantly reduced inference latency and improved throughput on long sequences is a major practical benefit in real-world applications when generating long responses and streaming data.
Blueprint for Hybrid Models
By combining local attention (for short-range patterns) with a gated linear recurrence (for stable, long-range capacity), Griffin strikes a balance that may inform future research in large language models beyond classical Transformers.
Group 6 Presentation: Learning to (Learning at Test Time): RNNs with Expressive Hidden States
Presented by:
Zhiyang Cheng and Pingchu Zhang
Paper Citation
Sun, Y., Li, X., Dalal, K., Xu, J., Vikram, A., Zhang, G., Dubois, Y., Chen, X., Wang, X., Koyejo, S., Hashimoto, T., & Guestrin, C. (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv. https://doi.org/10.48550/arXiv.2407.04620.
Summaries of key points
This paper revisits the traditional role of RNN hidden states, proposing that they can do more than just store information—they can enable learning at test time. The authors introduce a method where a hypernetwork dynamically generates the RNN’s weights as a function of its hidden state, allowing the model to adapt its behavior on the fly. This gives rise to what they call expressive hidden states, which encode both memory and the capacity to steer the model’s future updates. The approach effectively blurs the line between training and inference, treating test-time prediction as a form of continual adaptation. This results in stronger performance in settings like few-shot learning and online learning, where flexibility and rapid adaptation are crucial. Rather than relying on explicit optimization at test time (as in typical meta-learning setups), the RNN itself becomes the learner, continuously reshaping its internal dynamics based on the sequence it's processing.
While innovative, the method introduces nontrivial architectural and computational overhead. The use of a hypernetwork to produce weights at every time step means the model must manage a more complex parameter space and could become less scalable for long sequences or larger models. There's also the risk of instability, since small changes in the hidden state can lead to large changes in the generated weights. Regularization and careful design are needed to prevent the model from diverging. Another limitation is that while the paper shows strong performance on synthetic and controlled few-shot learning tasks, it doesn’t extensively benchmark on more complex natural language or real-world sequential data, leaving questions about generalization and practicality open.
Clear explanations to aid understanding
In a standard RNN, the weights are fixed during inference—you feed in tokens or sequence elements, and the hidden state updates based on those fixed rules. What this paper suggests is: what if the hidden state itself could influence the rules? So instead of always using the same weights, the RNN can generate new ones depending on what it's seen so far. This is done using a hypernetwork—a small neural network that outputs the weights for the main RNN. So as the RNN processes a sequence, it effectively reshapes its own behavior to fit the task or data distribution it's encountering. It’s like the RNN is learning while it’s making predictions, adapting in real-time to maximize performance without needing gradient descent at test time.
Group 6 Presentation: Learning to (Learning at Test Time): RNNs with Expressive Hidden States
Presented by:
Zhiyang Cheng and Pingchu Zhang
Paper Citation
Sun, Y., Li, X., Dalal, K., Xu, J., Vikram, A., Zhang, G., Dubois, Y., Chen, X., Wang, X., Koyejo, S., Hashimoto, T., & Guestrin, C. (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv. https://doi.org/10.48550/arXiv.2407.04620.
Summaries of key points
Goal: In Test-Time Training, make the hidden state into a small model that can be updated to improve the sequence modeling ability.
Background: The hidden state of RNNS is usually a fixed dimension that limits their expressiveness.
Methodology: Each step updates W gradients through a self-supervised task. Dual Form turns multi-step updates into a single matrix operation.
Result: In short sequences, the TTT model behaves similar to existing methods. In long sequences, TTT-Linear and TTT-MLP are significantly superior to Transformer and Mamba. TTT-Linear inference speed is closer to Mamba and faster than Transformer.
Constructive critiques or reviews
The presentation is clearly structured, and the slides include pictures and diagrams to help listeners understand better.
Turned on the camera to make it easier on the listener.
It can increase fluency appropriately.
Clear explanations to aid understanding
TTTN layer: Learn directly on the test sequence, and the update process is implemented through self-supervised learning.
Efficiency optimization: Improve computing efficiency with mini-batch and dual-form
Mamba: Mamba uses a state-space model for remote dependency capture.
Group 6 Presentation: Learning to (Learn at Test Time): RNNs with Expressive Hidden States
Presented by:
- Pingchu Zhang
- Zhiyang Cheng
Paper Citation
Sun, Y., Li, X., Dalal, K., Xu, J., Vikram, A., Zhang, G., Dubois, Y., Chen, X., Wang, X., Koyejo, S., Hashimoto, T., & Guestrin, C. (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv. https://doi.org/10.48550/arXiv.2407.04620
Background
For modern RNNs, performance in long context is limited by the expressive power of their hidden state of fixed size. Hence the authors introduced test-time training (TTT)
Technical Contributions
- Introduce TTT layers, where the hidden state is a model and the update rule is self-supervised learning, offering a new research direction.
- TTT-Linear, a simple implementation of TTT layers, outperforms Transformers and Mamba in evaluations.
- Improve the hardware efficiency of TTT layers through mini-batch TTT and the dual form, making TTT-Linear already a practical building block for LLMs.
Methodology
The key idea is to make the hidden state itself a model with weights, and the update rule a gradient step on the self-supervised loss. Then updating the hidden state on a test sequence is equivalent to training the model at test time.
Training a network with TTT layers
- Training the larger network as the outer loop and training weights within each TTT layer as the inner loop is preferred.
- TTT layers can replace RNN or self-attention layers in any network architecture. Training a network with TTT layers also works the same way as training any other language model.
Learning a self-supervised task for TTT
Add some outer-loo parameters to make this task learnable.
The input [math]\displaystyle{ x_t }[/math] is transformed using a learnable matrix [math]\displaystyle{ \theta_K }[/math] to create a projection [math]\displaystyle{ \tilde x_t = \theta_k x_t }[/math]
The reconstruction label is another low-rank projection [math]\displaystyle{ \theta_V x_t }[/math] which can differ from the input. Then we can create a test view [math]\displaystyle{ \theta_Q x_t }[/math]
Now the new self-supervised loss is: [math]\displaystyle{ l(W,;x_t) = \|f(\theta_k x_t; W)-\theta_V x_t\|^2 }[/math] and the output rule is modified to [math]\displaystyle{ z_t = f(\theta_q x_t;W_t) }[/math]
Summaries of Key Points
Motivation: Compressing Long Context Efficiently
Traditional Transformers handle large contexts by storing every token in a Key-Value cache, which grows linearly with sequence length and makes inference complexity scale quadratically. Modern recurrent neural networks (RNNs) like Mamba sidestep storing the entire context by having a fixed-size hidden state, which leads to linear time complexity. However, RNNs often struggle to exploit very long contexts because the fixed-size hidden state must compress a large amount of information. The authors propose a new design, Test-Time Training (TTT), that treats the hidden state as a small learnable model trained via a self-supervised loss on each incoming token—even at test time.
TTT Layers: Hidden State as a Learner
The paper reframes any sequence-modeling layer as “a hidden state plus an update rule.” For TTT layers, the hidden state is itself a small parametric or nonparametric model f, and the update rule is a step of gradient descent (or other training procedure) on each new token. Thus, at every token step, the hidden state is updated by “training” f on a self-supervised objective. Concretely, one might define a corruption or partial view of the token and train the parametric model to reconstruct the hidden or relevant aspects of the token.
Two Main Instantiations: TTT-Linear and TTT-MLP
The authors propose TTT-Linear, where the learner f is a simple linear mapping plus optional layer normalization and a residual connection. They also propose TTT-MLP, which uses a two-layer MLP as its learner, offering a more expressive hidden state. Both can be integrated into existing RNN-based or Transformer-based architectures in place of the usual self-attention or simple RNN blocks. Like other RNN layers, TTT layers compress all historical tokens into a fixed-size hidden state—but the learner can be updated more flexibly via gradient steps each time a new token arrives.
Efficiency Enhancements
Naively computing a gradient step per token would be too slow. Two key ideas improve hardware utilization: 1. **Mini-batch TTT** processes a batch of b tokens at once to parallelize the internal gradient steps. Smaller b yields more gradient steps (and better expressiveness) but can slow performance. 2. **A “dual form”** for TTT-Linear and TTT-MLP reworks the update and output computations into larger matrix multiplications, ensuring that modern accelerators (GPUs, TPUs) can exploit efficient batched operations.
Empirical Results
On language-modeling benchmarks (the Pile and Books), TTT-Linear and TTT-MLP match or exceed strong baselines (Transformer and the modern RNN Mamba) across model scales (125M to 1.3B parameters). TTT-Linear typically does as well as Mamba in short context (2k tokens) but outperforms Mamba substantially in longer contexts (8k or 32k), demonstrating that the extra expressiveness helps exploit more tokens. TTT-MLP can be even more expressive at very long contexts but can be more memory intensive. The authors also show that TTT-Linear can train and infer efficiently in wall-clock time using a specialized GPU kernel, yielding near-constant inference latency as context grows (unlike the linear growth in Transformers).
Significance and Future Work
TTT recasts the hidden-state update in an RNN-like layer as explicitly training a miniature model at test time—essentially “learning to learn” from each incoming token. With further improvements in tasks (beyond simple reconstruction), hardware kernels, and more expressive hidden states, TTT-based architectures may offer a new path toward efficient, high-performing sequence models for extremely long contexts.
Group 6 Presentation: Learning to (Learn at Test Time): RNNs with Expressive Hidden States
Presented by:
Pingchu Zhang, Zhiyang Cheng
Paper Citation
Sun, Y., Li, X., Dalal, K., Xu, J., Vikram, A., Zhang, G., Dubois, Y., Chen, X., Wang, X., Koyejo, S., Hashimoto, T., & Guestrin, C. (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv. https://doi.org/10.48550/arXiv.2407.04620
Background
Recurrent Neural Networks (RNNs) are attractive for their linear time complexity, which makes them efficient, especially for long-context inputs. However, they’ve historically struggled to match the performance of Transformers on tasks like language modeling. One key limitation is the fixed-size hidden state of RNNs, which forces them to compress all past context into a compact representation. This compression becomes increasingly difficult as the context grows longer.
Recent RNN variants like Mamba have closed the gap in scaling performance, but they still hit a ceiling: their ability to improve predictions plateaus at long context lengths (e.g., beyond 16k tokens). Transformers, in contrast, continue to benefit from more context, although at a higher computational cost due to their quadratic scaling.
The authors suggest that this limitation is tied to the expressive capacity of the hidden state. Inspired by how large language models compress vast datasets into their weights through training, they explore whether a hidden state can itself be a learnable model, updated online, even during inference.
Main Idea
The core proposal is the Test-Time Training (TTT) layer, a new kind of sequence modeling layer where the hidden state is a model, and the update rule is a self-supervised learning step. Instead of simply storing a vector or matrix, the hidden state consists of the weights of a small model (like a linear function or a 2-layer MLP). These weights are updated at each time step using gradient descent based on a self-supervised loss.
Key points:
The update happens at test time, not just during training—hence “test-time training.”
The layer sees each input token as a new self-supervised learning opportunity, updating its internal model to better predict the next token.
This approach allows the hidden state to grow in complexity without growing in size—it gains depth by learning, not by storing.
Two instantiations are tested:
TTT-Linear, where the hidden state is a linear model.
TTT-MLP, where the hidden state is a 2-layer MLP.
This method can be used in place of RNN or attention layers, and is compatible with existing architectures. Despite its novel structure, it can be trained end-to-end like other language models.
To make this practical on hardware, the authors also design efficient mini-batch updates and a dual form of the forward pass that enables good GPU utilization. These tricks allow them to run TTT layers efficiently, even faster than Transformers in some regimes.
Experimental & Result
The authors evaluate TTT-Linear and TTT-MLP against two baselines: a strong Transformer and Mamba (a recent high-performing RNN). They focus on both performance and efficiency, testing across different model sizes and context lengths.
1. Short Context (2k and 8k tokens)
At 2k tokens, TTT-Linear, Mamba, and Transformer perform similarly.
At 8k tokens, TTT models outperform Mamba. This shows that as context grows, the test-time learning approach starts to shine.
TTT-MLP generally has better perplexity than TTT-Linear, but is slower due to its more complex hidden state.
2. Long Context (up to 32k tokens)
Experiments on the Books3 subset of The Pile show that Mamba's performance plateaus after 16k tokens.
In contrast, TTT models (especially TTT-MLP) continue to improve, similar to how Transformers behave.
TTT-MLP performs best at long context, consistent with its higher expressivity.
3. Latency and Efficiency
In terms of wall-clock time, TTT-Linear is already faster than Transformers at 8k tokens and matches Mamba.
For token generation (decode time), TTT-Linear and Mamba have much lower latency than Transformers.
These efficiency gains are achieved thanks to GPU-aware design, including the use of mini-batch updates and matrix-optimized dual formulations.
4. Scaling and FLOPs
TTT-Linear uses fewer FLOPs than both baselines at equivalent perplexity.
TTT models perform well under the same training compute budgets (following the Chinchilla recipe).
They also maintain quality under increasing model sizes—from 125M to 1.3B parameters.
Group 7 Presentation: Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains
Presented by:
Jonathan Gallagher and Mariya Anashkina
Paper Citation
A. V. Makkuva et al., “Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains,” 2024, arXiv. doi: 10.48550/ARXIV.2402.04161.
Summaries of key points
This paper offers a theoretical framework for analyzing transformers using Markov chains, providing a new way to understand how information flows and dependencies are formed across layers. The core idea is to view the self-attention mechanism as inducing a Markov process over the sequence positions, where each attention head defines a transition probability matrix. By interpreting attention through this lens, the authors derive principled explanations for phenomena like context aggregation, token mixing, and the effects of multiple layers and heads. They show that stacking layers corresponds to composing Markov transitions, meaning that deeper transformers can be understood as performing longer-range probabilistic walks over the input. This allows them to make quantitative predictions about attention spread and mixing time, and helps formalize intuitions about how depth and attention head structure influence model behavior. Importantly, the framework applies without modifying the architecture—it’s purely analytical, making it a useful tool for understanding model internals in a mathematically grounded way.
The abstraction into Markov chains may oversimplify certain aspects of how attention operates—particularly when attention scores are data-dependent and influenced by non-linear operations (e.g., softmax and masking). The analysis tends to assume relatively clean or idealized cases, and may not fully capture the nuances of attention in real-world settings like language modeling with positional encoding or context-specific weighting. While the framework is insightful, it’s primarily useful for interpretability and analysis, not for improving model performance directly. Finally, the notion of interpretability here is mathematical, but not necessarily human-friendly—it gives you equations, not explanations in natural language.
Clear explanations to aid understanding
Think of each attention head as deciding where to “move” next in the sequence—like a random walker choosing the next word to look at. If you treat those attention weights as transition probabilities, then attention turns into a kind of Markov process, where information flows step-by-step based on those probabilities. Stacking layers is like walking further in the sequence graph, and having multiple heads is like having multiple walkers with different preferences. This view lets you talk about things like mixing time—how fast information spreads—or how different tokens influence each other over multiple layers. It’s a powerful way to bridge probabilistic modeling and deep learning, especially when trying to reason about what attention is doing beyond just visualizing heatmaps.
Group 7 Presentation: Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains
Presented by:
- Jonathan Gallagher
- Mariya Anashkina
Paper Citation
A. V. Makkuva et al., “Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains,” 2024, arXiv. doi: 10.48550/ARXIV.2402.04161.
Background and overview
Markov chains
Markov chains are probability models that predict the future state based on the current and prior states of a system. Language can be modeled as a high order Markov process, an idea that was popularized by Claude Shannon in a 1948 paper, and later became a building block for natural language processing tasks in the modern day. In transformers, predictions follow closely to Markov chains.
In this paper, the authors use a simplified model: a two-state (binary) Markov chain that is in its stationary distribution, making the assumption that the Markov chain has been running for an extremely long time. The term stationary distribution refers to a distribution of states that a chain converges towards as the number of steps increases indefinitely; at that point, each subsequent distribution is the same as the previous one.
Objectives in this paper
The paper introduces a framework to systematically analyze (through the lens of a Markov chain) how transformers learn to model sequential data. The objective is to explore how a transformer succeeds or struggles on first order Markov chains, a distribution that only needs one step of memory. They also do a theoretical analysis of network architecture choices for transformers, and look at how this affects the loss landscape and learning dynamics.
Weight tying
This is one of the main architectural choices in transformers that are studied in this paper. This is the practice of using the same weights in the input and output layers. Firstly, this means that whether the model is trying to read or generate the text, the tokens' representations are the same -- this is done for consistency. Secondly, this can act as a form of regularization as it gives the model fewer parameters to learn.
Theoretical Results
Significance of [math]\displaystyle{ p+q }[/math]
Recall that we are working with binary, first-order Markov processes here. Let [math]\displaystyle{ p }[/math] be the probability that a state 0 will turn to 1, and let [math]\displaystyle{ q }[/math] be the probability that a state 1 will turn to 0. Consequently, probabilities [math]\displaystyle{ 1-p }[/math] and [math]\displaystyle{ 1-q }[/math] are the probabilities that states 0 and 1 will remain unchanged, respectively.
The quantity [math]\displaystyle{ p+q }[/math] is referred to as a switching factor, and is used to characterize the overall tendency of the chain to change state. When [math]\displaystyle{ p+q \lt 1 }[/math], the system is likely to stay in its current state. When [math]\displaystyle{ p+q \gt 1 }[/math], the system is likely to change its state, and therefore will exhibit oscillatory behaviour.
The authors of the paper provide proof of the existence of the global minimum as well as bad local minima in the loss landscape. Furthermore, they look at the existence of global minima and saddle points in the cases with vs. without weight tying.
Theorem 1 (Global minimum). Let the input sequence be [math]\displaystyle{ \{x_n\}_{n=1}^N \sim \bigl(\pi(p,q), P(p,q)\bigr) }[/math] for some fixed [math]\displaystyle{ (p,q)\in(0,1)^2. }[/math] Then for all [math]\displaystyle{ (p,q), }[/math] there exists a [math]\displaystyle{ \theta_{\star}\in\mathbb{R}^{D-d} }[/math] with an explicit construction such that it is a global minimum for the population loss [math]\displaystyle{ L(\cdot) }[/math].
In other words, the authors are able to prove that for the transformer, there exists an optimal parameter configuration.
Theorem 2 (Bad local minimum). Let the input sequence be [math]\displaystyle{ \{x_n\}_{n=1}^N \sim \bigl(\pi(p,q), P(p,q)\bigr) }[/math] for some fixed [math]\displaystyle{ (p,q)\in(0,1)^2. }[/math] If [math]\displaystyle{ p+q\gt 1, }[/math] there exists an explicit [math]\displaystyle{ \theta_{\pi}\in\mathbb{R}^{D-d} }[/math] such that it is a bad local minimum for the loss [math]\displaystyle{ L(\cdot) }[/math]
Theorem 3 (Global minimum). Consider the same setting as in Thm.~1. Then for all [math]\displaystyle{ (p, q), }[/math] if [math]\displaystyle{ \theta_{\star} = (e_{\star} = a_{\star}, \dots, b_{\star}) \in \mathbb{R}^{D-d} }[/math] is a global minimum for the loss [math]\displaystyle{ L(\cdot) }[/math] in the weight-tied scenario, then its extension [math]\displaystyle{ \bar{\theta}_{\star} = (\bar{e}_{\star}, \bar{a}_{\star}) \in \mathbb{R}^D }[/math] is also a global minimum for [math]\displaystyle{ L(\cdot) }[/math] in [math]\displaystyle{ \mathbb{R}^D }[/math] in the non-weight-tied case. Further, [math]\displaystyle{ \bar{\theta}_{\star} }[/math] satisfies the same properties (ii)--(iv) as in Thm.~1.
Theorem 4 (Saddle point). Consider the same setting as in Thm.~3. For [math]\displaystyle{ p + q \gt 1, }[/math] let [math]\displaystyle{ \theta_{\pi} = (e_{\pi} = a_{\pi}, \dots, b_{\pi}) \in \mathbb{R}^{D-d} }[/math] be the corresponding bad local minimum for the loss [math]\displaystyle{ L(\cdot) }[/math] in the weight-tied scenario. Then its extension [math]\displaystyle{ \bar{\theta}_{\pi} = (\bar{e}_{\pi}, \bar{a}_{\pi}) \in \mathbb{R}^D }[/math] is a saddle point for [math]\displaystyle{ L(\cdot) }[/math] in [math]\displaystyle{ \mathbb{R}^D }[/math] in the non-weight-tied case. Further, [math]\displaystyle{ \bar{\theta}_{\pi} }[/math] satisfies the same properties (ii)--(iv) as in Thm.~2.
Main Contributions
Setup for the theoretical framework in this paper
- The dataset is first-order and binary.
- The model is a single-layer transformer, with a single attention head and no layer norm.
- The loss function is cross-entropy population loss. Note: the transformer model doesn't know that the data is Markovian; in other words, it is allowed to use the entire sequence's history, as it doesn't know that each step is dependent only on the previous one.
Empirical results
Summary of this paper
1. The authors introduce a theoretical framework that models the data source as a Markov process. This allows them to study how transformers learn sequential structure, contrasting it with other approaches that treat training data simply as i.i.d.
2. Focusing on single-layer transformers trained for next-token prediction on first-order Markov data, the paper gives a detailed analysis of the cross-entropy loss surface: There is always a set of parameters that perfectly recover the true Markov transition probabilities (thus achieving the global minimum).When the sum of the Markov chain's flipping probabilities exceeds one, the model can converge to parameters that simply predict the stationary distribution rather than the true transition probabilities. This phenomenon does not arise (or becomes only a saddle) when weights are untied or when the transformer has multiple layers.
3. Through experiments, they show that the theoretical findings match empirical behaviors. In particular: When weights are tied, the model may learn a constant (stationary) probability and fail to leverage sequential context if the transition probabilities are above a certain threshold. Removing weight tying—or increasing the transformer's depth—helps avoid such bad local minima.
4. The authors extend the analysis to higher-order processes. Surprisingly, simply increasing the transformer depth does not guarantee learning of higher-order transitions. They find, however, that restricting the attention window (rather than letting the network attend to all past tokens) dramatically improves learning of higher-order Markov patterns.
Related work
1. Analyzing Transformer Architectures:
Vaswani et al. (2017) introduced the Transformer architecture, which has since been extensively studied to understand its theoretical foundations and practical applications.
Rogers et al. (2020) provided a comprehensive analysis of BERT, a model based on the Transformer architecture, discussing its linguistic capabilities and limitations.
2. Combining Attention Mechanisms with Probabilistic Models:
Bahdanau et al. (2015) introduced the attention mechanism in neural machine translation, allowing models to focus on relevant parts of the input sequence dynamically.
Graves et al. (2014) proposed the Neural Turing Machine, combining neural networks with external memory, enabling models to learn algorithmic tasks.
Future direction
1.Enhanced Interpretability of Transformer Models:
Applying the Markov chain framework to dissect and visualize the internal workings of Transformers could lead to more interpretable models, aiding in identifying and mitigating biases.
2.Development of Hybrid Models:
Integrating Markovian structures with attention mechanisms may result in hybrid models that leverage the strengths of both approaches, potentially improving performance in tasks requiring sequential dependencies.
3.Theoretical Analysis of Attention Dynamics:
Further exploration into the mathematical properties of attention modelled as Markov processes could provide deeper insights into the stability and convergence of Transformer-based models.
Group 7 Review: An Analysis of Transformers via Markov Chains
Presented by:
Jonathan Gallagher and Mariya Anashkina
Paper Citation
A. V. Makkuva et al., “Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains,” 2024, arXiv. doi: 10.48550/ARXIV.2402.04161.
Review
I thought this presentation made it really easy to see the connection between Markov Chains, and Transformers when tasked to predict future tokens. The lunch menu analogy made it really easy to see how a first order markov chain behaves, and was a great general idea to discuss before diving deeper into how Markov chains behave in the language world. Theoretical results were also well explained before linking them to empirical results and discussing why the empirical results behave the way they do.
The empirical results were also well explained, showing how removing weight tying or increasing model depth can help escape bad local minima and ensure the transformers parameters are in the global minima for first order markov chains. And how smaller context windows are required to escape bad local minima for higher-order markov chains, even for deeper models (regardless of depth or weight tying).
I also thought this presentation was really well structured, and found that the slides were really easy to follow with great visual aids.
Group 8 Presentation: MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads
Presented by:
Nana Ye and Xingjian Zhou
Summaries of key points
This paper introduces MEDUSA, a lightweight yet effective method for accelerating inference in large language models by attaching multiple decoding heads at intermediate layers. The key idea is that, during inference, you can use these heads to predict several future tokens in parallel, reducing the number of sequential steps needed. Unlike speculative decoding (which relies on a separate draft model), MEDUSA uses the same model and architecture, just augmented with extra linear heads that predict future tokens based on intermediate hidden states. These predictions are verified by the base model in a final pass, similar in spirit to Draft & Verify, but with much lower overhead and implementation complexity. Despite its simplicity, MEDUSA achieves competitive speedups—up to 2×—on models like LLaMA, and integrates easily into existing transformer pipelines. It also preserves generation quality well, maintaining high accuracy across benchmarks without requiring retraining.
One potential limitation of MEDUSA is that its performance gains depend on the quality of intermediate predictions—if the early layers aren't predictive enough, the method may yield minimal speedup or introduce verification bottlenecks. Another concern is scalability: adding too many decoding heads could increase memory consumption or introduce architectural clutter. While the paper shows good results on standard benchmarks, it's less clear how MEDUSA performs in more complex decoding scenarios like beam search, sampling with temperature, or instruction-tuned models. Finally, although it's simple to implement, any modification to production LLM inference stacks still carries deployment costs, which MEDUSA somewhat underplays.
Clear explanations to aid understanding
Think of a transformer generating text one word at a time—slow, because each step waits on the previous. MEDUSA says: what if we could guess ahead a few tokens using partial information? It adds small prediction heads at different layers in the model, each trying to guess future tokens before the final layer finishes computing. Once these guesses are made, the base model verifies them. If correct, we skip ahead; if not, we fall back. It’s like speculative decoding, but self-contained—no second model, no complicated setup. You get the parallelism of speculative methods with the simplicity of just tweaking the model's architecture slightly.
Group 8 Presentation: MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads
Presented by:
Nana Ye and Xingjian Zhou
Summaries of key points
Goal: By adding multiple "Medusa headers" to the master model, you don't rely on an external model to predict multiple tokens at once.
Background: The core problem of slow inference in large models is not the memory bandwidth bottleneck. Autoregressive decoding generates tokens one by one and has low GPU utilization. Draft model acceleration has the problem of additional model overhead and inconsistent distribution.
Methodology: Use multiple Medusa heads to predict future tokens in parallel. Candidate tokens are organized into Tree Attention to validate multiple sequences simultaneously. Accept the longest prefixes with reasonable probability using the Typicality-based Acceptance strategy.
Result: Qwen7B vs. Zephyr7B model, on the ChatGPT dataset, Medusa 1 accelerates about 2.2x, Medusa 2 accelerates about 2.8x, and some tasks accelerate up to 3.6x, faster and with almost lossless quality.
Constructive critiques or reviews
With in-depth detailed explanation, let the audience understand more deeply.
You can try turning on the camera to increase affinity.
Clear explanations to aid understanding
Medusa 1: Train Medusa head only to save resources.
Medusa 2: Train Main model and Medusa head together. The performance degradation of the main model is avoided by a two-stage strategy
Mamba: Mamba uses a state-space model for remote dependency capture.
Group 8 Presentation: MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads
Presented by:
- Nana Ye
- Xingjian Zhou
Paper Citation
T. Cai et al., “Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads,” 2024, arXiv. doi: 10.48550/ARXIV.2401.10774.
https://arxiv.org/abs/2401.10774
Background
- As the size of LLMs grow, the speed at which they can generate tokens decreses. The bottleneck is primairly the transfer of data to/from the GPU
- Speculative Sampling is an existing solution that predicts multiple tokens in the future at once using smaller "draft" models
- Medusa instead solves this problem by adding multiple decoding heads and a tree based attention mechanism to existing LLMS
- Paper discusses the implementations of Medusa1 and Medusa2
Main Idea
The idea is that by replacing the draft model and use the heads within our own model since the best representation of a model is itself. The multiple heads predicts multiple tokens at once to leverage parallelism. This allows it to be more efficient and provide more tokens for the tree-based attention to choose. The tree-based attention is used is simulate the idea as if the tokens are being generated sequentially, by traversing through a tree from top to bottom, where top is the initial word.
Methodology
During training, each Medusa head is optimized to predict a future token given the same input context. For example, head 3 is trained to predict token xt+3 using only the context up to xt. The objective function is a standard cross-entropy loss over the predicted token distributions.
To avoid error accumulation, the training data for deeper heads (e.g., t+4, t+5) is generated using gold sequences rather than model outputs.
The decoding tree is constructed with a fixed depth and branching factor corresponding to the number of Medusa heads. Unlike beam search, this tree is evaluated in parallel, and scoring is done using a lightweight attention mechanism applied to the hidden states of the candidate paths.
For candidate selection, a method called typical acceptance is introduced as a fast alternative to rejection sampling. It accepts candidates based on whether their token-level probabilities fall within a "typical" range, reducing the number of evaluations needed during decoding.
Technical Contributions
Medusa 1:
- Uses a frozed pre-trained LLM and trains extra decoding heads on top
- Each additional decoding head predicts a token K time steps in the future
- Uses a probability loss function that scales based on the number of steps into the future
- Reduces memory usage because the backbone model is only used for hidden state extraction
- In simple terms, Medusa adds additional linear layers on top of the last hidden layer from the transformer output which are training to predict the tokens in future positions, rather than just the next token like a conventional transformer mechanism does in a typical auto-regressive manner.
Medusa 2:
- Fine tunes the LLM and trains the decoding heads at the same time.
- Encountered problems with high losses, switched to a two-stage training process:
- Stage1: train only the Medusa heads (simillar to Medusa1)
- Stage2: Train both the backbone model and the medusa heads together
Tree Attention
- Tree attention is used to enable the heads predicting later tokens to include the additional context which may have been created by the earlier medusa heads in the pipeline
- This tree structure does not occur autoregressively, however
- The top predictions from each head are fed into the tree structure as candidate tokens
- An attention mask is used to ensure that the future token prediction from the tree is based on prior tokens, not future ones past the one being dedicated
- Multiple future candidate tokens can be predicted with context-aware attention simultaneously
Self Distillation
- A dataset with prompts relevant to the desired model are created
- The full large language model predicts outputs to these prompts in a typical auto regressive manner. These prompts are used to form a training dataset for the self-distillation step
- Medusa Heads are trained on the generated training dataset
Tree Construction
- Prune Less Promising Branches: Branches with low probability of containing the next token are pruned from the tree of candidate tokens in tree attention, this reduces the computational expensiveness of MEDUSA 2
- Select the Best Candidate: From the remaining typical candidates, the longest accepted prefix is chosen for the next decoding step
Empirical Evaluation
Experiments on various LLMs show consistent 2–3 times speedups in practice without harming output quality (assessed by GPT-4 and other metrics). The authors also include ablation studies on key design choices (number of heads, attention structure, sampling thresholds), confirming the effectiveness and generality of the proposed framework.
Constructive Critique
While Medusa demonstrates promising improvements in decoding speed and flexibility, a few limitations remain:
- Training Stability: Especially in Medusa-2, jointly fine-tuning heads and the base model requires a two-stage schedule with learning rate separation and warmup—indicating some instability.
- Model Complexity: Adding multiple Medusa heads and tree attention introduces architectural complexity, which may hinder adoption or reproducibility without careful engineering.
- No open-source code: As of the paper's release, there is no official implementation, which limits replication and community engagement.
Group 8 Presentation: MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads
Presented by:
- Nana Ye
- Xingjian Zhou
Paper Citation
T. Cai et al., “Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads,” 2024, arXiv. doi: 10.48550/ARXIV.2401.10774.
https://arxiv.org/abs/2401.10774
Summaries of key points
One of the main challenges with LLMs is their slow inference process. It’s not because GPU can’t do the math faster, but because of memory bottleneck (the need to constantly move data between memory and the GPU). LLMs also generate text one token at a time using autoregressive decoding, where each token depends on the previous one, which leads to underutilization of the GPU.
One attempt to address this is Speculative Decoding, where a smaller "draft model" predicts multiple tokens at once, and the main model verifies them. While this speeds up generation, it requires maintaining a separate draft model, and it can lead to mismatches between the draft and main models, making integration into a system difficult.
So, Medusa is proposed to solve these issues.
How does Medusa work?
Core Idea of Medusa
Unlike speculative decoding, Medusa doesn’t need a separate draft model. Instead, it adds extra decoding heads to the same model, allowing it to predict multiple tokens at once (generating candidates). This reduces the dependency on previous tokens, making inference faster. Once multiple tokens are predicted, Medusa organizes these predictions into a tree structure, where each node represents a token. This structure helps Medusa evaluate several possible token sequences at once (processing candidates). This is called tree-based attention. After evaluating, Medusa picks the most likely token sequence and outputs the best one (accepting candidates).
Training Strategies of Medusa
Medusa has two training strategies to optimize its performance:
- Medusa 1: In this method, the original model (called the backbone) is frozen, meaning its parameters don’t change. Only the Medusa heads are trained. This saves computation and avoids the model forgetting what it learned originally, while improving inference speed by predicting multiple tokens at once.
- Medusa 2: In this approach, both the backbone and the Medusa heads are trained together. This boosts prediction accuracy, especially in larger models. The training starts with a two-stage process to prevent issues like high loss and large gradients from the new Medusa heads that could disrupt the backbone model. In Stage 1, only the Medusa heads are trained to specialize without affecting the main model. In Stage 2, both the Medusa heads and the backbone are trained together, with a warm-up strategy to gradually increase the learning rate for smoother adaptation. The tree attention mechanism also helps organize token continuations in a tree, allowing the model to evaluate multiple possibilities at once, speeding up inference.
Further enhancements of Medusa
The author further enhances Medusa's practical utility with three significant extensions:
1. Typical Acceptance Scheme: Instead of rejecting candidates based on strict probability thresholds, Medusa evaluates them based on how typical they are compared to the original model’s distribution. This speeds up the process without sacrificing quality.
2. Self-Distillation: Medusa can generate training data from its own output. It creates a seed dataset and then uses that data to train the Medusa heads, which helps improve the model’s efficiency in making predictions.
3. Optimized Trade Structure: Medusa improves how the model evaluates candidates by focusing on the most promising tokens, making the inference process faster and more efficient.
Benefits and Limitations
Medusa has shown great results in experiments with models like Riccuna-7B and Riccuna-13B, achieving up to 2.8 times faster performance than traditional methods, with even higher speedups for tasks like text extraction (3.62x) and coding(3.29x). It consistently outperformed speculative decoding, reaching 2.83x acceleration with Riccuna-7B, compared to 1.47x with speculative decoding. Despite the speed improvements, Medusa maintains high text quality, making it efficient without compromising accuracy. The tree-based attention mechanism further boosted speed by 1.9x. However, Medusa has some limitations, such as higher memory usage due to additional heads and the tree attention mechanism, and it can be challenging to scale for larger models and complex tasks.
Group 9 Presentation: Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality
Presented by:
- Kaiyue Ma
- Wenzhe Wang
Paper Citation
T. Dao and A. Gu, “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality,” 2024, arXiv. doi: 10.48550/ARXIV.2405.21060.
Background
- Transformers are effective, but computationally expensive and suffer from quadratic complexity
- Structured state space models (SSMs) are an alternative that scales linearly instead and works for long range modeling
- SSMs have not recieved the same main stream improvements as transformers, and lack support for parallelization and hardware acceleration
- Structured state space duality (SSD) bridges the gaps between transformers and SSMs
Additional Background
SSM(State Space Models) are traditionally used in control theory to model a dynamic system via variables. But then from this paper https://compneuro.uwaterloo.ca/files/publications/voelker.2018.pdf, they discovered that SSM is great for describing the time cells in the brain. A useful diagram of a SSM can be found here https://cdn-uploads.huggingface.co/production/uploads/613b0a62a14099d5afed7830/G7icfkYoxIqHZcJGHM7UD.png, where, n state variables, u, m state inputs, and y, p outputs.
Technical Contributions
- Represents SSMs as semiseparable matrices and uses semiseparable matrices for efficient matrix operations.
- Uses generalized linear attention mechanism with structured masks.
- Refines the original Mamba model to yield Mamba-2, which incorporates the new structured-state-space-duality algorithms. Mamba-2 is easily parallelizable and scales better to large state sizes. Empirical results demonstrate strong performance on language modelling benchmarks, surpassing older SSM models and matching or outperforming Transformers at various model scales.
Summaries of Key Points
The paper explores the theoretical connections between Transformer architectures and Structured State Space Models (SSMs). The authors introduce the State Space Duality (SSD) framework, which bridges these two model families through the concept of structured semiseparable matrices. This framework reveals that certain attention mechanisms in Transformers can be interpreted as SSMs, providing a unified perspective on sequence modelling techniques.
Leveraging the SSD framework, the authors propose a new architecture called Mamba-2. This model refines the selective SSM approach used in the original Mamba model, resulting in a design that is 2-8 times faster while maintaining competitive performance in language modelling tasks. Mamba-2 achieves this efficiency by simplifying the SSM layer, enabling better scalability and computational speed.
The paper also introduces efficient algorithms based on block decompositions of semiseparable matrices, which enhance the computational efficiency of SSMs. These algorithms allow for larger recurrent state sizes and improve the practicality of SSMs in handling long-range dependencies within sequences.
Empirical evaluations demonstrate that Mamba-2 outperforms both its predecessor and Transformer models in terms of training efficiency and performance on language modelling benchmarks. The architecture also shows superior capabilities in associative recall tasks, highlighting its effectiveness in capturing and utilizing long-range dependencies.
In summary, the paper provides a theoretical foundation connecting Transformers and SSMs, introduces the Mamba-2 architecture as a practical application of this theory, and presents algorithms that enhance the efficiency and scalability of sequence modelling techniques.
Constructive Critique
While the paper introduces a powerful theoretical framework through State Space Duality (SSD) and demonstrates strong empirical performance with Mamba-2, several areas could be further clarified or improved:
- Theoretical accessibility: The concept of semiseparable matrices and duality between attention and SSMs is mathematically rich, but not easily accessible to a broader audience. Including more visual or intuitive explanations would improve its pedagogical impact.
- Benchmark diversity: Most experiments focus on language modeling tasks. It remains unclear how Mamba-2 performs on other domains such as vision, speech, or reinforcement learning. Since SSD is a general framework, cross-domain validation would help showcase its broader applicability.
- Scalability limitations: While Mamba-2 is more efficient than its predecessor, the paper doesn’t fully discuss how performance scales with increasing model depth or state size, especially under training constraints on real-world hardware.
- Lack of interpretability analysis: The paper does not explore how the SSD framework or Mamba-2 influences model interpretability (e.g., how information is propagated or stored over long sequences), which could be important for downstream applications.
Despite these limitations, the paper makes a substantial theoretical and practical contribution by unifying two dominant modeling paradigms and offering a concrete architecture (Mamba-2) that is both efficient and performant.
Group 9 Presentation: Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality
Presented by:
- Kaiyue Ma
- Wenzhe Wang
Paper Citation
T. Dao and A. Gu, “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality,” 2024, arXiv. doi: 10.48550/ARXIV.2405.21060.
Background & Motivation
The paper aims to unify state-space models (SSMs) and attention mechanisms through structured state-space duality (SSD), demonstrating that SSMs can be interpreted as a form of masked attention with semiseparable matrices. This approach enables the utilization of attention's hardware efficiency (e.g., optimized matrix multiplications) while preserving the linear scaling property of SSMs. Although Mamba's selective SSM is powerful, it is slower than optimized attention due to its reliance on sequential scans rather than direct matrix operations. The authors propose methods to accelerate SSMs by 2–8× without compromising performance or even enhancing it. By reformulating SSMs as matrix transformations, the paper offers novel theoretical insights, such as their equivalence to semiseparable matrices, along with practical algorithms like block decomposition for efficient computation. These contributions pave the way for hybrid architectures (e.g., Mamba-2 augmented with attention layers) and improved system-level support (e.g., tensor and sequence parallelism).
Key Points
The paper titled "Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality" investigates the theoretical and practical connections between Transformers and State Space Models (SSMs), with a particular emphasis on structured state-space duality (SSD). The key contributions of the paper include:
1. Duality Framework: The authors introduce Structured State-Space Duality (SSD), a framework that establishes a connection between State Space Models (SSMs) and attention mechanisms via structured matrices, particularly semiseparable matrices. This duality enables SSMs to be interpreted as matrix transformations, thereby uncovering novel algorithmic and architectural insights.
2. Efficiency Improvements: The paper introduces the Mamba-2 architecture, which enhances the selective SSM of Mamba to achieve 2–8× faster computation while preserving competitive performance in language modeling tasks. This improvement is realized through the utilization of hardware-efficient matrix multiplications and block decompositions of semiseparable matrices.
3.Structured Masked Attention (SMA): A generalization of linear attention is introduced, in which the attention mask is replaced by a structured matrix (e.g., semiseparable matrices). This substitution enables subquadratic computational complexity and facilitates efficient autoregressive inference.
4. Hybrid Models: The paper demonstrates that combining SSMs with attention layers (e.g., 10% attention layers in Mamba-2) can improve performance, suggesting complementary strengths between the two paradigms.
Contirbutions
- Theoretical Connections: It establishes a rigorous equivalence between SSMs and semiseparable matrices, unifying recurrent, convolutional, and attention-based sequence models under a single framework.
- Algorithmic Innovations: The SSD algorithm optimizes SSM computation by blending linear (recurrent) and quadratic (attention-like) forms, achieving linear complexity while leveraging modern hardware.
- Mamba-2 Architecture: This new architecture improves upon Mamba by simplifying projections, enabling tensor parallelism, and incorporating larger state dimensions, resulting in better scalability and efficiency.
- Empirical Validation: The authors validate Mamba-2 on synthetic tasks (e.g., associative recall) and language modeling, showing it outperforms Mamba and matches or exceeds Transformer++ in perplexity and downstream tasks.
Constructive Critiques
- Expressivity Trade-offs: The adoption of scalar-identity structure for A matrices in Structured State-Space Duality (SSD) may constrain model expressivity relative to general diagonal State Space Models (SSMs). The paper could provide a more in-depth analysis of the trade-offs between hardware efficiency and model flexibility.
- Attention Approximation: The negative results observed for kernel approximations (e.g., Performer, cosFormer) in Mamba-2 indicate that the advantages of Structured State-Space Duality (SSD) may not fully translate from linear attention mechanisms. A more in-depth investigation into the reasons for the underperformance of these methods could further enhance the study.
- Broader Applicability: The focus is heavily on language modeling. Evaluating SSD on other domains (e.g., vision, reinforcement learning) could demonstrate its generalizability.
- Implementation Complexity: Although the SSD algorithm is simpler than Mamba's selective scan, its block decomposition may still present implementation challenges for practical adoption. Conducting additional ablation studies on parameters such as chunk size and parallelism levels could provide valuable guidance for practitioners.
Relationships to Other Works
The paper extends research in efficient sequence modeling, connecting various approaches through a unified framework. It builds on recent progress in State Space Models (SSMs), particularly from S4 to Mamba. S4 and S4D introduced diagonal-plus-low-rank matrices for long-range modeling, while Mamba's selective SSMs improved performance on dense data like language. The SSD framework generalizes these models using semiseparable matrices and introduces hardware-aware optimizations, making Mamba-2 significantly faster.
Connections to linear attention methods form another key thread. The paper generalizes Katharopoulos et al.'s linear attention with structured masked attention via semiseparable matrices. This links SSD to models like RetNet (fixed exponential decay) and GateLoop (input-dependent gating). GLA's chunkwise computation resembles SSD's block decomposition but lacks SSD's theoretical unification.
The work also intersects with efforts in efficient recurrent architectures. RWKV's attention-like gating shares similarities with SSD's matrix-based approach, though SSD offers a more rigorous mathematical foundation. Griffin's combination of SSMs with local attention and xLSTM's expanded state dimensions align with SSD's themes, suggesting SSD provides a unifying perspective.
On the systems side, the paper complements hardware-efficient Transformer implementations. While FlashAttention optimizes attention kernels, SSD advances SSM-based models. Monarch Mixer's structured matrices share some ideas with SSD but apply them differently. These connections highlight SSD's contribution to efficient deep learning architectures.
Theoretically, SSD bridges modern sequence modeling with classical numerical linear algebra. Semiseparable matrices connect to structured matrix computations, offering new insights into model representation. This grounding may inspire future algorithmic improvements.
Group 9 Presentation: Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality
Presented by:
- Kaiyue Ma
- Wenzhe Wang
Paper Citation
T. Dao and A. Gu, “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality,” 2024, arXiv. doi: 10.48550/ARXIV.2405.21060.
Introduction
The paper "Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality", by Sri Dao and Albert Gu establishes a theoretical framework connecting Transformers and State Space Models - SSMs.
This framework, termed Structured State Space Duality - SSD - brings these two prominent sequence architectures leading to the development of Mamba-2. This model enhances efficiency while painting competitive performance in LLMs.
Key Contributions
(1) Establishing Structured State Space Duality
The authors demonstrate that structured SSMS and attention mechanisms are closely related through structured matrices, specifically semiseperable matrices. This insight reveals that various sequence models can be interpreted as different parametrization of these matrices, providing a unified understanding.
Two perspectives are introduced to implement this duality:
Matrix representation: viewing sequence models as matrix transformations highlighting how SSmS can be represented using semiseparable matrices. This has sub-quadratic parameters.
Tensor Contraction Representation: This illustrates how the computations in Attention mechanisms can be reformulated in terms of tensor contractions; aligning them with SSM operations
By framing SSMs as well as attention mechanisms within the SSD framework, the paper enables the transfer of algorithmic and system optimizations between these models, fostering advancements in efficiency and stability.
(2) Development of Mamba-2 Architecture
Leveraging the SSD framework, Mamba-2 is an improved version of the Mamba-1 architecture. Mamba-2 refines in the selective SSM layer, and results in a significantly faster enhancement.
Mamba-2 therefore achieves 2-8 times faster performance compared to its predecessor while maintaining competitiveness with transformers in language modelling tasks. This demonstrates the practical benefits of applying the SSD design.
(3) Efficient Algorithms Through SSD
The paper presents efficient algorithms derived from the SSD framework that optimize the computation. These algorithms reduce the complexity often associated with traditional sequential models.
SMA, a novel attention variant, is introduced which benefits the structured properties of the SSD. This leads to a more efficient attention computation.
Applications and Impact
The SSD framework offers a new paradigm for designing sequence models, allowing practitioners to harness the strength of both SSMs and transformers. This leads to models that are both computationally efficient and effective in capturing long-range dependencies.
By reducing computational complexity, the insights from this paper facilitate the development of models that can handle longer sequences and larger datasets; addressing a common limitation in sequence modelling, thereby allowing scalability.
Finally, the theoretical connections established within the paper table the application of optimization techniques across different model architectures. This lays the foundations for more unified and efficient approaches to sequence modelling through unisons.
This work not only enhances theoretical understanding but also leads to practical advancements, exemplified by the development of the Mamba-2 Architecture.
Group 9 Presentation: Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality
Presented by:
- Kaiyue Ma
- Wenzhe Wang
Paper Citation
T. Dao and A. Gu, “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality,” 2024, arXiv. doi: 10.48550/ARXIV.2405.21060.
Summaries of key points
Goal: SFB connects SSM and Transformer, combining the best of both.
Background: Transformer can handle long distance dependencies but is computationally complex. SSM model is linear complexity, but it is difficult to fully parallel and accelerate.
Methodology: The SFB framework, with a semi-separable matrix structure, relates the SSM to the attention mechanism. The SSD algorithm is designed, the matrix is divided into diagonal blocks and non-diagonal blocks, and multi-head sequence transformation, parallel projection and kernel methods are added.
Result: Mamba2 has the same performance as Transformer but 2-8 times faster training.
Constructive critiques or reviews
There are clear illustrations, so that the audience can be more concise and intuitively understanding.
It's well structured, you can get a little more detail into place, and you can include more detail in your slides.
Clear explanations to aid understanding
Semi-separable matrix: Compress a large matrix into structured blocks that are easier to compute.
Block decomposition of SSDS: a matrix diagonal for attention work, non-diagonal for fast recursion.
Group 10 Presentation: Accelerating Large Language Model Decoding with Speculative Sampling
Paper Citation
C. Chen, S. Borgeaud, G. Irving, J.-B. Lespiau, L. Sifre, and J. Jumper, ‘Accelerating Large Language Model Decoding with Speculative Sampling’, Feb. 02, 2023, arXiv: arXiv:2302.01318. doi: 10.48550/arXiv.2302.01318.
https://arxiv.org/abs/2302.01318
Background
- Traditional autoregressive decoding for large language model is computationally expensive, as the entire model has to run for each additional token which is generated
- Transformer neural networks are typically memory bandwidth limted, and using quantization or distillation to make smaller models are solutions which have been used in the past to improve the performance of LLMs
Technical Contributions
- Speculative sampling was developed to increase the speed of LLM decoding without meaningfully reducing it's performance on predicting future tokens compared to the base model
- Generates a draft of k future tokens using a smaller model
- Score the proposed tokens using the base model
- A modified rejection sampling scheme was developed by the authors in the paper
- The acceptance of a draft token is based on the minimum of 1 and the ratio of the target model's probability to the draft model's probability for that token.
Explanations to Aid Understanding
- Transformer Decoding: This is the process by which a large language model generates text. Given an input prompt, the model sequentially selects the most probable next token, then uses that token to predict the subsequent one, and so on, and this process is computationally intensive.
- Speculative Sampling: Unlike traditional decoding where one token is generated at a time by the target model, speculative sampling aims to generate multiple tokens in parallel by using a faster, potentially smaller "draft model" to propose a short sequence (the draft). The target model evaluates these drafted tokens, and a rejection sampling mechanism decides which ones to accept, ensuring the output remains consistent with the target model's capabilities.
- Parallel Scoring: Instead of computing the logits for each drafted token sequentially, the method computes the logits for all (K+1) tokens in the draft at the same time. The presentation notes that the computing time for this parallel process is similar to sampling a single token with the target model, which is a key factor in the potential speedup. The key insight is that the model inference pass is dominated by memory bandwidth and inter-device communication rather than purely by the token count. By handling several drafted tokens per pass, overall decoding time is greatly reduced.
Summaries of Key Points
- Decoding Challenges in LLMs: Traditional autoregressive sampling methods generate one token at a time, leading to inefficiencies, especially as model
- Speculative Sampling Overview: SpS utilizes a draft model, which is a smaller and faster version of the target LLM, to propose a sequence of tokens. The target model then evaluates this proposed sequence, accepting or rejecting tokens based on a modified rejection sampling scheme that ensures the final output aligns with the target model's distribution.
- Algorithm Efficiency: The draft model generates a short sequence (e.g., K tokens) in parallel. The target model scores this sequence, and tokens are accepted or resampled as needed, allowing for the potential acceptance of up to K+1 tokens per iteration. This parallel approach contrasts with traditional methods that generate tokens sequentially, thereby reducing decoding latency.
- Empirical Results: Implementing SpS with Chinchilla, a 70-billion-parameter language model, resulted in a 2 to 2.5 times speedup in decoding across benchmark tasks such as XSum and HumanEval. These speed improvements were achieved without degrading sample quality or requiring changes to the target model's parameters or architecture.
- Advantages of Speculative Sampling: Maintains the integrity of the target model's output distribution. Requires no alterations to existing model architectures, facilitating easy integration into current systems. Demonstrates versatility across various tasks and decoding methods, making it broadly applicable in LLM deployments.
Group 10 Presentation: Accelerating Large Language Model Decoding with Speculative Sampling
Presented by:
Danyang Zhao
Speculative Sampling: An In-Depth Explanation
Speculative Sampling is a technique introduced to speed up text generation for large language models (LLMs) without significantly affecting output quality. It works by leveraging a smaller, faster “draft” model to propose multiple candidate tokens, which are then verified and accepted or rejected by the larger model. This method reduces latency during decoding while maintaining output quality close to the original model.
But, what specifically is Latency? Latency latency refers to the time delay between a user requesting text generation and the model producing the output. More formally, it is the time required to generate each token in an autoregressive model.
But why do we need Speculative Sampling? Because there re problems with the Standard Autoregressive Sampling. Mainly:
- LLMs generate text token-by-token. Each step depends on the previous tokens, making sequential decoding slow.
- The computational bottleneck comes from generating the next token, which requires a full forward pass of the model.
Therefore the goal of speculative sampling is to reduce the number of forward passes of the large language model per generated token.
How Speculative Sampling Works:
1. There are two models, a Draft model and a Large Model. The smaller draft model, which is cheap and fast to compute, generates a set of k speculative candidate tokens. Then, the large model, which is expensive yet accurate, verifies these tokens and either accepts or rejects them.
2. The Draft model proposes multiples tokens, that is, instead of sampling just on token a each step, the draft model generates a sequence of k candidate tokes: [math]\displaystyle{ $x_{t+1}, x_{t+2}, ..., x_{t+k} }[/math] using the standard autoregressive decoding.
3. Now, the large language model verifies the proposal. This is done by calculating the probability of each proposed tokens:
[math]\displaystyle{ p(x_{t+1} | x_{\le t}), p(x_{t+2} | x_{\le t+1}), ..., p(x_{t+k} | x_{\le t+k}) }[/math]
For each Candide, there are only two possibilities. 1). The token is accepted and is added to the output. 2). The token is rejected and the large model takes over and directly generates a new token.
4). Since the drat model is smaller, it generates multiples speculative tokens quickly. Then, the forward model only computes a few forward passes, making the process more efficient. This allows for faster decoding by efficient verification.
Now, let us go over a bit more of the mathematics.
For each token [math]\displaystyle{ x_i }[/math], we check:
[math]\displaystyle{ p(x_i | x_{\lt i}) \ge q(x_i | x_{\lt i}) }[/math]
Where p is the probability distribution of the large model and q that of the draft model. If true, it is accepted, else, it is not.
This is paired with a Metropolis-style acceptance probability, which ensures the final sampled tokens remain statistically valid while speeding up computation.
The acceptance probability can then be calculated as follows:
[math]\displaystyle{ \alpha_i = \min \Big( 1, \frac{p(x_i | x_{\lt i})} {q(x_i | x_{\lt i})} \Big) }[/math]
Group 10 Presentation: Accelerating Large Language Model Decoding with Speculative Sampling
Presented by:
Danyang Zhao
Summaries of key points
Goal: A speculative sampling algorithm is proposed to accelerate the decoding process of the large prediction model.
Background: Traditional Transformer is slow and costly, and existing methods cannot effectively improve the generation speed.
Methodology: A small draft model is used to generate a token sequence of length k, and the logits of k+1 tokens are computed in parallel with the target large model. The modified rejection sampling method is used to decide whether to accept the draft token.
Result: On Chinchilla, the output quality is almost unaffected, and the generation speed is significantly improved.
Constructive critiques or reviews
The presentation can be more detailed and provide more examples to help you understand.
Increase fluency and reduce pause time.
Clear explanations to aid understanding
Compare the probability of generating the target model and the draft model to decide whether to accept a token.
The output should not deviate from the target model distribution.
Compared with the distillation model, the speculative sampling model is not changed, and the acceleration is direct.
Group 11 Presentation: Simple Linear Attention Language Models Balance the Recall-Throughput Tradeoff
Presented by:
Yiyuan Yang, Anar Kuatzhan, Chuan Zhang
Paper Citation
Arora, S., Eyuboglu, S., Zhang, M., Timalsina, A., Alberti, S., Zinsley, D., ... & Ré, C. (2024). Simple linear attention language models balance the recall-throughput tradeoff. arXiv preprint arXiv:2402.18668. https://arxiv.org/pdf/2402.18668
Background
Nowadays large language models still struggle with efficiency. In attention-based models, attention requires a huge number of calculations as the input gets longer. Attention also stores every previous word in memory, which makes it memory-intensive. New language models developed in recent years can generate text faster while maintaining low perplexity. Low perplexity doesn't necessarily mean good recall. Gated convolutional models also struggle with recall. Attention-based models excels in recall tasks. To address these problems, the authors introduced the Based model.
Technical Contributions
The authors introduce a new architecture called BASED (Balanced Attention through Sliding + Exponential Decay), which is designed to address the recall–throughput tradeoff in language models. It combines two key components:
- Linear Attention (Global Context):
- Uses a second-order Taylor approximation of softmax attention.
- Enables constant-size memory state during autoregressive decoding.
- Sliding Window Attention (Local Context):
- Performs exact attention within a small local window (e.g., 64–128 tokens).
- Captures short-range dependencies with high precision.
Memory-Recall Tradeoff: Observed both within and across architecture classes.
Performance with Fixed Recurrent State: Not all architectures have the same recall capacity. Mamba optimally utilizes limited memory budgets while convolutional architectures underperform with memory constraints.
The Based model combines local fine-grained attention + long-range linear attention via Taylor approximation of softmax exponential function that are sub-quadratic complexity during training and permit an efficient recurrent inference view. Based outperforms prior sub-quadratic architectures in recall quality by up to 6.2 accuracy points.
Architecture
Softmax-approximating linear attention (applied globally) + exact softmax attention with sliding windows (applied locally)
This combination achieves 90.8% of full softmax attention's recall accuracy while reducing latency by a factor of 100,000.
Accomplishments
- Improved Recall:
- BASED outperforms Mamba by up to 10.36 accuracy points on recall-intensive tasks.
- Recovers over 90% of softmax attention’s recall performance while using significantly less memory.
- High Throughput:
- Achieves up to 24× higher generation throughput compared to FlashAttention-2.
- Competitive wall-clock time due to efficient CUDA kernel design.
- Strong Language Modeling:
- Matches or surpasses models like Mamba and Hyena in perplexity and downstream task accuracy.
- Theoretical Contributions:
- Demonstrates that recurrent models require [math]\displaystyle{ \Omega(N) }[/math]-bit memory to perform associative recall.
- Proves that BaseConv, a gated convolution model, cannot solve recall tasks in constant layers.
- Shows that the recall-throughput tradeoff is theoretically fundamental.
Group 11 Presentation: Simple Linear Attention Language Models Balance the Recall-Throughput Tradeoff
Presented by:
Yiyuan Yang, Anar Kuatzhan, Chuan Zhang
Paper Citation
Arora, S., Eyuboglu, S., Zhang, M., Timalsina, A., Alberti, S., Zinsley, D., ... & Ré, C. (2024). Simple linear attention language models balance the recall-throughput tradeoff. arXiv preprint arXiv:2402.18668. https://arxiv.org/pdf/2402.18668
Introduction
"Simple Linear Attention Language Models Balance the Recall-Throughput Tradeoff" is a peer which introduces BASED, an architecture designed to enhance the efficiency of language models by balancing memory consumption and recall abilities. This approach combines Linear Attention with Sliding Window Attention to navigate the tradeoff between state size and recall.
Methodology
The researchers analyzed various architectures to understand the tradeoff between a model's state size and its recall ability. They observed that efficient alternatives to attention, such as H3, Mamba, and RWKV, maintain a fixed-size recurrent state but exhibit limitations in recall performance. To address this, they proposed the BASED architecture, which combines linear attention with sliding window attention. By adjusting the window size and the feature dimension of the linear attention, BASED can navigate the Pareto frontier of the recall-memory tradeoff, effectively balancing recall quality and state size.
Empirical results
The study trained language models with up to 1.3 billion parameters and found that BASED matches the perplexity of leading sub-quadratic models like Mamba. Furthermore, BASED outperformed these models on real-world recall-intensive tasks by 6.22 accuracy points. Additionally, the implementation of input/output-aware algorithms enabled BASED to achieve 24 times higher throughput in language generation compared to FlashAttention-2 when generating 1,024 tokens using 1.3 billion parameter models.
Mathematical Explanation of BASED Architecture =
(1) Sliding Window Attention -- SWA
SWA computes the attention over a fixed-size window of precious tokens, capturing local dependencies. For a window size, say of [math]\displaystyle{ w }[/math], the attention for a token [math]\displaystyle{ t }[/math] considers only the tokens [math]\displaystyle{ [t-w, t-q] }[/math].
Given queries, keys and values [math]\displaystyle{ Q \in \mathbb^{n \times d} \ \ K \mathbb^{n \times d} \ \ V \mathbb^{n \times d} \lt \math\gt for a sequence of length \lt math\gt n }[/math] and of hidden dimension [math]\displaystyle{ d }[/math], the attention output [math]\displaystyle{ A }[/math] is computed as follow:
[math]\displaystyle{ A_t = \text{softmax} \Bigg( \frac{Q_t K_{t-w : t-1}}{\sqrt{d}} \Bigg) V_{t-w : t-1} }[/math]
Where [math]\displaystyle{ A_t }[/math] is the attention output at position [math]\displaystyle{ t }[/math].
(2) Linear Attention
Linear attention approximates standard attention mechanism to capture global dependencies with reduced computational complexity. It redefines the attention operation to be linear in the sequence length using feature maps [math]\displaystyle{ \phi }[/math] to project queries and keys.
The linear attention output is computed as:
[math]\displaystyle{ A = \phi(Q) \big( \phi(K)^T V \big) }[/math]
Where [math]\displaystyle{ \phi }[/math] is a feature map function applied to the queries and keys. This formulation allows the attention computation to be rearranged and optimized, reducing the complexity to [math]\displaystyle{ O(n) }[/math].
(3) Combining Sliding Window Attention and Linear Attention
BASED integrates SWA and linear attention to leverage the strength of both methods. SWA captures fine-grained local dependencies while Linear Attention models long-range dependencies.
By simply adjusting the sliding window size [math]\displaystyle{ w }[/math] and the feature dimension [math]\displaystyle{ d }[/math] in linear attention, BASED can navigate the trade off of memory consumption and the ability to recall information. A larger [math]\displaystyle{ w }[/math] enhances local context capture by increases memory usage, whereas a higher [math]\displaystyle{ d }[/math] improves global context understanding with minimal memory overhead.
Application and Performance (brief overview)
In this paper, BASED models were trained on up to 1.3 billion parameters and evaluated on tasks requiring high recall. This included tasks in information extraction and reading comprehension. The architecture demonstrated performance matching or in some instances surpassing other sub-quadratic models such as MAMBA. Notably, it excelled in recall-intensive situations.
Implementations of linear attention often lag behind optimized standard attention in efficiency. To address this, the authors developed I/O aware algorithms, enabling BASED to achieve 24x higher throughput in language generation compared to other methods such as Flash-Attention-2.
Conclusion
The BASED architecture offers a pragmatic solution to the recall and throughput tradeoff in language models by combining sliding window and linear attention mechanisms.
This integration has allowed for efficient handling of both local and global dependencies. Subsequently this has resulted in models that are both memory efficient and able to perform high recall tasks, thereby advancing the development of more efficient LLM techniques.
BASED is the first linear attention model shown to match or beat Mamba on:
- Perplexity (language modeling quality)
- Real-world recall benchmarks (copying, in-context learning)
This challenges the growing belief that attention-free models like Mamba are the most scalable path forward.
Group 11 Presentation: Simple Linear Attention Language Models Balance the Recall-Throughput Tradeoff
Presented by:
Yiyuan Yang, Anar Kuatzhan, Chuan Zhang
Paper Citation
Arora, S., Eyuboglu, S., Zhang, M., Timalsina, A., Alberti, S., Zinsley, D., ... & Ré, C. (2024). Simple linear attention language models balance the recall-throughput tradeoff. arXiv preprint arXiv:2402.18668. https://arxiv.org/pdf/2402.18668
Background & Motivation
Transformer-based language models rely on attention mechanisms that require storing increasing amounts of key-value pairs (KV-cache) during inference. This makes them memory-intensive and less suitable for real-time or resource-constrained applications. The paper investigates whether it's possible to reduce memory usage while maintaining strong contextual recall capabilities—hence the "recall-throughput tradeoff."
Methodology
1. Based Architecture:
- Linear Attention: Uses a second-order Taylor approximation of softmax to maintain global token interactions with a fixed-size recurrent state.
- Sliding Window Attention (SWA): Applies exact softmax attention locally in small windows (64-128 tokens) to handle precise local shifts. This combination allows Based to navigate the recall-throughput tradeoff effectively.
- IO-Aware Optimizations: Custom CUDA kernels reduce memory movement and improve hardware efficiency, enabling 24× higher throughput than FlashAttention-2 during generation.
- The BASED Model: BASED (Bidirectional Attention with Stable Expansion and Delay) is proposed as a simple and efficient linear attention architecture. Its defining traits include:
- Linear complexity with respect to sequence length.
- No KV-cache required during inference, unlike transformers.
- Introduces a memory state updated recurrently across tokens.
- Achieves bidirectional context modeling using a fixed-size memory block.
This makes BASED models more efficient for both training and inference, especially in streaming or real-time settings.
2. Theoretical and Empirical Analysis:
- Lower Bounds: The paper proves that any recurrent model requires Ω(N)-bits in state size to solve associative recall, highlighting the fundamental tradeoff.
- Achieves up to 24× higher generation throughput compared to FlashAttention-2.
- Empirical Results: Experiments on synthetic and real-world tasks (e.g., MQAR, Pile perplexity, information extraction) show Based outperforms Mamba by 10.36 accuracy points on recall-intensive tasks while matching its perplexity.
Experimental Results
- BASED models match or outperform standard transformer models on various language modeling benchmarks such as WikiText-103 and PG-19.
- They show strong performance on long-context tasks, including copy and retrieval tasks, indicating good memory recall.
- BASED demonstrates superior throughput, especially in inference without KV caching.
Key Findings
- Efficient memory usage and fast inference are achievable without sacrificing much performance.
- Linear attention models like BASED can serve as a viable alternative to transformers in memory-constrained or latency-sensitive applications.
- There exists a tradeoff surface between recall and throughput, and BASED models lie on an efficient frontier of that tradeoff.
Conclusion
Based expands the Pareto frontier of the recall-throughput tradeoff by combining simple, well-known techniques (linear and sliding window attention) with hardware-aware optimizations. The results suggest that efficient models can achieve high recall without sacrificing throughput, offering a promising direction for future language model architectures.
Group 12 Presentation: EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees
Presenters
Mutong Zhang, Hanqing Bi
Paper Citation
Y. Li, F. Wei, C. Zhang, and H. Zhang, ‘EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees’, Jun. 30, 2024, arXiv: arXiv:2406.16858. doi: 10.48550/arXiv.2406.16858.
https://arxiv.org/abs/2406.16858
Background
- LLMs to date have shown great performance, but they are slow and computationally expensive
- Speculative sampling - small model generates candidate tokens, whereas large model then evaluates those tokens, reducing the number of times the expensive computations of the large model have to occur
- EAGLE 1 used a draft tree to improve performance of speculative execution, and this builds on the author's prior work
- The authors observe that acceptance rates are also context-dependent, suggesting the need for adaptive drafting.
Paper Contributions
- EAGLE 1 doesn't directly predict tokens, and rather maps tokens to features, predicts features, and then predicts tokens from those features
- This work was shown to improve upon previous speculative execution methods
- Eagles uses a tree structure to propose alternative token when speculative sampling draft token is rejected by the full model
- EAGLE 2 noted that token acceptance is dependant on context and position. The first token has a high acceptance rate, and later tokens have lower acceptance rates
- EAGLE 2 uses a dynamic draft trees which incorporate "tree attention" to incorporate context information into selecting the next candidate token, increasing the acceptance rate of the token, as it depends on context as well as not only position
Summaries of Key Points
Addressing the Challenge of Slow LLM Inference
Large language models (LLMs) have revolutionized natural language processing, but their inference remains computationally expensive and slow. This bottleneck arises due to the vast number of model parameters and the sequential nature of token generation. Improving inference efficiency without sacrificing accuracy is a critical research direction in the field.
The Foundation: Speculative Sampling and Draft Models
EAGLE-2 builds on speculative sampling, a technique that leverages a smaller, lightweight draft model to propose candidate tokens, which are then verified by the full LLM. This approach speeds up inference by reducing the number of computations performed by the large model.
Evolution from EAGLE-1 to EAGLE-2
The original EAGLE-1 introduced a static draft tree, assuming that token acceptance rates depend solely on their position within the tree structure. However, this assumption overlooks the context-dependent nature of token acceptance. EAGLE-2 addresses this limitation by introducing a dynamic draft tree adjustment mechanism, which refines speculative sampling based on real-time confidence scores.
Key Mechanisms of EAGLE-2
1. Instead of directly predicting tokens, EAGLE-2’s draft model generates feature representations, which are then processed by the head of the LLM to produce token predictions. This method enhances accuracy while maintaining efficiency.
2. EAGLE-2 employs a tree-based verification mechanism, which is computationally more efficient than the standard chain-structured verification used in traditional speculative sampling. By verifying multiple candidate tokens in parallel, it accelerates the overall process.
3. A core innovation in EAGLE-2 is its ability to dynamically modify the draft tree based on confidence scores from the draft model. This ensures that speculative sampling adapts to varying contexts, improving token acceptance rates and overall inference speed.
Dynamic Expansion and Re-Ranking
1. Expansion Phase: EAGLE-2 introduces a novel expansion phase, where a tree-attention mechanism processes all tokens in a layer simultaneously. This significantly enhances efficiency compared to sequential processing. Additionally, selective expansion prioritizes only the top-k tokens with the highest estimated global acceptance probabilities, preventing unnecessary computational overhead.
2. Re-Ranking Phase: Following expansion, EAGLE-2 reranks candidate tokens by selecting the top-m tokens with the highest acceptance probabilities. In cases where multiple tokens have similar scores, shallower nodes in the draft tree are prioritized, further optimizing verification speed.
Experimental Results: Significant Performance Gains
EAGLE-2 achieved acceleration ratios of 3.05x - 4.26x across various tasks and large language model series such as Kuna, Llama 2 and Llama 3, making it 20% - 40% faster than EAGLE-1. It is also approximately 2 times faster than Medusa and 2.3 times faster than Lookahead. On token throughput, EAGLE-2 processes 4-5.5 tokens per verification cycle, about twice as many as traditional speculative sampling.
Key Advantages of EAGLE-2
1. Plug-and-Play Efficiency – EAGLE-2 requires no additional model training, as it seamlessly integrates the pre-trained draft model from EAGLE-1. It does not alter the original LLM, and maintains the exact same output distribution as greedy decoding (i.e., it is lossless).
2. Robust and Reliable – Unlike some acceleration techniques, EAGLE-2 does not modify the original model parameters or relax acceptance conditions, ensuring stable and consistent outputs.
3. Broad Generalization – The framework generalizes well across different tasks and architectures, demonstrating strong adaptability in diverse applications.
EAGLE-2 represents a significant advancement in accelerating LLM inference. By introducing a dynamic draft tree, efficient expansion strategies, and intelligent token re-ranking, it substantially reduces computational costs while maintaining accuracy. As large-scale models continue to grow, techniques like EAGLE-2 will be instrumental in making LLMs more practical and accessible for real-world applications.
Accomplishments
- State-of-the-art inference speedup:
- EAGLE-2 achieves up to 5× speedup in language model inference.
- Outperforms prior speculative decoding methods including EAGLE, Medusa, Lookahead, and others.
- Longest average acceptance length:
- EAGLE-2 generates longer sequences per accepted draft, reducing the number of calls to the target model.
- Wide applicability:
- Tested across 6 diverse tasks, including:
- Conversation
- Code generation
- Mathematical reasoning
- Question answering
- Summarization
- Instruction following
- Tested across 6 diverse tasks, including:
- Model-agnostic design:
- Compatible with popular LLMs such as:
- Vicuna
- LLaMA2-Chat
- LLaMA3-Instruct
- Compatible with popular LLMs such as:
Group 12 Presentation: EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees
Presenters
Mutong Zhang, Hanqing Bi
Paper Citation
Y. Li, F. Wei, C. Zhang, and H. Zhang, ‘EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees’, Jun. 30, 2024, arXiv: arXiv:2406.16858. doi: 10.48550/arXiv.2406.16858.
https://arxiv.org/abs/2406.16858
Constructive Critique and Review
The paper “EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees” introduces an innovative approach to enhancing the efficiency of Large Language Model (LLM) inference through the implementation of a context-aware dynamic draft tree.
This paper makes two main contributions to the field:
1). Dynamic Draft Tree Structure
Building upon the original EAGLE framework, EAGLE-2 replaces the static draft tree with a dynamic architecture that adapts on context. This adjustment acknowledges that the acceptance rate of draft tokens is influences not only by their position, but also by the surrounding context which leads to a more efficient token generation.
2). Utilization of Well Calibrated Draft Models
The paper also reveals that the draft model's confidence scores closely approximate the acceptance rates of draft tokens. As such, by leveraging this calibration, EaGLE-2 effectively predicts which tokens are more likely to be accepted, optimizing the entire drafting process and token generation.
Performance Outcomes
Extensive evaluations have been conducted access three series of LLMs, those of Vicuna, LLaMA2-Char and LLaMA3-Instruct, as well as on six diverse tasks, including multi-turn conversations, code generation and mathematical reasoning.
The results reveal that EAGLE-2 achieves speedup ratios of 3.05x to 4.26x, nearly a 20% to 40% improvement over EAGLE-1. Notably, this acceleration is achieved without altering the distribution of the generated text, ensuring the fidelity of the model's outputs.
Advancements within the field
EAGLE-2 makes significant advancements in the realm of LLM inference and especially optimization. By introducing a content-aware dynamic draft tree, the paper addresses the limitations of the previous speculative sampling architecture of EAGLE-1, which is a static in nature.
This innovation enhances the acceptance rate of draft tokens, thereby reducing inference latency and computational costs. Additionally, the approach maintains the integrity of the generated text distributions which distinguishes itself from other acceleration models that have compromised outputs.
Conclusion
The methodologies and findings presented in this paper offer a substantial contribution to the field of ML and most notably in optimizing the efficiency of the generative LLMs. The introduction of dynamic, context-aware drafting mechanisms sets a new benchmark for speculative sampling techniques paving the way for more responsive, fast, and cost-effective LLM applications.
Group 13 Presentation: Linear Attention Mechanism: An Efficient Attention for Semantic Segmentation
Presented By
Yuke Liu, Mei Si
Paper Citation
R. Li, J. Su, C. Duan, and S. Zheng, ‘Linear Attention Mechanism: An Efficient Attention for Semantic Segmentation’, Aug. 20, 2020, arXiv: arXiv:2007.14902. doi: 10.48550/arXiv.2007.14902.
https://arxiv.org/abs/2007.14902
Background
- Existing transformer models have [math]\displaystyle{ \mathcal{O} (n^2) }[/math] complexity, which is problematic as model size grows
- This limits the growth of model sizes due to computational resources constraints
- This paper focused on an alternative method to conventional dot product attention that is more computationally efficient
- Standard attention required the computation of [math]\displaystyle{ Q K^\top }[/math] which requires [math]\displaystyle{ \mathcal{O} (n^2) }[/math] complexity
- The paper proposes a linear attention mechanism that solves the problem while keeping the performance.
Technical Contributions
Rather than doing the full computation for the softmax in the transformer architecture, the authors instead compute
[math]\displaystyle{ D(\mathbf{Q}, \mathbf{K}, \mathbf{V})_i = \frac{\sum_{j=1}^{N} e^{\mathbf{q}_i^{T} \mathbf{k}_j} \mathbf{v}_j}{\sum_{j=1}^{N} e^{\mathbf{q}_i^{T} \mathbf{k}_j}} = \frac{\sum_{j=1}^{N} \text{sim}(\mathbf{q}_i, \mathbf{k}_j) \mathbf{v}_j}{\sum_{j=1}^{N} \text{sim}(\mathbf{q}_i, \mathbf{k}_j)} }[/math]
and define the transformation function as
[math]\displaystyle{ \text{sim}(\mathbf{q}_i, \mathbf{k}_j) = \phi(\mathbf{q}_i)^{T} \phi(\mathbf{k}_j) }[/math]
The authors apply a first-order Taylors series expansion, and after some rearranging and substiution arrive to their final formula (full derivation not shown)
[math]\displaystyle{ D(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \frac{ \sum_j \mathbf{v}_{i,j} + \left( \frac{\mathbf{Q}}{\lVert \mathbf{Q} \rVert_2} \right) \left( \left( \frac{\mathbf{K}}{\lVert \mathbf{K} \rVert_2} \right)^{T} \mathbf{V} \right) }{ N + \left( \frac{\mathbf{Q}}{\lVert \mathbf{Q} \rVert_2} \right) \sum_j \left( \frac{\mathbf{K}}{\lVert \mathbf{K} \rVert_2} \right)_{i,j}^{T} } }[/math]
The authors form of the attention mechanism can be solved in [math]\displaystyle{ \mathcal{O} (n) }[/math] complexity, reducing the computational scaling that existing with conventional transformers and enabling the creation of larger models using the underlying attention mechanism.
- Hybrid Attention Mechanisms for Selective Efficiency
Idea: Combine linear attention with local or sparse attention in a hybrid structure where the system dynamically chooses the best mechanism depending on context (e.g., object density, resolution, or modality). The rationale is that the dense areas of an image (e.g., urban environments) may benefit from sparse or local attention for spatial precision. Less complex regions (e.g., sky, ocean) can rely on linear attention for efficiency. It could be implemented with a gating module that selects or blends attention types.
Model Performance Evaluation
The model performance evaluation primarily focused on assessing how well the proposed linear attention mechanism enhances semantic segmentation performance while maintaining computational efficiency.
- Dataset: The experiments were conducted using the Fine Grained Image Data set, which consists of high-resolution satellite images. This dataset was selected due to its complex landscapes and varied environmental conditions, presenting significant challenges.
- Data Preprocessing: Due to the large size of the images in the dataset, each image was divided into smaller patches of 256x256 pixels, resulting in a total of 7,280 patches.
- Data Split: These batches were methodically partitioned into 60% for training, 20% for validation, and 20% for testing. This split ensured a rigorous evaluation of the model's performance across different scenarios.
- The standard dot-product attention is replaced by the linear attention mechanism in the baseline segmentation models such as U-Net, Res101, DeepLab, PSPNet, etc
- Implementation Details: The experiments were implemented using PyTorch framework and trained on an NVIDIA RTX 2080Ti GPU.
- Evaluation Metrics: OA, AA, K, mloU, F1.
Comparative Analysis
The proposed linear attention mechanism achieves a similar level of accuracy in semantic segmentation tasks compared to traditional dot product attention.
Linear attention maintains comparable performance metrics while drastically reducing both memory footprint and processing time required.
The efficiency gains of linear attention become more apparent in large-scale segmentation tasks, making it suitable for high-resolution images and long text sequences.
Future work includes extending the linear attention mechanism to more complex network designs, integrating with state-of-the-art deep learning models, and optimizing for real-time processing in demanding applications.
Method | OA | AA | K | mIoU | F1 |
---|---|---|---|---|---|
U-Net | 86.378 | 74.532 | 83.357 | 64.516 | 75.532 |
U-Net LAM | 87.692 | 77.297 | 84.935 | 68.038 | 78.593 |
Res101 | 89.251 | 80.451 | 86.846 | 72.433 | 82.510 |
Res101 LAM | 90.178 | 82.757 | 88.041 | 74.085 | 83.105 |
RefineNet | 89.857 | 81.169 | 87.597 | 73.167 | 83.113 |
RefineNet LAM | 90.214 | 83.544 | 88.083 | 74.973 | 84.311 |
DeepLab | 89.388 | 80.905 | 87.079 | 71.809 | 81.077 |
DeepLab LAM | 89.576 | 81.692 | 87.287 | 72.827 | 82.702 |
DeepLabV3+ | 90.125 | 81.483 | 87.959 | 72.668 | 81.492 |
DeepLabV3+ LAM | 90.315 | 81.695 | 88.182 | 73.727 | 82.736 |
PSPNet | 90.573 | 82.211 | 88.485 | 74.797 | 83.761 |
PSPNet LAM | 90.725 | 83.088 | 88.677 | 75.695 | 84.480 |
FastFCN | 90.336 | 83.625 | 88.221 | 74.364 | 83.704 |
FastFCN LAM | 90.835 | 83.075 | 88.769 | 75.174 | 84.023 |
Group 13 Presentation: Linear Attention Mechanism: An Efficient Attention for Semantic Segmentation
Presented By
Yuke Liu, Mei Si
Paper Citation
R. Li, J. Su, C. Duan, and S. Zheng, ‘Linear Attention Mechanism: An Efficient Attention for Semantic Segmentation’, Aug. 20, 2020, arXiv: arXiv:2007.14902. doi: 10.48550/arXiv.2007.14902.
https://arxiv.org/abs/2007.14902
Introduction
(1) Attention Mechanism
The attention mechanism gained prominence for its capacity to refine feature maps and to capture long-range dependencies in both computer vision and natural language processing. By focusing computational resources on the most relevant portions of the input data, attention mechanisms have proven especially valuable in tasks that demand context from large input sizes—such as high-resolution images or lengthy text sequences.
(2) Dot Product Attention
A widely used form of attention is the dot product attention. It forms the core of many state-of-the-art models, including Transformers in NLP and non-local modules in computer vision. Despite its strengths, dot product attention scales poorly (quadratically) with input size in both memory and computational cost.
(3) Problem Statements
Because the dot product attention mechanism requires [math]\displaystyle{ O(N^2) }[/math] operations for inputs of length [math]\displaystyle{ N }[/math] , deploying it on large inputs—e.g., high-resolution images or long sequences—becomes infeasible. This computational bottleneck has motivated research into more efficient variants of attention, particularly approaches that reduce the quadratic cost.
Overview of Attention Mechanismn
(1) Scaling Attention
In contrast to dot product attention, “scaling” attention mechanisms (sometimes also called “channel” or “spatial” attention in the literature) aim to emphasize informative features while reducing redundant information. Examples include: Squeeze-and-Excitation (SE) modules, which learn channel-wise scaling factors. Convolutional Block Attention Module (CBAM), which extends SE by considering both channel and spatial attention. These approaches do not necessarily learn long-range dependencies across positions in the same way as dot product attention; instead, they focus on enhancing or diminishing particular channels or spatial locations.
(2) Linear Attention: Taylor Expansion → Linear Time
Recent research efforts address the [math]\displaystyle{ O(N^2) }[/math] complexity of dot product attention by re-formulating the similarity function (the key step in attention). One family of methods uses kernel-based approximations, while another employs mathematical expansions. By approximating the exponential term of the softmax function (commonly used in dot product attention) with a first-order Taylor expansion, it is possible to make attention computations linear with respect to [math]\displaystyle{ N }[/math]. This insight forms the basis for linear attention mechanisms.
(3) Semantic Segmentation
Semantic segmentation requires dense predictions over every pixel of an image, and capturing global context is crucial. Traditional convolution-based networks, whether they follow a DilatedFCN (e.g., DeepLab, PSPNet) or an Encoder-Decoder (e.g., U-Net) architecture, benefit significantly from attention modules that can encode relationships across spatial locations. However, if the resolution is high, dot product attention quickly becomes prohibitively expensive, motivating more efficient variants like linear attention.
Dot Product Attention Details
(1) Query, Key, and Value
Dot product attention transforms each input feature [math]\displaystyle{ x_i }[/math] into three different representations:
Query([math]\displaystyle{ q_i }[/math]) Key([math]\displaystyle{ k_i }[/math]) Value([math]\displaystyle{ v_i }[/math])
The core operation then measures similarity between every query–key pair to weight the contribution of all values in forming an output.
(2) Kernel-Based Approximation
A known strategy to reduce complexity is to view the exponential term in dot product attention as operating in a kernel space, thereby factorizing the softmax attention to achieve lower complexity. Various works replace the explicit softmax with kernel-based transformations, enabling a more efficient computation.
Linear Attention Mechanism
Linear attention leverages a Taylor expansion to approximate the exponential term. By applying additional constraints (such as [math]\displaystyle{ L_2 }[/math] normalization to keep the term non-negative), the resulting formulation scales as [math]\displaystyle{ O(N) }[/math] rather than [math]\displaystyle{ O(N^2) }[/math]. This makes the mechanism practical for high-resolution inputs or very long sequences, significantly broadening the usability of attention in semantic segmentation and beyond.
Experimental Settings
(1) Dataset
Many experiments on linear attention for segmentation are conducted using large, high-resolution datasets. One common benchmark highlighted is a satellite imagery dataset (e.g., Fine Gaofen Image Dataset, GID). These datasets typically comprise large aerial images that are partitioned into patches, with splits for training, validation, and testing.
(2) Model Implementations Baseline segmentation networks (e.g., PSPNet, DeepLab series, U-Net variants, FastFCN, and RefineNet) integrate the proposed linear attention modules in place of, or alongside, standard attention mechanisms. The training setup often employs standard optimizers (such as Adam), cross-entropy loss, and hardware accelerators (like NVIDIA GPUs).
(3) Evaluation Metrics Common segmentation metrics include:
Overall Accuracy (OA) Average Accuracy (AA) Kappa Coefficient (K) mean Intersection over Union (mIoU) F1 Score
Evaluation metrics
Mean Intersection over Union (mIoU)
[math]\displaystyle{ IOU = \frac{TP}{TP+FP+FN} }[/math]
[math]\displaystyle{ mIoU = 1/N \sum_{i=1}^{N} IoU_i }[/math]
- Measures the average overlap between predicted segmentation and ground truth across all classes.
- Commonly used in semantic segmentation benchmarks.
Why It’s Important in This Paper:
- Serves as a primary benchmark for assessing spatial accuracy of LAM.
- Especially meaningful in pixel-level classification, where precise boundaries matter.
- Used to compare LAM-enhanced networks against traditional attention-based and baseline CNN models.
- LAM achieves competitive or improved mIoU across multiple medical image segmentation datasets (DRIVE, STARE, CHASE_DB1), validating its contextual understanding with reduced computational cost.
Kappa coefficient [math]\displaystyle{ \kappa = \frac{p_0 - p_e}{1-p_e} }[/math]
- Assesses statistical agreement between predicted segmentation and ground truth, adjusting for agreement by chance
Why It’s Important in This Paper:
- Medical segmentation tasks often suffer from class imbalance (e.g., small vessels vs large background).
- Kappa offers a robust metric that accounts for imbalanced distributions, unlike plain accuracy.
- The paper reports high Kappa values for LAM-based models, showing that they make meaningful predictions beyond chance, even when foreground (vessel) pixels are rare.
- This supports that LAM is not just "overfitting" to majority classes but learns class-relevant structure effectively.
Results and Experimental Improvement
In benchmark experiments, linear attention typically boosts the performance of various baseline segmentation models while reducing memory and computational overhead. For example, when embedded into U-Net, PSPNet, or DeepLab, linear attention can achieve a higher mIoU and Kappa coefficient compared to the original dot product attention. These gains confirm that the approximation introduced by the Taylor expansion still captures sufficient global context for accurate segmentation.
Comparative Analysis
Limitations:
(1) Approximation Error: The first-order Taylor expansion introduces approximation errors relative to exact dot product attention. Although experiments show minimal performance degradation, in certain tasks or extreme input scales, further refinements or higher-order terms might be necessary.
(2) Architecture Constraints: Integrating linear attention can require modifications to the existing network design, including normalization steps and careful parameter initialization.
Conclusion and Future Work
Linear attention mechanisms substantially reduce both memory usage and computational cost in high-resolution vision tasks, making them a promising alternative to dot product attention. Future research may involve:
Extending linear attention to multi-modal domains where extremely large inputs (e.g., video or multi-spectral data) are common.
Investigating higher-order approximations that may yield even more accurate results while retaining near-linear scalability.
Combining linear attention with other efficient modules (e.g., lightweight convolutions or quantization techniques) to further push the boundaries of real-time segmentation on resource-constrained devices.
Group 14 Presentation: Scalable Watermarking for Identifying Large Language Model Outputs
Presented by:
Ryan Tymkow and Benjamin Schnapp
Paper Citation
Dathathri, S., See, A., Ghaisas, S. et al. Scalable watermarking for identifying large language model outputs. Nature 634, 818–823 (2024). https://doi.org/10.1038/s41586-024-08025-4.
Summaries of key points
This paper tackles the problem of watermarking—embedding a detectable signal into the outputs of large language models (LLMs) to distinguish generated text from human-written content. The challenge is doing this in a way that is robust, scalable, and minimally intrusive to the model's output quality. The authors propose a method based on statistical biasing of token selection during generation. Specifically, they partition the vocabulary into “greenlist” and “redlist” tokens at each step based on a seeded hash function, and then bias sampling toward the greenlist using a tunable parameter. The watermark is invisible in individual outputs but detectable over longer texts using hypothesis testing. Importantly, the approach is model-agnostic, doesn’t require retraining, and adds very little computational overhead. It also scales well to large models and can be applied in high-throughput or real-time settings. Overall, it’s a lightweight yet effective strategy for watermarking that balances detectability, scalability, and usability.
A key limitation of this approach is that it may still be vulnerable to paraphrasing or text transformation—simple rewriting could break the statistical signature. Another concern is adversarial robustness: since the watermarking method is relatively transparent (based on vocabulary partitioning), a knowledgeable attacker could design strategies to erase or spoof the signal. Additionally, while the method maintains fluency and quality in most cases, biasing token selection could subtly affect stylistic or semantic nuances, especially in creative writing or long-form tasks. The paper doesn’t deeply explore how this might influence user-facing applications like chatbots or summarizers. Lastly, while the watermark is statistically detectable, it’s not embedded in a cryptographic sense, so it may not offer strong guarantees in high-stakes verification contexts.
Clear explanations to aid understanding
Imagine if every time a language model generated text, it subtly preferred certain words over others—but in a way that's invisible to readers. That’s what this watermark does. At each generation step, the model hashes its current context to choose a “greenlist” of preferred tokens and then slightly boosts their probabilities. Over many words, these choices form a statistically detectable pattern. It's like nudging a roulette wheel so certain numbers come up just a bit more often—not enough to be obvious, but enough to spot if you know where to look. The method is efficient and easy to integrate, since it works at the sampling level and doesn't require modifying the model architecture.
Review
I thought that the visual aids in this presentation greatly helped explain the process of how synthid alters the sampling process and how it embeds watermarks during the process, helping distinguish outputs created by the LLM vs outputs not created by the LLM. The graphic showing the three main parts of the process, the random seed generator, the sampling algorithm, and the scoring function, made it simpler to understand the whole process. The examples with the fruits, first generating random watermarking functions, then going through the tournament to sample the output token, also made it really easy to follow along on what exactly is going on.
Group 14 Presentation: Scalable Watermarking for Identifying Large Language Model Outputs
Presented by:
Ryan Tymkow, Benjamin Schnapp
Paper Citation
Dathathri, S., See, A., Ghaisas, S. et al. Scalable watermarking for identifying large language model outputs. Nature 634, 818–823 (2024). https://doi.org/10.1038/s41586-024-08025-4
Background
With the rise of LLMs and generative AI, there is a risk of spreading AF generated misinformation as if it were from a human. It is important to distinguish AI-generated text from human writing. There exists solutions to address this problem, but they all have limitations involving privacy and computational costs. For example, the traditional watermarking can result in unwanted artifacts in the text.
Previous Approaches
Some of the existing approaches are Retrieval Based Tracking, Post-Hoc Detection, and Traditional Watermarking.
Retrieval Based Tracking stores generated LLM responses in a reference database to determine whether a newly generated response is from the tracked LLM. The issue with this is scalability and privacy since you are storing generated responses.
Post-Hoc Detection uses statistical features of LLM generated text. The issue with this is the high computational cost and as LLM continues to improve to be more human like with the output, the more difficult this approach becomes.
Traditional watermarking would include hidden features in the text, for example replace synonyms or unicode characters. But tinkering with the output will degrade the quality of the original LLM.
Technical Contributions
This paper introducted SynthID-Text, a watermarking method for large language models (LLMs). The method uses a Tournament sampling approach, which ensures that generated text contains a detectable watermark with minimal computational overhead. SynthID-Text incorporates a random seed generator and scoring functions to embed a watermark into the model's output. This technique enhances the ability to identify if text originates from a specific LLM, while preserving the text's quality and minimizing distortion. SynthID-Text does not affect LLM training. It can be configured as distortionary or non-distortionary.
Central to SynthID-Text is the novel Tournament sampling procedure. Rather than sampling each token directly from the LLM's distribution, multiple candidate tokens compete in a multi-layer "tournament" based on pseudorandom watermarking scores, embedding a statistical “signature” that can later be detected.
Results
Synth ID achieved a 95% true positive detection rate with 1% false positives in a 20 million interaction test on Google's Gemini chatbot. It offers high detection accuracy with minimal computational cost and can be configured for non-distortionary or distortionary watermarking.
The benefits are as follows:
Minimal impact on large language model training:
-Synth ID text can be applied to any large language model with minimal modifications, as it only alters the sampling step.
High detection accuracy with low computational cost:
-Outperforms retrieval-based tracking, post hoc detection, and traditional watermarking methods.
-Offers the best balance between computational cost, scalability, and accuracy.
-Can be integrated into production environments using speculative sampling, where smaller models suggest tokens and the main model verifies their validity.
Configurable distortion levels:
-Allows for distortionary or non-distortionary configurations, enabling better control over the quality of generated text versus detection accuracy.
-In non-distortionary watermarking, the average token distribution of the generated text matches the original model's token distribution.
Group 15 Presentation: DiGress: Discrete denoising diffusion for graph generation
Presented by:
Sean Tang, Buji Wong
Paper Citation
Vignac, C., Krawczuk, I., Siraudin, A., Wang, B., Cevher, V., & Frossard, P. (2022). Digress: Discrete denoising diffusion for graph generation. arXiv preprint arXiv:2209.14734.
Background
Graph generation
The goal of this project is to generate graphs, which are represented by node matrices and edge matrices. Edges and nodes can also have their own categories. One application of this is molecule generation: atoms would be represented by nodes and the chemical bonds would be represented by edges.
The challenge of graph generation is a complex task due to the unordered nature and sparsity of graphs. While denoising diffusion models have been successful in other domains like images, they struggle with graphs due to their structural properties. Existing approaches that use continuous noise models disrupt the sparsity and connectivity crucial for graphs.
Discrete diffusion
Regular diffusion methods involve a forward process (where noise is gradually added to the data) and a neural network that is trained to predict the backwards process (where random noise is gradually turned back into something that would plausibly fit in with the dataset). Most of the time, diffusion models use Gaussian noise.
This graph generation problem involves discrete data, therefore, a discrete noise model is preferred for the diffusion process. (Note: The authors actually previously tried to use Gaussian noise, but the continuous nature of the Gaussian function meant that the discreteness of the graphs was not respected.) In this discrete case, however, the diffusion is defined in terms of transition matrices, [math]\displaystyle{ Q^t }[/math]:
[math]\displaystyle{ q(x_t | x_{t-1}) = x_{t-1} Q^t }[/math]
Each transition matrix is applied to each node and edge; this allows the nodes and edges to remain in discrete space.
Technical Contributions
Overview of DiGress
The authors introduce DiGress, a discrete denoising diffusion model designed specifically for graph generation with categorical node and edge attributes. DiGress improves graph generation by using a discrete noise model that preserves graph sparsity and structural properties. The model involves progressively editing graphs through edge addition/removal and attribute changes. A graph transformer network is used to reverse this noisy process using cross-entropy loss, sampling from the trained model by iteratively updating the noise level and computing structural features.
Key enhancements include a noise model that maintains node and edge type distributions, a guidance procedure for conditioning on graph-level features, and the use of auxiliary graph-theoretic features. DiGress achieves state-of-the-art performance on both molecular and non-molecular graph datasets.
Denoising network architecture
- The input is a noisy graph, expressed as [math]\displaystyle{ X }[/math] (a tensor of one-hot encoded nodes) and [math]\displaystyle{ E }[/math] (a tensor of one-hot encoded edges).
- The structural and spectral features of these graphs are calculated.
- The noisy graph along with the structural and spectral features are taken as input into a multi-layer perceptron (MLP) structure.
- Next is a sequence of graph transformer layers.
- After another MLP layer, the output is a distribution over nodes and edges.
Summary of results
Results showed Digress outperformed continuous diffusion methods on various metrics, including degree distribution, clustering, and novelty, and was more scalable for larger graphs. Moreover, in the creation of novel molecules, discrete diffusion aids scalability for larger graphs and molecules, making it more efficient compared to continuous diffusion. DiGress is the first one-shot graph-based model that feasibly trains on over a million molecules without fragment-based heuristics. Its performance on drug-like molecule benchmarks reaches or exceeds that of specialized autoregressive or SMILES-based baselines.
Accomplishments
- State-of-the-art one-shot graph generation:
- DiGress outperforms existing non-autoregressive models such as GDSS, GraphNVP, and SPECTRE.
- Achieves strong performance on benchmarks including Stochastic Block Models (SBM) and planar graphs.
- Scalable molecular graph generation:
- DiGress is the first diffusion-based model to scale to large molecular datasets such as:
- MOSES (small, drug-like molecules)
- GuacaMol (1.3M large molecules)
- Matches or exceeds autoregressive and fragment-based baselines on metrics such as validity, uniqueness, novelty, and scaffold diversity.
- DiGress is the first diffusion-based model to scale to large molecular datasets such as:
- Fast and efficient performance on QM9:
- Achieves near-perfect:
- Validity: 99%
- Uniqueness: ~96%
- Trains significantly faster than prior diffusion models (1 hour vs. 7.2 hours for ConGress).
- Achieves near-perfect:
- Effective conditional generation:
- Uses regressor-guided sampling to generate molecules with desired properties (e.g., HOMO energy, dipole moment).
- Outperforms unconditional models in property control and accuracy.
- Theoretical soundness:
- Proven permutation equivariance and exchangeability for correct graph generation.
- Provides an ELBO-based training objective for likelihood estimation and model comparison.
Group 15 Presentation: DiGress: Discrete denoising diffusion for graph generation
Presented by:
Sean Tang, Buji Wong
Paper Citation
Vignac, C., Krawczuk, I., Siraudin, A., Wang, B., Cevher, V., & Frossard, P. (2022). Digress: Discrete denoising diffusion for graph generation. arXiv preprint arXiv:2209.14734.
Summaries of key points
Motivation
Recall that in class we already learned diffusion models, they work by adding noise to data gradually (forward process) and training a model to reverse the process, restoring the clean data (reverse process). Traditional diffusion models are designed for continuous data, but graphs are genereally discrete. This discrete structure makes it difficult to directly apply traditional diffusion models to graphs.
DiGress tackles this issue by introducing a method that applies diffusion to discrete graphs through the concept of discrete diffusion.
Forward Diffusion Process: How to Add Discrete Noise?
As mentioned above, DiGress uses discrete noise. In DiGress, noise is added by changing the categories of nodes and edges (such as atom types or bond types). This means the model randomly selects new categories for nodes and edges from a predefined set of valid categories. To guide this noise addition, DiGress uses a transition matrix. The transition matrix defines the probability of each node and edge transitioning between categories at each step. This ensures that while the graph becomes noisier over time, it still maintains its structure and validity.
Reverse Diffusion Process: How to Denoise a Graph?
In the reverse process, DiGress gradually removes the noise from the graph. Instead of randomly undoing the noise, the model uses a graph transformer network, designed for graph-structured data. This network helps the model recognize the structure of the graph and the relationships between nodes and edges. During each step, the model focuses on the most relevant parts of the graph, predicting the correct categories for nodes and edges. And the model’s predictions are guided by cross-entropy loss (applied to both nodes and edges), which measures how accurately the model predicts the node and edge categories after denoising. By minimizing this loss, the model becomes better at removing the noise, step by step, until it recovers a valid and meaningful graph.
Conditioning Graph Generation on Desired Properties
One of the most powerful features of DiGress is its ability to condition the graph generation process on specific properties. For example, if you want the generated graph to have a certain number of oxygen atoms or satisfy some chemical property, DiGress can adjust the generation process to meet those requirements. This is done by checking the graph at each step of the sampling process and modifying it as needed to match the desired properties. This capability is particularly useful in areas like drug discovery, where the generated molecules must meet certain chemical and structural criteria.
Experimental Results
DiGress was tested on several datasets to evaluate its performance:
1. Graph Generation: On datasets like the Stochastic Block Model (SBM) and planar graphs, DiGress outperformed other methods, particularly in generating novel graphs that were not seen during training.
2. Molecule Generation: When applied to the MOSES dataset, DiGress produced more valid and novel molecules, even though it did not always surpass methods that check graph validity at every step.
3. Scalability: On larger graphs, such as those from the Guacamole dataset, DiGress demonstrated strong scalability, making it a suitable option for generating larger and more complex graphs.
Comparison with Existing Approaches
- Versus Autoregressive Models: These models (like GraphAF, GraphRNN) generate graphs node-by-node or edge-by-edge and often rely on ordering, making them slower and harder to parallelize.
- Versus Continuous Diffusion Models for Graphs: E.g., GDSS uses Gaussian noise and struggles with categorical data. DiGress handles discrete data directly, making it more suitable for molecular and symbolic domains.
- DiGress Advantage: Fully parallel sampling, no need to learn generation order, works better for discrete structured data.
Group 16 Presentation: Machine Learning and Hamilton-Jacobi-Bellman Equation for Optimal Decumulation: a Comparison Study
Presented by:
Zeyu Zhang
Paper Citation
Chen M, Shirazi M, Forsyth PA, Li Y. Machine Learning and Hamilton-Jacobi-Bellman Equation for Optimal Decumulation: a Comparison Study. Published online 2023. doi:10.48550/arxiv.2306.10582
Background
The paper is based on computational finance, focusing on the optimization problem related to "defined benefit" and "defined contribution plans". The main focus is on the challenge of ensuring retirees have enough funds for their retirement. Two key plans were discussed:
"Defined benefit plans" guarantee fixed monthly payments based on factors like tenure and salary but are cost-prohibitive and risky.
"Contribution plans" shift the investment and withdrawal strategy burden to individual investors, but they struggle to balance maximizing withdrawals and minimizing risk.
This problem, often called the "Nazi's hardest problem in finance," highlights the complexity of balancing risk and reward in financial planning for retirement.
The 4% rule is a traditional method recommending a constant 4% withdrawal each year, adjusted for inflation, and investing in stocks and bonds.
Despite its popularity, the 4% rule is suboptimal and not globally optimal
Peter Fauci proposed the HJB PDE method to maximize expected withdrawal and minimize the risk of running out of savings.
The HJB PDE method uses scalarization techniques to achieve Pareto optimal points, but it has limitations.
Technical Contributions
1. Hamilton-Jacobi-Bellman (HJB):
- The problem formulation involves complex mathematical equations related to computational finance.
- The problem uses dynamic programming to break down the optimal control problem, leading to the HJB function that represents the value function.
- The paper assumes stock and bond prices follow a jump diffusion model.
- The investors' total wealth at time [math]\displaystyle{ t }[/math] is defined as the sum of stock price and bond price at that time.
- The capital [math]\displaystyle{ T }[/math] is set to 30 years, and rebalancing times are defined with discrete withdrawal amounts and allocation for stocks and bonds.
2. Neural Network (NN): Control and Objective Function:
- The control at time [math]\displaystyle{ T_i }[/math] includes the withdrawal amount [math]\displaystyle{ Q_i }[/math] and allocation for the wealth at time [math]\displaystyle{ T_i^- }[/math].
- The admissible control set is defined, and the expected shortfall is introduced as a measure of risk.
- The expected total withdrawal is used as a measure of reward, aiming to maximize the expected total withdrawal while minimizing the expected shortfall.
- The pre-commitment in the expected shortfall problem is defined, focusing on maximizing the expected total withdrawal and minimizing the expected shortfall.
Neural Network (NN) Formulation
As an alternative to the HJB framework, the authors propose a Neural Network (NN) approach to solve the stochastic control problem. This framework has several advantages:
1. The NN approach is data-driven, meaning it avoids explicitly defining parametric models for stochastic processes. This provides flexibility and allows integration of auxiliary variables if needed.
2. It circumvents the computation of high-dimensional conditional expectations by solving a single, unconstrained optimization problem for control decisions. This avoids the curse of dimensionality often associated with dynamic programming.
3. If the optimal control is continuous in time and state, the NN reflects this property. If the control is discontinuous, the NN yields a smooth approximation, which is beneficial in practice for implementing investment policies.
4. The method is scalable, making it suitable for long horizons and high-frequency rebalancing without significantly increasing computational complexity.
The NN framework’s effectiveness lies in its ability to learn control policies directly from simulated state paths without requiring explicit knowledge of the underlying stochastic differential equations.
Instead of solving high-dimensional HJB PDEs, the neural network uses:
- Forward simulation to sample wealth trajectories under candidate policies.
- Backward evaluation to update network parameters based on performance (e.g., maximizing expected withdrawals, minimizing expected shortfall).
This model-free, data-driven method avoids dynamic programming and is especially useful in high dimensions, where solving PDEs becomes computationally infeasible.
Moreover, by designing appropriate activation functions (e.g., softmax for portfolio weights and sigmoid for withdrawal rates), the NN ensures that stochastic constraints are naturally respected throughout training and inference.
NN Approximation Setup
- The control policy [math]\displaystyle{ \mathcal{P} }[/math] is approximated using two feed-forward neural networks, with parameters [math]\displaystyle{ \boldsymbol{\theta}_q }[/math] and [math]\displaystyle{ \boldsymbol{\theta}_p }[/math], representing withdrawal and allocation strategies respectively.
- These networks take as inputs the Brownian motion path [math]\displaystyle{ W(t_i) }[/math] and time [math]\displaystyle{ t_i }[/math] to approximate control decisions:
[math]\displaystyle{ \hat{q}(W_i^-, t_i^-, \boldsymbol{\theta}_q) \approx q_i(W_i^-), \quad \hat{p}(W_i^+, t_i^+, \boldsymbol{\theta}_p) \approx p_i(W_i^+) }[/math]
- The final control policy is given by: [math]\displaystyle{ \hat{\mathcal{P}} = \{ (\hat{q}(\cdot), \hat{p}(\cdot)) \} \approx \mathcal{P} }[/math]
- The functions [math]\displaystyle{ \hat{p} }[/math] and [math]\displaystyle{ \hat{q} }[/math] use time as one of the inputs, allowing a single NN to handle decisions across all rebalancing points, rather than training separate models for each time step.
- The paper also discusses how the architecture includes activation functions that enforce stochastic constraints naturally.
Summaries of Key Notes
- Neural Network Framework for Pension Decumulation: A novel framework using neural networks (NNs) is proposed to optimize asset allocation and cash withdrawal strategies for defined contribution (DC) pension plans. Unlike traditional methods, it solves constraints efficiently via unconstrained optimization.
- Comparison with HJB Method: The NN approach achieves comparable accuracy to the Hamilton-Jacobi-Bellman (HJB) PDE method while being scalable to higher-dimensional problems and avoiding dynamic programming errors.
- Efficient Withdrawals: The NN framework closely approximates optimal "bang-bang" controls, effectively alternating between minimum and maximum withdrawals based on wealth, ensuring reliable pension decumulation.
- Robustness: Tested extensively on synthetic and historical market data, the NN solution adapts well and demonstrates strong out-of-sample and out-of-distribution performance. Advantages:
Constructive Critique
While the NN method replicates HJB-derived control policies with high accuracy, a few limitations and caveats exist:
- The training relies on synthetic data simulated from assumed models (e.g., geometric Brownian motion, jump diffusion), which may limit generalizability under real-world dynamics.
- The scalarization of the reward-risk tradeoff assumes a linear weighting of Expected Withdrawals and Expected Shortfall. In practice, retiree preferences might reflect more complex, utility-based behaviors.
- The interpretability of the learned policy is less transparent compared to explicit, closed-form control strategies derived via HJB.
- No formal convergence analysis or approximation bounds are provided for the NN solution.
Despite these challenges, the method is empirically robust and scalable, making it an appealing alternative for large-scale or real-time applications.
Related Works
The work is also related to "Spending Retirement on Planet Vulcan: The Impact of Longevity Risk Aversion on Optimal Withdrawal Rates", which introduces utility-based models for withdrawal strategies under longevity risk aversion. The models focus on retirees' behavioral preferences — particularly longevity risk aversion and the need for smooth consumption throughout retirement.
While not PDE-based, it contextualizes the trade-offs between consumption smoothing and risk of depletion, complementing the (EW, ES) approach by addressing behavioral and utility-driven objectives.
On the other hand, the benchmarkNNpaper takes a risk-sensitive stochastic control approach, optimizing a scalarized objective that balances:
Expected Withdrawals (EW) = proxy for consumption
Expected Shortfall (ES) = proxy for downside risk / depletion probability
Group 16 Presentation: Machine Learning and Hamilton-Jacobi-Bellman Equation for Optimal Decumulation: a Comparison Study
Presented by:
Zeyu Zhang
Paper Citation
Chen M, Shirazi M, Forsyth PA, Li Y. Machine Learning and Hamilton-Jacobi-Bellman Equation for Optimal Decumulation: a Comparison Study. Published online 2023. doi:10.48550/arxiv.2306.10582
Background & Motivation
The paper focuses on addressing a stochastic optimal control problem in retirement decumulation and asset allocation, which is a critical issue in financial planning. Specifically, it investigates how retirees can optimally withdraw funds from their savings (decumulation) while simultaneously managing asset allocation under uncertain market conditions.
Traditionally, rules of thumb such as the Bengen 4% Rule have been widely adopted in the financial industry to guide withdrawal strategies. However, these approaches are increasingly viewed as suboptimal, particularly in an environment characterized by volatile markets and evolving mortality patterns. Recent academic studies, such as Forsyth (2022), propose partial differential equation (PDE)-based methods that are provably convergent and optimal under specific assumptions. Nevertheless, PDE methods face significant limitations in scalability due to the curse of dimensionality, often performing well only in low-dimensional settings.
The motivation for this paper is to overcome the limitations of PDE-based approaches by leveraging neural networks (NNs) to solve the decumulation and asset allocation control problem. The authors aim to evaluate whether deep learning can accurately and robustly approximate the solution to this high-dimensional stochastic control problem, as well as whether it provides computational advantages.
Key Points
1. Problem Formulation: The paper formulates the decumulation problem as a stochastic optimal control problem, aiming to optimize a weighted sum of expected withdrawals (EW) and expected shortfall (ES) to effectively manage tail risk. Key constraints include minimum and maximum withdrawal limits, as well as no-shorting and no-leverage rules.
2. HJB Framework: The Hamilton-Jacobi-Bellman (HJB) approach employs dynamic programming to solve the problem numerically, providing a ground-truth benchmark for comparison. However, this method is computationally limited to low-dimensional problems and relies on parametric models for asset returns, which may not capture real-world complexities.
3. NN Framework: The proposed neural network (NN) framework directly approximates the control functions (withdrawal and allocation) using feed-forward networks with customized activation functions designed to enforce the specified constraints. This data-driven approach bypasses the need for dynamic programming and demonstrates scalability to higher-dimensional problems.
4. Comparative Results: On synthetic data, the NN solution achieves performance nearly identical to that of the HJB method, showcasing its high accuracy in approximating the optimal control policy, including complex "bang-bang" withdrawal strategies.
5. Robustness: The NN framework exhibits strong performance in out-of-sample and out-of-distribution tests, such as bootstrap-resampled historical data, thereby demonstrating its generalizability beyond the training distribution.
Contributions
- Demonstration of neural networks as reliable solvers for constrained stochastic control problems, which were previously addressable only through partial differential equations (PDEs).
- Quantitative benchmark comparisons between NN-based and PDE-based methods reveal near-equivalent accuracy, particularly in replicating key features such as the efficient frontier and optimal withdrawal behavior.
- The proposed approach is scalable to higher dimensions, unlike PDE-based methods, making it potentially transformative for real-world retirement planning problems that involve multiple assets or stochastic factors.
- The authors demonstrate that regularization within the NN framework helps mitigate instability in regions of the state space where the control problem becomes ill-posed (e.g., high wealth levels or near terminal time).
- The method provides continuous-time control outputs by explicitly incorporating time as an input to the network, ensuring smooth solutions when required.
Constructive Critiques
- Ill-Posed Regions: The NN and HJB solutions diverge in high-wealth regions near the terminal time due to the problem's ill-posedness. While the authors argue this has negligible impact on objectives, further analysis of how this affects real-world implementation would strengthen the paper.
- Training Complexity: The NN requires transfer learning for high κ values (weighting ES more heavily), suggesting potential instability in risk-averse scenarios. A deeper exploration of training challenges and solutions would be valuable.
- Historical Data Limitations: The bootstrap resampling tests rely on U.S. market data (1926–2019). Including non-U.S. data or stress-testing during extreme market conditions (e.g., hyperinflation) could enhance robustness claims.
- Computational Costs: While the NN avoids dynamic programming, the computational expense of training large networks is not quantified. A comparison of runtime between HJB and NN methods would clarify trade-offs.
Relationships to Other Works
This work builds on the stochastic control literature, particularly the decumulation problem studied in Forsyth (2022), which employs PDE-based methods. The current paper extends this research by providing a data-driven and high-dimensional alternative.It conceptually aligns with deep FBSDE methods, Deep Galerkin methods used for solving HJB equations, as well as reinforcement learning (RL)-based approaches to optimal control, such as the Deep Deterministic Policy Gradient.Compared to prior studies, such as Han and E (2016), Buehler et al. (2019), and Laurière et al. (2021), the current paper places emphasis on benchmarking against a well-established numerical method (PDEs), an aspect often overlooked in other NN-based control studies. The proposed method falls within the Policy Function Approximation (PFA) framework outlined in Powell (2021), providing a robust example of utilizing fixed neural networks to approximate control functions across time and state dimensions.
Group 18 Presentation: HIGEN: HIERARCHICAL GRAPH GENERATIVE NETWORKS
Presented by:
- Shiyu Zhu
- Jesse Xue
Paper Citation
M. Karami, “HiGen: Hierarchical Graph Generative Networks,” 2023, arXiv. doi: 10.48550/ARXIV.2305.19337.
Background
- Standard softmax function uses quadratic computational complexity during prediction, and linear attention mechanism leads to an underperforming model.
- Hierarchical or Multi-Scale Structure: Captures high level interactions or relationships between objects or groups, while also representing the lower level structures. One example is a company org chart.
- Existing graph generating models include: Variational Autoencoders, Generative Adversarial Networks, Autoregressive Models (GNN, Graph RNN, GRAN) and Diffusion Models
- Paper introduces HIGEN, Hierarchical Graph Generative Networks to address problems with existing models.
- Experiments were conducted on 5 datasets, each with increasing size and scale. The GraphRNN, GRAN, DiGress, GDSS, and SPEC
Technical Contributions
- Related Hierarchical Methods: The presentation discusses several recent hierarchical methods in specific domains like chemistry, highlighting HiGen's broader applicability, multi-level approach, and parallel generation as advantages over these more specialized techniques. These include a multi-based generation for molecular graphs (2020) relying on domain knowledge, a hierarchical normalizing flow model (2021) based on local neighborhoods, and a tree decomposition framework (2022) limited to a single abstraction level and medium-sized graphs.
Definition: Hierarchical Graph
A Hierarchical Graph is a multi-level representation of a graph [math]\displaystyle{ \mathcal{G} = (\mathcal{V}, \mathcal{E}) }[/math] where:
- [math]\displaystyle{ \mathcal{V} }[/math] is the set of nodes (vertices), and [math]\displaystyle{ \mathcal{E} }[/math] is the set of edges, with sizes [math]\displaystyle{ n = |\mathcal{V}| }[/math] and [math]\displaystyle{ m = |\mathcal{E}| }[/math].
- A node partition function [math]\displaystyle{ \mathcal{F}: \mathcal{V} \rightarrow \{1, ..., c\} }[/math] groups nodes into [math]\displaystyle{ c }[/math] communities or clusters.
- Each cluster forms a subgraph [math]\displaystyle{ \mathcal{C}_i = (\mathcal{V}(\mathcal{C}_i), \mathcal{E}(\mathcal{C}_i)) }[/math] with adjacency matrix [math]\displaystyle{ A_i }[/math].
- Cross-links between communities form bipartite graphs [math]\displaystyle{ \mathcal{B}_{ij} = (\mathcal{V}(\mathcal{C}_i), \mathcal{V}(\mathcal{C}_j), \mathcal{E}(\mathcal{B}_{ij})) }[/math].
- Each cluster is aggregated into a super-node and each bipartite into a super-edge at the next higher level. This forms a coarser graph at the parent level.
Methodology
HiGen (Hierarchical Graph Generative Networks) adopts a coarse-to-fine approach for generating graphs with complex hierarchical structures. The method consists of the following components:
1. Hierarchical Graph Construction
The input graph is recursively partitioned into communities using the Louvain clustering algorithm, forming a multi-level hierarchy. Each cluster (community) is abstracted as a "super-node" at the next level, and cross-community connections become "super-edges."
2. Community and Bipartite Generation
* Community Generation: Each community is generated independently using a GNN-based autoregressive model that factorizes the edge distribution into multinomial components. A mixture of multinomials is used to model intra-community edge structures. * Bipartite Generation: Once communities are generated, bipartite graphs (cross-community edges) are created using a separate neural network. These inter-cluster edges are predicted with a parallel GNN model using a similar factorization strategy.
3. Autoregressive Probabilistic Modeling
HiGen decomposes the joint probability of the graph into a product of conditional multinomial distributions. Both community and bipartite graphs are generated step-by-step using a recursive generation process guided by node and cluster embeddings.
4. Parallelism and Invariance
The hierarchical structure allows parallel generation of communities and bipartite edges at the same level, improving efficiency. The model is also invariant to node ordering within clusters, which improves robustness and scalability.
This design enables HiGen to generate large and complex graphs while maintaining global structure and fine-grained local details. It supports realistic synthesis across diverse graph types including molecules, biological networks, and social systems.
Summaries of Key Notes
- HiGen: Hierarchical Graph Generative Networks (HiGen) is a novel graph generation model that addresses the limitations of existing methods by leveraging hierarchical structures in graphs. It generates substructures in a coarse-to-fine manner, modeling the interactions between communities and cross-edges at multiple levels of abstraction.
- Community and Bipartite Generation: HiGen utilizes modular generation, creating communities in parallel followed by predictions of cross-edges with separate neural networks, ensuring scalability and efficiency. It employs multinomial distributions with recursive factorization, enabling autoregressive generation of integer-valued edge weights within communities.
- State-of-the-Art Performance: HiGen demonstrates superior graph quality across benchmark datasets compared to competing models like GRAN, GraphRNN, GDSS, and diffusion-based methods. It captures both local and global graph statistics effectively, achieving state-of-the-art metrics for graph degree, clustering coefficients, and Laplacian spectra.
- Scalability: The hierarchical structure of HiGen facilitates parallel graph generation, reduces sensitivity to node ordering, and enables block-wise processing of adjacency matrices, making it adaptable to large and complex graphs. Applications: HiGen excels in generating realistic graphs for applications in molecular modeling, protein structure analysis, data network synthesis, and more, addressing diverse domains where hierarchical graph structures are pivotal.
Results
- Experiments are conducted on 5 datasets (e.g., SBM, Protein, Enzyme, Ego) with increasing size and complexity.
- Evaluation metrics cover local (degree distribution, clustering coefficient distribution, orbit counts) and global (graph spectra) graph statistics.
- Observation: Compared to baselines like GraphRNN, GRAN, and SPECTRE, HiGen consistently achieves state-of-the-art performance, especially in capturing graph motifs and global structure.
Group 18 Presentation: HIGEN: HIERARCHICAL GRAPH GENERATIVE NETWORKS
Presented by:
- Shiyu Zhu
- Jesse Xue
Paper Citation
M. Karami, “HiGen: Hierarchical Graph Generative Networks,” 2023, arXiv. doi: 10.48550/ARXIV.2305.19337.
Introduction
"HiGen: Hierarchical Graph Generative Networks" by Mahdi Karma introduces a novel approach to graph generation that places emphasis on the hierarchal structures in many inherent real-world graphs.
Key Contributions
1). Hierarchical Graph Generation
HiGen employs a coarse-to-fine strategy to generate graphs, effectively capturing their hierarchal nature. This method involves generating sub-structures at multiple levels, which enhances the model's ability to reflect the inherent organization of complex graphs.
2). Parallel Community Generation
At every hierarchical level, the HiGen method generates communities in parallel. Then, this is followed by the prediction of cross-edge between these communities using separate neural networks. Such a modular approach enables scalable graph generation for large and complicated graphs.
3). Multinomial Edge Distribution Modelling
The model utilizes a multinomial distribution to represent the output distribution od edges within the graph. A recursive factorization of this distribution is employed, and HiGen facilitates the autoregressive generation of community graphs with integer valued edge weights. This improves the realism and accuracy of the generated graphs.
Performance Outcomes
Studies demonstrate the effectiveness and scalability of the HiGen, allowing it to achieve state of the art performance in terms of graph quality across a variety of benchmark datasets. The modular design and hierarchal generation process contributes to its abilities to ternate large and complex graphs with efficiency.
Advancements in the Field
HiGen has advanced the abilities of machine learning in graph generative models by explicitly incorporating hierarchal structures into the generation process. This approach has addressed limitations in existing methods that often overlook the multi-level organization of real world graphs. This thereby enhances the fidelity and applicability of generated graphs within various domains and applications.
Conclusion
The methodologies and finding presented in the HIGEN: HIERARCHICAL GRAPH GENERATIVE NETWORKS paper offer contributions to the field of graph generation. By introducing a hierarchal and modular approach, HIGen sets a new benchmark for generating complex graphs that accurately reflect real observed structures.
Group 20 Gated linear attention transformers with hardware efficient training
Reference
arXiv:2312.06635
Background
The paper discusses Gated Linear Attention (GLA) Transformers, addressing the computational inefficiency of traditional transformers with softmax attention. Regular transformers have a quadratic computational complexity with sequence length, which becomes extremely expensive for long sequences, and linear attention mechanism leads to an underperforming model.
Technical Contributions
It proposes using a linear kernel as an alternative to the softmax function, which allows attention to be formulated as a linear RNN with 2D hidden states. The key innovations include:
1. Introducing a data-dependent gating mechanism to improve model performance, which allows the model to forget past information adaptively
2. Developing a linear attention approach that reduces computational complexity
3. Creating a hardware-efficient training method (FLASHLINEARATTENTION Algorithm) that can handle long sequences more effectively
The main goal was to create a more efficient transformer model that can:
- Reduce computational expenses
- Maintain competitive performance across different tasks
- Handle long sequences more effectively
- Leverage modern GPU architectures for improved training and inference
The approach addresses the fundamental challenge of making transformer models more scalable and computationally efficient, particularly for tasks involving long sequences like processing books, dialogues, or complex scientific texts.
Results
The results and conclusions of the paper showed:
Performance Results:
- For the 340 million parameter model:
- Achieved competitive performance - Close to transformer performance - Slightly better or comparable to Rednet - Slightly below Mamba on some tasks
- For the 1.3 billion parameter model:
- Beat most benchmarks in average accuracy - Slightly behind transformer++ in perplexity - Showed impressive accuracy across tasks
Key Findings:
1. Gating mechanism is crucial for model performance
- Removing it significantly increased perplexity - Data-dependent scalar decay improved results
2. Recall-intensive tasks:
- Smaller model: Transformer still led - Larger model: GLA closed performance gap considerably - Competitive with Mamba and Rednet
3. Computational Efficiency:
- Higher training throughput for larger batch sizes - Slight increase in GPU memory usage - More efficient for bigger batches
Conclusions:
- GLA is highly effective for handling long sequences - Hardware-efficient design reduces computational costs - Gating mechanism significantly enhances model performance - Promising approach for making transformers more accessible and efficient - A efficient replacement for softmax attention in Transformers
The paper suggests future research should focus on optimizing the balance between performance and efficiency.
Linear Attention
Transformers traditionally use softmax attention, which scales poorly with sequence length due to quadratic complexity. Linear attention approximates softmax with kernel-based attention mechanisms, reducing this cost.
Parallel and Recurrent Forms
-
Parallel Form: Computes full attention using:
[math]\displaystyle{ \mathbf{O} = \text{softmax}((\mathbf{QK}^\top) \odot \mathbf{M}) \mathbf{V} }[/math]
Enables efficient training with full-sequence inputs. -
Recurrent Form: Used during inference, processes token-by-token with:
[math]\displaystyle{ \mathbf{o}_t = \frac{\sum_{i=1}^{t} \phi(\mathbf{q}_t) \phi(\mathbf{k}_i)^\top \mathbf{v}_i}{\sum_{i=1}^{t} \phi(\mathbf{q}_t) \phi(\mathbf{k}_i)^\top} }[/math] -
Using [math]\displaystyle{ \phi(x) = x }[/math] and removing normalization yields the simplified linear attention update:
[math]\displaystyle{ \mathbf{S}_t = \mathbf{S}_{t-1} + \mathbf{k}_t^\top \mathbf{v}_t }[/math], [math]\displaystyle{ \quad \mathbf{o}_t = \mathbf{q}_t \mathbf{S}_t }[/math]
Chunkwise Parallel Linear Attention
The chunkwise parallel form balances between full parallelism and full recurrence, enabling faster training on long sequences.
- Splits input [math]\displaystyle{ \mathbf{X} }[/math] into chunks of length [math]\displaystyle{ C }[/math].
-
Inter-chunk state update:
[math]\displaystyle{ \mathbf{S}_{[i+1]} = \mathbf{S}_{[i]} + \sum_{j=iC+1}^{(i+1)C} \mathbf{k}_j^\top \mathbf{v}_j }[/math] -
Intra-chunk output:
[math]\displaystyle{ \mathbf{O}_{[i+1]} = \mathbf{Q}_{[i+1]} \mathbf{S}_{[i]} + \left((\mathbf{Q}_{[i+1]} \mathbf{K}_{[i+1]}^\top) \odot \mathbf{M}\right) \mathbf{V}_{[i+1]} }[/math]
Swish Activation function (SwiGLU)
One notable component of the GLA model’s design is its use of the Swish activation function (and the derived SwiGLU gating unit) in key parts of the network. Swish, defined as [math]\displaystyle{ \text{Swish}(x) = x \cdot \sigma(x) }[/math] (where [math]\displaystyle{ \sigma }[/math] is the sigmoid), is a smooth, non-monotonic activation known to often outperform ReLU/GELU in deep networks. In this paper, Swish is employed in two main places: (1) the feed-forward network (FFN) layers, where the authors adopt the SwiGLU formulation, and (2) the computation of the data-dependent gates in the attention mechanism.
1) FFN
The function's smooth gradient and ability to yield non-zero outputs even for negative inputs help with optimization and expressiveness. In summary, using SwiGLU in each Transformer block’s FFN is an architectural choice that boosts performance per parameter, making the overall model more competitive with standard Transformers.
2)Gating mechanism
Swish is self-gating. when [math]\displaystyle{ x_t W_r }[/math] is large positive, Swish outputs a large value (roughly linear in x for large positive inputs), but when [math]\displaystyle{ x_t W_r }[/math] is around zero or negative, Swish outputs a small value (tending toward zero for large negative inputs). This means [math]\displaystyle{ r_t }[/math] will tend to selectively suppress tokens that the model deems less important (yielding near-zero for those features) while allowing strong signals to pass through (near-linear for large activations). A standard sigmoid gate could also suppress features (outputting 0-1), but it would saturate at 1 for any sufficiently large input, effectively capping the influence of very important features. Swish, by contrast, does not saturate to a hard limit – for important inputs it can output values greater than 1 (since [math]\displaystyle{ r_t \approx x_t }[/math] if [math]\displaystyle{ x_t W_r }[/math] is large), thereby allowing an amplification effect. This gives the model more flexibility than a sigmoid-gated GLU: small signals are squashed (multiply by a small fraction), while strong signals can be propagated in full or even amplified (almost identity for large positive x). This property can be crucial for modeling, for example, rare key tokens that should strongly influence the attention – Swish gating will let those contributions through rather unattenuated, whereas a sigmoid gate might bottleneck at 1.
Benefits
- Time complexity: [math]\displaystyle{ \mathcal{O}(LCd + Ld^2) }[/math], which is sub-quadratic.
- [math]\displaystyle{ C = 1 }[/math] recovers the recurrent form; [math]\displaystyle{ C = L }[/math] recovers the parallel form.
- Efficient and scalable to long sequences with minimal performance loss.
Future Work
1. Future hardware-aware optimization: balance between efficiency and performance.
2. Application to other data: the potential of applying GLA to image, video, or scientific data.
3. Test how GLA perform on larger model: due to computational limitations, the experiment is on moderate scale model.
Summaries of Key Points
- Gated Linear Attention (GLA) Transformer is a novel architecture that combines the efficiency of linear attention with data-dependent gating mechanisms to improve performance in sequence modeling tasks.
- FLASHLINEARATTENTION is introduced as a hardware-efficient implementation of linear attention, outperforming FLASHATTENTION-2 in speed, even for short sequences (e.g., 1K tokens).
- GLA Transformer enhances length generalization, allowing models trained on 2K sequences to generalize to sequences longer than 20K without significant performance degradation.
- The model is competitive with state-of-the-art architectures, including LLaMA Transformers and linear-time inference models like RetNet and Mamba, particularly in moderate-scale language modeling tasks.
- GLA Transformer achieves higher training throughput compared to similarly sized Mamba models while maintaining efficient long-context processing.
Group 20 Gated linear attention transformers with hardware efficient training
Presented by:
- Felix Jean
- Maxime Bouthilier
- Thomas Hudon
Paper Citation
S. Yang, B. Wang, Y. Shen, R. Panda & Y. Kim, “Gated linear attention transformers with hardware efficient training,” 2024, arXiv:2312.06635
Background & Motivation
The paper tackles the limitations of traditional softmax attention in Transformers, which, despite enabling efficient parallel training, exhibits quadratic complexity with respect to sequence length, rendering it impractical for long sequences. Linear attention has emerged as a promising alternative, providing linear-time inference by reformulating attention as a recurrent neural network (RNN) with 2D hidden states. However, in practice, linear attention often underperforms compared to softmax attention, and existing implementations lack I/O-awareness, leading to slower speeds relative to optimized softmax attention implementations such as FlashAttention-2. The authors identify two critical gaps: (1) the absence of hardware-efficient algorithms for linear attention that effectively balance memory movement and parallelizability, and (2) the lack of data-dependent gating mechanisms in linear attention, which are essential for achieving high performance in RNNs. These gaps motivate the development of FlashLinearAttention and the gated linear attention (GLA) Transformer.
Key Points
The paper introduces FlashLinearAttention, an I/O-aware and hardware-efficient algorithm for linear attention that optimizes memory movement and parallelizability. It achieves faster speeds than FlashAttention-2, even on short sequences (e.g., 1K tokens). The authors further extend this algorithm to Gated Linear Attention (GLA), which incorporates data-dependent gates to enhance model expressiveness. GLA preserves the linear-time inference property while improving performance across a range of tasks. Additionally, the paper proposes a chunkwise parallel formulation for GLA, enabling efficient training by dividing sequences into chunks and balancing inter-chunk and intra-chunk computations. Experimental results demonstrate that the GLA Transformer performs competitively against LLaMA-architecture Transformers and recent linear-time models such as RetNet and Mamba, particularly excelling in length generalization and recall-intensive tasks.
Contributions
- FlashLinearAttention: A hardware-efficient algorithm for linear attention that outperforms FlashAttention-2 in speed and memory efficiency.
- Gated Linear Attention (GLA): A novel linear attention variant with data-dependent gates, offering better performance and stability.
- Chunkwise Parallel Form: A training-friendly formulation of GLA that enables efficient parallelization and scalability.
- Empirical Validation: Demonstrates competitive performance against strong baselines, including LLaMA, RetNet, and Mamba, with notable strengths in length generalization and recall tasks.
- Open-source Implementation: The release of FlashLinearAttention as a practical tool for the community.
Constructive Critiques
- Scalability: Although the experiments are conducted at moderate scales (up to 1.3B parameters), it remains unclear how GLA would perform at larger scales (e.g., 7B+ parameters). The authors hypothesize that GLA’s efficiency would further improve at such scales, but this claim requires empirical validation.
- Generalization to Other Modalities: The current focus is on language modeling; however, extending GLA to other domains (e.g., vision or audio) could potentially broaden its applicability and impact.
- Complexity of Implementation: The secondary-level chunking and materialization strategies introduce additional complexity. Providing a more streamlined implementation or conducting ablation studies could help users better understand the associated trade-offs.
- Comparison to Hybrid Models: While the paper compares GLA to pure linear-time models (e.g., Mamba) and softmax attention, hybrid approaches that combine linear and sparse attention are not explored. Such comparisons could provide deeper insights into GLA's relative strengths and limitations.
Relationships to Other Works
Linear Attention extends prior work by Katharopoulos et al. (2020) and Sun et al. (2023a) by introducing data-dependent gates and hardware optimizations. Hardware-Efficient Attention follows the spirit of FlashAttention (Dao et al., 2022b) but adapts it for linear attention, addressing unique challenges such as chunkwise parallelism. Gated RNNs draws inspiration from gated RNNs (e.g., LSTMs, Mamba) but adapts the gating mechanism for linear attention’s 2D hidden states. Length Generalization complements recent efforts like RetNet and Mamba-2, offering a new solution for extrapolating beyond training lengths.
Group 23 Presentation: Discrete Diffusion Modelling By Estimating the Ratios of the Data Distribution
Presented By
Chenxin Lyu, Yixuan Zeng
Paper Citation
A. Lou, C. Meng, and S. Ermon, ‘Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution’, Jun. 06, 2024, arXiv: arXiv:2310.16834. doi: 10.48550/arXiv.2310.16834.
https://arxiv.org/abs/2310.16834
Background
- Diffusion models have shown great performance for generative artifical intelligence when applied to domains with continuous data
- Diffusion models are more difficult to implement for data in the discrete domain, such as tokenized texts
- Prior attempts at applying diffusion to text generations have performed worse than autoregressive models
Paper Contributions
- Developed a method called Score Entropy Discrete Diffusion (SEDD)
- Parameterizes the diffusion process for discrete data using data distribution ratios, rather than dealing with the tokenized data directly
- SEDD
SEDD is a framework for discrete diffusion modeling that learns to generate data by estimating probability ratios between neighboring discrete states.. In the paper, SEDD forms the core modeling strategy that enables diffusion-based generation for discrete data, achieving competitive perplexities with autoregressive baselines.
- Implicit Score Entropy Loss
[math]\displaystyle{ L_{ISE} = \mathbb{E}_{x \sim p} \sum_{y!=x}(w_{xy}s_\theta(x)_y - w_{yx}log s_\theta (y)_x) }[/math]
Implicit Score Entropy Loss is a novel training objective designed to learn the ratio function [math]\displaystyle{ s_\theta(x, t) = p_t(y)/p_t(x) }[/math] without requiring access to the true data distribution. In the paper, it makes the score entropy computationally more efficient, as it avoids the explicit dependence on [math]\displaystyle{ p(y)/p(x) }[/math] ratio.
It allows one to not compute partition functions or normalize over large vocabularies. The implicit score entropy loss makes the method scalable and practical for high-dimensional or categorical data (like text). It also connects naturally to energy-based modeling where exact densities are intractable but ratios or unnormalized scores can be learned.
It’s worth noting that the score entropy loss also connects to maximum likelihood. The authors show it can be used to derive an evidence lower bound (ELBO) for likelihood training
Discrete Diffusion Processes
- Models probability distributions over a finite discrete space [math]\displaystyle{ \mathcal{X} = \{1, \ldots, N\} }[/math], using probability mass vectors [math]\displaystyle{ p_t \in \mathbb{R}^N }[/math].
-
Evolution of [math]\displaystyle{ p_t }[/math] follows a linear ODE:
[math]\displaystyle{ \frac{dp_t}{dt} = Q_t p_t,\quad p_0 \approx p_{\text{data}} }[/math] - [math]\displaystyle{ Q_t }[/math] is a diffusion matrix with non-negative off-diagonal entries and column sums equal to 0 (mass is preserved).
- Often simplified as [math]\displaystyle{ Q_t = \sigma(t) Q }[/math], driving [math]\displaystyle{ p_t }[/math] toward a base distribution as [math]\displaystyle{ t \to \infty }[/math].
-
Simulated using Euler steps with small [math]\displaystyle{ \Delta t }[/math]. Transition probability:
[math]\displaystyle{ p(x_{t+\Delta t} = y \mid x_t = x) = \delta_{xy} + Q_t(y, x) \Delta t + O(\Delta t^2) }[/math] -
Time Reversal: Reverse process uses another matrix [math]\displaystyle{ \overline{Q}_t }[/math] with:
[math]\displaystyle{ \overline{Q}_t(y, x) = \frac{p_t(y)}{p_t(x)} Q_t(x, y) }[/math]
Reverse ODE: [math]\displaystyle{ \frac{dp_{T-t}}{dt} = \overline{Q}_{T-t} p_{T-t} }[/math] - This connects to the concrete score, generalizing the score function [math]\displaystyle{ \nabla_x \log p_t }[/math].
Summaries of Key Points
- SEDD (Score Entropy Discrete Diffusion models) is a novel approach to discrete diffusion modeling that bridges the gap between diffusion models and autoregressive language models. It introduces score entropy, a new loss function that extends score matching to discrete spaces, improving performance in discrete generative modeling tasks.
- SEDD significantly improves performance in language modeling, reducing perplexity by 25-75% compared to previous discrete diffusion models and outperforming GPT-2 in certain tasks.
- Key advantages of SEDD over traditional autoregressive models:
- Higher generative quality without requiring distribution annealing techniques such as temperature scaling.
- Computational efficiency, enabling similar output quality with up to 32× fewer network evaluations.
- Enhanced controllability, allowing for flexible text infilling beyond left-to-right prompting, while maintaining quality comparable to nucleus sampling.
- SEDD models discrete data by parameterizing a reverse discrete diffusion process using ratios of the data distribution, making them more effective in capturing language structure.
- The method challenges the dominance of autoregressive transformers by offering an alternative with better trade-offs between efficiency, control, and generation quality in discrete text modeling.
Results
- Perplexity Evaluation: SEDD outperforms previous diffusion models in terms of perplexity across multiple language modeling benchmarks. This suggests that SEDD models the underlying data distribution more accurately, improving likelihood estimation.
- Unconditional Generation: SEDD generates high-quality samples without any input conditioning, achieving performance comparable to GPT-2. It does so with 32× fewer network evaluations, indicating higher sampling efficiency and reduced computational cost.
- Conditional Generation (Infill Tasks): SEDD is tested on tasks where the model must generate missing parts of text conditioned on surrounding context. Despite lacking the autoregressive biases that normally boost performance in such tasks, SEDD remains competitive with strong baselines like GPT-2 and SSD-LM. This highlights SEDD’s ability to generalize well to complex discrete generation tasks without relying on token-by-token prediction.
Group 23 Presentation: Discrete Diffusion Modelling By Estimating the Ratios of the Data Distribution
Presented By
Chenxin Lyu, Yixuan Zeng
Paper Citation
A. Lou, C. Meng, and S. Ermon, ‘Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution’, Jun. 06, 2024, arXiv: arXiv:2310.16834. doi: 10.48550/arXiv.2310.16834.
https://arxiv.org/abs/2310.16834
Background & Motivation
The paper tackles a critical gap in generative modeling for discrete data, particularly within the domain of natural language processing (NLP). While diffusion models have achieved remarkable success in continuous domains such as image generation, their performance on discrete data (e.g., text) has fallen short compared to autoregressive models, which currently dominate the field. The authors pinpoint the root cause as the absence of a principled and scalable framework for discrete score matching—the foundational theory underlying continuous diffusion models. Existing approaches, such as mean prediction and ratio matching, exhibit theoretical and empirical limitations, including instability, inefficiency, and suboptimal performance. Motivated by these challenges, the paper introduces Score Entropy Discrete Diffusion (SEDD), a novel method that extends score matching to discrete spaces by estimating probability ratios of the data distribution. This approach seeks to close the performance gap between autoregressive and diffusion-based language models while addressing key challenges associated with slow sampling, limited controllability, and stringent annealing requirements in autoregressive models.
Key Points
1. Score Entropy Loss: The key innovation lies in a novel loss function, score entropy, which extends score matching to discrete spaces by modeling the ratios of the data distribution. This ensures positivity, scalability, and theoretical consistency, thereby addressing the limitations of prior methods such as concrete score matching (which inadequately penalizes negative values).
2. Discrete Diffusion Framework: SEDD parameterizes the reverse diffusion process using learned probability ratios, enabling efficient sampling and likelihood-based training. The framework supports token-level transitions via structured matrices (e.g., uniform or absorbing transitions), facilitating the handling of high-dimensional sequences.
3. Empirical Superiority: SEDD outperforms existing discrete and continuous diffusion models on language tasks, reducing perplexity by 25–75% and matching or surpassing GPT-2 in zero-shot perplexity. It also achieves significantly higher-quality unconditional generation (6–8× better generative perplexity than un-annealed GPT-2) and flexible conditional generation (e.g., infilling).
4. Practical Benefits: The model provides a favorable compute-quality trade-off (e.g., achieving GPT-2 quality with 32× fewer steps), eliminates the need for annealing techniques like temperature scaling, and enables controllable infilling without specialized training.
Contributions
- Theoretical: Introduces score entropy, a loss function that generalizes score matching to discrete spaces while ensuring positivity and scalability, with rigorous proofs of consistency and tractability (e.g., the denoising score entropy variant). .
- Methodological: Develops SEDD, a discrete diffusion framework that integrates score entropy with token-level transitions via structured matrices, enabling efficient training and sampling. The Tweedie-inspired τ-leaping sampling strategy further enhances performance in practical scenarios.
- Empirical: Demonstrates state-of-the-art results on language modeling benchmarks (e.g., text8, One Billion Words) and generation tasks (both unconditional and conditional), outperforming autoregressive baselines in key metrics such as perplexity. The model’s flexibility in infilling and its favorable compute-quality trade-offs represent significant advancements in the field.
Constructive Critiques
- Complexity: The reliance on matrix exponentials (e.g., for token transitions) may limit scalability to larger vocabularies or more complex structures (e.g., graphs).
- Generalization: While SEDD excels in language, its applicability to other discrete domains (e.g., molecules, code) remains untested.
- Training Cost: The paper notes SEDD’s parameter count is slightly higher than GPT-2, but the computational overhead of diffusion training versus autoregressive training is not thoroughly compared.
Relationships to Other Works
The SEDD model advances key areas of generative modeling by improving upon prior discrete diffusion approaches such as D3PM (Austin et al., 2021) and the continuous-time framework of Campbell et al. (2022). It replaces their mean prediction objectives with ratio estimation, addressing limitations in stability and continuous-time approximation. Compared to continuous diffusion models like Diffusion-LM (Li et al., 2022) and PLAID (Gulrajani & Hashimoto, 2023), SEDD achieves better performance in likelihood estimation and generation quality without requiring heuristic annealing techniques. The work also generalizes score matching methods, extending Hyvarinen's original score matching (2005) and concrete score matching (Meng et al., 2022) to discrete domains through its score entropy formulation. While not yet reaching the scale of modern autoregressive models, SEDD competes effectively with autoregressive baselines like GPT-2 in flexible generation tasks (e.g., infilling) and computational efficiency. Its success highlights the potential of combining SEDD with recent advances such as self-conditioning (Strudel et al., 2022) to further close the gap with autoregressive models.
Group 24 Presentation: Mitigating the Missing Fragmentation Problem in De Novo Peptide Sequencing With A Two-Stage Graph-Based Deep Learning Model
Presenters
Zi Hua Xu, Zehao Zhang
Paper Citation
Mao, Z., Zhang, R., Xin, L. et al. Mitigating the missing-fragmentation problem in de novo peptide sequencing with a two-stage graph-based deep learning model. Nat Mach Intell 5, 1250–1260 (2023). https://doi.org/10.1038/s42256-023-00738-x
https://www.nature.com/articles/s42256-023-00738-x#citeas
Background
- Proteins are crucial for biological functions
- Proteins are formed from peptides which are sequences of amino acids
- Mass spectrometry is used to analyze peptide sequences
- De Novo sequencing is used to piece together peptide sequences when the sequences are missing from existing established protein databases
- Deep learning has become commonly implimented to solve the problem of de-novo peptide sequencing
- When a peptide fails to fragment in the expected manner, it can make protein reconstruction difficult due to missing data
- One error in the protein can propogate to errors throughout the entire sequence
Paper Contributions
- Graph Novo was developed to handle incomplete segments
- GraphNovo-PathSearcher instead of directly predicting, does a path search method to predict the next peptide in a sequence
- A graph neural network is used to find the best path from the graph generated from the mass spectrometry input
- GraphNovo-SeqFiller instead of directly predicting, does a path search method to predict the next peptide in a sequence.
- It's expected that some peptides/ amino acids may have been missed, SeqFiller uses a transformer to add in amino acids which have been missed from PathSearcher
- Input is mass spectrum from mass spectrometry
- Graph construction is done where nodes represent possible fragments, and edges represent possible peptides (PathSearcher module)
- PathSearcher uses machine learning to find the optimal path on the generated graph
- SeqFiller fills in missing amino acids that may have not been included in the PathSearcher module due to lacking data from the mass spectrometry inputs
Peptide Sequencing in AlphaPeptDeep
- Peptide sequencing determines amino acid sequences of peptides, crucial for understanding proteins.
- Mass spectrometry (MS) is used to fragment peptides and analyze the resulting spectra.
Methods referenced in the presentation:
- Database Search & Spectral Library Search: AlphaPeptDeep improves prediction of MS spectra and retention time, boosting accuracy of both methods.
- de novo Sequencing: Enhanced spectral prediction from AlphaPeptDeep supports building peptide sequences without prior knowledge.
- AlphaPeptDeep predicts peptide properties (e.g., fragmentation patterns) to improve spectrum matching and sequence inference.
Contributions
- GraphNovo outperforms prior models such as DeepNovo, PointNovo, and Casanovo. Achieveing:
- 9.3–11.5% higher peptide recall,
- 6.1–8.9% higher amino acid recall,
- 5.8–8.7% higher amino acid precision.
- Substantially improves predictions in and beyond missing-fragmentation regions.
- Improves amino acid recall by up to 24.5% after missing fragmentation when guiding DeepNovo with GraphNovo’s predicted path.
- Maintains high accuracy across varying:
- Sequence lengths,
- Noise signal ratios,
- Degrees of missing fragmentation.
- Open-source availability:
- Code: GitHub – GraphNovo
- Data: Zenodo Repository
Constructive Critiques
- Struggles with long missing-fragmentation regions: The model has difficulty accurately predicting peptide sequences when large segments of fragmentation data are absent. These missing regions create gaps in the spectrum, which may impair the model's ability to reconstruct the full peptide chain.
- Requires predefined post-translational modifications: The system depends on a predefined list of possible post-translational modifications. This constraint limits the model’s ability to generalize to peptides with unexpected or novel PTMs, reducing its adaptability in complex biological samples.
- Computationally expensive: Due to its two-stage graph-based deep learning architecture, the model demands significant computational resources.
Reviews
I thought that this presentation clearly introduced the concepts of peptide sequences and how mass spectrometry is used to analyze/reconstruct peptide sequences. Also it was nice that you mentioned the different ways how sequences are matched with observed spectra besides deep learning methods (e.g, matching observed spectra with known sequences and using pre-existing spectral databases). The problem of how mistakes in the fragmentation can result in missing data in the spectrometry observed was also clearly explained, ultimately making sequence reconstruction difficult. As someone who plans to do project 3 for the final project, I found this presentation particularly helpful in understanding the type of data used in de novo peptide sequence generation. One comment that I have is that I found it a little unclear on how the mass spectra input is turned into a graph with fragments as nodes, and what exactly the optimal "path" is in the graph (aka how the edges are defined), although it may also because I am not too familiar in this area.
Group 24 Presentation: Mitigating the Missing Fragmentation Problem in De Novo Peptide Sequencing With A Two-Stage Graph-Based Deep Learning Model
Presenters
Zi Hua Xu, Zehao Zhang
Paper Citation
Mao, Z., Zhang, R., Xin, L. et al. Mitigating the missing-fragmentation problem in de novo peptide sequencing with a two-stage graph-based deep learning model. Nat Mach Intell 5, 1250–1260 (2023). https://doi.org/10.1038/s42256-023-00738-x
https://www.nature.com/articles/s42256-023-00738-x#citeas
Background
Peptide sequencing is a critical step in proteomics for determining the amino acid composition of proteins using tandem mass spectrometry (MS). Traditional approaches fall into three main categories:
(1) Database Search: Compares observed tandem mass spectra with peptides from a known database.
(2) Spectral Library Search: Uses curated spectral libraries containing experimentally acquired spectra to identify peptide–spectrum matches.
(3) De Novo Sequencing: Infers the peptide sequence directly from the MS data without relying on existing sequence databases, making it essential for novel protein discovery and cases where databases are incomplete or impractical.
Limitations in de novo peptide sequencing
Two major challenges persist in de novo peptide sequencing:
(1) Missing Fragmentation: Some peptide bonds do not break during MS fragmentation, leaving “gaps” in the spectrum that make certain regions difficult to reconstruct.
(2) Error Accumulation: A mistake in one poorly fragmented region often propagates, causing further downstream errors in the predicted peptide sequence.
GraphNovo
GraphNovo is a two-stage de novo sequencing algorithm designed to mitigate these issues:
(1) GraphNovo-PathSearcher: Constructs a directed acyclic graph (the “spectrum graph”) where each node represents a possible partial mass, and edges represent allowable amino acid mass differences. Predicts the “optimal path” from the start node to the end node, effectively capturing the correct arrangement of fragment ions and bypassing regions of missing fragmentation by labeling them as “mass tags.”
(2) GraphNovo-SeqFiller: Fills in the “mass tags” from the optimal path with their specific amino acid composition. Uses a transformer-based decoder guided by the path constraints, reducing errors that could otherwise accumulate if one low-confidence region led to incorrect subsequent predictions.
How GraphNovo Functions
(1) Graph Construction: Translates raw MS peaks into a spectrum graph. Each node corresponds to a potential prefix mass; edges form between nodes if the mass difference matches one or more amino acid masses (within tolerance).
(2) PathSearcher: Trained on known spectra to find the correct route through this graph, placing constraints on fragment ions (including missing-fragmentation regions represented as nodes without direct evidence).
(3) SeqFiller: Given the path—and hence each mass gap—SeqFiller “zooms in” on each mass tag to determine the exact amino acids. This two-stage strategy tackles missing fragmentation more directly than single-stage approaches.
Data and Preprocessing
Training Data: Includes high-confidence peptide identifications for Homo sapiens and Mus musculus from public proteomics datasets (e.g., plasma, HeLa, brain tissues). Only peptides with reliable annotations (1% false-discovery rate) and precursor masses below a certain threshold are included.
Test Data: Drawn from species not in the training set (e.g., Arabidopsis thaliana, C. elegans, E. coli), ensuring that evaluation measures generalizability.
Preprocessing Steps:
(1) Convert raw MS peaks into feature vectors (e.g., normalized m/z, relative intensity).
(2) Generate “node spectrum” (b, y, a, y2+ ions, among others) while discarding infeasible peaks.
(3) Build the graph by connecting nodes if their mass difference matches valid amino acid combinations.
Model Architecture
Graph Construction: Creates a directed graph with edges corresponding to possible amino acid masses.
Loss Functions:
(1) GraphNovo-PathSearcher: Uses Kullback–Leibler divergence to guide node (path) predictions.
(2) GraphNovo-SeqFiller: Uses cross-entropy to predict the exact amino acid sequence that fills each mass tag.
Hyperparameter Tuning:
Optimizer: AdamW (a variant of Adam) with a fixed learning rate in the reported experiments.
Both stages employ a transformer-based architecture, incorporating a specialized graph encoder (relation attention) to capture node and edge features.
Performance Comparison
Peptide Recall: GraphNovo shows a 9–12% improvement over the next-best approach (e.g., PointNovo, Casanovo) in correctly reconstructing entire peptide sequences.
Amino Acid Recall and Precision: Yields a 5–9% improvement across different test species, indicating more accurate individual residue identifications.
Robust to Missing Fragmentation and Noise: Maintains relatively high recall/precision for longer peptides, higher noise levels, and spectra with multiple missing-fragmentation sites, thereby mitigating error accumulation.
Constructive Critiques
Long Missing-Fragmentation Regions: While GraphNovo substantially reduces error propagation, very long or continuous gaps remain challenging.
Predefined Modifications: Must specify possible post-translational modifications (PTMs) in the graph construction step, which becomes computationally costly if many PTMs are considered at once.
Computational Overhead: Two-stage approach and large-scale graph construction require significant memory and processing time.
Future improvements
Enhanced Sequence Prediction: Integrate more MS features (e.g., retention time) to improve accuracy within large missing-fragmentation regions.
Expanded Applicability: Adapting the two-stage approach for top-down or middle-down proteomics and more extensive sets of PTMs.
Computational Efficiency: Explore faster graph-building algorithms, reduce the number of nodes through refined filtering, and potentially incorporate few-shot learning for user-specified PTMs.
Comment
Overall, GraphNovo demonstrates that a carefully designed, graph-based deep learning model can significantly mitigate the missing-fragmentation problem in de novo peptide sequencing, outperforming traditional and newer transformer-based methods by providing a stable framework for both path and sequence prediction.
Paper Citation
Group 47 Presentation: Jamba: A Hybrid Transformer - Mamba Language Model
Presenter
Chentao Jin
Paper Citation
Lieber, O., Lenz, B., Bata, H., Cohen, G., Osin, J., Dalmedigos, I., Safahi, E., Meirom, S., Belinkov, Y., Shalev-Shwartz, S., Abend, O., Alon, R., Asida, T., Bergman, A., Glozman, R., Gokhman, M., Manevich, A., Ratner, N., Rozen, N., Shwartz, E., Zusman, M., Shoham, Y. (2024). Jamba: A Hybrid Transformer-Mamba Language Model. arXiv. https://arxiv.org/abs/2403.19887
https://doi.org/10.48550/arXiv.2403.19887
Background
- Large language models (LLMs) have become essential for various natural language processing tasks.
- Transformers have been the dominant architecture for LLMs due to their effectiveness in handling sequential data.
- However, Transformers suffer from high memory and compute costs, limiting their efficiency in long-context processing.
- Mamba, a recent state-space model, has emerged as an alternative to Transformers, offering improved efficiency in handling long sequences.
- A hybrid approach combining Transformers and Mamba layers can leverage the strengths of both architectures.
- Mixture-of-Experts (MoE) techniques can further enhance model capacity while managing active parameter usage.
- Efficient model architectures are crucial for balancing performance, computational efficiency, and memory footprint in large-scale AI applications.
Summaries of Key Points
Bridging Transformer Expressiveness and Mamba Efficiency
Large language models (LLMs) have made remarkable advances, but scaling them to handle long-context processing remains a significant challenge. Jamba introduces a novel hybrid architecture that combines the strengths of Transformers and Mamba, enhanced with a Mixture of Experts (MoE) module. This integration enables efficient memory usage, high throughput, and scalability for sequences up to 256,000 tokens on a single GPU.
Key Architectural Innovations
1. Transformer Layers – Self-attention mechanisms allow the model to capture complex token relationships, crucial for tasks involving deep contextual reasoning. However, they come with high memory and compute costs when processing long sequences.
2. Mamba Layers – Derived from state-space models (SSMs), Mamba efficiently processes long sequences without storing extensive key-value caches. Instead, it maintains a hidden state to summarize prior information, significantly reducing memory overhead.
3. Mixture of Experts (MoE) – Jamba integrates sparse expert selection, where each token activates only a small subset of experts. This technique increases capacity while controlling computational costs. Specifically, Jamba uses 16 experts, with only two active per token, optimizing efficiency.
The Jamba Block: A Structured Hybrid Design
The architecture of a Jamba block follows a structured sequence of Transformer layers, Mamba layers, and MoE layers: Transformer-to-Mamba Ratio (1:7). The model incorporates one Transformer layer for every seven Mamba layers, balancing expressiveness and efficiency. MoE Placement – Instead of applying MoE to every layer, Jamba replaces every second multi-layer perceptron (MLP) layer with an MoE module. This approach increases model capacity without significantly raising parameter count.
By blending self-attention, state-space models, and sparsity techniques, Jamba pushes the boundaries of long-context processing in language models. Its ability to handle extremely long sequences while maintaining efficiency and scalability makes it a compelling innovation in the next generation of LLM architectures.
Performance and Benefits
Jamba achieves state-of-the-art efficiency and performance across academic benchmarks, long-context tasks, and throughput.
- Matches or outperforms larger models such as Mixtral-8x7B and LLaMA-2 70B on:
- Reasoning: HellaSwag, ARC-Challenge, WinoGrande, PIQA, TruthfulQA
- Comprehension: BoolQ, QuAC
- Math and Code: GSM8K, HumanEval
- Aggregated tasks: MMLU, BBH
- Outperforms Mixtral on:
- Needle-in-a-Haystack retrieval
- Few-shot classification: Banking77, TREC-Fine
- Long-context QA: NarrativeQA, CUAD, NaturalQuestions
- Up to 3× faster than Mixtral at [math]\displaystyle{ 128K }[/math] context length.
- Efficient inference on a single [math]\displaystyle{ 80\,\text{GB} }[/math] GPU (int8 quantization).
- KV-cache memory usage is 8× smaller than Transformers (e.g., [math]\displaystyle{ 4\,\text{GB} }[/math] vs. [math]\displaystyle{ 32\,\text{GB} }[/math] at [math]\displaystyle{ 256K }[/math] tokens).
Constructive Critique
While Jamba achieves impressive performance and efficiency through its hybrid architecture, several limitations and open questions remain:
- Lack of ablation analysis: The paper adopts a 1:7 ratio of Transformer to Mamba layers and places MoE modules every other layer, but provides little justification for these hyperparameters. A more thorough ablation would help clarify the contribution of each component to final performance.
- Interpretability trade-off: Combining multiple architectural modules (SSM, attention, MoE) increases complexity. While effective, this may make it harder to interpret model behavior or debug errors, especially compared to simpler Transformer-only baselines.
- ICL limitations: The authors mention that Mamba layers alone struggle with in-context learning (ICL). Although interleaving attention layers helps, this still suggests limitations in how well SSMs handle token-by-token reasoning or structure-sensitive tasks.
- Lack of fine-tuning or alignment: The released model is a base pretrained model without instruction tuning or safety alignment. This limits its immediate use in downstream applications without additional supervised or RLHF-based training.
Despite these challenges, Jamba represents a promising direction for scaling efficient, long-context language models and offers a practical blueprint for hybrid architectures in the LLM space.
Group 47 Presentation: Jamba: A Hybrid Transformer - Mamba Language Model
Presenter
Chentao Jin
Paper Citation
Lieber, O., Lenz, B., Bata, H., Cohen, G., Osin, J., Dalmedigos, I., Safahi, E., Meirom, S., Belinkov, Y., Shalev-Shwartz, S., Abend, O., Alon, R., Asida, T., Bergman, A., Glozman, R., Gokhman, M., Manevich, A., Ratner, N., Rozen, N., Shwartz, E., Zusman, M., Shoham, Y. (2024). Jamba: A Hybrid Transformer-Mamba Language Model. arXiv. https://arxiv.org/abs/2403.19887
https://doi.org/10.48550/arXiv.2403.19887
Jamba Architecture
Main Features
Jamba is a hybrid large language model that interleaves:
Transformer layers with self-attention (standard decoder blocks).
Mamba layers (a state-space model, SSM) introduced by Gu & Dao (2023).
Mixture-of-experts (MoE) modules integrated into some of the MLP layers.
By combining these three components, Jamba can balance efficiency, long-context capabilities, and model capacity without incurring a prohibitive computational or memory cost.
(1) Transformer Layers
Jamba uses standard decoder-only Transformer blocks, but crucially, they appear in a reduced proportion (e.g., 1 attention layer for every 7 Mamba layers).
The attention mechanism is still important for in-context learning and tasks that benefit from explicit token-to-token interactions.
(2) Mamba Layers
Mamba layers replace much of the attention with an SSM-based mechanism, scaling linearly with sequence length.
They significantly reduce key–value cache size for long contexts because each Mamba layer does not require storing extensive attention activations.
Unlike prior SSMs, Mamba is stabilized at large scale with carefully chosen RMSNorm inside the state-space modules.
The authors find no explicit positional encoding is required in the Mamba blocks—Mamba inherently captures positional information.
(3) Mixture-of-Experts (MoE)
Jamba integrates MoE in some MLP layers to increase total capacity without increasing the active parameters used per token.
MoE involves having multiple “expert” sub-MLPs, with only the top K experts selected for each token.
This leads to a sparse model design: total parameters can be large (e.g., 50B+), but only ~12B parameters are “active” at any forward pass.
Performance and Benefits of Jamba
(1) High throughput Compared to a pure-Transformer of similar size, Jamba achieves up to 3× higher inference throughput at very long context lengths. This is because Mamba’s linear-time scan avoids the quadratic cost and large key–value cache of attention.
(2) Memory efficiency Jamba’s key–value cache can be 8× smaller than a similarly sized Transformer, which makes it possible to handle up to 256K tokens of context (or even more) on a single 80GB GPU.
(3) Competitive quality On standard LM benchmarks (ARC, HellaSwag, WinoGrande, etc.), Jamba performs on par with or better than similarly sized Transformer or MoE-Transformer models. It also demonstrates strong capabilities in long-context tasks (e.g., “needle in a haystack” retrieval).
Key Design and Insights
(1) Hybrid Architecture
The mixed ratio of attention layers to Mamba layers (often 1:7) is crucial. Even a small fraction of attention layers confers strong in-context learning (format adherence, induction-like patterns), while the Mamba layers bring speed and memory savings.
Pure Mamba, though fast, sometimes struggles with emergent in-context learning behaviors (e.g., properly following few-shot prompts). The hybrid design preserves these Transformer-like capabilities.
(2) MoE Effectiveness
Using MoE on top of the hybrid model further improves perplexity and downstream performance, allowing the total parameter count to go up to 50B+ while keeping active parameter usage around ~12B.
Balancing the number of experts, top-K selection, and how frequently MoE is used (e.g., every other MLP layer) is key for controlling compute costs and memory.
(3) Training Stability and Design Choices
RMSNorm: Large-scale Mamba layers exhibit occasional large activation spikes. RMSNorm on internal activations stabilizes the training, preventing loss spikes.
No explicit positional encoding needed: Unlike typical Transformers (which use rotary, ALiBi, or other embeddings), the authors found that Mamba captures positional cues inherently. Adding RoPE gave no notable improvement.
Conclusion
(1) Uniqueness of Jamba High efficiency and strong design
Jamba’s combination of attention, Mamba, and MoE layers yields excellent throughput and long-sequence modeling.
(2) Handling long context better
Jamba’s memory footprint for KV caching is drastically smaller. It can handle contexts of up to 256K tokens on a single 80GB GPU—significantly exceeding typical Transformer-based LLMs of similar size.
(3) Open-source release
The model is released under an Apache 2.0 license, encouraging research on this hybrid approach. Pretrained checkpoints and ablation runs will also be provided.
Future Directions
(1) Optimize MoE Further
Investigating more sophisticated MoE routing strategies, expert balance, or hierarchical gating to push quality and efficiency further.
(2) Hybrid Scaling in Even Larger Models
Extending beyond ~7B–12B active parameters to tens of billions or more, exploring how the attention–Mamba ratio and MoE design scale at even larger training runs.
Group 47 Presentation: Jamba: A Hybrid Transformer - Mamba Language Model
Presenter
Chentao Jin
Paper Citation
Lieber, O., Lenz, B., Bata, H., Cohen, G., Osin, J., Dalmedigos, I., Safahi, E., Meirom, S., Belinkov, Y., Shalev-Shwartz, S., Abend, O., Alon, R., Asida, T., Bergman, A., Glozman, R., Gokhman, M., Manevich, A., Ratner, N., Rozen, N., Shwartz, E., Zusman, M., Shoham, Y. (2024). Jamba: A Hybrid Transformer-Mamba Language Model. arXiv. https://arxiv.org/abs/2403.19887
https://doi.org/10.48550/arXiv.2403.19887
Clear explanations to aid understanding
Jamba uses a unique combination of transformer layers, memory layers, and a Mixture of Experts (MoE) layer to improve its efficiency, scalability, and ability to process long sequences of text. This hybrid design optimizes both memory management and computational power while minimizing resource use. Here’s an overview of the components and their roles in the architecture:
Jamba's Architecture
Jamba’s architecture is built on a structured sequence of layers: after each memory layer, a Mixture of Experts (MoE) layer is placed, and every transformer layer is followed by seven memory layers. This design creates an efficient, hybrid system that maximizes both memory handling and computational power while minimizing resource usage.
Each of these layers:
1. Transformer Layers: The transformer layers in Jamba are responsible for capturing the relationships between different tokens in the sequence, no matter how far apart they are. This is done using self-attention, which helps the model understand complex relationships and context within the text. However, traditional transformers can struggle with very long sequences because the self-attention mechanism becomes very memory and computation-heavy as the sequence length grows.
2. Memory Layers: To tackle this, Jamba introduces memory layers based on a different model called the state space model (SSM). These memory layers don’t rely on self-attention. Instead, they maintain a hidden state that keeps track of important information from earlier in the sequence. This makes memory layers far more efficient when it comes to handling long sequences, as they don’t need to store as much data in memory, unlike the transformer layers.
3. Mixture of Experts (MoE): The MoE component is where Jamba gets its flexibility. Instead of using all model parameters for each token, MoE selectively activates a small subset of "experts" for each token. An "expert" is a specialized set of parameters that focuses on solving a specific part of the problem. The model dynamically selects the experts that are most relevant to each token's context, allowing it to handle different parts of the problem more efficiently. Since only a few experts are activated per token, the model can scale its capacity to handle more complex tasks or longer sequences without significantly increasing computational costs.
Additionally, Jamba stabilizes its training process with RMSNorm, which normalizes activations and ensures that training remains stable, even when the model is scaled up to very large sizes.
Performance and Benefits of Jamba
Jamba's hybrid approach provides several advantages:
- Efficient Handling of Long Contexts: Jamba is able to process long sequences of text effectively, overcoming the limitations of traditional transformer models.
- Balance Between Performance and Efficiency: Jamba achieves strong performance while reducing memory usage and computational costs thanks to its combination of transformer, memory, and MoE layers.
- Scalability: By using fewer active parameters than other models, Jamba is scalable and can handle tasks that require understanding large amounts of text without compromising efficiency.
Limitations of Jamba
But this hybrid approach has limitations as well: like training instability that needs RMSNorm to keep things stable. It also requires a lot of GPU memory, like 80 GB, to handle longer contexts. Also, the mixture of experts (MOE) still needs more optimization to improve performance.
Further Directions
- Optimize MoE further for even better efficiency
- Investigate how hybrid architectures scale in even larger contexts and models