stat940W25-presentation
Notes on Presentations
Group 1 Presentation: Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data
Paper Citation
Podina, L., Eastman, B., & Kohandel, M. (2023). Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data. In Proceedings of the 40th International Conference on Machine Learning (Vol. 202). PMLR, Honolulu, Hawaii, USA.
Background
Differential equations
Examples of differential equations in physics include Newton's second law (which is an ordinary differential equation), the Navier-Stokes equations (which are partial differential equations), etc.
Existing methods of solving differential equations:
- Analytical methods, such as integration or separation of variables.
- Numerical methods, such as finite difference, finite volume, or finite elements.
- Data-driven approaches: these involve Universal Differential Equations (UDEs) and Physics-Informed Neural Networks (PINNs), which are the focus of this paper.
Introduction to PINNs
With (many) machine learning approaches, the goal is to approximate the solution to a DE using a feed-forward neural network, optimized with MSE loss. The key difference that makes it physics-informed is an extra term in the loss, which penalizes the model for deviating from the governing DE.
Introduction to UDEs
Here, the differential equation is expressed as a sum of two terms: the known physics-based model and an unknown neural network.
Paper Contributions
Universal Physics-Informed Neural Networks (UPINNs)
PINNs and UDEs are combined, addressing the limitations of the original methods, while sharing their benefits.
The model integrates three network components:
- Surrogate Solution Network U: links to the measurement loss
- Unknown Differential Operator Network F: with with U within the PINN loss
- Boundary Condition Network B: links to the boundary loss
The loss function contains three terms:
- MSE
- Boundary loss, if you were to provide boundary conditions with the problem.
- PINN loss: ensure the model respects the differential conditions.
Experimental Validation
1. Lotka-Volterra Model
They first experimented with the UPINN on the Lotka-Volterra system of differential equations, which are used to model predator-prey dynamics:
[math]\displaystyle{ \frac{dx}{dt} = \alpha x - \beta xy }[/math]
[math]\displaystyle{ \frac{dy}{dt} = -\delta y + \gamma xy }[/math]
The UDE and PINN were individually tested on two scenarios: sparse data (where there are very few input data points) and noisy data. Alone, each model did not do very well, especially when the data was very sparse or very noisy. When the UPINN was used, the solution was quite good, even with high sparsity or noise.
2. Viscous Burgers’ Equation
Their next experiment was used Burger's equation, a system in fluid dynamics.
[math]\displaystyle{ \frac{\partial u}{\partial t} = -u \frac{\partial u}{\partial x} + \nu \frac{\partial^2 u}{\partial x^2} }[/math]
3. Cell Apoptosis Model
Summaries of key points
This paper introduces Universal Physics-Informed Neural Networks (UPINNs) for discovering unknown terms in differential equations (ODEs/PDEs) from sparse and possibly noisy data. It combines the strengths of standard Physics-Informed Neural Networks (PINNs)—which incorporate prior knowledge of the governing equations—while still allowing parts of the underlying model to remain unknown and be learned from the data. Unlike previous methods such as Universal Differential Equations (UDEs), which can falter in noisy and small-data regimes, UPINNs maintain good accuracy by:
1. Leveraging collocation points in the loss function to incorporate the differential equation constraints ("physics").
2. Adding a neural network component to represent the unknown terms of the operator.
3. Applying symbolic regression (e.g., AI Feynman) to convert the neural approximation of the hidden terms into interpretable, closed-form expressions.
Extensive experiments on the Lotka–Volterra system, a viscous Burgers’ PDE, and a cell apoptosis ODE show that UPINNs outperform UDEs in handling higher noise and fewer data points, while still recovering the hidden differential-operator terms accurately.
Furthermore, symbolic regression improves interpretability by converting neural outputs into explicit equations. This interpretability, combined with robustness to sparse and noisy data, makes UPINNs especially promising for scientific discovery. Potential applications include systems biology, fluid dynamics, and environmental modeling. Future research directions could address scalability to higher-dimensional PDEs and uncertainty quantification.
Related work
Nonlocal Physics-Informed Neural Networks (nPINNs): nPINNs introduce a universal nonlocal Laplace operator that encompasses classical and fractional Laplacians. This framework is utilized for parameter identification in nonlocal models, demonstrating consistency and accuracy in capturing operator behaviours.
Group 2 Presentation: EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty
Presented by:
Kareena Bhalla and Chelsea Huffman
Paper Citation
Li, Y., Wei, F., Zhang, C., Zhang, H. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty. arXiv. doi.org/10.48550/arXiv.2401.15077
Background
In this paper, we are looking at Large Language Models (LLMs). Autoregressive decoding in LLMs refers to generating text one token at a time, the model basing its predictions on the tokens that came before it. It's inefficient and costly.
Speculative Sampling
Speculative sampling is a technique meant to reduce the computational cost and runtime of autoregressive decoding. The process consists of two main parts:
- Draft stage: A small, fast model suggests some tokens.
- Verification stage: In parallel, the large main LLM verifies these tokens and selects the best ones.
Eagle also has both a draft stage and a verification stage.
How do we choose a draft model that functions like the big LLM, but faster? One approach is to use a reduced version of the main LLM, but this doesn't work when there is no small version available, also using a model with fewer parameters often come with high overhead and reduced accuracy.
Technical Contributions
Extrapolation Algorithm for Greater Language Model Efficiency (EAGLE)
Before making a prediction, the model looks ahead by one token/word in the sequence.
One advantage of the EAGLE method is that only a single decoder layer needs to be trained rather than an entire draft model. This makes the training process extremely efficient in terms of runtime/computational cost.
The EAGLE method makes improvements to the process based on the following 2 observations:
1. Autoregression is simpler at the feature level than the token level.
2. The uncertainties from the sampling process negatively affect the performance.
Feature-Level Autoregression
Unlike traditional speculative sampling methods that operate at the token level, EAGLE performs autoregression at the feature level, specifically utilizing the second-to-top-layer features of the LLM. This approach simplifies the drafting process and enhances efficiency.
Experimental Results
- EAGLE is much faster than ordinary autoregressive decoding.
- EAGLE can be applied to various LLMs without reducing the model's generation quality. Furthermore, it is also compatible with other generation acceleration methods; for example, the paper goes through a quick case study of combining EAGLE with an optimization technique called gpt-fast.
- EAGLE can generate approximately 3.2–4.5 tokens per pass. Regular, vanilla decoding only generates one token per forward pass, so this is a notable speed-up in token generation.
- EAGLE provides the greatest computational speed-up in code generation tasks. This is likely because code contains fixed templates (e.g., repetitive structures, strict and specific rules, etc.), thereby making it easier to generate plausible and accurate drafts.
Summaries of key points
EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency) is a framework designed to accelerate the text generation of large language models (LLMs) without sacrificing the original distribution of generated tokens. Traditionally, LLMs rely on autoregressive decoding, generating text one token at a time and incurring heavy computational costs. Speculative sampling, introduced in earlier works, seeks to reduce these costs by splitting text generation into a faster *“drafting”* step and a single-step “verification” step using the main LLM, thereby validating multiple drafted tokens in parallel. However, existing methods often struggle to balance high draft accuracy with low overhead, limiting their speed improvements.
EAGLE addresses these issues by predicting hidden-layer features rather than tokens. Specifically, it focuses on the second-to-top-layer feature vector of the original LLM. Because these features exhibit more consistent structure than raw language tokens, they can be modeled with fewer errors and simpler modules. Additionally, EAGLE includes a small Autoregression Head that factors in one-step-ahead token information—this is critical for mitigating the inherent uncertainty in feature prediction. When generating text, the main LLM determines which token is sampled at a given step, but from the feature-only perspective, it is impossible to know the token that will appear. By feeding the token as well as the past feature states into a lightweight single-layer decoder, EAGLE’s draft model can produce a highly accurate prediction of the next feature. After obtaining the predicted features, EAGLE applies the LLM’s original output layer (the LM head) to sample candidate tokens in parallel.
Once a tree of drafted tokens is produced, the method performs a single verification pass in the main LLM. In one forward call, EAGLE obtains the LLM’s probabilities for all tokens in this tree and accepts or rejects each token based on a mathematically proven acceptance criterion. Crucially, this ensures the distribution of tokens matches what the LLM would have generated if it performed naive, step-by-step autoregressive decoding.
In evaluations across code (HumanEval), mathematical reasoning (GSM8K), instruction following (Alpaca), and multi-turn dialogue (MT-bench) datasets, EAGLE yields 2.7–3.5× speedups for models like LLaMA2-Chat 70B and 2.8–3.3× for smaller Vicuna variants on single-GPU setups. Even for more modest setups, or MoE-based models like Mixtral 8×7B Instruct, EAGLE still demonstrates strong acceleration. The approach does not require finetuning the main LLM itself; only a small draft network is trained on a fixed dataset (e.g., ShareGPT). Because the overhead of this training is low—typically 1–2 days even for 70B-parameter models—EAGLE proves highly practical.
Overall, it offers a simple but powerful way to generate tokens in parallel while retaining exact quality and distribution guarantees, making it a compelling choice for real-world systems seeking reduced latency in LLM services.
Related Works
There is now a new version of Eagle, Eagle 2. Eagle 2 improves from this version by introducing dynamically adjustable draft tree. It would adjust the draft tree based on the context and position, which is built based on speculative sampling. Upon some testing, Eagle can be 20% - 40% faster than EAGEL-1. There is also EAGLE-3 that is out Scaling up Inference Acceleration of Large Language Models via Training-Time Test which can be found here: https://arxiv.org/html/2503.01840v1.EAGLE-3 advances the acceleration of LLM inference by shifting from feature prediction to direct token prediction and employing multi-layer feature fusion through a technique called training-time test. These enhancements enable the draft model to fully benefit from scaled-up training data, resulting in notable speedup ratios.
In the draft stage, other related works have utilized different methods. Speculative sampling and Lookahead use tokens to predict tokens. Medusa uses features to independently predict tokens.
In the verifying stage, other related works have modified acceptance probabilities (DistillSpec (Zhou et al., 2023)), or the acceptance method (BiLD (Kim etal., 2023) and Medusa (Cai et al., 2023)).
Other related methods that reduce the latency per forward pass including distillation, quantization and pruning.
Group 2 Presentation: EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty
Presented by:
Kareena Bhalla and Chelsea Huffman
Paper Citation
Du, Y., Ram, D., Liu, X., Su, Y., Liu, S., Lee, J., Mohamed, A., & Ma, T. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty. arXiv preprint arXiv:2402.00842.
Summaries of key points
Speculative sampling speeds up language model inference by having a fast draft model guess multiple tokens, which are then verified by a slower, accurate model. While effective, this breaks the model’s training assumption of strictly sequential input, introducing feature uncertainty—the hidden states are based on possibly incorrect guesses. The paper proposes EAGLE, an energy-based method that learns to assess and modulate this uncertainty at the feature level. Instead of discarding uncertain tokens, EAGLE uses a learned energy score to decide how much to trust them during decoding. This leads to better performance, especially on tasks requiring reasoning or longer context, and is more robust than previous speculative decoding approaches Problem Motivation: Speculative decoding (used to speed up LLM inference by parallelizing token generation) is great for efficiency but introduces a mismatch between training and inference: training assumes sequential decoding, but speculative sampling adds a “guess-and-check” step that breaks that assumption.
Key Idea: EAGLE (which stands for Energy-based Adaptive Guidance with Latent Evidence) proposes a new method to handle uncertainty that arises in speculative decoding. It adjusts the model’s internal feature representations to reflect this uncertainty, rather than just masking or ignoring it.
How It Works: Instead of relying on the regular transformer’s last hidden state, EAGLE builds an energy-based auxiliary model that learns to estimate whether a token guess is valid, using both the main model’s predictions and the speculative draft. This energy score helps modulate the influence of uncertain features during decoding.
Results: EAGLE shows better performance on downstream tasks compared to vanilla speculative decoding, especially on tasks that require reasoning or handling uncertainty — e.g., question answering or coding benchmarks.
Explanations to aid understanding
Speculative decoding in simpler terms: Imagine trying to write the next word in a sentence, but instead of just writing one word and waiting, you guess a bunch in parallel and then double-check them. This saves time, but makes it harder for the model to know what it should trust. EAGLE essentially adds a smart layer that acts like a “confidence gauge” for the guesses, using a learned energy function to decide how much to trust each speculative token and its underlying features.
Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces
Presented by:
Liang Wu, Jingcheng Yu, Candace Ng
Paper Citation
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752.
Background
Technical Contributions
- Introduce a selection mechanism for state space models that enables input-dependent state updates
- Develop a hardware-aware recurrent scan algorithm for efficient computation
- Propose Mamba, an attention-free, linear-time architecture that achieves state-of-the-art performance across multiple modalities
Architecture
Mamba integrates selective SSMs into a single homogeneous block. Each block comprises: a linear projection, a selective SSM layer, and an MLP block. This architecture is attention-free and scales linearly in sequence length.
Related Work
- Builds on structured SSMs (S4, S5) and architectures such as H3 and Hyena.
- Demonstrates theoretical and empirical connections to classical RNN gating.
Future Directions:
- Scale to larger models and refine training recipes
- Extend Mamba to multimodal tasks (e.g. video)
- Explore additional downstream affordances (fine-tuning, in-context learning, quantization)
Summaries of Key Points
1. Motivation: Faster and More Flexible Sequence Modeling
Large-scale "foundation models" (FMs) largely rely on Transformers (with attention) to handle long-range context. However, attention has quadratic complexity in sequence length and requires storing the entire key–value cache. This leads to high memory usage and slow generation. Meanwhile, Structured State-Space Models (SSMs) offer linear-time processing through either convolution or recurrence, achieving excellent results in domains like audio and vision. Yet they typically assume constant, time-invariant parameters, which hinders their ability to handle discrete or "content-based" tasks such as language.
2. Proposed Approach: Selective State-Space Models (S6)
The paper introduces a selection mechanism to inject content-based (input-dependent) reasoning into SSMs. Specifically:
- Parameterizing SSM transitions by inputs: Conventional SSMs have fixed transition matrices (A, B, C), but here, B and C—and crucially, the step-size parameter Δ—become input-dependent. This small twist allows the model to "select" which inputs matter and how strongly to update its hidden state, like gating in RNNs.
- Efficient "Selective Scan" Implementation: Because parameters vary with the input, one cannot use global convolution for training (which was key to SSM efficiency). Instead, the authors design a hardware-aware parallel scan on GPUs, fusing operators so that large intermediate states reside only in on-chip memory. This approach retains O(L) time scaling in sequence length while dramatically cutting overhead.
3. Mamba Architecture
To showcase selective SSMs, they build Mamba, a simple "attention-free" architecture. Each layer is effectively an MLP with a parallel SSM path included. By training only these blocks end-to-end (without multi-head attention), Mamba delivers:
- Linear-time training because it uses convolution-like or scan-based methods. - Constant-time inference per token (no caching entire past context), which is especially beneficial for extremely long sequences.
Additionally, Mamba integrates gating mechanisms within its SSM path, adaptively controlling information flow to better capture complex dependencies. Its streamlined layer design allows hardware-aware GPU optimizations, enhancing throughput and scalability compared to attention-based models.
4. Experiments Across Multiple Domains
1. Synthetic Tests (Selective Copying, Induction Heads): Mamba (or more precisely, its selective mechanism) solves tasks where ordinary time-invariant SSMs fail. It can ignore irrelevant tokens and preserve essential ones.
2. Language Modeling: On The Pile (multi-domain text), Mamba matches or exceeds standard Transformer perplexities when model size scales to ~1B parameters. It also outperforms other sub-quadratic methods (e.g., Hyena, RWKV). Notably, Mamba obtains 4–5× higher inference throughput than Transformers because it avoids large key–value caches.
3. Genomics: Mamba attains better perplexities than alternative long-sequence models (HyenaDNA) on a human genome corpus, especially at very long context (up to millions of tokens).
4. Audio Waveforms: The model outperforms prior state-of-the-art (e.g., SaShiMi) in generating raw speech signals, achieving lower FID scores and better sample quality.
5. Significance and Conclusions
Mamba demonstrates that:
1. Combining gating/selection with state-space recurrences, and
2. A carefully engineered GPU scan can outperform Transformers in certain tasks and match them in others, all with linear-time complexity. This opens a path for foundation models in domains needing extremely long context—like genomics, raw audio, and (possibly) next-generation language modeling—where Transformers become unwieldy. The authors highlight that future directions include scaling Mamba to billions of parameters (7B+), exploring interpretability, and combining it with other approaches (e.g., partial attention) for even broader performance gains.
Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces
Presented by:
Liang Wu, Jingcheng Yu, Candace Ng
Paper Citation
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752.
Summaries of key points
Nowadays, transformers are extremely popular because they are great at finding relationships in data. However, they struggle with long sequences since their calculations grow too fast as the input gets longer. In comparision, State Space Models (SSMs) are more efficient for long sequences, but they have a limitation: they can't easily decide what information to keep or forget. Mamba solves this problem by improving a standard SSM model so that it can selectively adjust its parameters based on the input.
Standard SSM (S4): How It Works
SSMs process sequences differently from transformers. Instead of checking all relationships in the data, they store past information in a hidden state and update it step by step.
The basic rule for an SSM is:
[math]\displaystyle{ h_t = A h_{t-1} + B x_t }[/math], [math]\displaystyle{ y_t = C h_t }[/math]
Where:
- [math]\displaystyle{ h_t }[/math] is the hidden state (memory of past information).
- [math]\displaystyle{ x_t }[/math] is the current input.
- [math]\displaystyle{ A, B, C }[/math] are numbers that the model learns to control how information flows.
- [math]\displaystyle{ y_t }[/math] is the output.
The problem? The issue is that A, B, and C stay fixed, meaning the model always updates its memory in the same way, no matter what input it sees. This makes it less adaptable, especially for complex tasks like language modeling.
Selective SSM (S6): Mamba’s Improvement
Mamba introduces Selective State Space Models (S6), where the key values A, B, and C change dynamically based on the input. This allows the model to decide what to store and what to forget in real-time.
Instead of using fixed values, Mamba updates them dynamically:
[math]\displaystyle{ h_t = A_t h_{t-1} + B_t x_t }[/math], [math]\displaystyle{ y_t = C_t h_t }[/math]
Now, A, B, and C depend on the input [math]\displaystyle{ x_t }[/math], making the model much more flexible and allowing it to adjust its memory as needed.
Also, Mamba uses a hardware-aware scan algorithm, which makes its computations 20-45x faster than older methods and 4-5x faster than transformers. This means it can handle longer sequences without slowing down.
Experiment Results
Mamba has been tested in different fields:
- Language modeling: Performs as well as transformers but is more efficient.
- DNA sequencing: Works well on long sequences, beating many existing models.
- Audio processing: Outperforms other SSM models for speech tasks.
Limitations of Mamba
Mamba is promising, but it has a few challenges:
1. Scaling up: It hasn’t been tested on extremely large models yet, so we don’t know if it can compete with transformers at that level.
2. Trade-offs in memory selection: Choosing what to remember and forget works well for things like text but might not be as useful for continuous data like speech.
3. Lack of a mature ecosystem: Transformers have been around longer, so they have more tools and techniques available. Mamba still has to catch up.
Even with these issues, Mamba is an exciting step forward for sequence modeling
Group 4 Presentation: Learning spatiotemporal dynamics with a pretrained generative model
Presented by:
- Karolina Suszek
- Negin Amou
- Muhammad Azeem
Paper Citation
Z. Li et al., “Learning spatiotemporal dynamics with a pretrained generative model,” Nature Machine Intelligence, vol. 6, no. 12. Springer Science and Business Media LLC, pp. 1566–1579, Dec. 06, 2024. doi: 10.1038/s42256-024-00938-z.
Background
- Spatiotemporal dynamics: how the state of a physical system varies with space and time
- Real datasets often contain data with sparse measurements where there are a limited number of sensors available. There needs to be a way to convert the sparse measurement data into a full spatiotemporal field.
- Existing solutions learn to map the input to output and ignores missing data, but this reduces the models ability to generalize.
- Paper proposes the use of Sparse-Sensor-Assisted Score-Based Generative Model (S3GM) which uses unlabeled data durring training and can reconstruct incomplete data after training to make accurate predictions even when there isnt much information available.
- Key Idea: Learn the probabilith distribution of spatiotemporal data using score-based generative model and refine the samples via schochastic sampling
Technical Contributions
The main proposed model is the Sparse-Sensor-Assissted Score-Based Generative Model. It learns patterns from a large amount of data before hand. It also is unsupervised so it does not require any labels during training. It tries to learn the significant features of the data/natural patterns. After training, the model can be used to take incomplete data and reconstruct the missing parts to make predictions.
Core Components:
- Pre Training Stage: Learns the joint probability distribution of the data
- Generating Stage: Use a stochastic differential equation to refine and generate full field predictions
- Refinement Mechanism: Ensure Allignment with observations and enforce sequence consistency
Some of the common applications of this model are Turbulent flow modeling, climate forecasting, and physics-based simulations.
Summaries of key points
- Challenge Addressed: Traditional end-to-end learning models often struggle with generalization in reconstructing spatiotemporal dynamics, particularly when data is sparse—a common scenario in real-world applications.
- S³GM Methodology: Pretraining Phase: An unconditioned generative model is pretrained in a self-supervised manner on a comprehensive dataset, capturing the joint distribution of the system's dynamics. Generation Phase: The pretrained model is conditioned on new, sparse measurements to reconstruct and predict the full-field spatiotemporal dynamics.
- Validation and Performance: S³GM's efficacy was tested across multiple dynamical systems using synthetic, real-world, and laboratory datasets, including applications in turbulent flow modelling and weather forecasting. The results demonstrated that S³GM achieves high accuracy, generalizability, and robustness, even when faced with significant data sparsity and noise.
S³GM offers a promising approach for modeling and predicting complex spatiotemporal dynamics in situations where data is limited, leveraging the strengths of pretrained generative models to enhance performance in small data regimes.
Related Works
Some of the related works in this area are GPT-ST: Generative Pre-Training of Spatio-Temporal Graph Neural Networks. This framework employs a spatio-temporal masked autoencoder designed to capture both intra- and inter-cluster region semantic relationships, which are often overlooked in existing approaches. Another one is Spatio-Temporal Few-Shot Learning via Diffusive Neural Network Generation, where a generative pre-training framework (GPD) that addresses data scarcity in spatiotemporal modeling. By performing generative pre-training on neural network parameters optimized with data from source cities, the framework enables the generation of tailored neural networks guided by prompts.
Group 4 Presentation: Learning spatiotemporal dynamics with a pretrained generative model
Presented by:
- Karolina Suszek
- Negin Amou
- Muhammad Azeem
Paper Citation
Z. Li et al., “Learning spatiotemporal dynamics with a pretrained generative model,” Nature Machine Intelligence, vol. 6, no. 12. Springer Science and Business Media LLC, pp. 1566–1579, Dec. 06, 2024. doi: 10.1038/s42256-024-00938-z.
Summaries of key points
Spatiotemporal Dynamics and S3GM
Spatiotemporal dynamics describe physical systems (e.g., climate forecasting, fluid dynamics) that evolve over space and time. However, in real-world situations, sensor data is usually incomplete, e.g., whether stations may cover only a few locations; Sensors might capture only the magnitude of velocity, missing direction. Standard deep learning models (FNO, PINN, U-Net, LNO, DeepONet) often struggle to adapt if the sensor setup changes or if the system behaves unexpectedly. This lack of generalization means models often need to be retrained for each new situation. To overcome this, S3GM is proposed.
How Does S3GM Work?
Instead of learning the full probability distribution [math]\displaystyle{ p(x) }[/math] (which is complex), S3GM learns the gradient of the data distribution, called the score function: [math]\displaystyle{ s(x) = \nabla_x \log p(x) }[/math]. This tells the model which direction the data is most likely to change, making learning more efficient.
It should also be noted that real-world data is often messy—noisy, incomplete, or missing. S3GM handles this using DSM:
- Adds a small amount of noise to the data and trains the model to remove it.
- Forces the model to focus on true underlying patterns rather than memorizing raw data.
By repeatedly removing noise, the model deeply understands the true data structure—even when parts are missing.
Once the model learns the data’s structure, it reconstructs missing information using stochastic differential equations (SDEs), which has 3 terms:
- Drift Term: guides reconstruction toward likely states.
- Diffusion Term: adds controlled randomness to explore multiple solutions.
- Correction Term: uses real sensor data to ensure consistency.
How Well Does S3GM Work?
S3GM is tested on four different systems to see how well it reconstructs and predicts missing data:
Experiment 1: Predicting Chaotic Behavior (Kuramoto-Sivashinsky Equation)
- Challenges Tested:
- Sparse Spatial Data (few sensor readings)
- Fourier Transform Domain (frequency-based measurements)
- Limited Initial Data (predict future states with few frames)
- Results:
- S3GM outperformed U-Net, FNO, and DeepONet, achieving lower errors.
- Stable even with limited input data.
Experiment 2: Reconstructing Turbulent Flow (Kolmogorov Flow)
- Challenges Tested:
- Trained on low-turbulence data.
- Tested on high-turbulence data.
- Results:
- Accurately reconstructed velocity fields and vorticity patterns.
Experiment 3: Climate Data Reconstruction (ERA5 Dataset)
- Challenges Tested:
- Extreme Data Sparsity (only 1% wind speed measurements available).
- Hidden Variables (missing temperature and pressure).
- Noisy Measurements (Gaussian noise added).
- Results:
- Successfully reconstructed missing climate variables.
- Performance improved with more sensor data.
Experiment 4: Flow Around a Cylinder
- Challenges Tested:
- Spatiotemporal Gaps (only specific cross-sections measured).
- Time-Averaged Data (some measurements were only available as averages).
- Results:
- Accurately reconstructed instantaneous and time-averaged flow fields.
- Outperformed physics-informed neural networks (PINNs).
Limitations of S3GM
While powerful, S3GM has limitations:
- Computational Cost: Pre-training is resource-intensive.
- Data Quality Dependence: Best performance with diverse, high-quality data.
- Generalization Issues: May struggle with entirely new dynamics.
- Processing Speed: Iterative reconstruction can be slower than traditional methods.
Despite these challenges, S3GM is a promising tool, and if these are improved, it could be even more powerful.
Group 5 Presentation: Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models
Presented by:
Guoqian Li, Xinlei Xu, Wenning Xu
Paper Citation
De, S., Smith, S., Fernando, A., Botev, A., Cristian-Muraru, G., Gu, A., Haroun, R., Berrada, L., Chen, Y., Srinivasan, S., Desjardins, G., Doucet, A., Budden, D., Teh, Y. W., Pascanu, R., De Freitas, N., Gulcehre, C. (2024). Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models. arXiv. arxiv.org/pdf/2402.19427
Background
RNNs and transformers
Recurrent neural networks (RNNs) are a basic architecture for handling sequential data; they are good for short sequences but can struggle with long ones.
Transformers generally perform better than RNNs and have become dominant in recent years. However, for large sequences, they become computationally expensive for a couple of reasons. Their global attention mechanism has a quadratic complexity. Furthermore, the key-value cache and the multi-query attention cache grows linearly with sequence length.
The paper proposes several innovations to improve upon existing attention mechanisms in neural networks:
- RG-LRU layer: a novel gated linear recurrent layer designed to replace Multi-Query Attention (MQA)
- Hawk: integrate multi-layer perceptrons (MLPs) with recurrent blocks
- Griffin: combine MLPs with a mix of recurrent blocks and local attention mechanisms
Technical Contributions
Model architecture
Their proposed models -- Hawk and Griffin -- use these following structures:
- Residual block: This is the main structure in the architecture. It starts with an RMSNorm layer being applied to the hidden state, followed by a temporal mixing block. There's a residual connection. Next, there is another RMSNorm layer followed by an MLP (multi-layer perceptron) block.
- Gated MLP block: This is a gated feedforward block. There are 2 branches: one is linear and one uses a GeLU activation. This part of the structure is used for feature selection.
- Temporal mixing block: They used 3 different kinds of temporal mixing blocks in their models: global multi-query attention, local sliding window attention, and recurrent blocks (what they proposed). The recurrent block contains 2 parallel branches. One has a GeLU activation, while the other has a temporal 1D convolutional layer followed by a RG-LRU layer (a proposal in this paper).
Performance and Efficiancy
- Hawk surpasses the reported performance of Mamba on downstream tasks.
- Griffin matches the performance of Llama-2, despite being trained on over six times fewer tokens.
- Both models demonstrate the ability to extrapolate to sequences significantly longer than those encountered during training.
Real-Gated Linear Recurrent Unit (RG-LRU)
RG-LRUs are inspired by regular Linear Recurrent Units (LRUs), but with a gating mechanism. They are more computational efficient, as they avoid the quadratic complexity that transformers have. Mathematically, the layer is defined as follows.
The recurrence gate:
[math]\displaystyle{ r_t = \sigma (W_a x_t + b_a) }[/math]
The input gate:
[math]\displaystyle{ i_t = \sigma (W_x x_t + b_x) }[/math]
[math]\displaystyle{ a_t = a^{cr_t} }[/math]
The output:
[math]\displaystyle{ h_t = a_t \odot h_{t-1} + \sqrt{1-a_t^2} \odot (i_t \odot x_t) }[/math]
Summaries of Key Points
Context and Motivation
Transformers have dominated large language modeling, but their attention mechanism can be expensive for very long sequences, particularly because of the quadratic complexity of attention and a Key-Value (KV) cache that grows linearly in sequence length. Recurrent neural networks (RNNs) compress entire contexts into a fixed-size hidden state and avoid storing a cache that grows with length, which can reduce memory and speed up inference. Historically, though, RNNs have been difficult to train at scale and often underperform Transformers on complex tasks. This paper proposes new recurrent-based architectures that can match or exceed Transformer performance while retaining the training efficiency of standard deep-learning pipelines and gaining a significant advantage during inference.
Proposed Architectures
Two models are introduced: Hawk, a purely recurrent architecture, and Griffin, which mixes gated linear recurrences with local attention. Both rely on a new Real-Gated Linear Recurrent Unit (RG-LRU) layer, which extends a previously studied linear recurrency (the LRU) by adding two gating mechanisms. One gate modulates how much of the new input to incorporate each step, and another governs whether the layer updates in a standard LRU style or simply retains the previous hidden state.
RG-LRU and Recurrent Blocks
RG-LRU’s diagonal recurrence matrix dramatically lowers the computational cost for large sequence lengths. The authors also include a lightweight convolution for near-token interactions. For Griffin, they interleave RG-LRU-based blocks with a local sliding-window attention, allowing short-range local attention plus a global recurrent state. This combination helps the model handle both local details and long-distance dependencies in a memory-efficient way.
Empirical Results
The authors scale Hawk and Griffin to billions of parameters and compare them with:
1. A baseline Transformer using Multi-Query Attention (MQA),
2. The Mamba recurrent model (from prior work),
3. Llama-2.
Their main findings include:
- Both Hawk and Griffin follow standard power-law scaling in validation loss versus training compute, matching or beating Transformers.
- Hawk matches or exceeds previous recurrent models like Mamba on language modeling benchmarks, despite training on fewer tokens.
- Griffin matches Llama-2’s accuracy on multiple downstream NLP tasks while training on only about one-seventh the token count.
- During training, these models achieve speeds comparable to Transformers (thanks to parallelized or fused kernels), and in some configurations are faster at long sequence lengths (2048 tokens or more).
- At inference time, recurrent and local-attention blocks avoid the large KV caches that hamper Transformers. Hence, Hawk and Griffin show lower latency and can decode tokens at significantly higher throughput, especially for long sequences.
- Hawk and Griffin extrapolate to far longer sequences than they were trained on. They also efficiently learn synthetic copying/retrieval tasks if explicitly trained on them. However, their out-of-the-box retrieval or copying for tasks not seen at training time is still below that of Transformers.
Significance
By mixing local attention with a novel gated recurrence, Griffin retains the best of both recurrent and attention-based approaches: it can model local context with a sliding window while maintaining a small, fixed-size global state for the entire sequence. This achieves strong performance on large-scale language modeling benchmarks while offering major advantages in memory footprint and decoding speed. The results position Hawk and Griffin as compelling alternatives to purely Transformer-based architectures for scalable language models.
Related work and Future direction
Griffin builds on several key areas in efficient language modelling, particularly in recurrent architectures and hybrid approaches combining recurrence and attention: State Space Models (SSMs), Hybrid Recurrence-Attention Models, Efficient Transformer Alternatives.
Scaling Griffin to Larger Models
The paper discusses Griffin models up to 14B parameters but suggests further exploration into scaling beyond this size to compete with GPT-4 or Gemini models. Investigating how Griffin handles long-context tasks in ultra-large models would be an interesting future study.
Memory and Efficiency Optimization
Griffin already improves efficiency compared to transformers, but further research into sparsification, quantization, and hardware-specific optimizations could enhance its real-world applicability. Exploring mixed-precision training and inference optimizations for mobile and edge devices.
Extending Griffin to Multimodal Learning
While Griffin is designed for language modeling, incorporating it into multimodal tasks (e.g., vision-language models, audio processing) could expand its impact. Combining Griffin’s recurrence mechanism with diffusion models or video understanding models might be a promising direction.
Group 5 Presentation: Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models
Presented by:
Guoqian Li, Xinlei Xu, Wenning Xu
Paper Citation
De, S., Smith, S., Fernando, A., Botev, A., Cristian-Muraru, G., Gu, A., Haroun, R., Berrada, L., Chen, Y., Srinivasan, S., Desjardins, G., Doucet, A., Budden, D., Teh, Y. W., Pascanu, R., De Freitas, N., Gulcehre, C. (2024). Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models. arXiv. arxiv.org/pdf/2402.19427
Summaries of Key Points
Motivation and background
Dominance of Transformers and their drawbacks: Transformers (with global attention) have driven most breakthroughs in large-scale language modeling. Despite their strengths, transformers struggle with quadratic complexity on long sequences; the large Key-Value (KV) cache grows linearly in the sequence length during inference, slowing down decoding for long inputs.
Potential of Recurrent Models: Recurrent neural networks (RNNs) offer an appealing alternative because they summarize the sequence with a fixed-size hidden state that is updated token by token (iteratively), thus avoiding a large KV cache. However, classical RNNs can be slow to train in practice and have struggled to match the performance of Transformers at large scales.
Goal: To develop recurrent architectures (and hybrids that combine local attention and recurrences) which (1) match or exceed Transformers’ performance on language modeling. (2) train efficiently at scale (both in FLOPs and hardware utilization). (3) offer improved inference efficiency (lower latency, higher throughput).
Architecture
Hawk (MLPs with recurrent blocks)
(1) Recurrent Block: Replaces self-attention entirely with a diagonal recurrent design. It processes tokens sequentially with fast iteration at inference time. (2) RG-LRU Layer: The crucial piece inside the Hawk model is the Real-Gated Linear Recurrent Unit (RG-LRU), a novel RNN layer that is stable, uses simple diagonal recurrences, and includes input and recurrence gates: Recurrence Gate [math]\displaystyle{ r_t }[/math] and input gate [math]\displaystyle{ i_t }[/math] each depend on the input [math]\displaystyle{ x_t }[/math] but not the recurrent state [math]\displaystyle{ h_(t-1) }[/math]. This yields stable, memory-friendly computations. The update equation mixes the previous hidden state [math]\displaystyle{ h_(t-1) }[/math] and a transformed input [math]\displaystyle{ x_t }[/math] using a diagonal recurrent weight [math]\displaystyle{ a }[/math]. This ensures the update is purely elementwise, minimizing overhead during inference. (3) Achieves strong performance on language modeling while avoiding the large KV cache burden of Transformers. Provides super-exponential memory (depending on the gating mechanism) and can, in practice, capture long-range dependencies.
Griffin (MLPs with a mixture of recurrent blocks and local attention)
(1) Motivation: Although recurrence alone scales well to long sequences, local patterns might still be easier to capture via (local) attention. Mixing the two yields a model that can leverage local attention for shorter-range patterns while relying on a fixed-size recurrent state for longer context. (2) Local Attention: Replaces global attention with a sliding-window approach (no quadratic cost in sequence length). The local window ensures that each token attends only to a fixed number of past tokens. (3) Interleaving Blocks. Griffin’s architecture interleaves the same RG-LRU-based recurrent block and a local attention block. This combination preserves efficiency (small recurrence state, bounded cache size), and matches or outperforms Transformers at the same scale.
Empirical Results and Main Findings
Power-Law Scaling
Hawk, Griffin, and a baseline Transformer (using multi-query attention, MQA) all exhibit a power-law relationship between validation loss and training FLOPs, resembling established Transformer scaling laws. Griffin consistently has slightly better loss–compute tradeoffs than the baseline Transformer at all scales tested.
Performance vs. Mamba and Llama-2
Hawk surpasses Mamba (a prior recurrent state-space model) on downstream tasks at 3B parameters despite training on half as many tokens (300B vs. 600B). Griffin (7B, 14B) matches or slightly outperforms Llama-2–7B/13B despite being trained on about seven times fewer tokens. Conclusion: The new architectures remain competitive on well-known benchmarks such as MMLU, HellaSwag, ARC, PIQA, and WinoGrande, even against popular open-source Transformers.
Implementation Challenges
RG-LRU’s diagonal recurrence is memory-bound on TPUs (and GPUs), as it performs elementwise updates rather than large matrix multiplications.The paper introduces a custom kernel (using Pallas on TPU-v3) that reads and writes hidden states in a linear scan and minimizes memory transfers, greatly speeding up training.
Speed Gains for Long Sequences
For short sequence lengths (e.g. 2K tokens), Griffin trains at a speed on par with the MQA baseline.As sequence lengths grow (4K, 8K), the MQA baseline slows down due to attention complexity, whereas the recurrent models maintain similar runtime. Hence, Hawk/Griffin gain a training-speed advantage at longer contexts.
Throughput and Latency Gains
Transformers (even MQA) rely on a KV cache growing linearly in sequence length. The overhead can become larger than the model parameters at long decode lengths. Hawk and Griffin keep a small, fixed-size recurrence state. This yields: (1) Lower latency: less reading/writing of large caches for extended contexts. (2) Higher throughput: bigger batch sizes fit into memory because the small recurrence state frees device memory otherwise used by attention KV caches.
Additional Observations (Long Context and Copy/Retrieve Tasks)
Extrapolation to Long Sequences
Hawk and Griffin can successfully exploit contexts longer than those used during training, showing better extrapolation than Transformers with typical positional embeddings (e.g. RoPE). Training them explicitly on 8K sequences (rather than 2K) further improves performance at longer contexts.
Copying/Retrieval Skills
Hawk alone may learn copy tasks but can sometimes lag behind Transformers in speed of learning. This matches prior observations about slow adaptation of purely recurrent state-space layers on synthetic tasks. Griffin (mixing local attention + recurrence) can learn copy tasks more quickly, similar to Transformers, and can extrapolate better to sequences beyond the training length.
Significance and Contributions
Challenging the Status Quo
This work demonstrates that RNN-based architectures can be trained at large scale as efficiently as Transformers, with the added benefit of small fixed-size hidden states.
High Performance at Fewer Tokens
Griffin achieves or exceeds Llama-2 performance while using substantially fewer training tokens, highlighting that it is not necessary to rely purely on full-sequence attention.
Practical Inference Advantage
Significantly reduced inference latency and improved throughput on long sequences is a major practical benefit in real-world applications when generating long responses and streaming data.
Blueprint for Hybrid Models
By combining local attention (for short-range patterns) with a gated linear recurrence (for stable, long-range capacity), Griffin strikes a balance that may inform future research in large language models beyond classical Transformers.
Group 6 Presentation: Learning to (Learning at Test Time): RNNs with Expressive Hidden States
Presented by:
Zhiyang Cheng and Pingchu Zhang
Paper Citation
Sun, Y., Li, X., Dalal, K., Xu, J., Vikram, A., Zhang, G., Dubois, Y., Chen, X., Wang, X., Koyejo, S., Hashimoto, T., & Guestrin, C. (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv. https://doi.org/10.48550/arXiv.2407.04620.
Summaries of key points
This paper revisits the traditional role of RNN hidden states, proposing that they can do more than just store information—they can enable learning at test time. The authors introduce a method where a hypernetwork dynamically generates the RNN’s weights as a function of its hidden state, allowing the model to adapt its behavior on the fly. This gives rise to what they call expressive hidden states, which encode both memory and the capacity to steer the model’s future updates. The approach effectively blurs the line between training and inference, treating test-time prediction as a form of continual adaptation. This results in stronger performance in settings like few-shot learning and online learning, where flexibility and rapid adaptation are crucial. Rather than relying on explicit optimization at test time (as in typical meta-learning setups), the RNN itself becomes the learner, continuously reshaping its internal dynamics based on the sequence it's processing.
While innovative, the method introduces nontrivial architectural and computational overhead. The use of a hypernetwork to produce weights at every time step means the model must manage a more complex parameter space and could become less scalable for long sequences or larger models. There's also the risk of instability, since small changes in the hidden state can lead to large changes in the generated weights. Regularization and careful design are needed to prevent the model from diverging. Another limitation is that while the paper shows strong performance on synthetic and controlled few-shot learning tasks, it doesn’t extensively benchmark on more complex natural language or real-world sequential data, leaving questions about generalization and practicality open.
Clear explanations to aid understanding
In a standard RNN, the weights are fixed during inference—you feed in tokens or sequence elements, and the hidden state updates based on those fixed rules. What this paper suggests is: what if the hidden state itself could influence the rules? So instead of always using the same weights, the RNN can generate new ones depending on what it's seen so far. This is done using a hypernetwork—a small neural network that outputs the weights for the main RNN. So as the RNN processes a sequence, it effectively reshapes its own behavior to fit the task or data distribution it's encountering. It’s like the RNN is learning while it’s making predictions, adapting in real-time to maximize performance without needing gradient descent at test time.
Group 6 Presentation: Learning to (Learn at Test Time): RNNs with Expressive Hidden States
Presented by:
- Pingchu Zhang
- Zhiyang Cheng
Paper Citation
Sun, Y., Li, X., Dalal, K., Xu, J., Vikram, A., Zhang, G., Dubois, Y., Chen, X., Wang, X., Koyejo, S., Hashimoto, T., & Guestrin, C. (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv. https://doi.org/10.48550/arXiv.2407.04620
Background
For modern RNNs, performance in long context is limited by the expressive power of their hidden state of fixed size. Hence the authors introduced test-time training (TTT)
Technical Contributions
- Introduce TTT layers, where the hidden state is a model and the update rule is self-supervised learning, offering a new research direction.
- TTT-Linear, a simple implementation of TTT layers, outperforms Transformers and Mamba in evaluations.
- Improve the hardware efficiency of TTT layers through mini-batch TTT and the dual form, making TTT-Linear already a practical building block for LLMs.
Methodology
The key idea is to make the hidden state itself a model with weights, and the update rule a gradient step on the self-supervised loss. Then updating the hidden state on a test sequence is equivalent to training the model at test time.
Training a network with TTT layers
- Training the larger network as the outer loop and training weights within each TTT layer as the inner loop is preferred.
- TTT layers can replace RNN or self-attention layers in any network architecture. Training a network with TTT layers also works the same way as training any other language model.
Learning a self-supervised task for TTT
Add some outer-loo parameters to make this task learnable.
The input [math]\displaystyle{ x_t }[/math] is transformed using a learnable matrix [math]\displaystyle{ \theta_K }[/math] to create a projection [math]\displaystyle{ \tilde x_t = \theta_k x_t }[/math]
The reconstruction label is another low-rank projection [math]\displaystyle{ \theta_V x_t }[/math] which can differ from the input. Then we can create a test view [math]\displaystyle{ \theta_Q x_t }[/math]
Now the new self-supervised loss is: [math]\displaystyle{ l(W,;x_t) = \|f(\theta_k x_t; W)-\theta_V x_t\|^2 }[/math] and the output rule is modified to [math]\displaystyle{ z_t = f(\theta_q x_t;W_t) }[/math]
Summaries of Key Points
Motivation: Compressing Long Context Efficiently
Traditional Transformers handle large contexts by storing every token in a Key-Value cache, which grows linearly with sequence length and makes inference complexity scale quadratically. Modern recurrent neural networks (RNNs) like Mamba sidestep storing the entire context by having a fixed-size hidden state, which leads to linear time complexity. However, RNNs often struggle to exploit very long contexts because the fixed-size hidden state must compress a large amount of information. The authors propose a new design, Test-Time Training (TTT), that treats the hidden state as a small learnable model trained via a self-supervised loss on each incoming token—even at test time.
TTT Layers: Hidden State as a Learner
The paper reframes any sequence-modeling layer as “a hidden state plus an update rule.” For TTT layers, the hidden state is itself a small parametric or nonparametric model f, and the update rule is a step of gradient descent (or other training procedure) on each new token. Thus, at every token step, the hidden state is updated by “training” f on a self-supervised objective. Concretely, one might define a corruption or partial view of the token and train the parametric model to reconstruct the hidden or relevant aspects of the token.
Two Main Instantiations: TTT-Linear and TTT-MLP
The authors propose TTT-Linear, where the learner f is a simple linear mapping plus optional layer normalization and a residual connection. They also propose TTT-MLP, which uses a two-layer MLP as its learner, offering a more expressive hidden state. Both can be integrated into existing RNN-based or Transformer-based architectures in place of the usual self-attention or simple RNN blocks. Like other RNN layers, TTT layers compress all historical tokens into a fixed-size hidden state—but the learner can be updated more flexibly via gradient steps each time a new token arrives.
Efficiency Enhancements
Naively computing a gradient step per token would be too slow. Two key ideas improve hardware utilization: 1. **Mini-batch TTT** processes a batch of b tokens at once to parallelize the internal gradient steps. Smaller b yields more gradient steps (and better expressiveness) but can slow performance. 2. **A “dual form”** for TTT-Linear and TTT-MLP reworks the update and output computations into larger matrix multiplications, ensuring that modern accelerators (GPUs, TPUs) can exploit efficient batched operations.
Empirical Results
On language-modeling benchmarks (the Pile and Books), TTT-Linear and TTT-MLP match or exceed strong baselines (Transformer and the modern RNN Mamba) across model scales (125M to 1.3B parameters). TTT-Linear typically does as well as Mamba in short context (2k tokens) but outperforms Mamba substantially in longer contexts (8k or 32k), demonstrating that the extra expressiveness helps exploit more tokens. TTT-MLP can be even more expressive at very long contexts but can be more memory intensive. The authors also show that TTT-Linear can train and infer efficiently in wall-clock time using a specialized GPU kernel, yielding near-constant inference latency as context grows (unlike the linear growth in Transformers).
Significance and Future Work
TTT recasts the hidden-state update in an RNN-like layer as explicitly training a miniature model at test time—essentially “learning to learn” from each incoming token. With further improvements in tasks (beyond simple reconstruction), hardware kernels, and more expressive hidden states, TTT-based architectures may offer a new path toward efficient, high-performing sequence models for extremely long contexts.
Group 7 Presentation: Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains
Presented by:
Jonathan Gallagher and Mariya Anashkina
Paper Citation
A. V. Makkuva et al., “Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains,” 2024, arXiv. doi: 10.48550/ARXIV.2402.04161.
Summaries of key points
This paper offers a theoretical framework for analyzing transformers using Markov chains, providing a new way to understand how information flows and dependencies are formed across layers. The core idea is to view the self-attention mechanism as inducing a Markov process over the sequence positions, where each attention head defines a transition probability matrix. By interpreting attention through this lens, the authors derive principled explanations for phenomena like context aggregation, token mixing, and the effects of multiple layers and heads. They show that stacking layers corresponds to composing Markov transitions, meaning that deeper transformers can be understood as performing longer-range probabilistic walks over the input. This allows them to make quantitative predictions about attention spread and mixing time, and helps formalize intuitions about how depth and attention head structure influence model behavior. Importantly, the framework applies without modifying the architecture—it’s purely analytical, making it a useful tool for understanding model internals in a mathematically grounded way.
The abstraction into Markov chains may oversimplify certain aspects of how attention operates—particularly when attention scores are data-dependent and influenced by non-linear operations (e.g., softmax and masking). The analysis tends to assume relatively clean or idealized cases, and may not fully capture the nuances of attention in real-world settings like language modeling with positional encoding or context-specific weighting. While the framework is insightful, it’s primarily useful for interpretability and analysis, not for improving model performance directly. Finally, the notion of interpretability here is mathematical, but not necessarily human-friendly—it gives you equations, not explanations in natural language.
Clear explanations to aid understanding
Think of each attention head as deciding where to “move” next in the sequence—like a random walker choosing the next word to look at. If you treat those attention weights as transition probabilities, then attention turns into a kind of Markov process, where information flows step-by-step based on those probabilities. Stacking layers is like walking further in the sequence graph, and having multiple heads is like having multiple walkers with different preferences. This view lets you talk about things like mixing time—how fast information spreads—or how different tokens influence each other over multiple layers. It’s a powerful way to bridge probabilistic modeling and deep learning, especially when trying to reason about what attention is doing beyond just visualizing heatmaps.
Group 7 Presentation: Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains
Presented by:
- Jonathan Gallagher
- Mariya Anashkina
Paper Citation
A. V. Makkuva et al., “Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains,” 2024, arXiv. doi: 10.48550/ARXIV.2402.04161.
Background and overview
Markov chains
Markov chains are probability models that predict the future state based on the current and prior states of a system. Language can be modeled as a high order Markov process, an idea that was popularized by Claude Shannon in a 1948 paper, and later became a building block for natural language processing tasks in the modern day. In transformers, predictions follow closely to Markov chains.
In this paper, the authors use a simplified model: a two-state (binary) Markov chain that is in its stationary distribution, making the assumption that the Markov chain has been running for an extremely long time. The term stationary distribution refers to a distribution of states that a chain converges towards as the number of steps increases indefinitely; at that point, each subsequent distribution is the same as the previous one.
Objectives in this paper
The paper introduces a framework to systematically analyze (through the lens of a Markov chain) how transformers learn to model sequential data. The objective is to explore how a transformer succeeds or struggles on first order Markov chains, a distribution that only needs one step of memory. They also do a theoretical analysis of network architecture choices for transformers, and look at how this affects the loss landscape and learning dynamics.
Weight tying
This is one of the main architectural choices in transformers that are studied in this paper. This is the practice of using the same weights in the input and output layers. Firstly, this means that whether the model is trying to read or generate the text, the tokens' representations are the same -- this is done for consistency. Secondly, this can act as a form of regularization as it gives the model fewer parameters to learn.
Theoretical Results
Significance of [math]\displaystyle{ p+q }[/math]
Recall that we are working with binary, first-order Markov processes here. Let [math]\displaystyle{ p }[/math] be the probability that a state 0 will turn to 1, and let [math]\displaystyle{ q }[/math] be the probability that a state 1 will turn to 0. Consequently, probabilities [math]\displaystyle{ 1-p }[/math] and [math]\displaystyle{ 1-q }[/math] are the probabilities that states 0 and 1 will remain unchanged, respectively.
The quantity [math]\displaystyle{ p+q }[/math] is referred to as a switching factor, and is used to characterize the overall tendency of the chain to change state. When [math]\displaystyle{ p+q \lt 1 }[/math], the system is likely to stay in its current state. When [math]\displaystyle{ p+q \gt 1 }[/math], the system is likely to change its state, and therefore will exhibit oscillatory behaviour.
The authors of the paper provide proof of the existence of the global minimum as well as bad local minima in the loss landscape. Furthermore, they look at the existence of global minima and saddle points in the cases with vs. without weight tying.
Theorem 1 (Global minimum). Let the input sequence be [math]\displaystyle{ \{x_n\}_{n=1}^N \sim \bigl(\pi(p,q), P(p,q)\bigr) }[/math] for some fixed [math]\displaystyle{ (p,q)\in(0,1)^2. }[/math] Then for all [math]\displaystyle{ (p,q), }[/math] there exists a [math]\displaystyle{ \theta_{\star}\in\mathbb{R}^{D-d} }[/math] with an explicit construction such that it is a global minimum for the population loss [math]\displaystyle{ L(\cdot) }[/math].
In other words, the authors are able to prove that for the transformer, there exists an optimal parameter configuration.
Theorem 2 (Bad local minimum). Let the input sequence be [math]\displaystyle{ \{x_n\}_{n=1}^N \sim \bigl(\pi(p,q), P(p,q)\bigr) }[/math] for some fixed [math]\displaystyle{ (p,q)\in(0,1)^2. }[/math] If [math]\displaystyle{ p+q\gt 1, }[/math] there exists an explicit [math]\displaystyle{ \theta_{\pi}\in\mathbb{R}^{D-d} }[/math] such that it is a bad local minimum for the loss [math]\displaystyle{ L(\cdot) }[/math]
Theorem 3 (Global minimum). Consider the same setting as in Thm.~1. Then for all [math]\displaystyle{ (p, q), }[/math] if [math]\displaystyle{ \theta_{\star} = (e_{\star} = a_{\star}, \dots, b_{\star}) \in \mathbb{R}^{D-d} }[/math] is a global minimum for the loss [math]\displaystyle{ L(\cdot) }[/math] in the weight-tied scenario, then its extension [math]\displaystyle{ \bar{\theta}_{\star} = (\bar{e}_{\star}, \bar{a}_{\star}) \in \mathbb{R}^D }[/math] is also a global minimum for [math]\displaystyle{ L(\cdot) }[/math] in [math]\displaystyle{ \mathbb{R}^D }[/math] in the non-weight-tied case. Further, [math]\displaystyle{ \bar{\theta}_{\star} }[/math] satisfies the same properties (ii)--(iv) as in Thm.~1.
Theorem 4 (Saddle point). Consider the same setting as in Thm.~3. For [math]\displaystyle{ p + q \gt 1, }[/math] let [math]\displaystyle{ \theta_{\pi} = (e_{\pi} = a_{\pi}, \dots, b_{\pi}) \in \mathbb{R}^{D-d} }[/math] be the corresponding bad local minimum for the loss [math]\displaystyle{ L(\cdot) }[/math] in the weight-tied scenario. Then its extension [math]\displaystyle{ \bar{\theta}_{\pi} = (\bar{e}_{\pi}, \bar{a}_{\pi}) \in \mathbb{R}^D }[/math] is a saddle point for [math]\displaystyle{ L(\cdot) }[/math] in [math]\displaystyle{ \mathbb{R}^D }[/math] in the non-weight-tied case. Further, [math]\displaystyle{ \bar{\theta}_{\pi} }[/math] satisfies the same properties (ii)--(iv) as in Thm.~2.
Main Contributions
Setup for the theoretical framework in this paper
- The dataset is first-order and binary.
- The model is a single-layer transformer, with a single attention head and no layer norm.
- The loss function is cross-entropy population loss. Note: the transformer model doesn't know that the data is Markovian; in other words, it is allowed to use the entire sequence's history, as it doesn't know that each step is dependent only on the previous one.
Empirical results
Summary of this paper
1. The authors introduce a theoretical framework that models the data source as a Markov process. This allows them to study how transformers learn sequential structure, contrasting it with other approaches that treat training data simply as i.i.d.
2. Focusing on single-layer transformers trained for next-token prediction on first-order Markov data, the paper gives a detailed analysis of the cross-entropy loss surface: There is always a set of parameters that perfectly recover the true Markov transition probabilities (thus achieving the global minimum).When the sum of the Markov chain's flipping probabilities exceeds one, the model can converge to parameters that simply predict the stationary distribution rather than the true transition probabilities. This phenomenon does not arise (or becomes only a saddle) when weights are untied or when the transformer has multiple layers.
3. Through experiments, they show that the theoretical findings match empirical behaviors. In particular: When weights are tied, the model may learn a constant (stationary) probability and fail to leverage sequential context if the transition probabilities are above a certain threshold. Removing weight tying—or increasing the transformer's depth—helps avoid such bad local minima.
4. The authors extend the analysis to higher-order processes. Surprisingly, simply increasing the transformer depth does not guarantee learning of higher-order transitions. They find, however, that restricting the attention window (rather than letting the network attend to all past tokens) dramatically improves learning of higher-order Markov patterns.
Related work
1. Analyzing Transformer Architectures:
Vaswani et al. (2017) introduced the Transformer architecture, which has since been extensively studied to understand its theoretical foundations and practical applications.
Rogers et al. (2020) provided a comprehensive analysis of BERT, a model based on the Transformer architecture, discussing its linguistic capabilities and limitations.
2. Combining Attention Mechanisms with Probabilistic Models:
Bahdanau et al. (2015) introduced the attention mechanism in neural machine translation, allowing models to focus on relevant parts of the input sequence dynamically.
Graves et al. (2014) proposed the Neural Turing Machine, combining neural networks with external memory, enabling models to learn algorithmic tasks.
Future direction
1.Enhanced Interpretability of Transformer Models:
Applying the Markov chain framework to dissect and visualize the internal workings of Transformers could lead to more interpretable models, aiding in identifying and mitigating biases.
2.Development of Hybrid Models:
Integrating Markovian structures with attention mechanisms may result in hybrid models that leverage the strengths of both approaches, potentially improving performance in tasks requiring sequential dependencies.
3.Theoretical Analysis of Attention Dynamics:
Further exploration into the mathematical properties of attention modelled as Markov processes could provide deeper insights into the stability and convergence of Transformer-based models.
Group 7 Review: An Analysis of Transformers via Markov Chains
Presented by:
Jonathan Gallagher and Mariya Anashkina
Paper Citation
A. V. Makkuva et al., “Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains,” 2024, arXiv. doi: 10.48550/ARXIV.2402.04161.
Review
I thought this presentation made it really easy to see the connection between Markov Chains, and Transformers when tasked to predict future tokens. The lunch menu analogy made it really easy to see how a first order markov chain behaves, and was a great general idea to discuss before diving deeper into how Markov chains behave in the language world. Theoretical results were also well explained before linking them to empirical results and discussing why the empirical results behave the way they do.
The empirical results were also well explained, showing how removing weight tying or increasing model depth can help escape bad local minima and ensure the transformers parameters are in the global minima for first order markov chains. And how smaller context windows are required to escape bad local minima for higher-order markov chains, even for deeper models (regardless of depth or weight tying).
I also thought this presentation was really well structured, and found that the slides were really easy to follow with great visual aids.
Group 8 Presentation: MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads
Presented by:
Nana Ye and Xingjian Zhou
Summaries of key points
This paper introduces MEDUSA, a lightweight yet effective method for accelerating inference in large language models by attaching multiple decoding heads at intermediate layers. The key idea is that, during inference, you can use these heads to predict several future tokens in parallel, reducing the number of sequential steps needed. Unlike speculative decoding (which relies on a separate draft model), MEDUSA uses the same model and architecture, just augmented with extra linear heads that predict future tokens based on intermediate hidden states. These predictions are verified by the base model in a final pass, similar in spirit to Draft & Verify, but with much lower overhead and implementation complexity. Despite its simplicity, MEDUSA achieves competitive speedups—up to 2×—on models like LLaMA, and integrates easily into existing transformer pipelines. It also preserves generation quality well, maintaining high accuracy across benchmarks without requiring retraining.
One potential limitation of MEDUSA is that its performance gains depend on the quality of intermediate predictions—if the early layers aren't predictive enough, the method may yield minimal speedup or introduce verification bottlenecks. Another concern is scalability: adding too many decoding heads could increase memory consumption or introduce architectural clutter. While the paper shows good results on standard benchmarks, it's less clear how MEDUSA performs in more complex decoding scenarios like beam search, sampling with temperature, or instruction-tuned models. Finally, although it's simple to implement, any modification to production LLM inference stacks still carries deployment costs, which MEDUSA somewhat underplays.
Clear explanations to aid understanding
Think of a transformer generating text one word at a time—slow, because each step waits on the previous. MEDUSA says: what if we could guess ahead a few tokens using partial information? It adds small prediction heads at different layers in the model, each trying to guess future tokens before the final layer finishes computing. Once these guesses are made, the base model verifies them. If correct, we skip ahead; if not, we fall back. It’s like speculative decoding, but self-contained—no second model, no complicated setup. You get the parallelism of speculative methods with the simplicity of just tweaking the model's architecture slightly.
Group 8 Presentation: MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads
Presented by:
- Nana Ye
- Xingjian Zhou
Paper Citation
T. Cai et al., “Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads,” 2024, arXiv. doi: 10.48550/ARXIV.2401.10774.
https://arxiv.org/abs/2401.10774
Background
- As the size of LLMs grow, the speed at which they can generate tokens decreses. The bottleneck is primairly the transfer of data to/from the GPU
- Speculative Sampling is an existing solution that predicts multiple tokens in the future at once using smaller "draft" models
- Medusa instead solves this problem by adding multiple decoding heads and a tree based attention mechanism to existing LLMS
- Paper discusses the implementations of Medusa1 and Medusa2
Main Idea
The idea is that by replacing the draft model and use the heads within our own model since the best representation of a model is itself. The multiple heads predicts multiple tokens at once to leverage parallelism. This allows it to be more efficient and provide more tokens for the tree-based attention to choose. The tree-based attention is used is simulate the idea as if the tokens are being generated sequentially, by traversing through a tree from top to bottom, where top is the initial word.
Methodology
During training, each Medusa head is optimized to predict a future token given the same input context. For example, head 3 is trained to predict token xt+3 using only the context up to xt. The objective function is a standard cross-entropy loss over the predicted token distributions.
To avoid error accumulation, the training data for deeper heads (e.g., t+4, t+5) is generated using gold sequences rather than model outputs.
The decoding tree is constructed with a fixed depth and branching factor corresponding to the number of Medusa heads. Unlike beam search, this tree is evaluated in parallel, and scoring is done using a lightweight attention mechanism applied to the hidden states of the candidate paths.
For candidate selection, a method called typical acceptance is introduced as a fast alternative to rejection sampling. It accepts candidates based on whether their token-level probabilities fall within a "typical" range, reducing the number of evaluations needed during decoding.
Technical Contributions
Medusa 1:
- Uses a frozed pre-trained LLM and trains extra decoding heads on top
- Each additional decoding head predicts a token K time steps in the future
- Uses a probability loss function that scales based on the number of steps into the future
- Reduces memory usage because the backbone model is only used for hidden state extraction
- In simple terms, Medusa adds additional linear layers on top of the last hidden layer from the transformer output which are training to predict the tokens in future positions, rather than just the next token like a conventional transformer mechanism does in a typical auto-regressive manner.
Medusa 2:
- Fine tunes the LLM and trains the decoding heads at the same time.
- Encountered problems with high losses, switched to a two-stage training process:
- Stage1: train only the Medusa heads (simillar to Medusa1)
- Stage2: Train both the backbone model and the medusa heads together
Tree Attention
- Tree attention is used to enable the heads predicting later tokens to include the additional context which may have been created by the earlier medusa heads in the pipeline
- This tree structure does not occur autoregressively, however
- The top predictions from each head are fed into the tree structure as candidate tokens
- An attention mask is used to ensure that the future token prediction from the tree is based on prior tokens, not future ones past the one being dedicated
- Multiple future candidate tokens can be predicted with context-aware attention simultaneously
Self Distillation
- A dataset with prompts relevant to the desired model are created
- The full large language model predicts outputs to these prompts in a typical auto regressive manner. These prompts are used to form a training dataset for the self-distillation step
- Medusa Heads are trained on the generated training dataset
Tree Construction
- Prune Less Promising Branches: Branches with low probability of containing the next token are pruned from the tree of candidate tokens in tree attention, this reduces the computational expensiveness of MEDUSA 2
- Select the Best Candidate: From the remaining typical candidates, the longest accepted prefix is chosen for the next decoding step
Empirical Evaluation
Experiments on various LLMs show consistent 2–3 times speedups in practice without harming output quality (assessed by GPT-4 and other metrics). The authors also include ablation studies on key design choices (number of heads, attention structure, sampling thresholds), confirming the effectiveness and generality of the proposed framework.
Constructive Critique
While Medusa demonstrates promising improvements in decoding speed and flexibility, a few limitations remain:
- Training Stability: Especially in Medusa-2, jointly fine-tuning heads and the base model requires a two-stage schedule with learning rate separation and warmup—indicating some instability.
- Model Complexity: Adding multiple Medusa heads and tree attention introduces architectural complexity, which may hinder adoption or reproducibility without careful engineering.
- No open-source code: As of the paper's release, there is no official implementation, which limits replication and community engagement.
Group 8 Presentation: MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads
Presented by:
- Nana Ye
- Xingjian Zhou
Paper Citation
T. Cai et al., “Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads,” 2024, arXiv. doi: 10.48550/ARXIV.2401.10774.
https://arxiv.org/abs/2401.10774
Summaries of key points
One of the main challenges with LLMs is their slow inference process. It’s not because GPU can’t do the math faster, but because of memory bottleneck (the need to constantly move data between memory and the GPU). LLMs also generate text one token at a time using autoregressive decoding, where each token depends on the previous one, which leads to underutilization of the GPU.
One attempt to address this is Speculative Decoding, where a smaller "draft model" predicts multiple tokens at once, and the main model verifies them. While this speeds up generation, it requires maintaining a separate draft model, and it can lead to mismatches between the draft and main models, making integration into a system difficult.
So, Medusa is proposed to solve these issues.
How does Medusa work?
Core Idea of Medusa
Unlike speculative decoding, Medusa doesn’t need a separate draft model. Instead, it adds extra decoding heads to the same model, allowing it to predict multiple tokens at once (generating candidates). This reduces the dependency on previous tokens, making inference faster. Once multiple tokens are predicted, Medusa organizes these predictions into a tree structure, where each node represents a token. This structure helps Medusa evaluate several possible token sequences at once (processing candidates). This is called tree-based attention. After evaluating, Medusa picks the most likely token sequence and outputs the best one (accepting candidates).
Training Strategies of Medusa
Medusa has two training strategies to optimize its performance:
- Medusa 1: In this method, the original model (called the backbone) is frozen, meaning its parameters don’t change. Only the Medusa heads are trained. This saves computation and avoids the model forgetting what it learned originally, while improving inference speed by predicting multiple tokens at once.
- Medusa 2: In this approach, both the backbone and the Medusa heads are trained together. This boosts prediction accuracy, especially in larger models. The training starts with a two-stage process to prevent issues like high loss and large gradients from the new Medusa heads that could disrupt the backbone model. In Stage 1, only the Medusa heads are trained to specialize without affecting the main model. In Stage 2, both the Medusa heads and the backbone are trained together, with a warm-up strategy to gradually increase the learning rate for smoother adaptation. The tree attention mechanism also helps organize token continuations in a tree, allowing the model to evaluate multiple possibilities at once, speeding up inference.
Further enhancements of Medusa
The author further enhances Medusa's practical utility with three significant extensions:
1. Typical Acceptance Scheme: Instead of rejecting candidates based on strict probability thresholds, Medusa evaluates them based on how typical they are compared to the original model’s distribution. This speeds up the process without sacrificing quality.
2. Self-Distillation: Medusa can generate training data from its own output. It creates a seed dataset and then uses that data to train the Medusa heads, which helps improve the model’s efficiency in making predictions.
3. Optimized Trade Structure: Medusa improves how the model evaluates candidates by focusing on the most promising tokens, making the inference process faster and more efficient.
Benefits and Limitations
Medusa has shown great results in experiments with models like Riccuna-7B and Riccuna-13B, achieving up to 2.8 times faster performance than traditional methods, with even higher speedups for tasks like text extraction (3.62x) and coding(3.29x). It consistently outperformed speculative decoding, reaching 2.83x acceleration with Riccuna-7B, compared to 1.47x with speculative decoding. Despite the speed improvements, Medusa maintains high text quality, making it efficient without compromising accuracy. The tree-based attention mechanism further boosted speed by 1.9x. However, Medusa has some limitations, such as higher memory usage due to additional heads and the tree attention mechanism, and it can be challenging to scale for larger models and complex tasks.
Group 9 Presentation: Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality
Presented by:
- Kaiyue Ma
- Wenzhe Wang
Paper Citation
T. Dao and A. Gu, “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality,” 2024, arXiv. doi: 10.48550/ARXIV.2405.21060.
Background
- Transformers are effective, but computationally expensive and suffer from quadratic complexity
- Structured state space models (SSMs) are an alternative that scales linearly instead and works for long range modeling
- SSMs have not recieved the same main stream improvements as transformers, and lack support for parallelization and hardware acceleration
- Structured state space duality (SSD) bridges the gaps between transformers and SSMs
Additional Background
SSM(State Space Models) are traditionally used in control theory to model a dynamic system via variables. But then from this paper https://compneuro.uwaterloo.ca/files/publications/voelker.2018.pdf, they discovered that SSM is great for describing the time cells in the brain. A useful diagram of a SSM can be found here https://cdn-uploads.huggingface.co/production/uploads/613b0a62a14099d5afed7830/G7icfkYoxIqHZcJGHM7UD.png, where, n state variables, u, m state inputs, and y, p outputs.
Technical Contributions
- Represents SSMs as semiseparable matrices and uses semiseparable matrices for efficient matrix operations.
- Uses generalized linear attention mechanism with structured masks.
- Refines the original Mamba model to yield Mamba-2, which incorporates the new structured-state-space-duality algorithms. Mamba-2 is easily parallelizable and scales better to large state sizes. Empirical results demonstrate strong performance on language modelling benchmarks, surpassing older SSM models and matching or outperforming Transformers at various model scales.
Summaries of Key Points
The paper explores the theoretical connections between Transformer architectures and Structured State Space Models (SSMs). The authors introduce the State Space Duality (SSD) framework, which bridges these two model families through the concept of structured semiseparable matrices. This framework reveals that certain attention mechanisms in Transformers can be interpreted as SSMs, providing a unified perspective on sequence modelling techniques.
Leveraging the SSD framework, the authors propose a new architecture called Mamba-2. This model refines the selective SSM approach used in the original Mamba model, resulting in a design that is 2-8 times faster while maintaining competitive performance in language modelling tasks. Mamba-2 achieves this efficiency by simplifying the SSM layer, enabling better scalability and computational speed.
The paper also introduces efficient algorithms based on block decompositions of semiseparable matrices, which enhance the computational efficiency of SSMs. These algorithms allow for larger recurrent state sizes and improve the practicality of SSMs in handling long-range dependencies within sequences.
Empirical evaluations demonstrate that Mamba-2 outperforms both its predecessor and Transformer models in terms of training efficiency and performance on language modelling benchmarks. The architecture also shows superior capabilities in associative recall tasks, highlighting its effectiveness in capturing and utilizing long-range dependencies.
In summary, the paper provides a theoretical foundation connecting Transformers and SSMs, introduces the Mamba-2 architecture as a practical application of this theory, and presents algorithms that enhance the efficiency and scalability of sequence modelling techniques.
Constructive Critique
While the paper introduces a powerful theoretical framework through State Space Duality (SSD) and demonstrates strong empirical performance with Mamba-2, several areas could be further clarified or improved:
- Theoretical accessibility: The concept of semiseparable matrices and duality between attention and SSMs is mathematically rich, but not easily accessible to a broader audience. Including more visual or intuitive explanations would improve its pedagogical impact.
- Benchmark diversity: Most experiments focus on language modeling tasks. It remains unclear how Mamba-2 performs on other domains such as vision, speech, or reinforcement learning. Since SSD is a general framework, cross-domain validation would help showcase its broader applicability.
- Scalability limitations: While Mamba-2 is more efficient than its predecessor, the paper doesn’t fully discuss how performance scales with increasing model depth or state size, especially under training constraints on real-world hardware.
- Lack of interpretability analysis: The paper does not explore how the SSD framework or Mamba-2 influences model interpretability (e.g., how information is propagated or stored over long sequences), which could be important for downstream applications.
Despite these limitations, the paper makes a substantial theoretical and practical contribution by unifying two dominant modeling paradigms and offering a concrete architecture (Mamba-2) that is both efficient and performant.
Group 10 Presentation: Accelerating Large Language Model Decoding with Speculative Sampling
Paper Citation
C. Chen, S. Borgeaud, G. Irving, J.-B. Lespiau, L. Sifre, and J. Jumper, ‘Accelerating Large Language Model Decoding with Speculative Sampling’, Feb. 02, 2023, arXiv: arXiv:2302.01318. doi: 10.48550/arXiv.2302.01318.
https://arxiv.org/abs/2302.01318
Background
- Traditional autoregressive decoding for large language model is computationally expensive, as the entire model has to run for each additional token which is generated
- Transformer neural networks are typically memory bandwidth limted, and using quantization or distillation to make smaller models are solutions which have been used in the past to improve the performance of LLMs
Technical Contributions
- Speculative sampling was developed to increase the speed of LLM decoding without meaningfully reducing it's performance on predicting future tokens compared to the base model
- Generates a draft of k future tokens using a smaller model
- Score the proposed tokens using the base model
- A modified rejection sampling scheme was developed by the authors in the paper
- The acceptance of a draft token is based on the minimum of 1 and the ratio of the target model's probability to the draft model's probability for that token.
Explanations to Aid Understanding
- Transformer Decoding: This is the process by which a large language model generates text. Given an input prompt, the model sequentially selects the most probable next token, then uses that token to predict the subsequent one, and so on, and this process is computationally intensive.
- Speculative Sampling: Unlike traditional decoding where one token is generated at a time by the target model, speculative sampling aims to generate multiple tokens in parallel by using a faster, potentially smaller "draft model" to propose a short sequence (the draft). The target model evaluates these drafted tokens, and a rejection sampling mechanism decides which ones to accept, ensuring the output remains consistent with the target model's capabilities.
- Parallel Scoring: Instead of computing the logits for each drafted token sequentially, the method computes the logits for all (K+1) tokens in the draft at the same time. The presentation notes that the computing time for this parallel process is similar to sampling a single token with the target model, which is a key factor in the potential speedup. The key insight is that the model inference pass is dominated by memory bandwidth and inter-device communication rather than purely by the token count. By handling several drafted tokens per pass, overall decoding time is greatly reduced.
Summaries of Key Points
- Decoding Challenges in LLMs: Traditional autoregressive sampling methods generate one token at a time, leading to inefficiencies, especially as model
- Speculative Sampling Overview: SpS utilizes a draft model, which is a smaller and faster version of the target LLM, to propose a sequence of tokens. The target model then evaluates this proposed sequence, accepting or rejecting tokens based on a modified rejection sampling scheme that ensures the final output aligns with the target model's distribution.
- Algorithm Efficiency: The draft model generates a short sequence (e.g., K tokens) in parallel. The target model scores this sequence, and tokens are accepted or resampled as needed, allowing for the potential acceptance of up to K+1 tokens per iteration. This parallel approach contrasts with traditional methods that generate tokens sequentially, thereby reducing decoding latency.
- Empirical Results: Implementing SpS with Chinchilla, a 70-billion-parameter language model, resulted in a 2 to 2.5 times speedup in decoding across benchmark tasks such as XSum and HumanEval. These speed improvements were achieved without degrading sample quality or requiring changes to the target model's parameters or architecture.
- Advantages of Speculative Sampling: Maintains the integrity of the target model's output distribution. Requires no alterations to existing model architectures, facilitating easy integration into current systems. Demonstrates versatility across various tasks and decoding methods, making it broadly applicable in LLM deployments.
Group 10 Presentation: Accelerating Large Language Model Decoding with Speculative Sampling
Presented by:
Danyang Zhao
Speculative Sampling: An In-Depth Explanation
Speculative Sampling is a technique introduced to speed up text generation for large language models (LLMs) without significantly affecting output quality. It works by leveraging a smaller, faster “draft” model to propose multiple candidate tokens, which are then verified and accepted or rejected by the larger model. This method reduces latency during decoding while maintaining output quality close to the original model.
But, what specifically is Latency? Latency latency refers to the time delay between a user requesting text generation and the model producing the output. More formally, it is the time required to generate each token in an autoregressive model.
But why do we need Speculative Sampling? Because there re problems with the Standard Autoregressive Sampling. Mainly:
- LLMs generate text token-by-token. Each step depends on the previous tokens, making sequential decoding slow.
- The computational bottleneck comes from generating the next token, which requires a full forward pass of the model.
Therefore the goal of speculative sampling is to reduce the number of forward passes of the large language model per generated token.
How Speculative Sampling Works:
1. There are two models, a Draft model and a Large Model. The smaller draft model, which is cheap and fast to compute, generates a set of k speculative candidate tokens. Then, the large model, which is expensive yet accurate, verifies these tokens and either accepts or rejects them.
2. The Draft model proposes multiples tokens, that is, instead of sampling just on token a each step, the draft model generates a sequence of k candidate tokes: [math]\displaystyle{ $x_{t+1}, x_{t+2}, ..., x_{t+k} }[/math] using the standard autoregressive decoding.
3. Now, the large language model verifies the proposal. This is done by calculating the probability of each proposed tokens:
[math]\displaystyle{ p(x_{t+1} | x_{\le t}), p(x_{t+2} | x_{\le t+1}), ..., p(x_{t+k} | x_{\le t+k}) }[/math]
For each Candide, there are only two possibilities. 1). The token is accepted and is added to the output. 2). The token is rejected and the large model takes over and directly generates a new token.
4). Since the drat model is smaller, it generates multiples speculative tokens quickly. Then, the forward model only computes a few forward passes, making the process more efficient. This allows for faster decoding by efficient verification.
Now, let us go over a bit more of the mathematics.
For each token [math]\displaystyle{ x_i }[/math], we check:
[math]\displaystyle{ p(x_i | x_{\lt i}) \ge q(x_i | x_{\lt i}) }[/math]
Where p is the probability distribution of the large model and q that of the draft model. If true, it is accepted, else, it is not.
This is paired with a Metropolis-style acceptance probability, which ensures the final sampled tokens remain statistically valid while speeding up computation.
The acceptance probability can then be calculated as follows:
[math]\displaystyle{ \alpha_i = \min \Big( 1, \frac{p(x_i | x_{\lt i})} {q(x_i | x_{\lt i})} \Big) }[/math]
Group 11 Presentation: Simple Linear Attention Language Models Balance the Recall-Throughput Tradeoff
Presented by:
Yiyuan Yang, Anar Kuatzhan, Chuan Zhang
Paper Citation
Arora, S., Eyuboglu, S., Zhang, M., Timalsina, A., Alberti, S., Zinsley, D., ... & Ré, C. (2024). Simple linear attention language models balance the recall-throughput tradeoff. arXiv preprint arXiv:2402.18668. https://arxiv.org/pdf/2402.18668
Background
Nowadays large language models still struggle with efficiency. In attention-based models, attention requires a huge number of calculations as the input gets longer. Attention also stores every previous word in memory, which makes it memory-intensive. New language models developed in recent years can generate text faster while maintaining low perplexity. Low perplexity doesn't necessarily mean good recall. Gated convolutional models also struggle with recall. Attention-based models excels in recall tasks. To address these problems, the authors introduced the Based model.
Technical Contributions
The authors introduce a new architecture called BASED (Balanced Attention through Sliding + Exponential Decay), which is designed to address the recall–throughput tradeoff in language models. It combines two key components:
- Linear Attention (Global Context):
- Uses a second-order Taylor approximation of softmax attention.
- Enables constant-size memory state during autoregressive decoding.
- Sliding Window Attention (Local Context):
- Performs exact attention within a small local window (e.g., 64–128 tokens).
- Captures short-range dependencies with high precision.
Memory-Recall Tradeoff: Observed both within and across architecture classes.
Performance with Fixed Recurrent State: Not all architectures have the same recall capacity. Mamba optimally utilizes limited memory budgets while convolutional architectures underperform with memory constraints.
The Based model combines local fine-grained attention + long-range linear attention via Taylor approximation of softmax exponential function that are sub-quadratic complexity during training and permit an efficient recurrent inference view. Based outperforms prior sub-quadratic architectures in recall quality by up to 6.2 accuracy points.
Architecture
Softmax-approximating linear attention (applied globally) + exact softmax attention with sliding windows (applied locally)
This combination achieves 90.8% of full softmax attention's recall accuracy while reducing latency by a factor of 100,000.
Accomplishments
- Improved Recall:
- BASED outperforms Mamba by up to 10.36 accuracy points on recall-intensive tasks.
- Recovers over 90% of softmax attention’s recall performance while using significantly less memory.
- High Throughput:
- Achieves up to 24× higher generation throughput compared to FlashAttention-2.
- Competitive wall-clock time due to efficient CUDA kernel design.
- Strong Language Modeling:
- Matches or surpasses models like Mamba and Hyena in perplexity and downstream task accuracy.
- Theoretical Contributions:
- Demonstrates that recurrent models require [math]\displaystyle{ \Omega(N) }[/math]-bit memory to perform associative recall.
- Proves that BaseConv, a gated convolution model, cannot solve recall tasks in constant layers.
- Shows that the recall-throughput tradeoff is theoretically fundamental.
Group 12 Presentation: EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees
Presenters
Mutong Zhang, Hanqing Bi
Paper Citation
Y. Li, F. Wei, C. Zhang, and H. Zhang, ‘EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees’, Jun. 30, 2024, arXiv: arXiv:2406.16858. doi: 10.48550/arXiv.2406.16858.
https://arxiv.org/abs/2406.16858
Background
- LLMs to date have shown great performance, but they are slow and computationally expensive
- Speculative sampling - small model generates candidate tokens, whereas large model then evaluates those tokens, reducing the number of times the expensive computations of the large model have to occur
- EAGLE 1 used a draft tree to improve performance of speculative execution, and this builds on the author's prior work
- The authors observe that acceptance rates are also context-dependent, suggesting the need for adaptive drafting.
Paper Contributions
- EAGLE 1 doesn't directly predict tokens, and rather maps tokens to features, predicts features, and then predicts tokens from those features
- This work was shown to improve upon previous speculative execution methods
- Eagles uses a tree structure to propose alternative token when speculative sampling draft token is rejected by the full model
- EAGLE 2 noted that token acceptance is dependant on context and position. The first token has a high acceptance rate, and later tokens have lower acceptance rates
- EAGLE 2 uses a dynamic draft trees which incorporate "tree attention" to incorporate context information into selecting the next candidate token, increasing the acceptance rate of the token, as it depends on context as well as not only position
Summaries of Key Points
Addressing the Challenge of Slow LLM Inference
Large language models (LLMs) have revolutionized natural language processing, but their inference remains computationally expensive and slow. This bottleneck arises due to the vast number of model parameters and the sequential nature of token generation. Improving inference efficiency without sacrificing accuracy is a critical research direction in the field.
The Foundation: Speculative Sampling and Draft Models
EAGLE-2 builds on speculative sampling, a technique that leverages a smaller, lightweight draft model to propose candidate tokens, which are then verified by the full LLM. This approach speeds up inference by reducing the number of computations performed by the large model.
Evolution from EAGLE-1 to EAGLE-2
The original EAGLE-1 introduced a static draft tree, assuming that token acceptance rates depend solely on their position within the tree structure. However, this assumption overlooks the context-dependent nature of token acceptance. EAGLE-2 addresses this limitation by introducing a dynamic draft tree adjustment mechanism, which refines speculative sampling based on real-time confidence scores.
Key Mechanisms of EAGLE-2
1. Instead of directly predicting tokens, EAGLE-2’s draft model generates feature representations, which are then processed by the head of the LLM to produce token predictions. This method enhances accuracy while maintaining efficiency.
2. EAGLE-2 employs a tree-based verification mechanism, which is computationally more efficient than the standard chain-structured verification used in traditional speculative sampling. By verifying multiple candidate tokens in parallel, it accelerates the overall process.
3. A core innovation in EAGLE-2 is its ability to dynamically modify the draft tree based on confidence scores from the draft model. This ensures that speculative sampling adapts to varying contexts, improving token acceptance rates and overall inference speed.
Dynamic Expansion and Re-Ranking
1. Expansion Phase: EAGLE-2 introduces a novel expansion phase, where a tree-attention mechanism processes all tokens in a layer simultaneously. This significantly enhances efficiency compared to sequential processing. Additionally, selective expansion prioritizes only the top-k tokens with the highest estimated global acceptance probabilities, preventing unnecessary computational overhead.
2. Re-Ranking Phase: Following expansion, EAGLE-2 reranks candidate tokens by selecting the top-m tokens with the highest acceptance probabilities. In cases where multiple tokens have similar scores, shallower nodes in the draft tree are prioritized, further optimizing verification speed.
Experimental Results: Significant Performance Gains
EAGLE-2 achieved acceleration ratios of 3.05x - 4.26x across various tasks and large language model series such as Kuna, Llama 2 and Llama 3, making it 20% - 40% faster than EAGLE-1. It is also approximately 2 times faster than Medusa and 2.3 times faster than Lookahead. On token throughput, EAGLE-2 processes 4-5.5 tokens per verification cycle, about twice as many as traditional speculative sampling.
Key Advantages of EAGLE-2
1. Plug-and-Play Efficiency – EAGLE-2 requires no additional model training, as it seamlessly integrates the pre-trained draft model from EAGLE-1. It does not alter the original LLM, and maintains the exact same output distribution as greedy decoding (i.e., it is lossless).
2. Robust and Reliable – Unlike some acceleration techniques, EAGLE-2 does not modify the original model parameters or relax acceptance conditions, ensuring stable and consistent outputs.
3. Broad Generalization – The framework generalizes well across different tasks and architectures, demonstrating strong adaptability in diverse applications.
EAGLE-2 represents a significant advancement in accelerating LLM inference. By introducing a dynamic draft tree, efficient expansion strategies, and intelligent token re-ranking, it substantially reduces computational costs while maintaining accuracy. As large-scale models continue to grow, techniques like EAGLE-2 will be instrumental in making LLMs more practical and accessible for real-world applications.
Accomplishments
- State-of-the-art inference speedup:
- EAGLE-2 achieves up to 5× speedup in language model inference.
- Outperforms prior speculative decoding methods including EAGLE, Medusa, Lookahead, and others.
- Longest average acceptance length:
- EAGLE-2 generates longer sequences per accepted draft, reducing the number of calls to the target model.
- Wide applicability:
- Tested across 6 diverse tasks, including:
- Conversation
- Code generation
- Mathematical reasoning
- Question answering
- Summarization
- Instruction following
- Tested across 6 diverse tasks, including:
- Model-agnostic design:
- Compatible with popular LLMs such as:
- Vicuna
- LLaMA2-Chat
- LLaMA3-Instruct
- Compatible with popular LLMs such as:
Group 13 Presentation: Linear Attention Mechanism: An Efficient Attention for Semantic Segmentation
Presented By
Yuke Liu, Mei Si
Paper Citation
R. Li, J. Su, C. Duan, and S. Zheng, ‘Linear Attention Mechanism: An Efficient Attention for Semantic Segmentation’, Aug. 20, 2020, arXiv: arXiv:2007.14902. doi: 10.48550/arXiv.2007.14902.
https://arxiv.org/abs/2007.14902
Background
- Existing transformer models have [math]\displaystyle{ \mathcal{O} (n^2) }[/math] complexity, which is problematic as model size grows
- This limits the growth of model sizes due to computational resources constraints
- This paper focused on an alternative method to conventional dot product attention that is more computationally efficient
- Standard attention required the computation of [math]\displaystyle{ Q K^\top }[/math] which requires [math]\displaystyle{ \mathcal{O} (n^2) }[/math] complexity
- The paper proposes a linear attention mechanism that solves the problem while keeping the performance.
Technical Contributions
Rather than doing the full computation for the softmax in the transformer architecture, the authors instead compute
[math]\displaystyle{ D(\mathbf{Q}, \mathbf{K}, \mathbf{V})_i = \frac{\sum_{j=1}^{N} e^{\mathbf{q}_i^{T} \mathbf{k}_j} \mathbf{v}_j}{\sum_{j=1}^{N} e^{\mathbf{q}_i^{T} \mathbf{k}_j}} = \frac{\sum_{j=1}^{N} \text{sim}(\mathbf{q}_i, \mathbf{k}_j) \mathbf{v}_j}{\sum_{j=1}^{N} \text{sim}(\mathbf{q}_i, \mathbf{k}_j)} }[/math]
and define the transformation function as
[math]\displaystyle{ \text{sim}(\mathbf{q}_i, \mathbf{k}_j) = \phi(\mathbf{q}_i)^{T} \phi(\mathbf{k}_j) }[/math]
The authors apply a first-order Taylors series expansion, and after some rearranging and substiution arrive to their final formula (full derivation not shown)
[math]\displaystyle{ D(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \frac{ \sum_j \mathbf{v}_{i,j} + \left( \frac{\mathbf{Q}}{\lVert \mathbf{Q} \rVert_2} \right) \left( \left( \frac{\mathbf{K}}{\lVert \mathbf{K} \rVert_2} \right)^{T} \mathbf{V} \right) }{ N + \left( \frac{\mathbf{Q}}{\lVert \mathbf{Q} \rVert_2} \right) \sum_j \left( \frac{\mathbf{K}}{\lVert \mathbf{K} \rVert_2} \right)_{i,j}^{T} } }[/math]
The authors form of the attention mechanism can be solved in [math]\displaystyle{ \mathcal{O} (n) }[/math] complexity, reducing the computational scaling that existing with conventional transformers and enabling the creation of larger models using the underlying attention mechanism.
- Hybrid Attention Mechanisms for Selective Efficiency
Idea: Combine linear attention with local or sparse attention in a hybrid structure where the system dynamically chooses the best mechanism depending on context (e.g., object density, resolution, or modality). The rationale is that the dense areas of an image (e.g., urban environments) may benefit from sparse or local attention for spatial precision. Less complex regions (e.g., sky, ocean) can rely on linear attention for efficiency. It could be implemented with a gating module that selects or blends attention types.
Model Performance Evaluation
The model performance evaluation primarily focused on assessing how well the proposed linear attention mechanism enhances semantic segmentation performance while maintaining computational efficiency.
- Dataset: The experiments were conducted using the Fine Grained Image Data set, which consists of high-resolution satellite images. This dataset was selected due to its complex landscapes and varied environmental conditions, presenting significant challenges.
- Data Preprocessing: Due to the large size of the images in the dataset, each image was divided into smaller patches of 256x256 pixels, resulting in a total of 7,280 patches.
- Data Split: These batches were methodically partitioned into 60% for training, 20% for validation, and 20% for testing. This split ensured a rigorous evaluation of the model's performance across different scenarios.
- The standard dot-product attention is replaced by the linear attention mechanism in the baseline segmentation models such as U-Net, Res101, DeepLab, PSPNet, etc
- Implementation Details: The experiments were implemented using PyTorch framework and trained on an NVIDIA RTX 2080Ti GPU.
- Evaluation Metrics: OA, AA, K, mloU, F1.
Comparative Analysis
The proposed linear attention mechanism achieves a similar level of accuracy in semantic segmentation tasks compared to traditional dot product attention.
Linear attention maintains comparable performance metrics while drastically reducing both memory footprint and processing time required.
The efficiency gains of linear attention become more apparent in large-scale segmentation tasks, making it suitable for high-resolution images and long text sequences.
Future work includes extending the linear attention mechanism to more complex network designs, integrating with state-of-the-art deep learning models, and optimizing for real-time processing in demanding applications.
Method | OA | AA | K | mIoU | F1 |
---|---|---|---|---|---|
U-Net | 86.378 | 74.532 | 83.357 | 64.516 | 75.532 |
U-Net LAM | 87.692 | 77.297 | 84.935 | 68.038 | 78.593 |
Res101 | 89.251 | 80.451 | 86.846 | 72.433 | 82.510 |
Res101 LAM | 90.178 | 82.757 | 88.041 | 74.085 | 83.105 |
RefineNet | 89.857 | 81.169 | 87.597 | 73.167 | 83.113 |
RefineNet LAM | 90.214 | 83.544 | 88.083 | 74.973 | 84.311 |
DeepLab | 89.388 | 80.905 | 87.079 | 71.809 | 81.077 |
DeepLab LAM | 89.576 | 81.692 | 87.287 | 72.827 | 82.702 |
DeepLabV3+ | 90.125 | 81.483 | 87.959 | 72.668 | 81.492 |
DeepLabV3+ LAM | 90.315 | 81.695 | 88.182 | 73.727 | 82.736 |
PSPNet | 90.573 | 82.211 | 88.485 | 74.797 | 83.761 |
PSPNet LAM | 90.725 | 83.088 | 88.677 | 75.695 | 84.480 |
FastFCN | 90.336 | 83.625 | 88.221 | 74.364 | 83.704 |
FastFCN LAM | 90.835 | 83.075 | 88.769 | 75.174 | 84.023 |
Group 13 Presentation: Linear Attention Mechanism: An Efficient Attention for Semantic Segmentation
Presented By
Yuke Liu, Mei Si
Paper Citation
R. Li, J. Su, C. Duan, and S. Zheng, ‘Linear Attention Mechanism: An Efficient Attention for Semantic Segmentation’, Aug. 20, 2020, arXiv: arXiv:2007.14902. doi: 10.48550/arXiv.2007.14902.
https://arxiv.org/abs/2007.14902
Introduction
(1) Attention Mechanism
The attention mechanism gained prominence for its capacity to refine feature maps and to capture long-range dependencies in both computer vision and natural language processing. By focusing computational resources on the most relevant portions of the input data, attention mechanisms have proven especially valuable in tasks that demand context from large input sizes—such as high-resolution images or lengthy text sequences.
(2) Dot Product Attention
A widely used form of attention is the dot product attention. It forms the core of many state-of-the-art models, including Transformers in NLP and non-local modules in computer vision. Despite its strengths, dot product attention scales poorly (quadratically) with input size in both memory and computational cost.
(3) Problem Statements
Because the dot product attention mechanism requires [math]\displaystyle{ O(N^2) }[/math] operations for inputs of length [math]\displaystyle{ N }[/math] , deploying it on large inputs—e.g., high-resolution images or long sequences—becomes infeasible. This computational bottleneck has motivated research into more efficient variants of attention, particularly approaches that reduce the quadratic cost.
Overview of Attention Mechanismn
(1) Scaling Attention
In contrast to dot product attention, “scaling” attention mechanisms (sometimes also called “channel” or “spatial” attention in the literature) aim to emphasize informative features while reducing redundant information. Examples include: Squeeze-and-Excitation (SE) modules, which learn channel-wise scaling factors. Convolutional Block Attention Module (CBAM), which extends SE by considering both channel and spatial attention. These approaches do not necessarily learn long-range dependencies across positions in the same way as dot product attention; instead, they focus on enhancing or diminishing particular channels or spatial locations.
(2) Linear Attention: Taylor Expansion → Linear Time
Recent research efforts address the [math]\displaystyle{ O(N^2) }[/math] complexity of dot product attention by re-formulating the similarity function (the key step in attention). One family of methods uses kernel-based approximations, while another employs mathematical expansions. By approximating the exponential term of the softmax function (commonly used in dot product attention) with a first-order Taylor expansion, it is possible to make attention computations linear with respect to [math]\displaystyle{ N }[/math]. This insight forms the basis for linear attention mechanisms.
(3) Semantic Segmentation
Semantic segmentation requires dense predictions over every pixel of an image, and capturing global context is crucial. Traditional convolution-based networks, whether they follow a DilatedFCN (e.g., DeepLab, PSPNet) or an Encoder-Decoder (e.g., U-Net) architecture, benefit significantly from attention modules that can encode relationships across spatial locations. However, if the resolution is high, dot product attention quickly becomes prohibitively expensive, motivating more efficient variants like linear attention.
Dot Product Attention Details
(1) Query, Key, and Value
Dot product attention transforms each input feature [math]\displaystyle{ x_i }[/math] into three different representations:
Query([math]\displaystyle{ q_i }[/math]) Key([math]\displaystyle{ k_i }[/math]) Value([math]\displaystyle{ v_i }[/math])
The core operation then measures similarity between every query–key pair to weight the contribution of all values in forming an output.
(2) Kernel-Based Approximation
A known strategy to reduce complexity is to view the exponential term in dot product attention as operating in a kernel space, thereby factorizing the softmax attention to achieve lower complexity. Various works replace the explicit softmax with kernel-based transformations, enabling a more efficient computation.
Linear Attention Mechanism
Linear attention leverages a Taylor expansion to approximate the exponential term. By applying additional constraints (such as [math]\displaystyle{ L_2 }[/math] normalization to keep the term non-negative), the resulting formulation scales as [math]\displaystyle{ O(N) }[/math] rather than [math]\displaystyle{ O(N^2) }[/math]. This makes the mechanism practical for high-resolution inputs or very long sequences, significantly broadening the usability of attention in semantic segmentation and beyond.
Experimental Settings
(1) Dataset
Many experiments on linear attention for segmentation are conducted using large, high-resolution datasets. One common benchmark highlighted is a satellite imagery dataset (e.g., Fine Gaofen Image Dataset, GID). These datasets typically comprise large aerial images that are partitioned into patches, with splits for training, validation, and testing.
(2) Model Implementations Baseline segmentation networks (e.g., PSPNet, DeepLab series, U-Net variants, FastFCN, and RefineNet) integrate the proposed linear attention modules in place of, or alongside, standard attention mechanisms. The training setup often employs standard optimizers (such as Adam), cross-entropy loss, and hardware accelerators (like NVIDIA GPUs).
(3) Evaluation Metrics Common segmentation metrics include:
Overall Accuracy (OA) Average Accuracy (AA) Kappa Coefficient (K) mean Intersection over Union (mIoU) F1 Score
Results and Experimental Improvement
In benchmark experiments, linear attention typically boosts the performance of various baseline segmentation models while reducing memory and computational overhead. For example, when embedded into U-Net, PSPNet, or DeepLab, linear attention can achieve a higher mIoU and Kappa coefficient compared to the original dot product attention. These gains confirm that the approximation introduced by the Taylor expansion still captures sufficient global context for accurate segmentation.
Comparative Analysis
Limitations:
Approximation Error: The first-order Taylor expansion introduces approximation errors relative to exact dot product attention. Although experiments show minimal performance degradation, in certain tasks or extreme input scales, further refinements or higher-order terms might be necessary.
Architecture Constraints: Integrating linear attention can require modifications to the existing network design, including normalization steps and careful parameter initialization.
Conclusion and Future Work
Linear attention mechanisms substantially reduce both memory usage and computational cost in high-resolution vision tasks, making them a promising alternative to dot product attention. Future research may involve:
Extending linear attention to multi-modal domains where extremely large inputs (e.g., video or multi-spectral data) are common.
Investigating higher-order approximations that may yield even more accurate results while retaining near-linear scalability.
Combining linear attention with other efficient modules (e.g., lightweight convolutions or quantization techniques) to further push the boundaries of real-time segmentation on resource-constrained devices.
Group 14 Presentation: Scalable Watermarking for Identifying Large Language Model Outputs
Presented by:
Ryan Tymkow and Benjamin Schnapp
Paper Citation
Dathathri, S., See, A., Ghaisas, S. et al. Scalable watermarking for identifying large language model outputs. Nature 634, 818–823 (2024). https://doi.org/10.1038/s41586-024-08025-4.
Summaries of key points
This paper tackles the problem of watermarking—embedding a detectable signal into the outputs of large language models (LLMs) to distinguish generated text from human-written content. The challenge is doing this in a way that is robust, scalable, and minimally intrusive to the model's output quality. The authors propose a method based on statistical biasing of token selection during generation. Specifically, they partition the vocabulary into “greenlist” and “redlist” tokens at each step based on a seeded hash function, and then bias sampling toward the greenlist using a tunable parameter. The watermark is invisible in individual outputs but detectable over longer texts using hypothesis testing. Importantly, the approach is model-agnostic, doesn’t require retraining, and adds very little computational overhead. It also scales well to large models and can be applied in high-throughput or real-time settings. Overall, it’s a lightweight yet effective strategy for watermarking that balances detectability, scalability, and usability.
A key limitation of this approach is that it may still be vulnerable to paraphrasing or text transformation—simple rewriting could break the statistical signature. Another concern is adversarial robustness: since the watermarking method is relatively transparent (based on vocabulary partitioning), a knowledgeable attacker could design strategies to erase or spoof the signal. Additionally, while the method maintains fluency and quality in most cases, biasing token selection could subtly affect stylistic or semantic nuances, especially in creative writing or long-form tasks. The paper doesn’t deeply explore how this might influence user-facing applications like chatbots or summarizers. Lastly, while the watermark is statistically detectable, it’s not embedded in a cryptographic sense, so it may not offer strong guarantees in high-stakes verification contexts.
Clear explanations to aid understanding
Imagine if every time a language model generated text, it subtly preferred certain words over others—but in a way that's invisible to readers. That’s what this watermark does. At each generation step, the model hashes its current context to choose a “greenlist” of preferred tokens and then slightly boosts their probabilities. Over many words, these choices form a statistically detectable pattern. It's like nudging a roulette wheel so certain numbers come up just a bit more often—not enough to be obvious, but enough to spot if you know where to look. The method is efficient and easy to integrate, since it works at the sampling level and doesn't require modifying the model architecture.
Group 14 Presentation: Scalable Watermarking for Identifying Large Language Model Outputs
Presented by:
Ryan Tymkow, Benjamin Schnapp
Paper Citation
Dathathri, S., See, A., Ghaisas, S. et al. Scalable watermarking for identifying large language model outputs. Nature 634, 818–823 (2024). https://doi.org/10.1038/s41586-024-08025-4
Background
With the rise of LLMs and generative AI, there is a risk of spreading AF generated misinformation as if it were from a human. It is important to distinguish AI-generated text from human writing. There exists solutions to address this problem, but they all have limitations involving privacy and computational costs. For example, the traditional watermarking can result in unwanted artifacts in the text.
Previous Approaches
Some of the existing approaches are Retrieval Based Tracking, Post-Hoc Detection, and Traditional Watermarking.
Retrieval Based Tracking stores generated LLM responses in a reference database to determine whether a newly generated response is from the tracked LLM. The issue with this is scalability and privacy since you are storing generated responses.
Post-Hoc Detection uses statistical features of LLM generated text. The issue with this is the high computational cost and as LLM continues to improve to be more human like with the output, the more difficult this approach becomes.
Traditional watermarking would include hidden features in the text, for example replace synonyms or unicode characters. But tinkering with the output will degrade the quality of the original LLM.
Technical Contributions
This paper introducted SynthID-Text, a watermarking method for large language models (LLMs). The method uses a Tournament sampling approach, which ensures that generated text contains a detectable watermark with minimal computational overhead. SynthID-Text incorporates a random seed generator and scoring functions to embed a watermark into the model's output. This technique enhances the ability to identify if text originates from a specific LLM, while preserving the text's quality and minimizing distortion. SynthID-Text does not affect LLM training. It can be configured as distortionary or non-distortionary.
Central to SynthID-Text is the novel Tournament sampling procedure. Rather than sampling each token directly from the LLM's distribution, multiple candidate tokens compete in a multi-layer "tournament" based on pseudorandom watermarking scores, embedding a statistical “signature” that can later be detected.
Results
Synth ID achieved a 95% true positive detection rate with 1% false positives in a 20 million interaction test on Google's Gemini chatbot. It offers high detection accuracy with minimal computational cost and can be configured for non-distortionary or distortionary watermarking.
The benefits are as follows:
Minimal impact on large language model training:
-Synth ID text can be applied to any large language model with minimal modifications, as it only alters the sampling step.
High detection accuracy with low computational cost:
-Outperforms retrieval-based tracking, post hoc detection, and traditional watermarking methods.
-Offers the best balance between computational cost, scalability, and accuracy.
-Can be integrated into production environments using speculative sampling, where smaller models suggest tokens and the main model verifies their validity.
Configurable distortion levels:
-Allows for distortionary or non-distortionary configurations, enabling better control over the quality of generated text versus detection accuracy.
-In non-distortionary watermarking, the average token distribution of the generated text matches the original model's token distribution.
Group 15 Presentation: DiGress: Discrete denoising diffusion for graph generation
Presented by:
Sean Tang, Buji Wong
Paper Citation
Vignac, C., Krawczuk, I., Siraudin, A., Wang, B., Cevher, V., & Frossard, P. (2022). Digress: Discrete denoising diffusion for graph generation. arXiv preprint arXiv:2209.14734.
Background
Graph generation
The goal of this project is to generate graphs, which are represented by node matrices and edge matrices. Edges and nodes can also have their own categories. One application of this is molecule generation: atoms would be represented by nodes and the chemical bonds would be represented by edges.
The challenge of graph generation is a complex task due to the unordered nature and sparsity of graphs. While denoising diffusion models have been successful in other domains like images, they struggle with graphs due to their structural properties. Existing approaches that use continuous noise models disrupt the sparsity and connectivity crucial for graphs.
Discrete diffusion
Regular diffusion methods involve a forward process (where noise is gradually added to the data) and a neural network that is trained to predict the backwards process (where random noise is gradually turned back into something that would plausibly fit in with the dataset). Most of the time, diffusion models use Gaussian noise.
This graph generation problem involves discrete data, therefore, a discrete noise model is preferred for the diffusion process. (Note: The authors actually previously tried to use Gaussian noise, but the continuous nature of the Gaussian function meant that the discreteness of the graphs was not respected.) In this discrete case, however, the diffusion is defined in terms of transition matrices, [math]\displaystyle{ Q^t }[/math]:
[math]\displaystyle{ q(x_t | x_{t-1}) = x_{t-1} Q^t }[/math]
Each transition matrix is applied to each node and edge; this allows the nodes and edges to remain in discrete space.
Technical Contributions
Overview of DiGress
The authors introduce DiGress, a discrete denoising diffusion model designed specifically for graph generation with categorical node and edge attributes. DiGress improves graph generation by using a discrete noise model that preserves graph sparsity and structural properties. The model involves progressively editing graphs through edge addition/removal and attribute changes. A graph transformer network is used to reverse this noisy process using cross-entropy loss, sampling from the trained model by iteratively updating the noise level and computing structural features.
Key enhancements include a noise model that maintains node and edge type distributions, a guidance procedure for conditioning on graph-level features, and the use of auxiliary graph-theoretic features. DiGress achieves state-of-the-art performance on both molecular and non-molecular graph datasets.
Denoising network architecture
- The input is a noisy graph, expressed as [math]\displaystyle{ X }[/math] (a tensor of one-hot encoded nodes) and [math]\displaystyle{ E }[/math] (a tensor of one-hot encoded edges).
- The structural and spectral features of these graphs are calculated.
- The noisy graph along with the structural and spectral features are taken as input into a multi-layer perceptron (MLP) structure.
- Next is a sequence of graph transformer layers.
- After another MLP layer, the output is a distribution over nodes and edges.
Summary of results
Results showed Digress outperformed continuous diffusion methods on various metrics, including degree distribution, clustering, and novelty, and was more scalable for larger graphs. Moreover, in the creation of novel molecules, discrete diffusion aids scalability for larger graphs and molecules, making it more efficient compared to continuous diffusion. DiGress is the first one-shot graph-based model that feasibly trains on over a million molecules without fragment-based heuristics. Its performance on drug-like molecule benchmarks reaches or exceeds that of specialized autoregressive or SMILES-based baselines.
Accomplishments
- State-of-the-art one-shot graph generation:
- DiGress outperforms existing non-autoregressive models such as GDSS, GraphNVP, and SPECTRE.
- Achieves strong performance on benchmarks including Stochastic Block Models (SBM) and planar graphs.
- Scalable molecular graph generation:
- DiGress is the first diffusion-based model to scale to large molecular datasets such as:
- MOSES (small, drug-like molecules)
- GuacaMol (1.3M large molecules)
- Matches or exceeds autoregressive and fragment-based baselines on metrics such as validity, uniqueness, novelty, and scaffold diversity.
- DiGress is the first diffusion-based model to scale to large molecular datasets such as:
- Fast and efficient performance on QM9:
- Achieves near-perfect:
- Validity: 99%
- Uniqueness: ~96%
- Trains significantly faster than prior diffusion models (1 hour vs. 7.2 hours for ConGress).
- Achieves near-perfect:
- Effective conditional generation:
- Uses regressor-guided sampling to generate molecules with desired properties (e.g., HOMO energy, dipole moment).
- Outperforms unconditional models in property control and accuracy.
- Theoretical soundness:
- Proven permutation equivariance and exchangeability for correct graph generation.
- Provides an ELBO-based training objective for likelihood estimation and model comparison.
Group 15 Presentation: DiGress: Discrete denoising diffusion for graph generation
Presented by:
Sean Tang, Buji Wong
Paper Citation
Vignac, C., Krawczuk, I., Siraudin, A., Wang, B., Cevher, V., & Frossard, P. (2022). Digress: Discrete denoising diffusion for graph generation. arXiv preprint arXiv:2209.14734.
Summaries of key points
Motivation
Recall that in class we already learned diffusion models, they work by adding noise to data gradually (forward process) and training a model to reverse the process, restoring the clean data (reverse process). Traditional diffusion models are designed for continuous data, but graphs are genereally discrete. This discrete structure makes it difficult to directly apply traditional diffusion models to graphs.
DiGress tackles this issue by introducing a method that applies diffusion to discrete graphs through the concept of discrete diffusion.
Forward Diffusion Process: How to Add Discrete Noise?
As mentioned above, DiGress uses discrete noise. In DiGress, noise is added by changing the categories of nodes and edges (such as atom types or bond types). This means the model randomly selects new categories for nodes and edges from a predefined set of valid categories. To guide this noise addition, DiGress uses a transition matrix. The transition matrix defines the probability of each node and edge transitioning between categories at each step. This ensures that while the graph becomes noisier over time, it still maintains its structure and validity.
Reverse Diffusion Process: How to Denoise a Graph?
In the reverse process, DiGress gradually removes the noise from the graph. Instead of randomly undoing the noise, the model uses a graph transformer network, designed for graph-structured data. This network helps the model recognize the structure of the graph and the relationships between nodes and edges. During each step, the model focuses on the most relevant parts of the graph, predicting the correct categories for nodes and edges. And the model’s predictions are guided by cross-entropy loss (applied to both nodes and edges), which measures how accurately the model predicts the node and edge categories after denoising. By minimizing this loss, the model becomes better at removing the noise, step by step, until it recovers a valid and meaningful graph.
Conditioning Graph Generation on Desired Properties
One of the most powerful features of DiGress is its ability to condition the graph generation process on specific properties. For example, if you want the generated graph to have a certain number of oxygen atoms or satisfy some chemical property, DiGress can adjust the generation process to meet those requirements. This is done by checking the graph at each step of the sampling process and modifying it as needed to match the desired properties. This capability is particularly useful in areas like drug discovery, where the generated molecules must meet certain chemical and structural criteria.
Experimental Results
DiGress was tested on several datasets to evaluate its performance:
1. Graph Generation: On datasets like the Stochastic Block Model (SBM) and planar graphs, DiGress outperformed other methods, particularly in generating novel graphs that were not seen during training.
2. Molecule Generation: When applied to the MOSES dataset, DiGress produced more valid and novel molecules, even though it did not always surpass methods that check graph validity at every step.
3. Scalability: On larger graphs, such as those from the Guacamole dataset, DiGress demonstrated strong scalability, making it a suitable option for generating larger and more complex graphs.
Group 16 Presentation: Machine Learning and Hamilton-Jacobi-Bellman Equation for Optimal Decumulation: a Comparison Study
Presented by:
Zeyu Zhang
Paper Citation
Chen M, Shirazi M, Forsyth PA, Li Y. Machine Learning and Hamilton-Jacobi-Bellman Equation for Optimal Decumulation: a Comparison Study. Published online 2023. doi:10.48550/arxiv.2306.10582
Background
The paper is based on computational finance, focusing on the optimization problem related to "defined benefit" and "defined contribution plans". The main focus is on the challenge of ensuring retirees have enough funds for their retirement. Two key plans were discussed:
"Defined benefit plans" guarantee fixed monthly payments based on factors like tenure and salary but are cost-prohibitive and risky.
"Contribution plans" shift the investment and withdrawal strategy burden to individual investors, but they struggle to balance maximizing withdrawals and minimizing risk.
This problem, often called the "Nazi's hardest problem in finance," highlights the complexity of balancing risk and reward in financial planning for retirement.
The 4% rule is a traditional method recommending a constant 4% withdrawal each year, adjusted for inflation, and investing in stocks and bonds.
Despite its popularity, the 4% rule is suboptimal and not globally optimal
Peter Fauci proposed the HJB PDE method to maximize expected withdrawal and minimize the risk of running out of savings.
The HJB PDE method uses scalarization techniques to achieve Pareto optimal points, but it has limitations.
Technical Contributions
1. Hamilton-Jacobi-Bellman (HJB):
- The problem formulation involves complex mathematical equations related to computational finance.
- The problem uses dynamic programming to break down the optimal control problem, leading to the HJB function that represents the value function.
- The paper assumes stock and bond prices follow a jump diffusion model.
- The investors' total wealth at time [math]\displaystyle{ t }[/math] is defined as the sum of stock price and bond price at that time.
- The capital [math]\displaystyle{ T }[/math] is set to 30 years, and rebalancing times are defined with discrete withdrawal amounts and allocation for stocks and bonds.
2. Neural Network (NN): Control and Objective Function:
- The control at time [math]\displaystyle{ T_i }[/math] includes the withdrawal amount [math]\displaystyle{ Q_i }[/math] and allocation for the wealth at time [math]\displaystyle{ T_i^- }[/math].
- The admissible control set is defined, and the expected shortfall is introduced as a measure of risk.
- The expected total withdrawal is used as a measure of reward, aiming to maximize the expected total withdrawal while minimizing the expected shortfall.
- The pre-commitment in the expected shortfall problem is defined, focusing on maximizing the expected total withdrawal and minimizing the expected shortfall.
Neural Network (NN) Formulation
As an alternative to the HJB framework, the authors propose a Neural Network (NN) approach to solve the stochastic control problem. This framework has several advantages:
1. The NN approach is data-driven, meaning it avoids explicitly defining parametric models for stochastic processes. This provides flexibility and allows integration of auxiliary variables if needed.
2. It circumvents the computation of high-dimensional conditional expectations by solving a single, unconstrained optimization problem for control decisions. This avoids the curse of dimensionality often associated with dynamic programming.
3. If the optimal control is continuous in time and state, the NN reflects this property. If the control is discontinuous, the NN yields a smooth approximation, which is beneficial in practice for implementing investment policies.
4. The method is scalable, making it suitable for long horizons and high-frequency rebalancing without significantly increasing computational complexity.
The NN framework’s effectiveness lies in its ability to learn control policies directly from simulated state paths without requiring explicit knowledge of the underlying stochastic differential equations.
Instead of solving high-dimensional HJB PDEs, the neural network uses:
- Forward simulation to sample wealth trajectories under candidate policies.
- Backward evaluation to update network parameters based on performance (e.g., maximizing expected withdrawals, minimizing expected shortfall).
This model-free, data-driven method avoids dynamic programming and is especially useful in high dimensions, where solving PDEs becomes computationally infeasible.
Moreover, by designing appropriate activation functions (e.g., softmax for portfolio weights and sigmoid for withdrawal rates), the NN ensures that stochastic constraints are naturally respected throughout training and inference.
NN Approximation Setup
- The control policy [math]\displaystyle{ \mathcal{P} }[/math] is approximated using two feed-forward neural networks, with parameters [math]\displaystyle{ \boldsymbol{\theta}_q }[/math] and [math]\displaystyle{ \boldsymbol{\theta}_p }[/math], representing withdrawal and allocation strategies respectively.
- These networks take as inputs the Brownian motion path [math]\displaystyle{ W(t_i) }[/math] and time [math]\displaystyle{ t_i }[/math] to approximate control decisions:
[math]\displaystyle{ \hat{q}(W_i^-, t_i^-, \boldsymbol{\theta}_q) \approx q_i(W_i^-), \quad \hat{p}(W_i^+, t_i^+, \boldsymbol{\theta}_p) \approx p_i(W_i^+) }[/math]
- The final control policy is given by: [math]\displaystyle{ \hat{\mathcal{P}} = \{ (\hat{q}(\cdot), \hat{p}(\cdot)) \} \approx \mathcal{P} }[/math]
- The functions [math]\displaystyle{ \hat{p} }[/math] and [math]\displaystyle{ \hat{q} }[/math] use time as one of the inputs, allowing a single NN to handle decisions across all rebalancing points, rather than training separate models for each time step.
- The paper also discusses how the architecture includes activation functions that enforce stochastic constraints naturally.
Summaries of Key Notes
- Neural Network Framework for Pension Decumulation: A novel framework using neural networks (NNs) is proposed to optimize asset allocation and cash withdrawal strategies for defined contribution (DC) pension plans. Unlike traditional methods, it solves constraints efficiently via unconstrained optimization.
- Comparison with HJB Method: The NN approach achieves comparable accuracy to the Hamilton-Jacobi-Bellman (HJB) PDE method while being scalable to higher-dimensional problems and avoiding dynamic programming errors.
- Efficient Withdrawals: The NN framework closely approximates optimal "bang-bang" controls, effectively alternating between minimum and maximum withdrawals based on wealth, ensuring reliable pension decumulation.
- Robustness: Tested extensively on synthetic and historical market data, the NN solution adapts well and demonstrates strong out-of-sample and out-of-distribution performance. Advantages:
Constructive Critique
While the NN method replicates HJB-derived control policies with high accuracy, a few limitations and caveats exist:
- The training relies on synthetic data simulated from assumed models (e.g., geometric Brownian motion, jump diffusion), which may limit generalizability under real-world dynamics.
- The scalarization of the reward-risk tradeoff assumes a linear weighting of Expected Withdrawals and Expected Shortfall. In practice, retiree preferences might reflect more complex, utility-based behaviors.
- The interpretability of the learned policy is less transparent compared to explicit, closed-form control strategies derived via HJB.
- No formal convergence analysis or approximation bounds are provided for the NN solution.
Despite these challenges, the method is empirically robust and scalable, making it an appealing alternative for large-scale or real-time applications.
Related Works
The work is also related to "Spending Retirement on Planet Vulcan: The Impact of Longevity Risk Aversion on Optimal Withdrawal Rates", which introduces utility-based models for withdrawal strategies under longevity risk aversion. The models focus on retirees' behavioral preferences — particularly longevity risk aversion and the need for smooth consumption throughout retirement.
While not PDE-based, it contextualizes the trade-offs between consumption smoothing and risk of depletion, complementing the (EW, ES) approach by addressing behavioral and utility-driven objectives.
On the other hand, the benchmarkNNpaper takes a risk-sensitive stochastic control approach, optimizing a scalarized objective that balances:
Expected Withdrawals (EW) = proxy for consumption
Expected Shortfall (ES) = proxy for downside risk / depletion probability
Group 18 Presentation: HIGEN: HIERARCHICAL GRAPH GENERATIVE NETWORKS
Presented by:
- Shiyu Zhu
- Jesse Xue
Paper Citation
M. Karami, “HiGen: Hierarchical Graph Generative Networks,” 2023, arXiv. doi: 10.48550/ARXIV.2305.19337.
Background
- Standard softmax function uses quadratic computational complexity during prediction, and linear attention mechanism leads to an underperforming model.
- Hierarchical or Multi-Scale Structure: Captures high level interactions or relationships between objects or groups, while also representing the lower level structures. One example is a company org chart.
- Existing graph generating models include: Variational Autoencoders, Generative Adversarial Networks, Autoregressive Models (GNN, Graph RNN, GRAN) and Diffusion Models
- Paper introduces HIGEN, Hierarchical Graph Generative Networks to address problems with existing models.
- Experiments were conducted on 5 datasets, each with increasing size and scale. The GraphRNN, GRAN, DiGress, GDSS, and SPEC
Technical Contributions
- Related Hierarchical Methods: The presentation discusses several recent hierarchical methods in specific domains like chemistry, highlighting HiGen's broader applicability, multi-level approach, and parallel generation as advantages over these more specialized techniques. These include a multi-based generation for molecular graphs (2020) relying on domain knowledge, a hierarchical normalizing flow model (2021) based on local neighborhoods, and a tree decomposition framework (2022) limited to a single abstraction level and medium-sized graphs.
Definition: Hierarchical Graph
A Hierarchical Graph is a multi-level representation of a graph [math]\displaystyle{ \mathcal{G} = (\mathcal{V}, \mathcal{E}) }[/math] where:
- [math]\displaystyle{ \mathcal{V} }[/math] is the set of nodes (vertices), and [math]\displaystyle{ \mathcal{E} }[/math] is the set of edges, with sizes [math]\displaystyle{ n = |\mathcal{V}| }[/math] and [math]\displaystyle{ m = |\mathcal{E}| }[/math].
- A node partition function [math]\displaystyle{ \mathcal{F}: \mathcal{V} \rightarrow \{1, ..., c\} }[/math] groups nodes into [math]\displaystyle{ c }[/math] communities or clusters.
- Each cluster forms a subgraph [math]\displaystyle{ \mathcal{C}_i = (\mathcal{V}(\mathcal{C}_i), \mathcal{E}(\mathcal{C}_i)) }[/math] with adjacency matrix [math]\displaystyle{ A_i }[/math].
- Cross-links between communities form bipartite graphs [math]\displaystyle{ \mathcal{B}_{ij} = (\mathcal{V}(\mathcal{C}_i), \mathcal{V}(\mathcal{C}_j), \mathcal{E}(\mathcal{B}_{ij})) }[/math].
- Each cluster is aggregated into a super-node and each bipartite into a super-edge at the next higher level. This forms a coarser graph at the parent level.
Methodology
HiGen (Hierarchical Graph Generative Networks) adopts a coarse-to-fine approach for generating graphs with complex hierarchical structures. The method consists of the following components:
1. Hierarchical Graph Construction
The input graph is recursively partitioned into communities using the Louvain clustering algorithm, forming a multi-level hierarchy. Each cluster (community) is abstracted as a "super-node" at the next level, and cross-community connections become "super-edges."
2. Community and Bipartite Generation
* Community Generation: Each community is generated independently using a GNN-based autoregressive model that factorizes the edge distribution into multinomial components. A mixture of multinomials is used to model intra-community edge structures. * Bipartite Generation: Once communities are generated, bipartite graphs (cross-community edges) are created using a separate neural network. These inter-cluster edges are predicted with a parallel GNN model using a similar factorization strategy.
3. Autoregressive Probabilistic Modeling
HiGen decomposes the joint probability of the graph into a product of conditional multinomial distributions. Both community and bipartite graphs are generated step-by-step using a recursive generation process guided by node and cluster embeddings.
4. Parallelism and Invariance
The hierarchical structure allows parallel generation of communities and bipartite edges at the same level, improving efficiency. The model is also invariant to node ordering within clusters, which improves robustness and scalability.
This design enables HiGen to generate large and complex graphs while maintaining global structure and fine-grained local details. It supports realistic synthesis across diverse graph types including molecules, biological networks, and social systems.
Summaries of Key Notes
- HiGen: Hierarchical Graph Generative Networks (HiGen) is a novel graph generation model that addresses the limitations of existing methods by leveraging hierarchical structures in graphs. It generates substructures in a coarse-to-fine manner, modeling the interactions between communities and cross-edges at multiple levels of abstraction.
- Community and Bipartite Generation: HiGen utilizes modular generation, creating communities in parallel followed by predictions of cross-edges with separate neural networks, ensuring scalability and efficiency. It employs multinomial distributions with recursive factorization, enabling autoregressive generation of integer-valued edge weights within communities.
- State-of-the-Art Performance: HiGen demonstrates superior graph quality across benchmark datasets compared to competing models like GRAN, GraphRNN, GDSS, and diffusion-based methods. It captures both local and global graph statistics effectively, achieving state-of-the-art metrics for graph degree, clustering coefficients, and Laplacian spectra.
- Scalability: The hierarchical structure of HiGen facilitates parallel graph generation, reduces sensitivity to node ordering, and enables block-wise processing of adjacency matrices, making it adaptable to large and complex graphs. Applications: HiGen excels in generating realistic graphs for applications in molecular modeling, protein structure analysis, data network synthesis, and more, addressing diverse domains where hierarchical graph structures are pivotal.
Results
- Experiments are conducted on 5 datasets (e.g., SBM, Protein, Enzyme, Ego) with increasing size and complexity.
- Evaluation metrics cover local (degree distribution, clustering coefficient distribution, orbit counts) and global (graph spectra) graph statistics.
- Observation: Compared to baselines like GraphRNN, GRAN, and SPECTRE, HiGen consistently achieves state-of-the-art performance, especially in capturing graph motifs and global structure.
Group 20 Gated linear attention transformers with hardware efficient training
Reference
arXiv:2312.06635
Background
The paper discusses Gated Linear Attention (GLA) Transformers, addressing the computational inefficiency of traditional transformers with softmax attention. Regular transformers have a quadratic computational complexity with sequence length, which becomes extremely expensive for long sequences, and linear attention mechanism leads to an underperforming model.
Technical Contributions
It proposes using a linear kernel as an alternative to the softmax function, which allows attention to be formulated as a linear RNN with 2D hidden states. The key innovations include:
1. Introducing a data-dependent gating mechanism to improve model performance, which allows the model to forget past information adaptively
2. Developing a linear attention approach that reduces computational complexity
3. Creating a hardware-efficient training method (FLASHLINEARATTENTION Algorithm) that can handle long sequences more effectively
The main goal was to create a more efficient transformer model that can:
- Reduce computational expenses
- Maintain competitive performance across different tasks
- Handle long sequences more effectively
- Leverage modern GPU architectures for improved training and inference
The approach addresses the fundamental challenge of making transformer models more scalable and computationally efficient, particularly for tasks involving long sequences like processing books, dialogues, or complex scientific texts.
Results
The results and conclusions of the paper showed:
Performance Results:
- For the 340 million parameter model:
- Achieved competitive performance - Close to transformer performance - Slightly better or comparable to Rednet - Slightly below Mamba on some tasks
- For the 1.3 billion parameter model:
- Beat most benchmarks in average accuracy - Slightly behind transformer++ in perplexity - Showed impressive accuracy across tasks
Key Findings:
1. Gating mechanism is crucial for model performance
- Removing it significantly increased perplexity - Data-dependent scalar decay improved results
2. Recall-intensive tasks:
- Smaller model: Transformer still led - Larger model: GLA closed performance gap considerably - Competitive with Mamba and Rednet
3. Computational Efficiency:
- Higher training throughput for larger batch sizes - Slight increase in GPU memory usage - More efficient for bigger batches
Conclusions:
- GLA is highly effective for handling long sequences - Hardware-efficient design reduces computational costs - Gating mechanism significantly enhances model performance - Promising approach for making transformers more accessible and efficient - A efficient replacement for softmax attention in Transformers
The paper suggests future research should focus on optimizing the balance between performance and efficiency.
Linear Attention
Transformers traditionally use softmax attention, which scales poorly with sequence length due to quadratic complexity. Linear attention approximates softmax with kernel-based attention mechanisms, reducing this cost.
Parallel and Recurrent Forms
-
Parallel Form: Computes full attention using:
[math]\displaystyle{ \mathbf{O} = \text{softmax}((\mathbf{QK}^\top) \odot \mathbf{M}) \mathbf{V} }[/math]
Enables efficient training with full-sequence inputs. -
Recurrent Form: Used during inference, processes token-by-token with:
[math]\displaystyle{ \mathbf{o}_t = \frac{\sum_{i=1}^{t} \phi(\mathbf{q}_t) \phi(\mathbf{k}_i)^\top \mathbf{v}_i}{\sum_{i=1}^{t} \phi(\mathbf{q}_t) \phi(\mathbf{k}_i)^\top} }[/math] -
Using [math]\displaystyle{ \phi(x) = x }[/math] and removing normalization yields the simplified linear attention update:
[math]\displaystyle{ \mathbf{S}_t = \mathbf{S}_{t-1} + \mathbf{k}_t^\top \mathbf{v}_t }[/math], [math]\displaystyle{ \quad \mathbf{o}_t = \mathbf{q}_t \mathbf{S}_t }[/math]
Chunkwise Parallel Linear Attention
The chunkwise parallel form balances between full parallelism and full recurrence, enabling faster training on long sequences.
- Splits input [math]\displaystyle{ \mathbf{X} }[/math] into chunks of length [math]\displaystyle{ C }[/math].
-
Inter-chunk state update:
[math]\displaystyle{ \mathbf{S}_{[i+1]} = \mathbf{S}_{[i]} + \sum_{j=iC+1}^{(i+1)C} \mathbf{k}_j^\top \mathbf{v}_j }[/math] -
Intra-chunk output:
[math]\displaystyle{ \mathbf{O}_{[i+1]} = \mathbf{Q}_{[i+1]} \mathbf{S}_{[i]} + \left((\mathbf{Q}_{[i+1]} \mathbf{K}_{[i+1]}^\top) \odot \mathbf{M}\right) \mathbf{V}_{[i+1]} }[/math]
Swish Activation function (SwiGLU)
One notable component of the GLA model’s design is its use of the Swish activation function (and the derived SwiGLU gating unit) in key parts of the network. Swish, defined as [math]\displaystyle{ \text{Swish}(x) = x \cdot \sigma(x) }[/math] (where [math]\displaystyle{ \sigma }[/math] is the sigmoid), is a smooth, non-monotonic activation known to often outperform ReLU/GELU in deep networks. In this paper, Swish is employed in two main places: (1) the feed-forward network (FFN) layers, where the authors adopt the SwiGLU formulation, and (2) the computation of the data-dependent gates in the attention mechanism.
1) FFN
The function's smooth gradient and ability to yield non-zero outputs even for negative inputs help with optimization and expressiveness. In summary, using SwiGLU in each Transformer block’s FFN is an architectural choice that boosts performance per parameter, making the overall model more competitive with standard Transformers.
2)Gating mechanism
Swish is self-gating. when [math]\displaystyle{ x_t W_r }[/math] is large positive, Swish outputs a large value (roughly linear in x for large positive inputs), but when [math]\displaystyle{ x_t W_r }[/math] is around zero or negative, Swish outputs a small value (tending toward zero for large negative inputs). This means [math]\displaystyle{ r_t }[/math] will tend to selectively suppress tokens that the model deems less important (yielding near-zero for those features) while allowing strong signals to pass through (near-linear for large activations). A standard sigmoid gate could also suppress features (outputting 0-1), but it would saturate at 1 for any sufficiently large input, effectively capping the influence of very important features. Swish, by contrast, does not saturate to a hard limit – for important inputs it can output values greater than 1 (since [math]\displaystyle{ r_t \approx x_t }[/math] if [math]\displaystyle{ x_t W_r }[/math] is large), thereby allowing an amplification effect. This gives the model more flexibility than a sigmoid-gated GLU: small signals are squashed (multiply by a small fraction), while strong signals can be propagated in full or even amplified (almost identity for large positive x). This property can be crucial for modeling, for example, rare key tokens that should strongly influence the attention – Swish gating will let those contributions through rather unattenuated, whereas a sigmoid gate might bottleneck at 1.
Benefits
- Time complexity: [math]\displaystyle{ \mathcal{O}(LCd + Ld^2) }[/math], which is sub-quadratic.
- [math]\displaystyle{ C = 1 }[/math] recovers the recurrent form; [math]\displaystyle{ C = L }[/math] recovers the parallel form.
- Efficient and scalable to long sequences with minimal performance loss.
Future Work
1. Future hardware-aware optimization: balance between efficiency and performance.
2. Application to other data: the potential of applying GLA to image, video, or scientific data.
3. Test how GLA perform on larger model: due to computational limitations, the experiment is on moderate scale model.
Summaries of Key Points
- Gated Linear Attention (GLA) Transformer is a novel architecture that combines the efficiency of linear attention with data-dependent gating mechanisms to improve performance in sequence modeling tasks.
- FLASHLINEARATTENTION is introduced as a hardware-efficient implementation of linear attention, outperforming FLASHATTENTION-2 in speed, even for short sequences (e.g., 1K tokens).
- GLA Transformer enhances length generalization, allowing models trained on 2K sequences to generalize to sequences longer than 20K without significant performance degradation.
- The model is competitive with state-of-the-art architectures, including LLaMA Transformers and linear-time inference models like RetNet and Mamba, particularly in moderate-scale language modeling tasks.
- GLA Transformer achieves higher training throughput compared to similarly sized Mamba models while maintaining efficient long-context processing.
Group 23 Presentation: Discrete Diffusion Modelling By Estimating the Ratios of the Data Distribution
Presented By
Chenxin Lyu, Yixuan Zeng
Paper Citation
A. Lou, C. Meng, and S. Ermon, ‘Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution’, Jun. 06, 2024, arXiv: arXiv:2310.16834. doi: 10.48550/arXiv.2310.16834.
https://arxiv.org/abs/2310.16834
Background
- Diffusion models have shown great performance for generative artifical intelligence when applied to domains with continuous data
- Diffusion models are more difficult to implement for data in the discrete domain, such as tokenized texts
- Prior attempts at applying diffusion to text generations have performed worse than autoregressive models
Paper Contributions
- Developed a method called Score Entropy Discrete Diffusion (SEDD)
- Parameterizes the diffusion process for discrete data using data distribution ratios, rather than dealing with the tokenized data directly
- SEDD
SEDD is a framework for discrete diffusion modeling that learns to generate data by estimating probability ratios between neighboring discrete states.. In the paper, SEDD forms the core modeling strategy that enables diffusion-based generation for discrete data, achieving competitive perplexities with autoregressive baselines.
- Implicit Score Entropy Loss
[math]\displaystyle{ L_{ISE} = \mathbb{E}_{x \sim p} \sum_{y!=x}(w_{xy}s_\theta(x)_y - w_{yx}log s_\theta (y)_x) }[/math]
Implicit Score Entropy Loss is a novel training objective designed to learn the ratio function [math]\displaystyle{ s_\theta(x, t) = p_t(y)/p_t(x) }[/math] without requiring access to the true data distribution. In the paper, it makes the score entropy computationally more efficient, as it avoids the explicit dependence on [math]\displaystyle{ p(y)/p(x) }[/math] ratio.
It allows one to not compute partition functions or normalize over large vocabularies. The implicit score entropy loss makes the method scalable and practical for high-dimensional or categorical data (like text). It also connects naturally to energy-based modeling where exact densities are intractable but ratios or unnormalized scores can be learned.
It’s worth noting that the score entropy loss also connects to maximum likelihood. The authors show it can be used to derive an evidence lower bound (ELBO) for likelihood training
Discrete Diffusion Processes
- Models probability distributions over a finite discrete space [math]\displaystyle{ \mathcal{X} = \{1, \ldots, N\} }[/math], using probability mass vectors [math]\displaystyle{ p_t \in \mathbb{R}^N }[/math].
-
Evolution of [math]\displaystyle{ p_t }[/math] follows a linear ODE:
[math]\displaystyle{ \frac{dp_t}{dt} = Q_t p_t,\quad p_0 \approx p_{\text{data}} }[/math] - [math]\displaystyle{ Q_t }[/math] is a diffusion matrix with non-negative off-diagonal entries and column sums equal to 0 (mass is preserved).
- Often simplified as [math]\displaystyle{ Q_t = \sigma(t) Q }[/math], driving [math]\displaystyle{ p_t }[/math] toward a base distribution as [math]\displaystyle{ t \to \infty }[/math].
-
Simulated using Euler steps with small [math]\displaystyle{ \Delta t }[/math]. Transition probability:
[math]\displaystyle{ p(x_{t+\Delta t} = y \mid x_t = x) = \delta_{xy} + Q_t(y, x) \Delta t + O(\Delta t^2) }[/math] -
Time Reversal: Reverse process uses another matrix [math]\displaystyle{ \overline{Q}_t }[/math] with:
[math]\displaystyle{ \overline{Q}_t(y, x) = \frac{p_t(y)}{p_t(x)} Q_t(x, y) }[/math]
Reverse ODE: [math]\displaystyle{ \frac{dp_{T-t}}{dt} = \overline{Q}_{T-t} p_{T-t} }[/math] - This connects to the concrete score, generalizing the score function [math]\displaystyle{ \nabla_x \log p_t }[/math].
Summaries of Key Points
- SEDD (Score Entropy Discrete Diffusion models) is a novel approach to discrete diffusion modeling that bridges the gap between diffusion models and autoregressive language models. It introduces score entropy, a new loss function that extends score matching to discrete spaces, improving performance in discrete generative modeling tasks.
- SEDD significantly improves performance in language modeling, reducing perplexity by 25-75% compared to previous discrete diffusion models and outperforming GPT-2 in certain tasks.
- Key advantages of SEDD over traditional autoregressive models:
- Higher generative quality without requiring distribution annealing techniques such as temperature scaling.
- Computational efficiency, enabling similar output quality with up to 32× fewer network evaluations.
- Enhanced controllability, allowing for flexible text infilling beyond left-to-right prompting, while maintaining quality comparable to nucleus sampling.
- SEDD models discrete data by parameterizing a reverse discrete diffusion process using ratios of the data distribution, making them more effective in capturing language structure.
- The method challenges the dominance of autoregressive transformers by offering an alternative with better trade-offs between efficiency, control, and generation quality in discrete text modeling.
Results
- Perplexity Evaluation: SEDD outperforms previous diffusion models in terms of perplexity across multiple language modeling benchmarks. This suggests that SEDD models the underlying data distribution more accurately, improving likelihood estimation.
- Unconditional Generation: SEDD generates high-quality samples without any input conditioning, achieving performance comparable to GPT-2. It does so with 32× fewer network evaluations, indicating higher sampling efficiency and reduced computational cost.
- Conditional Generation (Infill Tasks): SEDD is tested on tasks where the model must generate missing parts of text conditioned on surrounding context. Despite lacking the autoregressive biases that normally boost performance in such tasks, SEDD remains competitive with strong baselines like GPT-2 and SSD-LM. This highlights SEDD’s ability to generalize well to complex discrete generation tasks without relying on token-by-token prediction.
Group 24 Presentation: Mitigating the Missing Fragmentation Problem in De Novo Peptide Sequencing With A Two-Stage Graph-Based Deep Learning Model
Presenters
Zi Hua Xu, Zehao Zhang
Paper Citation
Mao, Z., Zhang, R., Xin, L. et al. Mitigating the missing-fragmentation problem in de novo peptide sequencing with a two-stage graph-based deep learning model. Nat Mach Intell 5, 1250–1260 (2023). https://doi.org/10.1038/s42256-023-00738-x
https://www.nature.com/articles/s42256-023-00738-x#citeas
Background
- Proteins are crucial for biological functions
- Proteins are formed from peptides which are sequences of amino acids
- Mass spectrometry is used to analyze peptide sequences
- De Novo sequencing is used to piece together peptide sequences when the sequences are missing from existing established protein databases
- Deep learning has become commonly implimented to solve the problem of de-novo peptide sequencing
- When a peptide fails to fragment in the expected manner, it can make protein reconstruction difficult due to missing data
- One error in the protein can propogate to errors throughout the entire sequence
Paper Contributions
- Graph Novo was developed to handle incomplete segments
- GraphNovo-PathSearcher instead of directly predicting, does a path search method to predict the next peptide in a sequence
- A graph neural network is used to find the best path from the graph generated from the mass spectrometry input
- GraphNovo-SeqFiller instead of directly predicting, does a path search method to predict the next peptide in a sequence.
- It's expected that some peptides/ amino acids may have been missed, SeqFiller uses a transformer to add in amino acids which have been missed from PathSearcher
- Input is mass spectrum from mass spectrometry
- Graph construction is done where nodes represent possible fragments, and edges represent possible peptides (PathSearcher module)
- PathSearcher uses machine learning to find the optimal path on the generated graph
- SeqFiller fills in missing amino acids that may have not been included in the PathSearcher module due to lacking data from the mass spectrometry inputs
Peptide Sequencing in AlphaPeptDeep
- Peptide sequencing determines amino acid sequences of peptides, crucial for understanding proteins.
- Mass spectrometry (MS) is used to fragment peptides and analyze the resulting spectra.
Methods referenced in the presentation:
- Database Search & Spectral Library Search: AlphaPeptDeep improves prediction of MS spectra and retention time, boosting accuracy of both methods.
- de novo Sequencing: Enhanced spectral prediction from AlphaPeptDeep supports building peptide sequences without prior knowledge.
- AlphaPeptDeep predicts peptide properties (e.g., fragmentation patterns) to improve spectrum matching and sequence inference.
Contributions
- GraphNovo outperforms prior models such as DeepNovo, PointNovo, and Casanovo. Achieveing:
- 9.3–11.5% higher peptide recall,
- 6.1–8.9% higher amino acid recall,
- 5.8–8.7% higher amino acid precision.
- Substantially improves predictions in and beyond missing-fragmentation regions.
- Improves amino acid recall by up to 24.5% after missing fragmentation when guiding DeepNovo with GraphNovo’s predicted path.
- Maintains high accuracy across varying:
- Sequence lengths,
- Noise signal ratios,
- Degrees of missing fragmentation.
- Open-source availability:
- Code: GitHub – GraphNovo
- Data: Zenodo Repository
Constructive Critiques
- Struggles with long missing-fragmentation regions: The model has difficulty accurately predicting peptide sequences when large segments of fragmentation data are absent. These missing regions create gaps in the spectrum, which may impair the model's ability to reconstruct the full peptide chain.
- Requires predefined post-translational modifications: The system depends on a predefined list of possible post-translational modifications. This constraint limits the model’s ability to generalize to peptides with unexpected or novel PTMs, reducing its adaptability in complex biological samples.
- Computationally expensive: Due to its two-stage graph-based deep learning architecture, the model demands significant computational resources.
Reviews
I thought that this presentation clearly introduced the concepts of peptide sequences and how mass spectrometry is used to analyze/reconstruct peptide sequences. Also it was nice that you mentioned the different ways how sequences are matched with observed spectra besides deep learning methods (e.g, matching observed spectra with known sequences and using pre-existing spectral databases). The problem of how mistakes in the fragmentation can result in missing data in the spectrometry observed was also clearly explained, ultimately making sequence reconstruction difficult. As someone who plans to do project 3 for the final project, I found this presentation particularly helpful in understanding the type of data used in de novo peptide sequence generation. One comment that I have is that I found it a little unclear on how the mass spectra input is turned into a graph with fragments as nodes, and what exactly the optimal "path" is in the graph (aka how the edges are defined), although it may also because I am not too familiar in this area.
Group 24 Presentation: Mitigating the Missing Fragmentation Problem in De Novo Peptide Sequencing With A Two-Stage Graph-Based Deep Learning Model
Presenters
Zi Hua Xu, Zehao Zhang
Paper Citation
Mao, Z., Zhang, R., Xin, L. et al. Mitigating the missing-fragmentation problem in de novo peptide sequencing with a two-stage graph-based deep learning model. Nat Mach Intell 5, 1250–1260 (2023). https://doi.org/10.1038/s42256-023-00738-x
https://www.nature.com/articles/s42256-023-00738-x#citeas
Background
Peptide sequencing is a critical step in proteomics for determining the amino acid composition of proteins using tandem mass spectrometry (MS). Traditional approaches fall into three main categories:
(1) Database Search: Compares observed tandem mass spectra with peptides from a known database.
(2) Spectral Library Search: Uses curated spectral libraries containing experimentally acquired spectra to identify peptide–spectrum matches.
(3) De Novo Sequencing: Infers the peptide sequence directly from the MS data without relying on existing sequence databases, making it essential for novel protein discovery and cases where databases are incomplete or impractical.
Limitations in de novo peptide sequencing
Two major challenges persist in de novo peptide sequencing:
(1) Missing Fragmentation: Some peptide bonds do not break during MS fragmentation, leaving “gaps” in the spectrum that make certain regions difficult to reconstruct.
(2) Error Accumulation: A mistake in one poorly fragmented region often propagates, causing further downstream errors in the predicted peptide sequence.
GraphNovo
GraphNovo is a two-stage de novo sequencing algorithm designed to mitigate these issues:
(1) GraphNovo-PathSearcher: Constructs a directed acyclic graph (the “spectrum graph”) where each node represents a possible partial mass, and edges represent allowable amino acid mass differences. Predicts the “optimal path” from the start node to the end node, effectively capturing the correct arrangement of fragment ions and bypassing regions of missing fragmentation by labeling them as “mass tags.”
(2) GraphNovo-SeqFiller: Fills in the “mass tags” from the optimal path with their specific amino acid composition. Uses a transformer-based decoder guided by the path constraints, reducing errors that could otherwise accumulate if one low-confidence region led to incorrect subsequent predictions.
How GraphNovo Functions
(1) Graph Construction: Translates raw MS peaks into a spectrum graph. Each node corresponds to a potential prefix mass; edges form between nodes if the mass difference matches one or more amino acid masses (within tolerance).
(2) PathSearcher: Trained on known spectra to find the correct route through this graph, placing constraints on fragment ions (including missing-fragmentation regions represented as nodes without direct evidence).
(3) SeqFiller: Given the path—and hence each mass gap—SeqFiller “zooms in” on each mass tag to determine the exact amino acids. This two-stage strategy tackles missing fragmentation more directly than single-stage approaches.
Data and Preprocessing
Training Data: Includes high-confidence peptide identifications for Homo sapiens and Mus musculus from public proteomics datasets (e.g., plasma, HeLa, brain tissues). Only peptides with reliable annotations (1% false-discovery rate) and precursor masses below a certain threshold are included.
Test Data: Drawn from species not in the training set (e.g., Arabidopsis thaliana, C. elegans, E. coli), ensuring that evaluation measures generalizability.
Preprocessing Steps:
(1) Convert raw MS peaks into feature vectors (e.g., normalized m/z, relative intensity).
(2) Generate “node spectrum” (b, y, a, y2+ ions, among others) while discarding infeasible peaks.
(3) Build the graph by connecting nodes if their mass difference matches valid amino acid combinations.
Model Architecture
Graph Construction: Creates a directed graph with edges corresponding to possible amino acid masses.
Loss Functions:
(1) GraphNovo-PathSearcher: Uses Kullback–Leibler divergence to guide node (path) predictions.
(2) GraphNovo-SeqFiller: Uses cross-entropy to predict the exact amino acid sequence that fills each mass tag.
Hyperparameter Tuning:
Optimizer: AdamW (a variant of Adam) with a fixed learning rate in the reported experiments.
Both stages employ a transformer-based architecture, incorporating a specialized graph encoder (relation attention) to capture node and edge features.
Performance Comparison
Peptide Recall: GraphNovo shows a 9–12% improvement over the next-best approach (e.g., PointNovo, Casanovo) in correctly reconstructing entire peptide sequences.
Amino Acid Recall and Precision: Yields a 5–9% improvement across different test species, indicating more accurate individual residue identifications.
Robust to Missing Fragmentation and Noise: Maintains relatively high recall/precision for longer peptides, higher noise levels, and spectra with multiple missing-fragmentation sites, thereby mitigating error accumulation.
Constructive Critiques
Long Missing-Fragmentation Regions: While GraphNovo substantially reduces error propagation, very long or continuous gaps remain challenging.
Predefined Modifications: Must specify possible post-translational modifications (PTMs) in the graph construction step, which becomes computationally costly if many PTMs are considered at once.
Computational Overhead: Two-stage approach and large-scale graph construction require significant memory and processing time.
Future improvements
Enhanced Sequence Prediction: Integrate more MS features (e.g., retention time) to improve accuracy within large missing-fragmentation regions.
Expanded Applicability: Adapting the two-stage approach for top-down or middle-down proteomics and more extensive sets of PTMs.
Computational Efficiency: Explore faster graph-building algorithms, reduce the number of nodes through refined filtering, and potentially incorporate few-shot learning for user-specified PTMs.
Comment
Overall, GraphNovo demonstrates that a carefully designed, graph-based deep learning model can significantly mitigate the missing-fragmentation problem in de novo peptide sequencing, outperforming traditional and newer transformer-based methods by providing a stable framework for both path and sequence prediction.
Paper Citation
Group 47 Presentation: Jamba: A Hybrid Transformer - Mamba Language Model
Presenter
Chentao Jin
Paper Citation
Lieber, O., Lenz, B., Bata, H., Cohen, G., Osin, J., Dalmedigos, I., Safahi, E., Meirom, S., Belinkov, Y., Shalev-Shwartz, S., Abend, O., Alon, R., Asida, T., Bergman, A., Glozman, R., Gokhman, M., Manevich, A., Ratner, N., Rozen, N., Shwartz, E., Zusman, M., Shoham, Y. (2024). Jamba: A Hybrid Transformer-Mamba Language Model. arXiv. https://arxiv.org/abs/2403.19887
https://doi.org/10.48550/arXiv.2403.19887
Background
- Large language models (LLMs) have become essential for various natural language processing tasks.
- Transformers have been the dominant architecture for LLMs due to their effectiveness in handling sequential data.
- However, Transformers suffer from high memory and compute costs, limiting their efficiency in long-context processing.
- Mamba, a recent state-space model, has emerged as an alternative to Transformers, offering improved efficiency in handling long sequences.
- A hybrid approach combining Transformers and Mamba layers can leverage the strengths of both architectures.
- Mixture-of-Experts (MoE) techniques can further enhance model capacity while managing active parameter usage.
- Efficient model architectures are crucial for balancing performance, computational efficiency, and memory footprint in large-scale AI applications.
Summaries of Key Points
Bridging Transformer Expressiveness and Mamba Efficiency
Large language models (LLMs) have made remarkable advances, but scaling them to handle long-context processing remains a significant challenge. Jamba introduces a novel hybrid architecture that combines the strengths of Transformers and Mamba, enhanced with a Mixture of Experts (MoE) module. This integration enables efficient memory usage, high throughput, and scalability for sequences up to 256,000 tokens on a single GPU.
Key Architectural Innovations
1. Transformer Layers – Self-attention mechanisms allow the model to capture complex token relationships, crucial for tasks involving deep contextual reasoning. However, they come with high memory and compute costs when processing long sequences.
2. Mamba Layers – Derived from state-space models (SSMs), Mamba efficiently processes long sequences without storing extensive key-value caches. Instead, it maintains a hidden state to summarize prior information, significantly reducing memory overhead.
3. Mixture of Experts (MoE) – Jamba integrates sparse expert selection, where each token activates only a small subset of experts. This technique increases capacity while controlling computational costs. Specifically, Jamba uses 16 experts, with only two active per token, optimizing efficiency.
The Jamba Block: A Structured Hybrid Design
The architecture of a Jamba block follows a structured sequence of Transformer layers, Mamba layers, and MoE layers: Transformer-to-Mamba Ratio (1:7). The model incorporates one Transformer layer for every seven Mamba layers, balancing expressiveness and efficiency. MoE Placement – Instead of applying MoE to every layer, Jamba replaces every second multi-layer perceptron (MLP) layer with an MoE module. This approach increases model capacity without significantly raising parameter count.
By blending self-attention, state-space models, and sparsity techniques, Jamba pushes the boundaries of long-context processing in language models. Its ability to handle extremely long sequences while maintaining efficiency and scalability makes it a compelling innovation in the next generation of LLM architectures.
Performance and Benefits
Jamba achieves state-of-the-art efficiency and performance across academic benchmarks, long-context tasks, and throughput.
- Matches or outperforms larger models such as Mixtral-8x7B and LLaMA-2 70B on:
- Reasoning: HellaSwag, ARC-Challenge, WinoGrande, PIQA, TruthfulQA
- Comprehension: BoolQ, QuAC
- Math and Code: GSM8K, HumanEval
- Aggregated tasks: MMLU, BBH
- Outperforms Mixtral on:
- Needle-in-a-Haystack retrieval
- Few-shot classification: Banking77, TREC-Fine
- Long-context QA: NarrativeQA, CUAD, NaturalQuestions
- Up to 3× faster than Mixtral at [math]\displaystyle{ 128K }[/math] context length.
- Efficient inference on a single [math]\displaystyle{ 80\,\text{GB} }[/math] GPU (int8 quantization).
- KV-cache memory usage is 8× smaller than Transformers (e.g., [math]\displaystyle{ 4\,\text{GB} }[/math] vs. [math]\displaystyle{ 32\,\text{GB} }[/math] at [math]\displaystyle{ 256K }[/math] tokens).
Constructive Critique
While Jamba achieves impressive performance and efficiency through its hybrid architecture, several limitations and open questions remain:
- Lack of ablation analysis: The paper adopts a 1:7 ratio of Transformer to Mamba layers and places MoE modules every other layer, but provides little justification for these hyperparameters. A more thorough ablation would help clarify the contribution of each component to final performance.
- Interpretability trade-off: Combining multiple architectural modules (SSM, attention, MoE) increases complexity. While effective, this may make it harder to interpret model behavior or debug errors, especially compared to simpler Transformer-only baselines.
- ICL limitations: The authors mention that Mamba layers alone struggle with in-context learning (ICL). Although interleaving attention layers helps, this still suggests limitations in how well SSMs handle token-by-token reasoning or structure-sensitive tasks.
- Lack of fine-tuning or alignment: The released model is a base pretrained model without instruction tuning or safety alignment. This limits its immediate use in downstream applications without additional supervised or RLHF-based training.
Despite these challenges, Jamba represents a promising direction for scaling efficient, long-context language models and offers a practical blueprint for hybrid architectures in the LLM space.
Group 47 Presentation: Jamba: A Hybrid Transformer - Mamba Language Model
Presenter
Chentao Jin
Paper Citation
Lieber, O., Lenz, B., Bata, H., Cohen, G., Osin, J., Dalmedigos, I., Safahi, E., Meirom, S., Belinkov, Y., Shalev-Shwartz, S., Abend, O., Alon, R., Asida, T., Bergman, A., Glozman, R., Gokhman, M., Manevich, A., Ratner, N., Rozen, N., Shwartz, E., Zusman, M., Shoham, Y. (2024). Jamba: A Hybrid Transformer-Mamba Language Model. arXiv. https://arxiv.org/abs/2403.19887
https://doi.org/10.48550/arXiv.2403.19887
Clear explanations to aid understanding
Jamba uses a unique combination of transformer layers, memory layers, and a Mixture of Experts (MoE) layer to improve its efficiency, scalability, and ability to process long sequences of text. This hybrid design optimizes both memory management and computational power while minimizing resource use. Here’s an overview of the components and their roles in the architecture:
Jamba's Architecture
Jamba’s architecture is built on a structured sequence of layers: after each memory layer, a Mixture of Experts (MoE) layer is placed, and every transformer layer is followed by seven memory layers. This design creates an efficient, hybrid system that maximizes both memory handling and computational power while minimizing resource usage.
Each of these layers:
1. Transformer Layers: The transformer layers in Jamba are responsible for capturing the relationships between different tokens in the sequence, no matter how far apart they are. This is done using self-attention, which helps the model understand complex relationships and context within the text. However, traditional transformers can struggle with very long sequences because the self-attention mechanism becomes very memory and computation-heavy as the sequence length grows.
2. Memory Layers: To tackle this, Jamba introduces memory layers based on a different model called the state space model (SSM). These memory layers don’t rely on self-attention. Instead, they maintain a hidden state that keeps track of important information from earlier in the sequence. This makes memory layers far more efficient when it comes to handling long sequences, as they don’t need to store as much data in memory, unlike the transformer layers.
3. Mixture of Experts (MoE): The MoE component is where Jamba gets its flexibility. Instead of using all model parameters for each token, MoE selectively activates a small subset of "experts" for each token. An "expert" is a specialized set of parameters that focuses on solving a specific part of the problem. The model dynamically selects the experts that are most relevant to each token's context, allowing it to handle different parts of the problem more efficiently. Since only a few experts are activated per token, the model can scale its capacity to handle more complex tasks or longer sequences without significantly increasing computational costs.
Additionally, Jamba stabilizes its training process with RMSNorm, which normalizes activations and ensures that training remains stable, even when the model is scaled up to very large sizes.
Performance and Benefits of Jamba
Jamba's hybrid approach provides several advantages:
- Efficient Handling of Long Contexts: Jamba is able to process long sequences of text effectively, overcoming the limitations of traditional transformer models.
- Balance Between Performance and Efficiency: Jamba achieves strong performance while reducing memory usage and computational costs thanks to its combination of transformer, memory, and MoE layers.
- Scalability: By using fewer active parameters than other models, Jamba is scalable and can handle tasks that require understanding large amounts of text without compromising efficiency.
Limitations of Jamba
But this hybrid approach has limitations as well: like training instability that needs RMSNorm to keep things stable. It also requires a lot of GPU memory, like 80 GB, to handle longer contexts. Also, the mixture of experts (MOE) still needs more optimization to improve performance.
Further Directions
- Optimize MoE further for even better efficiency
- Investigate how hybrid architectures scale in even larger contexts and models