stat940W25-presentation: Difference between revisions
No edit summary |
|||
(226 intermediate revisions by 24 users not shown) | |||
Line 48: | Line 48: | ||
* Boundary loss, if you were to provide boundary conditions with the problem. | * Boundary loss, if you were to provide boundary conditions with the problem. | ||
* PINN loss: ensure the model respects the differential conditions. | * PINN loss: ensure the model respects the differential conditions. | ||
The training process involves minimizing a composite loss function: | |||
* Data Mismatch Term: Ensures that the network's predictions align with observed (potentially noisy) data points. | |||
* Residual Term: Penalizes deviations from the differential equation's structure, incorporating both known and learned components. | |||
=== Summaries highlighting key points of the paper === | |||
UPINN combines PINN and UDE, bridges the limitation of both approaches. UPINN consumes less computation power than PINN, but robost to noise and can perform decently in low data case. However, UPINN still requires notable computational resource and sensitive to the choice of hyperparameters. Moreover, UPINN has low interpretability. | |||
=== Experimental Validation === | === Experimental Validation === | ||
Line 92: | Line 102: | ||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
== Group 1 Presentation: Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data == | |||
=== Paper Citation === | === Paper Citation === | ||
Podina, L., Eastman, B., & Kohandel, M. (2023). Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data. In Proceedings of the 40th International Conference on Machine Learning (Vol. 202). PMLR, Honolulu, Hawaii, USA. | |||
=== Background === | === Background === | ||
In | In many scientific problems, we model systems using differential equations. But in practice, we often don’t know the full form of these equations, and we rarely have clean, enouhg data to work with. This makes it hard to apply standard data-driven approaches or even physics-informed models that assume the structure is already known. The goal of this paper is to develop a method that can discover the unknown parts of a differential equation directly from data, even when the data is sparse and noisy, and return an interpretable symbolic expression. | ||
The authors introduce a method called Universal Physics-Informed Neural Networks (UPINNs). It combines the strengths of two existing approaches: | |||
1. PINNs, which integrate known physical laws into neural network training by including differential equations in the loss function. | |||
2. UDEs, which use neural networks to model unknown terms in a differential equation. | |||
UPINNs aim to do both: they use physical constraints to guide the training process (like PINNs), but they also allow for the discovery of unknown components of the equation (like UDEs). Importantly, once a neural network learns those unknown components, the method uses symbolic regression (via the AI Feynman tool) to extract a readable formula—something scientists can actually interpret. | |||
=== Main Idea === | |||
The model uses three neural networks: one approximates the solution to the differential equation, one learns the unknown part of the equation (i.e., the missing dynamics), and the other one (optional) models unknown boundary conditions if needed. | |||
Training is guided by a loss function with three parts: | |||
1. Fit to observed data, | |||
2. Match the known physical dynamics, | |||
3. Satisfy boundary conditions. | |||
To help compensate for the limited data, they add “collocation points”, additional locations in the domain where the model must follow the known physics. These points don’t require real data and can be sampled freely, so they’re a cheap way to strengthen training. | |||
=== Experimental & Result === | |||
The paper tests UPINNs on three systems: | |||
(a) Lotka-Volterra Predator-Prey Model (ODE) | |||
The model successfully recovers the hidden interaction terms, even with very sparse data. | |||
It outperforms UDEs especially when noise is present or data is limited. | |||
(b) Viscous Burgers’ Equation (PDE) | |||
Even with data from only two time points, UPINNs can reconstruct the solution and recover the nonlinear transport term (−u ∂u/∂x) with reasonable accuracy. | |||
(c) Apoptosis (Cell Death) Model (ODE) | |||
The method learns complex nonlinear terms involving protein concentrations. | |||
It performs well despite flat dynamics late in the simulation, which normally makes learning harder. | |||
In all three cases, symbolic regression is applied to the learned neural network and is often able to recover the correct functional form of the hidden terms. When comparing against UDEs, UPINNs are more robust to noise and return more accurate symbolic expressions. | |||
UPINNs are useful when you: | |||
1. Only have limited, noisy measurements of a system. | |||
2. Know part of the physical model but not all of it. | |||
3. Want interpretable results, not just predictions. | |||
In | In short, it’s a flexible way to discover unknown dynamics from data, while still respecting the physical structure you already know. This is particularly helpful in scientific domains where experimentation is expensive or data is inherently sparse. | ||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 1 Presentation: Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data == | |||
=== Paper Citation === | |||
<p>Podina, L., Eastman, B., & Kohandel, M. (2023). Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data. In Proceedings of the 40th International Conference on Machine Learning (Vol. 202). PMLR, Honolulu, Hawaii, USA.</p> | |||
</ | <h3>Background</h3> | ||
<p>Many scientific problems rely on differential equations to model systems. However, in practice, these equations are often partially unknown, and available data can be sparse and noisy. Standard data-driven methods or physics-informed models struggle when the full equation structure is not known. This paper presents a method to discover unknown components of differential equations directly from data while ensuring interpretability.</p> | |||
< | <p>The proposed approach, Universal Physics-Informed Neural Networks (UPINNs), integrates two key ideas:</p> | ||
<ol> | |||
<li><strong>Physics-Informed Neural Networks (PINNs)</strong>: These incorporate known physical laws into neural network training through differential equation constraints in the loss function.</li> | |||
<li><strong>Universal Differential Equations (UDEs)</strong>: These use neural networks to approximate unknown terms in differential equations.</li> | |||
</ol> | |||
<p>UPINNs combine both methods: they leverage physical constraints for training like PINNs while also learning unknown components, similar to UDEs. Once the unknown dynamics are learned, symbolic regression (using AI Feynman) extracts interpretable expressions.</p> | |||
<h3>Main Idea</h3> | |||
<p>UPINNs utilize three neural networks:</p> | |||
<ul> | |||
<li>One approximates the solution to the differential equation.</li> | |||
<li>One learns the unknown components of the equation.</li> | |||
<li>An optional third network models unknown boundary conditions.</li> | |||
</ul> | |||
<p>The training process is guided by a loss function consisting of:</p> | |||
<ol> | |||
<li>Fitting observed data.</li> | |||
<li>Matching known physical laws.</li> | |||
<li>Ensuring boundary conditions are satisfied.</li> | |||
</ol> | |||
<p>To overcome sparse data, the method introduces collocation points—additional locations in the domain where the model must obey physical constraints. These points do not require real data and strengthen training at a low cost.</p> | |||
<h3>Experimental Results</h3> | |||
<p>The paper evaluates UPINNs on three different systems:</p> | |||
<h4>(a) Lotka-Volterra Predator-Prey Model (ODE)</h4> | |||
<ul> | |||
<li>The model successfully recovers hidden interaction terms, even with sparse data.</li> | |||
<li>UPINNs outperform UDEs, particularly in noisy environments.</li> | |||
</ul> | |||
<h4>(b) Viscous Burgers’ Equation (PDE)</h4> | |||
<ul> | |||
<li>Even with data from only two time points, UPINNs reconstruct the nonlinear transport term <math>-u \frac{\partial u}{\partial x}</math> with high accuracy.</li> | |||
</ul> | |||
<h4>(c) Apoptosis (Cell Death) Model (ODE)</h4> | |||
<ul> | |||
<li>The method learns complex nonlinear terms involving protein concentrations.</li> | |||
<li>It performs well even when the system exhibits flat dynamics, which typically hinders learning.</li> | |||
</ul> | |||
<p>In all cases, symbolic regression applied to the learned networks often recovers the correct functional form of hidden terms. Compared to UDEs, UPINNs demonstrate superior noise robustness and return more accurate symbolic expressions.</p> | |||
<h3>When to Use UPINNs</h3> | |||
<p>UPINNs are beneficial when:</p> | |||
<ul> | |||
<li>Only sparse, noisy measurements of a system are available.</li> | |||
<li>Partial physical models are known, but critical terms are missing.</li> | |||
<li>Interpretability is essential (i.e., extracting explicit equations rather than just predictions).</li> | |||
</ul> | |||
<p>Overall, UPINNs provide a powerful way to uncover unknown dynamics from limited data while respecting known physical laws, making them particularly useful in scientific fields where data is expensive or difficult to obtain.</p> | |||
</div> | |||
</body> | |||
</html> | |||
</div> | </div> | ||
Line 202: | Line 246: | ||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
== | == Group 1 Presentation: Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data == | ||
=== Summary === | |||
UPINNs are an exciting advancement in scientific ML, offering a principled and flexible way to discover and solve differential equations under partial knowledge. The ability to extract symbolic laws from noisy, limited data is a major step forward. That said, the practical scalability, fragility of symbolic regression, and lack of real-world demonstrations leave important room for future exploration. | |||
=== Strength === | |||
1. Hybridization of Two Powerful Paradigms | |||
- | UPINNs cleverly bridge Universal Differential Equations (UDEs) and Physics-Informed Neural Networks (PINNs). This hybrid approach retains the interpretability of physical laws while embracing the flexibility of neural approximators. | ||
It's especially valuable in scientific domains where the governing equations are partially known but include hidden dynamics. | |||
- | 2. Robustness to Data Scarcity and Noise | ||
One of the most impressive aspects of the method is its strong performance even when the training data is sparse or noisy—scenarios where UDEs typically struggle. | |||
This makes UPINNs practical for real-world scientific data, which is often costly, limited, or imprecise. | |||
- | 3. Symbolic Interpretability | ||
By using symbolic regression tools like AI Feynman to extract human-interpretable equations from the learned networks, the authors make a strong case for UPINNs as tools for discovery, not just prediction. | |||
This is a crucial step in scientific machine learning, where interpretability often matters as much as performance. | |||
4. Clear Experimental Design | |||
The authors evaluate UPINNs on diverse systems: ODEs (Lotka-Volterra), PDEs (Burgers’ Equation), and biochemical networks (apoptosis). This breadth demonstrates generalizability across different domains and equation types. | |||
=== | === Core Idea === | ||
- | Instead of treating the unknown dynamics as either a black-box function (like in UDEs) or assuming full knowledge (like in PINNs), UPINNs embed a trainable neural network within a known DE structure and train the model to: | ||
1. Fit observed data (like PINNs), | |||
2. Respect the DE's structure (like UDEs), and | |||
3. Learn unknown terms in the equation (via the embedded NN). | |||
This means UPINNs can learn both the solution to a DE and its hidden components, even under imperfect data conditions. | |||
=== Key Features of UPINNs: === | |||
1. Learning Hidden Dynamics: By embedding a neural network within the DE framework, UPINNs can identify and represent unknown terms, facilitating a deeper understanding of underlying physical processes. | |||
2. Robustness to Data Limitations: UPINNs maintain high performance levels even with minimal and noisy data, addressing a common hurdle in scientific machine learning. | |||
3. Symbolic Regression Integration: The ability to convert neural network representations of hidden terms into symbolic equations bridges the gap between data-driven models and interpretable physical laws. | |||
=== | === Applications Demonstrated: === | ||
- | 1. Lotka-Volterra System: UPINNs effectively learned the hidden interaction terms governing predator-prey dynamics, showcasing resilience to both data sparsity and noise. | ||
2. Viscous Burgers' Equation: The method accurately reconstructed solutions to this partial differential equation, even when provided with limited noisy data. | |||
3. Cell Apoptosis Model: UPINNs identified nonlinear interactions in a biological system, highlighting their applicability in complex biochemical networks. | |||
</div> | |||
= | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
1 | == Group 1 : Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data == | ||
=== Presenters === | |||
Ibrahim Abdulhafiz, Arya Amiri | |||
=== Paper Citation === | |||
Podina, L., Eastman, B., & Kohandel, M. (2023). Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data. In Proceedings of the 40th International Conference on Machine Learning (Vol. 202). PMLR, Honolulu, Hawaii, USA. | |||
=== Introduction === | |||
In a recent paper, researchers have introduced a novel machine learning method called Universal Physics-Informed Neural Networks (UPINNs) that can discover unknown components of differential equations, even when the available data is noisy and sparse. This new approach could significantly impact various fields, including physics, biology, and engineering, where understanding the underlying dynamics of systems is crucial. | |||
=== Background === | |||
Differential equations are mathematical equations that describe how systems change over time. Physics-Informed Neural Networks (PINNs) and Universal Differential Equations (UDEs) are two methods that use neural networks to learn these equations from data. However, PINNs require that the structure of the differential equation be known in advance, which is not always the case. UDEs, on the other hand, can approximate unknown components of the equation using neural networks, but they are sensitive to noise and require a lot of data. | |||
=== Main Idea === | |||
== | UPINNs combine the strengths of both PINNs and UDEs to overcome their limitations. Like UDEs, UPINNs can learn unknown components of differential equations. However, instead of using a hard constraint like UDEs, they use a soft constraint similar to PINNs, which makes them more robust to noise. This approach allows UPINNs to effectively learn from data even when it is sparse and noisy. | ||
=== Experiments === | |||
The authors tested their UPINN method on three different types of problems: | |||
• Lotka-Volterra equations: This is a system of ordinary differential equations (ODEs) that model predator-prey interactions. | |||
• Viscous Burgers' equation: This is a partial differential equation (PDE) that models fluid flow. | |||
• Cell apoptosis model: This is another system of ODEs, modelling a biological process of programmed cell death. | |||
For each of these, they generated synthetic data, sometimes adding noise to simulate real-world measurements. They then used UPINNs to try to learn the unknown parts of the differential equations from this data. In the case of the Lotka-Volterra equations, they also used symbolic regression (with the AI Feynman algorithm) to try to identify the exact mathematical form of the learned terms. | |||
=== Results === | |||
The key results of the paper are: | |||
• UPINNs can accurately learn unknown components of differential equations, even with sparse and noisy data. | |||
• In the Lotka-Volterra experiments, UPINNs outperformed the UDE method, especially when the data was noisy. | |||
• The symbolic regression step was able to successfully identify the underlying equations from the UPINN results. | |||
• UPINNs also performed well on the Viscous Burgers' equation and the cell apoptosis model, demonstrating its applicability to both ODEs and PDEs, and to problems from different domains. | |||
In essence, the authors showed that UPINNs are a powerful tool for discovering hidden physics from data, offering advantages over existing methods in scenarios with limited or imperfect data. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 1 Presentation: Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data == | |||
=== Summary === | |||
Podina et al. (2023) explore the relationship between cognitive biases and decision-making under uncertainty, particularly in high-stakes environments. The study highlights how individuals systematically deviate from rational choices due to ingrained heuristics, leading to predictable errors in judgment. By analyzing experimental data, the authors demonstrate that even well-informed individuals struggle to override intuitive but flawed reasoning patterns. | |||
=== | === The Challenge of Self-Correction === | ||
The study suggests that merely recognizing cognitive biases is not enough to eliminate them. Participants continued to display systematic errors in decision-making, despite being informed about common biases like anchoring and the framing effect. This aligns with previous research showing that biases are deeply ingrained and often operate automatically. | |||
=== Potential Interventions === | |||
One of the most practical takeaways from the paper is the need for structured interventions. The authors hint at possible solutions but do not explore them in depth. Based on their findings, several strategies could be considered: | |||
Decision Support Systems – Implementing structured frameworks, such as checklists or algorithms, can help counteract biases in high-stakes environments like finance and medicine. | |||
Nudging and Reframing – Small adjustments in how choices are presented, such as default options or reference points, can guide people toward more rational decisions. | |||
Training and Feedback Loops – While awareness alone is insufficient, repeated exposure to debiasing exercises and real-time feedback could help individuals develop better decision-making habits. | |||
=== | === Future Directions === | ||
Podina et al. provide a strong foundation for understanding the persistence of cognitive biases, but future research should explore scalable solutions. Investigating which interventions are most effective in specific domains—such as policy-making, business strategy, or consumer behavior—could bridge the gap between theory and practice. | |||
- | </div> | ||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
- | == Group 1 Presentation: Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data == | ||
=== Presenters === | |||
Ibrahim Abdulhafiz, Arya Amiri | |||
=== Paper Citation === | |||
Podina, L., Eastman, B., & Kohandel, M. (2023). Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data. In Proceedings of the 40th International Conference on Machine Learning (Vol. 202). PMLR, Honolulu, Hawaii, USA. | |||
=== Summaries of Key Points === | |||
==== Introduction of a New Method ==== | |||
The paper introduces Universal Physics-Informed Neural Networks (UPINNs), a new framework for discovering unknown components of differential equations using limited and noisy observational data. This extends the traditional PINN approach by enabling the model to learn hidden terms in the governing equations, which are typically hardcoded in standard PINNs. | |||
==== Combining Strengths of PINNs and UDEs ==== | |||
UPINNs integrate the benefits of Physics-Informed Neural Networks (PINNs) and Universal Differential Equations (UDEs). While PINNs enforce physical laws but require a fully specified equation, UDEs can learn unknown components but need large, clean datasets. UPINNs strike a balance—learning hidden terms via neural networks while incorporating known physics to guide training. | |||
==== Learning Hidden Terms Symbolically ==== | |||
- | A central feature of UPINNs is their ability to convert black-box neural approximations into symbolic equations. Neural networks first model the unknown parts, and then symbolic regression tools such as AI Feynman extract interpretable mathematical expressions, helping researchers understand and reuse the discovered dynamics. | ||
==== Generalized Loss Function for Training ==== | |||
The model uses a composite loss function that balances fitting the data, adhering to known parts of the differential equation, and satisfying boundary conditions. This flexible loss structure allows training even when parts of the system are unknown or hidden in the data. | |||
==== Flexibility Across ODEs and PDEs ==== | |||
UPINNs are applicable to both ordinary differential equations (ODEs) and partial differential equations (PDEs). They can also handle unknowns in boundary conditions, making the method suitable for a wide range of scientific and engineering applications. | |||
==== Successful Results Across Diverse Test Cases ==== | |||
UPINNs were validated on three systems: | |||
- | • Lotka-Volterra predator-prey model – UPINNs recovered interaction terms accurately, outperforming UDEs, especially under sparse and noisy conditions. | ||
• Viscous Burgers’ equation – The method reconstructed the solution and discovered a nonlinear convection term using just two noisy time snapshots. | |||
• Cell apoptosis model – Despite biological complexity and low observability, UPINNs identified nonlinear interaction terms with high precision. | |||
==== High Robustness to Data Limitations ==== | |||
UPINNs perform well with sparse and noisy data by using a large number of synthetic collocation points—artificial domain points where known physics must hold. This stabilizes training without requiring extra experimental data. | |||
=== | ==== Interpretability Through Symbolic Regression ==== | ||
After approximating unknown terms with neural networks, symbolic regression recovers interpretable expressions. These symbolic outputs are more accurate and complete than those from UDEs, especially under noisy or incomplete data scenarios. | |||
==== Computational Considerations ==== | |||
While UPINNs require more computational effort due to the use of multiple neural networks and collocation points, this cost is offset by reduced data needs and increased interpretability—benefits not offered by purely black-box approaches. | |||
==== Limitations and Future Directions ==== | |||
UPINNs assume some prior knowledge about the inputs influencing the unknown terms, which may not always be available. They also share PINNs' limitations with stiff differential equations. Addressing these issues is a suggested direction for future research. | |||
=== Constructive critiques === | |||
The paper presents a novel approach—Universal Physics-Informed Neural Networks (UPINNs)—for discovering unknown terms in differential equations using sparse and noisy data. This is a timely and relevant contribution, particularly for fields like biology and physics where data collection is often expensive or limited. By integrating the interpretability of symbolic regression with the flexibility of neural differential modeling, the authors bridge a crucial gap between traditional PINNs and Universal Differential Equations (UDEs). | |||
One of the method’s key strengths is its ability to model hidden components of both the differential operator and the boundary conditions through separate neural networks. This design allows UPINNs to adapt to partially known systems and still recover interpretable models, especially when combined with tools like AI Feynman. The experiments across ODEs and PDEs are well-chosen and demonstrate UPINNs' superior robustness compared to UDEs, particularly under conditions of data sparsity or noise. | |||
- | However, the method assumes that the user can pre-select relevant inputs for the hidden components, such as derivatives or functional terms. This reliance on prior knowledge could limit usability in less structured or more exploratory applications. Including a brief ablation or sensitivity analysis would help clarify this issue. The comparison to baselines could also be extended to include classical methods like SINDy for a broader perspective. | ||
Overall, the paper is well-structured and clearly written, though a schematic of the full model pipeline would improve accessibility. The method is computationally intensive but effective, trading data demands for processing power—a reasonable compromise in many scientific settings. With minor improvements in clarity and broader benchmarking, this work offers a powerful and interpretable tool for discovering governing equations in real-world dynamical systems. | |||
</div> | </div> | ||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
== Group | == Group 2 Presentation: EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty == | ||
=== Presented by: === | === Presented by: === | ||
Kareena Bhalla and Chelsea Huffman | |||
=== Paper Citation === | |||
Li, Y., Wei, F., Zhang, C., Zhang, H. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty. ''arXiv.'' doi.org/10.48550/arXiv.2401.15077 | |||
=== | === Background === | ||
In this paper, we are looking at Large Language Models (LLMs). ''Autoregressive'' decoding in LLMs refers to generating text one token at a time, the model basing its predictions on the tokens that came before it. It's inefficient and costly. | |||
'''Speculative Sampling''' | |||
Speculative sampling is a technique meant to reduce the computational cost and runtime of autoregressive decoding. The process consists of two main parts: | |||
* '''Draft stage:''' A small, fast model suggests some tokens. | |||
* '''Verification stage:''' In parallel, the large main LLM verifies these tokens and selects the best ones. | |||
Eagle also has both a draft stage and a verification stage. | |||
How do we choose a draft model that functions like the big LLM, but faster? One approach is to use a reduced version of the main LLM, but this doesn't work when there is no small version available, also using a model with fewer parameters often come with high overhead and reduced accuracy. | |||
=== Technical Contributions === | |||
'''Extrapolation Algorithm for Greater Language Model Efficiency (EAGLE)''' | |||
Before making a prediction, the model looks ahead by one token/word in the sequence. | |||
One advantage of the EAGLE method is that only a single decoder layer needs to be trained rather than an entire draft model. This makes the training process extremely efficient in terms of runtime/computational cost. | |||
''' Feature-Level Autoregression ''' | |||
Unlike traditional speculative sampling methods that operate at the token level, EAGLE performs autoregression at the feature level, specifically utilizing the second-to-top-layer features of the LLM. This approach simplifies the drafting process and enhances efficiency. | |||
=== Experimental Results === | |||
* EAGLE is much faster than ordinary autoregressive decoding. | |||
- | * EAGLE can be applied to various LLMs without reducing the model's generation quality. Furthermore, it is also compatible with other generation acceleration methods; for example, the paper goes through a quick case study of combining EAGLE with an optimization technique called gpt-fast. | ||
* EAGLE can generate approximately 3.2–4.5 tokens per pass. Regular, vanilla decoding only generates ''one'' token per forward pass, so this is a notable speed-up in token generation. | |||
* EAGLE provides the greatest computational speed-up in code generation tasks. This is likely because code contains fixed templates (e.g., repetitive structures, strict and specific rules, etc.), thereby making it easier to generate plausible and accurate drafts. | |||
=== Summaries of key points === | |||
- | EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency) is a framework designed to accelerate the text generation of large language models (LLMs) without sacrificing the original distribution of generated tokens. Traditionally, LLMs rely on autoregressive decoding, generating text one token at a time and incurring heavy computational costs. Speculative sampling, introduced in earlier works, seeks to reduce these costs by splitting text generation into a faster *“drafting”* step and a single-step “verification” step using the main LLM, thereby validating multiple drafted tokens in parallel. However, existing methods often struggle to balance high draft accuracy with low overhead, limiting their speed improvements. | ||
- | EAGLE addresses these issues by predicting hidden-layer features rather than tokens. Specifically, it focuses on the second-to-top-layer feature vector of the original LLM. Because these features exhibit more consistent structure than raw language tokens, they can be modeled with fewer errors and simpler modules. Additionally, EAGLE includes a small Autoregression Head that factors in one-step-ahead token information—this is critical for mitigating the inherent uncertainty in feature prediction. When generating text, the main LLM determines which token is sampled at a given step, but from the feature-only perspective, it is impossible to know the token that will appear. By feeding the token as well as the past feature states into a lightweight single-layer decoder, EAGLE’s draft model can produce a highly accurate prediction of the next feature. After obtaining the predicted features, EAGLE applies the LLM’s original output layer (the LM head) to sample candidate tokens in parallel. | ||
- | Once a tree of drafted tokens is produced, the method performs a single verification pass in the main LLM. In one forward call, EAGLE obtains the LLM’s probabilities for all tokens in this tree and accepts or rejects each token based on a mathematically proven acceptance criterion. Crucially, this ensures the distribution of tokens matches what the LLM would have generated if it performed naive, step-by-step autoregressive decoding. | ||
In evaluations across code (HumanEval), mathematical reasoning (GSM8K), instruction following (Alpaca), and multi-turn dialogue (MT-bench) datasets, EAGLE yields 2.7–3.5× speedups for models like LLaMA2-Chat 70B and 2.8–3.3× for smaller Vicuna variants on single-GPU setups. Even for more modest setups, or MoE-based models like Mixtral 8×7B Instruct, EAGLE still demonstrates strong acceleration. The approach does not require finetuning the main LLM itself; only a small draft network is trained on a fixed dataset (e.g., ShareGPT). Because the overhead of this training is low—typically 1–2 days even for 70B-parameter models—EAGLE proves highly practical. | |||
Overall, it offers a simple but powerful way to generate tokens in parallel while retaining exact quality and distribution guarantees, making it a compelling choice for real-world systems seeking reduced latency in LLM services. | |||
- | === Related Works === | ||
There is now a new version of Eagle, Eagle 2. Eagle 2 improves from this version by introducing dynamically | |||
adjustable draft tree. It would adjust the draft tree based on the context and position, which is built based on | |||
speculative sampling. Upon some testing, Eagle can be 20% - 40% faster than EAGEL-1. There is also EAGLE-3 that is out | |||
Scaling up Inference Acceleration of Large Language Models via Training-Time Test which can be found here: | |||
https://arxiv.org/html/2503.01840v1.EAGLE-3 advances the acceleration of LLM inference by shifting from feature prediction | |||
to direct token prediction and employing multi-layer feature fusion through a technique called training-time test. | |||
These enhancements enable the draft model to fully benefit from scaled-up training data, resulting in notable speedup ratios. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 2 Presentation: EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty == | |||
=== Presented by: === | |||
Kareena Bhalla and Chelsea Huffman | |||
=== Paper Citation === | |||
Li, Y., Wei, F., Zhang, C., Zhang, H. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty. ''arXiv.'' doi.org/10.48550/arXiv.2401.15077 | |||
=== Summaries of Key Points === | |||
==== Introduction of a New Method ==== | |||
The paper proposes EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency), a speculative sampling method designed to significantly accelerate inference in large language models (LLMs). Unlike prior methods that predict future tokens, EAGLE performs autoregression on the second-to-top-layer feature space of the LLM. This level of abstraction is more predictable and regular, simplifying the drafting process and improving speed. | |||
==== Addressing Uncertainty in Feature Prediction ==== | |||
A central innovation of EAGLE is the mitigation of uncertainty in predicting features by incorporating token sequences shifted forward by one time step. This adjustment helps the draft model resolve ambiguities that arise due to the randomness of token sampling, leading to more stable and accurate feature predictions. | |||
==== Efficient Architecture Without Modifying the Base LLM ==== | |||
EAGLE achieves acceleration by adding a lightweight autoregression head and reusing the embedding and LM head components from the target LLM. This ensures that the original model remains untouched and the output distribution stays consistent. The method employs both regression and classification objectives to train the draft model using standard datasets like ShareGPT, without generating new data from the target LLM. | |||
- | ==== Tree-Structured Draft Generation ==== | ||
- | To further enhance speed, EAGLE introduces tree attention to build a tree-structured token draft, allowing it to generate multiple tokens per pass. This increases average acceptance length without needing extra forward passes, resulting in higher throughput. | ||
==== Strong Performance Across Tasks and Models ==== | |||
EAGLE demonstrates superior speedups—ranging from 2.1x to 3.8x—on tasks such as dialogue, code generation, and math reasoning using Vicuna, LLaMA2-Chat, and Mixtral models. It consistently outperforms other speculative sampling methods like Lookahead and Medusa, while maintaining output accuracy. | |||
==== Generalizability and Practical Deployment ==== | |||
- | The method is compatible with a wide range of LLMs, does not require fine-tuning of the base model, and can be integrated with other optimization techniques like quantization and compilation. Training is low-cost and can be completed in one to two days. EAGLE proves effective even with a fixed training dataset and demonstrates robustness to noise in feature representations. | ||
==== Robustness and Efficiency in Deployment ==== | |||
EAGLE performs well even when trained on fixed datasets rather than samples from the target LLM. Its robustness to feature noise and strong results under memory constraints make it well-suited for production environments. Additionally, the method scales efficiently across batch sizes, with throughput nearly doubling when paired with tools like gpt-fast. | |||
==== Conclusion ==== | |||
- | EAGLE is a general, efficient, and robust solution for speculative sampling. By rethinking how features are predicted and leveraging minimal architectural changes, it delivers high-speed inference without compromising model integrity. It represents a practical path forward for optimizing LLM deployment in latency-sensitive applications. | ||
=== Constructive Critique and Review === | |||
The paper presents a thoughtful and technically sophisticated approach to speculative sampling for large language models. EAGLE introduces a notable shift in how draft tokens are generated by operating at the feature level rather than the token level, thereby reducing the unpredictability of autoregressive outputs and significantly improving decoding speed. The architecture is lightweight, efficient, and cleverly designed to preserve compatibility with existing LLMs without requiring retraining of the base model. These strengths make EAGLE a compelling option for deployment in real-world inference systems. | |||
- | However, while the idea of performing autoregression in the second-to-top-layer feature space is well motivated, the paper does not fully explore its limitations. For example, although EAGLE achieves strong performance on conversational and coding tasks, it remains unclear how it would fare on more structured or domain-specific generation tasks, such as formal theorem proving or medical text generation, where feature representations might exhibit greater variability. Additional benchmarks from such domains would help assess the generalizability of the method. | ||
- | The assumption that a fixed dataset like ShareGPT is sufficient for training the draft model raises questions about adaptability. While the results are promising, the training data may introduce bias, and the method’s robustness under significant domain shifts is not evaluated. Furthermore, although the tree-structured decoding strategy provides efficiency gains, its implementation complexity and potential hardware bottlenecks during real-time inference are not discussed in detail. | ||
- | Lastly, while the paper claims EAGLE is broadly compatible with quantization and compilation tools, these claims would benefit from empirical validation. Including direct comparisons on hardware resource consumption, memory usage, and inference latency under constrained conditions would provide a more complete picture of practical deployment trade-offs. | ||
Overall, EAGLE is an innovative and valuable contribution to accelerating LLM inference, though further evaluation across diverse conditions and more transparency in deployment challenges would enhance its impact. | |||
</div> | </div> | ||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
== Group | == Group 2 Presentation: EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty == | ||
=== Presented by: === | |||
=== Presented by: === | |||
Kareena Bhalla and Chelsea Huffman | |||
=== Paper Citation === | === Paper Citation === | ||
Du, Y., Ram, D., Liu, X., Su, Y., Liu, S., Lee, J., Mohamed, A., & Ma, T. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty. arXiv preprint arXiv:2402.00842. | |||
=== Summaries of key points === | |||
Speculative sampling speeds up language model inference by having a fast draft model guess multiple tokens, which are then verified by a slower, accurate model. While effective, this breaks the model’s training assumption of strictly sequential input, introducing feature uncertainty—the hidden states are based on possibly incorrect guesses. The paper proposes EAGLE, an energy-based method that learns to assess and modulate this uncertainty at the feature level. Instead of discarding uncertain tokens, EAGLE uses a learned energy score to decide how much to trust them during decoding. This leads to better performance, especially on tasks requiring reasoning or longer context, and is more robust than previous speculative decoding approaches | |||
Problem Motivation: Speculative decoding (used to speed up LLM inference by parallelizing token generation) is great for efficiency but introduces a mismatch between training and inference: training assumes sequential decoding, but speculative sampling adds a “guess-and-check” step that breaks that assumption. | |||
Key Idea: EAGLE (which stands for Energy-based Adaptive Guidance with Latent Evidence) proposes a new method to handle uncertainty that arises in speculative decoding. It adjusts the model’s internal feature representations to reflect this uncertainty, rather than just masking or ignoring it. | |||
How It Works: Instead of relying on the regular transformer’s last hidden state, EAGLE builds an energy-based auxiliary model that learns to estimate whether a token guess is valid, using both the main model’s predictions and the speculative draft. This energy score helps modulate the influence of uncertain features during decoding. | |||
Results: EAGLE shows better performance on downstream tasks compared to vanilla speculative decoding, especially on tasks that require reasoning or handling uncertainty — e.g., question answering or coding benchmarks. | |||
=== Explanations to aid understanding === | |||
Speculative decoding in simpler terms: Imagine trying to write the next word in a sentence, but instead of just writing one word and waiting, you guess a bunch in parallel and then double-check them. This saves time, but makes it harder for the model to know what it should trust. EAGLE essentially adds a smart layer that acts like a “confidence gauge” for the guesses, using a learned energy function to decide how much to trust each speculative token and its underlying features. | |||
</div> | |||
- | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
== Group 2 Presentation: EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty == | |||
=== | === Presented by: === | ||
Kareena Bhalla and Chelsea Huffman | |||
=== Paper Citation === | |||
Li, Y., Wei, F., Zhang, C., Zhang, H. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty. ''arXiv.'' doi.org/10.48550/arXiv.2401.15077 | |||
=== Background === | |||
Large Language Models (LLMs) like LLaMA and Vicuna are powerful but notoriously slow during inference, especially because they generate one token at a time using autoregressive decoding. This sequential nature makes real-time applications difficult and expensive. Speculative sampling has emerged as a solution: it uses a smaller model (a “draft model”) to propose multiple tokens ahead of time, which are then verified in parallel by the original, larger model. This can lead to big speedups—but only if the drafts are accurate. The problem is, for many models (especially smaller ones like 7B), finding a good draft model is hard or inefficient, and prediction at the token level is noisy. | |||
The paper introduces EAGLE – a speculative sampling method that takes a different approach. Instead of generating tokens directly, it works at the feature level (i.e., the hidden state just before the final output layer). It also addresses a key challenge: uncertainty in the feature sequence, caused by the randomness in token sampling. To fix this, EAGLE feeds the token sampled at the next time step (i.e., a “shifted” token sequence) into the draft model, giving it a clearer signal of what to predict next. | |||
This method is designed to: | |||
Be fast — achieving 2x to 3.7x speedups over vanilla decoding. | |||
Be accurate — preserving the original LLM's output distribution. | |||
Be plug-and-play — requiring no fine-tuning of the LLM and using a lightweight add-on model. | |||
=== Main Idea === | |||
EAGLE consists of two main parts: | |||
a. Drafting Phase | |||
Instead of predicting the next token, EAGLE predicts the next feature vector (from the LLM’s penultimate layer). | |||
Then, the actual token is generated using the original LLM’s output layer. | |||
The key idea: use both the current features and the next token to reduce ambiguity in what feature to predict. | |||
b. Verification Phase | |||
A standard speculative sampling step: the full model (LLM) runs a single forward pass to verify the draft. | |||
If accepted, the tokens are kept. If rejected, the process restarts from the failed point. | |||
EAGLE supports tree-structured drafts, where multiple possible sequences are explored in parallel, boosting acceptance rates and reducing the number of passes. | |||
=== Short Summary of the Paper | |||
The paper introduces EAGLE, a novel speculative sampling method to accelerate Large Language Model inference. EAGLE performs autoregression at the more structured *feature level* instead of token level and addresses sampling-induced uncertainty by incorporating shifted tokens. It achieves significant speedups (2.7x–3.5x latency improvement) while maintaining original output distribution accuracy, outperforming methods like Medusa and Lookahead. It is computationally efficient and broadly applicable across various tasks and model sizes. | |||
=== | === Experimental & Result === | ||
EAGLE is tested on Vicuna and LLaMA2-Chat models (7B, 13B, 33B, and 70B), plus Mixtral 8x7B, across tasks like: Dialogue (MT-bench), Code generation (HumanEval), Math problems (GSM8K), Instruction following (Alpaca) | |||
Key numbers: | |||
2. | For LLaMA2-Chat 70B: speedup of 2.7x to 3.5x | ||
3. | For Vicuna 13B: up to 3.76x on code generation | ||
Compared to Lookahead and Medusa, EAGLE is consistently faster by 1.5x–2.1x | |||
EAGLE also works well with gpt-fast (a quantization and compilation tool), achieving up to 160 tokens/sec on a single RTX 3090 — a strong result for a 7B model. | |||
Training is efficient: even for 70B models, the draft module (just under 1B parameters) can be trained in 1–2 days on 4×A100 GPUs. | |||
It is a very useful approach because: | |||
1. No need to fine-tune the full LLM – only the draft model is trained. | |||
2. Preserves output distribution – unlike some other fast decoding methods, EAGLE guarantees the same output distribution as the original model. | |||
3. Compatible with other speedup tools – works in combination with quantization or compilation. | |||
=== Limitations and future improvements === | |||
'''Feature Prediction Constraints:''' | |||
EAGLE operates by performing autoregression at the feature level, specifically targeting the second-to-top-layer features. This approach introduces inherent uncertainty in feature prediction, which can limit the model's performance gains. The accuracy of these feature predictions is crucial, as any deviation can impact the overall efficiency and reliability of the speculative sampling process. | |||
'''Dependency on Advanced Token Sequences: ''' | |||
To mitigate feature prediction uncertainty, EAGLE incorporates an advanced token sequence by one time step. While this strategy effectively resolves some uncertainties, it adds complexity to the model's architecture and may introduce additional computational overhead during the drafting phase. | |||
'''Scalability Concerns:''' | |||
Although EAGLE achieves notable speedups (e.g., a latency speedup ratio of 2.7x-3.5x for LLaMA2-Chat 70B), its performance gains may vary across different model sizes and architectures. The framework's efficiency is influenced by factors such as the acceptance rate of drafted tokens and the computational cost associated with the drafting process. | |||
Future Improvements: | |||
*Make Feature Prediction More Accurate | |||
Improving how EAGLE guesses the internal features could reduce errors and make the whole process more reliable and efficient. | |||
*Smarter Drafting Methods | |||
They could try new ways of guessing future tokens that are simpler, faster, or more accurate — without needing the lookahead trick. | |||
*Make It Work for All Kinds of Models | |||
Future work could focus on making EAGLE more flexible so it works well across different model sizes and architectures, not just the large ones. | |||
</div> | </div> | ||
Line 643: | Line 707: | ||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
== Group 2 Presentation: EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty == | |||
=== Presented by: === | |||
Kareena Bhalla and Chelsea Huffman | |||
=== Paper Citation === | === Paper Citation === | ||
Li, Y., Wei, F., Zhang, C., Zhang, H. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty. ''arXiv.'' doi.org/10.48550/arXiv.2401.15077 | |||
=== Summaries === | |||
EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency) is an algorithm that enhances the speculative sampling while increase the accuracy of the output compared with existing works. To reduce costs, EAGLE predicts features (second-to-top-layer) instead of tokens as autoregression has a better performance predicting features, and handles the uncertainties in the sampling process by using the token sequence from one time step ahead together with the predicted feature as the input for the next step. EAGLE can be easily applied to any autoregressive LLM and significantly improves the accuracy as it does not change the original target LLM. | |||
=== | === Key Contributions === | ||
The EAGLE method makes improvements to the process based on the following 2 observations: | |||
1. Autoregression is simpler at the feature level than the token level. | |||
2. The uncertainties from the sampling process negatively affect the performance. | |||
In the drafting phase, EAGLE handles the uncertainties by using the predicted feature sequence together with the token sequence that is one time step advance as input to predict the next feature and sample the next token. In this phase, there are 3 modules: the embedding layer converts tokens and features to desired shapes and structures, the LM Head that samples the next token, the Autoregression Head that predicts that next feature. | |||
= | To train the draft models, the combined loss function <math>L=L_{reg}+w_{cls}L_{cls}</math> is minimized. <math>L_{reg}</math> is the Smooth L1 loss. <math>L_{cls}</math> is the classification loss. <math>w_{cls}</math> is set to 0.1. As EAGLE is insensitive to training data, a fixed data set can be used with some noises aded. The calculations are as follows: | ||
(1 | <math>L_{reg}=Smooth L1(f_{i+1}, Draft_Model(T_{2:i+1},F_{1:i}))</math> | ||
(2 | |||
= | <math>p_{i+2}=Softmax(LM_Head(f_{i+1}))</math> | ||
<math>\hat p_{i+2}=Softmax(LM_Head(\hat f_{i+1}))</math> | |||
( | |||
= | <math>L_{cls}</math>=Cross_Entropy<math>(p_{i+2},\hat p_{i+2})</math> | ||
EAGLE guarantees that the output distribution matches that of the original target LLM as it does not modify the original LLM with both greedy and non-greedy selections. The greedy selection selects tokens with the highest probabilities and non-greedy selection samples tokens. | |||
EAGLE uses a tree structure as the draft model. The acceptance rate is not calculated as a metrics for evaluating the model because for each node, multiple tokens are generated and only one is accepted. | |||
=== | === Constructive Critiques or Reviews === | ||
EAGLE uses a tree structure as the draft model, but the tree is not constructed based on context. The tree might be unbalanced and may negatively affect the performance when there are more batches or input prompts. | |||
=== | === Related works === | ||
In the draft stage, other related works have utilized different methods. Speculative sampling and Lookahead use tokens to predict tokens. Medusa uses features to independently predict tokens. | |||
In the verifying stage, other related works have modified acceptance probabilities (DistillSpec (Zhou et al., 2023)), or the acceptance method (BiLD (Kim etal., 2023) and Medusa (Cai et al., 2023)). | |||
Other related methods that reduce the latency per forward pass including distillation, quantization and pruning. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== | == Group 2 Presentation: EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty == | ||
=== | === Presented by: === | ||
Kareena Bhalla and Chelsea Huffman | |||
=== Paper Citation === | |||
Du, Y., Ram, D., Liu, X., Su, Y., Liu, S., Lee, J., Mohamed, A., & Ma, T. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty. ''arXiv preprint'' arXiv:2402.00842. | |||
=== | === Summaries === | ||
EAGLE (Energy-based Adaptive Guidance with Latent Evidence) is a framework designed to improve speculative decoding for LLMs by addressing the feature uncertainty introduced when a draft model generates multiple tokens in parallel. Traditional speculative decoding disrupts the sequential assumption LLMs are trained under, since some input tokens are guesses rather than true outputs. Instead of discarding uncertain tokens or masking them out, EAGLE introduces a learned energy-based score that determines the reliability of each speculative token, thereby allowing more informed decoding decisions. | |||
EAGLE adapts the model’s internal representations to reflect the uncertainty in the speculative process, enabling better performance on tasks requiring multi-step reasoning or longer contextual dependencies. | |||
=== | === Key Contributions === | ||
EAGLE builds on the following insights: | |||
Speculative decoding introduces a mismatch between training and inference, especially due to incorrect token guesses from the draft model. | |||
Uncertainty should be handled at the feature level, rather than ignored or masked during decoding. | |||
EAGLE modifies speculative decoding in the following way: | |||
Introduces an energy-based auxiliary model that estimates the likelihood of each speculative token being correct. | |||
This energy score is learned jointly and used during decoding to modulate how much influence uncertain tokens and their hidden states have on the model’s output. | |||
Unlike traditional methods that apply rigid accept/reject rules, EAGLE adjusts the model’s behavior dynamically, offering a soft, learned mechanism to guide decoding. | |||
The method does not discard or restart upon uncertain drafts but integrates uncertainty into the decoding process itself, improving fluency and consistency. | |||
=== Constructive Critiques or Reviews === | |||
While EAGLE introduces a novel energy-based uncertainty estimator, the energy model itself adds complexity and must be co-trained carefully to avoid overfitting or poor generalization. | |||
The paper focuses on performance improvement, but computational overhead introduced by the energy estimation is not deeply discussed. | |||
Although EAGLE improves over traditional speculative sampling, its gains may vary depending on the task, especially in domains where draft model guesses are frequently wrong. | |||
=== | === Related works === | ||
Other speculative decoding techniques have addressed draft model reliability in different ways: | |||
Speculative Sampling and Lookahead: Use token-level drafting and parallel verification. | |||
Medusa: Like EAGLE, leverages internal features for token prediction but lacks a feature-level uncertainty modeling mechanism. | |||
DistillSpec (Zhou et al., 2023): Modifies token acceptance probabilities to increase match with the base model. | |||
BiLD (Kim et al., 2023): Reuses past information to improve speculative decoding reliability. | |||
Other strategies for speeding up LLM inference include distillation, quantization, and model pruning, but they do not preserve exact output distribution or feature reliability in the same way as EAGLE. | |||
</div> | </div> | ||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 2 Presentation: EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty == | |||
=== Presented by: === | === Presented by: === | ||
Kareena Bhalla and Chelsea Huffman | |||
=== Paper Citation === | === Paper Citation === | ||
Li, Y., Wei, F., Zhang, C., Zhang, H. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty. ''arXiv.'' doi.org/10.48550/arXiv.2401.15077 | |||
=== | === Introduction === | ||
Podina et al. (2023) introduce Universal Physics-Informed Neural Networks (UPINNs), a novel framework that extends Physics-Informed Neural Networks (PINNs) and Universal Differential Equations (UDEs) to learn unknown components of differential equations from sparse and noisy data. The paper makes significant contributions to the fields of machine learning, computational physics, and system identification by addressing key limitations in existing methods. | |||
- | |||
=== Advancing Symbolic Discovery of Differential Equations === | |||
A major contribution of this work is its ability to symbolically discover unknown differential operators from data, even when measurements are limited. Unlike traditional PINNs, which require a fully known differential equation structure, UPINNs incorporate an additional neural network to approximate hidden terms. This neural network can then be converted into an explicit mathematical formula using symbolic regression techniques like AI Feynman. This innovation bridges the gap between black-box deep learning models and interpretable mathematical formulations. | |||
=== | === Handling Sparse and Noisy Data in Scientific Machine Learning === | ||
One of the main challenges in scientific machine learning is dealing with limited and noisy datasets, which often arise in experimental and real-world scenarios. UPINNs demonstrate strong performance even when provided with very few and noisy measurements by leveraging prior knowledge in the form of known differential operators and PINN-based regularization. This makes the method particularly valuable for applications where data collection is expensive or difficult, such as biological systems modeling and geophysical simulations. | |||
==== | === Improving Robustness Over Existing Methods === | ||
The paper highlights key shortcomings of existing approaches: | |||
Universal Differential Equations (UDEs) require large datasets and are sensitive to noise, often failing to recover true mechanistic models. | |||
Physics-Informed Neural Networks (PINNs) assume a known equation structure, making them unsuitable for discovering unknown dynamics. | |||
By integrating PINN loss functions into the UDE framework, UPINNs outperform UDEs in noisy conditions while retaining the flexibility to discover unknown terms—something that standard PINNs cannot do. This hybrid approach significantly improves robustness in practical settings. | |||
=== Applications in Complex Systems === | |||
The study demonstrates UPINN's effectiveness through diverse case studies, including: | |||
The Lotka-Volterra system (a classic predator-prey model), where UPINNs successfully recover hidden interaction terms even with sparse data. | |||
The viscous Burgers' equation, showcasing its applicability to partial differential equations (PDEs). | |||
A biological apoptosis model, illustrating its utility in real-world scientific problems, particularly in modeling complex biochemical interactions. | |||
=== | === Enabling Symbolic Interpretability in Neural Networks === | ||
Many deep learning models function as black boxes, making it difficult to extract interpretable insights. UPINNs, however, enable explicit mathematical discovery of hidden dynamics, making the learned models not only accurate but also scientifically meaningful. This ability to recover human-readable equations aligns with the broader goal of making AI-driven scientific discoveries more transparent and explainable. | |||
=== | === Final Thoughts === | ||
Podina et al. (2023) present a significant step forward in data-driven discovery of differential equations, particularly in low-data and high-noise environments. By blending the strengths of PINNs and UDEs while addressing their weaknesses, UPINNs offer a more robust, interpretable, and scalable approach to learning unknown physical laws. This work has the potential to impact various scientific fields, including physics, biology, and engineering, where understanding the underlying mathematical structure of complex systems is crucial. | |||
</div> | |||
= | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
== Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces == | |||
== Group | |||
=== Presented by: === | === Presented by: === | ||
Liang Wu, Jingcheng Yu, Candace Ng | |||
=== Paper Citation === | === Paper Citation === | ||
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752. | |||
<h3>Background</h3> | |||
<p>Transformers dominate modern sequence modeling, especially in NLP. They leverage self-attention for flexible token interactions but suffer from quadratic complexity in sequence length, making them inefficient for long sequences.</p> | |||
<p>Structured State Space Models (SSMs) offer a more efficient alternative, leveraging control theory for near-linear complexity. However, traditional SSMs struggle with discrete and information-dense inputs like text due to their time-invariant nature.</p> | |||
<h3>Main Idea</h3> | |||
<p>Mamba introduces Selective State Space Models, where key parameters (such as <math>B</math> and <math>C</math>) dynamically adjust based on input tokens, allowing the model to:</p> | |||
<ul> | |||
<li>Retain relevant information while filtering out noise.</li> | |||
<li>Process inputs in a context-aware manner.</li> | |||
<li>Adapt dynamically rather than using fixed transformations.</li> | |||
</ul> | |||
<p>Since this approach disrupts efficient convolution-based operations, the authors implement a hardware-friendly selective scan method, optimized for GPUs.</p> | |||
<p>The Mamba architecture diverges from traditional Transformers by merging sequence transformations with MLP layers into a streamlined, stackable block.</p> | |||
</ | <h3>Experimental Results</h3> | ||
<h4>1. Synthetic Tasks</h4> | |||
<p>Mamba successfully solves sequence modeling benchmarks such as selective copying and induction heads, demonstrating generalization to sequences up to a million tokens long.</p> | |||
< | <h4>2. Language Modeling</h4> | ||
<p>Trained on The Pile and evaluated on LAMBADA, HellaSwag, and ARC, Mamba-1.4B:</p> | |||
<ul> | |||
<li>Outperforms Pythia models of comparable or larger size.</li> | |||
<li>Matches "Transformer++" models with fewer parameters.</li> | |||
<li>Achieves faster inference without relying on key-value caching.</li> | |||
</ul> | |||
<h4>3. Genomics (DNA Modeling)</h4> | |||
<p>Mamba scales efficiently on the HG38 genome dataset and excels at species classification (e.g., distinguishing human, chimp, and gorilla DNA).</p> | |||
<h4>4. Audio Modeling</h4> | |||
<p>Mamba surpasses S4-based models in waveform generation and beats GANs and diffusion models on speech datasets like SC09.</p> | |||
- | <h4>5. Efficiency</h4> | ||
<p>The selective scan mechanism offers significant efficiency gains:</p> | |||
<ul> | |||
<li>Faster than FlashAttention-2 for long sequences.</li> | |||
<li>4–5x higher inference throughput than Transformers of similar size.</li> | |||
<li>Reduced memory usage by eliminating key-value caching.</li> | |||
</ul> | |||
</div> | |||
</div> | |||
= | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
== Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces == | |||
=== Presented by: === | |||
Liang Wu, Jingcheng Yu, Candace Ng | |||
=== Paper Citation === | |||
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752. | |||
=== Background === | |||
=== Technical Contributions === | |||
- Introduce a selection mechanism for state space models that enables input-dependent state updates | |||
- Develop a hardware-aware recurrent scan algorithm for efficient computation | |||
- Propose Mamba, an attention-free, linear-time architecture that achieves state-of-the-art performance across multiple modalities | |||
=== Architecture === | |||
Mamba integrates selective SSMs into a single homogeneous block. Each block comprises: a linear projection, a selective SSM layer, and an MLP block. This architecture is attention-free and scales linearly in sequence length. | |||
=== Related Work === | |||
- Builds on structured SSMs (S4, S5) and architectures such as H3 and Hyena. | |||
- Demonstrates theoretical and empirical connections to classical RNN gating. | |||
<b> Future Directions:</b> | |||
- Scale to larger models and refine training recipes | |||
- Extend Mamba to multimodal tasks (e.g. video) | |||
- Explore additional downstream affordances (fine-tuning, in-context learning, quantization) | |||
=== Summaries of Key Points === | |||
==== 1. Motivation: Faster and More Flexible Sequence Modeling ==== | |||
Large-scale "foundation models" (FMs) largely rely on Transformers (with attention) to handle long-range context. However, attention has quadratic complexity in sequence length and requires storing the entire key–value cache. This leads to high memory usage and slow generation. Meanwhile, Structured State-Space Models (SSMs) offer linear-time processing through either convolution or recurrence, achieving excellent results in domains like audio and vision. Yet they typically assume constant, time-invariant parameters, which hinders their ability to handle discrete or "content-based" tasks such as language. | |||
==== 2. Proposed Approach: Selective State-Space Models (S6) ==== | |||
The paper introduces a selection mechanism to inject content-based (input-dependent) reasoning into SSMs. Specifically: | |||
- Parameterizing SSM transitions by inputs: Conventional SSMs have fixed transition matrices (A, B, C), but here, B and C—and crucially, the step-size parameter Δ—become input-dependent. This small twist allows the model to "select" which inputs matter and how strongly to update its hidden state, like gating in RNNs. | |||
- Efficient "Selective Scan" Implementation: Because parameters vary with the input, one cannot use global convolution for training (which was key to SSM efficiency). Instead, the authors design a hardware-aware parallel scan on GPUs, fusing operators so that large intermediate states reside only in on-chip memory. This approach retains O(L) time scaling in sequence length while dramatically cutting overhead. | |||
==== 3. Mamba Architecture ==== | |||
To showcase selective SSMs, they build Mamba, a simple "attention-free" architecture. Each layer is effectively an MLP with a parallel SSM path included. By training only these blocks end-to-end (without multi-head attention), Mamba delivers: | |||
- Linear-time training because it uses convolution-like or scan-based methods. | |||
- Constant-time inference per token (no caching entire past context), which is especially beneficial for extremely long sequences. | |||
Additionally, Mamba integrates gating mechanisms within its SSM path, adaptively controlling information flow to better capture complex dependencies. Its streamlined layer design allows hardware-aware GPU optimizations, enhancing throughput and scalability compared to attention-based models. | |||
==== 4. Experiments Across Multiple Domains ==== | |||
1. Synthetic Tests (Selective Copying, Induction Heads): Mamba (or more precisely, its selective mechanism) solves tasks where ordinary time-invariant SSMs fail. It can ignore irrelevant tokens and preserve essential ones. | |||
2. | 2. Language Modeling: On The Pile (multi-domain text), Mamba matches or exceeds standard Transformer perplexities when model size scales to ~1B parameters. It also outperforms other sub-quadratic methods (e.g., Hyena, RWKV). Notably, Mamba obtains 4–5× higher inference throughput than Transformers because it avoids large key–value caches. | ||
3. Genomics: Mamba attains better perplexities than alternative long-sequence models (HyenaDNA) on a human genome corpus, especially at very long context (up to millions of tokens). | |||
4. Audio Waveforms: The model outperforms prior state-of-the-art (e.g., SaShiMi) in generating raw speech signals, achieving lower FID scores and better sample quality. | |||
=== | ==== 5. Significance and Conclusions ==== | ||
Mamba demonstrates that: | |||
1. | 1. Combining gating/selection with state-space recurrences, and <br> | ||
2. A carefully engineered GPU scan can outperform Transformers in certain tasks and match them in others, all with linear-time complexity. This opens a path for foundation models in domains needing extremely long context—like genomics, raw audio, and (possibly) next-generation language modeling—where Transformers become unwieldy. The authors highlight that future directions include scaling Mamba to billions of parameters (7B+), exploring interpretability, and combining it with other approaches (e.g., partial attention) for even broader performance gains. | |||
</div> | </div> | ||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
== Group | == Group 3 Presentation: Mamba: Linear-Time Sequence Modeling with Selective State Spaces == | ||
=== Presented by: === | === Presented by: === | ||
Liang Wu, Jingcheng Yu, Candace Ng | |||
=== Paper Citation === | === Paper Citation === | ||
=== | Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752. | ||
=== Summaries of Key Points === | |||
==== Introduction ==== | |||
The paper presents a comprehensive review of Graph Neural Networks (GNNs), a class of deep learning models designed to perform inference on data described by graphs. The authors highlight the growing importance of GNNs due to the prevalence of graph-structured data in numerous real-world domains, including social networks, biological systems, and knowledge graphs. | |||
==== Taxonomy of Graph Neural Networks ==== | |||
The authors classify GNN models into several broad categories based on their architectural design and learning paradigms: | |||
• Recurrent Graph Neural Networks (RecGNNs) – The earliest GNNs that iteratively update node representations using recurrent architectures. | |||
• Convolutional Graph Neural Networks (ConvGNNs) – Inspired by CNNs, these models generalize convolution operations to graph domains. Sub-categories include: | |||
Spectral-based methods (e.g., ChebNet, GCN) | |||
Spatial-based methods (e.g., GraphSAGE, GAT) | |||
• Graph Autoencoders (GAEs) – Unsupervised models that encode nodes into embeddings for reconstruction or downstream tasks. | |||
• Spatial-Temporal GNNs (STGNNs) – Models designed to handle dynamic graphs, integrating both spatial and temporal information. | |||
==== Model Strengths and Design Elements ==== | |||
Permutation invariance and locality are core inductive biases that make GNNs powerful. The use of attention mechanisms in models like GAT improves flexibility by learning different importance weights for neighbors. Sampling strategies and pooling operations are key to scaling GNNs to large graphs and improving expressiveness. | |||
==== Applications ==== | |||
The review outlines a wide array of applications where GNNs have shown strong performance: | |||
• Node Classification – Predicting user attributes in social networks. | |||
• Link Prediction – Knowledge graph completion. | |||
• Graph Classification – Molecular property prediction in chemistry. | |||
- | • Recommendation Systems – Leveraging user-item interaction graphs. | ||
- | • Traffic and Time-Series Forecasting – Using STGNNs for spatio-temporal modeling. | ||
==== Challenges and Open Problems ==== | |||
The paper identifies several pressing challenges in the field: | |||
• Scalability – Efficient training on large graphs remains difficult. | |||
• Over-smoothing – Deep GNNs tend to produce indistinguishable node embeddings. | |||
• Dynamic Graphs – Many GNNs struggle with real-time updates and evolving structures. | |||
• Theoretical Understanding – There is limited theoretical analysis compared to other deep learning models. | |||
=== | === Constructive Critique and Review === | ||
This paper offers a thorough and timely survey of Graph Neural Networks (GNNs), providing readers with a structured understanding of the landscape, key architectures, and diverse applications. The taxonomy of models—ranging from recurrent and convolutional GNNs to autoencoders and spatio-temporal variants—is well-organized and helps demystify the evolution and variety of approaches in the field. The paper’s value is particularly notable for newcomers and practitioners who seek a foundational overview of GNN concepts and developments. | |||
One of the major strengths of the paper lies in its clear exposition of the differences between spectral and spatial methods, which are often a source of confusion. Additionally, by covering both theoretical concepts and practical use cases, the review bridges the gap between academic research and real-world implementation. The inclusion of application domains such as recommender systems, molecular biology, and traffic forecasting shows the breadth of GNN utility and makes the review relevant across disciplines. | |||
- | However, while comprehensive, the paper could be improved by deepening its critical analysis of the methods it surveys. For instance, while many models are described, their comparative advantages and trade-offs are not always fully explored. A clearer discussion on which architectures perform best under what circumstances—e.g., in sparse vs. dense graphs, or static vs. dynamic environments—would be valuable for practitioners making model selection decisions. | ||
- | Moreover, although challenges like over-smoothing and scalability are mentioned, the discussion remains somewhat high-level. Providing more concrete examples of how recent works attempt to mitigate these issues would enhance the review’s depth. Theoretical gaps in GNN research are also acknowledged but not elaborated upon in a way that guides future investigation. | ||
Overall, the paper serves as an essential entry point into the field of GNNs. With added critical perspective and more technical comparison among models, it could serve not only as an introduction but also as a practical reference for advanced researchers. | |||
- | </div> | ||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces == | |||
=== Presented by: === | |||
Liang Wu, Jingcheng Yu, Candace Ng | |||
=== Paper Citation === | |||
- | Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752. | ||
=== Summaries of key points === | |||
Nowadays, transformers are extremely popular because they are great at finding relationships in data. However, they struggle with long sequences since their calculations grow too fast as the input gets longer. In comparision, State Space Models (SSMs) are more efficient for long sequences, but they have a limitation: they can't easily decide what information to keep or forget. Mamba solves this problem by improving a standard SSM model so that it can selectively adjust its parameters based on the input. | |||
==== Standard SSM (S4): How It Works ==== | |||
SSMs process sequences differently from transformers. Instead of checking all relationships in the data, they store past information in a hidden state and update it step by step. | |||
The basic rule for an SSM is: | |||
- | <math> h_t = A h_{t-1} + B x_t </math>, <math> y_t = C h_t </math> | ||
Where: | |||
- | - <math>h_t</math> is the hidden state (memory of past information). | ||
- | - <math>x_t</math> is the current input. | ||
- <math>A, B, C</math> are numbers that the model learns to control how information flows. | |||
- <math>y_t</math> is the output. | |||
- A | This type of model can represent long-range dependencies, and previous models like S4, DSS, and S5 have explored this direction. | ||
The problem? The issue is that A, B, and C stay fixed, meaning the model always updates its memory in the same way, no matter what input it sees. This makes it less adaptable, especially for complex tasks like language modeling. | |||
==== Selective SSM (S6): Mamba’s Improvement ==== | |||
Mamba introduces Selective State Space Models (S6), where the key values A, B, and C change dynamically based on the input. This allows the model to decide what to store and what to forget in real-time. | |||
Instead of using fixed values, Mamba updates them dynamically: | |||
<math> h_t = A_t h_{t-1} + B_t x_t </math>, <math> y_t = C_t h_t </math> | |||
Now, A, B, and C depend on the input <math>x_t</math>, making the model much more flexible and allowing it to adjust its memory as needed. | |||
- | Also, Mamba uses a hardware-aware scan algorithm, which makes its computations 20-45x faster than older methods and 4-5x faster than transformers. This means it can handle longer sequences without slowing down. | ||
==== Experiment Results ==== | |||
Mamba has been tested in different fields: | |||
- Language modeling: Performs as well as transformers but is more efficient. | |||
- DNA sequencing: Works well on long sequences, beating many existing models. | |||
- Audio processing: Outperforms other SSM models for speech tasks. | |||
==== Limitations of Mamba ==== | |||
Mamba is promising, but it has a few challenges: | |||
1. Scaling up: It hasn’t been tested on extremely large models yet, so we don’t know if it can compete with transformers at that level. | |||
2. Trade-offs in memory selection: Choosing what to remember and forget works well for things like text but might not be as useful for continuous data like speech. | |||
3. Lack of a mature ecosystem: Transformers have been around longer, so they have more tools and techniques available. Mamba still has to catch up. | |||
Even with these issues, Mamba is an exciting step forward for sequence modeling | |||
</div> | </div> | ||
Line 1,097: | Line 1,157: | ||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
== Group | |||
== Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces == | |||
=== Presented by: === | === Presented by: === | ||
Liang Wu, Jingcheng Yu, Candace Ng | |||
=== Paper Citation === | === Paper Citation === | ||
T. | Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752. | ||
=== Background === | |||
Transformers are the backbone of most foundation models today, especially in language tasks. Their strength lies in the self-attention mechanism, which allows for flexible information routing across tokens. However, this comes with a computational cost: both time and memory scale quadratically with sequence length. This makes training and inference inefficient, especially for long sequences. | |||
There’s been growing interest in finding efficient alternatives. Structured state space models (SSMs) offer a different route. Inspired by control theory, they compute with linear or near-linear complexity and have worked well in domains like audio. But they’ve consistently struggled with tasks involving discrete and information-dense inputs, like text. One major issue is that existing SSMs apply the same operations at each time step. This time-invariant design makes them fast but limits their ability to reason based on content. | |||
=== Main Idea === | |||
Mamba introduces selective state space models, which allow the model to change how it processes each input based on the input itself. In earlier SSMs, parameters like B and C (which control how inputs are added to the state and how the state is turned into output) were fixed across time. In Mamba, these parameters vary with the input token. | |||
This makes the model capable of: | |||
1. Retaining relevant tokens while forgetting unimportant ones. | |||
2. Filtering out noise or filler content. | |||
3. Adapting its internal state in a context-aware manner. | |||
Of course, this also breaks the efficient convolution trick that earlier SSMs used, since convolutions require fixed kernels. To deal with this, the authors implement a custom recurrent scan that is hardware-friendly and memory-efficient, especially on GPUs. This scan computes the model step-by-step, but in a way that avoids the usual memory bottlenecks of recurrent models. | |||
Rather than using the traditional Transformer layout of attention followed by an MLP block, Mamba builds a simpler block: | |||
a. It merges the sequence transformation (via the selective SSM) with the MLP into a single unit. | |||
b. This block is stacked repeatedly, with residual connections and normalization in between. | |||
c. The resulting model is easier to scale and implement than Transformer-based designs. | |||
The paper also shows that the model works with real-valued parameters (as opposed to the complex-valued versions used in some previous SSMs), which improves compatibility with common deep learning hardware. | |||
=== Experimental & Result === | |||
1. Synthetic Tasks | |||
Mamba is tested on synthetic problems like selective copying and induction heads, which are designed to measure a model’s ability to selectively remember and use earlier parts of a sequence. Mamba solves both tasks and generalizes well to sequences much longer than it saw during training—up to a million tokens. | |||
2. Language Modeling | |||
The authors train Mamba on The Pile and evaluate it on standard zero-shot benchmarks (LAMBADA, HellaSwag, ARC, etc.). | |||
Key findings: | |||
Mamba-1.4B outperforms Pythia models of similar and larger size. | |||
It matches the performance of “Transformer++” variants while using fewer parameters. | |||
- | It runs significantly faster at inference time because it does not rely on key-value caching. | ||
3. Genomics (DNA modeling) | |||
On the HG38 genome dataset, Mamba shows better scaling than baselines as both model size and context length increase. | |||
Even on a challenging species classification task with closely related species (human, chimp, gorilla, etc.), Mamba performs well with long-context inputs. | |||
- | 4. Audio Modeling | ||
Mamba improves on S4-based baselines in waveform modeling. On the SC09 speech generation dataset, it beats previous models, including GANs and diffusion models on several automated metrics. | |||
5. Efficiency | |||
The selective scan implementation is highly optimized: | |||
Faster than FlashAttention-2 for long sequences. | |||
4–5x faster inference throughput than Transformers of similar size. | |||
Uses less memory, since it doesn’t need to cache key/value pairs during generation. | |||
= | </div> | ||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces == | |||
=== Presented by: === | |||
Liang Wu, Jingcheng Yu, Candace Ng | |||
=== | === Paper Citation === | ||
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752. | |||
=== Motivation and Problem Statement === | |||
(1) Attention bottleneck | |||
Transformers have demonstrated strong modeling capacity thanks to self-attention, but attention layers are known to scale quadratically with sequence length. This becomes problematic for very long sequences due to high compute and memory requirements. | |||
(2) Subquadratic approaches | |||
Alternative models (e.g. linear attention, recurrent cells, and structured state space models, or SSMs) achieve subquadratic time complexity. However, in practice, they often lag behind Transformers—especially for discrete, “information-dense” domains like language. | |||
(3) Key challenge | |||
Balancing the efficiency of subquadratic approaches with the “context-compression” power typical of full attention. In particular, standard linear time-invariant (LTI) SSMs struggle to handle tasks that require input-dependent selection or “content-based” reasoning. | |||
=== Contribution === | |||
(1) Selective Mechanism | |||
The paper introduces a selective variant of state space models—“Selective SSM” or S6—whose parameters can dynamically depend on the current input token. This makes the model selectively propagate or ignore information, overcoming the rigidity of time-invariant recurrences. | |||
(2) Hardware-Aware Recurrent Scan | |||
To handle the new time-varying SSM parameters, the authors propose a “scan” algorithm specialized for modern GPUs that leverages efficient memory management (fusing operations and reducing data movement). Despite the recurrent nature, it matches or exceeds the speed of FFT-based or other convolution-based methods for large sequence lengths. | |||
(3) Mamba Architecture | |||
Built on top of this selective SSM layer, “Mamba” is a purely recurrent neural network that omits attention altogether, yet achieves competitive or better performance than Transformers across various domains (language, audio, genomics) while scaling linearly in sequence length. | |||
=== | === Algorithm 1: Standard SSM (S4) === | ||
Structured State Space Models (S4) were initially designed to combine RNN-like recurrences with global convolutions: | |||
(1) Core idea | |||
SSMs are defined by continuous-time parameters. These are discretized (using, for example, a Zero-Order Hold) so that the SSM can operate on discrete sequences. | |||
(2) LTI property | |||
All prior S4-type models are time-invariant—the parameters stay constant for all time steps. This allows S4 to be computed as either (i) a single global convolution or (ii) a linear recurrence. | |||
- | (3) Limitation | ||
Because the transition dynamics do not depend on the input, S4 cannot do content-based selection of the tokens it should store or ignore. | |||
=== Computation of SSMs === | |||
(1) Convolutional mode | |||
Time-invariant S4 can exploit global convolutions (via the expanded kernel) to compute outputs in <math> O(LlogL) </math> or near-linear time. This approach avoids explicitly storing the large hidden state. | |||
- | (2) Recurrent mode | ||
Alternatively, one can compute the same sequence mapping in a step-by-step fashion with <math> O(L) </math> steps but multiplied by the state dimension N. Typical parallel RNN implementations are memory-heavy because the hidden state dimension can be large. | |||
- | (3) Trade-off | ||
Convolution mode is highly parallel (suitable for training) but struggles with time-varying parameters. Hence, prior S4-based models remain LTI to preserve convolution-based efficiency. | |||
- | === Limitations of Linear Time-Invariant SSMs === | ||
(1) Static transitions | |||
Standard SSMs (S4) cannot adapt or filter out unimportant inputs on the fly. | |||
- | (2) Discrete data handling | ||
In discrete, information-dense tasks (like language), one often must selectively attend to critical tokens. Purely LTI models do not have a built-in mechanism for such content-based selection. | |||
=== | === Algorithm 2: Selective SSMs (S6) === | ||
(1) Key idea | |||
Make parts of the SSM parameters become functions of the current input token, hence “selective.” At each time step t, the model can decide whether to store or forget the information from <math> x_t </math> | |||
(2) Effect on recurrence | |||
The system is no longer linear time-invariant. However, it gains the ability to gate hidden states based on content—similar to an RNN gating mechanism but integrated within the SSM formulation. | |||
=== Efficient Implementations of Selective SSMs === | |||
(1) Challenge | |||
With time-varying parameters, the global convolution trick no longer applies. A naive RNN-like approach would be slow (or memory-heavy) when N is large. | |||
- | (2) Hardware-aware parallel scan | ||
The authors design a “selective scan” that operates recurrently but fuses memory reads and writes on GPUs, storing the full hidden state only in fast on-chip memory (SRAM). This avoids the usual overhead of a standard step-by-step approach. | |||
(3) Performance | |||
Benchmarks indicate the proposed selective scan can be faster than attention beyond certain sequence lengths and avoids the large memory overhead of an attention KV-cache. | |||
=== Mamba Architecture === | |||
(1) Simplified design | |||
Mamba blocks combine: | |||
A selective SSM layer (the new S6 variant). | |||
A Gated MLP pathway (or “Gated MLP”) in the same layer. | |||
This merges what used to be a multi-layer approach (SSM + MLP) into a single homogeneous block. | |||
- | (2) Purely recurrent | ||
Mamba is completely attention-free. Each layer processes the input in linear time with the selective scan. | |||
(3) Competitive performance | |||
Despite omitting self-attention, Mamba consistently achieves Transformer-quality or better results on diverse tasks, with far lower memory/time overhead at long context lengths. | |||
=== Interpretations of Selection Mechanisms === | |||
(1) Variable spacing | |||
The selective update effectively allows the model to “jump over” irrelevant tokens, addressing tasks like “selective copying” where the positions of relevant tokens vary. | |||
(2) Filtering context | |||
S6 can decide which tokens to integrate or forget. If the input at time t is unimportant, the update gate can suppress it, preventing noise accumulation in the hidden state. | |||
(3) Boundary resetting | |||
When sequences are concatenated (e.g., different segments back-to-back), selective SSMs can “reset” the hidden state if the new segment is unrelated, mimicking the attention mask for different documents. | |||
=== Overview of Experiments === | |||
(1) Synthetic tasks | |||
Selective Copying / Induction Heads: | |||
Demonstrates that selective SSMs learn to focus on relevant tokens. LTI SSMs fail these tasks, but S6-based Mamba solves them and even extrapolates correctly to much longer sequences. | |||
(2) Language modeling | |||
Scaling laws: | |||
Mamba shows strong scaling, matching or surpassing Transformers on the Pile dataset when model sizes go up to 1B+ parameters. | |||
Zero-shot downstream tasks: | |||
Mamba outperforms or matches similarly sized Transformer baselines on tasks like LAMBADA, HellaSwag, and ARC. | |||
(3) DNA sequences | |||
- | Extremely long contexts: | ||
Mamba uses million-length context and still improves perplexity, while LTI SSMs degrade at such scales. | |||
Classification tasks: | |||
Fine-tuning Mamba at lengths up to 1M tokens surpasses prior approaches on synthetic species classification. | |||
(4) Audio generation | |||
Long-range modeling: | |||
Mamba outperforms convolution-based S4 layers for autoregressive waveforms, especially beyond tens of thousands of time steps. | |||
Speech quality: | |||
On a speech benchmark, Mamba cuts the previous state-of-the-art FID roughly in half, achieving more realistic outputs. | |||
=== Speed and Memory Benchmarks === | |||
1 | (1) Selective scan | ||
Achieves high training speed and memory efficiency on modern GPUs. Outperforms naive recurrent approaches by a large margin. | |||
(2) Inference | |||
As a recurrent model with no need to store a growing KV-cache, Mamba obtains up to 5× higher inference throughput than Transformers of comparable size, especially at batch sizes beyond 1. | |||
=== Related Work and Future Directions === | |||
(1) Transformer adaptations | |||
Many recent efforts approximate or modify attention (linear attention, kernel methods, etc.) to achieve subquadratic complexity—yet none consistently matched Transformers across modalities. | |||
(2) Structured State Spaces | |||
Previous S4 variants excelled at continuous signals; discrete tasks were less successful due to the inability to filter input tokens selectively. | |||
Future | |||
(1) Scaled training: | |||
Investigating even larger Mamba models, or specialized domain tasks (vision, speech). | |||
(2) Low-level optimization: | |||
The fused scan approach might be combined with novel GPU/TPU kernels. | |||
(3) Formal interpretability: | |||
Mechanistically verifying how the model “chooses” tokens would improve transparency. | |||
=== Limitations === | |||
(1) Discrete–Continuous tradeoff | |||
While the selective approach helps with discrete data, the authors note that certain initializations (real or complex) may still matter for stable training on continuous signals. | |||
(2) Complex parameterization | |||
Tuning the selection parameters for each domain can be non-trivial, particularly if one wants to combine multiple forms of gating or advanced expansions. | |||
(3) Residual data dependence | |||
Unlike attention, which explicitly attends to tokens by index, selective SSMs rely on gating from learned projections. Certain tasks might still favor explicit attention or local convolutions. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces == | |||
=== Presented by: === | |||
Liang Wu, Jingcheng Yu, Candace Ng | |||
=== Paper Citation === | |||
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752. | |||
=== Summaries of key points === | |||
Goal: With parameter-dependent input, content-aware selection is achieved while ensuring efficiency and flexibility. | |||
Background: Although the traditional SSM model is linear and efficient, it is weak in dynamic content selection. | |||
Methodology: By making SSM parameters vary with input, information can be selectively remembered. The parallel scan algorithm is used to preserve the linear time complexity. Completely attention-free architecture, each module (Selective SSM + MLP) is a stackable structural unified module. | |||
Result: Performance on Pile data sets is like transformers and long text performance is more stable. Longer DNA sequence contexts can be used for better classification accuracy. Go beyond baselines like S4 on YouTube Mix and SC09. | |||
=== Constructive critiques or reviews === | |||
The structure is clear, the transition is natural, and the explanation is full. | |||
More images can be added from more intuitive descriptions. | |||
Provide more detailed examples to help the audience understand better. | |||
=== Clear explanations to aid understanding === | |||
Parallel Scan: avoids the problem of slow memory and high memory of traditional RNN inference, and the efficiency is almost equal to the attention mechanism but saves computing resources. | |||
Selective SSM: Updates status when seeing keywords and skips irrelevant information. | |||
=== Connections to related works === | |||
S6 can be degenerated into classic RNN gating mechanism (generalization of RNN) | |||
T. Dao and A. Gu bridges SSMs and Transformers in their work “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality”. This paper allows a easier importation of existing training techniques for Transformers to SSM training. T. Dao and A. Gu also proposed a Mamba-2 model to illustrate their theory. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
== Group | == Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces == | ||
=== | === What Problem is Mamba Solving? === | ||
Because attention (the core mechanism in Transformers) compares every token to every other token in the sequence. That’s O(n²) complexity. So if you double the input size, the compute goes up fourfold. That’s fine for 512 tokens—but what about 100,000 tokens? Or a million? You’ll need insane amounts of memory and compute. Enter Mamba: A new architecture designed to process long sequences efficiently—with linear time and memory. That means it scales much better. | |||
=== | === Key Innovations === | ||
1. Selective State Space Models (SSMs): | |||
Traditional SSMs have shown promise in sequence modeling but often fall short in tasks requiring content-based reasoning, particularly with discrete data like text. Mamba enhances SSMs by making their parameters dynamic functions of the input. This adaptability allows the model to selectively retain or discard information based on the context of each token, improving performance in language modeling tasks. | |||
2. Efficient Computation with Hardware-Aware Algorithms: | |||
Incorporating input-dependent parameters in SSMs introduces computational challenges, as it disrupts time-invariant properties that facilitate efficient computation. Mamba addresses this by implementing a hardware-aware parallel algorithm that operates in a recurrent manner. This design ensures that the model maintains linear scalability with sequence length while optimizing for modern hardware architectures. | |||
3. Streamlined Architecture: | |||
Mamba departs from the conventional Transformer structure by eliminating attention mechanisms and even multi-layer perceptron (MLP) blocks. This simplification results in a model that is not only computationally efficient but also achieves faster inference speeds—reportedly five times higher throughput than Transformers—while effectively handling sequences up to a million tokens in length | |||
=== More Explanation === | |||
Transformers are amazing—but they struggle with long sequences. | |||
To grasp Mamba's contributions more thoroughly, it's helpful to contextualize them within the broader landscape of sequence modeling: | |||
State Space Models (SSMs): SSMs are mathematical frameworks used to model time-series data by representing systems with hidden states that evolve over time based on inputs. They have been foundational in various applications, including control systems and signal processing. | |||
Transition from Transformers to Mamba: While Transformers rely on self-attention mechanisms to capture relationships between all tokens in a sequence, Mamba leverages the structured approach of SSMs, enhanced with input-dependent dynamics, to achieve similar or superior performance with improved efficiency. | |||
=== Constructive Critiques === | |||
While Mamba presents significant advancements, certain aspects warrant further exploration: | |||
1. Expressiveness vs. Efficiency Trade-off: | |||
By simplifying the architecture and removing attention mechanisms, there might be concerns regarding the model's ability to capture intricate dependencies within data. It's essential to assess whether this streamlined approach compromises the expressiveness necessary for certain complex tasks. | |||
2. Empirical Validation Across Tasks: | |||
Although Mamba shows promise in several domains, comprehensive evaluations across a broader range of tasks and datasets are crucial to fully establish its versatility and generalizability. | |||
3. Implementation Complexity: | |||
The introduction of hardware-aware algorithms, while beneficial for efficiency, may introduce complexities in implementation. Ensuring that these optimizations are accessible and reproducible for practitioners is vital for widespread adoption. | |||
=== | === A Few Things to Think About === | ||
1. No Attention? | |||
Attention is known for its explicit ability to focus on relevant parts of input. Removing it might hurt tasks that need pinpoint accuracy—like reasoning over multiple steps or focusing on specific tokens far away. | |||
2. Training Stability and Tuning | |||
Complex architectures like Mamba sometimes require careful hyperparameter tuning. That might limit plug-and-play usability, at least initially. | |||
3. Interpretability | |||
Attention maps (in Transformers) can sometimes be visualized to explain model behavior. Mamba’s internal state dynamics may be less interpretable. | |||
</div> | |||
- | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
== | == Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces == | ||
=== Presented by: === | |||
Liang Wu, Jingcheng Yu, Candace Ng | |||
- | === Paper Citation === | ||
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752 | |||
=== Motivation and Problem Statement === | |||
(1) Attention overhead | |||
Transformers achieve great performance with attention, but their O(n²) complexity in sequence length limits them for very long sequences due to high memory and compute costs. | |||
(2) Subquadratic alternatives | |||
While linear attention and SSMs offer better scaling, they often fall short on tasks like language modeling, where discrete tokens require flexible, input-dependent representation. | |||
(3) Core challenge | |||
Can we match the modeling power of attention while retaining linear-time efficiency? Standard linear time-invariant (LTI) SSMs lack the flexibility to perform content-based selection. | |||
=== | === Contribution === | ||
(1) Selective SSM (S6) | |||
Mamba introduces a state-space model where parameters (A, B, C) are functions of the current input. This dynamic mechanism lets the model choose what to store or ignore based on context. | |||
(2) Hardware-Aware Recurrent Scan | |||
A GPU-optimized algorithm processes time-varying recurrence efficiently by using on-chip memory and fused operations. It handles long sequences with high speed and low memory use. | |||
(3) Mamba Architecture | |||
Mamba replaces attention entirely with a purely recurrent, stackable block composed of a selective SSM and a gated MLP. Despite its simplicity, Mamba matches or outperforms Transformers. | |||
=== Algorithm 1: Standard SSM (S4) === | |||
- Based on linear time-invariant recurrence: | |||
h_t = A h_{t-1} + B x_t, \quad y_t = C h_t | |||
- Allows for global convolution or step-wise recurrence. | |||
- Limitation: Parameters are fixed, unable to select information based on input. | |||
=== Computation of SSMs === | |||
(1) Convolution mode | |||
Allows parallelism but struggles with input-dependent (time-varying) parameters. | |||
(2) Recurrent mode | |||
Linear in sequence length but costly in memory when hidden size is large. | |||
(3) Trade-off | |||
Convolution is fast but static; recurrence is flexible but slower—especially for varying inputs. | |||
1 | === Algorithm 2: Selective SSMs (S6) === | ||
- Dynamic parameters: | |||
h_t = A(x_t) h_{t-1} + B(x_t) x_t, \quad y_t = C(x_t) h_t | |||
- Input-aware gating allows filtering irrelevant tokens, akin to RNN gating but within an SSM structure. | |||
=== Efficient Implementations of Selective SSMs === | |||
(1) Bottleneck | |||
Naively processing each step slows down training or bloats memory usage. | |||
(2) Selective scan | |||
Custom GPU-friendly scan algorithm keeps memory local and throughput high—matching or beating attention at long sequence lengths. | |||
(3) Advantage | |||
No KV-cache required; scales efficiently with input size and hardware. | |||
1. | === Mamba Architecture === | ||
(1) Unified block | |||
Each layer contains a selective SSM layer and a gated MLP, merged into one module. | |||
(2) Fully recurrent | |||
Linear-time processing—no attention heads, no multi-head complexity. | |||
(3) Performance | |||
Outperforms Transformers on language, audio, and genomics with better efficiency. | |||
=== Interpretations of Selection Mechanisms === | |||
(1) Skipping irrelevant tokens | |||
S6 can “jump” across unimportant inputs, ideal for tasks like selective copy. | |||
(2) Dynamic context filtering | |||
Tokens at each step are filtered or integrated based on input relevance. | |||
(3) Reset capability | |||
When inputs shift (e.g., document boundaries), S6 can reset its hidden state—like segment-aware attention masks. | |||
=== Overview of Experiments === | |||
(1) Synthetic tasks | |||
Mamba solves selective copy and induction head benchmarks where LTI SSMs fail. | |||
(2) Language modeling | |||
Matches or surpasses Transformers on The Pile dataset up to 1B+ parameters. Strong in zero-shot tasks like LAMBADA, HellaSwag, ARC. | |||
(3) Genomics | |||
Performs well on million-token contexts, beating existing baselines on classification tasks. | |||
(4) Audio | |||
Achieves better FID scores and quality in speech generation. Outperforms S4 on long-range audio modeling. | |||
=== Speed and Memory Benchmarks === | |||
(1) Training | |||
Selective scan enables high throughput with minimal memory pressure. | |||
(2) Inference | |||
Mamba gets 4–5× faster inference throughput than Transformers by avoiding large key–value caches. | |||
=== Related Work and Future Directions === | |||
(1) Related efforts | |||
Builds on S4, Hyena, and other structured SSMs but adds input-dependent dynamics missing in prior work. | |||
(2) Future work | |||
- Scaling Mamba to 10B+ parameters | |||
- Combining with attention mechanisms | |||
- Formalizing interpretability of selection gates | |||
- Exploring new domains like vision and multi-modal data | |||
=== Limitations === | |||
(1) Initialization sensitivity | |||
Real/complex state initializations affect stability on continuous data. | |||
(2) Parameter complexity | |||
Domain-specific tuning of gates and parameter schedules may be needed. | |||
(3) Interpretability | |||
No attention weights means token-level decisions are harder to visualize. | |||
</div> | </div> | ||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
== Group | == Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces == | ||
=== Presented | === Presented by: === | ||
Liang Wu, Jingcheng Yu, Candace Ng | |||
=== Paper Citation === | === Paper Citation === | ||
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752 | |||
=== Research Motivation and Challenge === | |||
Modern foundation models use Transformers with powerful self-attention mechanisms, but the attention mechanism scales quadratically and is limited to fixed-sized windows. While subquadratic alternatives have been proposed, they often struggle with discrete, information-dense data. | |||
The main challenge lies in finding a balance between computational efficiency and effective context compression that adapts to the content. | |||
=== Contributions === | |||
1. A mechanism is introduced that enables state updates to be dynamically adjusted based on the input, improving the model's ability to adapt to varying inputs. | |||
2. A new algorithm is developed that optimizes computation by considering hardware characteristics, especially memory access patterns and parallelism, to enhance computational efficiency. | |||
3. Mamba is proposed as a new architecture that operates without traditional attention mechanisms and processes in linear time, achieving state-of-the-art performance across various modalities. | |||
=== Mamba === | |||
Achitecture: Mamba integrates selective SSMs into a single homogeneous block. Each block comprises a linear projection, a selective SSM layer and an MLP block. | |||
Experiment performance: Mamba performs well in synthetic task, language modeling, DNA modeling, audio modeling and generation. | |||
=== Future Directions === | |||
1. Scale to larger models and refine training recipes | |||
2. Extend Mamba to multimodal tasks | |||
3. Explore additional downstream affordances | |||
</div> | |||
< | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
== Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces == | |||
=== Presented by: === | |||
Liang Wu, Jingcheng Yu, Candace Ng | |||
=== Paper Citation === | |||
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752. | |||
=== | === Summary === | ||
The | 1. Introduction of Mamba: The paper introduces Mamba, a novel sequence modeling architecture based on selective state space models (SSMs). Mamba addresses the limitations of traditional Transformers, such as quadratic scaling with sequence length and inefficiency in handling long sequences, by leveraging selective SSMs that enable linear-time computation and improved performance. | ||
- | 2. Selective State Space Models (SSMs): | ||
Selection Mechanism: Mamba introduces input-dependent SSM parameters, allowing the model to selectively propagate or forget information based on the current token. This addresses the inability of prior SSMs to perform content-based reasoning. | |||
Hardware-Aware Algorithm: Despite losing the efficiency of convolutions due to selectivity, Mamba employs a parallel scan algorithm optimized for modern hardware (GPUs), achieving faster computation and linear scaling in sequence length. | |||
3. Simplified Architecture: Mamba combines the design of prior SSM architectures with MLP blocks into a single, homogeneous block, eliminating the need for attention or even traditional MLP layers. This simplification enhances efficiency and scalability. | |||
- | 4. Empirical Performance: | ||
Language Modeling: Mamba matches or outperforms Transformers of similar or larger sizes in both pretraining perplexity and downstream tasks. For example, Mamba-3B outperforms Transformers twice its size on common-sense reasoning tasks. | |||
DNA and Audio Modeling: Mamba excels in modeling long sequences in genomics and audio, showing improved performance with context lengths up to 1 million tokens. | |||
Synthetic Tasks: Mamba solves tasks like Selective Copying and Induction Heads, demonstrating its ability to handle content-aware and context-aware reasoning. | |||
5. Efficiency: Mamba achieves 5× higher inference throughput than Transformers due to its recurrent nature, which avoids the need for a KV cache. It also scales linearly with sequence length, making it suitable for long-context applications. | |||
6. Ablations and Insights: | |||
The selection mechanism (especially input-dependent Δ) is critical for performance. | |||
Real-valued SSMs perform comparably to complex-valued ones in most settings, except for continuous modalities like audio. | |||
Increasing the state dimension (N) significantly improves performance with minimal parameter overhead. | |||
=== Constructive Critique === | |||
1. Strengths: | |||
The | Innovative Approach: The selective SSM mechanism is a novel solution to the limitations of LTI models, enabling content-aware reasoning without sacrificing efficiency. | ||
Comprehensive Evaluation: The paper validates Mamba across diverse domains (language, DNA, audio) and demonstrates scalability to extremely long sequences. | |||
Practical Impact: The linear-time inference and training scalability make Mamba a strong candidate for real-world applications requiring long-context modeling. | |||
2. Potential Limitations: | |||
Generalization to Larger Models: While Mamba performs well at scales up to 3B parameters, its performance at larger scales (e.g., 7B+ parameters) remains to be verified, especially compared to models like RWKV or RetNet. | |||
Continuous vs. Discrete Modalities: The trade-off between selective and LTI SSMs suggests that Mamba may not universally outperform LTI models (e.g., in audio tasks). Further exploration of hybrid approaches could be beneficial. | |||
Complexity of Implementation: The hardware-aware algorithm, while efficient, may require specialized optimization for different hardware setups, potentially limiting accessibility. | |||
=== Connections to Related Work === | |||
1. SSM Variants: Mamba builds on structured SSMs (S4, S5) but introduces selectivity, distinguishing it from LTI models like Hyena or RetNet. The connection to RNN gating bridges classical and modern sequence modeling. | |||
2. Efficient Attention: Mamba’s linear-time scaling contrasts with subquadratic attention variants (e.g., Linear Attention, Performer). The paper positions Mamba as the first attention-free model to match Transformer quality. | |||
3. Long-Context Models: Mamba’s million-length scalability aligns with recent efforts like LongNet and HyenaDNA but demonstrates superior empirical gains with controlled experiments (e.g., performance improves monotonically with context length). | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 4 Presentation: Learning spatiotemporal dynamics with a pretrained generative model == | |||
=== Presented by: === | |||
- Karolina Suszek | |||
- Negin Amou | |||
- Muhammad Azeem | |||
=== Paper Citation === | |||
Z. Li et al., “Learning spatiotemporal dynamics with a pretrained generative model,” Nature Machine Intelligence, vol. 6, no. 12. Springer Science and Business Media LLC, pp. 1566–1579, Dec. 06, 2024. doi: 10.1038/s42256-024-00938-z. | |||
=== Background === | |||
- Spatiotemporal dynamics: how the state of a physical system varies with space and time | |||
- Real datasets often contain data with sparse measurements where there are a limited number of sensors available. There needs to be a way to convert the sparse measurement data into a full spatiotemporal field. | |||
- Existing solutions learn to map the input to output and ignores missing data, but this reduces the models ability to generalize. | |||
- Paper proposes the use of Sparse-Sensor-Assisted Score-Based Generative Model (S3GM) which uses unlabeled data durring training and can reconstruct incomplete data after training to make accurate predictions even when there isnt much information available. | |||
- Key Idea: Learn the probabilith distribution of spatiotemporal data using score-based generative model and refine the samples via schochastic sampling | |||
=== Technical Contributions === | |||
The main proposed model is the Sparse-Sensor-Assissted Score-Based Generative Model. It learns patterns from a large amount of data | |||
before hand. It also is unsupervised so it does not require any labels during training. It tries to learn the significant features | |||
of the data/natural patterns. After training, the model can be used to take incomplete data and reconstruct the missing parts | |||
to make predictions. | |||
Core Components: | |||
- Pre Training Stage: Learns the joint probability distribution of the data | |||
- Generating Stage: Use a stochastic differential equation to refine and generate full field predictions | |||
- Refinement Mechanism: Ensure Allignment with observations and enforce sequence consistency | |||
Some of the common applications of this model are Turbulent flow modeling, climate forecasting, and physics-based simulations. | |||
=== Summaries of key points === | |||
- Challenge Addressed: Traditional end-to-end learning models often struggle with generalization in reconstructing spatiotemporal dynamics, particularly when data is sparse—a common scenario in real-world applications. | |||
- S³GM Methodology: | |||
Pretraining Phase: An unconditioned generative model is pretrained in a self-supervised manner on a comprehensive dataset, capturing the joint distribution of the system's dynamics. | |||
Generation Phase: The pretrained model is conditioned on new, sparse measurements to reconstruct and predict the full-field spatiotemporal dynamics. | |||
- Validation and Performance: S³GM's efficacy was tested across multiple dynamical systems using synthetic, real-world, and laboratory datasets, including applications in turbulent flow modelling and weather forecasting. The results demonstrated that S³GM achieves high accuracy, generalizability, and robustness, even when faced with significant data sparsity and noise. | |||
S³GM offers a promising approach for modeling and predicting complex spatiotemporal dynamics in situations where data is limited, leveraging the strengths of pretrained generative models to enhance performance in small data regimes. | |||
=== Related Works === | |||
Some of the related works in this area are GPT-ST: Generative Pre-Training of Spatio-Temporal Graph Neural Networks. This framework employs a spatio-temporal masked autoencoder designed to capture both intra- and inter-cluster region semantic relationships, which are often overlooked in existing approaches. Another one is Spatio-Temporal Few-Shot Learning via Diffusive Neural Network Generation, where a generative pre-training framework (GPD) that addresses data scarcity in spatiotemporal modeling. By performing generative pre-training on neural network parameters optimized with data from source cities, the framework enables the generation of tailored neural networks guided by prompts. | |||
Many other methods that map the the sparse measurements (input) to the full spatial temporal reconstructed field include the following: | |||
<ul> | |||
<li>Using Fourier or Laplace transforms to learn mappings between function spaces. Fourier transform transforms the sparse input data into the frequency domain, where reconstructed techniques can be applied more easily.</li> | |||
<li>Using CNN's to learn latent representations of full spatial-temporal fields and reconstruct missing regions through an encoder and decoder</li> | |||
<li>Using PINN's to incorporate physics laws (differential equations) into the loss function. This can be useful when data is sparse or noisy as they enforce physical consistency in the absence of complete ground-truth data.</li> | |||
</ul> | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 4 Presentation: Learning spatiotemporal dynamics with a pretrained generative model == | |||
=== Presented by: === | |||
- Karolina Suszek | |||
- Negin Amou | |||
- Muhammad Azeem | |||
=== Paper Citation === | |||
Z. Li et al., “Learning spatiotemporal dynamics with a pretrained generative model,” Nature Machine Intelligence, vol. 6, no. 12. Springer Science and Business Media LLC, pp. 1566–1579, Dec. 06, 2024. doi: 10.1038/s42256-024-00938-z. | |||
=== Summaries of key points === | |||
==== Spatiotemporal Dynamics and S3GM ==== | |||
Spatiotemporal dynamics describe physical systems (e.g., climate forecasting, fluid dynamics) that evolve over space and time. However, in real-world situations, sensor data is usually incomplete, e.g., whether stations may cover only a few locations; Sensors might capture only the magnitude of velocity, missing direction. Standard deep learning models (FNO, PINN, U-Net, LNO, DeepONet) often struggle to adapt if the sensor setup changes or if the system behaves unexpectedly. This lack of generalization means models often need to be retrained for each new situation. To overcome this, S3GM is proposed. | |||
==== How Does S3GM Work?==== | |||
</ | Instead of learning the full probability distribution <math> p(x) </math> (which is complex), S3GM learns the gradient of the data distribution, called the score function: <math>s(x) = \nabla_x \log p(x)</math>. This tells the model which direction the data is most likely to change, making learning more efficient. | ||
It should also be noted that real-world data is often messy—noisy, incomplete, or missing. S3GM handles this using DSM: | |||
- Adds a small amount of noise to the data and trains the model to remove it. | |||
- Forces the model to focus on true underlying patterns rather than memorizing raw data. | |||
By repeatedly removing noise, the model deeply understands the true data structure—even when parts are missing. | |||
Once the model learns the data’s structure, it reconstructs missing information using stochastic differential equations (SDEs), which has 3 terms: | |||
- Drift Term: guides reconstruction toward likely states. | |||
- Diffusion Term: adds controlled randomness to explore multiple solutions. | |||
- Correction Term: uses real sensor data to ensure consistency. | |||
==== How Well Does S3GM Work? ==== | |||
S3GM is tested on four different systems to see how well it reconstructs and predicts missing data: | |||
'''Experiment 1: Predicting Chaotic Behavior (Kuramoto-Sivashinsky Equation)''' | |||
* '''Challenges Tested:''' | |||
- Sparse Spatial Data (few sensor readings) | |||
( | - Fourier Transform Domain (frequency-based measurements) | ||
- Limited Initial Data (predict future states with few frames) | |||
* '''Results:''' | |||
- S3GM outperformed U-Net, FNO, and DeepONet, achieving lower errors. | |||
- Stable even with limited input data. | |||
'''Experiment 2: Reconstructing Turbulent Flow (Kolmogorov Flow)''' | |||
* '''Challenges Tested:''' | |||
- Trained on low-turbulence data. | |||
- Tested on high-turbulence data. | |||
* '''Results:''' | |||
- Accurately reconstructed velocity fields and vorticity patterns. | |||
'''Experiment 3: Climate Data Reconstruction (ERA5 Dataset)''' | |||
* '''Challenges Tested:''' | |||
- Extreme Data Sparsity (only 1% wind speed measurements available). | |||
- Hidden Variables (missing temperature and pressure). | |||
- Noisy Measurements (Gaussian noise added). | |||
* '''Results:''' | |||
- Successfully reconstructed missing climate variables. | |||
- Performance improved with more sensor data. | |||
'''Experiment 4: Flow Around a Cylinder''' | |||
* '''Challenges Tested:''' | |||
- Spatiotemporal Gaps (only specific cross-sections measured). | |||
- Time-Averaged Data (some measurements were only available as averages). | |||
* '''Results:''' | |||
- Accurately reconstructed instantaneous and time-averaged flow fields. | |||
- Outperformed physics-informed neural networks (PINNs). | |||
==== Limitations of S3GM ==== | |||
While powerful, S3GM has limitations: | |||
- Computational Cost: Pre-training is resource-intensive. | |||
- Data Quality Dependence: Best performance with diverse, high-quality data. | |||
- Generalization Issues: May struggle with entirely new dynamics. | |||
- Processing Speed: Iterative reconstruction can be slower than traditional methods. | |||
Despite these challenges, S3GM is a promising tool, and if these are improved, it could be even more powerful. | |||
</div> | </div> | ||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
== Group 4 Presentation: Learning spatiotemporal dynamics with a pretrained generative model == | |||
== Group | |||
=== Presented by: === | === Presented by: === | ||
Karolina Suszekm, Negin Amou, and Muhammad Azeem | |||
=== | === Paper Citation === | ||
Z. Li et al., “Learning spatiotemporal dynamics with a pretrained generative model,” Nature Machine Intelligence, vol. 6, no. 12. Springer Science and Business Media LLC, pp. 1566–1579, Dec. 06, 2024. doi: 10.1038/s42256-024-00938-z. | |||
=== | === Background === | ||
This paper proposes Mamba, a new type of sequence model designed to match the modeling quality of Transformers while improving computational efficiency. The key innovation is a selective state space model (SSM) that can reason based on content and scale linearly with sequence length. While providing 4–5x quicker inference than Transformers of comparable size, Mamba shows strong performance across several domains—language, music, and genomics—positioning itself as a general-purpose backbone for foundation models. | |||
Most large models today rely on Transformers, which are powerful but inefficient, especially for long sequences. Both training and inference are bottlenecked by the quadratic scaling with the sequence length of the self-attention mechanism. Efficient substitutes have been state space models (SSMs), which have drawn growing interest. Though current versions have fallen short on jobs like language modelling, these models are recurrent and scale linearly. A major drawback the writers point out is that conventional SSMs are time-invariant, applying the same dynamics at every time step regardless of the input. This limits their capacity to complete tasks needing content-based thinking. | |||
=== | === Main Idea === | ||
The central idea of this paper is to improve state space models by making them selective. Traditional structured state space models (SSMs) apply the same linear operations at every time step, which works well for smooth or continuous data like audio, but not for discrete tasks like language modeling. The authors argue that this is because these models cannot adapt their behavior based on the content of the input. | |||
Mamba addresses this by allowing some of the internal dynamics of the SSM to depend on the current input token. Specifically, the model modifies the SSM parameters (like Δ, B, and C) so that they are no longer fixed, but vary depending on what the model sees at each step. This makes it possible for the model to filter, retain, or discard information in a context-aware way. | |||
This design sacrifices the ability to use fast convolutional implementations, but the authors introduce an alternative they call a selective scan—a custom, hardware-friendly way of computing the state updates efficiently on GPU. This allows the model to maintain linear computational complexity while being much more flexible than previous SSMs. | |||
Mamba’s architecture is also deliberately kept simple. It does not rely on attention, nor does it use the usual Transformer-style MLP blocks. Instead, it stacks blocks based on this new selective SSM design, each combining sequence modeling and nonlinear transformation in one place. | |||
=== | === Experimental & Result === | ||
The authors test Mamba on a wide range of tasks to show both its performance and its scalability. | |||
On synthetic tasks like selective copying and induction heads, Mamba succeeds in learning long-range dependencies that other models fail to capture. It generalizes well even when the test sequences are far longer than the ones it was trained on, reaching up to a million tokens. | |||
In language modeling, they train Mamba on The Pile and compare it to Transformer baselines like Pythia and RWKV. Despite being smaller in size, Mamba-1.4B performs better than Pythia-2.8B on several zero-shot benchmarks. It also matches the performance of more carefully tuned Transformer setups. One major advantage is that Mamba runs faster at inference time—achieving 4 to 5 times the throughput of Transformer models—because it avoids key-value caching. | |||
For genomics, Mamba is trained on the human genome (HG38). Its perplexity improves as the sequence length increases, which is unusual—most models perform worse on longer contexts. On a classification task involving DNA from closely related species (humans, chimps, gorillas, etc.), Mamba significantly outperforms other models, especially at longer input lengths. | |||
In audio modeling, Mamba is plugged into the SaShiMi framework and outperforms it on waveform prediction and speech generation. On the SC09 dataset, it scores better than WaveNet and DiffWave, despite having fewer parameters. | |||
is | |||
Finally, in terms of efficiency, the new scan implementation is fast. It’s faster than both a naive PyTorch loop and FlashAttention-2 for long sequences. Mamba’s speed and memory use scale linearly with sequence length, making it practical for real-world applications with long inputs or limited compute. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 4 Presentation: Learning spatiotemporal dynamics with a pretrained generative model == | |||
=== Presented by: === | |||
Karolina Suszek, Negin Amou and Muhammad Azeem | |||
=== Paper Citation === | |||
Z. Li et al., “Learning spatiotemporal dynamics with a pretrained generative model,” Nature Machine Intelligence, vol. 6, no. 12. Springer Science and Business Media LLC, pp. 1566–1579, Dec. 06, 2024. doi: 10.1038/s42256-024-00938-z. | |||
=== Summary === | |||
The article "Learning spatiotemporal dynamics with a pretrained generative model" introduces a novel approach to reconstructing and predicting full-field spatiotemporal dynamics from sparse sensor measurements using a sparse-sensor-assisted score-based generative model (S³GM). The key points of the paper are as follows: | |||
1. Problem Addressed: Reconstructing spatiotemporal dynamics (e.g., velocity, temperature, pressure fields) from sparse and heterogeneous sensor data is a significant challenge in fields like fluid dynamics, geophysics, and atmospheric physics. Traditional end-to-end learning models struggle with generalization, especially under sparse data conditions common in real-world scenarios. | |||
- | 2. Proposed Solution - S³GM: The authors propose S³GM, which leverages a pretrained generative model to capture the joint distribution of pretraining data in a self-supervised manner. This model is then conditioned on sparse measurements to reconstruct and predict full-field dynamics. Unlike conventional methods that directly map inputs to outputs, S³GM uses a two-step process: pretraining on vast datasets followed by conditional sampling. | ||
- | 3. Performance and Validation: The efficacy of S³GM is demonstrated across multiple dynamical systems, including turbulent flow modeling (e.g., Kolmogorov flow), weather/climate forecasting (using ERA5 data), and laboratory cylinder flow experiments (via PIV measurements). The model excels in zero-shot reconstruction and future-state forecasting, even with high data sparsity (e.g., 8× downsampling) and noise, outperforming baseline methods like U-Net, FNO, DeepONets, and DMD. | ||
=== Key Features === | |||
Accuracy and Robustness: S³GM accurately reconstructs fields and maintains statistical fidelity (e.g., kinetic energy spectra) under varying sparsity levels. | |||
Generalizability: It performs well on unseen data, a significant improvement over traditional models. | |||
Stability: The model shows numerical stability in long-term forecasting, as evidenced by low error accumulation compared to baselines. | |||
Applications and Datasets: The approach is tested on synthetic (e.g., Kuramoto-Sivashinsky equation, Kolmogorov flow), real-world (ERA5 reanalysis), and experimental (cylinder flow at Reynolds numbers 100–250) datasets, with all data and code made publicly available. | |||
Comparison with Baselines: S³GM is benchmarked against seven methods, including neural network-based (U-Net, FNO, PINN) and linear (DMD, piDMD) approaches. It consistently delivers superior performance, particularly in handling complex, sparse, and noisy data. | |||
Implications: The method offers a transformative tool for scientific and engineering applications where sensor data is limited, enhancing our ability to understand and control complex dynamical systems. | |||
</div> | </div> | ||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
== Group 4: Learning spatiotemporal dynamics with a pretrained generative model. == | |||
== | === As presented by: === | ||
- Karolina Suszek | |||
- Negin Amou | |||
- Muhammad Azeem | |||
=== | === Overview === | ||
Reconstruction of the spatiotemporal dynamics of dynamic systems is a canonically challenging task in engineering and science (may be more broadly referred to as "inverse problems"). An interesting heuristic for time dependent state dynamics problems is that they can be conceptualized as an "image to image" problem, or in some cases an "image" reconstruction problem. With this heuristic in mind, of course it makes sense to consider using generative models - which otherwise excel at image problems - for this use case. The authors introduce a specific framework which they title a "sparse sensor assisted score based generative model" (Abbreviated, <math> S^3GM </math>), leveraging a pretrained generative model and demonstrate its efficacy in recreating spatiotemporal system dynamics given only access to sparse sensor measurements. | |||
=== Operating Principles === | |||
The model framework requires a few key components. It involves the use of an embedding network, consisting of spatial and temporal convolutions, which creates a prior that is used to inform, or steer the generator toward a physically plausible solution. At the time of generation, the generator works by denoising a tensor of random gaussians, whos trajectory is governed by the embedding provided by the prior network and penalized for discrepancy between the known sensor measurements. The result is a generated sample which is both physically plausible, and in agreement with the measurements at the known locations. | |||
=== | === Implementation Details === | ||
The authors employ a "Video-U-Net" as a prior network, effectively characterizing the joint distribution of the spatiotemporal data. The generator is a "score-based-generative model" (SGM) which employs a denoising process governed by Stochastic differential equations (SDE). The generator is guided through the two aforementioned mechanisms, namely the output of the prior network guides the generation toward physically plausible solutions, and an observation consistency term, which penalizes proposed solutions which may very well come from the space of physically plausible trajectories, but which differ from the experimental (sensor) observations. | |||
=== Discussion / Conclusion === | |||
Overall the authors demonstrate the efficacy of the S3GM model on a series of canonical dynamical systems problems. They show its ability to reconstruct dynamics for Kuramoto-Sivashinksy dynamics, Kolmogorov turblent flow, climate observations and cylinder flow, achieving very low error rates. | |||
=== Related Works === | |||
An interesting paper which employs a similar methodology, is"Learning to Solve PDE-constrained Inverse Problems with Graph Networks" by Zhao, Lindell, Wetzstein (published in ICML 2023). Specifically, they employ a learned prior network to reconstruct the initial condition only, and use this in combination with a GNN to predict the forward dynamics. This coupled model could potentially take better advantage of the Generative models by using them only to create physically plausible initial conditions (when measured against real sensor locations) while using a more suitable architecture for the forward propagation. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 5 Presentation: Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models== | |||
=== Presented by: === | |||
Guoqian Li, Xinlei Xu, Wenning Xu | |||
=== Paper Citation === | |||
De, S., Smith, S., Fernando, A., Botev, A., Cristian-Muraru, G., Gu, A., Haroun, R., Berrada, L., Chen, Y., Srinivasan, S., Desjardins, G., Doucet, A., Budden, D., Teh, Y. W., Pascanu, R., De Freitas, N., Gulcehre, C. (2024). Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models. ''arXiv''. arxiv.org/pdf/2402.19427 | |||
=== Background === | |||
'''RNNs and transformers''' | |||
Recurrent neural networks (RNNs) are a basic architecture for handling sequential data; they are good for short sequences but can struggle with long ones. | |||
Transformers generally perform better than RNNs and have become dominant in recent years. However, for large sequences, they become computationally expensive for a couple of reasons. Their global attention mechanism has a quadratic complexity. Furthermore, the key-value cache and the multi-query attention cache grows linearly with sequence length. | |||
The paper proposes several innovations to improve upon existing attention mechanisms in neural networks: | |||
- RG-LRU layer: a novel gated linear recurrent layer designed to replace Multi-Query Attention (MQA) | |||
- Hawk: integrate multi-layer perceptrons (MLPs) with recurrent blocks | |||
- Griffin: combine MLPs with a mix of recurrent blocks and local attention mechanisms | |||
=== | === Technical Contributions === | ||
'''Model architecture''' | |||
Their proposed models -- Hawk and Griffin -- use these following structures: | |||
* ''' | * '''Residual block:''' This is the main structure in the architecture. It starts with an RMSNorm layer being applied to the hidden state, followed by a temporal mixing block. There's a residual connection. Next, there is another RMSNorm layer followed by an MLP (multi-layer perceptron) block. | ||
* | * '''Gated MLP block:''' This is a gated feedforward block. There are 2 branches: one is linear and one uses a GeLU activation. This part of the structure is used for feature selection. | ||
* '''Temporal mixing block:''' They used 3 different kinds of temporal mixing blocks in their models: global multi-query attention, local sliding window attention, and recurrent blocks (what they proposed). The recurrent block contains 2 parallel branches. One has a GeLU activation, while the other has a temporal 1D convolutional layer followed by a RG-LRU layer (a proposal in this paper). | |||
* | |||
''' Performance and Efficiancy ''' | |||
* | * Hawk surpasses the reported performance of Mamba on downstream tasks. <br> | ||
* | * Griffin matches the performance of Llama-2, despite being trained on over six times fewer tokens. | ||
* | * Both models demonstrate the ability to extrapolate to sequences significantly longer than those encountered during training. | ||
'''Real-Gated Linear Recurrent Unit (RG-LRU)''' | |||
RG-LRUs are inspired by regular Linear Recurrent Units (LRUs), but with a gating mechanism. They are more computational efficient, as they avoid the quadratic complexity that transformers have. Mathematically, the layer is defined as follows. | |||
The recurrence gate: | |||
= | <math> r_t = \sigma (W_a x_t + b_a) </math> | ||
The input gate: | |||
= | <math> i_t = \sigma (W_x x_t + b_x) </math> | ||
<math> a_t = a^{cr_t} </math> | |||
The output: | |||
= | <math> h_t = a_t \odot h_{t-1} + \sqrt{1-a_t^2} \odot (i_t \odot x_t) </math> | ||
=== Summaries of Key Points === | |||
==== Context and Motivation ==== | |||
Transformers have dominated large language modeling, but their attention mechanism can be expensive for very long sequences, particularly because of the quadratic complexity of attention and a Key-Value (KV) cache that grows linearly in sequence length. Recurrent neural networks (RNNs) compress entire contexts into a fixed-size hidden state and avoid storing a cache that grows with length, which can reduce memory and speed up inference. Historically, though, RNNs have been difficult to train at scale and often underperform Transformers on complex tasks. This paper proposes new recurrent-based architectures that can match or exceed Transformer performance while retaining the training efficiency of standard deep-learning pipelines and gaining a significant advantage during inference. | |||
==== Proposed Architectures ==== | |||
Two models are introduced: Hawk, a purely recurrent architecture, and Griffin, which mixes gated linear recurrences with local attention. Both rely on a new Real-Gated Linear Recurrent Unit (RG-LRU) layer, which extends a previously studied linear recurrency (the LRU) by adding two gating mechanisms. One gate modulates how much of the new input to incorporate each step, and another governs whether the layer updates in a standard LRU style or simply retains the previous hidden state. | |||
==== RG-LRU and Recurrent Blocks ==== | |||
RG-LRU’s diagonal recurrence matrix dramatically lowers the computational cost for large sequence lengths. The authors also include a lightweight convolution for near-token interactions. For Griffin, they interleave RG-LRU-based blocks with a local sliding-window attention, allowing short-range local attention plus a global recurrent state. This combination helps the model handle both local details and long-distance dependencies in a memory-efficient way. | |||
==== Empirical Results ==== | |||
The authors scale Hawk and Griffin to billions of parameters and compare them with: | |||
1. A baseline Transformer using Multi-Query Attention (MQA), | |||
2. The Mamba recurrent model (from prior work), | |||
3. | 3. Llama-2. | ||
Their main findings include: | |||
* Both Hawk and Griffin follow standard power-law scaling in validation loss versus training compute, matching or beating Transformers. | |||
* Hawk matches or exceeds previous recurrent models like Mamba on language modeling benchmarks, despite training on fewer tokens. | |||
* Griffin matches Llama-2’s accuracy on multiple downstream NLP tasks while training on only about one-seventh the token count. | |||
* During training, these models achieve speeds comparable to Transformers (thanks to parallelized or fused kernels), and in some configurations are faster at long sequence lengths (2048 tokens or more). | |||
* At inference time, recurrent and local-attention blocks avoid the large KV caches that hamper Transformers. Hence, Hawk and Griffin show lower latency and can decode tokens at significantly higher throughput, especially for long sequences. | |||
* Hawk and Griffin extrapolate to far longer sequences than they were trained on. They also efficiently learn synthetic copying/retrieval tasks if explicitly trained on them. However, their out-of-the-box retrieval or copying for tasks not seen at training time is still below that of Transformers. | |||
== | ==== Significance ==== | ||
= | |||
By mixing local attention with a novel gated recurrence, Griffin retains the best of both recurrent and attention-based approaches: it can model local context with a sliding window while maintaining a small, fixed-size global state for the entire sequence. This achieves strong performance on large-scale language modeling benchmarks while offering major advantages in memory footprint and decoding speed. The results position Hawk and Griffin as compelling alternatives to purely Transformer-based architectures for scalable language models. | |||
=== | === Related work and Future direction === | ||
Griffin builds on several key areas in efficient language modelling, particularly in recurrent architectures and hybrid approaches combining recurrence and attention: State Space Models (SSMs), Hybrid Recurrence-Attention Models, Efficient Transformer Alternatives. | |||
=== Scaling Griffin to Larger Models === | |||
The paper discusses Griffin models up to 14B parameters but suggests further exploration into scaling beyond this size to compete with GPT-4 or Gemini models. Investigating how Griffin handles long-context tasks in ultra-large models would be an interesting future study. | |||
=== Memory and Efficiency Optimization === | |||
Griffin already improves efficiency compared to transformers, but further research into sparsification, quantization, and hardware-specific optimizations could enhance its real-world applicability. Exploring mixed-precision training and inference optimizations for mobile and edge devices. | |||
=== Extending Griffin to Multimodal Learning === | |||
While Griffin is designed for language modeling, incorporating it into multimodal tasks (e.g., vision-language models, audio processing) could expand its impact. Combining Griffin’s recurrence mechanism with diffusion models or video understanding models might be a promising direction. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
=== | == Group 5 Presentation: Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models== | ||
=== Presented by: === | |||
Guoqian Li, Xinlei Xu, Wenning Xu | |||
=== Paper Citation === | |||
- | De, S., Smith, S., Fernando, A., Botev, A., Cristian-Muraru, G., Gu, A., Haroun, R., Berrada, L., Chen, Y., Srinivasan, S., Desjardins, G., Doucet, A., Budden, D., Teh, Y. W., Pascanu, R., De Freitas, N., Gulcehre, C. (2024). Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models. ''arXiv''. arxiv.org/pdf/2402.19427 | ||
=== Summaries of Key Points === | |||
==== Introduction and Motivation ==== | |||
This paper introduces the concept of Neural Collapse, a surprising and consistent geometric phenomenon observed in the terminal phase of training deep neural networks for classification. The authors highlight how neural networks, regardless of architecture or dataset, tend to collapse into a highly symmetrical state in which class features and classifiers align with remarkable regularity. The term terminal phase refers to the late stage of training after zero training error has already been achieved. | |||
==== Description of Neural Collapse ==== | |||
Neural Collapse (NC) involves four key empirical properties: | |||
• NC1 – Variability Collapse**: Within-class variability in feature representations collapses to zero. | |||
• NC2 – Convergence to Class Means**: The feature vectors of each class converge to their class mean. | |||
• NC3 – Equiangular Tight Frame (ETF) Structure**: The class means become equidistant and symmetrically arranged, forming a simplex ETF. | |||
- | • NC4 – Alignment of Classifier and Features**: The last-layer classifier weights align with the class means. | ||
- | These properties are observed across various settings, models, and datasets (e.g., MNIST, CIFAR-10, CIFAR-100, ImageNet). | ||
=== | ==== Methodological Approach ==== | ||
The authors provide both empirical and theoretical support for neural collapse. They measure within-class variability, angles between class means, and alignment between classifiers and features to verify the four NC properties. They also propose a simplified analytical model (a deep linear network trained under certain assumptions) to theoretically demonstrate why neural collapse can emerge. | |||
==== Theoretical Explanation ==== | |||
A key insight is the identification of minimization of cross-entropy loss with weight decay as an implicit regularizer driving networks toward this highly symmetric configuration. The authors prove that under simplified conditions, the ETF structure is an optimizer of the regularized loss. This aligns theoretical predictions with observed empirical behavior. | |||
==== Experiments and Findings ==== | |||
Across multiple experiments (shown through figures and plots in the paper), the authors demonstrate that: | |||
• Neural collapse becomes prominent after training accuracy hits 100%. | |||
• Even with non-linear architectures and real-world data, the ETF configuration emerges. | |||
• This behavior is observed even when networks are over-parameterized, suggesting it’s not due to constraints but rather a preference encoded by gradient descent. | |||
==== Implications and Broader Impact ==== | |||
This | Neural collapse reveals that trained neural networks inherently develop geometric simplicity and symmetry in their representations. This insight could lead to better theoretical understanding of deep learning and inspire new architectures or training methods that explicitly promote these properties. It also connects to classical ideas in signal processing and geometry, such as tight frames and simplex structures. | ||
=== | === Constructive Critique and Review === | ||
This paper offers a compelling contribution to the theoretical understanding of deep learning by identifying and rigorously analyzing the phenomenon of Neural Collapse. The authors present clear empirical evidence supported by a strong theoretical foundation, illustrating how trained deep networks tend to converge toward highly structured geometric configurations in their final training phase. This finding bridges a gap between practical neural network behavior and abstract geometric regularities, making it highly relevant to both practitioners and theorists in the field. | |||
One of the most commendable aspects of the paper is its clear articulation of four key characteristics of neural collapse, each of which is supported by intuitive visualizations and consistent experimental evidence. The authors do an excellent job demonstrating the robustness of these phenomena across a wide range of datasets and architectures. The simplified theoretical model and analytical derivations further strengthen the paper’s foundation, offering an elegant and accessible explanation of a previously undocumented training behavior. | |||
Despite these strengths, the work does leave room for further exploration. While the paper presents strong empirical results, most of the theoretical analysis is limited to simplified, idealized conditions such as deep linear networks or networks trained with weight decay. It is not yet clear how well these theoretical insights extend to more complex or non-convex training dynamics commonly used in real-world applications. Additionally, the paper focuses exclusively on classification tasks. It would be valuable to explore whether similar collapse behaviors occur in regression settings or in models trained on multi-label or sequence-based tasks. | |||
Moreover, the practical implications of neural collapse remain largely speculative. While the geometric symmetry is intellectually intriguing, the paper does not provide concrete evidence that exploiting this phenomenon leads to better performance or generalization. Future work could explore whether enforcing or encouraging neural collapse during training could yield benefits in model robustness or efficiency. | |||
Overall, the paper is well-executed and offers a fresh theoretical lens on deep learning. With additional investigation into practical applications and broader model types, this line of research could offer foundational insights into why deep networks generalize so effectively. | |||
</div> | |||
= | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
== Group 5 Presentation: Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models== | |||
=== Presented by: === | |||
Guoqian Li, Xinlei Xu, Wenning Xu | |||
=== Paper Citation === | |||
- | De, S., Smith, S., Fernando, A., Botev, A., Cristian-Muraru, G., Gu, A., Haroun, R., Berrada, L., Chen, Y., Srinivasan, S., Desjardins, G., Doucet, A., Budden, D., Teh, Y. W., Pascanu, R., De Freitas, N., Gulcehre, C. (2024). Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models. ''arXiv''. arxiv.org/pdf/2402.19427 | ||
=== Summaries of Key Points === | |||
=== | ==== Motivation and background ==== | ||
Dominance of Transformers and their drawbacks: Transformers (with global attention) have driven most breakthroughs in large-scale language modeling. Despite their strengths, transformers struggle with quadratic complexity on long sequences; the large Key-Value (KV) cache grows linearly in the sequence length during inference, slowing down decoding for long inputs. | |||
Potential of Recurrent Models: Recurrent neural networks (RNNs) offer an appealing alternative because they summarize the sequence with a fixed-size hidden state that is updated token by token (iteratively), thus avoiding a large KV cache. However, classical RNNs can be slow to train in practice and have struggled to match the performance of Transformers at large scales. | |||
Goal: To develop recurrent architectures (and hybrids that combine local attention and recurrences) which | |||
(1) match or exceed Transformers’ performance on language modeling. | |||
(2) train efficiently at scale (both in FLOPs and hardware utilization). | |||
(3) offer improved inference efficiency (lower latency, higher throughput). | |||
=== | ==== Architecture ==== | ||
==== Hawk (MLPs with recurrent blocks) ==== | |||
(1) Recurrent Block: Replaces self-attention entirely with a diagonal recurrent design. It processes tokens sequentially with fast iteration at inference time. | |||
(2) RG-LRU Layer: The crucial piece inside the Hawk model is the Real-Gated Linear Recurrent Unit (RG-LRU), a novel RNN layer that is stable, uses simple diagonal recurrences, and includes input and recurrence gates: Recurrence Gate <math> r_t </math> and input gate <math> i_t </math> each depend on the input <math> x_t </math> but not the recurrent state <math> h_(t-1) </math>. This yields stable, memory-friendly computations. The update equation mixes the previous hidden state <math> h_(t-1) </math> and a transformed input | |||
<math> x_t </math> using a diagonal recurrent weight <math> a </math>. This ensures the update is purely elementwise, minimizing overhead during inference. | |||
(3) Achieves strong performance on language modeling while avoiding the large KV cache burden of Transformers. | |||
Provides super-exponential memory (depending on the gating mechanism) and can, in practice, capture long-range dependencies. | |||
==== Griffin (MLPs with a mixture of recurrent blocks and local attention) ==== | |||
(1) Motivation: Although recurrence alone scales well to long sequences, local patterns might still be easier to capture via (local) attention. Mixing the two yields a model that can leverage local attention for shorter-range patterns while relying on a fixed-size recurrent state for longer context. | |||
(2) Local Attention: Replaces global attention with a sliding-window approach (no quadratic cost in sequence length). The local window ensures that each token attends only to a fixed number of past tokens. | |||
(3) Interleaving Blocks. Griffin’s architecture interleaves the same RG-LRU-based recurrent block and a local attention block. This combination preserves efficiency (small recurrence state, bounded cache size), and matches or outperforms Transformers at the same scale. | |||
== | ==== Empirical Results and Main Findings ==== | ||
=== | ==== Power-Law Scaling ==== | ||
- | Hawk, Griffin, and a baseline Transformer (using multi-query attention, MQA) all exhibit a power-law relationship between validation loss and training FLOPs, resembling established Transformer scaling laws. Griffin consistently has slightly better loss–compute tradeoffs than the baseline Transformer at all scales tested. | ||
- | ==== Performance vs. Mamba and Llama-2 ==== | ||
Hawk surpasses Mamba (a prior recurrent state-space model) on downstream tasks at 3B parameters despite training on half as many tokens (300B vs. 600B). | |||
Griffin (7B, 14B) matches or slightly outperforms Llama-2–7B/13B despite being trained on about seven times fewer tokens. | |||
Conclusion: The new architectures remain competitive on well-known benchmarks such as MMLU, HellaSwag, ARC, PIQA, and WinoGrande, even against popular open-source Transformers. | |||
==== Implementation Challenges ==== | |||
RG-LRU’s diagonal recurrence is memory-bound on TPUs (and GPUs), as it performs elementwise updates rather than large matrix multiplications.The paper introduces a custom kernel (using Pallas on TPU-v3) that reads and writes hidden states in a linear scan and minimizes memory transfers, greatly speeding up training. | |||
==== Speed Gains for Long Sequences ==== | |||
- | For short sequence lengths (e.g. 2K tokens), Griffin trains at a speed on par with the MQA baseline.As sequence lengths grow (4K, 8K), the MQA baseline slows down due to attention complexity, whereas the recurrent models maintain similar runtime. Hence, Hawk/Griffin gain a training-speed advantage at longer contexts. | ||
==== Throughput and Latency Gains ==== | |||
Transformers (even MQA) rely on a KV cache growing linearly in sequence length. The overhead can become larger than the model parameters at long decode lengths. | |||
Hawk and Griffin keep a small, fixed-size recurrence state. This yields: | |||
(1) Lower latency: less reading/writing of large caches for extended contexts. | |||
(2) Higher throughput: bigger batch sizes fit into memory because the small recurrence state frees device memory otherwise used by attention KV caches. | |||
==== Additional Observations (Long Context and Copy/Retrieve Tasks) ==== | |||
==== Extrapolation to Long Sequences ==== | |||
Hawk and Griffin can successfully exploit contexts longer than those used during training, showing better extrapolation than Transformers with typical positional embeddings (e.g. RoPE). Training them explicitly on 8K sequences (rather than 2K) further improves performance at longer contexts. | |||
==== Copying/Retrieval Skills ==== | |||
Hawk alone may learn copy tasks but can sometimes lag behind Transformers in speed of learning. This matches prior observations about slow adaptation of purely recurrent state-space layers on synthetic tasks. Griffin (mixing local attention + recurrence) can learn copy tasks more quickly, similar to Transformers, and can extrapolate better to sequences beyond the training length. | |||
==== Significance and Contributions ==== | |||
==== Challenging the Status Quo ==== | |||
- | This work demonstrates that RNN-based architectures can be trained at large scale as efficiently as Transformers, with the added benefit of small fixed-size hidden states. | ||
=== | ==== High Performance at Fewer Tokens ==== | ||
- | Griffin achieves or exceeds Llama-2 performance while using substantially fewer training tokens, highlighting that it is not necessary to rely purely on full-sequence attention. | ||
=== | ==== Practical Inference Advantage ==== | ||
Significantly reduced inference latency and improved throughput on long sequences is a major practical benefit in real-world applications when generating long responses and streaming data. | |||
=== | ==== Blueprint for Hybrid Models ==== | ||
By combining local attention (for short-range patterns) with a gated linear recurrence (for stable, long-range capacity), Griffin strikes a balance that may inform future research in large language models beyond classical Transformers. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 5: Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models == | |||
=== Presenters === | |||
Guoqian Li, Xinlei Xu, Wenning Xu | |||
=== | === Paper Citation === | ||
- | De, S., Smith, S., Fernando, A., Botev, A., Cristian-Muraru, G., Gu, A., Haroun, R., Berrada, L., Chen, Y., Srinivasan, S., Desjardins, G., Doucet, A., Budden, D., Teh, Y. W., Pascanu, R., De Freitas, N., Gulcehre, C. (2024). Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models. arXiv. arxiv.org/pdf/2402.19427 | ||
=== Introduction === | |||
Researchers at Google DeepMind have introduced a new approach to language modelling that combines the strengths of recurrent neural networks (RNNs) and Transformers. Their work, titled "Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models," presents a novel architecture that aims to overcome the limitations of traditional methods. | |||
=== | === Background === | ||
RNNs were foundational in the early days of deep learning and natural language processing, demonstrating success in various applications, including machine translation. However, the Transformer architecture has since become dominant, achieving superior performance and hardware efficiency. Despite their success, Transformers face challenges in scaling to long sequences due to the computational demands of global attention and the increasing memory required for the Key-Value (KV) cache during inference. | |||
=== Main Idea === | |||
The authors propose a hybrid model called "Griffin" that mixes gated linear recurrences with local attention mechanisms. This design aims to achieve the efficiency of RNNs in handling long sequences while maintaining the performance of Transformers. The core component of their recurrent architecture is a novel gated linear recurrent layer called the Real-Gated Linear Recurrent Unit (RG-LRU). | |||
=== Experiments === | |||
The researchers conducted several experiments to evaluate their models: | |||
• They compared the scaling behaviour of their models (Hawk and Griffin) against a Multi-Query Attention (MQA) Transformer baseline, examining the relationship between held-out loss and training FLOPs. | |||
• They assessed the models' performance on downstream tasks, comparing them to Mamba-3B and Llama-2. | |||
• They measured training speeds on TPU-v3 devices. | |||
• They evaluated inference speed, considering latency and throughput. | |||
• They tested the models' ability to handle long contexts and perform copying and retrieval tasks. | |||
=== Results === | |||
The key findings of the paper include: | |||
• Griffin demonstrates comparable scaling performance to Transformers. | |||
• Griffin matches the performance of Llama-2 while being trained on significantly less data. | |||
• Griffin and Hawk exhibit comparable training efficiency to Transformers on TPU-v3s. | |||
• Griffin achieves higher throughput and lower latency during inference, especially with longer sequences. | |||
• Griffin demonstrates strong extrapolation capabilities on long sequences and performs well on copying and retrieval tasks. | |||
- | In conclusion, the Griffin architecture presents a promising direction for language models, offering a balance between performance, efficiency, and the ability to handle long-range dependencies in sequences. | ||
- | </div> | ||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 5: Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models == | |||
=== | === Presenters === | ||
Guoqian Li, Xinlei Xu, Wenning Xu | |||
=== Paper Citation === | |||
- | De, S., Smith, S., Fernando, A., Botev, A., Cristian-Muraru, G., Gu, A., Haroun, R., Berrada, L., Chen, Y., Srinivasan, S., Desjardins, G., Doucet, A., Budden, D., Teh, Y. W., Pascanu, R., De Freitas, N., Gulcehre, C. (2024). Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models. arXiv. arxiv.org/pdf/2402.19427 | ||
- | === Research Motivation === | ||
Recurrent Neural Networks (RNNs) compress the entire sequence into a fixed-size hidden state, which is updated through iterations. | |||
Key Findings | Transformers outperform RNNs by employing multi-layer perceptrons (MLPs) and multi-head attention (MHA). The complexity of global attention in Transformers is quadratic, while the growth of the Key-Value (KV) cache increases linearly. With Multi-Query Attention (MQA), the cache continues to grow linearly with the sequence length. | ||
=== Contribution === | |||
RG-LRU layer: a novel gated linear recurrent layer to replace MQA | |||
Hawk: MLPs with recurrent blocks | |||
Griffin: MLPs with a mixture of recurrent blocks and local attention | |||
=== Key Findings === | |||
1. The held-out loss decreases as more training FLOPs are used. Griffin achieves slightly lower held-out loss across all model sizes. | |||
2. Improved performance. Hawk-3B outperforms Mamba-3B on downstream tasks. Griffin-7B and Griffin-14B perform similarly to Llama-2 but were trained on approximately seven times fewer tokens. | |||
3. Comparable training efficiency to Transformers on TPU-v3. | |||
4. Griffin achieves significantly higher throughput than MQA transformers. | |||
5. Griffin performs better when evaluated on sequences longer than those seen during training. | |||
6. Griffin performs less effectively than Transformers on copy and exact retrieval tasks without fine-tuning. | |||
- | |||
</div> | |||
< | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
== | == Group 5: Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models == | ||
=== Presenters === | |||
Guoqian Li, Xinlei Xu, Wenning Xu | |||
=== Paper Citation === | |||
De, S., Smith, S., Fernando, A., Botev, A., Cristian-Muraru, G., Gu, A., Haroun, R., Berrada, L., Chen, Y., Srinivasan, S., Desjardins, G., Doucet, A., Budden, D., Teh, Y. W., Pascanu, R., De Freitas, N., Gulcehre, C. (2024). Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models. arXiv. arxiv.org/pdf/2402.19427 | |||
=== Summary=== | |||
1.Introduction of RG-LRU Layer | |||
The paper proposes the Real-Gated Linear Recurrent Unit (RG-LRU), a novel gated linear recurrent layer. This layer combines the stability of linear recurrences with input-dependent gating mechanisms inspired by LSTMs and GRUs, enhancing the model's ability to handle long sequences efficiently. | |||
2. Hybrid Architecture (Griffin) | |||
The authors introduce Griffin, a hybrid model that integrates RG-LRU layers with local attention mechanisms. This combination leverages the strengths of both recurrent neural networks (efficient long-sequence handling) and attention mechanisms (local context modeling), achieving superior performance compared to pure RNNs or Transformers. | |||
3. Scalability and Efficiency | |||
2) | The paper demonstrates that Griffin and Hawk (a pure RNN variant) scale efficiently with model size, following power-law scaling similar to Transformers. Griffin matches or exceeds the performance of larger models like Llama-2 while being trained on significantly fewer tokens (6 times fewer). | ||
4. Hardware Optimization | |||
The authors address the challenge of efficiently training diagonal RNNs on TPUs by developing a custom Pallas kernel for the RG-LRU layer. This optimization minimizes memory transfers, achieving near 3x speedup over naive implementations and ensuring competitive training speeds with Transformers. | |||
5. Inference Advantages | |||
Griffin and Hawk exhibit lower latency and higher throughput during inference, especially for long sequences, due to their fixed-size state (unlike the linearly growing KV cache in Transformers). This makes them practical for real-time applications and large-scale deployments. | |||
6. Long-Context Extrapolation | |||
The models show remarkable ability to extrapolate to sequences much longer than those seen during training. Griffin, in particular, maintains strong performance on long-context tasks, outperforming Transformers in such scenarios. | |||
7. Copying and Retrieval Capabilities | |||
The paper explores the models' performance on synthetic tasks like selective copying and induction heads. Griffin, with its hybrid design, matches Transformer performance on these tasks, while pure RNNs (Hawk) lag behind. This highlights the importance of combining recurrence and attention for such capabilities. | |||
8. Empirical Validation | |||
- | Extensive experiments validate the models' performance across various benchmarks (MMLU, HellaSwag, etc.), demonstrating that Griffin achieves competitive or better results than state-of-the-art baselines, even with reduced computational budgets. | ||
- | === Impact === | ||
The work advances the field by offering a viable alternative to Transformers, particularly for long-sequence tasks, with significant improvements in training and inference efficiency. The hybrid design of Griffin sets a new benchmark for balancing performance and computational cost in large language models. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
== Group | == Group 6 Presentation: Learning to (Learning at Test Time): RNNs with Expressive Hidden States == | ||
=== Presented | === Presented by: === | ||
Zhiyang Cheng and Pingchu Zhang | |||
=== Paper Citation === | === Paper Citation === | ||
Sun, Y., Li, X., Dalal, K., Xu, J., Vikram, A., Zhang, G., Dubois, Y., Chen, X., Wang, X., Koyejo, S., Hashimoto, T., & Guestrin, C. (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv. https://doi.org/10.48550/arXiv.2407.04620. | |||
=== Summaries of key points === | |||
This paper revisits the traditional role of RNN hidden states, proposing that they can do more than just store information—they can enable learning at test time. The authors introduce a method where a hypernetwork dynamically generates the RNN’s weights as a function of its hidden state, allowing the model to adapt its behavior on the fly. This gives rise to what they call expressive hidden states, which encode both memory and the capacity to steer the model’s future updates. The approach effectively blurs the line between training and inference, treating test-time prediction as a form of continual adaptation. This results in stronger performance in settings like few-shot learning and online learning, where flexibility and rapid adaptation are crucial. Rather than relying on explicit optimization at test time (as in typical meta-learning setups), the RNN itself becomes the learner, continuously reshaping its internal dynamics based on the sequence it's processing. | |||
While innovative, the method introduces nontrivial architectural and computational overhead. The use of a hypernetwork to produce weights at every time step means the model must manage a more complex parameter space and could become less scalable for long sequences or larger models. There's also the risk of instability, since small changes in the hidden state can lead to large changes in the generated weights. Regularization and careful design are needed to prevent the model from diverging. Another limitation is that while the paper shows strong performance on synthetic and controlled few-shot learning tasks, it doesn’t extensively benchmark on more complex natural language or real-world sequential data, leaving questions about generalization and practicality open. | |||
=== | === Clear explanations to aid understanding === | ||
In a standard RNN, the weights are fixed during inference—you feed in tokens or sequence elements, and the hidden state updates based on those fixed rules. What this paper suggests is: what if the hidden state itself could influence the rules? So instead of always using the same weights, the RNN can generate new ones depending on what it's seen so far. This is done using a hypernetwork—a small neural network that outputs the weights for the main RNN. So as the RNN processes a sequence, it effectively reshapes its own behavior to fit the task or data distribution it's encountering. It’s like the RNN is learning while it’s making predictions, adapting in real-time to maximize performance without needing gradient descent at test time. | |||
</div> | |||
= | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
- | == Group 6 Presentation: Learning to (Learning at Test Time): RNNs with Expressive Hidden States == | ||
=== Presented by: === | |||
Zhiyang Cheng and Pingchu Zhang | |||
=== Paper Citation === | |||
Sun, Y., Li, X., Dalal, K., Xu, J., Vikram, A., Zhang, G., Dubois, Y., Chen, X., Wang, X., Koyejo, S., Hashimoto, T., & Guestrin, C. (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv. https://doi.org/10.48550/arXiv.2407.04620. | |||
=== Summaries of Key Points === | |||
==== Introduction of a New Research Question ==== | |||
This paper investigates whether incorporating sentence-level context—information from surrounding sentences—can improve the quality of translations produced by Statistical Machine Translation (SMT) systems. Traditional SMT models typically translate each sentence independently, ignoring the potential benefits of broader contextual information. The authors aim to quantify how much context matters and to what extent it can enhance translation fluency, coherence, and accuracy. | |||
==== Design of a Context-Aware Re-Ranking Model ==== | |||
To evaluate the usefulness of sentence-level context, the authors develop a re-ranking model that operates on the n-best translation outputs of a baseline SMT system. This model uses a discriminative classifier trained to distinguish between better and worse hypotheses using features that include lexical overlap, syntactic consistency, and contextual similarity with adjacent sentences. By integrating these features, the system is able to promote translations that are more coherent within the broader discourse. | |||
==== Use of Diverse and Realistic Datasets ==== | |||
The authors test their model on two datasets with naturally occurring multi-sentence structures: Europarl (parliamentary proceedings) and OpenSubtitles (dialogue-driven subtitles). These corpora represent both formal and conversational genres, providing a comprehensive testbed for evaluating the effectiveness of context. The subtitle data, in particular, presents challenges such as short, ambiguous sentences that strongly benefit from contextual cues. | |||
==== Evaluation Through Both Automated and Human Measures ==== | |||
The proposed system shows consistent, though modest, improvements in BLEU scores compared to the baseline. However, human evaluation reveals clearer gains in fluency, referential consistency, and discourse-level cohesion. These results suggest that standard metrics may underestimate the value of context and highlight the importance of human judgment in translation assessment. | |||
==== Contributions to Future Directions in MT ==== | |||
- | While the overall performance boost is not dramatic, this paper plays an important role in shifting attention toward discourse-aware translation. It lays the groundwork for future research in context modeling, which later becomes central to neural machine translation approaches. The authors also advocate for more nuanced evaluation techniques that capture translation quality beyond sentence-level accuracy. | ||
=== Constructive Critique and Review === | |||
This paper provides an insightful early investigation into the role of sentence-level context in improving the output of statistical machine translation systems. The authors tackle an important problem: the lack of inter-sentential coherence in SMT, which treats each sentence independently. Their proposed method—context-aware reranking of translation hypotheses using a discriminative classifier—is both innovative and practical, as it builds on existing SMT outputs rather than requiring system retraining. | |||
A major strength of the paper lies in its thoughtful experimental design. By selecting two distinct corpora, Europarl and OpenSubtitles, the authors ensure that the method is evaluated in both formal and conversational settings. This choice highlights the stronger impact of context in domains with high ambiguity and short utterances, such as subtitles. The integration of automatic and human evaluation adds further depth to the analysis, revealing that improvements in fluency and coherence may be underrepresented by standard metrics like BLEU. | |||
However, there are limitations that reduce the generalizability and interpretability of the findings. The classifier’s performance is difficult to isolate due to limited ablation or feature-wise analysis. While the feature set is described, the contribution of individual context-related features remains unclear. A clearer breakdown of which contextual signals are most influential would have strengthened the practical implications of the work. | |||
Furthermore, the improvements reported are modest in terms of BLEU scores, raising questions about the tradeoff between additional model complexity and measurable gains. The paper also predates neural machine translation, and while it was forward-thinking at the time, some of its techniques may appear limited by today’s standards. Nonetheless, the core insight—that context contributes meaningfully to translation quality—is validated and influential. | |||
Overall, this is a well-motivated and carefully executed study that helped shift attention in the MT community toward discourse-aware modeling. Its methodological clarity and early focus on contextual coherence paved the way for future advances in both evaluation and translation architecture. | |||
</div> | |||
- | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
== Group 6 Presentation: Learning to (Learning at Test Time): RNNs with Expressive Hidden States == | |||
=== Presented by: === | |||
Zhiyang Cheng and Pingchu Zhang | |||
=== Paper Citation === | |||
Sun, Y., Li, X., Dalal, K., Xu, J., Vikram, A., Zhang, G., Dubois, Y., Chen, X., Wang, X., Koyejo, S., Hashimoto, T., & Guestrin, C. (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv. https://doi.org/10.48550/arXiv.2407.04620. | |||
=== Summaries of key points === | |||
- | Goal: In Test-Time Training, make the hidden state into a small model that can be updated to improve the sequence modeling ability. | ||
Background: The hidden state of RNNS is usually a fixed dimension that limits their expressiveness. | |||
- | Methodology: Each step updates W gradients through a self-supervised task. Dual Form turns multi-step updates into a single matrix operation. | ||
Result: In short sequences, the TTT model behaves similar to existing methods. In long sequences, TTT-Linear and TTT-MLP are significantly superior to Transformer and Mamba. TTT-Linear inference speed is closer to Mamba and faster than Transformer. | |||
=== Constructive critiques or reviews === | |||
The presentation is clearly structured, and the slides include pictures and diagrams to help listeners understand better. | |||
Turned on the camera to make it easier on the listener. | |||
It can increase fluency appropriately. | |||
=== Clear explanations to aid understanding === | |||
TTTN layer: Learn directly on the test sequence, and the update process is implemented through self-supervised learning. | |||
Efficiency optimization: Improve computing efficiency with mini-batch and dual-form | |||
Benefits of using dual form: | |||
- Reduces memory consumption by not storing intermediate gradient matrices explicitly. | |||
- Maximizes GPU/TPU hardware utilization by using matrix multiplications instead of sequential outer products. | |||
=== Connections to related works === | |||
Mamba: Mamba uses a state-space model for remote dependency capture. | |||
</div> | |||
- | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
== Group 6 Presentation: Learning to (Learn at Test Time): RNNs with Expressive Hidden States== | |||
=== Presented by: === | |||
- | - Pingchu Zhang | ||
- | - Zhiyang Cheng | ||
=== Paper Citation === | |||
Sun, Y., Li, X., Dalal, K., Xu, J., Vikram, A., Zhang, G., Dubois, Y., Chen, X., Wang, X., Koyejo, S., Hashimoto, T., & Guestrin, C. (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv. https://doi.org/10.48550/arXiv.2407.04620 | |||
=== | === Background === | ||
For modern RNNs, performance in long context is limited by the expressive power of their hidden state of fixed size. Hence the authors introduced test-time training (TTT) | |||
- | === Technical Contributions === | ||
- Introduce TTT layers, where the hidden state is a model and the update rule is self-supervised learning, offering a new research direction. | |||
- | - TTT-Linear, a simple implementation of TTT layers, outperforms Transformers and Mamba in evaluations. | ||
- | - Improve the hardware efficiency of TTT layers through mini-batch TTT and the dual form, making TTT-Linear already a practical building block for LLMs. | ||
=== Methodology === | |||
The key idea is to make the hidden state itself a model with weights, and the update rule a gradient step on the self-supervised loss. Then updating the hidden state on a test sequence is equivalent to training the model at test time. | |||
==== TTT as updating a hidden state ==== | |||
- | ==== Training a network with TTT layers ==== | ||
- Training the larger network as the outer loop and training weights within each TTT layer as the inner loop is preferred. | |||
- | - TTT layers can replace RNN or self-attention layers in any network architecture. Training a network with TTT layers also works the same way as training any other language model. | ||
- | ==== Learning a self-supervised task for TTT ==== | ||
Add some outer-loo parameters to make this task learnable. | |||
The input <math>x_t</math> is transformed using a learnable matrix <math>\theta_K</math> to create a projection <math>\tilde x_t = \theta_k x_t</math> | |||
The reconstruction label is another low-rank projection <math>\theta_V x_t</math> which can differ from the input. Then we can create a test view <math>\theta_Q x_t</math> | |||
- | Now the new self-supervised loss is: <math>l(W,;x_t) = \|f(\theta_k x_t; W)-\theta_V x_t\|^2</math> and the output rule is modified to <math>z_t = f(\theta_q x_t;W_t)</math> | ||
=== | === Summaries of Key Points === | ||
==== Motivation: Compressing Long Context Efficiently ==== | |||
Traditional Transformers handle large contexts by storing every token in a Key-Value cache, which grows linearly with sequence length and makes inference complexity scale quadratically. Modern recurrent neural networks (RNNs) like Mamba sidestep storing the entire context by having a fixed-size hidden state, which leads to linear time complexity. However, RNNs often struggle to exploit very long contexts because the fixed-size hidden state must compress a large amount of information. The authors propose a new design, Test-Time Training (TTT), that treats the hidden state as a small learnable model trained via a self-supervised loss on each incoming token—even at test time. | |||
==== TTT Layers: Hidden State as a Learner ==== | |||
The paper reframes any sequence-modeling layer as “a hidden state plus an update rule.” For TTT layers, the hidden state is itself a small parametric or nonparametric model f, and the update rule is a step of gradient descent (or other training procedure) on each new token. Thus, at every token step, the hidden state is updated by “training” f on a self-supervised objective. Concretely, one might define a corruption or partial view of the token and train the parametric model to reconstruct the hidden or relevant aspects of the token. | |||
=== | ==== Two Main Instantiations: TTT-Linear and TTT-MLP ==== | ||
The authors propose TTT-Linear, where the learner f is a simple linear mapping plus optional layer normalization and a residual connection. They also propose TTT-MLP, which uses a two-layer MLP as its learner, offering a more expressive hidden state. Both can be integrated into existing RNN-based or Transformer-based architectures in place of the usual self-attention or simple RNN blocks. Like other RNN layers, TTT layers compress all historical tokens into a fixed-size hidden state—but the learner can be updated more flexibly via gradient steps each time a new token arrives. | |||
* | ==== Efficiency Enhancements ==== | ||
* | Naively computing a gradient step per token would be too slow. Two key ideas improve hardware utilization: | ||
1. **Mini-batch TTT** processes a batch of b tokens at once to parallelize the internal gradient steps. Smaller b yields more gradient steps (and better expressiveness) but can slow performance. | |||
2. **A “dual form”** for TTT-Linear and TTT-MLP reworks the update and output computations into larger matrix multiplications, ensuring that modern accelerators (GPUs, TPUs) can exploit efficient batched operations. | |||
==== Empirical Results ==== | |||
On language-modeling benchmarks (the Pile and Books), TTT-Linear and TTT-MLP match or exceed strong baselines (Transformer and the modern RNN Mamba) across model scales (125M to 1.3B parameters). TTT-Linear typically does as well as Mamba in short context (2k tokens) but outperforms Mamba substantially in longer contexts (8k or 32k), demonstrating that the extra expressiveness helps exploit more tokens. TTT-MLP can be even more expressive at very long contexts but can be more memory intensive. The authors also show that TTT-Linear can train and infer efficiently in wall-clock time using a specialized GPU kernel, yielding near-constant inference latency as context grows (unlike the linear growth in Transformers). | |||
==== Significance and Future Work ==== | |||
TTT recasts the hidden-state update in an RNN-like layer as explicitly training a miniature model at test time—essentially “learning to learn” from each incoming token. With further improvements in tasks (beyond simple reconstruction), hardware kernels, and more expressive hidden states, TTT-based architectures may offer a new path toward efficient, high-performing sequence models for extremely long contexts. | |||
==== | |||
</div> | </div> | ||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
== Group 6 Presentation: Learning to (Learn at Test Time): RNNs with Expressive Hidden States== | |||
=== Presented by: === | |||
Pingchu Zhang, Zhiyang Cheng | |||
== | === Paper Citation === | ||
Sun, Y., Li, X., Dalal, K., Xu, J., Vikram, A., Zhang, G., Dubois, Y., Chen, X., Wang, X., Koyejo, S., Hashimoto, T., & Guestrin, C. (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv. https://doi.org/10.48550/arXiv.2407.04620 | |||
=== | === Background === | ||
Recurrent Neural Networks (RNNs) are attractive for their linear time complexity, which makes them efficient, especially for long-context inputs. However, they’ve historically struggled to match the performance of Transformers on tasks like language modeling. One key limitation is the fixed-size hidden state of RNNs, which forces them to compress all past context into a compact representation. This compression becomes increasingly difficult as the context grows longer. | |||
Recent RNN variants like Mamba have closed the gap in scaling performance, but they still hit a ceiling: their ability to improve predictions plateaus at long context lengths (e.g., beyond 16k tokens). Transformers, in contrast, continue to benefit from more context, although at a higher computational cost due to their quadratic scaling. | |||
The authors suggest that this limitation is tied to the expressive capacity of the hidden state. Inspired by how large language models compress vast datasets into their weights through training, they explore whether a hidden state can itself be a learnable model, updated online, even during inference. | |||
=== Main Idea === | |||
( | The core proposal is the Test-Time Training (TTT) layer, a new kind of sequence modeling layer where the hidden state is a model, and the update rule is a self-supervised learning step. Instead of simply storing a vector or matrix, the hidden state consists of the weights of a small model (like a linear function or a 2-layer MLP). These weights are updated at each time step using gradient descent based on a self-supervised loss. | ||
Key points: | |||
The update happens at test time, not just during training—hence “test-time training.” | |||
The layer sees each input token as a new self-supervised learning opportunity, updating its internal model to better predict the next token. | |||
This approach allows the hidden state to grow in complexity without growing in size—it gains depth by learning, not by storing. | |||
Two instantiations are tested: | |||
TTT-Linear, where the hidden state is a linear model. | |||
TTT-MLP, where the hidden state is a 2-layer MLP. | |||
This method can be used in place of RNN or attention layers, and is compatible with existing architectures. Despite its novel structure, it can be trained end-to-end like other language models. | |||
To make this practical on hardware, the authors also design efficient mini-batch updates and a dual form of the forward pass that enables good GPU utilization. These tricks allow them to run TTT layers efficiently, even faster than Transformers in some regimes. | |||
=== Experimental & Result === | |||
The authors evaluate TTT-Linear and TTT-MLP against two baselines: a strong Transformer and Mamba (a recent high-performing RNN). They focus on both performance and efficiency, testing across different model sizes and context lengths. | |||
1. Short Context (2k and 8k tokens) | |||
At 2k tokens, TTT-Linear, Mamba, and Transformer perform similarly. | |||
At 8k tokens, TTT models outperform Mamba. This shows that as context grows, the test-time learning approach starts to shine. | |||
TTT-MLP generally has better perplexity than TTT-Linear, but is slower due to its more complex hidden state. | |||
2. Long Context (up to 32k tokens) | |||
Experiments on the Books3 subset of The Pile show that Mamba's performance plateaus after 16k tokens. | |||
In contrast, TTT models (especially TTT-MLP) continue to improve, similar to how Transformers behave. | |||
TTT-MLP performs best at long context, consistent with its higher expressivity. | |||
3. Latency and Efficiency | |||
In terms of wall-clock time, TTT-Linear is already faster than Transformers at 8k tokens and matches Mamba. | |||
For token generation (decode time), TTT-Linear and Mamba have much lower latency than Transformers. | |||
These efficiency gains are achieved thanks to GPU-aware design, including the use of mini-batch updates and matrix-optimized dual formulations. | |||
4. Scaling and FLOPs | |||
TTT-Linear uses fewer FLOPs than both baselines at equivalent perplexity. | |||
TTT models perform well under the same training compute budgets (following the Chinchilla recipe). | |||
They also maintain quality under increasing model sizes—from 125M to 1.3B parameters. | |||
</div> | |||
= | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
== Group 6 Presentation: Learning to (Learn at Test Time): RNNs with Expressive Hidden States == | |||
=== Presented by: === | |||
Zhiyang Cheng and Pingchu Zhang | |||
=== Paper Citation === | |||
Sun, Y., Li, X., Dalal, K., Xu, J., Vikram, A., Zhang, G., Dubois, Y., Chen, X., Wang, X., Koyejo, S., Hashimoto, T., & Guestrin, C. (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv: https://doi.org/10.48550/arXiv.2407.04620 | |||
=== | === What’s the Big Idea? === | ||
Most RNNs have fixed-size hidden states—they crunch all previous context into a small box and just hope for the best. But what if the hidden state could *learn* and *adapt* as it reads? That’s what this paper is about: making the hidden state itself into a **tiny learnable model** that updates itself on the fly. Think of it like an RNN that’s learning while it's predicting. Cool, right? | |||
They call this approach **Test-Time Training (TTT)**—the model gets smarter with every token it sees, using a self-supervised loss. It’s not just training before deployment; it's adapting in real-time. | |||
=== Key Concepts === | |||
- **TTT Layer**: The hidden state isn’t just a vector anymore—it’s a model (like a mini linear function or MLP) that gets updated at each step. | |||
- **TTT-Linear** and **TTT-MLP**: Two variations—one simple, one more expressive. | |||
- **Dual-form + Mini-batch TTT**: Optimized versions that make training/inference more efficient on GPUs. | |||
=== Why This Matters === | |||
RNNs are fast and efficient—but they've always been limited by their fixed memory. TTT gives them a much-needed upgrade: they can now change behavior on the fly without relying on massive attention layers or key–value caches. | |||
This method: | |||
- Helps models generalize better in few-shot and online learning scenarios. | |||
- Beats Mamba and Transformers in long-context tasks (up to 32k tokens). | |||
- Runs faster at inference time with lower memory use. | |||
=== How It Works === | |||
1. Each token triggers a **self-supervised update** to the hidden state model. | |||
2. The update is like a mini training step: it tweaks the model to better predict the next token. | |||
3. This makes the RNN act like it’s doing meta-learning in real time. | |||
=== Experimental Highlights === | |||
- **Short context (2k–8k tokens)**: TTT-Linear/MLP performs as well as or better than Transformers and Mamba. | |||
- **Long context (up to 32k tokens)**: TTT significantly outperforms Mamba, especially in perplexity. | |||
- **Latency & FLOPs**: TTT models are lean—faster inference, lower computational cost. | |||
- **Scales well**: Works across model sizes from 125M to 1.3B parameters. | |||
=== Strengths of the Paper === | |||
- Fresh perspective—blurs the line between training and inference. | |||
- Strong empirical results—especially in long-context tasks. | |||
- Hardware-friendly—designed to run efficiently on modern GPUs. | |||
- Compatible with existing RNN/Transformer architectures. | |||
=== | === Some Critiques & Considerations === | ||
- Could get unstable—small changes in input = big changes in model weights. | |||
- Needs more testing on real-world, noisy datasets (beyond synthetic ones). | |||
- Might need careful tuning for best performance (learning rate, loss scaling, etc.). | |||
- The idea is deep—might take a few reads to fully grasp. | |||
=== How It Connects to Other Work === | |||
- TTT builds on meta-learning, but goes a step further: it adapts during inference, not just training. | |||
- It shares some goals with Mamba—efficient, long-context modeling—but does it differently. | |||
- Could inspire future models that combine learning and inference in more seamless ways. | |||
</div> | </div> | ||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | <div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | ||
== Group 6 Presentation: Learning to (Learn at Test Time): RNNs with Expressive Hidden States == | |||
=== Presented by: === | |||
Pingchu Zhang and Zhiyang Cheng | |||
=== Paper Citation === | === Paper Citation: === | ||
Sun, Y., Li, X., Dalal, K., Xu, J., Vikram, A., Zhang, G., Dubois, Y., Chen, X., Wang, X., Koyejo, S., Hashimoto, T., & Guestrin, C. (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv. https://doi.org/10.48550/arXiv.2407.04620 | |||
=== Summaries === | |||
This paper proposes a novel approach to enhance the expressiveness and adaptability of Recurrent Neural Networks (RNNs) by allowing the hidden state itself to learn at test time. The authors introduce the concept of Test-Time Training (TTT), where the hidden state is no longer a fixed vector but the parameters of a small model (e.g., a linear layer or MLP) that gets updated dynamically through self-supervised learning. | |||
Traditional RNNs update hidden states using fixed rules during inference, but TTT enables each time step to adapt based on what the model sees, making the RNN itself the learner. This results in better performance for long sequences and tasks requiring rapid, online adaptation—blurring the line between training and inference. | |||
Two versions of the TTT layer are implemented: | |||
TTT-Linear: hidden state as a linear model | |||
=== | TTT-MLP: hidden state as a 2-layer MLP | ||
The method is end-to-end trainable, computationally efficient due to optimizations like dual-form updates, and shows competitive or superior performance compared to Mamba and Transformers, especially at long context lengths. | |||
=== Key Contributions === | |||
Expressive Hidden States: Redefines RNN hidden states as small learnable models updated at test time. | |||
Self-Supervised Test-Time Learning: Treats each input token as a new learning opportunity to improve next-token prediction. | |||
Dual-Form Optimization: Reformulates weight updates into a matrix-based approach, improving computational efficiency and reducing memory usage. | |||
Efficiency with Performance: TTT-Linear runs faster than Transformers and matches Mamba in speed while outperforming both at longer context lengths (up to 32k tokens). | |||
Scalability: TTT models scale well across different sizes (from 125M to 1.3B parameters) and training budgets, following the Chinchilla efficiency principles. | |||
=== Constructive Critiques or Reviews === | |||
While the method introduces exciting adaptability, architectural complexity increases due to the use of hypernetworks and per-step updates. | |||
Stability concerns arise, as small changes in the hidden state could lead to unpredictable behaviors. Careful regularization is needed. | |||
- | Most benchmarks are conducted on controlled or synthetic tasks. Further validation on real-world NLP datasets would enhance the paper’s practical impact. | ||
- | Although efficient, TTT-MLP introduces latency overhead compared to TTT-Linear, limiting its practicality for latency-sensitive applications. | ||
=== Related Works === | |||
- | Mamba: A high-performance RNN that uses state-space models for long-range dependency, but lacks test-time adaptability. | ||
- | Meta-Learning Approaches: TTT shares the spirit of meta-learning but avoids explicit test-time optimization. | ||
HyperNetworks: The idea of dynamically generated weights draws on prior work in hypernetworks, but TTT applies it in an online, token-wise setting. | |||
Gradient-Based Test-Time Adaptation (TTT++): Prior methods apply gradient steps at inference but often require task-specific objectives—TTT generalizes this by embedding learning into the RNN dynamics. | |||
Efficient Transformer Variants: While methods like FlashAttention and Longformer improve Transformer scalability, they do not adapt during inference the way TTT does. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 6 Presentation: Learning to (Learn at Test Time): RNNs with Expressive Hidden States == | |||
=== Presented by: === | |||
Pingchu Zhang and Zhiyang Cheng | |||
=== Paper Citation: === | |||
Sun, Y., Li, X., Dalal, K., Xu, J., Vikram, A., Zhang, G., Dubois, Y., Chen, X., Wang, X., Koyejo, S., Hashimoto, T., & Guestrin, C. (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv. https://doi.org/10.48550/arXiv.2407.04620 | |||
=== Summaries === | |||
To address modern RNNs performance limit in long context due to expressive power of their hidden state of fixed size, this paper proposed a new class of sequence modelling layers with linear complexity and an expressive hidden state. | |||
=== Key Contributions === | |||
This paper RNNs with Expressive Hidden States introduce TTT layers, where the hidden state is a model and the update rule is self-supervised learning. Its implementation of TTT layers called TTT-Linear outperforms Transformers and Mamba in evaluation. Moreover, it also improved the hardware efficiency of TTT layers through mini-batch TTT and the dual form making it practical building block for LLMs | |||
=== Constructive Critiques or Reviews === | |||
Memory I/O bottlenecks for TTT-MLP and scaling to billion-parameter models and million-token contexts. | |||
* Outperforms Mixtral on: | While this approach allows the models to adapt spontanously using only hidden state dynamics, this limits learning capacity compared to models that update weights or use attention mechanisms. In comparrison, Ttransformers model long-range dependencies more effectively through self-attention. This leads to TTT-RNN that often underperform on complex tasks requiring long memory or a global context. Nonetheless, the efficiency gains using this approach may be a step towards efficient sequence predictions. | ||
** Needle-in-a-Haystack retrieval | |||
** Few-shot classification: Banking77, TREC-Fine | </div> | ||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 6 : Learning to (learn at test time) == | |||
=== Presented by : === | |||
Pingchu Zhang and Zhiyang Cheng | |||
=== Overview === | |||
One of the main issues with the transformer architecture is the | |||
quadratic nature of attention mechanism. Specifically, when computing | |||
<math> \text{Softmax}(QK^T) </math>. The size of QK grows | |||
quadratically with the length of the text. As such we may look to | |||
other architectures with the hope that they may alleviate some of | |||
these concerns. For example, RNNs do not suffer from this issue | |||
(scaling with O(n) complexity), but they do suffer from a unique issue | |||
of their own - still related to sequence length. Specifically, RNNs | |||
struggle to recall the contributions of earlier tokens as the sequence | |||
length grows.This is due to the nature of RNN layers, which require | |||
that context be compressed into a hidden layer of a fixed size. This | |||
fact is explored in and attributed to OpenAI's scaling law paper | |||
(Kaplan, et al, 2020) where they showed that LSTMs (RNN subtype) do | |||
not scale like transformers or make use of exceptionally long context. | |||
The authors introduce a solution to this compression issue in RNNs, by | |||
introducing "TTT" layers, which they demonstrate as outperforming both | |||
Transformers and Mamba models for long context problem settings. | |||
=== Governing Principles === | |||
The authors design TTT layers (Test-time training), which update the | |||
hidden state at test time, which they assert as being equivalent with | |||
test time training, hence the nomenclature. They introduce two such | |||
layers, the TTT-Linear and TTT-MLP (multilayer perceptron). Where in | |||
the former the hidden state consists of a linear model and the latter | |||
a two layer MLP. The key difference is that in the naive RNN | |||
implementation, the system state is a vector, whereas in the TTT, it | |||
could be the weights (W) of a linear layer or MLP layer, allowing for | |||
far more information to be stored and retained. Further, by | |||
representing the hidden state with a parameterized model, it can be | |||
refined at test time with a self supervised update. | |||
=== Implementation Details === | |||
For the proposed TTT model, the output of a given layer is: | |||
<math> z_t = f(x_t;,W_t) </math> Where <math> z_t </math> is the | |||
output, x the input token and <math>W_t</math> the learned model | |||
parameters.The key innovation is the addition of self-supervised / | |||
test time learning. That is the model is able to assess its own | |||
ability to retrive information about the sequence given its internal | |||
state, and update these internal states to better retain more | |||
information. This allows the model to take full advantage of the | |||
learnable parameters it replaced the standard embedding with, using | |||
them as an efficient means of compressing information. | |||
=== Discussion / Conclusions === | |||
With TTT layers the authors prove that RNNs can compete with SOTA | |||
transformer models (at the time of publishing), Specifically, they | |||
demonstrate through means of a perplexity metric measured against | |||
sequence length that TTT layers are able to stave off the usual | |||
pitfalls of RNN/LSTMs, effectively capturing longer-range token | |||
dependencies. They further show that TTT-MLP performs slightly better | |||
than its linear counterpart, but this difference is less relevant | |||
outside of exceptionally large contexts. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 6 Presentation: Learning to (Learn at Test Time): RNNs with Expressive Hidden States == | |||
=== Key Innovations and Highlights: === | |||
1. Test-Time Training Layers: | |||
The hidden state isn't a static vector anymore but rather a mini-model itself—like a linear layer or a small neural network. | |||
Updating this hidden state involves a gradient descent step on a self-supervised learning objective, effectively training the model incrementally on every new token it encounters—even during inference. | |||
2. Two Variants of TTT: | |||
TTT-Linear: Uses a linear model as the hidden state. | |||
TTT-MLP: Employs a two-layer MLP, offering more expressive power. | |||
3. Linear Complexity with Transformer-level Expressivity: | |||
Both TTT models maintain linear complexity, crucial for efficiency in handling long sequences. | |||
Unlike traditional RNNs, which plateau in performance beyond certain sequence lengths, TTT layers continue to reduce perplexity effectively as the context grows (demonstrated clearly at 32k tokens). | |||
4. Efficiency and Performance: | |||
TTT-Linear outperforms Transformers and matches/exceeds Mamba, especially at long context (8k and beyond), while significantly reducing the number of computational operations (FLOPs). | |||
With optimizations like mini-batch TTT and a dual-form computation, TTT-Linear matches the computational efficiency of the most advanced RNNs like Mamba. | |||
=== Enhancing Understanding: What Makes TTT Different? === | |||
Think of traditional RNNs as short-term memory devices that quickly become overwhelmed with too much information. Transformers, on the other hand, carry around an ever-growing notebook (Key-Value caches) that's comprehensive but computationally expensive to flip through. TTT layers are like having a dynamic notebook whose pages continually rewrite themselves, efficiently storing key insights learned from recent information and discarding less relevant details. | |||
By treating inference as continuous "micro-training," the model consistently refines its internal understanding, maintaining richer context representations without the typical constraints of fixed-size hidden states. | |||
=== Constructive Critiques and Suggestions: === | |||
1. Memory and I/O Bottlenecks in TTT-MLP: Although highly promising in terms of performance, TTT-MLP suffers from increased memory overhead. Addressing this could further unlock its potential. Possible future work could include more optimized implementations or exploring more compact yet expressive intermediate states. | |||
2. Hyperparameter Sensitivity: Mini-batch size in TTT training greatly impacts performance and computational efficiency. Further research might systematically explore adaptive strategies for selecting this hyperparameter dynamically based on context length or sequence complexit | |||
3. Backbone Compatibility: The authors show better performance using the Mamba backbone architecture, which involves temporal convolutions. It raises a question: Would TTT layers achieve similar gains when integrated with alternative backbones or hybrid approaches? | |||
=== Connections to Related Work: === | |||
Fast Weight Programmers and Hebbian Networks: The concept of updating internal model parameters dynamically has been explored before (e.g., Fast Weight Networks, linear attention models). However, explicitly integrating gradient descent steps as part of inference significantly expands the practical and theoretical possibilities. | |||
Modern RNN Architectures (e.g., Mamba, RWKV): TTT can be viewed as the next evolution of modern RNNs, building upon recent innovations in structured state-space models but overcoming their inherent limitations in handling long contexts. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 7 Presentation: Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains == | |||
=== Presented by: === | |||
Jonathan Gallagher and Mariya Anashkina | |||
=== Paper Citation === | |||
A. V. Makkuva et al., “Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains,” 2024, arXiv. doi: 10.48550/ARXIV.2402.04161. | |||
=== Summaries of key points === | |||
This paper offers a theoretical framework for analyzing transformers using Markov chains, providing a new way to understand how information flows and dependencies are formed across layers. The core idea is to view the self-attention mechanism as inducing a Markov process over the sequence positions, where each attention head defines a transition probability matrix. By interpreting attention through this lens, the authors derive principled explanations for phenomena like context aggregation, token mixing, and the effects of multiple layers and heads. They show that stacking layers corresponds to composing Markov transitions, meaning that deeper transformers can be understood as performing longer-range probabilistic walks over the input. This allows them to make quantitative predictions about attention spread and mixing time, and helps formalize intuitions about how depth and attention head structure influence model behavior. Importantly, the framework applies without modifying the architecture—it’s purely analytical, making it a useful tool for understanding model internals in a mathematically grounded way. | |||
The abstraction into Markov chains may oversimplify certain aspects of how attention operates—particularly when attention scores are data-dependent and influenced by non-linear operations (e.g., softmax and masking). The analysis tends to assume relatively clean or idealized cases, and may not fully capture the nuances of attention in real-world settings like language modeling with positional encoding or context-specific weighting. While the framework is insightful, it’s primarily useful for interpretability and analysis, not for improving model performance directly. Finally, the notion of interpretability here is mathematical, but not necessarily human-friendly—it gives you equations, not explanations in natural language. | |||
=== Clear explanations to aid understanding === | |||
Think of each attention head as deciding where to “move” next in the sequence—like a random walker choosing the next word to look at. If you treat those attention weights as transition probabilities, then attention turns into a kind of Markov process, where information flows step-by-step based on those probabilities. Stacking layers is like walking further in the sequence graph, and having multiple heads is like having multiple walkers with different preferences. This view lets you talk about things like mixing time—how fast information spreads—or how different tokens influence each other over multiple layers. It’s a powerful way to bridge probabilistic modeling and deep learning, especially when trying to reason about what attention is doing beyond just visualizing heatmaps. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 7 Presentation: Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains == | |||
=== Presented by: === | |||
Jonathan Gallagher and Mariya Anashkina | |||
=== Paper Citation === | |||
A. V. Makkuva et al., “Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains,” 2024, arXiv. doi: 10.48550/ARXIV.2402.04161. | |||
<h3>Summary of Key Points</h3> | |||
<p>This paper introduces a theoretical framework that interprets transformers through the lens of Markov chains. The central idea is to model the self-attention mechanism as inducing a Markov process over token positions, where each attention head is seen as defining a transition probability matrix.</p> | |||
<p>This interpretation allows researchers to:</p> | |||
<ul> | |||
<li>Quantify how information propagates across layers (context aggregation).</li> | |||
<li>Understand how attention heads mix token information.</li> | |||
<li>Relate model depth to longer-range probabilistic walks over inputs.</li> | |||
</ul> | |||
<p>Layer stacking is interpreted as composing Markov transitions, and deeper networks thus perform longer walks over input sequences. This allows for formal predictions about mixing times and how tokens influence one another across layers. The framework is analytical and does not alter the transformer architecture, making it useful for interpretability.</p> | |||
<p>However, it assumes idealized settings (e.g., fixed softmax structure, no masking or complex context-dependence), which may limit its real-world applicability. It’s mainly an interpretability tool and doesn't improve transformer performance directly. Moreover, while it provides a mathematically grounded view, its interpretability is theoretical rather than intuitive or human-readable.</p> | |||
<h3>Clear Explanation for Better Understanding</h3> | |||
<p>Imagine each attention head as a random walker choosing which token to move to next. The attention scores then define transition probabilities, turning attention into a kind of Markov process. As you stack layers, it’s like letting the walker take multiple steps.</p> | |||
<p>Multiple attention heads mean different walkers, each with unique preferences. This model helps quantify how quickly information spreads (mixing time) and how tokens influence each other, going beyond mere attention heatmaps.</p> | |||
<p>This probabilistic view bridges deep learning with classical stochastic processes, giving insights into how and why transformers work the way they do.</p> | |||
</div> | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 7 Presentation: Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains== | |||
=== Presented by: === | |||
- Jonathan Gallagher | |||
- Mariya Anashkina | |||
=== Paper Citation === | |||
A. V. Makkuva et al., “Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains,” 2024, arXiv. doi: 10.48550/ARXIV.2402.04161. | |||
=== Background and overview === | |||
'''Markov chains''' | |||
Markov chains are probability models that predict the future state based on the current and prior states of a system. Language can be modeled as a high order Markov process, an idea that was popularized by Claude Shannon in a 1948 paper, and later became a building block for natural language processing tasks in the modern day. In transformers, predictions follow closely to Markov chains. | |||
In this paper, the authors use a simplified model: a two-state (binary) Markov chain that is in its ''stationary distribution'', making the assumption that the Markov chain has been running for an extremely long time. The term stationary distribution refers to a distribution of states that a chain converges towards as the number of steps increases indefinitely; at that point, each subsequent distribution is the same as the previous one. | |||
'''Objectives in this paper''' | |||
The paper introduces a framework to systematically analyze (through the lens of a Markov chain) how transformers learn to model sequential data. The objective is to explore how a transformer succeeds or struggles on first order Markov chains, a distribution that only needs one step of memory. They also do a theoretical analysis of network architecture choices for transformers, and look at how this affects the loss landscape and learning dynamics. | |||
'''Weight tying''' | |||
This is one of the main architectural choices in transformers that are studied in this paper. This is the practice of using the same weights in the input and output layers. Firstly, this means that whether the model is trying to ''read'' or ''generate'' the text, the tokens' representations are the same -- this is done for consistency. Secondly, this can act as a form of regularization as it gives the model fewer parameters to learn. | |||
=== Theoretical Results === | |||
'''Significance of <math> p+q </math>''' | |||
Recall that we are working with binary, first-order Markov processes here. Let <math> p </math> be the probability that a state 0 will turn to 1, and let <math> q </math> be the probability that a state 1 will turn to 0. Consequently, probabilities <math> 1-p </math> and <math> 1-q </math> are the probabilities that states 0 and 1 will remain unchanged, respectively. | |||
The quantity <math> p+q </math> is referred to as a ''switching factor'', and is used to characterize the overall tendency of the chain to change state. When <math> p+q < 1</math>, the system is likely to stay in its current state. When <math> p+q > 1</math>, the system is likely to change its state, and therefore will exhibit oscillatory behaviour. | |||
The authors of the paper provide proof of the existence of the global minimum as well as bad local minima in the loss landscape. Furthermore, they look at the existence of global minima and saddle points in the cases with vs. without weight tying. | |||
'''Theorem 1 (Global minimum).''' Let the input sequence be <math>\{x_n\}_{n=1}^N \sim \bigl(\pi(p,q), P(p,q)\bigr)</math> for some fixed <math>(p,q)\in(0,1)^2.</math> Then for all <math>(p,q),</math> there exists a <math>\theta_{\star}\in\mathbb{R}^{D-d}</math> with an explicit construction such that it is a global minimum for the population loss <math>L(\cdot)</math>. | |||
In other words, the authors are able to prove that for the transformer, there exists an optimal parameter configuration. | |||
'''Theorem 2 (Bad local minimum).''' Let the input sequence be <math>\{x_n\}_{n=1}^N \sim \bigl(\pi(p,q), P(p,q)\bigr)</math> for some fixed <math>(p,q)\in(0,1)^2.</math> If <math>p+q>1,</math> there exists an explicit <math>\theta_{\pi}\in\mathbb{R}^{D-d}</math> such that it is a bad local minimum for the loss <math>L(\cdot)</math> | |||
'''Theorem 3 (Global minimum).''' Consider the same setting as in Thm.~1. Then for all <math>(p, q),</math> if <math>\theta_{\star} = (e_{\star} = a_{\star}, \dots, b_{\star}) \in \mathbb{R}^{D-d}</math> is a global minimum for the loss <math>L(\cdot)</math> in the weight-tied scenario, then its extension <math>\bar{\theta}_{\star} = (\bar{e}_{\star}, \bar{a}_{\star}) \in \mathbb{R}^D</math> is also a global minimum for <math>L(\cdot)</math> in <math>\mathbb{R}^D</math> in the non-weight-tied case. Further, <math>\bar{\theta}_{\star}</math> satisfies the same properties (ii)--(iv) as in Thm.~1. | |||
'''Theorem 4 (Saddle point).''' Consider the same setting as in Thm.~3. For <math>p + q > 1,</math> let <math>\theta_{\pi} = (e_{\pi} = a_{\pi}, \dots, b_{\pi}) \in \mathbb{R}^{D-d}</math> be the corresponding bad local minimum for the loss <math>L(\cdot)</math> in the weight-tied scenario. Then its extension <math>\bar{\theta}_{\pi} = (\bar{e}_{\pi}, \bar{a}_{\pi}) \in \mathbb{R}^D</math> is a saddle point for <math>L(\cdot)</math> in <math>\mathbb{R}^D</math> in the non-weight-tied case. Further, <math>\bar{\theta}_{\pi}</math> satisfies the same properties (ii)--(iv) as in Thm.~2. | |||
=== Main Contributions === | |||
'''Setup for the theoretical framework in this paper''' | |||
* The dataset is first-order and binary. | |||
* The model is a single-layer transformer, with a single attention head and no layer norm. | |||
* The loss function is cross-entropy population loss. Note: the transformer model doesn't know that the data is Markovian; in other words, it is allowed to use the entire sequence's history, as it doesn't know that each step is dependent only on the previous one. | |||
'''Empirical results''' | |||
'''Summary of this paper''' | |||
1. The authors introduce a theoretical framework that models the data source as a Markov process. This allows them to study how transformers learn sequential structure, contrasting it with other approaches that treat training data simply as i.i.d. | |||
2. Focusing on single-layer transformers trained for next-token prediction on first-order Markov data, the paper gives a detailed analysis of the cross-entropy loss surface: There is always a set of parameters that perfectly recover the true Markov transition probabilities (thus achieving the global minimum).When the sum of the Markov chain's flipping probabilities exceeds one, the model can converge to parameters that simply predict the stationary distribution rather than the true transition probabilities. This phenomenon does not arise (or becomes only a saddle) when weights are untied or when the transformer has multiple layers. | |||
3. Through experiments, they show that the theoretical findings match empirical behaviors. In particular: When weights are tied, the model may learn a constant (stationary) probability and fail to leverage sequential context if the transition probabilities are above a certain threshold. Removing weight tying—or increasing the transformer's depth—helps avoid such bad local minima. | |||
4. The authors extend the analysis to higher-order processes. Surprisingly, simply increasing the transformer depth does not guarantee learning of higher-order transitions. They find, however, that restricting the attention window (rather than letting the network attend to all past tokens) dramatically improves learning of higher-order Markov patterns. | |||
=== Related work === | |||
1. Analyzing Transformer Architectures: | |||
Vaswani et al. (2017) introduced the Transformer architecture, which has since been extensively studied to understand its theoretical foundations and practical applications. | |||
Rogers et al. (2020) provided a comprehensive analysis of BERT, a model based on the Transformer architecture, discussing its linguistic capabilities and limitations. | |||
2. Combining Attention Mechanisms with Probabilistic Models: | |||
Bahdanau et al. (2015) introduced the attention mechanism in neural machine translation, allowing models to focus on relevant parts of the input sequence dynamically. | |||
Graves et al. (2014) proposed the Neural Turing Machine, combining neural networks with external memory, enabling models to learn algorithmic tasks. | |||
=== Future direction === | |||
1.Enhanced Interpretability of Transformer Models: | |||
Applying the Markov chain framework to dissect and visualize the internal workings of Transformers could lead to more interpretable models, aiding in identifying and mitigating biases. | |||
2.Development of Hybrid Models: | |||
Integrating Markovian structures with attention mechanisms may result in hybrid models that leverage the strengths of both approaches, potentially improving performance in tasks requiring sequential dependencies. | |||
3.Theoretical Analysis of Attention Dynamics: | |||
Further exploration into the mathematical properties of attention modelled as Markov processes could provide deeper insights into the stability and convergence of Transformer-based models. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 7 Review: An Analysis of Transformers via Markov Chains == | |||
=== Presented by: === | |||
Jonathan Gallagher and Mariya Anashkina | |||
=== Paper Citation === | |||
A. V. Makkuva et al., “Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains,” 2024, arXiv. doi: 10.48550/ARXIV.2402.04161. | |||
=== Critiques Towards Weight Tying === | |||
Potential drawbacks of weight tying includes: | |||
- Reduced Flexibility | |||
- Challenges in optimization: it sometimes introduce gradient conflicts or hinder training stability, especially if tied layers have significantly different roles. | |||
- Task-specific limitations: performance depends on nature of tasks. | |||
=== Review === | |||
I thought this presentation made it really easy to see the connection between Markov Chains, and Transformers when tasked to predict future tokens. The lunch menu analogy made it really easy to see how a first order markov chain behaves, and was a great general idea to discuss before diving deeper into how Markov chains behave in the language world. Theoretical results were also well explained before linking them to empirical results and discussing why the empirical results behave the way they do. | |||
The empirical results were also well explained, showing how removing weight tying or increasing model depth can help escape bad local minima and ensure the transformers parameters are in the global minima for first order markov chains. And how smaller context windows are required to escape bad local minima for higher-order markov chains, even for deeper models (regardless of depth or weight tying). | |||
I also thought this presentation was really well structured, and found that the slides were really easy to follow with great visual aids. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 7 Review: An Analysis of Transformers via Markov Chains == | |||
=== Presented by: === | |||
Jonathan Gallagher and Mariya Anashkina | |||
=== Paper Citation === | |||
A. V. Makkuva et al., “Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains,” 2024, arXiv. doi: 10.48550/ARXIV.2402.04161. | |||
=== Key Contributions === | |||
* A new Markov chains framework for analyzing transformers: model the sequential input as Markov process. | |||
* Proposed theorems for global minimums and bad local minimum and proofs. | |||
* For single-layer transformers, characterized the loss landscape. | |||
* Applied these findings to higher order Markov chains. | |||
Key findings and observations: | |||
* For first order Markov chains, weight tying and transition probabilities significantly affect the loss landscape. Weight tying is when the embedding and linear layers are tied. | |||
* For single-layer transformers, weight tying may introduce bad local minima. To avoid this, they increase the depth of the transformer. | |||
* For higher-order Markov chains, a masking is necessary to correctly predict the probabilities. Increasing the depth for higher-order Markov chains does not significantly affect the performance. | |||
=== Explanations for details === | |||
'''Single-layer Transformer''': The single-layer transformer has a single-head attention with input being binary. | |||
'''Property of Markov chains''': Future steps in Markov chains only depend on the most recent m steps. The probabilities of transition is independent of position. The Markov chain has a steady state when it reaches a stationary distribution <math>\pi</math> and continues to have the same distribution in the future. | |||
'''First-order Markov Chains''': The next step only depend on 1 step in the past, and is independent of all other past steps. | |||
'''Masking''': This limits the scope of the attention layer. The attention is changed from <math>y_n=x_n+W_O \sum_{i \in [n]}{att_{n,i}} \cdot W_V x_i \in \mathbb{R}^d</math> to <math>y_n=x_n+W_O \sum_{i=n-W+1}{att_{n,i}} \cdot W_V x_i</math> where W is the number of symbols the model has attended to. Reducing W has been found to improve the performance. | |||
=== Related Works === | |||
There are many existing works that tried to understand more about transformer models: | |||
* There are works that tried to understand the transformer components (Nanda et al., Progress measures for grokking via mechanistic interpretability. In <math>\textit{The Eleventh International Conference on Learning Representations,}</math> 2023) but lack theoretical guarantees. | |||
* There is work that focused on how transformer models learn the semantic structures (Li et al., How do transformers learn topic structure: towards a mechanistic understanding, In <math>\textit{Proceedings of the 40th International Conference on Machine Learning}</math>, 2023.). | |||
* There is work used optimization to understand the training dynamics and implicit biases for transformers trained with gradient descent (Tarzanagh et al., Transformers as support vector machines. In <math>\textit{NeurIPS 2023 Workshop on Mathematics of Modern Machine Learning}</math>, 2023 and Tarzanagh et al., Max-margin token selection in attention mechanism. In <math>\textit{Thirty-seventh Conference on Neural Information Processing Systems}</math>, 2023). | |||
* There is work that proposed that weight-tying does not positively affect encoder-only transformer models (Chung et al., Rethinking embedding coupling in pre-trained language models. In <math>\textit{International Conference on Learning Representations,}</math> 2021). | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 7 Review: An Analysis of Transformers via Markov Chains == | |||
=== Presented by: === | |||
Jonathan Gallagher and Mariya Anashkina | |||
=== Paper Citation === | |||
A. V. Makkuva et al., “Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains,” 2024, arXiv. doi: 10.48550/ARXIV.2402.04161. | |||
=== Theoretical foundation === | |||
Claude Shannon demonstrated that human communication was able to be reasonably approximated as a higher order Markov process in 1948. | |||
=== Research Objective === | |||
The authors propose a framework to systematically analyze how transformers learn to model sequential data using the perspective of Markov chains. | |||
By generating synthetic data from controlled Markov processes, they can precisely define the relationship between data characteristics, model architecture, and learning performance. | |||
They also aim to theoretically describe the loss landscape of transformers and determine how specific architectural decisions impact learning. | |||
Additionally, they explore how model complexity and the order of Markov chains influence the model’s ability to capture sequential dependencies. | |||
=== Key Methodology === | |||
1. Weight tying | |||
2. Attention masking | |||
=== Conclusions === | |||
For first-order Markov chains: | |||
1.Transformers can readily learn the transition dynamics | |||
2.When p+q > 1 with weight tying, models may get stuck prediciting the stationary distribution | |||
3.Removing weight tying or increasing model depth helps escape bad local minima | |||
For higher-order Markov chains: | |||
1.Standard transformers struggle regardless of depth or weight tying | |||
2.Limiting the context window (masking) dramatically improves learning | |||
3.Surprisingly, deeper models require even smaller context windows | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 7 Review: An Analysis of Transformers via Markov Chains == | |||
=== Presented by: === | |||
Jonathan Gallagher and Mariya Anashkina | |||
=== Paper Citation === | |||
A. V. Makkuva et al., “Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains,” 2024, arXiv. doi: 10.48550/ARXIV.2402.04161. | |||
=== Research Objective === | |||
The author characterize the relationship between data properties, model architecture, and learning performance by generating synthetic data from controlled Markov processes. They characterized the loss landscape of transformers and identify how specific architectural choices affect learning. | |||
=== Key Methodology === | |||
- Weight tying: use the same weights for both input embeddings and the final output layers. This ensures that a token's representation stays the same across the model, leading to more coherent and efficient learning | |||
- Study the loss landscape of transformers when deadling with first order Markov Chain (one step of memory) | |||
- Then study a special class of second order chain, where Xn+1 is influenced only by Xn-1 | |||
=== Conclusions === | |||
Architectural choices such as weight tying and attention masking significantly impact transformers' ability to learn Markovian patterns. For first-order Markov chain, transformers can readily learn the transition dynamics and removing weight tying or increasing model depth can help escape bad local minima. | |||
For higher-order Markov chains, standard transformers struggle regardless of depth or weigh tying but by limiting the masking it can improve learning. Moreover, deeper models require smaller context windows. This suggest that unlimited context isn't always beneficial. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 8 Presentation: MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads == | |||
=== Presented by: === | |||
Nana Ye and Xingjian Zhou | |||
=== Summaries of key points === | |||
This paper introduces MEDUSA, a lightweight yet effective method for accelerating inference in large language models by attaching multiple decoding heads at intermediate layers. The key idea is that, during inference, you can use these heads to predict several future tokens in parallel, reducing the number of sequential steps needed. Unlike speculative decoding (which relies on a separate draft model), MEDUSA uses the same model and architecture, just augmented with extra linear heads that predict future tokens based on intermediate hidden states. These predictions are verified by the base model in a final pass, similar in spirit to Draft & Verify, but with much lower overhead and implementation complexity. Despite its simplicity, MEDUSA achieves competitive speedups—up to 2×—on models like LLaMA, and integrates easily into existing transformer pipelines. It also preserves generation quality well, maintaining high accuracy across benchmarks without requiring retraining. | |||
One potential limitation of MEDUSA is that its performance gains depend on the quality of intermediate predictions—if the early layers aren't predictive enough, the method may yield minimal speedup or introduce verification bottlenecks. Another concern is scalability: adding too many decoding heads could increase memory consumption or introduce architectural clutter. While the paper shows good results on standard benchmarks, it's less clear how MEDUSA performs in more complex decoding scenarios like beam search, sampling with temperature, or instruction-tuned models. Finally, although it's simple to implement, any modification to production LLM inference stacks still carries deployment costs, which MEDUSA somewhat underplays. | |||
=== Clear explanations to aid understanding === | |||
Think of a transformer generating text one word at a time—slow, because each step waits on the previous. MEDUSA says: what if we could guess ahead a few tokens using partial information? It adds small prediction heads at different layers in the model, each trying to guess future tokens before the final layer finishes computing. Once these guesses are made, the base model verifies them. If correct, we skip ahead; if not, we fall back. It’s like speculative decoding, but self-contained—no second model, no complicated setup. You get the parallelism of speculative methods with the simplicity of just tweaking the model's architecture slightly. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 8 Presentation: MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads == | |||
=== Presented by: === | |||
Nana Ye and Xingjian Zhou | |||
=== Summaries of key points === | |||
Goal: By adding multiple "Medusa headers" to the master model, you don't rely on an external model to predict multiple tokens at once. | |||
Background: The core problem of slow inference in large models is not the memory bandwidth bottleneck. Autoregressive decoding generates tokens one by one and has low GPU utilization. Draft model acceleration has the problem of additional model overhead and inconsistent distribution. | |||
Methodology: Use multiple Medusa heads to predict future tokens in parallel. Candidate tokens are organized into Tree Attention to validate multiple sequences simultaneously. Accept the longest prefixes with reasonable probability using the Typicality-based Acceptance strategy. | |||
Result: Qwen7B vs. Zephyr7B model, on the ChatGPT dataset, Medusa 1 accelerates about 2.2x, Medusa 2 accelerates about 2.8x, and some tasks accelerate up to 3.6x, faster and with almost lossless quality. | |||
=== Constructive critiques or reviews === | |||
With in-depth detailed explanation, let the audience understand more deeply. | |||
You can try turning on the camera to increase affinity. | |||
=== Clear explanations to aid understanding === | |||
Medusa 1: Train Medusa head only to save resources. | |||
Medusa 2: Train Main model and Medusa head together. The performance degradation of the main model is avoided by a two-stage strategy | |||
=== Connections to related works === | |||
Mamba: Mamba uses a state-space model for remote dependency capture. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 8 Presentation: MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads== | |||
=== Presented by: === | |||
- Nana Ye | |||
- Xingjian Zhou | |||
=== Paper Citation === | |||
T. Cai et al., “Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads,” 2024, arXiv. doi: 10.48550/ARXIV.2401.10774. | |||
https://arxiv.org/abs/2401.10774 | |||
=== Background === | |||
- As the size of LLMs grow, the speed at which they can generate tokens decreses. The bottleneck is primairly the transfer of data to/from the GPU | |||
- Speculative Sampling is an existing solution that predicts multiple tokens in the future at once using smaller "draft" models | |||
- Medusa instead solves this problem by adding multiple decoding heads and a tree based attention mechanism to existing LLMS | |||
- Paper discusses the implementations of Medusa1 and Medusa2 | |||
=== Main Idea === | |||
The idea is that by replacing the draft model and use the heads within our own model since the best representation of a | |||
model is itself. The multiple heads predicts multiple tokens at once to leverage parallelism. This allows it to be | |||
more efficient and provide more tokens for the tree-based attention to choose. The tree-based attention is used is simulate | |||
the idea as if the tokens are being generated sequentially, by traversing through a tree from top to bottom, where top is the | |||
initial word. | |||
=== Methodology === | |||
During training, each Medusa head is optimized to predict a future token given the same input context. For example, head 3 is trained to predict token x<sub>t+3</sub> using only the context up to x<sub>t</sub>. The objective function is a standard cross-entropy loss over the predicted token distributions. | |||
To avoid error accumulation, the training data for deeper heads (e.g., t+4, t+5) is generated using gold sequences rather than model outputs. | |||
The decoding tree is constructed with a fixed depth and branching factor corresponding to the number of Medusa heads. Unlike beam search, this tree is evaluated in parallel, and scoring is done using a lightweight attention mechanism applied to the hidden states of the candidate paths. | |||
For candidate selection, a method called typical acceptance is introduced as a fast alternative to rejection sampling. It accepts candidates based on whether their token-level probabilities fall within a "typical" range, reducing the number of evaluations needed during decoding. | |||
=== Technical Contributions === | |||
'''Medusa 1:''' | |||
- Uses a frozed pre-trained LLM and trains extra decoding heads on top | |||
- Each additional decoding head predicts a token K time steps in the future | |||
- Uses a probability loss function that scales based on the number of steps into the future | |||
- Reduces memory usage because the backbone model is only used for hidden state extraction | |||
- In simple terms, Medusa adds additional linear layers on top of the last hidden layer from the transformer output which are training to predict the tokens in future positions, rather than just the next token like a conventional transformer mechanism does in a typical auto-regressive manner. | |||
'''Medusa 2:''' | |||
- Fine tunes the LLM and trains the decoding heads at the same time. | |||
- Encountered problems with high losses, switched to a two-stage training process: | |||
- Stage1: train only the Medusa heads (simillar to Medusa1) | |||
- Stage2: Train both the backbone model and the medusa heads together | |||
'''Tree Attention''' | |||
- Tree attention is used to enable the heads predicting later tokens to include the additional context which may have been created by the earlier medusa heads in the pipeline | |||
- This tree structure does not occur autoregressively, however | |||
- The top predictions from each head are fed into the tree structure as candidate tokens | |||
- An attention mask is used to ensure that the future token prediction from the tree is based on prior tokens, not future ones past the one being dedicated | |||
- Multiple future candidate tokens can be predicted with context-aware attention simultaneously | |||
'''Self Distillation''' | |||
- A dataset with prompts relevant to the desired model are created | |||
- The full large language model predicts outputs to these prompts in a typical auto regressive manner. These prompts are used to form a training dataset for the self-distillation step | |||
- Medusa Heads are trained on the generated training dataset | |||
'''Tree Construction''' | |||
- Prune Less Promising Branches: Branches with low probability of containing the next token are pruned from the tree of candidate tokens in tree attention, this reduces the computational expensiveness of MEDUSA 2 | |||
- Select the Best Candidate: From the remaining typical candidates, the longest accepted prefix is chosen for the next decoding step | |||
'''Empirical Evaluation''' | |||
Experiments on various LLMs show consistent 2–3 times speedups in practice without harming output quality (assessed by GPT-4 and other metrics). The authors also include ablation studies on key design choices (number of heads, attention structure, sampling thresholds), confirming the effectiveness and generality of the proposed framework. | |||
=== Constructive Critique === | |||
While Medusa demonstrates promising improvements in decoding speed and flexibility, a few limitations remain: | |||
* Training Stability: Especially in Medusa-2, jointly fine-tuning heads and the base model requires a two-stage schedule with learning rate separation and warmup—indicating some instability. | |||
* Model Complexity: Adding multiple Medusa heads and tree attention introduces architectural complexity, which may hinder adoption or reproducibility without careful engineering. | |||
* No open-source code: As of the paper's release, there is no official implementation, which limits replication and community engagement. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 8 Presentation: MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads == | |||
=== Presented by: === | |||
- Nana Ye | |||
- Xingjian Zhou | |||
=== Paper Citation === | |||
T. Cai et al., “Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads,” 2024, arXiv. doi: 10.48550/ARXIV.2401.10774. | |||
https://arxiv.org/abs/2401.10774 | |||
=== Summaries of key points === | |||
One of the main challenges with LLMs is their slow inference process. It’s not because GPU can’t do the math faster, but because of memory bottleneck (the need to constantly move data between memory and the GPU). LLMs also generate text one token at a time using autoregressive decoding, where each token depends on the previous one, which leads to underutilization of the GPU. | |||
One attempt to address this is '''Speculative Decoding''', where a smaller "draft model" predicts multiple tokens at once, and the main model verifies them. While this speeds up generation, it requires maintaining a separate draft model, and it can lead to mismatches between the draft and main models, making integration into a system difficult. | |||
So, '''Medusa''' is proposed to solve these issues. | |||
==== How does Medusa work? ==== | |||
===== Core Idea of Medusa ===== | |||
Unlike speculative decoding, Medusa doesn’t need a separate draft model. Instead, it adds extra decoding heads to the same model, allowing it to predict multiple tokens at once (generating candidates). This reduces the dependency on previous tokens, making inference faster. Once multiple tokens are predicted, Medusa organizes these predictions into a tree structure, where each node represents a token. This structure helps Medusa evaluate several possible token sequences at once (processing candidates). This is called tree-based attention. After evaluating, Medusa picks the most likely token sequence and outputs the best one (accepting candidates). | |||
===== Training Strategies of Medusa ===== | |||
Medusa has two training strategies to optimize its performance: | |||
*'''Medusa 1''': In this method, the original model (called the backbone) is frozen, meaning its parameters don’t change. Only the Medusa heads are trained. This saves computation and avoids the model forgetting what it learned originally, while improving inference speed by predicting multiple tokens at once. | |||
*'''Medusa 2''': In this approach, both the backbone and the Medusa heads are trained together. This boosts prediction accuracy, especially in larger models. The training starts with a two-stage process to prevent issues like high loss and large gradients from the new Medusa heads that could disrupt the backbone model. In Stage 1, only the Medusa heads are trained to specialize without affecting the main model. In Stage 2, both the Medusa heads and the backbone are trained together, with a warm-up strategy to gradually increase the learning rate for smoother adaptation. The tree attention mechanism also helps organize token continuations in a tree, allowing the model to evaluate multiple possibilities at once, speeding up inference. | |||
===== Further enhancements of Medusa ===== | |||
The author further enhances Medusa's practical utility with three significant extensions: | |||
1. Typical Acceptance Scheme: Instead of rejecting candidates based on strict probability thresholds, Medusa evaluates them based on how typical they are compared to the original model’s distribution. This speeds up the process without sacrificing quality. | |||
2. Self-Distillation: Medusa can generate training data from its own output. It creates a seed dataset and then uses that data to train the Medusa heads, which helps improve the model’s efficiency in making predictions. | |||
3. Optimized Trade Structure: Medusa improves how the model evaluates candidates by focusing on the most promising tokens, making the inference process faster and more efficient. | |||
==== Benefits and Limitations ==== | |||
Medusa has shown great results in experiments with models like Riccuna-7B and Riccuna-13B, achieving up to 2.8 times faster performance than traditional methods, with even higher speedups for tasks like text extraction (3.62x) and coding(3.29x). It consistently outperformed speculative decoding, reaching 2.83x acceleration with Riccuna-7B, compared to 1.47x with speculative decoding. Despite the speed improvements, Medusa maintains high text quality, making it efficient without compromising accuracy. The tree-based attention mechanism further boosted speed by 1.9x. However, Medusa has some limitations, such as higher memory usage due to additional heads and the tree attention mechanism, and it can be challenging to scale for larger models and complex tasks. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 8 Presentation: MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads == | |||
=== Presented by: === | |||
Nana Ye and Xingjian Zhou | |||
=== Motivation === | |||
Large language models have become increasingly prolific, but anyone who has used one, or paid to use one knows, that the inference cost is not insignificant. If one is familiar with how next token prediction works, they may ask, what if i predict multiple tokens in a single forward pass instead of a single token? That is the question that the authors of the Medusa paper asked, and what they tried to do. | |||
=== Operating principle === | |||
In simple terms, the medusa model manages to predict multiple probability distributions over possible next tokens by learning multiple decoding heads. That is, provided a string of input tokens and the corresponding embeddings, multiple decoding heads are used on the same token, with each head predicting a distinctly different distribution, corresponding to the following distributions over the dictionary: <math> p_i, p_{i+1},...,p_{i+3} </math>. In the paper they use a total of 4 heads. The result is that each forward pass will predict four tokens instead of one. | |||
=== Training Strategies === | |||
The authors introduce two training strategies, with the goal that one will be able to train Medusa without needing access to an HPC, that is they want to demonstrate that you can download an open source model Like Meta's LLama, and modify it to use medusa with limited computing resources. | |||
The training strategies are: | |||
Medusa 1: Frozen Backbone | |||
This is the simplest possible training sequence. The user is going to essentially freeze the weights of their pretrained model, and train only the weights of our multiple MEDUSA heads, by computing the cross-entropy loss between our "n" token predictions from our MEDUSA heads and the next "n" tokens in the ground truth. | |||
Specifically, if we have the ground truth token for a specific token: | |||
<math> y_{t+k+1} </math> the associated loss with the <math> k^{th} </math> head is simply : <math> L_k = - logp_t^{(k)}(y_{t+k+1}) </math> where <math> p_t^{(k)}(y) </math> is the probability of token y being predicted by the <math> k^th </math> head. The authors note that the loss for each k grows with k, which makes intuitive sense as the further away you are trying to predict from the last token, the more uncertainty in your forecasting. This is one of the limitations of Medusa, but not one anybody should be surprised by. The total medusa loss is just the sum over each of these idnividual head losses weighed by some <math> \lambda_k </math> which is typically set as something like the k power of a constant < 1.0. This is so that the model treats the tokens which are closest to the last known token as more important. | |||
Medusa 2: Joint Training | |||
Joint training is a slightly more sophisticated approach than frozen backbone. When using joint training we consider two sources of training loss, that is the standard LLM loss, coupled with the medusa loss. This method can be beneficial as it provides some degree of fine tuning to the foundation model so it can learn embeddings that are more complementary to the multi head approach. | |||
<math>L_{total} = L_{LM} + \alpha \sum_{k=1}^{K} \lambda_k L_k</math> | |||
We balance the losses to account for the fact that the backbone loss is likely to be very small, and we do not want the medusa loss to dominate early on. Tuning these parameters is important so as to maintain learnability in the medusa heads but not lose the foundation models functionality. | |||
=== Results / Conclusion === | |||
Overall, medusa models prove themselves an interesting option for those wishing to improve the efficiency of foundation models, offering flexibility for users based on available computational resources. Medusa's predictive accuracy naturally falls for tokens further along in the sequence, but for instances where one wants a model to perform while respecting hardware limitations, such as a locally hosted LLM or a distilled LLM on a mobile device, medusa models could still prove a very efficient option. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 8 Presentation: MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads == | |||
=== Presented by: === | |||
- Nana Ye | |||
- Xingjian Zhou | |||
=== Paper Citation === | |||
T. Cai et al., “Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads,” 2024, arXiv. doi: 10.48550/ARXIV.2401.10774. | |||
https://arxiv.org/abs/2401.10774 | |||
=== Motivation === | |||
Slow inference with LLM: memory-bound issue & latency bottleneck, which causes the underutilizing of GPU computational potential. | |||
Existing methods accelerate inference but introduce complexity and integration challenges. | |||
=== Main idea === | |||
MEDUSA replace the complexity of speculative decoding by adding multiple lightweight decoding heads to existing Large Language Models. | |||
Key components: No draft model; Decoding Heads; Tree-based attention | |||
=== Extension === | |||
1. Typical Acceptance | |||
Problem: Rejection sampling is inefficient at high temperatures. | |||
Solution: | |||
A candidate is accepted if its probability is above a certain threshold, which is adjusted based on the entropy, since higher enthropy indicates more uncertainty in the model's predictions, allowing for a broader range of candidates to be considered typical. | |||
2. Self-Distillation | |||
Select seed dataset - Generate Responses - Create Training dataset - Training with self-distillation | |||
3. Optimized Tree Construction | |||
=== Practical Advantages === | |||
1. Integration-friendly | |||
2. Scalable and efficient | |||
3. Resource-efficient | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 9 Presentation: Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality== | |||
=== Presented by: === | |||
- Kaiyue Ma | |||
- Wenzhe Wang | |||
=== Paper Citation === | |||
T. Dao and A. Gu, “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality,” 2024, arXiv. doi: 10.48550/ARXIV.2405.21060. | |||
=== Background === | |||
- Transformers are effective, but computationally expensive and suffer from quadratic complexity | |||
- Structured state space models (SSMs) are an alternative that scales linearly instead and works for long range modeling | |||
- SSMs have not recieved the same main stream improvements as transformers, and lack support for parallelization and hardware acceleration | |||
- Structured state space duality (SSD) bridges the gaps between transformers and SSMs | |||
=== Additional Background === | |||
SSM(State Space Models) are traditionally used in control theory to model a dynamic system via variables. But then from this paper | |||
https://compneuro.uwaterloo.ca/files/publications/voelker.2018.pdf, they discovered that SSM is great for describing the time cells | |||
in the brain. A useful diagram of a SSM can be found here https://cdn-uploads.huggingface.co/production/uploads/613b0a62a14099d5afed7830/G7icfkYoxIqHZcJGHM7UD.png, where, n state variables, u, m state inputs, and y, p outputs. | |||
=== Technical Contributions === | |||
- Represents SSMs as semiseparable matrices and uses semiseparable matrices for efficient matrix operations. | |||
- Uses generalized linear attention mechanism with structured masks. | |||
- Refines the original Mamba model to yield Mamba-2, which incorporates the new structured-state-space-duality algorithms. Mamba-2 is easily parallelizable and scales better to large state sizes. Empirical results demonstrate strong performance on language modelling benchmarks, surpassing older SSM models and matching or outperforming Transformers at various model scales. | |||
=== Summaries of Key Points === | |||
The paper explores the theoretical connections between Transformer architectures and Structured State Space Models (SSMs). The authors introduce the State Space Duality (SSD) framework, which bridges these two model families through the concept of structured semiseparable matrices. This framework reveals that certain attention mechanisms in Transformers can be interpreted as SSMs, providing a unified perspective on sequence modelling techniques. | |||
Leveraging the SSD framework, the authors propose a new architecture called Mamba-2. This model refines the selective SSM approach used in the original Mamba model, resulting in a design that is 2-8 times faster while maintaining competitive performance in language modelling tasks. Mamba-2 achieves this efficiency by simplifying the SSM layer, enabling better scalability and computational speed. | |||
The paper also introduces efficient algorithms based on block decompositions of semiseparable matrices, which enhance the computational efficiency of SSMs. These algorithms allow for larger recurrent state sizes and improve the practicality of SSMs in handling long-range dependencies within sequences. | |||
Empirical evaluations demonstrate that Mamba-2 outperforms both its predecessor and Transformer models in terms of training efficiency and performance on language modelling benchmarks. The architecture also shows superior capabilities in associative recall tasks, highlighting its effectiveness in capturing and utilizing long-range dependencies. | |||
In summary, the paper provides a theoretical foundation connecting Transformers and SSMs, introduces the Mamba-2 architecture as a practical application of this theory, and presents algorithms that enhance the efficiency and scalability of sequence modelling techniques. | |||
=== Constructive Critique === | |||
While the paper introduces a powerful theoretical framework through State Space Duality (SSD) and demonstrates strong empirical performance with Mamba-2, several areas could be further clarified or improved: | |||
* Theoretical accessibility: The concept of semiseparable matrices and duality between attention and SSMs is mathematically rich, but not easily accessible to a broader audience. Including more visual or intuitive explanations would improve its pedagogical impact. | |||
* Benchmark diversity: Most experiments focus on language modeling tasks. It remains unclear how Mamba-2 performs on other domains such as vision, speech, or reinforcement learning. Since SSD is a general framework, cross-domain validation would help showcase its broader applicability. | |||
* Scalability limitations: While Mamba-2 is more efficient than its predecessor, the paper doesn’t fully discuss how performance scales with increasing model depth or state size, especially under training constraints on real-world hardware. | |||
* Lack of interpretability analysis: The paper does not explore how the SSD framework or Mamba-2 influences model interpretability (e.g., how information is propagated or stored over long sequences), which could be important for downstream applications. | |||
Despite these limitations, the paper makes a substantial theoretical and practical contribution by unifying two dominant modeling paradigms and offering a concrete architecture (Mamba-2) that is both efficient and performant. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 9 Presentation: Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality== | |||
=== Presented by: === | |||
- Kaiyue Ma | |||
- Wenzhe Wang | |||
=== Paper Citation === | |||
T. Dao and A. Gu, “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality,” 2024, arXiv. doi: 10.48550/ARXIV.2405.21060. | |||
=== Background & Motivation=== | |||
The paper aims to unify state-space models (SSMs) and attention mechanisms through structured state-space duality (SSD), demonstrating that SSMs can be interpreted as a form of masked attention with semiseparable matrices. This approach enables the utilization of attention's hardware efficiency (e.g., optimized matrix multiplications) while preserving the linear scaling property of SSMs. Although Mamba's selective SSM is powerful, it is slower than optimized attention due to its reliance on sequential scans rather than direct matrix operations. The authors propose methods to accelerate SSMs by 2–8× without compromising performance or even enhancing it. By reformulating SSMs as matrix transformations, the paper offers novel theoretical insights, such as their equivalence to semiseparable matrices, along with practical algorithms like block decomposition for efficient computation. These contributions pave the way for hybrid architectures (e.g., Mamba-2 augmented with attention layers) and improved system-level support (e.g., tensor and sequence parallelism). | |||
=== Key Points=== | |||
The paper titled "Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality" investigates the theoretical and practical connections between Transformers and State Space Models (SSMs), with a particular emphasis on structured state-space duality (SSD). The key contributions of the paper include: | |||
1. '''Duality Framework''': The authors introduce Structured State-Space Duality (SSD), a framework that establishes a connection between State Space Models (SSMs) and attention mechanisms via structured matrices, particularly semiseparable matrices. This duality enables SSMs to be interpreted as matrix transformations, thereby uncovering novel algorithmic and architectural insights. | |||
2. '''Efficiency Improvements''': The paper introduces the Mamba-2 architecture, which enhances the selective SSM of Mamba to achieve 2–8× faster computation while preserving competitive performance in language modeling tasks. This improvement is realized through the utilization of hardware-efficient matrix multiplications and block decompositions of semiseparable matrices. | |||
3.'''Structured Masked Attention (SMA)''': A generalization of linear attention is introduced, in which the attention mask is replaced by a structured matrix (e.g., semiseparable matrices). This substitution enables subquadratic computational complexity and facilitates efficient autoregressive inference. | |||
4. '''Hybrid Models''': The paper demonstrates that combining SSMs with attention layers (e.g., 10% attention layers in Mamba-2) can improve performance, suggesting complementary strengths between the two paradigms. | |||
=== Contirbutions=== | |||
* '''Theoretical Connections''': It establishes a rigorous equivalence between SSMs and semiseparable matrices, unifying recurrent, convolutional, and attention-based sequence models under a single framework. | |||
* '''Algorithmic Innovations''': The SSD algorithm optimizes SSM computation by blending linear (recurrent) and quadratic (attention-like) forms, achieving linear complexity while leveraging modern hardware. | |||
* '''Mamba-2 Architecture''': This new architecture improves upon Mamba by simplifying projections, enabling tensor parallelism, and incorporating larger state dimensions, resulting in better scalability and efficiency. | |||
* '''Empirical Validation''': The authors validate Mamba-2 on synthetic tasks (e.g., associative recall) and language modeling, showing it outperforms Mamba and matches or exceeds Transformer++ in perplexity and downstream tasks. | |||
=== Constructive Critiques=== | |||
* '''Expressivity Trade-offs''': The adoption of scalar-identity structure for A matrices in Structured State-Space Duality (SSD) may constrain model expressivity relative to general diagonal State Space Models (SSMs). The paper could provide a more in-depth analysis of the trade-offs between hardware efficiency and model flexibility. | |||
* '''Attention Approximation''': The negative results observed for kernel approximations (e.g., Performer, cosFormer) in Mamba-2 indicate that the advantages of Structured State-Space Duality (SSD) may not fully translate from linear attention mechanisms. A more in-depth investigation into the reasons for the underperformance of these methods could further enhance the study. | |||
* '''Broader Applicability''': The focus is heavily on language modeling. Evaluating SSD on other domains (e.g., vision, reinforcement learning) could demonstrate its generalizability. | |||
* '''Implementation Complexity''': Although the SSD algorithm is simpler than Mamba's selective scan, its block decomposition may still present implementation challenges for practical adoption. Conducting additional ablation studies on parameters such as chunk size and parallelism levels could provide valuable guidance for practitioners. | |||
=== Relationships to Other Works=== | |||
The paper extends research in efficient sequence modeling, connecting various approaches through a unified framework. It builds on recent progress in State Space Models (SSMs), particularly from S4 to Mamba. S4 and S4D introduced diagonal-plus-low-rank matrices for long-range modeling, while Mamba's selective SSMs improved performance on dense data like language. The SSD framework generalizes these models using semiseparable matrices and introduces hardware-aware optimizations, making Mamba-2 significantly faster. | |||
Connections to linear attention methods form another key thread. The paper generalizes Katharopoulos et al.'s linear attention with structured masked attention via semiseparable matrices. This links SSD to models like RetNet (fixed exponential decay) and GateLoop (input-dependent gating). GLA's chunkwise computation resembles SSD's block decomposition but lacks SSD's theoretical unification. | |||
The work also intersects with efforts in efficient recurrent architectures. RWKV's attention-like gating shares similarities with SSD's matrix-based approach, though SSD offers a more rigorous mathematical foundation. Griffin's combination of SSMs with local attention and xLSTM's expanded state dimensions align with SSD's themes, suggesting SSD provides a unifying perspective. | |||
On the systems side, the paper complements hardware-efficient Transformer implementations. While FlashAttention optimizes attention kernels, SSD advances SSM-based models. Monarch Mixer's structured matrices share some ideas with SSD but apply them differently. These connections highlight SSD's contribution to efficient deep learning architectures. | |||
Theoretically, SSD bridges modern sequence modeling with classical numerical linear algebra. Semiseparable matrices connect to structured matrix computations, offering new insights into model representation. This grounding may inspire future algorithmic improvements. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 9 Presentation: Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality== | |||
=== Presented by: === | |||
- Kaiyue Ma | |||
- Wenzhe Wang | |||
=== Paper Citation === | |||
T. Dao and A. Gu, “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality,” 2024, arXiv. doi: 10.48550/ARXIV.2405.21060. | |||
=== Introduction === | |||
The paper "Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality", by Sri Dao and Albert Gu establishes a theoretical framework connecting Transformers and State Space Models - SSMs. | |||
This framework, termed Structured State Space Duality - SSD - brings these two prominent sequence architectures leading to the development of Mamba-2. This model enhances efficiency while painting competitive performance in LLMs. | |||
=== Key Contributions === | |||
==== (1) Establishing Structured State Space Duality ==== | |||
The authors demonstrate that structured SSMS and attention mechanisms are closely related through structured matrices, specifically semiseperable matrices. This insight reveals that various sequence models can be interpreted as different parametrization of these matrices, providing a unified understanding. | |||
Two perspectives are introduced to implement this duality: | |||
Matrix representation: viewing sequence models as matrix transformations highlighting how SSmS can be represented using semiseparable matrices. This has sub-quadratic parameters. | |||
Tensor Contraction Representation: This illustrates how the computations in Attention mechanisms can be reformulated in terms of tensor contractions; aligning them with SSM operations | |||
By framing SSMs as well as attention mechanisms within the SSD framework, the paper enables the transfer of algorithmic and system optimizations between these models, fostering advancements in efficiency and stability. | |||
==== (2) Development of Mamba-2 Architecture ==== | |||
Leveraging the SSD framework, Mamba-2 is an improved version of the Mamba-1 architecture. Mamba-2 refines in the selective SSM layer, and results in a significantly faster enhancement. | |||
Mamba-2 therefore achieves 2-8 times faster performance compared to its predecessor while maintaining competitiveness with transformers in language modelling tasks. This demonstrates the practical benefits of applying the SSD design. | |||
==== (3) Efficient Algorithms Through SSD ==== | |||
The paper presents efficient algorithms derived from the SSD framework that optimize the computation. These algorithms reduce the complexity often associated with traditional sequential models. | |||
SMA, a novel attention variant, is introduced which benefits the structured properties of the SSD. This leads to a more efficient attention computation. | |||
=== Applications and Impact === | |||
The SSD framework offers a new paradigm for designing sequence models, allowing practitioners to harness the strength of both SSMs and transformers. This leads to models that are both computationally efficient and effective in capturing long-range dependencies. | |||
By reducing computational complexity, the insights from this paper facilitate the development of models that can handle longer sequences and larger datasets; addressing a common limitation in sequence modelling, thereby allowing scalability. | |||
Finally, the theoretical connections established within the paper table the application of optimization techniques across different model architectures. This lays the foundations for more unified and efficient approaches to sequence modelling through unisons. | |||
This work not only enhances theoretical understanding but also leads to practical advancements, exemplified by the development of the Mamba-2 Architecture. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 9 Presentation: Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality== | |||
=== Presented by: === | |||
- Kaiyue Ma | |||
- Wenzhe Wang | |||
=== Paper Citation === | |||
T. Dao and A. Gu, “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality,” 2024, arXiv. doi: 10.48550/ARXIV.2405.21060. | |||
=== Summaries of key points === | |||
Goal: SFB connects SSM and Transformer, combining the best of both. | |||
Background: Transformer can handle long distance dependencies but is computationally complex. SSM model is linear complexity, but it is difficult to fully parallel and accelerate. | |||
Methodology: The SFB framework, with a semi-separable matrix structure, relates the SSM to the attention mechanism. The SSD algorithm is designed, the matrix is divided into diagonal blocks and non-diagonal blocks, and multi-head sequence transformation, parallel projection and kernel methods are added. | |||
Result: Mamba2 has the same performance as Transformer but 2-8 times faster training. | |||
=== Constructive critiques or reviews === | |||
There are clear illustrations, so that the audience can be more concise and intuitively understanding. | |||
It's well structured, you can get a little more detail into place, and you can include more detail in your slides. | |||
=== Clear explanations to aid understanding === | |||
Semi-separable matrix: Compress a large matrix into structured blocks that are easier to compute. | |||
Block decomposition of SSDS: a matrix diagonal for attention work, non-diagonal for fast recursion. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 9 Presentation: Transformers are SSMs – Generalized Models and Efficient Algorithms Through Structured State Space Duality == | |||
=== Presented by: === | |||
Kaiyue Ma and Wenzhe Wang | |||
=== Paper Citation === | |||
T. Dao and A. Gu, “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality,” arXiv, 2024. doi: 10.48550/ARXIV.2405.21060 | |||
=== Summaries === | |||
This paper proposes Structured State Space Duality (SSD), a novel framework that unifies Transformers and State Space Models (SSMs). By leveraging semiseparable matrices, SSD connects the two models both theoretically and computationally. | |||
The SSD framework enables SSMs to be interpreted as a form of structured attention, allowing them to adopt hardware-efficient matrix operations typically used in attention mechanisms. Based on this insight, the paper develops the Mamba-2 architecture—an SSM-based model that is 2–8× faster than previous designs, while maintaining or even improving performance on language modeling tasks. | |||
The paper offers both theoretical contributions (matrix/tensor formulations of sequence models) and practical improvements (block decomposition, structured masked attention), presenting a new paradigm for designing fast and expressive sequence models. | |||
=== Key Contributions === | |||
Structured State Space Duality (SSD) | |||
Establishes a bridge between SSMs and attention via semiseparable matrices. | |||
Two perspectives: | |||
Matrix Representation: Shows how SSMs can be implemented as structured matrix transformations with sub-quadratic properties. | |||
Tensor Contraction Representation: Reformulates attention as tensor operations, aligning attention and SSMs under a shared computational lens. | |||
Mamba-2 Architecture | |||
Improves on Mamba by simplifying selective SSM layers. | |||
Supports tensor and sequence parallelism, enabling scalability and hardware optimization. | |||
Maintains high accuracy while achieving significant speed-ups. | |||
Efficient Algorithms via SSD | |||
Proposes block decomposition for semiseparable matrices to allow fast, parallel computation. | |||
Introduces Structured Masked Attention (SMA), which generalizes linear attention using structured matrices, reducing inference complexity while preserving expressiveness. | |||
Hybrid Sequence Models | |||
Demonstrates that models like Mamba-2 can benefit from a small percentage of attention layers (e.g., 10%), leading to performance boosts without major efficiency trade-offs. | |||
=== Constructive Critiques or Reviews === | |||
Expressivity Trade-off: SSD uses scalar-identity matrix A for simplicity, but this may limit modeling flexibility compared to more general SSMs. | |||
Domain Specificity: Experiments focus heavily on NLP tasks. Broader validation across vision or RL tasks would strengthen the case for SSD. | |||
Implementation Overhead: Although simpler than Mamba-1, SSD’s block decomposition still adds engineering complexity. Ablation studies on chunk size and parallelism could help guide practical usage. | |||
Kernel Approximation Limitations: Some approximations (like Performer or cosFormer) underperform when used within SSD, suggesting limitations in directly porting techniques from linear attention to SSMs. | |||
=== Related Works === | |||
SSM Lineage: Builds on S4, S4D, and Mamba, continuing the trend of enhancing long-range modeling with structured matrices. | |||
Linear Attention Connections: Extends ideas from Katharopoulos et al., RetNet (exponential decay), and GateLoop (input-dependent gating) through the use of Structured Masked Attention. | |||
Recurrent Model Efficiency: Shares similarities with RWKV’s attention-like gating and Griffin’s hybrid SSM-attention strategy. SSD offers a theoretically grounded unification of these approaches. | |||
System-Level Optimization: Complements tools like FlashAttention by advancing SSM-based models. Monarch Mixer shares matrix structuring ideas but differs in architecture. | |||
Mathematical Foundations: Draws on classical numerical linear algebra via semiseparable matrices, opening new directions in sequence model efficiency and design. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 9 Presentation: Transformers are SSMs – Generalized Models and Efficient Algorithms Through Structured State Space Duality == | |||
=== Presented by: === | |||
Kaiyue Ma | |||
Wenzhe Wang | |||
=== Paper Citation === | |||
T. Dao and A. Gu, *Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality*, 2024, arXiv: [10.48550/ARXIV.2405.21060] | |||
=== So, What’s This All About? === | |||
This paper tackles a big idea: Transformers and Structured State Space Models (SSMs) aren’t as different as they seem. What if we told you they’re actually two sides of the same mathematical coin? | |||
Using something called **Structured State Space Duality (SSD)**, the authors show that attention mechanisms in Transformers can be seen as a special case of SSMs—specifically, semiseparable matrix operations. This realization lets them create models that combine the efficiency of SSMs with the expressiveness of attention. | |||
And yes, they take it a step further by introducing **Mamba-2**, a faster, more parallel version of the original Mamba. | |||
=== What’s New and Cool? === | |||
Here are the main contributions from the paper: | |||
**(1) Structured State Space Duality** | |||
- SSD shows how to map attention to SSMs using semiseparable matrices. | |||
- You can view both models through matrix and tensor representations, revealing a deep mathematical connection. | |||
**(2) Mamba-2 Architecture** | |||
- Mamba-2 is a refined version of Mamba, now optimized using SSD principles. | |||
- It’s **2–8× faster** than the original Mamba, with the same or better performance on language modeling benchmarks. | |||
- Plus, it scales to large state sizes and runs smoothly on modern hardware. | |||
**(3) Efficient Algorithms for Long Sequences** | |||
- By leveraging SSD, they optimize matrix multiplications and block decompositions. | |||
- The result? Efficient, parallel-friendly inference—even for very long input sequences. | |||
**(4) SMA (Structured Masked Attention)** | |||
- A new take on linear attention, replacing it with a structured matrix formulation. | |||
- Maintains accuracy, reduces complexity, and plays well with modern accelerators. | |||
=== Real-World Impact & Experiments === | |||
- **Language Modeling**: Mamba-2 beats or matches Transformers on perplexity across standard datasets. | |||
- **Speed**: Achieves faster inference with lower computational costs. | |||
- **Scalability**: Easily extends to larger models and longer contexts thanks to block-structured matrices. | |||
- **Hybrid Potential**: Shows that combining attention + SSMs (hybrid models) may give the best of both worlds. | |||
=== Let’s Talk Strengths === | |||
- The paper builds a solid theoretical bridge between Transformers and SSMs | |||
- Offers practical improvements—Mamba-2 isn’t just math; it works! | |||
- Hardware-friendly: optimized for real-world use, not just academic theory | |||
- Great generalization to other architectures (like FlashAttention, Monarch Mixer, xLSTMs) | |||
=== Any Caveats or Challenges? === | |||
- **Math-heavy**: SSD is powerful, but not easy for everyone to digest. More visuals/examples would help. | |||
- **Narrow evaluation**: Most experiments focus on language modeling. Other areas (like vision, speech) need testing. | |||
- **Interpretability**: With all these matrix tricks, it's harder to see how info flows through the model. | |||
- **Implementation complexity**: Setting up SSD-based models isn’t as plug-and-play as Transformers yet. | |||
=== Helpful Analogies & Clarifications === | |||
- **Semiseparable matrices**: Think of compressing a huge attention matrix into smarter, structured blocks. | |||
- **Block decomposition**: Imagine breaking attention into diagonal + off-diagonal parts—each optimized for a different type of sequence interaction. | |||
- **SSD = a unifying lens**: It’s like saying Transformers and SSMs are just different “views” of the same underlying engine. | |||
=== Final Thoughts === | |||
This paper doesn’t just tweak existing models—it proposes a unified theory that explains and improves both Transformers and SSMs. With SSD, we can create hybrid models that are faster, more efficient, and still powerful. And with Mamba-2, we get real performance gains without sacrificing quality. | |||
It’s an exciting step forward that’s both theoretical and practical—bridging two major model families and opening up new paths for scalable, efficient sequence modeling. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 9 Presentation: Transformers are SSMs – Generalized Models and Efficient Algorithms Through Structured State Space Duality == | |||
=== Presented by: === | |||
Kaiyue Ma and Wenzhe Wang | |||
=== Paper Citation === | |||
T. Dao and A. Gu, “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality,” arXiv, 2024. doi: 10.48550/ARXIV.2405.21060 | |||
=== Objective === | |||
Develop connections between Transformers & SSMs | |||
Transformers: Effective but computationally expensive. SSMs:Linear complexity, efficient long-range modeling. | |||
=== Contributions === | |||
1. Introduce SSD framework | |||
SSD bridges attention and SSMs by structured matrices. | |||
2. Present Mamba-2 architecture with improved performance and speed. | |||
=== Experimental Results === | |||
Mamba-2 trains 2-8 times faster than original Mamba. | |||
Improved long-sequence modeling. | |||
Surpasses standard Transformer benchmarks | |||
=== Future Directions === | |||
structured matrices, efficiency enhancements, broader applications | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 9 Presentation: Transformers are SSMs – Generalized Models and Efficient Algorithms Through Structured State Space Duality == | |||
=== Presented by: === | |||
Kaiyue Ma and Wenzhe Wang | |||
=== Paper Citation === | |||
T. Dao and A. Gu, “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality,” arXiv, 2024. doi:10.48550/ARXIV.2405.21060 | |||
=== Objective === | |||
Explore the theoretical and practical connection between Transformers and State Space Models (SSMs). | |||
Transformers: Powerful sequence models, but suffer from quadratic computational costs. | |||
SSMs: Linear time complexity, well-suited for long-range dependencies. | |||
=== Key Contributions === | |||
Structured State Space Duality (SSD) Framework | |||
Establishes a formal link between attention mechanisms and SSMs via structured matrices. | |||
Unifies seemingly distinct modeling approaches under a common mathematical foundation. | |||
Introduction of Mamba-2 Architecture | |||
An enhanced SSM-based model demonstrating significant improvements in training speed and sequence modeling capabilities. | |||
Efficient and scalable, designed to outperform traditional Transformers. | |||
=== Experimental Highlights === | |||
Training Speed: Mamba-2 trains 2× to 8× faster than the original Mamba. | |||
Performance: Achieves state-of-the-art results on several long-sequence benchmarks. | |||
Efficiency: Maintains low computational overhead while improving expressivity. | |||
=== Future Directions === | |||
Further exploration of structured matrix representations. | |||
Optimization techniques to boost training and inference efficiency. | |||
Broader integration into diverse domains such as vision, language, and time-series analysis. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 10 Presentation: Accelerating Large Language Model Decoding with Speculative Sampling== | |||
=== Paper Citation === | |||
C. Chen, S. Borgeaud, G. Irving, J.-B. Lespiau, L. Sifre, and J. Jumper, ‘Accelerating Large Language Model Decoding with Speculative Sampling’, Feb. 02, 2023, arXiv: arXiv:2302.01318. doi: 10.48550/arXiv.2302.01318. | |||
https://arxiv.org/abs/2302.01318 | |||
=== Background === | |||
- Traditional autoregressive decoding for large language model is computationally expensive, as the entire model has to run for each additional token which is generated | |||
- Transformer neural networks are typically memory bandwidth limted, and using quantization or distillation to make smaller models are solutions which have been used in the past to improve the performance of LLMs | |||
=== Technical Contributions === | |||
- Speculative sampling was developed to increase the speed of LLM decoding without meaningfully reducing it's performance on predicting future tokens compared to the base model | |||
- Generates a draft of k future tokens using a smaller model | |||
- Score the proposed tokens using the base model | |||
- A modified rejection sampling scheme was developed by the authors in the paper | |||
- The acceptance of a draft token is based on the minimum of 1 and the ratio of the target model's probability to the draft model's probability for that token. | |||
=== Explanations to Aid Understanding === | |||
- Transformer Decoding: This is the process by which a large language model generates text. Given an input prompt, the model sequentially selects the most probable next token, then uses that token to predict the subsequent one, and so on, and this process is computationally intensive. | |||
- Speculative Sampling: Unlike traditional decoding where one token is generated at a time by the target model, speculative sampling aims to generate multiple tokens in parallel by using a faster, potentially smaller "draft model" to propose a short sequence (the draft). The target model evaluates these drafted tokens, and a rejection sampling mechanism decides which ones to accept, ensuring the output remains consistent with the target model's capabilities. | |||
- Parallel Scoring: Instead of computing the logits for each drafted token sequentially, the method computes the logits for all (K+1) tokens in the draft at the same time. The presentation notes that the computing time for this parallel process is similar to sampling a single token with the target model, which is a key factor in the potential speedup. The key insight is that the model inference pass is dominated by memory bandwidth and inter-device communication rather than purely by the token count. By handling several drafted tokens per pass, overall decoding time is greatly reduced. | |||
=== Summaries of Key Points === | |||
- Decoding Challenges in LLMs: | |||
Traditional autoregressive sampling methods generate one token at a time, leading to inefficiencies, especially as model | |||
- Speculative Sampling Overview: | |||
SpS utilizes a draft model, which is a smaller and faster version of the target LLM, to propose a sequence of tokens. | |||
The target model then evaluates this proposed sequence, accepting or rejecting tokens based on a modified rejection sampling scheme that ensures the final output aligns with the target model's distribution. | |||
- Algorithm Efficiency: | |||
The draft model generates a short sequence (e.g., K tokens) in parallel. | |||
The target model scores this sequence, and tokens are accepted or resampled as needed, allowing for the potential acceptance of up to K+1 tokens per iteration. | |||
This parallel approach contrasts with traditional methods that generate tokens sequentially, thereby reducing decoding latency. | |||
- Empirical Results: | |||
Implementing SpS with Chinchilla, a 70-billion-parameter language model, resulted in a 2 to 2.5 times speedup in decoding across benchmark tasks such as XSum and HumanEval. | |||
These speed improvements were achieved without degrading sample quality or requiring changes to the target model's parameters or architecture. | |||
- Advantages of Speculative Sampling: | |||
Maintains the integrity of the target model's output distribution. | |||
Requires no alterations to existing model architectures, facilitating easy integration into current systems. | |||
Demonstrates versatility across various tasks and decoding methods, making it broadly applicable in LLM deployments. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 10 Presentation: Accelerating Large Language Model Decoding with Speculative Sampling == | |||
=== Presented by: === | |||
Danyang Zhao | |||
=== Speculative Sampling: An In-Depth Explanation === | |||
Speculative Sampling is a technique introduced to speed up text generation for large language models (LLMs) without significantly affecting output quality. It works by leveraging a smaller, faster “draft” model to propose multiple candidate tokens, which are then verified and accepted or rejected by the larger model. This method reduces latency during decoding while maintaining output quality close to the original model. | |||
But, what specifically is Latency? Latency latency refers to the time delay between a user requesting text generation and the model producing the output. More formally, it is the time required to generate each token in an autoregressive model. | |||
But why do we need Speculative Sampling? Because there re problems with the Standard Autoregressive Sampling. Mainly: | |||
- LLMs generate text token-by-token. Each step depends on the previous tokens, making sequential decoding slow. | |||
- The computational bottleneck comes from generating the next token, which requires a full forward pass of the model. | |||
Therefore the goal of speculative sampling is to reduce the number of forward passes of the large language model per generated token. | |||
How Speculative Sampling Works: | |||
1. There are two models, a Draft model and a Large Model. The smaller draft model, which is cheap and fast to compute, generates a set of k speculative candidate tokens. Then, the large model, which is expensive yet accurate, verifies these tokens and either accepts or rejects them. | |||
2. The Draft model proposes multiples tokens, that is, instead of sampling just on token a each step, the draft model generates a sequence of k candidate tokes: <math> $x_{t+1}, x_{t+2}, ..., x_{t+k} </math> using the standard autoregressive decoding. | |||
3. Now, the large language model verifies the proposal. This is done by calculating the probability of each proposed tokens: | |||
<math> p(x_{t+1} | x_{\le t}), p(x_{t+2} | x_{\le t+1}), ..., p(x_{t+k} | x_{\le t+k}) </math> | |||
For each Candide, there are only two possibilities. 1). The token is accepted and is added to the output. 2). The token is rejected and the large model takes over and directly generates a new token. | |||
4). Since the drat model is smaller, it generates multiples speculative tokens quickly. Then, the forward model only computes a few forward passes, making the process more efficient. This allows for faster decoding by efficient verification. | |||
Now, let us go over a bit more of the mathematics. | |||
For each token <math> x_i </math>, we check: | |||
<math> p(x_i | x_{< i}) \ge q(x_i | x_{< i}) </math> | |||
Where p is the probability distribution of the large model and q that of the draft model. If true, it is accepted, else, it is not. | |||
This is paired with a Metropolis-style acceptance probability, which ensures the final sampled tokens remain statistically valid while speeding up computation. | |||
The acceptance probability can then be calculated as follows: | |||
<math> \alpha_i = \min \Big( 1, \frac{p(x_i | x_{< i})} {q(x_i | x_{< i})} \Big) </math> | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 10 Presentation: Accelerating Large Language Model Decoding with Speculative Sampling == | |||
=== Presented by: === | |||
Danyang Zhao | |||
=== Summaries of key points === | |||
Goal: A speculative sampling algorithm is proposed to accelerate the decoding process of the large prediction model. | |||
Background: Traditional Transformer is slow and costly, and existing methods cannot effectively improve the generation speed. | |||
Methodology: A small draft model is used to generate a token sequence of length k, and the logits of k+1 tokens are computed in parallel with the target large model. The modified rejection sampling method is used to decide whether to accept the draft token. | |||
Result: On Chinchilla, the output quality is almost unaffected, and the generation speed is significantly improved. | |||
=== Constructive critiques or reviews === | |||
The presentation can be more detailed and provide more examples to help you understand. | |||
Increase fluency and reduce pause time. | |||
While the presenter explained the concept verbaly, graphical models and diagrams would help the viewers have a better understanding of the model. For instance, in addition to providing the algorithm psudocode, a diagram of the transformer model with latency and bottleneck highlighted would showcase why their proposed parralel sampling approach improved speed. | |||
=== Clear explanations to aid understanding === | |||
Compare the probability of generating the target model and the draft model to decide whether to accept a token. | |||
The output should not deviate from the target model distribution. | |||
=== Connections to related works === | |||
Compared with the distillation model, the speculative sampling model is not changed, and the acceleration is direct. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 10 Presentation: Accelerating Large Language Model Decoding with Speculative Sampling == | |||
=== Presented by === | |||
Danyang Zhao | |||
=== Paper Citation === | |||
C. Chen, S. Borgeaud, G. Irving, J.-B. Lespiau, L. Sifre, and J. Jumper, ‘Accelerating Large Language Model Decoding with Speculative Sampling’, Feb. 02, 2023, arXiv: arXiv:2302.01318. doi: 10.48550/arXiv.2302.01318. | |||
https://arxiv.org/abs/2302.01318 | |||
=== Background === | |||
- In Transformer-based language models, decoding is typically done autoregressively, and one token is selected at a time based on previously generated tokens. This is computationally expensive and memory bandwidth bound, meaning the process is limited by the speed at which the data can be transferred rather than compute itself. | |||
- The goal of the paper is to reduce the latency of decoding without modifying the original language model, and they introduce speculative sampling. | |||
- Speculative sampling is a method that is used to generate multiple tokens in parallel, rather than one at a time, to speed up inference. | |||
=== Summaries of Key Points === | |||
- The speculative sampling process involves 3 steps. The first step is draft generation. This is when a lightweight draft model generates a sequence of k tokens in parallel. Then, the target model, which is typically larger than the draft model, is used to score the draft tokens. After these tokens are scored, the method accepts or rejects each draft token using a modified rejection sampling algorithm, ensuring the final output matches the distribution of the target model. This sampling scheme is provably correct since the output distribution does match that of the target model and is not just an approximation. | |||
- Speculative sampling achieves up to 2-3x speedups in decoding times without loss in generation quality, especially when using efficient draft models like small LMs or knowledge distillation. | |||
- Chinchilla was used for target decoding and compared speculative sampling with standard autoregressive sampling, showing significant latency reduction at little to no cost in output quality. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 10 Presentation: Accelerating Large Language Model Decoding with Speculative Sampling == | |||
=== Presented by === | |||
Danyang Zhao | |||
=== Paper Citation === | |||
C. Chen, S. Borgeaud, G. Irving, J.-B. Lespiau, L. Sifre, and J. Jumper, ‘Accelerating Large Language Model Decoding with Speculative Sampling’, Feb. 02, 2023, arXiv: arXiv:2302.01318. doi: 10.48550/arXiv.2302.01318. | |||
https://arxiv.org/abs/2302.01318 | |||
=== Objective === | |||
Transformer sampling is typical memory bandwidth bound, time needed is proportional to the number of parameters and size of the memory. Current studies have proposed quantisation, distillation, and smaller models to address this issue. Also cache of keys and values is maintained for every attention layer. Recent studies did not focus on increasing the speed for decoding. This paper focus on reducing the latency of transformer decoding without altering the capabilities of the original model. | |||
=== Summaries of Key Points === | |||
Speculative sampling by generating a short draft of length K then score the draft using the target model which is the model that we wish to sample from. Finally, it uses modified rejection sampling scheme use to recover the distribution of the target model. Comparing the speculative sampling method with Chinchilla, the time has been greatly reduced. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 10 Presentation: Accelerating Large Language Model Decoding with Speculative Sampling == | |||
=== How Does Speculative Sampling Work? === | |||
1. The Basic Idea: "Drafting" Tokens | |||
Imagine you're completing someone's sentence. Usually, you might correctly guess a few words ahead because they're predictable. The idea behind speculative sampling is similar. A smaller, quicker "draft" model predicts multiple tokens (words) rapidly, guessing what the large model would likely say next. | |||
2. Confirming with the "Big Model" | |||
After the smaller model drafts several tokens, the bigger, smarter (but slower) model checks if these guesses align with its own predictions. If the draft guesses match closely enough, the large model accepts them immediately. | |||
3. Efficiently Handling Mistakes | |||
If the draft is wrong—maybe it guesses something improbable or off-topic—the large model rejects those tokens. But rather than starting from scratch, it quickly generates a replacement token that accurately reflects its original distribution. This clever mechanism ensures accuracy stays high, and no incorrect information slips through. | |||
=== Real-World Results: Surprisingly Fast, Equally Accurate === | |||
The DeepMind team tested speculative sampling with Chinchilla, a popular 70-billion-parameter language model. They used two tasks: | |||
Text summarization (XSum): Condensing lengthy articles into short summaries. | |||
Code generation (HumanEval): Writing accurate Python code from natural-language descriptions. | |||
The results? Speculative sampling doubled the speed of Chinchilla’s token generation, and in some cases, it was up to 2.5 times faster—all without losing quality. | |||
The secret sauce here isn't magic; it's that the smaller model can swiftly handle simple predictions while the larger model just verifies and fills in any blanks. | |||
=== Why Does This Matter? === | |||
Real-Time AI Experiences: | |||
If AI assistants could respond faster, interactions would feel more natural, allowing seamless conversations without frustrating delays. | |||
Cost Efficiency: | |||
Faster token generation means saving computation time, directly reducing the costs associated with running large models in commercial and research contexts. | |||
No Need for Retraining: | |||
Speculative sampling doesn’t require adjusting or retraining the big, expensive model, making it practical to implement quickly with existing setups. | |||
=== Critique and Considerations === | |||
While speculative sampling seems like a powerful approach, there are a few things worth noting: | |||
1. Domain Dependence: The method shines brightest when token predictions are straightforward or repetitive—like structured code. However, for less predictable or more creative text, the speedup might be smaller, as the smaller draft model might guess incorrectly more frequently, increasing overhead. | |||
2. Choosing the Right Draft Model: Selecting a suitable draft model is critical. The draft needs to be good enough to ensure high accuracy but small enough to be quick. Picking the optimal size and architecture for this secondary model can be nuanced. | |||
3. Variance and Latency: As more speculative tokens are generated at once, the total latency per step increases, potentially adding variability in response times. This may be problematic for applications sensitive to latency variations. | |||
=== Connections to Other Techniques === | |||
Speculative sampling complements existing optimization strategies: | |||
Quantization and Distillation: These techniques compress large models into smaller, faster versions. Unlike speculative sampling, they require retraining or altering the model itself. Combining speculative sampling with quantization could further amplify performance improvements. | |||
Parallel Decoding Techniques: Previous approaches like "blockwise parallel decoding" similarly attempt to generate tokens in groups rather than one-by-one. However, they often require substantial architectural changes. Speculative sampling’s elegance lies in its simplicity—no changes to existing models required. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 11 Presentation: Simple Linear Attention Language Models Balance the Recall-Throughput Tradeoff== | |||
=== Presented by: === | |||
Yiyuan Yang, Anar Kuatzhan, Chuan Zhang | |||
=== Paper Citation === | |||
Arora, S., Eyuboglu, S., Zhang, M., Timalsina, A., Alberti, S., Zinsley, D., ... & Ré, C. (2024). Simple linear attention language models balance the recall-throughput tradeoff. arXiv preprint arXiv:2402.18668. https://arxiv.org/pdf/2402.18668 | |||
=== Background === | |||
Nowadays large language models still struggle with efficiency. In attention-based models, attention requires a huge number of calculations as the input gets longer. Attention also stores every previous word in memory, which makes it memory-intensive. New language models developed in recent years can generate text faster while maintaining low perplexity. Low perplexity doesn't necessarily mean good recall. Gated convolutional models also struggle with recall. Attention-based models excels in recall tasks. To address these problems, the authors introduced the Based model. | |||
=== Technical Contributions === | |||
The authors introduce a new architecture called BASED (Balanced Attention through Sliding + Exponential Decay), which is designed to address the recall–throughput tradeoff in language models. It combines two key components: | |||
* '''Linear Attention (Global Context)''': | |||
** Uses a second-order Taylor approximation of softmax attention. | |||
** Enables constant-size memory state during autoregressive decoding. | |||
* '''Sliding Window Attention (Local Context)''': | |||
** Performs exact attention within a small local window (e.g., 64–128 tokens). | |||
** Captures short-range dependencies with high precision. | |||
<b>Memory-Recall Tradeoff:</b> Observed both within and across architecture classes. | |||
<b>Performance with Fixed Recurrent State:</b> Not all architectures have the same recall capacity. Mamba optimally utilizes limited memory budgets while convolutional architectures underperform with memory constraints. | |||
<b>The Based model</b> combines local fine-grained attention + long-range linear attention via Taylor approximation of softmax exponential function that are sub-quadratic complexity during training and permit an efficient recurrent inference view. Based outperforms prior sub-quadratic architectures in recall quality by up to 6.2 accuracy points. | |||
<b>Hybrid Layer Composition for Enhanced Recall and Efficiency<b> BASED is constructed as a hybrid architecture composed of approximately 20% linear attention layers, 20% sliding window attention layers, and 60% gated convolution layers. This combination leverages the precision of local token comparison (via sliding window and short convolutions) with the global context capabilities of linear attention. The inclusion of short gated convolution layers (e.g., filter size 3) helps model local dependencies that small sliding windows might miss, improving the architecture’s overall recall ability. This hybrid design, detailed in Appendix E.1 of the paper, enables BASED to outperform other sub-quadratic models like Mamba in both recall-intensive tasks and generation throughput benchmarks. | |||
=== Architecture === | |||
Softmax-approximating linear attention (applied globally) + exact softmax attention with sliding windows (applied locally) | |||
This combination achieves 90.8% of full softmax attention's recall accuracy while reducing latency by a factor of 100,000. | |||
=== Accomplishments === | |||
* '''Improved Recall''': | |||
** BASED outperforms Mamba by up to 10.36 accuracy points on recall-intensive tasks. | |||
** Recovers over 90% of softmax attention’s recall performance while using significantly less memory. | |||
* '''High Throughput''': | |||
** Achieves up to 24× higher generation throughput compared to FlashAttention-2. | |||
** Competitive wall-clock time due to efficient CUDA kernel design. | |||
* '''Strong Language Modeling''': | |||
** Matches or surpasses models like Mamba and Hyena in perplexity and downstream task accuracy. | |||
* '''Theoretical Contributions''': | |||
** Demonstrates that recurrent models require <math>\Omega(N)</math>-bit memory to perform associative recall. | |||
** Proves that BaseConv, a gated convolution model, cannot solve recall tasks in constant layers. | |||
** Shows that the recall-throughput tradeoff is theoretically fundamental. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 11 Presentation: Simple Linear Attention Language Models Balance the Recall-Throughput Tradeoff== | |||
=== Presented by: === | |||
Yiyuan Yang, Anar Kuatzhan, Chuan Zhang | |||
=== Paper Citation === | |||
Arora, S., Eyuboglu, S., Zhang, M., Timalsina, A., Alberti, S., Zinsley, D., ... & Ré, C. (2024). Simple linear attention language models balance the recall-throughput tradeoff. arXiv preprint arXiv:2402.18668. https://arxiv.org/pdf/2402.18668 | |||
=== Introduction === | |||
"Simple Linear Attention Language Models Balance the Recall-Throughput Tradeoff" is a peer which introduces BASED, an architecture designed to enhance the efficiency of language models by balancing memory consumption and recall abilities. This approach combines Linear Attention with Sliding Window Attention to navigate the tradeoff between state size and recall. | |||
=== Methodology === | |||
The researchers analyzed various architectures to understand the tradeoff between a model's state size and its recall ability. They observed that efficient alternatives to attention, such as H3, Mamba, and RWKV, maintain a fixed-size recurrent state but exhibit limitations in recall performance. To address this, they proposed the BASED architecture, which combines linear attention with sliding window attention. By adjusting the window size and the feature dimension of the linear attention, BASED can navigate the Pareto frontier of the recall-memory tradeoff, effectively balancing recall quality and state size. | |||
=== Empirical results === | |||
The study trained language models with up to 1.3 billion parameters and found that BASED matches the perplexity of leading sub-quadratic models like Mamba. Furthermore, BASED outperformed these models on real-world recall-intensive tasks by 6.22 accuracy points. Additionally, the implementation of input/output-aware algorithms enabled BASED to achieve 24 times higher throughput in language generation compared to FlashAttention-2 when generating 1,024 tokens using 1.3 billion parameter models. | |||
== Mathematical Explanation of BASED Architecture === | |||
==== (1) Sliding Window Attention -- SWA ==== | |||
SWA computes the attention over a fixed-size window of precious tokens, capturing local dependencies. For a window size, say of <math> w </math>, the attention for a token <math> t</math> considers only the tokens <math> [t-w, t-q] </math>. | |||
Given queries, keys and values <math> Q \in \mathbb^{n \times d} \ \ K \mathbb^{n \times d} \ \ V \mathbb^{n \times d} <\math> for a sequence of length <math> n </math> and of hidden dimension <math> d </math>, the attention output <math> A </math> is computed as follow: | |||
<math> A_t = \text{softmax} \Bigg( \frac{Q_t K_{t-w : t-1}}{\sqrt{d}} \Bigg) V_{t-w : t-1} </math> | |||
Where <math> A_t </math> is the attention output at position <math> t </math>. | |||
==== (2) Linear Attention ==== | |||
Linear attention approximates standard attention mechanism to capture global dependencies with reduced computational complexity. It redefines the attention operation to be linear in the sequence length using feature maps <math> \phi </math> to project queries and keys. | |||
The linear attention output is computed as: | |||
<math> A = \phi(Q) \big( \phi(K)^T V \big) </math> | |||
Where <math> \phi </math> is a feature map function applied to the queries and keys. This formulation allows the attention computation to be rearranged and optimized, reducing the complexity to <math> O(n) </math>. | |||
==== (3) Combining Sliding Window Attention and Linear Attention ==== | |||
BASED integrates SWA and linear attention to leverage the strength of both methods. SWA captures fine-grained local dependencies while Linear Attention models long-range dependencies. | |||
By simply adjusting the sliding window size <math> w </math> and the feature dimension <math> d </math> in linear attention, BASED can navigate the trade off of memory consumption and the ability to recall information. A larger <math> w </math> enhances local context capture by increases memory usage, whereas a higher <math> d </math> improves global context understanding with minimal memory overhead. | |||
=== Application and Performance (brief overview) === | |||
In this paper, BASED models were trained on up to 1.3 billion parameters and evaluated on tasks requiring high recall. This included tasks in information extraction and reading comprehension. The architecture demonstrated performance matching or in some instances surpassing other sub-quadratic models such as MAMBA. Notably, it excelled in recall-intensive situations. | |||
Implementations of linear attention often lag behind optimized standard attention in efficiency. To address this, the authors developed I/O aware algorithms, enabling BASED to achieve 24x higher throughput in language generation compared to other methods such as Flash-Attention-2. | |||
=== Conclusion === | |||
The BASED architecture offers a pragmatic solution to the recall and throughput tradeoff in language models by combining sliding window and linear attention mechanisms. | |||
This integration has allowed for efficient handling of both local and global dependencies. Subsequently this has resulted in models that are both memory efficient and able to perform high recall tasks, thereby advancing the development of more efficient LLM techniques. | |||
BASED is the first linear attention model shown to match or beat Mamba on: | |||
*Perplexity (language modeling quality) | |||
*Real-world recall benchmarks (copying, in-context learning) | |||
This challenges the growing belief that attention-free models like Mamba are the most scalable path forward. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 11 Presentation: Simple Linear Attention Language Models Balance the Recall-Throughput Tradeoff== | |||
=== Presented by: === | |||
Yiyuan Yang, Anar Kuatzhan, Chuan Zhang | |||
=== Paper Citation === | |||
Arora, S., Eyuboglu, S., Zhang, M., Timalsina, A., Alberti, S., Zinsley, D., ... & Ré, C. (2024). Simple linear attention language models balance the recall-throughput tradeoff. arXiv preprint arXiv:2402.18668. https://arxiv.org/pdf/2402.18668 | |||
=== Background & Motivation === | |||
Transformer-based language models rely on attention mechanisms that require storing increasing amounts of key-value pairs (KV-cache) during inference. This makes them memory-intensive and less suitable for real-time or resource-constrained applications. The paper investigates whether it's possible to reduce memory usage while maintaining strong contextual recall capabilities—hence the "recall-throughput tradeoff." | |||
=== Methodology === | |||
1. '''Based Architecture''': | |||
* '''Linear Attention''': Uses a second-order Taylor approximation of softmax to maintain global token interactions with a fixed-size recurrent state. | |||
* '''Sliding Window Attention (SWA)''': Applies exact softmax attention locally in small windows (64-128 tokens) to handle precise local shifts. This combination allows Based to navigate the recall-throughput tradeoff effectively. | |||
* '''IO-Aware Optimizations''': Custom CUDA kernels reduce memory movement and improve hardware efficiency, enabling 24× higher throughput than FlashAttention-2 during generation. | |||
* '''The BASED Model''': BASED (Bidirectional Attention with Stable Expansion and Delay) is proposed as a simple and efficient linear attention architecture. Its defining traits include: | |||
** Linear complexity with respect to sequence length. | |||
** No KV-cache required during inference, unlike transformers. | |||
** Introduces a memory state updated recurrently across tokens. | |||
** Achieves bidirectional context modeling using a fixed-size memory block. | |||
This makes BASED models more efficient for both training and inference, especially in streaming or real-time settings. | |||
2. '''Theoretical and Empirical Analysis''': | |||
* '''Lower Bounds''': The paper proves that any recurrent model requires Ω(N)-bits in state size to solve associative recall, highlighting the fundamental tradeoff. | |||
** Achieves up to 24× higher generation throughput compared to FlashAttention-2. | |||
* '''Empirical Results''': Experiments on synthetic and real-world tasks (e.g., MQAR, Pile perplexity, information extraction) show Based outperforms Mamba by 10.36 accuracy points on recall-intensive tasks while matching its perplexity. | |||
=== Experimental Results === | |||
* BASED models match or outperform standard transformer models on various language modeling benchmarks such as WikiText-103 and PG-19. | |||
* They show strong performance on long-context tasks, including copy and retrieval tasks, indicating good memory recall. | |||
* BASED demonstrates superior throughput, especially in inference without KV caching. | |||
=== Key Findings === | |||
* Efficient memory usage and fast inference are achievable without sacrificing much performance. | |||
* Linear attention models like BASED can serve as a viable alternative to transformers in memory-constrained or latency-sensitive applications. | |||
* There exists a tradeoff surface between recall and throughput, and BASED models lie on an efficient frontier of that tradeoff. | |||
=== Conclusion === | |||
Based expands the Pareto frontier of the recall-throughput tradeoff by combining simple, well-known techniques (linear and sliding window attention) with hardware-aware optimizations. The results suggest that efficient models can achieve high recall without sacrificing throughput, offering a promising direction for future language model architectures. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 12 Presentation: EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees == | |||
=== Presenters === | |||
Mutong Zhang, Hanqing Bi | |||
=== Paper Citation === | |||
Y. Li, F. Wei, C. Zhang, and H. Zhang, ‘EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees’, Jun. 30, 2024, arXiv: arXiv:2406.16858. doi: 10.48550/arXiv.2406.16858. | |||
https://arxiv.org/abs/2406.16858 | |||
=== Background === | |||
- LLMs to date have shown great performance, but they are slow and computationally expensive | |||
- Speculative sampling - small model generates candidate tokens, whereas large model then evaluates those tokens, reducing the number of times the expensive computations of the large model have to occur | |||
- However, speculative decoding methods suffered from low acceptance rates of draft tokens as they lacked efficient ways to predict which draft tokens would be accepted (non-efficient computation). | |||
- EAGLE 1 used a draft tree to improve performance of speculative execution, and this builds on the author's prior work | |||
- The authors observe that acceptance rates are also context-dependent, suggesting the need for adaptive drafting. | |||
=== Paper Contributions=== | |||
- EAGLE 1 doesn't directly predict tokens, and rather maps tokens to features, predicts features, and then predicts tokens from those features | |||
- This work was shown to improve upon previous speculative execution methods | |||
- Eagles uses a tree structure to propose alternative token when speculative sampling draft token is rejected by the full model | |||
- EAGLE 2 noted that token acceptance is dependant on context and position. The first token has a high acceptance rate, and later tokens have lower acceptance rates | |||
- EAGLE 2 uses a dynamic draft trees which incorporate "tree attention" to incorporate context information into selecting the next candidate token, increasing the acceptance rate of the token, as it depends on context as well as not only position | |||
=== Summaries of Key Points=== | |||
'''Addressing the Challenge of Slow LLM Inference''' | |||
Large language models (LLMs) have revolutionized natural language processing, but their inference remains computationally expensive and slow. This bottleneck arises due to the vast number of model parameters and the sequential nature of token generation. Improving inference efficiency without sacrificing accuracy is a critical research direction in the field. | |||
'''The Foundation: Speculative Sampling and Draft Models''' | |||
EAGLE-2 builds on speculative sampling, a technique that leverages a smaller, lightweight draft model to propose candidate tokens, which are then verified by the full LLM. This approach speeds up inference by reducing the number of computations performed by the large model. | |||
'''Evolution from EAGLE-1 to EAGLE-2''' | |||
The original EAGLE-1 introduced a static draft tree, assuming that token acceptance rates depend solely on their position within the tree structure. However, this assumption overlooks the context-dependent nature of token acceptance. EAGLE-2 addresses this limitation by introducing a dynamic draft tree adjustment mechanism, which refines speculative sampling based on real-time confidence scores. | |||
'''Key Mechanisms of EAGLE-2''' | |||
1. Instead of directly predicting tokens, EAGLE-2’s draft model generates feature representations, which are then processed by the head of the LLM to produce token predictions. This method enhances accuracy while maintaining efficiency. | |||
2. EAGLE-2 employs a tree-based verification mechanism, which is computationally more efficient than the standard chain-structured verification used in traditional speculative sampling. By verifying multiple candidate tokens in parallel, it accelerates the overall process. | |||
3. A core innovation in EAGLE-2 is its ability to dynamically modify the draft tree based on confidence scores from the draft model. This ensures that speculative sampling adapts to varying contexts, improving token acceptance rates and overall inference speed. | |||
'''Dynamic Expansion and Re-Ranking''' | |||
1. Expansion Phase: EAGLE-2 introduces a novel expansion phase, where a tree-attention mechanism processes all tokens in a layer simultaneously. This significantly enhances efficiency compared to sequential processing. Additionally, selective expansion prioritizes only the top-k tokens with the highest estimated global acceptance probabilities, preventing unnecessary computational overhead. | |||
2. Re-Ranking Phase: Following expansion, EAGLE-2 reranks candidate tokens by selecting the top-m tokens with the highest acceptance probabilities. In cases where multiple tokens have similar scores, shallower nodes in the draft tree are prioritized, further optimizing verification speed. | |||
During the reranking phase, EAGLE-2 ensures that the most promising tokens are selected not only based on their depth in the draft tree but also on their overall likelihood of being accepted. Since deeper nodes in the draft tree tend to have lower values due to the multiplication of multiple acceptance probabilities, some shallow nodes—though not expanded in the previous phase—may have higher values. To optimize performance, EAGLE-2 reranks all candidate tokens (including both shallow and deep nodes) and selects the top m tokens with the highest values. Importantly, when multiple tokens share the same value, shallower nodes are prioritized to preserve tree connectivity and maximize verification efficiency. This strategy improves both the average acceptance length and the speedup ratio, as confirmed by the ablation studies. | |||
'''Experimental Results: Significant Performance Gains''' | |||
EAGLE-2 achieved acceleration ratios of 3.05x - 4.26x across various tasks and large language model series such as Kuna, Llama 2 and Llama 3, making it 20% - 40% faster than EAGLE-1. It is also approximately 2 times faster than Medusa and 2.3 times faster than Lookahead. On token throughput, EAGLE-2 processes 4-5.5 tokens per verification cycle, about twice as many as traditional speculative sampling. | |||
'''Key Advantages of EAGLE-2''' | |||
1. Plug-and-Play Efficiency – EAGLE-2 requires no additional model training, as it seamlessly integrates the pre-trained draft model from EAGLE-1. It does not alter the original LLM, and maintains the exact same output distribution as greedy decoding (i.e., it is lossless). | |||
2. Robust and Reliable – Unlike some acceleration techniques, EAGLE-2 does not modify the original model parameters or relax acceptance conditions, ensuring stable and consistent outputs. | |||
3. Broad Generalization – The framework generalizes well across different tasks and architectures, demonstrating strong adaptability in diverse applications. | |||
EAGLE-2 represents a significant advancement in accelerating LLM inference. By introducing a dynamic draft tree, efficient expansion strategies, and intelligent token re-ranking, it substantially reduces computational costs while maintaining accuracy. As large-scale models continue to grow, techniques like EAGLE-2 will be instrumental in making LLMs more practical and accessible for real-world applications. | |||
=== Accomplishments === | |||
* '''State-of-the-art inference speedup''': | |||
** EAGLE-2 achieves up to '''5×''' speedup in language model inference. | |||
** Outperforms prior speculative decoding methods including EAGLE, Medusa, Lookahead, and others. | |||
* '''Longest average acceptance length''': | |||
** EAGLE-2 generates longer sequences per accepted draft, reducing the number of calls to the target model. | |||
* '''Wide applicability''': | |||
** Tested across '''6 diverse tasks''', including: | |||
*** Conversation | |||
*** Code generation | |||
*** Mathematical reasoning | |||
*** Question answering | |||
*** Summarization | |||
*** Instruction following | |||
* '''Model-agnostic design''': | |||
** Compatible with popular LLMs such as: | |||
*** Vicuna | |||
*** LLaMA2-Chat | |||
*** LLaMA3-Instruct | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 12 Presentation: EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees == | |||
=== Presenters === | |||
Mutong Zhang, Hanqing Bi | |||
=== Paper Citation === | |||
Y. Li, F. Wei, C. Zhang, and H. Zhang, ‘EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees’, Jun. 30, 2024, arXiv: arXiv:2406.16858. doi: 10.48550/arXiv.2406.16858. | |||
https://arxiv.org/abs/2406.16858 | |||
=== Constructive Critique and Review === | |||
The paper “EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees” introduces an innovative approach to enhancing the efficiency of Large Language Model (LLM) inference through the implementation of a context-aware dynamic draft tree. | |||
This paper makes two main contributions to the field: | |||
==== 1). Dynamic Draft Tree Structure ==== | |||
Building upon the original EAGLE framework, EAGLE-2 replaces the static draft tree with a dynamic architecture that adapts on context. This adjustment acknowledges that the acceptance rate of draft tokens is influences not only by their position, but also by the surrounding context which leads to a more efficient token generation. | |||
==== 2). Utilization of Well Calibrated Draft Models ==== | |||
The paper also reveals that the draft model's confidence scores closely approximate the acceptance rates of draft tokens. As such, by leveraging this calibration, EaGLE-2 effectively predicts which tokens are more likely to be accepted, optimizing the entire drafting process and token generation. | |||
=== Performance Outcomes === | |||
Extensive evaluations have been conducted access three series of LLMs, those of Vicuna, LLaMA2-Char and LLaMA3-Instruct, as well as on six diverse tasks, including multi-turn conversations, code generation and mathematical reasoning. | |||
The results reveal that EAGLE-2 achieves speedup ratios of 3.05x to 4.26x, nearly a 20% to 40% improvement over EAGLE-1. Notably, this acceleration is achieved without altering the distribution of the generated text, ensuring the fidelity of the model's outputs. | |||
=== Advancements within the field === | |||
EAGLE-2 makes significant advancements in the realm of LLM inference and especially optimization. By introducing a content-aware dynamic draft tree, the paper addresses the limitations of the previous speculative sampling architecture of EAGLE-1, which is a static in nature. | |||
This innovation enhances the acceptance rate of draft tokens, thereby reducing inference latency and computational costs. Additionally, the approach maintains the integrity of the generated text distributions which distinguishes itself from other acceleration models that have compromised outputs. | |||
=== Conclusion === | |||
The methodologies and findings presented in this paper offer a substantial contribution to the field of ML and most notably in optimizing the efficiency of the generative LLMs. The introduction of dynamic, context-aware drafting mechanisms sets a new benchmark for speculative sampling techniques paving the way for more responsive, fast, and cost-effective LLM applications. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 12 Presentation: EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees == | |||
=== Presenters === | |||
Mutong Zhang, Hanqing Bi | |||
=== Paper Citation === | |||
Y. Li, F. Wei, C. Zhang, and H. Zhang, ‘EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees’, Jun. 30, 2024, arXiv: arXiv:2406.16858. doi: 10.48550/arXiv.2406.16858. | |||
https://arxiv.org/abs/2406.16858 | |||
=== Summaries === | |||
EAGLE-2 makes improvements to the speculative sampling by improving EAGLE. EAGLE-2 constructs dynamic tree as the draft model based on context (acceptance rates). To reduce the costs of computing acceptance rates, the confidence score of the draft model that can be generated by EAGLE is used to approximate the acceptance rate. EAGLE-2 does not require extra training and there is no loss in the distribution of generated texts compare with the original LLM as it does not modify the original LLM. | |||
=== Key Contributions === | |||
EAGLE-2 follows the same procedure as EAGLE. EAGLE-2 constructs a tree draft model based on context (acceptance rate), where the acceptance rate is approximated by the confidence score from EAGLE. | |||
'''Expansion Phase''': The best tokens (top-k tokens) from the last layer are used to predict tokens in the next layer. | |||
'''Re-ranking''': Tokens with the highest acceptance rates (top-m tokens) are selected and used in the verifying phase. | |||
'''Metrics''': Speedup ratio and average acceptance length are used to evaluate the model. | |||
=== Explanations of details === | |||
'''Why a dynamic tree is used''' | |||
The acceptance rates of tokens at different positions were tested. It was noted that the acceptance rates of tokens were higher in the upper left side of the tree, and were lower in the lower right side. Also, the acceptance rates at the same position varied significantly. Therefore, the authors concluded that it is worthy to build a dynamic tree that is based on context. | |||
'''Why use an approximation of acceptance rates''' | |||
To calculate the real acceptance rates of tokens, we need to do a forward pass in the original LLM, which is costly. The authors found that there was a positive relationship between the confidence score and the acceptance rate. | |||
=== Related Works === | |||
For LLMs, the existing works that try to accelerate LLM inference include: low-bit quantization (Hubara et al., 2018; Shen et al., 2020; Kim et al., 2021; Zadeh et al., 2020; Zafrir et al., 2019), pruning (Gale et al., 2019; Sanh et al., 2020), distillation (Hinton et al., 2015). These methods often decrease the quality of the output. | |||
Speculative sampling (Leviathan et al., 2023; Chen et al., 2023a) tries to reduce the costs in the decoding process while preserving the output quality. It involves two phases: the generating phase that generates multiple tokens and the verifying phase. It uses a chain-structured draft model. | |||
EAGLE has improved the speculative sampling by autoregressively predicting features and using both features and tokens to reduce uncertainties. EAGLE uses a static tree structure as the draft model. Medusa (Cai et al., 2024) also uses a tree structure. They all pick candidates at specific step. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 12 Presentation: EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees == | |||
=== Presenters === | |||
Mutong Zhang, Hanqing Bi | |||
=== Paper Citation === | |||
Y. Li, F. Wei, C. Zhang, and H. Zhang, ‘EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees’, Jun. 30, 2024, arXiv: arXiv:2406.16858. doi: 10.48550/arXiv.2406.16858. | |||
=== Background === | |||
Big language models (LLMs) are changing the game in AI, but they come with a huge downside: they take a lot of time and computing power to run. Making them faster and cheaper without losing quality is a major challenge. One popular way to speed things up is speculative decoding, where a smaller, faster model makes guesses to help the bigger model generate text more efficiently. While this approach works, it still struggles with getting the right balance between speed, accuracy, and ease of use. | |||
=== Main Idea === | |||
EAGLE 2 is a new and improved speculative decoding method that makes LLMs way faster without sacrificing quality. The trick is using a small, efficient model to suggest multiple possible next words at the same time, then having the big model quickly check and accept as many as possible. Unlike older speculative decoding techniques that focus on just one guess at a time, EAGLE 2 smartly boosts the number of accepted tokens using a better verification process. It also fine-tunes how guesses are made and rejected, cutting down on wasted computations and speeding things up even more. | |||
=== Experiments === | |||
The researchers ran a bunch of tests to see how well EAGLE 2 performs across different LLMs and datasets. They looked at: | |||
• How it compares to standard autoregressive decoding and older speculative decoding methods. | |||
• How well it works on general NLP tasks like text generation and summarization, as well as more specific datasets. | |||
• Key metrics like speed improvements, how often the model accepts the suggested words, and how good the final output is. | |||
=== Results === | |||
EAGLE 2 delivered big improvements in speed without lowering output quality. Key takeaways include: | |||
• Higher Acceptance Rates: More of the small model’s suggested words were accepted, meaning fewer wasted computations. | |||
• Faster Performance: Compared to the usual autoregressive decoding method, EAGLE 2 sped things up by 2 to 4 times. | |||
• Same Great Quality: Even with the increased efficiency, the text output was just as good as before, based on standard evaluation methods. | |||
=== Conclusion === | |||
EAGLE 2 makes large language models run much faster while keeping their responses accurate and high-quality. This makes it a strong candidate for real-world applications where speed and cost matter. With further improvements, it could become even more efficient and be applied to even larger-scale AI systems. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 12 Presentation: EAGLE-2: Faster Inference of Language Models With Dynamic Draft Trees == | |||
=== Presented by === | |||
Mutong Zhang and Hanqing Bi | |||
=== Paper Citation === | |||
Y. Li, F. Wei, C. Zhang, and H. Zhang, ‘EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees’, Jun. 30, 2024, arXiv: arXiv:2406.16858. doi: 10.48550/arXiv.2406.16858. | |||
https://arxiv.org/abs/2406.16858 | |||
=== Background === | |||
- Large Language Models (LLMs) are computationally intensive during inference due to their substantial parameter sizes, leading to high latency and resource consumption. | |||
- Speculative sampling is a technique that was aimed at accelerating LLM inference by generating multiple tokens in parallel. This technique was employed in EAGLE-1, a model that uses a static draft tree, where token acceptance is predetermined by position. However, this approach over looks the context-dependent nature of token acceptance rates. | |||
- EAGLE-2 addresses these limitations by introducing a dynamic adjustment mechanism for the draft tree, and it enhances efficiency by adapting to contextual variations | |||
=== Summaries of Key Points === | |||
- EAGLE-2 recognizes that token acceptance rates are influenced by context, not just position, and it performs context-aware adaptations in order to dynamically adjust the draft tree. By leveraging the draft model’s confidence scores, which closely approximate actual acceptance probabilities, the draft tree dynamically adjusts, focusing computational resources where there are most effective. | |||
- The dynamic draft tree is constructed through expansion and re-ranking phases that selectively generate and prioritize token branches based on their likelihood of acceptance, improving efficiency and reducing redundant computation. | |||
- In terms of performance, EAGLE-2 achieves significant speedups without compromising the integrity of the generated text, maintaining the original output distribution. Additionally, it resulted in an increase in throughput, processing almost twice as many tokens per cycle compared to traditional methods. Ablation studies demonstrate that EAGLE-2 offers higher speedups and better acceptance rates than EAGLE-1, showing improvement by use of the dynamic nature of the draft tree. | |||
- The advantages of EAGLE-2 are that no additional training is required and there is consistency in text generation even when using accelerated inference. EAGLE-2 also exhibits consistent performance improvements across a diverse range of tasks. | |||
=== Explanation of Expansion and Re-ranking phases === | |||
- To construct the dynamic draft tree efficiently, the EAGLE-2 model employs an expansion phase and a re-ranking phase. | |||
- In the expansion phase, a tree attention mechanism guides the selective growth of the tree by identifying and expanding the most promising nodes. It uses confidence scores to prioritize branches with higher token acceptance probabilities. This stage ensures computational resources are focused on likely candidates, reducing unnecessary expansion. | |||
- In the re-ranking phase, draft tokens are re-evaluated using the target model, and the top-m tokens with the highest probabilities are selected to ensure optimal token generation sequences. This preserves generation quality while maintaining consistency in token ordering and ensuring child nodes follow their parent nodes in the final output sequence. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 13 Presentation: Linear Attention Mechanism: An Efficient Attention for Semantic Segmentation == | |||
=== Presented By === | |||
Yuke Liu, Mei Si | |||
=== Paper Citation === | |||
R. Li, J. Su, C. Duan, and S. Zheng, ‘Linear Attention Mechanism: An Efficient Attention for Semantic Segmentation’, Aug. 20, 2020, arXiv: arXiv:2007.14902. doi: 10.48550/arXiv.2007.14902. | |||
https://arxiv.org/abs/2007.14902 | |||
=== Background === | |||
- Existing transformer models have <math>\mathcal{O} (n^2) </math> complexity, which is problematic as model size grows | |||
- This limits the growth of model sizes due to computational resources constraints | |||
- This paper focused on an alternative method to conventional dot product attention that is more computationally efficient | |||
- Standard attention required the computation of <math> Q K^\top </math> which requires <math>\mathcal{O} (n^2) </math> complexity | |||
- The paper proposes a linear attention mechanism that solves the problem while keeping the performance. | |||
=== Technical Contributions === | |||
Rather than doing the full computation for the softmax in the transformer architecture, the authors instead compute | |||
<math> D(\mathbf{Q}, \mathbf{K}, \mathbf{V})_i = | |||
\frac{\sum_{j=1}^{N} e^{\mathbf{q}_i^{T} \mathbf{k}_j} \mathbf{v}_j}{\sum_{j=1}^{N} e^{\mathbf{q}_i^{T} \mathbf{k}_j}} = \frac{\sum_{j=1}^{N} \text{sim}(\mathbf{q}_i, \mathbf{k}_j) \mathbf{v}_j}{\sum_{j=1}^{N} \text{sim}(\mathbf{q}_i, \mathbf{k}_j)} </math> | |||
and define the transformation function as | |||
<math>\text{sim}(\mathbf{q}_i, \mathbf{k}_j) = \phi(\mathbf{q}_i)^{T} \phi(\mathbf{k}_j) </math> | |||
The authors apply a first-order Taylors series expansion, and after some rearranging and substiution arrive to their final formula (full derivation not shown) | |||
<math>D(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = | |||
\frac{ | |||
\sum_j \mathbf{v}_{i,j} + \left( \frac{\mathbf{Q}}{\lVert \mathbf{Q} \rVert_2} \right) \left( \left( \frac{\mathbf{K}}{\lVert \mathbf{K} \rVert_2} \right)^{T} \mathbf{V} \right) }{ N + \left( \frac{\mathbf{Q}}{\lVert \mathbf{Q} \rVert_2} \right) \sum_j \left( \frac{\mathbf{K}}{\lVert \mathbf{K} \rVert_2} \right)_{i,j}^{T} } </math> | |||
The authors form of the attention mechanism can be solved in <math>\mathcal{O} (n) </math> complexity, reducing the computational scaling that existing with conventional transformers and enabling the creation of larger models using the underlying attention mechanism. <br> | |||
*Hybrid Attention Mechanisms for Selective Efficiency <br> | |||
Idea: Combine linear attention with local or sparse attention in a hybrid structure where the system dynamically chooses the best mechanism depending on context (e.g., object density, resolution, or modality). The rationale is that the dense areas of an image (e.g., urban environments) may benefit from sparse or local attention for spatial precision. Less complex regions (e.g., sky, ocean) can rely on linear attention for efficiency. It could be implemented with a gating module that selects or blends attention types. | |||
=== Model Performance Evaluation === | |||
The model performance evaluation primarily focused on assessing how well the proposed linear attention mechanism enhances semantic segmentation performance while maintaining computational efficiency. | |||
- Dataset: The experiments were conducted using the Fine Grained Image Data set, which consists of high-resolution satellite images. This dataset was selected due to its complex landscapes and varied environmental conditions, presenting significant challenges. | |||
- Data Preprocessing: Due to the large size of the images in the dataset, each image was divided into smaller patches of 256x256 pixels, resulting in a total of 7,280 patches. | |||
- Data Split: These batches were methodically partitioned into 60% for training, 20% for validation, and 20% for testing. This split ensured a rigorous evaluation of the model's performance across different scenarios. | |||
- The standard dot-product attention is replaced by the linear attention mechanism in the baseline segmentation models such as U-Net, Res101, DeepLab, PSPNet, etc | |||
- Implementation Details: The experiments were implemented using PyTorch framework and trained on an NVIDIA RTX 2080Ti GPU. | |||
- Evaluation Metrics: OA, AA, K, mloU, F1. | |||
=== Comparative Analysis === | |||
The proposed linear attention mechanism achieves a similar level of accuracy in semantic segmentation tasks compared to traditional dot product attention. | |||
Linear attention maintains comparable performance metrics while drastically reducing both memory footprint and processing time required. | |||
The efficiency gains of linear attention become more apparent in large-scale segmentation tasks, making it suitable for high-resolution images and long text sequences. | |||
Future work includes extending the linear attention mechanism to more complex network designs, integrating with state-of-the-art deep learning models, and optimizing for real-time processing in demanding applications. | |||
<table border="1" cellspacing="0" cellpadding="6"> | |||
<th>Method</th> | |||
<th>OA</th> | |||
<th>AA</th> | |||
<th>K</th> | |||
<th>mIoU</th> | |||
<th>F1</th> | |||
</tr> | |||
<tr> | |||
<td>U-Net</td> | |||
<td>86.378</td> | |||
<td>74.532</td> | |||
<td>83.357</td> | |||
<td>64.516</td> | |||
<td>75.532</td> | |||
</tr> | |||
<tr> | |||
<td>U-Net LAM</td> | |||
<td>87.692</td> | |||
<td>77.297</td> | |||
<td>84.935</td> | |||
<td>68.038</td> | |||
<td>78.593</td> | |||
</tr> | |||
<tr> | |||
<td>Res101</td> | |||
<td>89.251</td> | |||
<td>80.451</td> | |||
<td>86.846</td> | |||
<td>72.433</td> | |||
<td>82.510</td> | |||
</tr> | |||
<tr> | |||
<td>Res101 LAM</td> | |||
<td>90.178</td> | |||
<td>82.757</td> | |||
<td>88.041</td> | |||
<td>74.085</td> | |||
<td>83.105</td> | |||
</tr> | |||
<tr> | |||
<td>RefineNet</td> | |||
<td>89.857</td> | |||
<td>81.169</td> | |||
<td>87.597</td> | |||
<td>73.167</td> | |||
<td>83.113</td> | |||
</tr> | |||
<tr> | |||
<td>RefineNet LAM</td> | |||
<td>90.214</td> | |||
<td>83.544</td> | |||
<td>88.083</td> | |||
<td>74.973</td> | |||
<td>84.311</td> | |||
</tr> | |||
<tr> | |||
<td>DeepLab</td> | |||
<td>89.388</td> | |||
<td>80.905</td> | |||
<td>87.079</td> | |||
<td>71.809</td> | |||
<td>81.077</td> | |||
</tr> | |||
<tr> | |||
<td>DeepLab LAM</td> | |||
<td>89.576</td> | |||
<td>81.692</td> | |||
<td>87.287</td> | |||
<td>72.827</td> | |||
<td>82.702</td> | |||
</tr> | |||
<tr> | |||
<td>DeepLabV3+</td> | |||
<td>90.125</td> | |||
<td>81.483</td> | |||
<td>87.959</td> | |||
<td>72.668</td> | |||
<td>81.492</td> | |||
</tr> | |||
<tr> | |||
<td>DeepLabV3+ LAM</td> | |||
<td>90.315</td> | |||
<td>81.695</td> | |||
<td>88.182</td> | |||
<td>73.727</td> | |||
<td>82.736</td> | |||
</tr> | |||
<tr> | |||
<td>PSPNet</td> | |||
<td>90.573</td> | |||
<td>82.211</td> | |||
<td>88.485</td> | |||
<td>74.797</td> | |||
<td>83.761</td> | |||
</tr> | |||
<tr> | |||
<td>PSPNet LAM</td> | |||
<td>90.725</td> | |||
<td>83.088</td> | |||
<td>88.677</td> | |||
<td>75.695</td> | |||
<td>84.480</td> | |||
</tr> | |||
<tr> | |||
<td>FastFCN</td> | |||
<td>90.336</td> | |||
<td>83.625</td> | |||
<td>88.221</td> | |||
<td>74.364</td> | |||
<td>83.704</td> | |||
</tr> | |||
<tr> | |||
<td>FastFCN LAM</td> | |||
<td>90.835</td> | |||
<td>83.075</td> | |||
<td>88.769</td> | |||
<td>75.174</td> | |||
<td>84.023</td> | |||
</tr> | |||
</table> | |||
=== Constructive Critiques and Reviews === | |||
The presenters throughly explained the precursor to the proposed Linear Attention Mechanism and how it has potential in computer vision applications. However, it would helpful to show individual samples as well, showcasing any difference in visual accuracy (if any) and where this approach may struggle in other computer vision tasks. While this proposed solution does clearly bring advantages, it most often is the case that certain sub-tasks in computer vision benefit whereas others actually show inferior results. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 13 Presentation: Linear Attention Mechanism: An Efficient Attention for Semantic Segmentation == | |||
=== Presented By === | |||
Yuke Liu, Mei Si | |||
=== Paper Citation === | |||
R. Li, J. Su, C. Duan, and S. Zheng, ‘Linear Attention Mechanism: An Efficient Attention for Semantic Segmentation’, Aug. 20, 2020, arXiv: arXiv:2007.14902. doi: 10.48550/arXiv.2007.14902. | |||
https://arxiv.org/abs/2007.14902 | |||
=== Introduction === | |||
(1) Attention Mechanism | |||
The attention mechanism gained prominence for its capacity to refine feature maps and to capture long-range dependencies in both computer vision and natural language processing. By focusing computational resources on the most relevant portions of the input data, attention mechanisms have proven especially valuable in tasks that demand context from large input sizes—such as high-resolution images or lengthy text sequences. | |||
(2) Dot Product Attention | |||
A widely used form of attention is the dot product attention. It forms the core of many state-of-the-art models, including Transformers in NLP and non-local modules in computer vision. Despite its strengths, dot product attention scales poorly (quadratically) with input size in both memory and computational cost. | |||
(3) Problem Statements | |||
Because the dot product attention mechanism requires <math> O(N^2) </math> operations for inputs of length <math> N </math> , deploying it on large inputs—e.g., high-resolution images or long sequences—becomes infeasible. This computational bottleneck has motivated research into more efficient variants of attention, particularly approaches that reduce the quadratic cost. | |||
=== Overview of Attention Mechanismn === | |||
(1) Scaling Attention | |||
In contrast to dot product attention, “scaling” attention mechanisms (sometimes also called “channel” or “spatial” attention in the literature) aim to emphasize informative features while reducing redundant information. Examples include: | |||
Squeeze-and-Excitation (SE) modules, which learn channel-wise scaling factors. | |||
Convolutional Block Attention Module (CBAM), which extends SE by considering both channel and spatial attention. | |||
These approaches do not necessarily learn long-range dependencies across positions in the same way as dot product attention; instead, they focus on enhancing or diminishing particular channels or spatial locations. | |||
(2) Linear Attention: Taylor Expansion → Linear Time | |||
Recent research efforts address the <math> O(N^2) </math> complexity of dot product attention by re-formulating the similarity function (the key step in attention). One family of methods uses kernel-based approximations, while another employs mathematical expansions. By approximating the exponential term of the softmax function (commonly used in dot product attention) with a first-order Taylor expansion, it is possible to make attention computations linear with respect to <math> N </math>. This insight forms the basis for linear attention mechanisms. | |||
(3) Semantic Segmentation | |||
Semantic segmentation requires dense predictions over every pixel of an image, and capturing global context is crucial. Traditional convolution-based networks, whether they follow a DilatedFCN (e.g., DeepLab, PSPNet) or an Encoder-Decoder (e.g., U-Net) architecture, benefit significantly from attention modules that can encode relationships across spatial locations. However, if the resolution is high, dot product attention quickly becomes prohibitively expensive, motivating more efficient variants like linear attention. | |||
=== Dot Product Attention Details === | |||
(1) Query, Key, and Value | |||
Dot product attention transforms each input feature <math> x_i </math> into three different representations: | |||
Query(<math> q_i </math>) | |||
Key(<math> k_i </math>) | |||
Value(<math> v_i </math>) | |||
The core operation then measures similarity between every query–key pair to weight the contribution of all values in forming an output. | |||
(2) Kernel-Based Approximation | |||
A known strategy to reduce complexity is to view the exponential term in dot product attention as operating in a kernel space, thereby factorizing the softmax attention to achieve lower complexity. Various works replace the explicit softmax with kernel-based transformations, enabling a more efficient computation. | |||
=== Linear Attention Mechanism === | |||
Linear attention leverages a Taylor expansion to approximate the exponential term. By applying additional constraints (such as <math> L_2 </math> normalization to keep the term non-negative), the resulting formulation scales as <math> O(N) </math> rather than <math> O(N^2) </math>. This makes the mechanism practical for high-resolution inputs or very long sequences, significantly broadening the usability of attention in semantic segmentation and beyond. | |||
=== Experimental Settings === | |||
(1) Dataset | |||
Many experiments on linear attention for segmentation are conducted using large, high-resolution datasets. One common benchmark highlighted is a satellite imagery dataset (e.g., Fine Gaofen Image Dataset, GID). These datasets typically comprise large aerial images that are partitioned into patches, with splits for training, validation, and testing. | |||
(2) Model Implementations | |||
Baseline segmentation networks (e.g., PSPNet, DeepLab series, U-Net variants, FastFCN, and RefineNet) integrate the proposed linear attention modules in place of, or alongside, standard attention mechanisms. The training setup often employs standard optimizers (such as Adam), cross-entropy loss, and hardware accelerators (like NVIDIA GPUs). | |||
(3) Evaluation Metrics | |||
Common segmentation metrics include: | |||
Overall Accuracy (OA) | |||
Average Accuracy (AA) | |||
Kappa Coefficient (K) | |||
mean Intersection over Union (mIoU) | |||
F1 Score | |||
===Evaluation metrics === | |||
'''Mean Intersection over Union (mIoU)''' | |||
<math> IOU = \frac{TP}{TP+FP+FN} </math> | |||
<math> mIoU = 1/N \sum_{i=1}^{N} IoU_i </math> | |||
*Measures the average overlap between predicted segmentation and ground truth across all classes. | |||
*Commonly used in semantic segmentation benchmarks. | |||
Why It’s Important in This Paper: | |||
*Serves as a primary benchmark for assessing spatial accuracy of LAM. | |||
*Especially meaningful in pixel-level classification, where precise boundaries matter. | |||
*Used to compare LAM-enhanced networks against traditional attention-based and baseline CNN models. | |||
*LAM achieves competitive or improved mIoU across multiple medical image segmentation datasets (DRIVE, STARE, CHASE_DB1), validating its contextual understanding with reduced computational cost. | |||
'''Kappa coefficient''' | |||
<math> \kappa = \frac{p_0 - p_e}{1-p_e} </math> | |||
*Assesses statistical agreement between predicted segmentation and ground truth, adjusting for agreement by chance | |||
Why It’s Important in This Paper: | |||
*Medical segmentation tasks often suffer from class imbalance (e.g., small vessels vs large background). | |||
*Kappa offers a robust metric that accounts for imbalanced distributions, unlike plain accuracy. | |||
*The paper reports high Kappa values for LAM-based models, showing that they make meaningful predictions beyond chance, even when foreground (vessel) pixels are rare. | |||
*This supports that LAM is not just "overfitting" to majority classes but learns class-relevant structure effectively. | |||
=== Results and Experimental Improvement === | |||
In benchmark experiments, linear attention typically boosts the performance of various baseline segmentation models while reducing memory and computational overhead. For example, when embedded into U-Net, PSPNet, or DeepLab, linear attention can achieve a higher mIoU and Kappa coefficient compared to the original dot product attention. These gains confirm that the approximation introduced by the Taylor expansion still captures sufficient global context for accurate segmentation. | |||
=== Comparative Analysis === | |||
Limitations: | |||
(1) Approximation Error: The first-order Taylor expansion introduces approximation errors relative to exact dot product attention. Although experiments show minimal performance degradation, in certain tasks or extreme input scales, further refinements or higher-order terms might be necessary. | |||
(2) Architecture Constraints: Integrating linear attention can require modifications to the existing network design, including normalization steps and careful parameter initialization. | |||
=== Conclusion and Future Work === | |||
Linear attention mechanisms substantially reduce both memory usage and computational cost in high-resolution vision tasks, making them a promising alternative to dot product attention. Future research may involve: | |||
Extending linear attention to multi-modal domains where extremely large inputs (e.g., video or multi-spectral data) are common. | |||
Investigating higher-order approximations that may yield even more accurate results while retaining near-linear scalability. | |||
Combining linear attention with other efficient modules (e.g., lightweight convolutions or quantization techniques) to further push the boundaries of real-time segmentation on resource-constrained devices. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 13 Presentation: Linear Attention Mechanism – An Efficient Attention for Semantic Segmentation == | |||
=== Presented By === | |||
Yuke Liu, Mei Si | |||
=== Paper Citation === | |||
R. Li, J. Su, C. Duan, and S. Zheng, *Linear Attention Mechanism: An Efficient Attention for Semantic Segmentation*, arXiv: [2007.14902](https://arxiv.org/abs/2007.14902) | |||
--- | |||
=== What's the Problem? === | |||
Transformers are powerful but **computationally expensive**—especially when working with large inputs like high-resolution images or long sequences. The root issue is their **quadratic complexity** in sequence length due to dot product attention. This makes scaling hard and slows things down. | |||
This paper tackles that bottleneck by proposing a more efficient method: **linear attention**, which keeps performance strong but dramatically cuts the cost. | |||
--- | |||
=== Key Idea: Linear Attention Instead of Dot Product === | |||
Dot product attention works, but it’s expensive. This paper introduces a smarter way by: | |||
- Approximating attention using a **first-order Taylor expansion** of the softmax function. | |||
- Rewriting the attention calculation to **scale linearly** with input size. | |||
- Enabling **O(N)** performance instead of **O(N²)**. | |||
The authors also suggest a **hybrid attention system** that can flexibly switch between local and global attention depending on context (like object density or resolution). For example, city scenes may need precise local attention; ocean images might not. | |||
--- | |||
=== Technical Details (In a Nutshell) === | |||
- Instead of computing the full softmax in standard attention, the authors reformulate the attention matrix using linear operations. | |||
- They apply kernel-based tricks to restructure the attention, reducing compute without sacrificing performance. | |||
- The result? Faster attention that still captures essential relationships. | |||
Final formula (simplified): | |||
\[ | |||
D(Q, K, V) = \frac{\sum_j v_j + (\left( \frac{K_i}{\sum_l K_l} \right)^T V)}{N + \left( \left( \frac{K_i}{\sum_l K_l} \right)^T K \right)} | |||
\] | |||
Don’t worry if it looks complicated—just know it’s faster and scales better. | |||
--- | |||
=== Dataset & Experimental Setup === | |||
- Dataset: Fine Grained Image Dataset (GID) with large aerial satellite images. | |||
- The dataset was split 60/20/20 into train/val/test sets. | |||
- Patch size: 256x256, total of 7,280 patches. | |||
- Used standard segmentation models like U-Net, Res101, PSPNet, DeepLab, etc. | |||
- Training on NVIDIA RTX 2080Ti using PyTorch. | |||
--- | |||
=== Metrics Used === | |||
- **OA (Overall Accuracy)** | |||
- **AA (Average Accuracy)** | |||
- **Kappa Coefficient** (accounts for chance agreement) | |||
- **mIoU (mean Intersection over Union)** – most important for segmentation | |||
- **F1 Score** | |||
--- | |||
=== How Well Does It Perform? === | |||
Here’s what the authors found: | |||
- **Competitive Accuracy**: Linear attention performs on par with dot product attention in almost all models. | |||
- **Better Efficiency**: Requires less time and memory, especially useful for high-res inputs. | |||
- **Scalability**: Works better in larger models or longer sequences. | |||
- **Flexible Integration**: Easily replaces attention in many architectures with minimal tweaks. | |||
--- | |||
=== Pros of Linear Attention === | |||
- Efficient: Great for long inputs or high-res images. | |||
- Accurate: Keeps precision high in segmentation tasks. | |||
- Flexible: Works across many architectures. | |||
- Practical: Easy to implement in PyTorch. | |||
--- | |||
=== Limitations & Trade-offs === | |||
- **Approximation error**: Taylor expansions are not exact—there’s some performance loss in rare cases. | |||
- **Architecture tweaks**: Some models may require re-tuning (normalization, initialization, etc.) to use linear attention smoothly. | |||
--- | |||
=== Final Thoughts & What’s Next === | |||
This paper shows that linear attention can be a drop-in, compute-efficient alternative to traditional attention for semantic segmentation. It’s especially promising for real-time systems or edge devices where every GPU cycle counts. | |||
**Future directions include:** | |||
- Extending to multi-modal inputs like video or multi-spectral images. | |||
- Testing higher-order Taylor approximations for even better accuracy. | |||
- Combining linear attention with other modules like lightweight convolutions or quantization. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 14 Presentation: Scalable Watermarking for Identifying Large Language Model Outputs == | |||
=== Presented by: === | |||
Ryan Tymkow and Benjamin Schnapp | |||
=== Paper Citation === | |||
Dathathri, S., See, A., Ghaisas, S. et al. Scalable watermarking for identifying large language model outputs. Nature 634, 818–823 (2024). https://doi.org/10.1038/s41586-024-08025-4. | |||
=== Summaries of key points === | |||
This paper tackles the problem of watermarking—embedding a detectable signal into the outputs of large language models (LLMs) to distinguish generated text from human-written content. The challenge is doing this in a way that is robust, scalable, and minimally intrusive to the model's output quality. The authors propose a method based on statistical biasing of token selection during generation. Specifically, they partition the vocabulary into “greenlist” and “redlist” tokens at each step based on a seeded hash function, and then bias sampling toward the greenlist using a tunable parameter. The watermark is invisible in individual outputs but detectable over longer texts using hypothesis testing. Importantly, the approach is model-agnostic, doesn’t require retraining, and adds very little computational overhead. It also scales well to large models and can be applied in high-throughput or real-time settings. Overall, it’s a lightweight yet effective strategy for watermarking that balances detectability, scalability, and usability. | |||
A key limitation of this approach is that it may still be vulnerable to paraphrasing or text transformation—simple rewriting could break the statistical signature. Another concern is adversarial robustness: since the watermarking method is relatively transparent (based on vocabulary partitioning), a knowledgeable attacker could design strategies to erase or spoof the signal. Additionally, while the method maintains fluency and quality in most cases, biasing token selection could subtly affect stylistic or semantic nuances, especially in creative writing or long-form tasks. The paper doesn’t deeply explore how this might influence user-facing applications like chatbots or summarizers. Lastly, while the watermark is statistically detectable, it’s not embedded in a cryptographic sense, so it may not offer strong guarantees in high-stakes verification contexts. | |||
=== Clear explanations to aid understanding === | |||
Imagine if every time a language model generated text, it subtly preferred certain words over others—but in a way that's invisible to readers. That’s what this watermark does. At each generation step, the model hashes its current context to choose a “greenlist” of preferred tokens and then slightly boosts their probabilities. Over many words, these choices form a statistically detectable pattern. It's like nudging a roulette wheel so certain numbers come up just a bit more often—not enough to be obvious, but enough to spot if you know where to look. The method is efficient and easy to integrate, since it works at the sampling level and doesn't require modifying the model architecture. | |||
=== Review === | |||
I thought that the visual aids in this presentation greatly helped explain the process of how synthid alters the sampling process and how it embeds watermarks during the process, helping distinguish outputs created by the LLM vs outputs not created by the LLM. The graphic showing the three main parts of the process, the random seed generator, the sampling algorithm, and the scoring function, made it simpler to understand the whole process. The examples with the fruits, first generating random watermarking functions, then going through the tournament to sample the output token, also made it really easy to follow along on what exactly is going on. | |||
Regarding some suggestions, perhaps including multiple examples of two outputs contrasting the encoded with the normal sampling output would help viewers better picture the elegance of the process. With regards to the limitations of this apporach, perhaps encoding over a longer horizon (multiple words to sentences) would make it more robust, although in the end, it could never be failproof by the fundamental essence of written language. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 14 Presentation: Scalable Watermarking for Identifying Large Language Model Outputs == | |||
=== Presented by: === | |||
Ryan Tymkow and Benjamin Schnapp | |||
=== Paper Citation === | |||
Dathathri, S., See, A., Ghaisas, S., et al. Scalable watermarking for identifying large language model outputs. Nature, 634, 818–823 (2024). https://doi.org/10.1038/s41586-024-08025-4 | |||
=== Key Summary === | |||
This paper addresses the challenge of watermarking outputs from large language models (LLMs)—embedding subtle, detectable signals to distinguish machine-generated text from human-written content. The goal is to design a method that is robust, scalable, and minimally intrusive to output quality. | |||
The authors propose a token biasing technique that works as follows: | |||
At each generation step, a seeded hash function is used to partition the vocabulary into a “greenlist” (preferred tokens) and a “redlist.” | |||
The model then slightly biases sampling towards the greenlist using a tunable parameter. | |||
The watermark is invisible in short outputs, but detectable across longer texts using statistical hypothesis testing. | |||
Key strengths: | |||
Model-agnostic: No retraining required. | |||
Low overhead: Minimal impact on inference speed or quality. | |||
Scalable: Suitable for large models and real-time systems. | |||
Practical: Balances watermark strength with natural text generation. | |||
=== Limitations and Concerns === | |||
Vulnerability to paraphrasing: Rewriting could disrupt the statistical signature. | |||
Adversarial robustness: Knowledgeable attackers could potentially remove or spoof the watermark. | |||
Stylistic influence: May subtly affect semantic or creative output, especially in long-form or artistic applications. | |||
Not cryptographically secure: Lacks strong guarantees for high-stakes verification or forensic scenarios. | |||
=== Clear Explanation for Intuition === | |||
Think of the watermark as a gentle push in the model’s word choice—like favoring certain dice rolls in a way that’s imperceptible up close but noticeable when you observe many outcomes. | |||
Each time the model generates a word, it computes a hash of the current context to decide which words get a slight boost. | |||
Over time, this creates a hidden statistical pattern, detectable if you know what to test for. | |||
It’s efficient, subtle, and works entirely during token sampling—no changes to the model itself. | |||
=== Presentation Review === | |||
The presentation effectively demystified the watermarking process. Visual aids were especially helpful: | |||
Diagrams showing the random seed generator, sampling mechanism, and scoring function clarified the full pipeline. | |||
The fruit-based examples for token selection and tournament-style sampling made the technical process intuitive and engaging. | |||
Suggestions for Improvement: | |||
Include side-by-side examples of outputs with and without watermarking to better illustrate the subtlety and strength of the method. | |||
Explore watermarking across longer horizons (e.g., multi-token or sentence-level biasing) to improve robustness—though this tradeoff remains an open challenge due to the fluid nature of language. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 14 Presentation: Scalable Watermarking for Identifying Large Language Model Outputs == | |||
=== Presented by: === | |||
Ryan Tymkow, Benjamin Schnapp | |||
=== Paper Citation === | |||
Dathathri, S., See, A., Ghaisas, S. et al. Scalable watermarking for identifying large language model outputs. Nature 634, 818–823 (2024). https://doi.org/10.1038/s41586-024-08025-4 | |||
=== Background === | |||
With the rise of LLMs and generative AI, there is a risk of spreading AF generated misinformation as if it were from a human. It is important to distinguish AI-generated text from human writing. There exists solutions to address this problem, but they all have limitations involving privacy and computational costs. For example, the traditional watermarking can result in unwanted artifacts in the text. | |||
=== Previous Approaches === | |||
Some of the existing approaches are Retrieval Based Tracking, Post-Hoc Detection, and Traditional Watermarking. | |||
Retrieval Based Tracking stores generated LLM responses in a reference database to determine whether a newly generated response | |||
is from the tracked LLM. The issue with this is scalability and privacy since you are storing generated responses. | |||
Post-Hoc Detection uses statistical features of LLM generated text. The issue with this is the high computational cost and as LLM continues to improve to be more human like with the output, the more difficult this approach becomes. | |||
Traditional watermarking would include hidden features in the text, for example replace synonyms or unicode characters. But tinkering with the output will degrade the quality of the original LLM. | |||
=== Technical Contributions === | |||
This paper introducted SynthID-Text, a watermarking method for large language models (LLMs). The method uses a Tournament sampling approach, which ensures that generated text contains a detectable watermark with minimal computational overhead. SynthID-Text incorporates a random seed generator and scoring functions to embed a watermark into the model's output. This technique enhances the ability to identify if text originates from a specific LLM, while preserving the text's quality and minimizing distortion. SynthID-Text does not affect LLM training. It can be configured as distortionary or non-distortionary. | |||
Central to SynthID-Text is the novel Tournament sampling procedure. Rather than sampling each token directly from the LLM's distribution, multiple candidate tokens compete in a multi-layer "tournament" based on pseudorandom watermarking scores, embedding a statistical “signature” that can later be detected. | |||
=== Results === | |||
Synth ID achieved a 95% true positive detection rate with 1% false positives in a 20 million interaction test on Google's Gemini chatbot. It offers high detection accuracy with minimal computational cost and can be configured for non-distortionary or distortionary watermarking. | |||
The benefits are as follows: | |||
Minimal impact on large language model training: | |||
-Synth ID text can be applied to any large language model with minimal modifications, as it only alters the sampling step. | |||
High detection accuracy with low computational cost: | |||
-Outperforms retrieval-based tracking, post hoc detection, and traditional watermarking methods. | |||
-Offers the best balance between computational cost, scalability, and accuracy. | |||
-Can be integrated into production environments using speculative sampling, where smaller models suggest tokens and the main model verifies their validity. | |||
Configurable distortion levels: | |||
-Allows for distortionary or non-distortionary configurations, enabling better control over the quality of generated text versus detection accuracy. | |||
-In non-distortionary watermarking, the average token distribution of the generated text matches the original model's token distribution. | |||
Combining Watermarking with Speculative Sampling: | |||
-SynthID-Text supports integration with speculative sampling to accelerate LLM inference. | |||
-It introduces two configurations: high-detectability and fast watermarked speculative sampling. | |||
-High-detectability configuration preserves watermark strength but may reduce generation speed. | |||
-Fast configuration maintains speed and requires non-distortionary watermarking. | |||
-The fast version uses a learned Bayesian scoring function to improve detectability. | |||
-This enables efficient deployment of watermarking in real-world production systems. | |||
-The approach balances performance, speed, and watermark traceability. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 14 Presentation: Scalable watermarking for identifying large language model outputs == | |||
=== Presenters === | |||
Ben Schnapp, Ryan Tymkow | |||
=== Paper Citation === | |||
Dathathri, S., See, A., Ghaisas, S. et al. Scalable watermarking for identifying large language model outputs. Nature 634, 818–823 (2024). https://doi.org/10.1038/s41586-024-08025-4 | |||
=== Introduction === | |||
Researchers at Google DeepMind have introduced a new method for watermarking the text generated by large language models (LLMs). Their work, published in Nature, details "SynthID-Text," a watermarking scheme designed to be practical for use in production systems. | |||
=== Background === | |||
LLMs are now capable of generating high-quality text that can be difficult to distinguish from human-written content. This raises concerns about the potential misuse of LLMs and the need for methods to identify AI-generated text. Existing methods for detecting AI-generated text have limitations, including computational cost, privacy concerns, and inconsistent performance. Text watermarking offers a potential solution by embedding a signal within the generated text that can be used for identification. | |||
=== Main Idea === | |||
The authors developed SynthID-Text, a generative watermarking scheme that modifies the LLM's sampling procedure to embed a watermark in the generated text. This approach allows for efficient watermark detection without requiring access to the underlying LLM. SynthID-Text introduces a novel "Tournament sampling" algorithm to achieve this. | |||
=== Experiments === | |||
The researchers evaluated SynthID-Text across multiple LLMs and conducted a live experiment with nearly 20 million Gemini responses to assess its performance in a real-world setting. They compared SynthID-Text with existing watermarking methods, focusing on text quality, detectability, and computational efficiency. | |||
=== Results === | |||
The key findings of the paper are: | |||
- SynthID-Text preserves text quality, as confirmed by human evaluations and automated metrics. | |||
- SynthID-Text provides improved watermark detectability compared to other methods. | |||
- SynthID-Text has minimal computational overhead, making it suitable for large-scale production systems. | |||
- The authors also developed an algorithm to integrate watermarking with speculative sampling, a technique used to speed up LLM text generation.
| |||
In conclusion, the paper presents SynthID-Text as a practical and effective solution for watermarking LLM-generated text, addressing key challenges in the responsible use of this technology. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 14 Presentation: Scalable watermarking for identifying LLM outputs: == | |||
=== Presented by: === | |||
Ben Schnapp,& Ryan Tymkow | |||
=== Overview === | |||
With AI, specifically language models, becoming | |||
increasingly prolific, there are strong incentives driving the | |||
development of technologies capable of discerning when text has been | |||
either fully or in part generated by a machine. At present, there are | |||
two methodologies which dominate this space, post-hoc detection, and | |||
watermarking. Post-hoc detection describes an attempt to train a | |||
classifier with the intent of detecting text outputted by an LLM. | |||
While sometimes successful, it is difficult to determine with absolute | |||
certainty whether the accusations made by post-hoc detectors are | |||
accurate, and further, they often suffer when they are being used for | |||
either a language which is outside of their distribution, or they are | |||
being employed on text from a unique LLM that was not part of the | |||
training set. These limitations are some of the reasons that | |||
engineers, developers and scientists are increasingly considering | |||
methods to watermark text at the time of generation. That is, include | |||
some kind of text artifact which readily identifies the text as LLM | |||
generated without negatively impacting the quality of the outputs. The | |||
authors of this paper propose a unique method for this task, | |||
specifically by means of what they dub their "Tournament sampling | |||
algorithm" | |||
=== Governing Principles === | |||
LLMs typically take in a series of tokens as input, producing a probability distribution over the space of possible next tokens, which is sampled and appended to the input sequence before calling the forward pass again. The goal in watermarking, is to imbue this sampling process with some signature unique to the generating model, such that it can be identified in a straightforward manner without the need for a second adversary model. Typically this is accomplished by combining a random seed generator, a sampling algorithm, and a scoring function.This functions by combining the random seed with a known proprietary key, before using this seed/key combo to influence sampling. The result is not deterministic, but rather pseudo-random, so that the behaviour of the LLM remains useful. The scoring function, is able to determine at the time of evaluation the probability that these tokens were drawn from our intentionally biased distribution. The key development in this paper is the introduction of many scoring functions, so that a "tournament" among scoring functions ultimately decides the which token the model selects. This is because designing a good scoring function is inherently difficult, so the authors want to bias the model toward selecting the token which is "most" identifiable of the viable tokens from the pseudorandom set. | |||
=== Implementation Details === | |||
The key implementation detail worth mentioning is how the scoring function works after tournament sampling. | |||
Specifically, tournament sampling ensures that the token which is most | |||
likely to score higher under the random watermaking functions is | |||
chosen. To detect watermarks, we measure how highly a piece of text | |||
scores with respect to all watermaking functions. Specifically we | |||
compute: | |||
<math> \text{Score}(x) = \frac{1}{mT} \sum_{t=1}^T \sum_{\ell=1}^m g_\ell (x_t,r_t) </math> | |||
Where <math> x_t </math> is the token sequence and <math> r_t </math> the seed + key | |||
=== Discussion / Conclusions === | |||
The key insight from this paper is | |||
that tournament sampling doesn't negatively impact performance from | |||
the user perspective. Because this was a deepmind paper, they had | |||
access to production models, and were able to complete user | |||
evaluations of output quality and A/B test with and without | |||
watermarking. Ultimately, they determined that users did not have a | |||
preference for non-watermarked text, that is that there watermarking | |||
method is able to preserve the quality of model outputs. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 14 Presentation: Scalable Watermarking For Identifying Large Language Model Outputs == | |||
=== Presented by === | |||
Ryan Tymkow and Benjamin Schnapp | |||
=== Paper Citation === | |||
Dathathri, S., See, A., Ghaisas, S. et al. Scalable watermarking for identifying large language model outputs. Nature 634, 818–823 (2024). https://doi.org/10.1038/s41586-024-08025-4 | |||
=== Background === | |||
- <b>Motivation:</b> With the increasing integration of AI in content creation, distinguishing AI-generated text from human-authored content has become crucial to prevent misinformation and ensure content authenticity. | |||
- <b>Existing Solutions:</b> Retrieval-based tracking involves storing all LLM outputs in a database for comparison. However, this method raises privacy concerns and faces scalability challenges. Post-hoc detection utilizes classifiers to detect AI-generated text after production but often suffers from false positives and high computational costs. Finally, traditional watermarking embeds hidden markers by altering the text, but this can degrade output quality. | |||
- SynthID-Text modifies the token selection process during text generation to embed watermarks while preserving the original text’s quality. Unlike traditional methods, it integrates watermarking seamlessly into the generation process without necessitating changes to the LLM’s training. | |||
- The standard LLM generation process starts off with tokenization, where the input text is divided into tokens. Then, the LLM predicts probabilities for the next token based on the input. Last, a token is chosen based on the predicted probabilities, and the process repeats. SynthID-Text embeds watermarks by altering the sampling step, ensuring the watermark is integrated without affecting the text’s coherence. | |||
=== Summaries of Key Points === | |||
- SynthID-Text has 3 key components: A random seed generator, a sampling algorithm, and a scoring function. The random seed generator generates a deterministic seed based on preceding tokens using a hash function, ensuring consistency in the watermarking process without altering the probability distribution. The sampling algorithm then modifies the token selection process to embed the watermark and incorporates a scoring function to prioritize tokens that align with the watermarking scheme. The scoring function evaluates how likely a token is to be a part of the watermarked sequence. It helps to facilitate the detection mechanism by assigning higher scores to watermarked tokens. | |||
- The tournament sampling step involves generating multiple candidate tokens and selecting the most suitable one based not he scoring function, which ensures the watermark is embedded without compromising the naturalness of the text. | |||
- The detection mechanism is used to determine if a given text is AI generated by checking if the given text contains the embedded watermark by analyzing token sequences and their associated scores. Factors such as text length and the entropy of the LLM distribution can influence the detection accuracy. | |||
- A benefit of SynthID-Text is that it does not require modifications to the LLM’s training process. Additionally, it has high detection accuracy, so it can effectively identify watermarked text with minimal false positives. It is also easily configured and can adjust to be distortionary or non-distortionary based on requirements. Another benefit is that it integrates seamlessly into the text generation process without a significant amount of overhead. | |||
=== Explanation of Tournament Sampling Step === | |||
- Watermarking functions guide the selection of tokens during text generation. | |||
- Although the process of selecting tokens is called “random”, these functions are deterministic, ensuring consistency in watermark embedding. | |||
- The process involves evaluating multiple token candidates and selecting the one that best aligns with the watermarking criteria. | |||
- During the watermarking process with SynthID, the LLM generates probabilities for the next token based on preceding tokens. Then, a set of candidate tokens is sampled and evaluated through the tournament sampling process. A watermarking seed is generated using a hash function that’s defined by a random key and the preceding tokens. This ensures the watermark is deterministically embedded into the generated text while preserving its fluency and coherence. | |||
- The tournament sampling step is crucial because it balances the trade-off between preserving the original probability distribution of the model and embedding a detectable watermark, allowing high-quality text generation without compromising detection reliability. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 15 Presentation: DiGress: Discrete denoising diffusion for graph generation== | |||
=== Presented by: === | |||
Sean Tang, Buji Wong | |||
=== Paper Citation === | |||
Vignac, C., Krawczuk, I., Siraudin, A., Wang, B., Cevher, V., & Frossard, P. (2022). Digress: Discrete denoising diffusion for graph generation. arXiv preprint arXiv:2209.14734. | |||
=== Background === | |||
'''Graph generation''' | |||
The goal of this project is to generate graphs, which are represented by node matrices and edge matrices. Edges and nodes can also have their own categories. One application of this is molecule generation: atoms would be represented by nodes and the chemical bonds would be represented by edges. | |||
The challenge of graph generation is a complex task due to the unordered nature and sparsity of graphs. While denoising diffusion models have been successful in other domains like images, they struggle with graphs due to their structural properties. Existing approaches that use continuous noise models disrupt the sparsity and connectivity crucial for graphs. | |||
'''Discrete diffusion''' | |||
Regular diffusion methods involve a forward process (where noise is gradually added to the data) and a neural network that is trained to predict the backwards process (where random noise is gradually turned back into something that would plausibly fit in with the dataset). Most of the time, diffusion models use Gaussian noise. | |||
This graph generation problem involves ''discrete'' data, therefore, a ''discrete'' noise model is preferred for the diffusion process. (Note: The authors actually previously tried to use Gaussian noise, but the continuous nature of the Gaussian function meant that the discreteness of the graphs was not respected.) In this discrete case, however, the diffusion is defined in terms of ''transition matrices'', <math> Q^t </math>: | |||
<math> q(x_t | x_{t-1}) = x_{t-1} Q^t </math> | |||
Each transition matrix is applied to each node and edge; this allows the nodes and edges to remain in discrete space. | |||
=== Technical Contributions === | |||
'''Overview of DiGress''' | |||
The authors introduce DiGress, a discrete denoising diffusion model designed specifically for graph generation with categorical node and edge attributes. DiGress improves graph generation by using a discrete noise model that preserves graph sparsity and structural properties. The model involves progressively editing graphs through edge addition/removal and attribute changes. A graph transformer network is used to reverse this noisy process using cross-entropy loss, sampling from the trained model by iteratively updating the noise level and computing structural features. | |||
Key enhancements include a noise model that maintains node and edge type distributions, a guidance procedure for conditioning on graph-level features, and the use of auxiliary graph-theoretic features. DiGress achieves state-of-the-art performance on both molecular and non-molecular graph datasets. | |||
'''Denoising network architecture''' | |||
* The input is a noisy graph, expressed as <math> X </math> (a tensor of one-hot encoded nodes) and <math> E </math> (a tensor of one-hot encoded edges). | |||
* The structural and spectral features of these graphs are calculated. | |||
* The noisy graph along with the structural and spectral features are taken as input into a multi-layer perceptron (MLP) structure. | |||
* Next is a sequence of graph transformer layers. | |||
* After another MLP layer, the output is a distribution over nodes and edges. | |||
'''Summary of results''' | |||
Results showed Digress outperformed continuous diffusion methods on various metrics, including degree distribution, clustering, and novelty, and was more scalable for larger graphs. Moreover, in the creation of novel molecules, discrete diffusion aids scalability for larger graphs and molecules, making it more efficient compared to continuous diffusion. DiGress is the first one-shot graph-based model that feasibly trains on over a million molecules without fragment-based heuristics. Its performance on drug-like molecule benchmarks reaches or exceeds that of specialized autoregressive or SMILES-based baselines. | |||
=== Accomplishments === | |||
* '''State-of-the-art one-shot graph generation''': | |||
** DiGress outperforms existing non-autoregressive models such as GDSS, GraphNVP, and SPECTRE. | |||
** Achieves strong performance on benchmarks including Stochastic Block Models (SBM) and planar graphs. | |||
* '''Scalable molecular graph generation''': | |||
** DiGress is the first diffusion-based model to scale to large molecular datasets such as: | |||
*** MOSES (small, drug-like molecules) | |||
*** GuacaMol (1.3M large molecules) | |||
** Matches or exceeds autoregressive and fragment-based baselines on metrics such as validity, uniqueness, novelty, and scaffold diversity. | |||
* '''Fast and efficient performance on QM9''': | |||
** Achieves near-perfect: | |||
*** Validity: 99% | |||
*** Uniqueness: ~96% | |||
** Trains significantly faster than prior diffusion models (1 hour vs. 7.2 hours for ConGress). | |||
* '''Effective conditional generation''': | |||
** Uses regressor-guided sampling to generate molecules with desired properties (e.g., HOMO energy, dipole moment). | |||
** Outperforms unconditional models in property control and accuracy. | |||
* '''Theoretical soundness''': | |||
** Proven permutation equivariance and exchangeability for correct graph generation. | |||
** Provides an ELBO-based training objective for likelihood estimation and model comparison. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 14 Presentation: Scalable Watermarking for Identifying Large Language Model Outputs == | |||
=== Presented by: === | |||
Ryan Tymkow and Benjamin Schnapp | |||
=== Paper Citation === | |||
Dathathri, S., See, A., Ghaisas, S. et al. (2024). Scalable watermarking for identifying large language model outputs. Nature, 634, 818–823. https://doi.org/10.1038/s41586-024-08025-4 | |||
=== Summaries === | |||
This paper addresses the need to reliably detect AI-generated text, particularly from large language models (LLMs), in the context of growing concerns over misinformation, authorship verification, and content authenticity. Existing approaches—such as retrieval-based tracking, post-hoc statistical detection, or traditional watermarking—face challenges related to scalability, privacy, and output degradation. | |||
The authors propose a lightweight and scalable watermarking method that subtly biases the LLM’s sampling process without modifying its architecture or requiring retraining. The method partitions the vocabulary at each generation step into two lists using a seeded hash function: | |||
Greenlist: tokens favored for selection | |||
Redlist: tokens slightly penalized | |||
By biasing the LLM to favor greenlist tokens, a statistically detectable pattern is formed over longer texts. This watermark is invisible in short sequences but detectable through hypothesis testing in larger samples. The approach is model-agnostic and adds minimal overhead, making it suitable for real-time or high-throughput deployment. | |||
=== Key Contributions === | |||
Token-Level Watermarking via Sampling Bias: | |||
Rather than embedding hidden characters or syntactic artifacts, the method adjusts token selection probabilities via a seeded hash function. | |||
Statistical Detection Over Long Outputs: | |||
The watermark is not intended to be seen in any one sentence, but accumulates over time and is verified through statistical analysis. | |||
Model-Agnostic and Training-Free: | |||
This method requires no retraining and can be applied to any autoregressive LLM. | |||
Low Computational Cost: | |||
Because it operates only at sampling time, the watermarking process adds negligible runtime cost. | |||
Configurable Distortion: | |||
The bias strength is adjustable, enabling trade-offs between output quality and detection strength. In non-distortionary mode, the output token distribution closely resembles the unwatermarked model. | |||
=== Constructive Critiques or Reviews === | |||
The watermark can be vulnerable to paraphrasing or text manipulation, which may erase or weaken the statistical signal. | |||
Since the watermarking logic (vocabulary partitioning via hashing) is relatively transparent, adversarial users might reverse-engineer or spoof the watermark. | |||
There is minor concern about stylistic or semantic drift, especially in long-form generation or creative tasks. | |||
The method lacks cryptographic strength, limiting its application in high-assurance or legal verification scenarios. | |||
=== Related Works === | |||
Retrieval-Based Tracking: | |||
Stores model outputs in a reference database. Issues: privacy concerns, poor scalability. | |||
Post-Hoc Detection: | |||
Classifies AI vs. human text using statistical features. Becomes less effective as LLMs improve and mimick human writing. | |||
Traditional Watermarking: | |||
Inserts visible or invisible tokens or character-level perturbations (e.g., synonyms, Unicode tweaks). Problematic due to quality degradation and easy circumvention. | |||
In contrast, this paper’s statistical watermarking strikes a better balance among robustness, usability, and deployment feasibility—particularly suitable for integration into production systems like Google’s Gemini, which reported 95% true positive rate with only 1% false positives over 20 million interactions using this method. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 14 Presentation: Scalable Watermarking for Identifying Large Language Model Outputs == | |||
=== Presented by: === | |||
Ryan Tymkow and Benjamin Schnapp | |||
=== Paper Citation === | |||
Dathathri, S., See, A., Ghaisas, S. et al. (2024). Scalable watermarking for identifying large language model outputs. Nature, 634, 818–823. https://doi.org/10.1038/s41586-024-08025-4 | |||
=== Summaries === | |||
The paper presents SynthID-Text, a watermarking scheme designed to detect synthetic text generated by large language models (LLMs). The authors aim to address the challenge of distinguishing between human-written and AI-generated content, which has become increasingly difficult with the advancement of LLMs. | |||
=== Key Contributions === | |||
The authors introduce SynthID-Text, a watermarking method that integrates with the sampling process of LLMs to embed identifiable markers into generated text. This approach allows for the detection of AI-generated content without compromising text quality. The paper reports high detection accuracy of the watermarked text across multiple LLMs. Standard benchmarks and human evaluations indicate that SynthID-Text maintains the original capabilities and quality of the LLM outputs. | |||
=== Constructive Critiques or Reviews === | |||
The study primarily focuses on general text generation. Evaluating SynthID-Text's performance across various text genres and domains, such as technical writing or creative literature, would offer insights into its versatility and potential limitations. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 15 Presentation: DiGress: Discrete denoising diffusion for graph generation== | |||
=== Presented by: === | |||
Sean Tang, Buji Wong | |||
=== Paper Citation === | |||
Vignac, C., Krawczuk, I., Siraudin, A., Wang, B., Cevher, V., & Frossard, P. (2022). Digress: Discrete denoising diffusion for graph generation. arXiv preprint arXiv:2209.14734. | |||
=== Summaries of key points === | |||
==== Motivation ==== | |||
Recall that in class we already learned diffusion models, they work by adding noise to data gradually (forward process) and training a model to reverse the process, restoring the clean data (reverse process). Traditional diffusion models are designed for continuous data, but graphs are genereally discrete. This discrete structure makes it difficult to directly apply traditional diffusion models to graphs. | |||
DiGress tackles this issue by introducing a method that applies diffusion to discrete graphs through the concept of discrete diffusion. | |||
==== Forward Diffusion Process: How to Add Discrete Noise? ==== | |||
As mentioned above, DiGress uses discrete noise. In DiGress, noise is added by changing the categories of nodes and edges (such as atom types or bond types). This means the model randomly selects new categories for nodes and edges from a predefined set of valid categories. To guide this noise addition, DiGress uses a transition matrix. The transition matrix defines the probability of each node and edge transitioning between categories at each step. This ensures that while the graph becomes noisier over time, it still maintains its structure and validity. | |||
To improve the quality of the noisy graphs during diffusion, DiGress introduces a Markovian noise model that preserves the marginal distributions of node and edge types observed in the training data. Rather than using uniform transitions, which can create unrealistic dense graphs, this model defines transition matrices where the probability of changing to a new category is proportional to its frequency in the dataset. This approach helps maintain realistic sparsity and structure throughout the diffusion steps, making the reverse denoising process easier and more effective. | |||
==== Reverse Diffusion Process: How to Denoise a Graph? ==== | |||
In the reverse process, DiGress gradually removes the noise from the graph. Instead of randomly undoing the noise, the model uses a graph transformer network, designed for graph-structured data. This network helps the model recognize the structure of the graph and the relationships between nodes and edges. During each step, the model focuses on the most relevant parts of the graph, predicting the correct categories for nodes and edges. And the model’s predictions are guided by cross-entropy loss (applied to both nodes and edges), which measures how accurately the model predicts the node and edge categories after denoising. By minimizing this loss, the model becomes better at removing the noise, step by step, until it recovers a valid and meaningful graph. | |||
==== Conditioning Graph Generation on Desired Properties ==== | |||
One of the most powerful features of DiGress is its ability to condition the graph generation process on specific properties. For example, if you want the generated graph to have a certain number of oxygen atoms or satisfy some chemical property, DiGress can adjust the generation process to meet those requirements. This is done by checking the graph at each step of the sampling process and modifying it as needed to match the desired properties. This capability is particularly useful in areas like drug discovery, where the generated molecules must meet certain chemical and structural criteria. | |||
==== Experimental Results ==== | |||
DiGress was tested on several datasets to evaluate its performance: | |||
1. Graph Generation: On datasets like the Stochastic Block Model (SBM) and planar graphs, DiGress outperformed other methods, particularly in generating novel graphs that were not seen during training. | |||
2. Molecule Generation: When applied to the MOSES dataset, DiGress produced more valid and novel molecules, even though it did not always surpass methods that check graph validity at every step. | |||
3. Scalability: On larger graphs, such as those from the Guacamole dataset, DiGress demonstrated strong scalability, making it a suitable option for generating larger and more complex graphs. | |||
=== Comparison with Existing Approaches === | |||
- Versus Autoregressive Models: These models (like GraphAF, GraphRNN) generate graphs node-by-node or edge-by-edge and often rely on ordering, making them slower and harder to parallelize. | |||
- Versus Continuous Diffusion Models for Graphs: E.g., GDSS uses Gaussian noise and struggles with categorical data. DiGress handles discrete data directly, making it more suitable for molecular and symbolic domains. | |||
- DiGress Advantage: Fully parallel sampling, no need to learn generation order, works better for discrete structured data. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 15 Presentation: DIGRESS: DISCRETE DENOISING DIFFUSION FOR GRAPH GENERATION == | |||
=== Motivation === | |||
If one is familiar with graph theory, you will not need convincing that they are some of the most useful mathematical objects. Consisting of vertices (nodes) and edges, graphs can be used to represent a wide variety of relational data types. Examples include: | |||
Molecules, where atoms are nodes and bonds are edges | |||
Social networks, where people are nodes and connections are edges | |||
Traffic systems, where intersections are nodes and roads are edge | |||
etc etc. | |||
As such, one could imagine why being able to generate such data types would be useful for research purposes. One could imagine training a generative model to generate all kinds of plausible graphs, for purposes such as material science or drug discovery. | |||
If one is familiar with graph theory, you will not need convincing that graphs are some of the most powerful and flexible mathematical structures. Consisting of vertices (nodes) and edges, graphs are capable of representing a wide range of relational data types. Examples include: | |||
Traditional generative models struggle with graph data because of its non-Euclidean structure, as such there is a use case for a model specifically tuned to this use case. | |||
=== Operating principle === | |||
The key thought process behind DiGress is to adapt denoising diffusion probabilistic models (DDPMS) to a graph domain. For the forward process,the DiGress authors define a markov chain of noisy graphs, with each step gradually adding noise to the node features and the graph structure. At each step <math> t </math> the graph <math> G_t </math> will be a noised version of the graph <math> G_{t-1} </math>. At time T, the graph <math> G_T </math> can be thought of as a real graph having been turned into pure noise. The goal of DiGress is to learn how to reverse this noising process, or to generate a feasible graph (given a certain training data) provided a noisy input signal. This process begins by sampling noise from a given distribution, after which the denoising network is used to predict the noise that was added at that step. This process is repeated, sometimes with the addition of a small amount of fresh noise, until the sample is no longer noisy and we are left with a believable sample. | |||
=== Technical Details === | |||
Similar to image diffusion models, which apply noise to each pixel independently, DiGress applies noise on each node and edge. As a result, the state space is not that of all possible graphs, which would be too enormous because of the explosion in the size of the transition matrix. The denoising network is traine dto minimize a cross-entropy loss between predicted and true node/edge dsitributions according to: | |||
<math> l(\hat{p}^G, G) = \sum_{1 \leq i \leq n} \text{CE}(x_i, \hat{p}_X^i) + \lambda \sum_{1 \leq i,j \leq n} \text{CE}(e_{ij}, \hat{p}_E^{ij}) </math> | |||
The model learns a reverse diffusion distribution <math> p_\theta(G^{t-1} | G^t) </math> for sampling clean graphs | |||
=== Results / Conclusion === | |||
DiGress has thus far proven to be SOTA for both molecular and non momlecular datasets. Specifically, it achieved SOTA performance on the GuacaMol dataset, containing >> 1.3 million drug like compounds. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 15 Presentation: DiGress: Discrete denoising diffusion for graph generation== | |||
=== Presented by: === | |||
Sean Tang, Buji Wong | |||
=== Paper Citation === | |||
Vignac, C., Krawczuk, I., Siraudin, A., Wang, B., Cevher, V., & Frossard, P. (2022). Digress: Discrete denoising diffusion for graph generation. arXiv preprint arXiv:2209.14734. | |||
=== Summaries === | |||
The paper proposed a Discrete Graph Denoising Diffusion model (DiGress) to generate graphs. This model can overcome the limitations that existing graph generation models have that are in a continuous space but destroy the graph's sparsity and connectivity. The model can handle graphs with categorical nodes and edge attributes. The model has two steps: the diffusion process that adds noise to the original graph, then running a transformer network to invert the noise. This model satisfies the properties of efficiency. The model specifically tackles the challenge that graphs are insensitive to the order of nodes. The paper also discusses the choice of the noise model which should should have a distribution close to the true data distribution. The model also has the ability to generate graphs conditional on graph-level which would be beneficial for real-life applications. | |||
=== Key Contributions === | |||
'''Diffusion process''': The diffusion process has two steps: adding noise to the original graph by the noise model, then invert the noises by the denoising neural network. The diffusion is done separately for each node and edge attribute. | |||
'''Noise model''': The noise model creates a sequence of noisy data points that has a Markovian structure with the transition probability <math>q(z^1,...z^T | x)=q(z^1 | x) \prod_{t=2}^T {q(z^t | z^{t-1})}.</math> | |||
'''Denoising neural network''': This process predicts noise <math>z^{t-1}</math> from <math>z^t</math>. This is done by minimizing the cross-entropy loss for each node and edge attribute: <math>l(\hat{p}^G, G) = \sum_{i \le i \le n} {cross-entropy (x_i, \hat{p}_i^X)} + \lambda \sum_{1 \le i, j \le n} {cross-entropy (e_{ij}, \hat{p}_{ij}^E)}</math>, where X is the matrix of all one-hot encodings, E is the tensor, <math>\hat{p}</math> is the predicted probability, i, j represent rows and columns in the matrix for nodes. This formula considers the relative importance of nodes and edges by including a parameter <math>\lambda</math>. Then the distribution of the next noisy graph is calculated. The next noisy graph is sampled from this distribution and used as the input for the next step. | |||
'''Properties of Efficiency''': This is satisfied by the existing continuous graph generation methods, which should also be met for the discrete graph generation methods. | |||
* The distribution of the noisy data has a closed-form formula. This achieves efficiency because the noises are not added recursively. | |||
* The posterior distribution <math>q(z_{t-1} | z_t, x)</math> has a closed-form formula. With this condition, the original data points can be used as the target of the denoising network. | |||
* The limit distribution of the noise model <math>q_{\infty} = \lim_{T \rightarrow \infty} {q(z^T | x)}</math> does not depend on the original data points. With this condition, this distribution can be used as the prior distribution to be more efficient. | |||
=== Related Works === | |||
Existing works have utilized Gaussian noise in a continuous setting to add noise to the node features and graph adjacency matrix (Nui et al., Jo et al., 2022). However this method doesn't keep the most significant characteristics for graphs that they are sparse, insensitive to order of nodes and other structural properties. | |||
Other existing works that explored discrete diffusion model are applied to images, texts and audios and haven't the unique challenges of generating graphs. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 15 Presentation: DiGress: Discrete Denoising Diffusion for Graph Generation == | |||
=== Presented by: === | |||
Sean Tang, Buji Wong | |||
=== Paper Citation === | |||
Vignac, C., Krawczuk, I., Siraudin, A., Wang, B., Cevher, V., & Frossard, P. (2022). Digress: Discrete denoising diffusion for graph generation. arXiv preprint arXiv:2209.14734. | |||
=== Introduction === | |||
Researchers have introduced a new method for generating graphs with categorical node and edge attributes called DiGress. This work, published in the proceedings of the ICLR 2023 conference, presents a discrete denoising diffusion model designed for graph generation. | |||
=== Background === | |||
Diffusion models have shown remarkable success in various domains, particularly in image and video generation. The ability of diffusion models to outperform other generative methods in these areas has motivated researchers to explore their potential for graph generation. However, generating graphs presents unique challenges due to their unordered nature and sparsity. Previous approaches that applied continuous diffusion models with Gaussian noise to graphs have struggled to preserve the inherent structural properties of graphs, such as sparsity and connectivity. | |||
=== Main Idea === | |||
DiGress addresses these challenges by employing a discrete denoising diffusion model. The model progressively edits graphs through a Markov process that involves adding or removing edges and changing node or edge categories. A graph transformer network is then trained to reverse this process, effectively learning to denoise the graph. This approach simplifies the complex task of graph distribution learning into a series of node and edge classification tasks. | |||
=== Experiments === | |||
The authors evaluated DiGress on both molecular and non-molecular datasets, comparing its performance against state-of-the-art graph generation methods. Their experiments included: | |||
* Unconditional generation on stochastic block model (SBM) and planar graph datasets. | |||
* Conditional generation on the QM9 dataset to assess the model's ability to generate graphs with specific properties. | |||
* Large-scale molecule generation on the MOSES and GuacaMol datasets. | |||
=== Results === | |||
The results of the experiments demonstrate that DiGress achieves state-of-the-art performance in graph generation. Key highlights include: | |||
* DiGress exhibits strong performance on both molecular and non-molecular datasets. | |||
* The model shows significant improvements in validity, particularly on the planar graph dataset.
| |||
* DiGress is the first model to successfully scale to the large GuacaMol dataset, achieving performance comparable to autoregressive models. | |||
In summary, DiGress introduces a novel and effective approach to graph generation by leveraging discrete denoising diffusion models. The model's ability to handle discrete graph structures and scale to large datasets represents a significant advancement in the field of graph generation. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 16 Presentation: Machine Learning and Hamilton-Jacobi-Bellman Equation for Optimal Decumulation: a Comparison Study == | |||
=== Presented by: === | |||
Zeyu Zhang | |||
=== Paper Citation === | |||
Chen M, Shirazi M, Forsyth PA, Li Y. Machine Learning and Hamilton-Jacobi-Bellman Equation for Optimal Decumulation: a Comparison Study. Published online 2023. doi:10.48550/arxiv.2306.10582 | |||
=== Background === | |||
The paper is based on computational finance, focusing on the optimization problem related to "defined benefit" and "defined contribution plans". | |||
The main focus is on the challenge of ensuring retirees have enough funds for their retirement. | |||
Two key plans were discussed: | |||
"Defined benefit plans" guarantee fixed monthly payments based on factors like tenure and salary but are cost-prohibitive and risky. | |||
"Contribution plans" shift the investment and withdrawal strategy burden to individual investors, but they struggle to balance maximizing withdrawals and minimizing risk. | |||
This problem, often called the "Nazi's hardest problem in finance," highlights the complexity of balancing risk and reward in financial planning for retirement. | |||
The 4% rule is a traditional method recommending a constant 4% withdrawal each year, adjusted for inflation, and investing in stocks and bonds. | |||
Despite its popularity, the 4% rule is suboptimal and not globally optimal | |||
Peter Fauci proposed the HJB PDE method to maximize expected withdrawal and minimize the risk of running out of savings. | |||
The HJB PDE method uses scalarization techniques to achieve Pareto optimal points, but it has limitations. | |||
=== Technical Contributions === | |||
1. Hamilton-Jacobi-Bellman (HJB): | |||
- The problem formulation involves complex mathematical equations related to computational finance. | |||
- The problem uses dynamic programming to break down the optimal control problem, leading to the HJB function that represents the value function. | |||
- The paper assumes stock and bond prices follow a jump diffusion model. | |||
- The investors' total wealth at time <math> t </math> is defined as the sum of stock price and bond price at that time. | |||
- The capital <math> T </math> is set to 30 years, and rebalancing times are defined with discrete withdrawal amounts and allocation for stocks and bonds. | |||
2. Neural Network (NN): | |||
Control and Objective Function: | |||
- The control at time <math>T_i </math> includes the withdrawal amount <math> Q_i </math> and allocation for the wealth at time <math> T_i^- </math>. | |||
- The admissible control set is defined, and the expected shortfall is introduced as a measure of risk. | |||
- The expected total withdrawal is used as a measure of reward, aiming to maximize the expected total withdrawal while minimizing the expected shortfall. | |||
- The pre-commitment in the expected shortfall problem is defined, focusing on maximizing the expected total withdrawal and minimizing the expected shortfall. | |||
=== Neural Network (NN) Formulation === | |||
As an alternative to the HJB framework, the authors propose a Neural Network (NN) approach to solve the stochastic control problem. This framework has several advantages: | |||
1. The NN approach is data-driven, meaning it avoids explicitly defining parametric models for stochastic processes. This provides flexibility and allows integration of auxiliary variables if needed. | |||
2. It circumvents the computation of high-dimensional conditional expectations by solving a single, unconstrained optimization problem for control decisions. This avoids the curse of dimensionality often associated with dynamic programming. | |||
3. If the optimal control is continuous in time and state, the NN reflects this property. If the control is discontinuous, the NN yields a smooth approximation, which is beneficial in practice for implementing investment policies. | |||
4. The method is scalable, making it suitable for long horizons and high-frequency rebalancing without significantly increasing computational complexity. | |||
The NN framework’s effectiveness lies in its ability to learn control policies directly from simulated state paths without requiring explicit knowledge of the underlying stochastic differential equations. | |||
Instead of solving high-dimensional HJB PDEs, the neural network uses: | |||
* Forward simulation to sample wealth trajectories under candidate policies. | |||
* Backward evaluation to update network parameters based on performance (e.g., maximizing expected withdrawals, minimizing expected shortfall). | |||
This model-free, data-driven method avoids dynamic programming and is especially useful in high dimensions, where solving PDEs becomes computationally infeasible. | |||
Moreover, by designing appropriate activation functions (e.g., softmax for portfolio weights and sigmoid for withdrawal rates), the NN ensures that stochastic constraints are naturally respected throughout training and inference. | |||
==== NN Approximation Setup ==== | |||
* The control policy <math>\mathcal{P}</math> is approximated using two feed-forward neural networks, with parameters <math>\boldsymbol{\theta}_q</math> and <math>\boldsymbol{\theta}_p</math>, representing withdrawal and allocation strategies respectively. | |||
* These networks take as inputs the Brownian motion path <math>W(t_i)</math> and time <math>t_i</math> to approximate control decisions: | |||
<math> | |||
\hat{q}(W_i^-, t_i^-, \boldsymbol{\theta}_q) \approx q_i(W_i^-), \quad | |||
\hat{p}(W_i^+, t_i^+, \boldsymbol{\theta}_p) \approx p_i(W_i^+) | |||
</math> | |||
- The final control policy is given by: | |||
<math>\hat{\mathcal{P}} = \{ (\hat{q}(\cdot), \hat{p}(\cdot)) \} \approx \mathcal{P}</math> | |||
* The functions <math>\hat{p}</math> and <math>\hat{q}</math> use time as one of the inputs, allowing a single NN to handle decisions across all rebalancing points, rather than training separate models for each time step. | |||
* The paper also discusses how the architecture includes activation functions that enforce stochastic constraints naturally. | |||
=== Summaries of Key Notes === | |||
- Neural Network Framework for Pension Decumulation: A novel framework using neural networks (NNs) is proposed to optimize asset allocation and cash withdrawal strategies for defined contribution (DC) pension plans. Unlike traditional methods, it solves constraints efficiently via unconstrained optimization. | |||
- Comparison with HJB Method: The NN approach achieves comparable accuracy to the Hamilton-Jacobi-Bellman (HJB) PDE method while being scalable to higher-dimensional problems and avoiding dynamic programming errors. | |||
- Efficient Withdrawals: The NN framework closely approximates optimal "bang-bang" controls, effectively alternating between minimum and maximum withdrawals based on wealth, ensuring reliable pension decumulation. | |||
- Robustness: Tested extensively on synthetic and historical market data, the NN solution adapts well and demonstrates strong out-of-sample and out-of-distribution performance. | |||
Advantages: | |||
=== Constructive Critique === | |||
While the NN method replicates HJB-derived control policies with high accuracy, a few limitations and caveats exist: | |||
* The training relies on synthetic data simulated from assumed models (e.g., geometric Brownian motion, jump diffusion), which may limit generalizability under real-world dynamics. | |||
* The scalarization of the reward-risk tradeoff assumes a linear weighting of Expected Withdrawals and Expected Shortfall. In practice, retiree preferences might reflect more complex, utility-based behaviors. | |||
* The interpretability of the learned policy is less transparent compared to explicit, closed-form control strategies derived via HJB. | |||
* No formal convergence analysis or approximation bounds are provided for the NN solution. | |||
Despite these challenges, the method is empirically robust and scalable, making it an appealing alternative for large-scale or real-time applications. | |||
=== Related Works === | |||
The work is also related to "Spending Retirement on Planet Vulcan: The Impact of Longevity Risk Aversion on Optimal Withdrawal Rates", which introduces utility-based models for withdrawal strategies under longevity risk aversion. The models focus on retirees' behavioral preferences — particularly longevity risk aversion and the need for smooth consumption throughout retirement.<br> | |||
While not PDE-based, it contextualizes the trade-offs between consumption smoothing and risk of depletion, complementing the (EW, ES) approach by addressing behavioral and utility-driven objectives. <br> | |||
On the other hand, the benchmarkNNpaper takes a risk-sensitive stochastic control approach, optimizing a scalarized objective that balances: | |||
Expected Withdrawals (EW) = proxy for consumption <br> | |||
Expected Shortfall (ES) = proxy for downside risk / depletion probability | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 16 Presentation: Machine Learning and Hamilton-Jacobi-Bellman Equation for Optimal Decumulation: a Comparison Study == | |||
=== Presented by: === | |||
Zeyu Zhang | |||
=== Paper Citation === | |||
Chen M, Shirazi M, Forsyth PA, Li Y. Machine Learning and Hamilton-Jacobi-Bellman Equation for Optimal Decumulation: a Comparison Study. Published online 2023. doi:10.48550/arxiv.2306.10582 | |||
=== Background & Motivation=== | |||
The paper focuses on addressing a stochastic optimal control problem in retirement decumulation and asset allocation, which is a critical issue in financial planning. Specifically, it investigates how retirees can optimally withdraw funds from their savings (decumulation) while simultaneously managing asset allocation under uncertain market conditions. | |||
Traditionally, rules of thumb such as the Bengen 4% Rule have been widely adopted in the financial industry to guide withdrawal strategies. However, these approaches are increasingly viewed as suboptimal, particularly in an environment characterized by volatile markets and evolving mortality patterns. Recent academic studies, such as Forsyth (2022), propose partial differential equation (PDE)-based methods that are provably convergent and optimal under specific assumptions. Nevertheless, PDE methods face significant limitations in scalability due to the curse of dimensionality, often performing well only in low-dimensional settings. | |||
The motivation for this paper is to overcome the limitations of PDE-based approaches by leveraging neural networks (NNs) to solve the decumulation and asset allocation control problem. The authors aim to evaluate whether deep learning can accurately and robustly approximate the solution to this high-dimensional stochastic control problem, as well as whether it provides computational advantages. | |||
=== Key Points=== | |||
1. '''Problem Formulation''': The paper formulates the decumulation problem as a stochastic optimal control problem, aiming to optimize a weighted sum of expected withdrawals (EW) and expected shortfall (ES) to effectively manage tail risk. Key constraints include minimum and maximum withdrawal limits, as well as no-shorting and no-leverage rules. | |||
2. '''HJB Framework''': The Hamilton-Jacobi-Bellman (HJB) approach employs dynamic programming to solve the problem numerically, providing a ground-truth benchmark for comparison. However, this method is computationally limited to low-dimensional problems and relies on parametric models for asset returns, which may not capture real-world complexities. | |||
3. '''NN Framework''': The proposed neural network (NN) framework directly approximates the control functions (withdrawal and allocation) using feed-forward networks with customized activation functions designed to enforce the specified constraints. This data-driven approach bypasses the need for dynamic programming and demonstrates scalability to higher-dimensional problems. | |||
4. '''Comparative Results''': On synthetic data, the NN solution achieves performance nearly identical to that of the HJB method, showcasing its high accuracy in approximating the optimal control policy, including complex "bang-bang" withdrawal strategies. | |||
5. '''Robustness''': The NN framework exhibits strong performance in out-of-sample and out-of-distribution tests, such as bootstrap-resampled historical data, thereby demonstrating its generalizability beyond the training distribution. | |||
=== Contributions=== | |||
*Demonstration of neural networks as reliable solvers for constrained stochastic control problems, which were previously addressable only through partial differential equations (PDEs). | |||
*Quantitative benchmark comparisons between NN-based and PDE-based methods reveal near-equivalent accuracy, particularly in replicating key features such as the efficient frontier and optimal withdrawal behavior. | |||
*The proposed approach is scalable to higher dimensions, unlike PDE-based methods, making it potentially transformative for real-world retirement planning problems that involve multiple assets or stochastic factors. | |||
*The authors demonstrate that regularization within the NN framework helps mitigate instability in regions of the state space where the control problem becomes ill-posed (e.g., high wealth levels or near terminal time). | |||
*The method provides continuous-time control outputs by explicitly incorporating time as an input to the network, ensuring smooth solutions when required. | |||
A key innovation in the neural network (NN) formulation is the use of customized activation functions to enforce feasibility of withdrawal and allocation controls. Instead of training under constrained optimization, the authors design activation functions that inherently respect control bounds (e.g., no shorting, no leverage, and withdrawal limits). The withdrawal control uses a modified sigmoid function that maps outputs to the valid withdrawal range, which depends on wealth. The allocation control uses a softmax activation to ensure portfolio weights are non-negative and sum to one. This allows training via standard unconstrained optimization, significantly simplifying the optimization process while ensuring all control outputs remain feasible throughout the training and inference phases. | |||
=== Constructive Critiques=== | |||
* '''Ill-Posed Regions''': The NN and HJB solutions diverge in high-wealth regions near the terminal time due to the problem's ill-posedness. While the authors argue this has negligible impact on objectives, further analysis of how this affects real-world implementation would strengthen the paper. | |||
* '''Training Complexity''': The NN requires transfer learning for high κ values (weighting ES more heavily), suggesting potential instability in risk-averse scenarios. A deeper exploration of training challenges and solutions would be valuable. | |||
* '''Historical Data Limitations''': The bootstrap resampling tests rely on U.S. market data (1926–2019). Including non-U.S. data or stress-testing during extreme market conditions (e.g., hyperinflation) could enhance robustness claims. | |||
* '''Computational Costs''': While the NN avoids dynamic programming, the computational expense of training large networks is not quantified. A comparison of runtime between HJB and NN methods would clarify trade-offs. | |||
=== Relationships to Other Works=== | |||
This work builds on the stochastic control literature, particularly the decumulation problem studied in Forsyth (2022), which employs PDE-based methods. The current paper extends this research by providing a data-driven and high-dimensional alternative.It conceptually aligns with deep FBSDE methods, Deep Galerkin methods used for solving HJB equations, as well as reinforcement learning (RL)-based approaches to optimal control, such as the Deep Deterministic Policy Gradient.Compared to prior studies, such as Han and E (2016), Buehler et al. (2019), and Laurière et al. (2021), the current paper places emphasis on benchmarking against a well-established numerical method (PDEs), an aspect often overlooked in other NN-based control studies. The proposed method falls within the Policy Function Approximation (PFA) framework outlined in Powell (2021), providing a robust example of utilizing fixed neural networks to approximate control functions across time and state dimensions. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 16 Presentation: Machine Learning and HJB Equation for Optimal Decumulation – A Comparison Study == | |||
=== Presented by: === | |||
Zeyu Zhang | |||
=== Paper Citation === | |||
Chen M., Shirazi M., Forsyth P.A., Li Y. *Machine Learning and Hamilton-Jacobi-Bellman Equation for Optimal Decumulation: A Comparison Study*. arXiv, 2023. doi: [10.48550/arXiv.2306.10582](https://doi.org/10.48550/arXiv.2306.10582) | |||
=== What’s the Study About? === | |||
Retirement planning isn’t just about saving—it’s also about *withdrawing* wisely. This paper looks at how retirees can manage their savings smartly over time while investing in both stocks and bonds. The focus? Finding the **best way to withdraw** without running out of money or taking on too much risk. | |||
It compares two methods: | |||
1. **Traditional HJB PDE-based optimization** (fancy math from control theory). | |||
2. **Modern neural network (NN)-based learning** (using machine learning to approximate good decisions). | |||
The challenge is often called the *“Nazi’s hardest problem in finance”*—how to balance withdrawal needs, inflation, investment returns, and the fear of running out of money. | |||
=== Core Contributions === | |||
**1. HJB Approach – Classic but Computationally Heavy** | |||
- Uses dynamic programming and mathematical modeling to find the value function. | |||
- Models include risky asset returns, expected withdrawal needs, and allocation strategies. | |||
- Precise but often limited to low-dimensional cases due to complexity. | |||
**2. Neural Network (NN) Control – New, Scalable, and Flexible** | |||
- Treats the control strategy (how much to withdraw and how to allocate funds) as a learning problem. | |||
- Learns from simulated trajectories without solving complex equations. | |||
- Naturally handles constraints like no-short-selling or withdrawal limits. | |||
- More flexible in high dimensions and generalizes well to different market conditions. | |||
=== Big Idea === | |||
They set up a reward function balancing two goals: | |||
- Maximize total withdrawals (you want money to spend). | |||
- Minimize shortfall (you don’t want to run out). | |||
The NN learns how to make these decisions across time, adapting its strategy as the market evolves. | |||
=== Experimental Setup === | |||
- Simulates investment scenarios using real historical market data (e.g., bootstrap sampling from 1926–2019). | |||
- Uses feedforward networks and backward training for better accuracy and constraint satisfaction. | |||
- Approximates complex decisions (like "bang-bang" control where you go all-in or all-out) very effectively. | |||
=== Key Results & Takeaways === | |||
- **Accuracy**: The NN performs nearly as well as the HJB method—even for tricky withdrawal rules. | |||
- **Efficiency**: No need to solve high-dimensional PDEs; the NN handles complexity better. | |||
- **Scalability**: Can handle more assets, longer time horizons, and stochastic market patterns. | |||
- **Robustness**: Generalizes well to unseen data, even under distribution shifts. | |||
=== Pros of the NN Approach === | |||
- Avoids the "curse of dimensionality" that haunts HJB methods. | |||
- Works well with noisy and historical data. | |||
- Automatically learns smooth solutions—even when traditional solutions are discontinuous. | |||
- Better suited for real-world policies where exact math assumptions don’t always hold. | |||
=== Critiques & Limitations === | |||
- **Ill-posed regions**: Control can become unstable near wealth boundaries or end-of-life planning. | |||
- **Training complexity**: Transfer learning may be needed for extreme cases (like very risk-averse clients). | |||
- **Historical data bias**: Based on U.S. market data—might miss global scenarios or crises. | |||
- **Compute cost**: Large NNs still require serious GPU time, even without PDEs. | |||
=== Related Work === | |||
- Builds on Forsyth (2022), which used PDEs for optimal retirement strategies. | |||
- Closely linked to reinforcement learning (RL), especially policy gradient methods. | |||
- Part of the broader field of stochastic control and financial engineering. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 18 Presentation: HIGEN: HIERARCHICAL GRAPH GENERATIVE NETWORKS == | |||
=== Presented by: === | |||
- Shiyu Zhu | |||
- Jesse Xue | |||
=== Paper Citation === | |||
M. Karami, “HiGen: Hierarchical Graph Generative Networks,” 2023, arXiv. doi: 10.48550/ARXIV.2305.19337. | |||
=== Background === | |||
- Standard softmax function uses quadratic computational complexity during prediction, and linear attention mechanism leads to an underperforming model. | |||
- Hierarchical or Multi-Scale Structure: Captures high level interactions or relationships between objects or groups, while also representing the lower level structures. One example is a company org chart. | |||
- Existing graph generating models include: Variational Autoencoders, Generative Adversarial Networks, Autoregressive Models (GNN, Graph RNN, GRAN) and Diffusion Models | |||
- Paper introduces HIGEN, Hierarchical Graph Generative Networks to address problems with existing models. | |||
- Experiments were conducted on 5 datasets, each with increasing size and scale. The GraphRNN, GRAN, DiGress, GDSS, and SPEC | |||
=== Technical Contributions=== | |||
- Related Hierarchical Methods: The presentation discusses several recent hierarchical methods in specific domains like chemistry, highlighting HiGen's broader applicability, multi-level approach, and parallel generation as advantages over these more specialized techniques. These include a multi-based generation for molecular graphs (2020) relying on domain knowledge, a hierarchical normalizing flow model (2021) based on local neighborhoods, and a tree decomposition framework (2022) limited to a single abstraction level and medium-sized graphs. | |||
=== Definition: Hierarchical Graph === | |||
<p> | |||
A <strong>Hierarchical Graph</strong> is a multi-level representation of a graph | |||
<math>\mathcal{G} = (\mathcal{V}, \mathcal{E})</math> where: | |||
<ul> | |||
<li> | |||
<math>\mathcal{V}</math> is the set of <em>nodes (vertices)</em>, and | |||
<math>\mathcal{E}</math> is the set of <em>edges</em>, with sizes | |||
<math>n = |\mathcal{V}|</math> and <math>m = |\mathcal{E}|</math>. | |||
</li> | |||
<li> | |||
A node partition function <math>\mathcal{F}: \mathcal{V} \rightarrow \{1, ..., c\}</math> | |||
groups nodes into <math>c</math> communities or clusters. | |||
</li> | |||
<li> | |||
Each cluster forms a subgraph <math>\mathcal{C}_i = (\mathcal{V}(\mathcal{C}_i), \mathcal{E}(\mathcal{C}_i))</math> | |||
with adjacency matrix <math>A_i</math>. | |||
</li> | |||
<li> | |||
Cross-links between communities form <em>bipartite graphs</em> | |||
<math>\mathcal{B}_{ij} = (\mathcal{V}(\mathcal{C}_i), \mathcal{V}(\mathcal{C}_j), \mathcal{E}(\mathcal{B}_{ij}))</math>. | |||
</li> | |||
<li> | |||
Each cluster is aggregated into a super-node and each bipartite into a super-edge at the next higher level. | |||
This forms a <strong>coarser graph</strong> at the parent level. | |||
</li> | |||
</ul> | |||
</p> | |||
=== Methodology === | |||
HiGen (Hierarchical Graph Generative Networks) adopts a coarse-to-fine approach for generating graphs with complex hierarchical structures. The method consists of the following components: | |||
1. Hierarchical Graph Construction | |||
The input graph is recursively partitioned into communities using the Louvain clustering algorithm, forming a multi-level hierarchy. Each cluster (community) is abstracted as a "super-node" at the next level, and cross-community connections become "super-edges." | |||
2. Community and Bipartite Generation | |||
* Community Generation: Each community is generated independently using a GNN-based autoregressive model that factorizes the edge distribution into multinomial components. A mixture of multinomials is used to model intra-community edge structures. | |||
* Bipartite Generation: Once communities are generated, bipartite graphs (cross-community edges) are created using a separate neural network. These inter-cluster edges are predicted with a parallel GNN model using a similar factorization strategy. | |||
3. Autoregressive Probabilistic Modeling | |||
HiGen decomposes the joint probability of the graph into a product of conditional multinomial distributions. Both community and bipartite graphs are generated step-by-step using a recursive generation process guided by node and cluster embeddings. | |||
4. Parallelism and Invariance | |||
The hierarchical structure allows parallel generation of communities and bipartite edges at the same level, improving efficiency. The model is also invariant to node ordering within clusters, which improves robustness and scalability. | |||
This design enables HiGen to generate large and complex graphs while maintaining global structure and fine-grained local details. It supports realistic synthesis across diverse graph types including molecules, biological networks, and social systems. | |||
=== Summaries of Key Notes === | |||
- HiGen: Hierarchical Graph Generative Networks (HiGen) is a novel graph generation model that addresses the limitations of existing methods by leveraging hierarchical structures in graphs. It generates substructures in a coarse-to-fine manner, modeling the interactions between communities and cross-edges at multiple levels of abstraction. | |||
- Community and Bipartite Generation: HiGen utilizes modular generation, creating communities in parallel followed by predictions of cross-edges with separate neural networks, ensuring scalability and efficiency. It employs multinomial distributions with recursive factorization, enabling autoregressive generation of integer-valued edge weights within communities. | |||
- State-of-the-Art Performance: HiGen demonstrates superior graph quality across benchmark datasets compared to competing models like GRAN, GraphRNN, GDSS, and diffusion-based methods. It captures both local and global graph statistics effectively, achieving state-of-the-art metrics for graph degree, clustering coefficients, and Laplacian spectra. | |||
- Scalability: The hierarchical structure of HiGen facilitates parallel graph generation, reduces sensitivity to node ordering, and enables block-wise processing of adjacency matrices, making it adaptable to large and complex graphs. | |||
Applications: HiGen excels in generating realistic graphs for applications in molecular modeling, protein structure analysis, data network synthesis, and more, addressing diverse domains where hierarchical graph structures are pivotal. | |||
=== Results === | |||
* Experiments are conducted on 5 datasets (e.g., SBM, Protein, Enzyme, Ego) with increasing size and complexity. | |||
* Evaluation metrics cover local (degree distribution, clustering coefficient distribution, orbit counts) and global (graph spectra) graph statistics. | |||
* Observation: Compared to baselines like GraphRNN, GRAN, and SPECTRE, HiGen consistently achieves state-of-the-art performance, especially in capturing graph motifs and global structure. | |||
=== Related works === | |||
Classical graph models are typically heuristic, rule-based, and statistical, relying on predefined structures rather than learning distributions from data. They are typically computationally efficient, but can't capture complex graphs/structures beyond predefined rules. They are also not data-driven, so they are unable to learn patterns from observed graphs. | |||
Additional deep learning methods for graph generation include the following below, along with some of the things they struggle with | |||
<ul> | |||
<li>Variational auto encoders: Limited to smaller graphs due to scalability limitations</li> | |||
<li>GANs: Hard to train, cannot capture complex dependencies</li> | |||
<li>Autoregressive models such as graph neural networks and graph RNNs: Have high complexity</li> | |||
<li>Diffusion models: Struggles with sampling speed</li> | |||
</ul> | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 18 Presentation: HIGEN: HIERARCHICAL GRAPH GENERATIVE NETWORKS == | |||
=== Presented by: === | |||
- Shiyu Zhu | |||
- Jesse Xue | |||
=== Paper Citation === | |||
M. Karami, “HiGen: Hierarchical Graph Generative Networks,” 2023, arXiv. doi: 10.48550/ARXIV.2305.19337. | |||
=== Introduction === | |||
"HiGen: Hierarchical Graph Generative Networks" by Mahdi Karma introduces a novel approach to graph generation that places emphasis on the hierarchal structures in many inherent real-world graphs. | |||
=== Key Contributions === | |||
==== 1). Hierarchical Graph Generation ==== | |||
HiGen employs a coarse-to-fine strategy to generate graphs, effectively capturing their hierarchal nature. This method involves generating sub-structures at multiple levels, which enhances the model's ability to reflect the inherent organization of complex graphs. | |||
==== 2). Parallel Community Generation ==== | |||
At every hierarchical level, the HiGen method generates communities in parallel. Then, this is followed by the prediction of cross-edge between these communities using separate neural networks. Such a modular approach enables scalable graph generation for large and complicated graphs. | |||
==== 3). Multinomial Edge Distribution Modelling ==== | |||
The model utilizes a multinomial distribution to represent the output distribution od edges within the graph. A recursive factorization of this distribution is employed, and HiGen facilitates the autoregressive generation of community graphs with integer valued edge weights. This improves the realism and accuracy of the generated graphs. | |||
=== Performance Outcomes === | |||
Studies demonstrate the effectiveness and scalability of the HiGen, allowing it to achieve state of the art performance in terms of graph quality across a variety of benchmark datasets. The modular design and hierarchal generation process contributes to its abilities to ternate large and complex graphs with efficiency. | |||
=== Advancements in the Field === | |||
HiGen has advanced the abilities of machine learning in graph generative models by explicitly incorporating hierarchal structures into the generation process. This approach has addressed limitations in existing methods that often overlook the multi-level organization of real world graphs. This thereby enhances the fidelity and applicability of generated graphs within various domains and applications. | |||
=== Conclusion === | |||
The methodologies and finding presented in the HIGEN: HIERARCHICAL GRAPH GENERATIVE NETWORKS paper offer contributions to the field of graph generation. By introducing a hierarchal and modular approach, HIGen sets a new benchmark for generating complex graphs that accurately reflect real observed structures. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 18 Presentation: HiGen: Hierarchical Graph Generative Networks == | |||
=== Presented by === | |||
- Shiyu Zhu | |||
- Jesse Xue | |||
=== Paper Citation === | |||
M. Karami, “HiGen: Hierarchical Graph Generative Networks,” 2023, arXiv. doi: 10.48550/ARXIV.2305.19337. | |||
=== Background === | |||
- <b>Problem</b> Graphs are widely used to represent complex relationships in various domains (e.g., social networks, molecular structures, and knowledge graphs). However, generating realistic graphs is challenging due to the need to capture both local and global structures. | |||
- Hierarchical structures provide a natural way to model these interactions. Lower levels capture dense local structures, and higher levels capture global properties. | |||
- Existing graph generative models had limitations. Variational Autoencoders struggle with scalability, limiting them to small graphs. Autoregressive models captured hierarchical dependencies well, but suffered from high computational complexity due to sequential node/edge modeling. Diffusion models generated high-quality graphs but were computationally expensive due to long denoising steps. | |||
- <b>Motivation:</b> HiGen introduces hierarchical graph generation with a structured decomposition approach to improve scalability, efficiency, and parallelization. | |||
=== Summaries of Key Points === | |||
- <b>Goal:</b> HiGen aims to learn a generative model from training graphs and efficiently generating new graphs while capturing both local and global structures. | |||
- Uses generative probability and decomposed probability. Generative probability directly models the joint distribution of the entire graph and ensures a holistic view of the graph structure. Decomposed probability breaks the graph into smaller, independent components, and it allows for parallelization, which makes training and generation more scalable and efficient. | |||
- Hierarchical graph generative networks decompose graphs into communities and their bipartite interactions, leveraging conditional independence for efficiency. Essentially, community structures are conditionally independent of bipartite interactions, and this decomposition allows parallel generation of communities and their bipartite edges, reducing training time. | |||
- Inspired by Graph Recurrent Attention Network (GRAN). GRAN generates groups of nodes in parallel to reduce complexity, and HiGen extends this by introducing a k-mixture model. The k-mixture model learns more complex structural patterns and adapts to diverse hierarchical structures that are found in real-world graphs. It is implemented with Multi-Layer Perceptrons (MLPs), each with different inputs and activation functions to capture hierarchical relationships. | |||
- HiGen reduces the number of required adjacency matrices, which improves computational efficiency. It also encourages parallelizable training by organizing graphs into structured blocks. Those blocks represent strongly connected communities, effectively capturing intercommunity and global relationships. | |||
=== Related works === | |||
- Unlike previous models, HiGen explicitly incorporates hierarchical structure to improve scalability, making it more efficient than other methods in handling large graphs. | |||
- <b>Potential Applications:</b> Application in chemistry for molecular graph generation, where the hierarchical structures correspond to molecular backbones and functional groups. It is also helpful for simulating hierarchical community structures in large-scale networks, in terms of social networks. HiGen is applicable to any field that requires scalable graph generation with groups that have hierarchical dependencies. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 18 Presentation: HiGen: Hierarchical Graph Generative Networks == | |||
=== Presented by === | |||
- Shiyu Zhu and Jesse Xue | |||
=== Paper Citation === | |||
M. Karami, “HiGen: Hierarchical Graph Generative Networks,” 2023, arXiv. doi: 10.48550/ARXIV.2305.19337. | |||
=== Summaries === | |||
This paper proposed a new graph generation network , Hierarchical Graph Generative Network (HiGen), to improve the use of graph generation in real life applications without the expertise knowledge on a specific domain. This paper addressed the issue with existing works that they do not capture the hierarchical structure of graphs, especially graphs and the unordered nature of graphs. The new network generates based on partitions of the graph: communities and cross-links (bipartite graph) between neighboring communities. These components are based on its parent node and are independent from each other so can be generated in parallel which reduce the computational costs. Partitioning the graph also solves the issue with the autoregressive models have that it requires inputs to have orders. With this more efficient network, large graphs can be handled so that better be used for real-life applications. | |||
=== Key Contributions === | |||
* The generative process is run iteratively from the root node to the leaf nodes. Distribution is calculated for the partitioned graph at each level conditional on its parent level. The paper identified that the joint distribution of the set of weights of all edges follows a multinomial distribution. For each community and bipartite sub-graph, the conditional distributions are independent from each other, each follows a multinomial distribution. This finding allows for a parallel generation which reduces computational costs. | |||
* The generation of communities can be broken down to a autoregressive process to predict edge-by-edge. This is because the joint multibinomial distribution of the communities in a specific level given the partition graph at its parent level is a sequence of binomial distributions of each edge. | |||
* The edges of bipartite graphs (cross-links between neighboring communities) can also be generated in parallel. | |||
=== Explanations of details === | |||
'''Communities''': Communities are clusters of nodes or modules that can be grouped together. | |||
'''Partition process''': The graph is partitioned using a graph partitioning function. The partition process is done from the leaf nodes to the root node. | |||
=== Related Works === | |||
Some existing works are specific to an expertise domains which have the limitations to only work for one type of application. | |||
The autoregressive process used for existing works to predict requires an appropriate order of inputs which ignores an unique property of graphs which is that they are insensitive to ordering of nodes. | |||
Existing works also require significant computational costs thus cannot handle large graphs. | |||
Also there has not been a work that explores the hierarchical structure of graphs which limits the performance of existing models. | |||
For diffusion models, continuous diffusion models destroy the sparsity and structural properties of graphs, while discrete diffusion models requires an extra denoising step which may increase the computational costs. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 18 Presentation: HiGen: Hierarchical Graph Generative Networks == | |||
=== Presented by === | |||
- Shiyu Zhu and Jesse Xue | |||
=== Paper Citation === | |||
M. Karami, “HiGen: Hierarchical Graph Generative Networks,” 2023, arXiv. doi: 10.48550/ARXIV.2305.19337. | |||
=== Summaries === | |||
The paper introduces HiGen, a novel graph generative model that explicitly captures the hierarchical structure of real-world graphs. HiGen generates graphs in a coarse-to-fine manner, successively building communities and their interactions across multiple levels. At each level, the model generates subgraphs (communities) in parallel and then predicts cross-community (bipartite) edges using dedicated neural networks. This decomposition allows HiGen to scale efficiently and capture both local and global graph properties. | |||
=== Key Contributions === | |||
Proposes a multi-resolution generative framework where graphs are generated level-by-level, capturing hierarchical community structures naturally present in many real-world graphs. Decomposes each graph into communities that can be generated in parallel, increasing efficiency and scalability. Models edge weights using a multinomial distribution and introduces a recursive factorization (via stick-breaking) to enable efficient autoregressive generation. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 20 Gated linear attention transformers with hardware efficient training == | |||
=== Reference === | |||
arXiv:2312.06635 | |||
=== Background === | |||
The paper discusses Gated Linear Attention (GLA) Transformers, addressing the computational inefficiency of traditional transformers with softmax attention. Regular transformers have a quadratic computational complexity with sequence length, which becomes extremely expensive for long sequences, and linear attention mechanism leads to an underperforming model. | |||
=== Technical Contributions === | |||
It proposes using a linear kernel as an alternative to the softmax function, which allows attention to be formulated as a linear RNN with 2D hidden states. The key innovations include: | |||
1. Introducing a data-dependent gating mechanism to improve model performance, which allows the model to forget past information adaptively | |||
2. Developing a linear attention approach that reduces computational complexity | |||
3. Creating a hardware-efficient training method (FLASHLINEARATTENTION Algorithm) that can handle long sequences more effectively | |||
The main goal was to create a more efficient transformer model that can: | |||
- Reduce computational expenses | |||
- Maintain competitive performance across different tasks | |||
- Handle long sequences more effectively | |||
- Leverage modern GPU architectures for improved training and inference | |||
The approach addresses the fundamental challenge of making transformer models more scalable and computationally efficient, particularly for tasks involving long sequences like processing books, dialogues, or complex scientific texts. | |||
=== Results === | |||
The results and conclusions of the paper showed: | |||
Performance Results: | |||
- For the 340 million parameter model: | |||
- Achieved competitive performance | |||
- Close to transformer performance | |||
- Slightly better or comparable to Rednet | |||
- Slightly below Mamba on some tasks | |||
- For the 1.3 billion parameter model: | |||
- Beat most benchmarks in average accuracy | |||
- Slightly behind transformer++ in perplexity | |||
- Showed impressive accuracy across tasks | |||
Key Findings: | |||
1. Gating mechanism is crucial for model performance | |||
- Removing it significantly increased perplexity | |||
- Data-dependent scalar decay improved results | |||
2. Recall-intensive tasks: | |||
- Smaller model: Transformer still led | |||
- Larger model: GLA closed performance gap considerably | |||
- Competitive with Mamba and Rednet | |||
3. Computational Efficiency: | |||
- Higher training throughput for larger batch sizes | |||
- Slight increase in GPU memory usage | |||
- More efficient for bigger batches | |||
Conclusions: | |||
- GLA is highly effective for handling long sequences | |||
- Hardware-efficient design reduces computational costs | |||
- Gating mechanism significantly enhances model performance | |||
- Promising approach for making transformers more accessible and efficient | |||
- A efficient replacement for softmax attention in Transformers | |||
The paper suggests future research should focus on optimizing the balance between performance and efficiency. | |||
=== Linear Attention === | |||
<p> | |||
Transformers traditionally use softmax attention, which scales poorly with sequence length due to quadratic complexity. Linear attention approximates softmax with kernel-based attention mechanisms, reducing this cost. | |||
</p> | |||
==== Parallel and Recurrent Forms ==== | |||
<ul> | |||
<li> | |||
<strong>Parallel Form:</strong> Computes full attention using:<br> | |||
<math> | |||
\mathbf{O} = \text{softmax}((\mathbf{QK}^\top) \odot \mathbf{M}) \mathbf{V} | |||
</math><br> | |||
Enables efficient training with full-sequence inputs. | |||
</li> | |||
<li> | |||
<strong>Recurrent Form:</strong> Used during inference, processes token-by-token with:<br> | |||
<math> | |||
\mathbf{o}_t = \frac{\sum_{i=1}^{t} \phi(\mathbf{q}_t) \phi(\mathbf{k}_i)^\top \mathbf{v}_i}{\sum_{i=1}^{t} \phi(\mathbf{q}_t) \phi(\mathbf{k}_i)^\top} | |||
</math> | |||
</li> | |||
<li> | |||
Using <math>\phi(x) = x</math> and removing normalization yields the simplified linear attention update:<br> | |||
<math>\mathbf{S}_t = \mathbf{S}_{t-1} + \mathbf{k}_t^\top \mathbf{v}_t</math>, | |||
<math>\quad \mathbf{o}_t = \mathbf{q}_t \mathbf{S}_t</math> | |||
</li> | |||
</ul> | |||
==== Chunkwise Parallel Linear Attention ==== | |||
<p> | |||
The <strong>chunkwise parallel</strong> form balances between full parallelism and full recurrence, enabling faster training on long sequences. | |||
</p> | |||
<ul> | |||
<li> | |||
Splits input <math>\mathbf{X}</math> into chunks of length <math>C</math>. | |||
</li> | |||
<li> | |||
<strong>Inter-chunk state update:</strong><br> | |||
<math> | |||
\mathbf{S}_{[i+1]} = \mathbf{S}_{[i]} + \sum_{j=iC+1}^{(i+1)C} \mathbf{k}_j^\top \mathbf{v}_j | |||
</math> | |||
</li> | |||
<li> | |||
<strong>Intra-chunk output:</strong><br> | |||
<math> | |||
\mathbf{O}_{[i+1]} = \mathbf{Q}_{[i+1]} \mathbf{S}_{[i]} + \left((\mathbf{Q}_{[i+1]} \mathbf{K}_{[i+1]}^\top) \odot \mathbf{M}\right) \mathbf{V}_{[i+1]} | |||
</math> | |||
</li> | |||
</ul> | |||
===== Swish Activation function (SwiGLU) ===== | |||
One notable component of the GLA model’s design is its use of the Swish activation function (and the derived SwiGLU gating unit) in key parts of the network. | |||
Swish, defined as <math> \text{Swish}(x) = x \cdot \sigma(x) </math> (where <math> \sigma </math> is the sigmoid), is a smooth, non-monotonic activation known to often outperform ReLU/GELU in deep networks. In this paper, Swish is employed in two main places: (1) the feed-forward network (FFN) layers, where the authors adopt the SwiGLU formulation, and (2) the computation of the data-dependent gates in the attention mechanism. | |||
1) FFN | |||
The function's smooth gradient and ability to yield non-zero outputs even for negative inputs help with optimization and expressiveness. In summary, using SwiGLU in each Transformer block’s FFN is an architectural choice that boosts performance per parameter, making the overall model more competitive with standard Transformers. | |||
2)Gating mechanism | |||
Swish is self-gating. | |||
when <math>x_t W_r </math> is large positive, Swish outputs a large value (roughly linear in x for large positive inputs), but when <math>x_t W_r</math> is around zero or negative, Swish outputs a small value (tending toward zero for large negative inputs). This means <math> r_t</math> will tend to selectively suppress tokens that the model deems less important (yielding near-zero for those features) while allowing strong signals to pass through (near-linear for large activations). A standard sigmoid gate could also suppress features (outputting 0-1), but it would saturate at 1 for any sufficiently large input, effectively capping the influence of very important features. Swish, by contrast, does not saturate to a hard limit – for important inputs it can output values greater than 1 (since <math> r_t \approx x_t</math> if <math>x_t W_r</math> is large), thereby allowing an amplification effect. This gives the model more flexibility than a sigmoid-gated GLU: small signals are squashed (multiply by a small fraction), while strong signals can be propagated in full or even amplified (almost identity for large positive x). This property can be crucial for modeling, for example, rare key tokens that should strongly influence the attention – Swish gating will let those contributions through rather unattenuated, whereas a sigmoid gate might bottleneck at 1. | |||
===== Benefits ===== | |||
<ul> | |||
<li>Time complexity: <math>\mathcal{O}(LCd + Ld^2)</math>, which is sub-quadratic.</li> | |||
<li><math>C = 1</math> recovers the recurrent form; <math>C = L</math> recovers the parallel form.</li> | |||
<li>Efficient and scalable to long sequences with minimal performance loss.</li> | |||
</ul> | |||
=== Future Work === | |||
1. Future hardware-aware optimization: balance between efficiency and performance. | |||
2. Application to other data: the potential of applying GLA to image, video, or scientific data. | |||
3. Test how GLA perform on larger model: due to computational limitations, the experiment is on moderate scale model. | |||
=== Summaries of Key Points === | |||
- Gated Linear Attention (GLA) Transformer is a novel architecture that combines the efficiency of linear attention with data-dependent gating mechanisms to improve performance in sequence modeling tasks. | |||
- FLASHLINEARATTENTION is introduced as a hardware-efficient implementation of linear attention, outperforming FLASHATTENTION-2 in speed, even for short sequences (e.g., 1K tokens). | |||
- GLA Transformer enhances length generalization, allowing models trained on 2K sequences to generalize to sequences longer than 20K without significant performance degradation. | |||
- The model is competitive with state-of-the-art architectures, including LLaMA Transformers and linear-time inference models like RetNet and Mamba, particularly in moderate-scale language modeling tasks. | |||
- GLA Transformer achieves higher training throughput compared to similarly sized Mamba models while maintaining efficient long-context processing. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 20 Gated linear attention transformers with hardware efficient training == | |||
=== Presented by: === | |||
- Felix Jean | |||
- Maxime Bouthilier | |||
- Thomas Hudon | |||
=== Paper Citation === | |||
S. Yang, B. Wang, Y. Shen, R. Panda & Y. Kim, “Gated linear attention transformers with hardware efficient training,” 2024, arXiv:2312.06635 | |||
=== Background & Motivation=== | |||
The paper tackles the limitations of traditional softmax attention in Transformers, which, despite enabling efficient parallel training, exhibits quadratic complexity with respect to sequence length, rendering it impractical for long sequences. Linear attention has emerged as a promising alternative, providing linear-time inference by reformulating attention as a recurrent neural network (RNN) with 2D hidden states. However, in practice, linear attention often underperforms compared to softmax attention, and existing implementations lack I/O-awareness, leading to slower speeds relative to optimized softmax attention implementations such as FlashAttention-2. The authors identify two critical gaps: (1) the absence of hardware-efficient algorithms for linear attention that effectively balance memory movement and parallelizability, and (2) the lack of data-dependent gating mechanisms in linear attention, which are essential for achieving high performance in RNNs. These gaps motivate the development of FlashLinearAttention and the gated linear attention (GLA) Transformer. | |||
=== Key Points=== | |||
The paper introduces FlashLinearAttention, an I/O-aware and hardware-efficient algorithm for linear attention that optimizes memory movement and parallelizability. It achieves faster speeds than FlashAttention-2, even on short sequences (e.g., 1K tokens). The authors further extend this algorithm to Gated Linear Attention (GLA), which incorporates data-dependent gates to enhance model expressiveness. GLA preserves the linear-time inference property while improving performance across a range of tasks. Additionally, the paper proposes a chunkwise parallel formulation for GLA, enabling efficient training by dividing sequences into chunks and balancing inter-chunk and intra-chunk computations. Experimental results demonstrate that the GLA Transformer performs competitively against LLaMA-architecture Transformers and recent linear-time models such as RetNet and Mamba, particularly excelling in length generalization and recall-intensive tasks. | |||
=== Contributions=== | |||
*'''FlashLinearAttention''': A hardware-efficient algorithm for linear attention that outperforms FlashAttention-2 in speed and memory efficiency. | |||
*'''Gated Linear Attention (GLA)''': A novel linear attention variant with data-dependent gates, offering better performance and stability. | |||
*'''Chunkwise Parallel Form''': A training-friendly formulation of GLA that enables efficient parallelization and scalability. | |||
*'''Empirical Validation''': Demonstrates competitive performance against strong baselines, including LLaMA, RetNet, and Mamba, with notable strengths in length generalization and recall tasks. | |||
*'''Open-source Implementation''': The release of FlashLinearAttention as a practical tool for the community. | |||
=== Constructive Critiques=== | |||
*'''Scalability''': Although the experiments are conducted at moderate scales (up to 1.3B parameters), it remains unclear how GLA would perform at larger scales (e.g., 7B+ parameters). The authors hypothesize that GLA’s efficiency would further improve at such scales, but this claim requires empirical validation. | |||
*'''Generalization to Other Modalities''': The current focus is on language modeling; however, extending GLA to other domains (e.g., vision or audio) could potentially broaden its applicability and impact. | |||
*'''Complexity of Implementation''': The secondary-level chunking and materialization strategies introduce additional complexity. Providing a more streamlined implementation or conducting ablation studies could help users better understand the associated trade-offs. | |||
*'''Comparison to Hybrid Models''': While the paper compares GLA to pure linear-time models (e.g., Mamba) and softmax attention, hybrid approaches that combine linear and sparse attention are not explored. Such comparisons could provide deeper insights into GLA's relative strengths and limitations. | |||
=== Relationships to Other Works=== | |||
Linear Attention extends prior work by Katharopoulos et al. (2020) and Sun et al. (2023a) by introducing data-dependent gates and hardware optimizations. Hardware-Efficient Attention follows the spirit of FlashAttention (Dao et al., 2022b) but adapts it for linear attention, addressing unique challenges such as chunkwise parallelism. Gated RNNs draws inspiration from gated RNNs (e.g., LSTMs, Mamba) but adapts the gating mechanism for linear attention’s 2D hidden states. Length Generalization complements recent efforts like RetNet and Mamba-2, offering a new solution for extrapolating beyond training lengths. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 23 Presentation: Discrete Diffusion Modelling By Estimating the Ratios of the Data Distribution == | |||
=== Presented By === | |||
Chenxin Lyu, Yixuan Zeng | |||
=== Paper Citation === | |||
A. Lou, C. Meng, and S. Ermon, ‘Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution’, Jun. 06, 2024, arXiv: arXiv:2310.16834. doi: 10.48550/arXiv.2310.16834. | |||
https://arxiv.org/abs/2310.16834 | |||
=== Background === | |||
- Diffusion models have shown great performance for generative artifical intelligence when applied to domains with continuous data | |||
- Diffusion models are more difficult to implement for data in the discrete domain, such as tokenized texts | |||
- Prior attempts at applying diffusion to text generations have performed worse than autoregressive models | |||
=== Paper Contributions=== | |||
- Developed a method called Score Entropy Discrete Diffusion (SEDD) | |||
- Parameterizes the diffusion process for discrete data using data distribution ratios, rather than dealing with the tokenized data directly | |||
*SEDD | |||
SEDD is a framework for discrete diffusion modeling that learns to generate data by estimating probability ratios between neighboring discrete states.. | |||
In the paper, SEDD forms the core modeling strategy that enables diffusion-based generation for discrete data, achieving competitive perplexities with autoregressive baselines. | |||
*Implicit Score Entropy Loss | |||
<math> L_{ISE} = \mathbb{E}_{x \sim p} \sum_{y!=x}(w_{xy}s_\theta(x)_y - w_{yx}log s_\theta (y)_x) </math> | |||
Implicit Score Entropy Loss is a novel training objective designed to learn the ratio function <math> s_\theta(x, t) = p_t(y)/p_t(x)</math> | |||
without requiring access to the true data distribution. In the paper, it makes the score entropy computationally more efficient, as it avoids the explicit dependence on <math> p(y)/p(x) </math> ratio. | |||
It allows one to not compute partition functions or normalize over large vocabularies. The implicit score entropy loss makes the method scalable and practical for high-dimensional or categorical data (like text). It also connects naturally to energy-based modeling where exact densities are intractable but ratios or unnormalized scores can be learned. | |||
It’s worth noting that the score entropy loss also connects to maximum likelihood. The authors show it can be used to derive an evidence lower bound (ELBO) for likelihood training | |||
=== Discrete Diffusion Processes === | |||
<ul> | |||
<li> | |||
Models probability distributions over a <strong>finite discrete space</strong> | |||
<math> \mathcal{X} = \{1, \ldots, N\}</math>, using <strong>probability mass vectors</strong> | |||
<math>p_t \in \mathbb{R}^N</math>. | |||
</li> | |||
<li> | |||
Evolution of <math>p_t</math> follows a <strong>linear ODE</strong>:<br> | |||
<math>\frac{dp_t}{dt} = Q_t p_t,\quad p_0 \approx p_{\text{data}}</math> | |||
</li> | |||
<li> | |||
<math>Q_t</math> is a <strong>diffusion matrix</strong> with non-negative off-diagonal entries | |||
and column sums equal to 0 (mass is preserved). | |||
</li> | |||
<li> | |||
Often simplified as <math>Q_t = \sigma(t) Q</math>, driving | |||
<math>p_t</math> toward a base distribution as <math>t \to \infty</math>. | |||
</li> | |||
<li> | |||
Simulated using <strong>Euler steps</strong> with small <math>\Delta t</math>. Transition probability:<br> | |||
<math>p(x_{t+\Delta t} = y \mid x_t = x) = \delta_{xy} + Q_t(y, x) \Delta t + O(\Delta t^2)</math> | |||
</li> | |||
<li> | |||
<strong>Time Reversal:</strong> Reverse process uses another matrix <math>\overline{Q}_t</math> with:<br> | |||
<math>\overline{Q}_t(y, x) = \frac{p_t(y)}{p_t(x)} Q_t(x, y)</math><br> | |||
Reverse ODE: <math>\frac{dp_{T-t}}{dt} = \overline{Q}_{T-t} p_{T-t}</math> | |||
</li> | |||
<li> | |||
This connects to the <strong>concrete score</strong>, generalizing the score function | |||
<math>\nabla_x \log p_t</math>. | |||
</li> | |||
</ul> | |||
=== Summaries of Key Points=== | |||
- SEDD (Score Entropy Discrete Diffusion models) is a novel approach to discrete diffusion modeling that bridges the gap between diffusion models and autoregressive language models. It introduces score entropy, a new loss function that extends score matching to discrete spaces, improving performance in discrete generative modeling tasks. | |||
- SEDD significantly improves performance in language modeling, reducing perplexity by 25-75% compared to previous discrete diffusion models and outperforming GPT-2 in certain tasks. | |||
- Key advantages of SEDD over traditional autoregressive models: | |||
- Higher generative quality without requiring distribution annealing techniques such as temperature scaling. | |||
- Computational efficiency, enabling similar output quality with up to 32× fewer network evaluations. | |||
- Enhanced controllability, allowing for flexible text infilling beyond left-to-right prompting, while maintaining quality comparable to nucleus sampling. | |||
- SEDD models discrete data by parameterizing a reverse discrete diffusion process using ratios of the data distribution, making them more effective in capturing language structure. | |||
- The method challenges the dominance of autoregressive transformers by offering an alternative with better trade-offs between efficiency, control, and generation quality in discrete text modeling. | |||
==== Results ==== | |||
* Perplexity Evaluation: SEDD outperforms previous diffusion models in terms of perplexity across multiple language modeling benchmarks. This suggests that SEDD models the underlying data distribution more accurately, improving likelihood estimation. | |||
* Unconditional Generation: SEDD generates high-quality samples without any input conditioning, achieving performance comparable to GPT-2. It does so with 32× fewer network evaluations, indicating higher sampling efficiency and reduced computational cost. | |||
* Conditional Generation (Infill Tasks): SEDD is tested on tasks where the model must generate missing parts of text conditioned on surrounding context. Despite lacking the autoregressive biases that normally boost performance in such tasks, SEDD remains competitive with strong baselines like GPT-2 and SSD-LM. This highlights SEDD’s ability to generalize well to complex discrete generation tasks without relying on token-by-token prediction. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 23 Presentation: Discrete Diffusion Modelling By Estimating the Ratios of the Data Distribution == | |||
=== Presented By === | |||
Chenxin Lyu, Yixuan Zeng | |||
=== Paper Citation === | |||
A. Lou, C. Meng, and S. Ermon, ‘Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution’, Jun. 06, 2024, arXiv: arXiv:2310.16834. doi: 10.48550/arXiv.2310.16834. | |||
https://arxiv.org/abs/2310.16834 | |||
=== Background & Motivation=== | |||
The paper tackles a critical gap in generative modeling for discrete data, particularly within the domain of natural language processing (NLP). While diffusion models have achieved remarkable success in continuous domains such as image generation, their performance on discrete data (e.g., text) has fallen short compared to autoregressive models, which currently dominate the field. The authors pinpoint the root cause as the absence of a principled and scalable framework for discrete score matching—the foundational theory underlying continuous diffusion models. Existing approaches, such as mean prediction and ratio matching, exhibit theoretical and empirical limitations, including instability, inefficiency, and suboptimal performance. Motivated by these challenges, the paper introduces Score Entropy Discrete Diffusion (SEDD), a novel method that extends score matching to discrete spaces by estimating probability ratios of the data distribution. This approach seeks to close the performance gap between autoregressive and diffusion-based language models while addressing key challenges associated with slow sampling, limited controllability, and stringent annealing requirements in autoregressive models. | |||
=== Key Points=== | |||
1. '''Score Entropy Loss''': The key innovation lies in a novel loss function, score entropy, which extends score matching to discrete spaces by modeling the ratios of the data distribution. This ensures positivity, scalability, and theoretical consistency, thereby addressing the limitations of prior methods such as concrete score matching (which inadequately penalizes negative values). | |||
2. '''Discrete Diffusion Framework''': SEDD parameterizes the reverse diffusion process using learned probability ratios, enabling efficient sampling and likelihood-based training. The framework supports token-level transitions via structured matrices (e.g., uniform or absorbing transitions), facilitating the handling of high-dimensional sequences. | |||
3. '''Empirical Superiority''': SEDD outperforms existing discrete and continuous diffusion models on language tasks, reducing perplexity by 25–75% and matching or surpassing GPT-2 in zero-shot perplexity. It also achieves significantly higher-quality unconditional generation (6–8× better generative perplexity than un-annealed GPT-2) and flexible conditional generation (e.g., infilling). | |||
4. '''Practical Benefits''': The model provides a favorable compute-quality trade-off (e.g., achieving GPT-2 quality with 32× fewer steps), eliminates the need for annealing techniques like temperature scaling, and enables controllable infilling without specialized training. | |||
=== Contributions=== | |||
* '''Theoretical''': Introduces score entropy, a loss function that generalizes score matching to discrete spaces while ensuring positivity and scalability, with rigorous proofs of consistency and tractability (e.g., the denoising score entropy variant). . | |||
* '''Methodological''': Develops SEDD, a discrete diffusion framework that integrates score entropy with token-level transitions via structured matrices, enabling efficient training and sampling. The Tweedie-inspired τ-leaping sampling strategy further enhances performance in practical scenarios. | |||
* '''Empirical''': Demonstrates state-of-the-art results on language modeling benchmarks (e.g., text8, One Billion Words) and generation tasks (both unconditional and conditional), outperforming autoregressive baselines in key metrics such as perplexity. The model’s flexibility in infilling and its favorable compute-quality trade-offs represent significant advancements in the field. | |||
=== Constructive Critiques=== | |||
* '''Complexity''': The reliance on matrix exponentials (e.g., for token transitions) may limit scalability to larger vocabularies or more complex structures (e.g., graphs). | |||
* '''Generalization''': While SEDD excels in language, its applicability to other discrete domains (e.g., molecules, code) remains untested. | |||
* '''Training Cost''': The paper notes SEDD’s parameter count is slightly higher than GPT-2, but the computational overhead of diffusion training versus autoregressive training is not thoroughly compared. | |||
=== Relationships to Other Works=== | |||
The SEDD model advances key areas of generative modeling by improving upon prior discrete diffusion approaches such as D3PM (Austin et al., 2021) and the continuous-time framework of Campbell et al. (2022). It replaces their mean prediction objectives with ratio estimation, addressing limitations in stability and continuous-time approximation. Compared to continuous diffusion models like Diffusion-LM (Li et al., 2022) and PLAID (Gulrajani & Hashimoto, 2023), SEDD achieves better performance in likelihood estimation and generation quality without requiring heuristic annealing techniques. The work also generalizes score matching methods, extending Hyvarinen's original score matching (2005) and concrete score matching (Meng et al., 2022) to discrete domains through its score entropy formulation. While not yet reaching the scale of modern autoregressive models, SEDD competes effectively with autoregressive baselines like GPT-2 in flexible generation tasks (e.g., infilling) and computational efficiency. Its success highlights the potential of combining SEDD with recent advances such as self-conditioning (Strudel et al., 2022) to further close the gap with autoregressive models. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 23 Presentation: Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution == | |||
=== Presented by: === | |||
Chenxin Lyu and Yixuan Zeng | |||
=== Paper Citation === | |||
Lou, A., Meng, C., & Ermon, S. (2024). Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution. arXiv. https://doi.org/10.48550/arXiv.2310.16834 | |||
=== Summaries === | |||
Diffusion models have achieved remarkable performance in continuous domains like image generation, but extending them to discrete domains—such as natural language—has proven challenging. Previous efforts to adapt diffusion to text suffer from instability, inefficiency, and inferior performance when compared to autoregressive models, which currently dominate NLP. | |||
This paper proposes Score Entropy Discrete Diffusion (SEDD), a novel generative modeling framework for discrete data. Rather than modeling tokens directly, SEDD estimates probability ratios between discrete states, bridging the gap between diffusion models and autoregressive transformers. | |||
SEDD introduces a new training loss, score entropy, which generalizes score matching to discrete spaces. This loss function ensures theoretical consistency, avoids the need to compute partition functions, and maintains scalability for high-dimensional data. The reverse diffusion process is parameterized using these ratios, enabling efficient sampling, improved generation quality, and competitive perplexity scores—sometimes even outperforming GPT-2. | |||
=== Key Contributions === | |||
Score Entropy Loss | |||
A new loss function for discrete score matching. Unlike previous approaches, it avoids negative values, ensures scalability, and naturally connects to maximum likelihood estimation through an evidence lower bound (ELBO). | |||
Discrete Diffusion Framework | |||
SEDD models a reverse diffusion process using structured transition matrices and estimated data distribution ratios. This supports token-level transitions with practical benefits for text generation tasks. | |||
Efficient Sampling & High-Quality Output | |||
The model significantly reduces perplexity (by 25–75%) and achieves GPT-2–level generation quality while requiring 32× fewer sampling steps. It supports unconditional generation, conditional infilling, and eliminates the need for annealing strategies (e.g., temperature scaling). | |||
Theoretical Rigor & Practical Flexibility | |||
SEDD is grounded in a discrete formulation of diffusion ODEs and uses a Tweedie-inspired τ-leaping strategy for improved sample efficiency. | |||
=== Constructive Critiques or Reviews === | |||
Scalability Concerns: | |||
The use of matrix exponentials in token transitions might limit performance when applied to large vocabularies or graph-structured data. | |||
Domain Generalization: | |||
While SEDD shows strong results in language tasks, its application to other discrete domains like molecules or source code remains untested. | |||
Training Cost: | |||
The model has slightly more parameters than GPT-2, but the full computational trade-off between SEDD and autoregressive training isn’t thoroughly explored. | |||
=== Related Works === | |||
Discrete Diffusion Models: | |||
Builds upon earlier works like D3PM (Austin et al., 2021) and continuous diffusion adaptations such as Diffusion-LM and PLAID. | |||
Score Matching Foundations: | |||
Extends Hyvärinen’s (2005) original score matching and Meng et al.’s (2022) concrete score matching to discrete spaces using the score entropy formulation. | |||
Comparisons with Autoregressive Models: | |||
SEDD matches or outperforms GPT-2 in zero-shot and infilling tasks, offering better trade-offs between generation quality, efficiency, and controllability. | |||
Alternative Sampling Techniques: | |||
The τ-leaping sampler resembles approaches in energy-based modeling and may be enhanced further with self-conditioning (Strudel et al., 2022). | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 23 Presentation: Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution == | |||
=== Presented by === | |||
Chenin Lyu and Yixuan Zeng | |||
=== Paper Citation === | |||
A. Lou, C. Meng, and S. Ermon, ‘Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution’, Jun. 06, 2024, arXiv: arXiv:2310.16834. doi: 10.48550/arXiv.2310.16834. | |||
https://arxiv.org/abs/2310.16834 | |||
=== Background === | |||
- Generative modeling aims to create models capable of learning and reproducing real-world data distributions. | |||
- Existing autoregressive models generate sequences token-by-token, making them computationally slow, and they rely heavily on heuristic measures and hyper parameter tuning for optimal performance. | |||
- Diffusion models have emerged as an alternative, but existing approaches face challenges in discrete data modeling. | |||
- This paper introduces Score Entropy Discrete Diffusion (SEDD), which learns the ratio between different states in discrete data distributions rather than estimating explicit probability densities. It parameterizes the reverse diffusion process, making it computationally efficient and scalable for high-dimensional discrete data. | |||
=== Summaries of Key Points === | |||
- A diffusion process models the transformation of data from a structured state to pure noise over time and then reverses this process to generate new samples. The discrete diffusion process’s purpose is to create a probabilistic framework that can efficiently model discrete data. The time-reversed formulation of the diffusion process ensures that the generative process follows a learned reverse trajectory back to the original data distribution. | |||
- An ideal diffusion model is consistent, meaning the learned score function aligns with the true probability distribution. Score entropy loss is a loss function that is used to improve stability in discrete diffusion models. It ensures consistency by minimizing errors in the estimated probability ratios, leading to a more reliable generative process. | |||
- Implicit score entropy loss only depends on observed examples and learned scores, which is useful for practical optimization since it allows training without requiring knowledge of all true probabilities. It also encourages scalability since it works for high-dimensional discrete tasks, where exact probability distributions are infeasible to compute. | |||
- These discrete diffusion models address the issue of computational efficiency as well. Because computing all transition probabilities directly requires large matrix-matrix multiplications, which is impractical in memory-intensive settings, it uses two structured matrices to compute transition ratios effectively. | |||
- The model also enables unconditional generation by leveraging the learned diffusion process without requiring additional conditioning variables. | |||
=== Explanation of Time-Reversal Strategies === | |||
- Discrete Tweedie’s theorem says that if <math>p_t</math> satisfies the diffusion ODE <math>dp_t = Qp_t</math>, then the exact denoiser can be expressed as: | |||
<math>p_{0|t}(x_0|x_t) = \left(exp(-tQ)\left[\frac{p_t(i)}{p_t(x_t)}\right]_{i=1}^{N}\right)_{x_0} exp(tQ)(x_t,x_0)</math> | |||
- This theorem shows that given full knowledge of probability ratios, the optimal denoiser can be expressed in closed form. | |||
- This guides practical denoiser construction because the model learns data distribution ratios rather than explicit probabilities. It also allows the design of an efficient denoising process that reconstructs original states from noise. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 24 Presentation: Mitigating the Missing Fragmentation Problem in De Novo Peptide Sequencing With A Two-Stage Graph-Based Deep Learning Model == | |||
=== Presenters === | |||
Zi Hua Xu, Zehao Zhang | |||
=== Paper Citation === | |||
Mao, Z., Zhang, R., Xin, L. et al. Mitigating the missing-fragmentation problem in de novo peptide sequencing with a two-stage graph-based deep learning model. Nat Mach Intell 5, 1250–1260 (2023). https://doi.org/10.1038/s42256-023-00738-x | |||
https://www.nature.com/articles/s42256-023-00738-x#citeas | |||
=== Background === | |||
- Proteins are crucial for biological functions | |||
- Proteins are formed from peptides which are sequences of amino acids | |||
- Mass spectrometry is used to analyze peptide sequences | |||
- De Novo sequencing is used to piece together peptide sequences when the sequences are missing from existing established protein databases | |||
- Deep learning has become commonly implimented to solve the problem of de-novo peptide sequencing | |||
- When a peptide fails to fragment in the expected manner, it can make protein reconstruction difficult due to missing data | |||
- One error in the protein can propogate to errors throughout the entire sequence | |||
=== Paper Contributions=== | |||
- Graph Novo was developed to handle incomplete segments | |||
- '''GraphNovo-PathSearcher''' instead of directly predicting, does a path search method to predict the next peptide in a sequence | |||
- A graph neural network is used to find the best path from the graph generated from the mass spectrometry input | |||
- '''GraphNovo-SeqFiller''' instead of directly predicting, does a path search method to predict the next peptide in a sequence. | |||
- It's expected that some peptides/ amino acids may have been missed, SeqFiller uses a transformer to add in amino acids which have been missed from PathSearcher | |||
- Input is mass spectrum from mass spectrometry | |||
- Graph construction is done where nodes represent possible fragments, and edges represent possible peptides (PathSearcher module) | |||
- PathSearcher uses machine learning to find the optimal path on the generated graph | |||
- SeqFiller fills in missing amino acids that may have not been included in the PathSearcher module due to lacking data from the mass spectrometry inputs | |||
=== Peptide Sequencing in AlphaPeptDeep === | |||
- Peptide sequencing determines amino acid sequences of peptides, crucial for understanding proteins. | |||
- Mass spectrometry (MS) is used to fragment peptides and analyze the resulting spectra. | |||
==== Methods referenced in the presentation: ==== | |||
* Database Search & Spectral Library Search: AlphaPeptDeep improves prediction of MS spectra and retention time, boosting accuracy of both methods. | |||
* de novo Sequencing: Enhanced spectral prediction from AlphaPeptDeep supports building peptide sequences without prior knowledge. | |||
* AlphaPeptDeep predicts peptide properties (e.g., fragmentation patterns) to improve spectrum matching and sequence inference. | |||
=== Contributions === | |||
*GraphNovo outperforms prior models such as DeepNovo, PointNovo, and Casanovo. Achieveing: | |||
** 9.3–11.5% higher peptide recall, | |||
** 6.1–8.9% higher amino acid recall, | |||
** 5.8–8.7% higher amino acid precision. | |||
* Substantially improves predictions in and beyond missing-fragmentation regions. | |||
* Improves amino acid recall by up to 24.5% after missing fragmentation when guiding DeepNovo with GraphNovo’s predicted path. | |||
* Maintains high accuracy across varying: | |||
** Sequence lengths, | |||
** Noise signal ratios, | |||
** Degrees of missing fragmentation. | |||
* '''Open-source availability''': | |||
** Code: [https://github.com/AmadeusloveIris/Graphnovo GitHub – GraphNovo] | |||
** Data: [https://doi.org/10.5281/zenodo.8000316 Zenodo Repository] | |||
==== Constructive Critiques ==== | |||
* Struggles with long missing-fragmentation regions: The model has difficulty accurately predicting peptide sequences when large segments of fragmentation data are absent. These missing regions create gaps in the spectrum, which may impair the model's ability to reconstruct the full peptide chain. | |||
* Requires predefined post-translational modifications: The system depends on a predefined list of possible post-translational modifications. This constraint limits the model’s ability to generalize to peptides with unexpected or novel PTMs, reducing its adaptability in complex biological samples. | |||
* Computationally expensive: Due to its two-stage graph-based deep learning architecture, the model demands significant computational resources. | |||
==== Reviews ==== | |||
I thought that this presentation clearly introduced the concepts of peptide sequences and how mass spectrometry is used to analyze/reconstruct peptide sequences. Also it was nice that you mentioned the different ways how sequences are matched with observed spectra besides deep learning methods (e.g, matching observed spectra with known sequences and using pre-existing spectral databases). The problem of how mistakes in the fragmentation can result in missing data in the spectrometry observed was also clearly explained, ultimately making sequence reconstruction difficult. As someone who plans to do project 3 for the final project, I found this presentation particularly helpful in understanding the type of data used in de novo peptide sequence generation. One comment that I have is that I found it a little unclear on how the mass spectra input is turned into a graph with fragments as nodes, and what exactly the optimal "path" is in the graph (aka how the edges are defined), although it may also because I am not too familiar in this area. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 24 Presentation: Mitigating the Missing Fragmentation Problem in De Novo Peptide Sequencing With A Two-Stage Graph-Based Deep Learning Model == | |||
=== Presenters === | |||
Zi Hua Xu, Zehao Zhang | |||
=== Paper Citation === | |||
Mao, Z., Zhang, R., Xin, L. et al. Mitigating the missing-fragmentation problem in de novo peptide sequencing with a two-stage graph-based deep learning model. Nat Mach Intell 5, 1250–1260 (2023). https://doi.org/10.1038/s42256-023-00738-x | |||
https://www.nature.com/articles/s42256-023-00738-x#citeas | |||
=== Background === | |||
Peptide sequencing is a critical step in proteomics for determining the amino acid composition of proteins using tandem mass spectrometry (MS). Traditional approaches fall into three main categories: | |||
(1) Database Search: Compares observed tandem mass spectra with peptides from a known database. | |||
(2) Spectral Library Search: Uses curated spectral libraries containing experimentally acquired spectra to identify peptide–spectrum matches. | |||
(3) De Novo Sequencing: Infers the peptide sequence directly from the MS data without relying on existing sequence databases, making it essential for novel protein discovery and cases where databases are incomplete or impractical. | |||
=== Limitations in de novo peptide sequencing === | |||
Two major challenges persist in de novo peptide sequencing: | |||
(1) Missing Fragmentation: Some peptide bonds do not break during MS fragmentation, leaving “gaps” in the spectrum that make certain regions difficult to reconstruct. | |||
(2) Error Accumulation: A mistake in one poorly fragmented region often propagates, causing further downstream errors in the predicted peptide sequence. | |||
=== GraphNovo === | |||
GraphNovo is a two-stage de novo sequencing algorithm designed to mitigate these issues: | |||
(1) GraphNovo-PathSearcher: | |||
Constructs a directed acyclic graph (the “spectrum graph”) where each node represents a possible partial mass, and edges represent allowable amino acid mass differences. | |||
Predicts the “optimal path” from the start node to the end node, effectively capturing the correct arrangement of fragment ions and bypassing regions of missing fragmentation by labeling them as “mass tags.” | |||
(2) GraphNovo-SeqFiller: | |||
Fills in the “mass tags” from the optimal path with their specific amino acid composition. | |||
Uses a transformer-based decoder guided by the path constraints, reducing errors that could otherwise accumulate if one low-confidence region led to incorrect subsequent predictions. | |||
=== How GraphNovo Functions === | |||
(1) Graph Construction: Translates raw MS peaks into a spectrum graph. Each node corresponds to a potential prefix mass; edges form between nodes if the mass difference matches one or more amino acid masses (within tolerance). | |||
(2) PathSearcher: Trained on known spectra to find the correct route through this graph, placing constraints on fragment ions (including missing-fragmentation regions represented as nodes without direct evidence). | |||
(3) SeqFiller: Given the path—and hence each mass gap—SeqFiller “zooms in” on each mass tag to determine the exact amino acids. This two-stage strategy tackles missing fragmentation more directly than single-stage approaches. | |||
=== Data and Preprocessing === | |||
Training Data: Includes high-confidence peptide identifications for Homo sapiens and Mus musculus from public proteomics datasets (e.g., plasma, HeLa, brain tissues). Only peptides with reliable annotations (1% false-discovery rate) and precursor masses below a certain threshold are included. | |||
Test Data: Drawn from species not in the training set (e.g., Arabidopsis thaliana, C. elegans, E. coli), ensuring that evaluation measures generalizability. | |||
Preprocessing Steps: | |||
(1) Convert raw MS peaks into feature vectors (e.g., normalized m/z, relative intensity). | |||
(2) Generate “node spectrum” (b, y, a, y2+ ions, among others) while discarding infeasible peaks. | |||
(3) Build the graph by connecting nodes if their mass difference matches valid amino acid combinations. | |||
=== Model Architecture === | |||
Graph Construction: Creates a directed graph with edges corresponding to possible amino acid masses. | |||
Loss Functions: | |||
(1) GraphNovo-PathSearcher: Uses Kullback–Leibler divergence to guide node (path) predictions. | |||
(2) GraphNovo-SeqFiller: Uses cross-entropy to predict the exact amino acid sequence that fills each mass tag. | |||
Hyperparameter Tuning: | |||
Optimizer: AdamW (a variant of Adam) with a fixed learning rate in the reported experiments. | |||
Both stages employ a transformer-based architecture, incorporating a specialized graph encoder (relation attention) to capture node and edge features. | |||
=== Performance Comparison === | |||
Peptide Recall: GraphNovo shows a 9–12% improvement over the next-best approach (e.g., PointNovo, Casanovo) in correctly reconstructing entire peptide sequences. | |||
Amino Acid Recall and Precision: Yields a 5–9% improvement across different test species, indicating more accurate individual residue identifications. | |||
Robust to Missing Fragmentation and Noise: Maintains relatively high recall/precision for longer peptides, higher noise levels, and spectra with multiple missing-fragmentation sites, thereby mitigating error accumulation. | |||
=== Constructive Critiques === | |||
Long Missing-Fragmentation Regions: While GraphNovo substantially reduces error propagation, very long or continuous gaps remain challenging. | |||
Predefined Modifications: Must specify possible post-translational modifications (PTMs) in the graph construction step, which becomes computationally costly if many PTMs are considered at once. | |||
Computational Overhead: Two-stage approach and large-scale graph construction require significant memory and processing time. | |||
=== Future improvements === | |||
Enhanced Sequence Prediction: Integrate more MS features (e.g., retention time) to improve accuracy within large missing-fragmentation regions. | |||
Expanded Applicability: Adapting the two-stage approach for top-down or middle-down proteomics and more extensive sets of PTMs. | |||
Computational Efficiency: Explore faster graph-building algorithms, reduce the number of nodes through refined filtering, and potentially incorporate few-shot learning for user-specified PTMs. | |||
=== Comment === | |||
Overall, GraphNovo demonstrates that a carefully designed, graph-based deep learning model can significantly mitigate the missing-fragmentation problem in de novo peptide sequencing, outperforming traditional and newer transformer-based methods by providing a stable framework for both path and sequence prediction. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
=== Paper Citation === | |||
== Group 47 Presentation: Jamba: A Hybrid Transformer - Mamba Language Model == | |||
=== Presenter === | |||
Chentao Jin | |||
=== Paper Citation === | |||
Lieber, O., Lenz, B., Bata, H., Cohen, G., Osin, J., Dalmedigos, I., Safahi, E., Meirom, S., Belinkov, Y., Shalev-Shwartz, S., Abend, O., Alon, R., Asida, T., Bergman, A., Glozman, R., Gokhman, M., Manevich, A., Ratner, N., Rozen, N., Shwartz, E., Zusman, M., Shoham, Y. (2024). Jamba: A Hybrid Transformer-Mamba Language Model. arXiv. https://arxiv.org/abs/2403.19887 | |||
https://doi.org/10.48550/arXiv.2403.19887 | |||
=== Background=== | |||
- Large language models (LLMs) have become essential for various natural language processing tasks. | |||
- Transformers have been the dominant architecture for LLMs due to their effectiveness in handling sequential data. | |||
- However, Transformers suffer from high memory and compute costs, limiting their efficiency in long-context processing. | |||
- Mamba, a recent state-space model, has emerged as an alternative to Transformers, offering improved efficiency in handling long sequences. | |||
- A hybrid approach combining Transformers and Mamba layers can leverage the strengths of both architectures. | |||
- Mixture-of-Experts (MoE) techniques can further enhance model capacity while managing active parameter usage. | |||
- Efficient model architectures are crucial for balancing performance, computational efficiency, and memory footprint in large-scale AI applications. | |||
=== Summaries of Key Points=== | |||
'''Bridging Transformer Expressiveness and Mamba Efficiency''' | |||
Large language models (LLMs) have made remarkable advances, but scaling them to handle long-context processing remains a significant challenge. Jamba introduces a novel hybrid architecture that combines the strengths of Transformers and Mamba, enhanced with a Mixture of Experts (MoE) module. This integration enables efficient memory usage, high throughput, and scalability for sequences up to 256,000 tokens on a single GPU. | |||
'''Key Architectural Innovations''' | |||
1. Transformer Layers – Self-attention mechanisms allow the model to capture complex token relationships, crucial for tasks involving deep contextual reasoning. However, they come with high memory and compute costs when processing long sequences. | |||
2. Mamba Layers – Derived from state-space models (SSMs), Mamba efficiently processes long sequences without storing extensive key-value caches. Instead, it maintains a hidden state to summarize prior information, significantly reducing memory overhead. | |||
3. Mixture of Experts (MoE) – Jamba integrates sparse expert selection, where each token activates only a small subset of experts. This technique increases capacity while controlling computational costs. Specifically, Jamba uses 16 experts, with only two active per token, optimizing efficiency. | |||
'''The Jamba Block: A Structured Hybrid Design''' | |||
The architecture of a Jamba block follows a structured sequence of Transformer layers, Mamba layers, and MoE layers: Transformer-to-Mamba Ratio (1:7). The model incorporates one Transformer layer for every seven Mamba layers, balancing expressiveness and efficiency. MoE Placement – Instead of applying MoE to every layer, Jamba replaces every second multi-layer perceptron (MLP) layer with an MoE module. This approach increases model capacity without significantly raising parameter count. | |||
By blending self-attention, state-space models, and sparsity techniques, Jamba pushes the boundaries of long-context processing in language models. Its ability to handle extremely long sequences while maintaining efficiency and scalability makes it a compelling innovation in the next generation of LLM architectures. | |||
=== Performance and Benefits === | |||
'''Jamba''' achieves state-of-the-art efficiency and performance across academic benchmarks, long-context tasks, and throughput. | |||
* Matches or outperforms larger models such as '''Mixtral-8x7B''' and '''LLaMA-2 70B''' on: | |||
** Reasoning: HellaSwag, ARC-Challenge, WinoGrande, PIQA, TruthfulQA | |||
** Comprehension: BoolQ, QuAC | |||
** Math and Code: GSM8K, HumanEval | |||
** Aggregated tasks: MMLU, BBH | |||
* Outperforms Mixtral on: | |||
** Needle-in-a-Haystack retrieval | |||
** Few-shot classification: Banking77, TREC-Fine | |||
** Long-context QA: NarrativeQA, CUAD, NaturalQuestions | ** Long-context QA: NarrativeQA, CUAD, NaturalQuestions | ||
** Up to 3× faster than Mixtral at <math>128K</math> context length. | ** Up to 3× faster than Mixtral at <math>128K</math> context length. | ||
* Efficient inference on a single <math>80\,\text{GB}</math> GPU (int8 quantization). | |||
* KV-cache memory usage is 8× smaller than Transformers (e.g., <math>4\,\text{GB}</math> vs. <math>32\,\text{GB}</math> at <math>256K</math> tokens). | |||
=== Constructive Critique === | |||
While Jamba achieves impressive performance and efficiency through its hybrid architecture, several limitations and open questions remain: | |||
* Lack of ablation analysis: The paper adopts a 1:7 ratio of Transformer to Mamba layers and places MoE modules every other layer, but provides little justification for these hyperparameters. A more thorough ablation would help clarify the contribution of each component to final performance. | |||
* Interpretability trade-off: Combining multiple architectural modules (SSM, attention, MoE) increases complexity. While effective, this may make it harder to interpret model behavior or debug errors, especially compared to simpler Transformer-only baselines. | |||
* ICL limitations: The authors mention that Mamba layers alone struggle with in-context learning (ICL). Although interleaving attention layers helps, this still suggests limitations in how well SSMs handle token-by-token reasoning or structure-sensitive tasks. | |||
* Lack of fine-tuning or alignment: The released model is a base pretrained model without instruction tuning or safety alignment. This limits its immediate use in downstream applications without additional supervised or RLHF-based training. | |||
Despite these challenges, Jamba represents a promising direction for scaling efficient, long-context language models and offers a practical blueprint for hybrid architectures in the LLM space. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 47 Presentation: Jamba: A Hybrid Transformer - Mamba Language Model == | |||
=== Presenter === | |||
Chentao Jin | |||
=== Paper Citation === | |||
Lieber, O., Lenz, B., Bata, H., Cohen, G., Osin, J., Dalmedigos, I., Safahi, E., Meirom, S., Belinkov, Y., Shalev-Shwartz, S., Abend, O., Alon, R., Asida, T., Bergman, A., Glozman, R., Gokhman, M., Manevich, A., Ratner, N., Rozen, N., Shwartz, E., Zusman, M., Shoham, Y. (2024). Jamba: A Hybrid Transformer-Mamba Language Model. arXiv. https://arxiv.org/abs/2403.19887 | |||
https://doi.org/10.48550/arXiv.2403.19887 | |||
=== Jamba Architecture === | |||
=== Main Features === | |||
Jamba is a hybrid large language model that interleaves: | |||
Transformer layers with self-attention (standard decoder blocks). | |||
Mamba layers (a state-space model, SSM) introduced by Gu & Dao (2023). | |||
Mixture-of-experts (MoE) modules integrated into some of the MLP layers. | |||
By combining these three components, Jamba can balance efficiency, long-context capabilities, and model capacity without incurring a prohibitive computational or memory cost. | |||
(1) Transformer Layers | |||
Jamba uses standard decoder-only Transformer blocks, but crucially, they appear in a reduced proportion (e.g., 1 attention layer for every 7 Mamba layers). | |||
The attention mechanism is still important for in-context learning and tasks that benefit from explicit token-to-token interactions. | |||
(2) Mamba Layers | |||
Mamba layers replace much of the attention with an SSM-based mechanism, scaling linearly with sequence length. | |||
They significantly reduce key–value cache size for long contexts because each Mamba layer does not require storing extensive attention activations. | |||
Unlike prior SSMs, Mamba is stabilized at large scale with carefully chosen RMSNorm inside the state-space modules. | |||
The authors find no explicit positional encoding is required in the Mamba blocks—Mamba inherently captures positional information. | |||
(3) Mixture-of-Experts (MoE) | |||
Jamba integrates MoE in some MLP layers to increase total capacity without increasing the active parameters used per token. | |||
MoE involves having multiple “expert” sub-MLPs, with only the top K experts selected for each token. | |||
This leads to a sparse model design: total parameters can be large (e.g., 50B+), but only ~12B parameters are “active” at any forward pass. | |||
=== Performance and Benefits of Jamba === | |||
(1) High throughput | |||
Compared to a pure-Transformer of similar size, Jamba achieves up to 3× higher inference throughput at very long context lengths. This is because Mamba’s linear-time scan avoids the quadratic cost and large key–value cache of attention. | |||
(2) Memory efficiency | |||
Jamba’s key–value cache can be 8× smaller than a similarly sized Transformer, which makes it possible to handle up to 256K tokens of context (or even more) on a single 80GB GPU. | |||
(3) Competitive quality | |||
On standard LM benchmarks (ARC, HellaSwag, WinoGrande, etc.), Jamba performs on par with or better than similarly sized Transformer or MoE-Transformer models. It also demonstrates strong capabilities in long-context tasks (e.g., “needle in a haystack” retrieval). | |||
=== Key Design and Insights === | |||
(1) Hybrid Architecture | |||
The mixed ratio of attention layers to Mamba layers (often 1:7) is crucial. Even a small fraction of attention layers confers strong in-context learning (format adherence, induction-like patterns), while the Mamba layers bring speed and memory savings. | |||
Pure Mamba, though fast, sometimes struggles with emergent in-context learning behaviors (e.g., properly following few-shot prompts). The hybrid design preserves these Transformer-like capabilities. | |||
(2) MoE Effectiveness | |||
Using MoE on top of the hybrid model further improves perplexity and downstream performance, allowing the total parameter count to go up to 50B+ while keeping active parameter usage around ~12B. | |||
Balancing the number of experts, top-K selection, and how frequently MoE is used (e.g., every other MLP layer) is key for controlling compute costs and memory. | |||
(3) Training Stability and Design Choices | |||
RMSNorm: Large-scale Mamba layers exhibit occasional large activation spikes. RMSNorm on internal activations stabilizes the training, preventing loss spikes. | |||
No explicit positional encoding needed: Unlike typical Transformers (which use rotary, ALiBi, or other embeddings), the authors found that Mamba captures positional cues inherently. Adding RoPE gave no notable improvement. | |||
=== Conclusion === | |||
(1) Uniqueness of Jamba | |||
High efficiency and strong design | |||
Jamba’s combination of attention, Mamba, and MoE layers yields excellent throughput and long-sequence modeling. | |||
(2) Handling long context better | |||
Jamba’s memory footprint for KV caching is drastically smaller. It can handle contexts of up to 256K tokens on a single 80GB GPU—significantly exceeding typical Transformer-based LLMs of similar size. | |||
(3) Open-source release | |||
The model is released under an Apache 2.0 license, encouraging research on this hybrid approach. Pretrained checkpoints and ablation runs will also be provided. | |||
=== Future Directions === | |||
(1) Optimize MoE Further | |||
Investigating more sophisticated MoE routing strategies, expert balance, or hierarchical gating to push quality and efficiency further. | |||
(2) Hybrid Scaling in Even Larger Models | |||
Extending beyond ~7B–12B active parameters to tens of billions or more, exploring how the attention–Mamba ratio and MoE design scale at even larger training runs. | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 47 Presentation: Jamba: A Hybrid Transformer - Mamba Language Model == | |||
=== Presenter === | |||
Chentao Jin | |||
=== Paper Citation === | |||
Lieber, O., Lenz, B., Bata, H., Cohen, G., Osin, J., Dalmedigos, I., Safahi, E., Meirom, S., Belinkov, Y., Shalev-Shwartz, S., Abend, O., Alon, R., Asida, T., Bergman, A., Glozman, R., Gokhman, M., Manevich, A., Ratner, N., Rozen, N., Shwartz, E., Zusman, M., Shoham, Y. (2024). Jamba: A Hybrid Transformer-Mamba Language Model. arXiv. https://arxiv.org/abs/2403.19887 | |||
https://doi.org/10.48550/arXiv.2403.19887 | |||
=== Clear explanations to aid understanding === | |||
Jamba uses a unique combination of transformer layers, memory layers, and a Mixture of Experts (MoE) layer to improve its efficiency, scalability, and ability to process long sequences of text. This hybrid design optimizes both memory management and computational power while minimizing resource use. Here’s an overview of the components and their roles in the architecture: | |||
==== Jamba's Architecture ==== | |||
Jamba’s architecture is built on a structured sequence of layers: after each memory layer, a Mixture of Experts (MoE) layer is placed, and every transformer layer is followed by seven memory layers. This design creates an efficient, hybrid system that maximizes both memory handling and computational power while minimizing resource usage. | |||
Each of these layers: | |||
1. '''Transformer Layers''': The transformer layers in Jamba are responsible for capturing the relationships between different tokens in the sequence, no matter how far apart they are. This is done using self-attention, which helps the model understand complex relationships and context within the text. However, traditional transformers can struggle with very long sequences because the self-attention mechanism becomes very memory and computation-heavy as the sequence length grows. | |||
2. '''Memory Layers''': To tackle this, Jamba introduces memory layers based on a different model called the state space model (SSM). These memory layers don’t rely on self-attention. Instead, they maintain a hidden state that keeps track of important information from earlier in the sequence. This makes memory layers far more efficient when it comes to handling long sequences, as they don’t need to store as much data in memory, unlike the transformer layers.''' | |||
3. '''Mixture of Experts (MoE)''': The MoE component is where Jamba gets its flexibility. Instead of using all model parameters for each token, MoE selectively activates a small subset of "experts" for each token. An "expert" is a specialized set of parameters that focuses on solving a specific part of the problem. The model dynamically selects the experts that are most relevant to each token's context, allowing it to handle different parts of the problem more efficiently. Since only a few experts are activated per token, the model can scale its capacity to handle more complex tasks or longer sequences without significantly increasing computational costs.''' | |||
Additionally, Jamba stabilizes its training process with RMSNorm, which normalizes activations and ensures that training remains stable, even when the model is scaled up to very large sizes. | |||
==== Performance and Benefits of Jamba ==== | |||
Jamba's hybrid approach provides several advantages: | |||
* Efficient Handling of Long Contexts: Jamba is able to process long sequences of text effectively, overcoming the limitations of traditional transformer models. | |||
* Balance Between Performance and Efficiency: Jamba achieves strong performance while reducing memory usage and computational costs thanks to its combination of transformer, memory, and MoE layers. | |||
* Scalability: By using fewer active parameters than other models, Jamba is scalable and can handle tasks that require understanding large amounts of text without compromising efficiency. | |||
==== Limitations of Jamba ==== | |||
But this hybrid approach has limitations as well: like training instability that needs RMSNorm to keep things stable. It also requires a lot of GPU memory, like 80 GB, to handle longer contexts. Also, the mixture of experts (MOE) still needs more optimization to improve performance. | |||
==== Further Directions ==== | |||
* Optimize MoE further for even better efficiency | |||
* Investigate how hybrid architectures scale in even larger contexts and models | |||
</div> | |||
<div style="border: 2px solid #0073e6; background-color: #f0f8ff; padding: 10px; margin: 10px 0; border-radius: 5px;"> | |||
== Group 47 Presentation: Jamba: A Hybrid Transformer - Mamba Language Model == | |||
=== A New Approach to Long-Context Language Modeling === | |||
Jamba is a new language model that seeks to overcome some of the classic limitations of Transformers—namely, the high memory overhead of the key-value cache and the computational inefficiency when processing long sequences. The paper introduces a hybrid architecture that blends traditional Transformer layers with a newer type of state-space layer called Mamba. This fusion is further enhanced by incorporating Mixture-of-Experts (MoE) modules, creating a model that is both memory efficient and highly performant on long-context tasks. | |||
=== Key Architectural Innovations === | |||
1. Hybrid Transformer-Mamba Design: | |||
Transformers vs. Mamba: Traditional Transformers excel at in-context learning and have become the de facto architecture for language models. However, their self-attention mechanism results in quadratic memory usage relative to the context length, making them less efficient for very long texts. | |||
Mamba Layers: These layers, based on state-space models, offer a more efficient alternative by reducing memory requirements and improving throughput, especially when the input sequence becomes very long. | |||
Interleaving Strategy: Jamba interleaves a small number of attention layers with a larger number of Mamba layers. In one specific configuration, a ratio of 1 attention layer to 7 Mamba layers was found to be both effective and compute-efficient. This interleaving allows the model to benefit from the strengths of both components. | |||
2. Mixture-of-Experts (MoE): | |||
Expanding Capacity Without Extra Cost: MoE modules are used to boost the total parameter count (and thus the model’s capacity) without a proportional increase in compute costs. This is achieved by activating only a small subset of experts (e.g., top-2 out of 16) for any given token. | |||
Flexibility: The MoE integration is flexible and can be adjusted (by varying the frequency of MoE layers or the number of experts) to trade off between memory usage, throughput, and performance. | |||
3. Resource Efficiency: | |||
KV Cache Savings: One of the standout features of Jamba is its dramatic reduction in key-value cache memory requirements. For instance, while some comparable models might require tens of gigabytes to store the cache for long contexts, Jamba can process up to 256K tokens with only a few gigabytes. | |||
Single-GPU Feasibility: Despite having a total available parameter count of 52B (with only 12B active at any time), the model is engineered to fit on a single 80GB GPU, which is an impressive engineering feat. | |||
=== Constructive Critiques and Discussion === | |||
Ablation Insights: | |||
The paper includes a thorough ablation study showing that neither pure Transformer nor pure Mamba models perform as robustly across all tasks as the hybrid does. However, it raises questions about the precise roles of each component—especially how much reliance there is on the Transformer layers for in-context learning. Future work could explore whether even fewer attention layers might suffice or if alternative mechanisms might further enhance the balance. | |||
Scaling and Adaptability: | |||
While the current design is optimized to fit on a single high-end GPU, it remains to be seen how the architecture scales when further pushed in size or applied to more diverse downstream tasks. Additionally, the robustness of the MoE routing (i.e., ensuring the right experts are chosen consistently) could benefit from further investigation and refinement. | |||
Positional Encoding: | |||
An interesting observation is that Jamba performs similarly with and without explicit positional embeddings. The paper suggests that the Mamba layers might be implicitly capturing positional information. This finding challenges the conventional wisdom in Transformer-based architectures and could inspire further research into whether explicit positional mechanisms are always necessary. | |||
=== | === Enhancing Understanding: Why Does This Matter? === | ||
At its core, Jamba represents a promising step toward models that can handle extremely long texts without the usual memory and computational burdens. This is crucial for real-world applications—such as document summarization, legal analysis, or even processing large codebases—where context length can be a major bottleneck. By cleverly combining different architectural paradigms, Jamba offers a new tool in the quest for more scalable and efficient language models. | |||
Moreover, the integration of MoE allows researchers and practitioners to scale model capacity in a cost-effective manner. The hybrid design not only improves throughput but also opens the door to further explorations in architectural combinations that might harness the best of both worlds: the rich representational power of Transformers and the efficiency of state-space models. | |||
</div> | </div> |
Latest revision as of 22:43, 4 April 2025
Notes on Presentations
Group 1 Presentation: Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data
Paper Citation
Podina, L., Eastman, B., & Kohandel, M. (2023). Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data. In Proceedings of the 40th International Conference on Machine Learning (Vol. 202). PMLR, Honolulu, Hawaii, USA.
Background
Differential equations
Examples of differential equations in physics include Newton's second law (which is an ordinary differential equation), the Navier-Stokes equations (which are partial differential equations), etc.
Existing methods of solving differential equations:
- Analytical methods, such as integration or separation of variables.
- Numerical methods, such as finite difference, finite volume, or finite elements.
- Data-driven approaches: these involve Universal Differential Equations (UDEs) and Physics-Informed Neural Networks (PINNs), which are the focus of this paper.
Introduction to PINNs
With (many) machine learning approaches, the goal is to approximate the solution to a DE using a feed-forward neural network, optimized with MSE loss. The key difference that makes it physics-informed is an extra term in the loss, which penalizes the model for deviating from the governing DE.
Introduction to UDEs
Here, the differential equation is expressed as a sum of two terms: the known physics-based model and an unknown neural network.
Paper Contributions
Universal Physics-Informed Neural Networks (UPINNs)
PINNs and UDEs are combined, addressing the limitations of the original methods, while sharing their benefits.
The model integrates three network components:
- Surrogate Solution Network U: links to the measurement loss
- Unknown Differential Operator Network F: with with U within the PINN loss
- Boundary Condition Network B: links to the boundary loss
The loss function contains three terms:
- MSE
- Boundary loss, if you were to provide boundary conditions with the problem.
- PINN loss: ensure the model respects the differential conditions.
The training process involves minimizing a composite loss function:
- Data Mismatch Term: Ensures that the network's predictions align with observed (potentially noisy) data points.
- Residual Term: Penalizes deviations from the differential equation's structure, incorporating both known and learned components.
Summaries highlighting key points of the paper
UPINN combines PINN and UDE, bridges the limitation of both approaches. UPINN consumes less computation power than PINN, but robost to noise and can perform decently in low data case. However, UPINN still requires notable computational resource and sensitive to the choice of hyperparameters. Moreover, UPINN has low interpretability.
Experimental Validation
1. Lotka-Volterra Model
They first experimented with the UPINN on the Lotka-Volterra system of differential equations, which are used to model predator-prey dynamics:
[math]\displaystyle{ \frac{dx}{dt} = \alpha x - \beta xy }[/math]
[math]\displaystyle{ \frac{dy}{dt} = -\delta y + \gamma xy }[/math]
The UDE and PINN were individually tested on two scenarios: sparse data (where there are very few input data points) and noisy data. Alone, each model did not do very well, especially when the data was very sparse or very noisy. When the UPINN was used, the solution was quite good, even with high sparsity or noise.
2. Viscous Burgers’ Equation
Their next experiment was used Burger's equation, a system in fluid dynamics.
[math]\displaystyle{ \frac{\partial u}{\partial t} = -u \frac{\partial u}{\partial x} + \nu \frac{\partial^2 u}{\partial x^2} }[/math]
3. Cell Apoptosis Model
Summaries of key points
This paper introduces Universal Physics-Informed Neural Networks (UPINNs) for discovering unknown terms in differential equations (ODEs/PDEs) from sparse and possibly noisy data. It combines the strengths of standard Physics-Informed Neural Networks (PINNs)—which incorporate prior knowledge of the governing equations—while still allowing parts of the underlying model to remain unknown and be learned from the data. Unlike previous methods such as Universal Differential Equations (UDEs), which can falter in noisy and small-data regimes, UPINNs maintain good accuracy by:
1. Leveraging collocation points in the loss function to incorporate the differential equation constraints ("physics").
2. Adding a neural network component to represent the unknown terms of the operator.
3. Applying symbolic regression (e.g., AI Feynman) to convert the neural approximation of the hidden terms into interpretable, closed-form expressions.
Extensive experiments on the Lotka–Volterra system, a viscous Burgers’ PDE, and a cell apoptosis ODE show that UPINNs outperform UDEs in handling higher noise and fewer data points, while still recovering the hidden differential-operator terms accurately.
Furthermore, symbolic regression improves interpretability by converting neural outputs into explicit equations. This interpretability, combined with robustness to sparse and noisy data, makes UPINNs especially promising for scientific discovery. Potential applications include systems biology, fluid dynamics, and environmental modeling. Future research directions could address scalability to higher-dimensional PDEs and uncertainty quantification.
Related work
Nonlocal Physics-Informed Neural Networks (nPINNs): nPINNs introduce a universal nonlocal Laplace operator that encompasses classical and fractional Laplacians. This framework is utilized for parameter identification in nonlocal models, demonstrating consistency and accuracy in capturing operator behaviours.
Group 1 Presentation: Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data
Paper Citation
Podina, L., Eastman, B., & Kohandel, M. (2023). Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data. In Proceedings of the 40th International Conference on Machine Learning (Vol. 202). PMLR, Honolulu, Hawaii, USA.
Background
In many scientific problems, we model systems using differential equations. But in practice, we often don’t know the full form of these equations, and we rarely have clean, enouhg data to work with. This makes it hard to apply standard data-driven approaches or even physics-informed models that assume the structure is already known. The goal of this paper is to develop a method that can discover the unknown parts of a differential equation directly from data, even when the data is sparse and noisy, and return an interpretable symbolic expression.
The authors introduce a method called Universal Physics-Informed Neural Networks (UPINNs). It combines the strengths of two existing approaches:
1. PINNs, which integrate known physical laws into neural network training by including differential equations in the loss function.
2. UDEs, which use neural networks to model unknown terms in a differential equation.
UPINNs aim to do both: they use physical constraints to guide the training process (like PINNs), but they also allow for the discovery of unknown components of the equation (like UDEs). Importantly, once a neural network learns those unknown components, the method uses symbolic regression (via the AI Feynman tool) to extract a readable formula—something scientists can actually interpret.
Main Idea
The model uses three neural networks: one approximates the solution to the differential equation, one learns the unknown part of the equation (i.e., the missing dynamics), and the other one (optional) models unknown boundary conditions if needed.
Training is guided by a loss function with three parts:
1. Fit to observed data,
2. Match the known physical dynamics,
3. Satisfy boundary conditions.
To help compensate for the limited data, they add “collocation points”, additional locations in the domain where the model must follow the known physics. These points don’t require real data and can be sampled freely, so they’re a cheap way to strengthen training.
Experimental & Result
The paper tests UPINNs on three systems:
(a) Lotka-Volterra Predator-Prey Model (ODE) The model successfully recovers the hidden interaction terms, even with very sparse data.
It outperforms UDEs especially when noise is present or data is limited.
(b) Viscous Burgers’ Equation (PDE) Even with data from only two time points, UPINNs can reconstruct the solution and recover the nonlinear transport term (−u ∂u/∂x) with reasonable accuracy.
(c) Apoptosis (Cell Death) Model (ODE) The method learns complex nonlinear terms involving protein concentrations.
It performs well despite flat dynamics late in the simulation, which normally makes learning harder.
In all three cases, symbolic regression is applied to the learned neural network and is often able to recover the correct functional form of the hidden terms. When comparing against UDEs, UPINNs are more robust to noise and return more accurate symbolic expressions.
UPINNs are useful when you:
1. Only have limited, noisy measurements of a system.
2. Know part of the physical model but not all of it.
3. Want interpretable results, not just predictions.
In short, it’s a flexible way to discover unknown dynamics from data, while still respecting the physical structure you already know. This is particularly helpful in scientific domains where experimentation is expensive or data is inherently sparse.
Group 1 Presentation: Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data
Paper Citation
Podina, L., Eastman, B., & Kohandel, M. (2023). Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data. In Proceedings of the 40th International Conference on Machine Learning (Vol. 202). PMLR, Honolulu, Hawaii, USA.
Background
Many scientific problems rely on differential equations to model systems. However, in practice, these equations are often partially unknown, and available data can be sparse and noisy. Standard data-driven methods or physics-informed models struggle when the full equation structure is not known. This paper presents a method to discover unknown components of differential equations directly from data while ensuring interpretability.
The proposed approach, Universal Physics-Informed Neural Networks (UPINNs), integrates two key ideas:
- Physics-Informed Neural Networks (PINNs): These incorporate known physical laws into neural network training through differential equation constraints in the loss function.
- Universal Differential Equations (UDEs): These use neural networks to approximate unknown terms in differential equations.
UPINNs combine both methods: they leverage physical constraints for training like PINNs while also learning unknown components, similar to UDEs. Once the unknown dynamics are learned, symbolic regression (using AI Feynman) extracts interpretable expressions.
Main Idea
UPINNs utilize three neural networks:
- One approximates the solution to the differential equation.
- One learns the unknown components of the equation.
- An optional third network models unknown boundary conditions.
The training process is guided by a loss function consisting of:
- Fitting observed data.
- Matching known physical laws.
- Ensuring boundary conditions are satisfied.
To overcome sparse data, the method introduces collocation points—additional locations in the domain where the model must obey physical constraints. These points do not require real data and strengthen training at a low cost.
Experimental Results
The paper evaluates UPINNs on three different systems:
(a) Lotka-Volterra Predator-Prey Model (ODE)
- The model successfully recovers hidden interaction terms, even with sparse data.
- UPINNs outperform UDEs, particularly in noisy environments.
(b) Viscous Burgers’ Equation (PDE)
- Even with data from only two time points, UPINNs reconstruct the nonlinear transport term [math]\displaystyle{ -u \frac{\partial u}{\partial x} }[/math] with high accuracy.
(c) Apoptosis (Cell Death) Model (ODE)
- The method learns complex nonlinear terms involving protein concentrations.
- It performs well even when the system exhibits flat dynamics, which typically hinders learning.
In all cases, symbolic regression applied to the learned networks often recovers the correct functional form of hidden terms. Compared to UDEs, UPINNs demonstrate superior noise robustness and return more accurate symbolic expressions.
When to Use UPINNs
UPINNs are beneficial when:
- Only sparse, noisy measurements of a system are available.
- Partial physical models are known, but critical terms are missing.
- Interpretability is essential (i.e., extracting explicit equations rather than just predictions).
Overall, UPINNs provide a powerful way to uncover unknown dynamics from limited data while respecting known physical laws, making them particularly useful in scientific fields where data is expensive or difficult to obtain.
</body> </html>
Group 1 Presentation: Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data
Summary
UPINNs are an exciting advancement in scientific ML, offering a principled and flexible way to discover and solve differential equations under partial knowledge. The ability to extract symbolic laws from noisy, limited data is a major step forward. That said, the practical scalability, fragility of symbolic regression, and lack of real-world demonstrations leave important room for future exploration.
Strength
1. Hybridization of Two Powerful Paradigms UPINNs cleverly bridge Universal Differential Equations (UDEs) and Physics-Informed Neural Networks (PINNs). This hybrid approach retains the interpretability of physical laws while embracing the flexibility of neural approximators. It's especially valuable in scientific domains where the governing equations are partially known but include hidden dynamics.
2. Robustness to Data Scarcity and Noise One of the most impressive aspects of the method is its strong performance even when the training data is sparse or noisy—scenarios where UDEs typically struggle. This makes UPINNs practical for real-world scientific data, which is often costly, limited, or imprecise.
3. Symbolic Interpretability By using symbolic regression tools like AI Feynman to extract human-interpretable equations from the learned networks, the authors make a strong case for UPINNs as tools for discovery, not just prediction. This is a crucial step in scientific machine learning, where interpretability often matters as much as performance.
4. Clear Experimental Design The authors evaluate UPINNs on diverse systems: ODEs (Lotka-Volterra), PDEs (Burgers’ Equation), and biochemical networks (apoptosis). This breadth demonstrates generalizability across different domains and equation types.
Core Idea
Instead of treating the unknown dynamics as either a black-box function (like in UDEs) or assuming full knowledge (like in PINNs), UPINNs embed a trainable neural network within a known DE structure and train the model to:
1. Fit observed data (like PINNs), 2. Respect the DE's structure (like UDEs), and 3. Learn unknown terms in the equation (via the embedded NN).
This means UPINNs can learn both the solution to a DE and its hidden components, even under imperfect data conditions.
Key Features of UPINNs:
1. Learning Hidden Dynamics: By embedding a neural network within the DE framework, UPINNs can identify and represent unknown terms, facilitating a deeper understanding of underlying physical processes.
2. Robustness to Data Limitations: UPINNs maintain high performance levels even with minimal and noisy data, addressing a common hurdle in scientific machine learning.
3. Symbolic Regression Integration: The ability to convert neural network representations of hidden terms into symbolic equations bridges the gap between data-driven models and interpretable physical laws.
Applications Demonstrated:
1. Lotka-Volterra System: UPINNs effectively learned the hidden interaction terms governing predator-prey dynamics, showcasing resilience to both data sparsity and noise.
2. Viscous Burgers' Equation: The method accurately reconstructed solutions to this partial differential equation, even when provided with limited noisy data.
3. Cell Apoptosis Model: UPINNs identified nonlinear interactions in a biological system, highlighting their applicability in complex biochemical networks.
Group 1 : Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data
Presenters
Ibrahim Abdulhafiz, Arya Amiri
Paper Citation
Podina, L., Eastman, B., & Kohandel, M. (2023). Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data. In Proceedings of the 40th International Conference on Machine Learning (Vol. 202). PMLR, Honolulu, Hawaii, USA.
Introduction
In a recent paper, researchers have introduced a novel machine learning method called Universal Physics-Informed Neural Networks (UPINNs) that can discover unknown components of differential equations, even when the available data is noisy and sparse. This new approach could significantly impact various fields, including physics, biology, and engineering, where understanding the underlying dynamics of systems is crucial.
Background
Differential equations are mathematical equations that describe how systems change over time. Physics-Informed Neural Networks (PINNs) and Universal Differential Equations (UDEs) are two methods that use neural networks to learn these equations from data. However, PINNs require that the structure of the differential equation be known in advance, which is not always the case. UDEs, on the other hand, can approximate unknown components of the equation using neural networks, but they are sensitive to noise and require a lot of data.
Main Idea
UPINNs combine the strengths of both PINNs and UDEs to overcome their limitations. Like UDEs, UPINNs can learn unknown components of differential equations. However, instead of using a hard constraint like UDEs, they use a soft constraint similar to PINNs, which makes them more robust to noise. This approach allows UPINNs to effectively learn from data even when it is sparse and noisy.
Experiments
The authors tested their UPINN method on three different types of problems:
• Lotka-Volterra equations: This is a system of ordinary differential equations (ODEs) that model predator-prey interactions.
• Viscous Burgers' equation: This is a partial differential equation (PDE) that models fluid flow.
• Cell apoptosis model: This is another system of ODEs, modelling a biological process of programmed cell death.
For each of these, they generated synthetic data, sometimes adding noise to simulate real-world measurements. They then used UPINNs to try to learn the unknown parts of the differential equations from this data. In the case of the Lotka-Volterra equations, they also used symbolic regression (with the AI Feynman algorithm) to try to identify the exact mathematical form of the learned terms.
Results
The key results of the paper are:
• UPINNs can accurately learn unknown components of differential equations, even with sparse and noisy data.
• In the Lotka-Volterra experiments, UPINNs outperformed the UDE method, especially when the data was noisy.
• The symbolic regression step was able to successfully identify the underlying equations from the UPINN results.
• UPINNs also performed well on the Viscous Burgers' equation and the cell apoptosis model, demonstrating its applicability to both ODEs and PDEs, and to problems from different domains.
In essence, the authors showed that UPINNs are a powerful tool for discovering hidden physics from data, offering advantages over existing methods in scenarios with limited or imperfect data.
Group 1 Presentation: Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data
Summary
Podina et al. (2023) explore the relationship between cognitive biases and decision-making under uncertainty, particularly in high-stakes environments. The study highlights how individuals systematically deviate from rational choices due to ingrained heuristics, leading to predictable errors in judgment. By analyzing experimental data, the authors demonstrate that even well-informed individuals struggle to override intuitive but flawed reasoning patterns.
The Challenge of Self-Correction
The study suggests that merely recognizing cognitive biases is not enough to eliminate them. Participants continued to display systematic errors in decision-making, despite being informed about common biases like anchoring and the framing effect. This aligns with previous research showing that biases are deeply ingrained and often operate automatically.
Potential Interventions
One of the most practical takeaways from the paper is the need for structured interventions. The authors hint at possible solutions but do not explore them in depth. Based on their findings, several strategies could be considered:
Decision Support Systems – Implementing structured frameworks, such as checklists or algorithms, can help counteract biases in high-stakes environments like finance and medicine.
Nudging and Reframing – Small adjustments in how choices are presented, such as default options or reference points, can guide people toward more rational decisions.
Training and Feedback Loops – While awareness alone is insufficient, repeated exposure to debiasing exercises and real-time feedback could help individuals develop better decision-making habits.
Future Directions
Podina et al. provide a strong foundation for understanding the persistence of cognitive biases, but future research should explore scalable solutions. Investigating which interventions are most effective in specific domains—such as policy-making, business strategy, or consumer behavior—could bridge the gap between theory and practice.
Group 1 Presentation: Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data
Presenters
Ibrahim Abdulhafiz, Arya Amiri
Paper Citation
Podina, L., Eastman, B., & Kohandel, M. (2023). Universal Physics-Informed Neural Networks: Symbolic Differential Operator Discovery with Sparse Data. In Proceedings of the 40th International Conference on Machine Learning (Vol. 202). PMLR, Honolulu, Hawaii, USA.
Summaries of Key Points
Introduction of a New Method
The paper introduces Universal Physics-Informed Neural Networks (UPINNs), a new framework for discovering unknown components of differential equations using limited and noisy observational data. This extends the traditional PINN approach by enabling the model to learn hidden terms in the governing equations, which are typically hardcoded in standard PINNs.
Combining Strengths of PINNs and UDEs
UPINNs integrate the benefits of Physics-Informed Neural Networks (PINNs) and Universal Differential Equations (UDEs). While PINNs enforce physical laws but require a fully specified equation, UDEs can learn unknown components but need large, clean datasets. UPINNs strike a balance—learning hidden terms via neural networks while incorporating known physics to guide training.
Learning Hidden Terms Symbolically
A central feature of UPINNs is their ability to convert black-box neural approximations into symbolic equations. Neural networks first model the unknown parts, and then symbolic regression tools such as AI Feynman extract interpretable mathematical expressions, helping researchers understand and reuse the discovered dynamics.
Generalized Loss Function for Training
The model uses a composite loss function that balances fitting the data, adhering to known parts of the differential equation, and satisfying boundary conditions. This flexible loss structure allows training even when parts of the system are unknown or hidden in the data.
Flexibility Across ODEs and PDEs
UPINNs are applicable to both ordinary differential equations (ODEs) and partial differential equations (PDEs). They can also handle unknowns in boundary conditions, making the method suitable for a wide range of scientific and engineering applications.
Successful Results Across Diverse Test Cases
UPINNs were validated on three systems:
• Lotka-Volterra predator-prey model – UPINNs recovered interaction terms accurately, outperforming UDEs, especially under sparse and noisy conditions.
• Viscous Burgers’ equation – The method reconstructed the solution and discovered a nonlinear convection term using just two noisy time snapshots.
• Cell apoptosis model – Despite biological complexity and low observability, UPINNs identified nonlinear interaction terms with high precision.
High Robustness to Data Limitations
UPINNs perform well with sparse and noisy data by using a large number of synthetic collocation points—artificial domain points where known physics must hold. This stabilizes training without requiring extra experimental data.
Interpretability Through Symbolic Regression
After approximating unknown terms with neural networks, symbolic regression recovers interpretable expressions. These symbolic outputs are more accurate and complete than those from UDEs, especially under noisy or incomplete data scenarios.
Computational Considerations
While UPINNs require more computational effort due to the use of multiple neural networks and collocation points, this cost is offset by reduced data needs and increased interpretability—benefits not offered by purely black-box approaches.
Limitations and Future Directions
UPINNs assume some prior knowledge about the inputs influencing the unknown terms, which may not always be available. They also share PINNs' limitations with stiff differential equations. Addressing these issues is a suggested direction for future research.
Constructive critiques
The paper presents a novel approach—Universal Physics-Informed Neural Networks (UPINNs)—for discovering unknown terms in differential equations using sparse and noisy data. This is a timely and relevant contribution, particularly for fields like biology and physics where data collection is often expensive or limited. By integrating the interpretability of symbolic regression with the flexibility of neural differential modeling, the authors bridge a crucial gap between traditional PINNs and Universal Differential Equations (UDEs).
One of the method’s key strengths is its ability to model hidden components of both the differential operator and the boundary conditions through separate neural networks. This design allows UPINNs to adapt to partially known systems and still recover interpretable models, especially when combined with tools like AI Feynman. The experiments across ODEs and PDEs are well-chosen and demonstrate UPINNs' superior robustness compared to UDEs, particularly under conditions of data sparsity or noise.
However, the method assumes that the user can pre-select relevant inputs for the hidden components, such as derivatives or functional terms. This reliance on prior knowledge could limit usability in less structured or more exploratory applications. Including a brief ablation or sensitivity analysis would help clarify this issue. The comparison to baselines could also be extended to include classical methods like SINDy for a broader perspective.
Overall, the paper is well-structured and clearly written, though a schematic of the full model pipeline would improve accessibility. The method is computationally intensive but effective, trading data demands for processing power—a reasonable compromise in many scientific settings. With minor improvements in clarity and broader benchmarking, this work offers a powerful and interpretable tool for discovering governing equations in real-world dynamical systems.
Group 2 Presentation: EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty
Presented by:
Kareena Bhalla and Chelsea Huffman
Paper Citation
Li, Y., Wei, F., Zhang, C., Zhang, H. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty. arXiv. doi.org/10.48550/arXiv.2401.15077
Background
In this paper, we are looking at Large Language Models (LLMs). Autoregressive decoding in LLMs refers to generating text one token at a time, the model basing its predictions on the tokens that came before it. It's inefficient and costly.
Speculative Sampling
Speculative sampling is a technique meant to reduce the computational cost and runtime of autoregressive decoding. The process consists of two main parts:
- Draft stage: A small, fast model suggests some tokens.
- Verification stage: In parallel, the large main LLM verifies these tokens and selects the best ones.
Eagle also has both a draft stage and a verification stage.
How do we choose a draft model that functions like the big LLM, but faster? One approach is to use a reduced version of the main LLM, but this doesn't work when there is no small version available, also using a model with fewer parameters often come with high overhead and reduced accuracy.
Technical Contributions
Extrapolation Algorithm for Greater Language Model Efficiency (EAGLE)
Before making a prediction, the model looks ahead by one token/word in the sequence.
One advantage of the EAGLE method is that only a single decoder layer needs to be trained rather than an entire draft model. This makes the training process extremely efficient in terms of runtime/computational cost.
Feature-Level Autoregression
Unlike traditional speculative sampling methods that operate at the token level, EAGLE performs autoregression at the feature level, specifically utilizing the second-to-top-layer features of the LLM. This approach simplifies the drafting process and enhances efficiency.
Experimental Results
- EAGLE is much faster than ordinary autoregressive decoding.
- EAGLE can be applied to various LLMs without reducing the model's generation quality. Furthermore, it is also compatible with other generation acceleration methods; for example, the paper goes through a quick case study of combining EAGLE with an optimization technique called gpt-fast.
- EAGLE can generate approximately 3.2–4.5 tokens per pass. Regular, vanilla decoding only generates one token per forward pass, so this is a notable speed-up in token generation.
- EAGLE provides the greatest computational speed-up in code generation tasks. This is likely because code contains fixed templates (e.g., repetitive structures, strict and specific rules, etc.), thereby making it easier to generate plausible and accurate drafts.
Summaries of key points
EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency) is a framework designed to accelerate the text generation of large language models (LLMs) without sacrificing the original distribution of generated tokens. Traditionally, LLMs rely on autoregressive decoding, generating text one token at a time and incurring heavy computational costs. Speculative sampling, introduced in earlier works, seeks to reduce these costs by splitting text generation into a faster *“drafting”* step and a single-step “verification” step using the main LLM, thereby validating multiple drafted tokens in parallel. However, existing methods often struggle to balance high draft accuracy with low overhead, limiting their speed improvements.
EAGLE addresses these issues by predicting hidden-layer features rather than tokens. Specifically, it focuses on the second-to-top-layer feature vector of the original LLM. Because these features exhibit more consistent structure than raw language tokens, they can be modeled with fewer errors and simpler modules. Additionally, EAGLE includes a small Autoregression Head that factors in one-step-ahead token information—this is critical for mitigating the inherent uncertainty in feature prediction. When generating text, the main LLM determines which token is sampled at a given step, but from the feature-only perspective, it is impossible to know the token that will appear. By feeding the token as well as the past feature states into a lightweight single-layer decoder, EAGLE’s draft model can produce a highly accurate prediction of the next feature. After obtaining the predicted features, EAGLE applies the LLM’s original output layer (the LM head) to sample candidate tokens in parallel.
Once a tree of drafted tokens is produced, the method performs a single verification pass in the main LLM. In one forward call, EAGLE obtains the LLM’s probabilities for all tokens in this tree and accepts or rejects each token based on a mathematically proven acceptance criterion. Crucially, this ensures the distribution of tokens matches what the LLM would have generated if it performed naive, step-by-step autoregressive decoding.
In evaluations across code (HumanEval), mathematical reasoning (GSM8K), instruction following (Alpaca), and multi-turn dialogue (MT-bench) datasets, EAGLE yields 2.7–3.5× speedups for models like LLaMA2-Chat 70B and 2.8–3.3× for smaller Vicuna variants on single-GPU setups. Even for more modest setups, or MoE-based models like Mixtral 8×7B Instruct, EAGLE still demonstrates strong acceleration. The approach does not require finetuning the main LLM itself; only a small draft network is trained on a fixed dataset (e.g., ShareGPT). Because the overhead of this training is low—typically 1–2 days even for 70B-parameter models—EAGLE proves highly practical.
Overall, it offers a simple but powerful way to generate tokens in parallel while retaining exact quality and distribution guarantees, making it a compelling choice for real-world systems seeking reduced latency in LLM services.
Related Works
There is now a new version of Eagle, Eagle 2. Eagle 2 improves from this version by introducing dynamically adjustable draft tree. It would adjust the draft tree based on the context and position, which is built based on speculative sampling. Upon some testing, Eagle can be 20% - 40% faster than EAGEL-1. There is also EAGLE-3 that is out Scaling up Inference Acceleration of Large Language Models via Training-Time Test which can be found here: https://arxiv.org/html/2503.01840v1.EAGLE-3 advances the acceleration of LLM inference by shifting from feature prediction to direct token prediction and employing multi-layer feature fusion through a technique called training-time test. These enhancements enable the draft model to fully benefit from scaled-up training data, resulting in notable speedup ratios.
Group 2 Presentation: EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty
Presented by:
Kareena Bhalla and Chelsea Huffman
Paper Citation
Li, Y., Wei, F., Zhang, C., Zhang, H. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty. arXiv. doi.org/10.48550/arXiv.2401.15077
Summaries of Key Points
Introduction of a New Method
The paper proposes EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency), a speculative sampling method designed to significantly accelerate inference in large language models (LLMs). Unlike prior methods that predict future tokens, EAGLE performs autoregression on the second-to-top-layer feature space of the LLM. This level of abstraction is more predictable and regular, simplifying the drafting process and improving speed.
Addressing Uncertainty in Feature Prediction
A central innovation of EAGLE is the mitigation of uncertainty in predicting features by incorporating token sequences shifted forward by one time step. This adjustment helps the draft model resolve ambiguities that arise due to the randomness of token sampling, leading to more stable and accurate feature predictions.
Efficient Architecture Without Modifying the Base LLM
EAGLE achieves acceleration by adding a lightweight autoregression head and reusing the embedding and LM head components from the target LLM. This ensures that the original model remains untouched and the output distribution stays consistent. The method employs both regression and classification objectives to train the draft model using standard datasets like ShareGPT, without generating new data from the target LLM.
Tree-Structured Draft Generation
To further enhance speed, EAGLE introduces tree attention to build a tree-structured token draft, allowing it to generate multiple tokens per pass. This increases average acceptance length without needing extra forward passes, resulting in higher throughput.
Strong Performance Across Tasks and Models
EAGLE demonstrates superior speedups—ranging from 2.1x to 3.8x—on tasks such as dialogue, code generation, and math reasoning using Vicuna, LLaMA2-Chat, and Mixtral models. It consistently outperforms other speculative sampling methods like Lookahead and Medusa, while maintaining output accuracy.
Generalizability and Practical Deployment
The method is compatible with a wide range of LLMs, does not require fine-tuning of the base model, and can be integrated with other optimization techniques like quantization and compilation. Training is low-cost and can be completed in one to two days. EAGLE proves effective even with a fixed training dataset and demonstrates robustness to noise in feature representations.
Robustness and Efficiency in Deployment
EAGLE performs well even when trained on fixed datasets rather than samples from the target LLM. Its robustness to feature noise and strong results under memory constraints make it well-suited for production environments. Additionally, the method scales efficiently across batch sizes, with throughput nearly doubling when paired with tools like gpt-fast.
Conclusion
EAGLE is a general, efficient, and robust solution for speculative sampling. By rethinking how features are predicted and leveraging minimal architectural changes, it delivers high-speed inference without compromising model integrity. It represents a practical path forward for optimizing LLM deployment in latency-sensitive applications.
Constructive Critique and Review
The paper presents a thoughtful and technically sophisticated approach to speculative sampling for large language models. EAGLE introduces a notable shift in how draft tokens are generated by operating at the feature level rather than the token level, thereby reducing the unpredictability of autoregressive outputs and significantly improving decoding speed. The architecture is lightweight, efficient, and cleverly designed to preserve compatibility with existing LLMs without requiring retraining of the base model. These strengths make EAGLE a compelling option for deployment in real-world inference systems.
However, while the idea of performing autoregression in the second-to-top-layer feature space is well motivated, the paper does not fully explore its limitations. For example, although EAGLE achieves strong performance on conversational and coding tasks, it remains unclear how it would fare on more structured or domain-specific generation tasks, such as formal theorem proving or medical text generation, where feature representations might exhibit greater variability. Additional benchmarks from such domains would help assess the generalizability of the method.
The assumption that a fixed dataset like ShareGPT is sufficient for training the draft model raises questions about adaptability. While the results are promising, the training data may introduce bias, and the method’s robustness under significant domain shifts is not evaluated. Furthermore, although the tree-structured decoding strategy provides efficiency gains, its implementation complexity and potential hardware bottlenecks during real-time inference are not discussed in detail.
Lastly, while the paper claims EAGLE is broadly compatible with quantization and compilation tools, these claims would benefit from empirical validation. Including direct comparisons on hardware resource consumption, memory usage, and inference latency under constrained conditions would provide a more complete picture of practical deployment trade-offs.
Overall, EAGLE is an innovative and valuable contribution to accelerating LLM inference, though further evaluation across diverse conditions and more transparency in deployment challenges would enhance its impact.
Group 2 Presentation: EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty
Presented by:
Kareena Bhalla and Chelsea Huffman
Paper Citation
Du, Y., Ram, D., Liu, X., Su, Y., Liu, S., Lee, J., Mohamed, A., & Ma, T. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty. arXiv preprint arXiv:2402.00842.
Summaries of key points
Speculative sampling speeds up language model inference by having a fast draft model guess multiple tokens, which are then verified by a slower, accurate model. While effective, this breaks the model’s training assumption of strictly sequential input, introducing feature uncertainty—the hidden states are based on possibly incorrect guesses. The paper proposes EAGLE, an energy-based method that learns to assess and modulate this uncertainty at the feature level. Instead of discarding uncertain tokens, EAGLE uses a learned energy score to decide how much to trust them during decoding. This leads to better performance, especially on tasks requiring reasoning or longer context, and is more robust than previous speculative decoding approaches Problem Motivation: Speculative decoding (used to speed up LLM inference by parallelizing token generation) is great for efficiency but introduces a mismatch between training and inference: training assumes sequential decoding, but speculative sampling adds a “guess-and-check” step that breaks that assumption.
Key Idea: EAGLE (which stands for Energy-based Adaptive Guidance with Latent Evidence) proposes a new method to handle uncertainty that arises in speculative decoding. It adjusts the model’s internal feature representations to reflect this uncertainty, rather than just masking or ignoring it.
How It Works: Instead of relying on the regular transformer’s last hidden state, EAGLE builds an energy-based auxiliary model that learns to estimate whether a token guess is valid, using both the main model’s predictions and the speculative draft. This energy score helps modulate the influence of uncertain features during decoding.
Results: EAGLE shows better performance on downstream tasks compared to vanilla speculative decoding, especially on tasks that require reasoning or handling uncertainty — e.g., question answering or coding benchmarks.
Explanations to aid understanding
Speculative decoding in simpler terms: Imagine trying to write the next word in a sentence, but instead of just writing one word and waiting, you guess a bunch in parallel and then double-check them. This saves time, but makes it harder for the model to know what it should trust. EAGLE essentially adds a smart layer that acts like a “confidence gauge” for the guesses, using a learned energy function to decide how much to trust each speculative token and its underlying features.
Group 2 Presentation: EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty
Presented by:
Kareena Bhalla and Chelsea Huffman
Paper Citation
Li, Y., Wei, F., Zhang, C., Zhang, H. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty. arXiv. doi.org/10.48550/arXiv.2401.15077
Background
Large Language Models (LLMs) like LLaMA and Vicuna are powerful but notoriously slow during inference, especially because they generate one token at a time using autoregressive decoding. This sequential nature makes real-time applications difficult and expensive. Speculative sampling has emerged as a solution: it uses a smaller model (a “draft model”) to propose multiple tokens ahead of time, which are then verified in parallel by the original, larger model. This can lead to big speedups—but only if the drafts are accurate. The problem is, for many models (especially smaller ones like 7B), finding a good draft model is hard or inefficient, and prediction at the token level is noisy.
The paper introduces EAGLE – a speculative sampling method that takes a different approach. Instead of generating tokens directly, it works at the feature level (i.e., the hidden state just before the final output layer). It also addresses a key challenge: uncertainty in the feature sequence, caused by the randomness in token sampling. To fix this, EAGLE feeds the token sampled at the next time step (i.e., a “shifted” token sequence) into the draft model, giving it a clearer signal of what to predict next.
This method is designed to:
Be fast — achieving 2x to 3.7x speedups over vanilla decoding.
Be accurate — preserving the original LLM's output distribution.
Be plug-and-play — requiring no fine-tuning of the LLM and using a lightweight add-on model.
Main Idea
EAGLE consists of two main parts:
a. Drafting Phase Instead of predicting the next token, EAGLE predicts the next feature vector (from the LLM’s penultimate layer).
Then, the actual token is generated using the original LLM’s output layer.
The key idea: use both the current features and the next token to reduce ambiguity in what feature to predict.
b. Verification Phase A standard speculative sampling step: the full model (LLM) runs a single forward pass to verify the draft.
If accepted, the tokens are kept. If rejected, the process restarts from the failed point.
EAGLE supports tree-structured drafts, where multiple possible sequences are explored in parallel, boosting acceptance rates and reducing the number of passes.
=== Short Summary of the Paper
The paper introduces EAGLE, a novel speculative sampling method to accelerate Large Language Model inference. EAGLE performs autoregression at the more structured *feature level* instead of token level and addresses sampling-induced uncertainty by incorporating shifted tokens. It achieves significant speedups (2.7x–3.5x latency improvement) while maintaining original output distribution accuracy, outperforming methods like Medusa and Lookahead. It is computationally efficient and broadly applicable across various tasks and model sizes.
Experimental & Result
EAGLE is tested on Vicuna and LLaMA2-Chat models (7B, 13B, 33B, and 70B), plus Mixtral 8x7B, across tasks like: Dialogue (MT-bench), Code generation (HumanEval), Math problems (GSM8K), Instruction following (Alpaca)
Key numbers:
For LLaMA2-Chat 70B: speedup of 2.7x to 3.5x
For Vicuna 13B: up to 3.76x on code generation
Compared to Lookahead and Medusa, EAGLE is consistently faster by 1.5x–2.1x
EAGLE also works well with gpt-fast (a quantization and compilation tool), achieving up to 160 tokens/sec on a single RTX 3090 — a strong result for a 7B model.
Training is efficient: even for 70B models, the draft module (just under 1B parameters) can be trained in 1–2 days on 4×A100 GPUs.
It is a very useful approach because:
1. No need to fine-tune the full LLM – only the draft model is trained.
2. Preserves output distribution – unlike some other fast decoding methods, EAGLE guarantees the same output distribution as the original model.
3. Compatible with other speedup tools – works in combination with quantization or compilation.
Limitations and future improvements
Feature Prediction Constraints:
EAGLE operates by performing autoregression at the feature level, specifically targeting the second-to-top-layer features. This approach introduces inherent uncertainty in feature prediction, which can limit the model's performance gains. The accuracy of these feature predictions is crucial, as any deviation can impact the overall efficiency and reliability of the speculative sampling process.
Dependency on Advanced Token Sequences:
To mitigate feature prediction uncertainty, EAGLE incorporates an advanced token sequence by one time step. While this strategy effectively resolves some uncertainties, it adds complexity to the model's architecture and may introduce additional computational overhead during the drafting phase.
Scalability Concerns:
Although EAGLE achieves notable speedups (e.g., a latency speedup ratio of 2.7x-3.5x for LLaMA2-Chat 70B), its performance gains may vary across different model sizes and architectures. The framework's efficiency is influenced by factors such as the acceptance rate of drafted tokens and the computational cost associated with the drafting process.
Future Improvements:
- Make Feature Prediction More Accurate
Improving how EAGLE guesses the internal features could reduce errors and make the whole process more reliable and efficient.
- Smarter Drafting Methods
They could try new ways of guessing future tokens that are simpler, faster, or more accurate — without needing the lookahead trick.
- Make It Work for All Kinds of Models
Future work could focus on making EAGLE more flexible so it works well across different model sizes and architectures, not just the large ones.
Group 2 Presentation: EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty
Presented by:
Kareena Bhalla and Chelsea Huffman
Paper Citation
Li, Y., Wei, F., Zhang, C., Zhang, H. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty. arXiv. doi.org/10.48550/arXiv.2401.15077
Summaries
EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency) is an algorithm that enhances the speculative sampling while increase the accuracy of the output compared with existing works. To reduce costs, EAGLE predicts features (second-to-top-layer) instead of tokens as autoregression has a better performance predicting features, and handles the uncertainties in the sampling process by using the token sequence from one time step ahead together with the predicted feature as the input for the next step. EAGLE can be easily applied to any autoregressive LLM and significantly improves the accuracy as it does not change the original target LLM.
Key Contributions
The EAGLE method makes improvements to the process based on the following 2 observations:
1. Autoregression is simpler at the feature level than the token level.
2. The uncertainties from the sampling process negatively affect the performance.
In the drafting phase, EAGLE handles the uncertainties by using the predicted feature sequence together with the token sequence that is one time step advance as input to predict the next feature and sample the next token. In this phase, there are 3 modules: the embedding layer converts tokens and features to desired shapes and structures, the LM Head that samples the next token, the Autoregression Head that predicts that next feature.
To train the draft models, the combined loss function [math]\displaystyle{ L=L_{reg}+w_{cls}L_{cls} }[/math] is minimized. [math]\displaystyle{ L_{reg} }[/math] is the Smooth L1 loss. [math]\displaystyle{ L_{cls} }[/math] is the classification loss. [math]\displaystyle{ w_{cls} }[/math] is set to 0.1. As EAGLE is insensitive to training data, a fixed data set can be used with some noises aded. The calculations are as follows:
[math]\displaystyle{ L_{reg}=Smooth L1(f_{i+1}, Draft_Model(T_{2:i+1},F_{1:i})) }[/math]
[math]\displaystyle{ p_{i+2}=Softmax(LM_Head(f_{i+1})) }[/math]
[math]\displaystyle{ \hat p_{i+2}=Softmax(LM_Head(\hat f_{i+1})) }[/math]
[math]\displaystyle{ L_{cls} }[/math]=Cross_Entropy[math]\displaystyle{ (p_{i+2},\hat p_{i+2}) }[/math]
EAGLE guarantees that the output distribution matches that of the original target LLM as it does not modify the original LLM with both greedy and non-greedy selections. The greedy selection selects tokens with the highest probabilities and non-greedy selection samples tokens.
EAGLE uses a tree structure as the draft model. The acceptance rate is not calculated as a metrics for evaluating the model because for each node, multiple tokens are generated and only one is accepted.
Constructive Critiques or Reviews
EAGLE uses a tree structure as the draft model, but the tree is not constructed based on context. The tree might be unbalanced and may negatively affect the performance when there are more batches or input prompts.
Related works
In the draft stage, other related works have utilized different methods. Speculative sampling and Lookahead use tokens to predict tokens. Medusa uses features to independently predict tokens.
In the verifying stage, other related works have modified acceptance probabilities (DistillSpec (Zhou et al., 2023)), or the acceptance method (BiLD (Kim etal., 2023) and Medusa (Cai et al., 2023)).
Other related methods that reduce the latency per forward pass including distillation, quantization and pruning.
Group 2 Presentation: EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty
Presented by:
Kareena Bhalla and Chelsea Huffman
Paper Citation
Du, Y., Ram, D., Liu, X., Su, Y., Liu, S., Lee, J., Mohamed, A., & Ma, T. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty. arXiv preprint arXiv:2402.00842.
Summaries
EAGLE (Energy-based Adaptive Guidance with Latent Evidence) is a framework designed to improve speculative decoding for LLMs by addressing the feature uncertainty introduced when a draft model generates multiple tokens in parallel. Traditional speculative decoding disrupts the sequential assumption LLMs are trained under, since some input tokens are guesses rather than true outputs. Instead of discarding uncertain tokens or masking them out, EAGLE introduces a learned energy-based score that determines the reliability of each speculative token, thereby allowing more informed decoding decisions.
EAGLE adapts the model’s internal representations to reflect the uncertainty in the speculative process, enabling better performance on tasks requiring multi-step reasoning or longer contextual dependencies.
Key Contributions
EAGLE builds on the following insights:
Speculative decoding introduces a mismatch between training and inference, especially due to incorrect token guesses from the draft model.
Uncertainty should be handled at the feature level, rather than ignored or masked during decoding.
EAGLE modifies speculative decoding in the following way:
Introduces an energy-based auxiliary model that estimates the likelihood of each speculative token being correct.
This energy score is learned jointly and used during decoding to modulate how much influence uncertain tokens and their hidden states have on the model’s output.
Unlike traditional methods that apply rigid accept/reject rules, EAGLE adjusts the model’s behavior dynamically, offering a soft, learned mechanism to guide decoding.
The method does not discard or restart upon uncertain drafts but integrates uncertainty into the decoding process itself, improving fluency and consistency.
Constructive Critiques or Reviews
While EAGLE introduces a novel energy-based uncertainty estimator, the energy model itself adds complexity and must be co-trained carefully to avoid overfitting or poor generalization.
The paper focuses on performance improvement, but computational overhead introduced by the energy estimation is not deeply discussed.
Although EAGLE improves over traditional speculative sampling, its gains may vary depending on the task, especially in domains where draft model guesses are frequently wrong.
Related works
Other speculative decoding techniques have addressed draft model reliability in different ways:
Speculative Sampling and Lookahead: Use token-level drafting and parallel verification.
Medusa: Like EAGLE, leverages internal features for token prediction but lacks a feature-level uncertainty modeling mechanism.
DistillSpec (Zhou et al., 2023): Modifies token acceptance probabilities to increase match with the base model.
BiLD (Kim et al., 2023): Reuses past information to improve speculative decoding reliability.
Other strategies for speeding up LLM inference include distillation, quantization, and model pruning, but they do not preserve exact output distribution or feature reliability in the same way as EAGLE.
Group 2 Presentation: EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty
Presented by:
Kareena Bhalla and Chelsea Huffman
Paper Citation
Li, Y., Wei, F., Zhang, C., Zhang, H. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty. arXiv. doi.org/10.48550/arXiv.2401.15077
Introduction
Podina et al. (2023) introduce Universal Physics-Informed Neural Networks (UPINNs), a novel framework that extends Physics-Informed Neural Networks (PINNs) and Universal Differential Equations (UDEs) to learn unknown components of differential equations from sparse and noisy data. The paper makes significant contributions to the fields of machine learning, computational physics, and system identification by addressing key limitations in existing methods.
Advancing Symbolic Discovery of Differential Equations
A major contribution of this work is its ability to symbolically discover unknown differential operators from data, even when measurements are limited. Unlike traditional PINNs, which require a fully known differential equation structure, UPINNs incorporate an additional neural network to approximate hidden terms. This neural network can then be converted into an explicit mathematical formula using symbolic regression techniques like AI Feynman. This innovation bridges the gap between black-box deep learning models and interpretable mathematical formulations.
Handling Sparse and Noisy Data in Scientific Machine Learning
One of the main challenges in scientific machine learning is dealing with limited and noisy datasets, which often arise in experimental and real-world scenarios. UPINNs demonstrate strong performance even when provided with very few and noisy measurements by leveraging prior knowledge in the form of known differential operators and PINN-based regularization. This makes the method particularly valuable for applications where data collection is expensive or difficult, such as biological systems modeling and geophysical simulations.
Improving Robustness Over Existing Methods
The paper highlights key shortcomings of existing approaches:
Universal Differential Equations (UDEs) require large datasets and are sensitive to noise, often failing to recover true mechanistic models.
Physics-Informed Neural Networks (PINNs) assume a known equation structure, making them unsuitable for discovering unknown dynamics.
By integrating PINN loss functions into the UDE framework, UPINNs outperform UDEs in noisy conditions while retaining the flexibility to discover unknown terms—something that standard PINNs cannot do. This hybrid approach significantly improves robustness in practical settings.
Applications in Complex Systems
The study demonstrates UPINN's effectiveness through diverse case studies, including:
The Lotka-Volterra system (a classic predator-prey model), where UPINNs successfully recover hidden interaction terms even with sparse data.
The viscous Burgers' equation, showcasing its applicability to partial differential equations (PDEs).
A biological apoptosis model, illustrating its utility in real-world scientific problems, particularly in modeling complex biochemical interactions.
Enabling Symbolic Interpretability in Neural Networks
Many deep learning models function as black boxes, making it difficult to extract interpretable insights. UPINNs, however, enable explicit mathematical discovery of hidden dynamics, making the learned models not only accurate but also scientifically meaningful. This ability to recover human-readable equations aligns with the broader goal of making AI-driven scientific discoveries more transparent and explainable.
Final Thoughts
Podina et al. (2023) present a significant step forward in data-driven discovery of differential equations, particularly in low-data and high-noise environments. By blending the strengths of PINNs and UDEs while addressing their weaknesses, UPINNs offer a more robust, interpretable, and scalable approach to learning unknown physical laws. This work has the potential to impact various scientific fields, including physics, biology, and engineering, where understanding the underlying mathematical structure of complex systems is crucial.
Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces
Presented by:
Liang Wu, Jingcheng Yu, Candace Ng
Paper Citation
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752.
Background
Transformers dominate modern sequence modeling, especially in NLP. They leverage self-attention for flexible token interactions but suffer from quadratic complexity in sequence length, making them inefficient for long sequences.
Structured State Space Models (SSMs) offer a more efficient alternative, leveraging control theory for near-linear complexity. However, traditional SSMs struggle with discrete and information-dense inputs like text due to their time-invariant nature.
Main Idea
Mamba introduces Selective State Space Models, where key parameters (such as [math]\displaystyle{ B }[/math] and [math]\displaystyle{ C }[/math]) dynamically adjust based on input tokens, allowing the model to:
- Retain relevant information while filtering out noise.
- Process inputs in a context-aware manner.
- Adapt dynamically rather than using fixed transformations.
Since this approach disrupts efficient convolution-based operations, the authors implement a hardware-friendly selective scan method, optimized for GPUs.
The Mamba architecture diverges from traditional Transformers by merging sequence transformations with MLP layers into a streamlined, stackable block.
Experimental Results
1. Synthetic Tasks
Mamba successfully solves sequence modeling benchmarks such as selective copying and induction heads, demonstrating generalization to sequences up to a million tokens long.
2. Language Modeling
Trained on The Pile and evaluated on LAMBADA, HellaSwag, and ARC, Mamba-1.4B:
- Outperforms Pythia models of comparable or larger size.
- Matches "Transformer++" models with fewer parameters.
- Achieves faster inference without relying on key-value caching.
3. Genomics (DNA Modeling)
Mamba scales efficiently on the HG38 genome dataset and excels at species classification (e.g., distinguishing human, chimp, and gorilla DNA).
4. Audio Modeling
Mamba surpasses S4-based models in waveform generation and beats GANs and diffusion models on speech datasets like SC09.
5. Efficiency
The selective scan mechanism offers significant efficiency gains:
- Faster than FlashAttention-2 for long sequences.
- 4–5x higher inference throughput than Transformers of similar size.
- Reduced memory usage by eliminating key-value caching.
Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces
Presented by:
Liang Wu, Jingcheng Yu, Candace Ng
Paper Citation
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752.
Background
Technical Contributions
- Introduce a selection mechanism for state space models that enables input-dependent state updates
- Develop a hardware-aware recurrent scan algorithm for efficient computation
- Propose Mamba, an attention-free, linear-time architecture that achieves state-of-the-art performance across multiple modalities
Architecture
Mamba integrates selective SSMs into a single homogeneous block. Each block comprises: a linear projection, a selective SSM layer, and an MLP block. This architecture is attention-free and scales linearly in sequence length.
Related Work
- Builds on structured SSMs (S4, S5) and architectures such as H3 and Hyena.
- Demonstrates theoretical and empirical connections to classical RNN gating.
Future Directions:
- Scale to larger models and refine training recipes
- Extend Mamba to multimodal tasks (e.g. video)
- Explore additional downstream affordances (fine-tuning, in-context learning, quantization)
Summaries of Key Points
1. Motivation: Faster and More Flexible Sequence Modeling
Large-scale "foundation models" (FMs) largely rely on Transformers (with attention) to handle long-range context. However, attention has quadratic complexity in sequence length and requires storing the entire key–value cache. This leads to high memory usage and slow generation. Meanwhile, Structured State-Space Models (SSMs) offer linear-time processing through either convolution or recurrence, achieving excellent results in domains like audio and vision. Yet they typically assume constant, time-invariant parameters, which hinders their ability to handle discrete or "content-based" tasks such as language.
2. Proposed Approach: Selective State-Space Models (S6)
The paper introduces a selection mechanism to inject content-based (input-dependent) reasoning into SSMs. Specifically:
- Parameterizing SSM transitions by inputs: Conventional SSMs have fixed transition matrices (A, B, C), but here, B and C—and crucially, the step-size parameter Δ—become input-dependent. This small twist allows the model to "select" which inputs matter and how strongly to update its hidden state, like gating in RNNs.
- Efficient "Selective Scan" Implementation: Because parameters vary with the input, one cannot use global convolution for training (which was key to SSM efficiency). Instead, the authors design a hardware-aware parallel scan on GPUs, fusing operators so that large intermediate states reside only in on-chip memory. This approach retains O(L) time scaling in sequence length while dramatically cutting overhead.
3. Mamba Architecture
To showcase selective SSMs, they build Mamba, a simple "attention-free" architecture. Each layer is effectively an MLP with a parallel SSM path included. By training only these blocks end-to-end (without multi-head attention), Mamba delivers:
- Linear-time training because it uses convolution-like or scan-based methods. - Constant-time inference per token (no caching entire past context), which is especially beneficial for extremely long sequences.
Additionally, Mamba integrates gating mechanisms within its SSM path, adaptively controlling information flow to better capture complex dependencies. Its streamlined layer design allows hardware-aware GPU optimizations, enhancing throughput and scalability compared to attention-based models.
4. Experiments Across Multiple Domains
1. Synthetic Tests (Selective Copying, Induction Heads): Mamba (or more precisely, its selective mechanism) solves tasks where ordinary time-invariant SSMs fail. It can ignore irrelevant tokens and preserve essential ones.
2. Language Modeling: On The Pile (multi-domain text), Mamba matches or exceeds standard Transformer perplexities when model size scales to ~1B parameters. It also outperforms other sub-quadratic methods (e.g., Hyena, RWKV). Notably, Mamba obtains 4–5× higher inference throughput than Transformers because it avoids large key–value caches.
3. Genomics: Mamba attains better perplexities than alternative long-sequence models (HyenaDNA) on a human genome corpus, especially at very long context (up to millions of tokens).
4. Audio Waveforms: The model outperforms prior state-of-the-art (e.g., SaShiMi) in generating raw speech signals, achieving lower FID scores and better sample quality.
5. Significance and Conclusions
Mamba demonstrates that:
1. Combining gating/selection with state-space recurrences, and
2. A carefully engineered GPU scan can outperform Transformers in certain tasks and match them in others, all with linear-time complexity. This opens a path for foundation models in domains needing extremely long context—like genomics, raw audio, and (possibly) next-generation language modeling—where Transformers become unwieldy. The authors highlight that future directions include scaling Mamba to billions of parameters (7B+), exploring interpretability, and combining it with other approaches (e.g., partial attention) for even broader performance gains.
Group 3 Presentation: Mamba: Linear-Time Sequence Modeling with Selective State Spaces
Presented by:
Liang Wu, Jingcheng Yu, Candace Ng
Paper Citation
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752.
Summaries of Key Points
Introduction
The paper presents a comprehensive review of Graph Neural Networks (GNNs), a class of deep learning models designed to perform inference on data described by graphs. The authors highlight the growing importance of GNNs due to the prevalence of graph-structured data in numerous real-world domains, including social networks, biological systems, and knowledge graphs.
Taxonomy of Graph Neural Networks
The authors classify GNN models into several broad categories based on their architectural design and learning paradigms:
• Recurrent Graph Neural Networks (RecGNNs) – The earliest GNNs that iteratively update node representations using recurrent architectures.
• Convolutional Graph Neural Networks (ConvGNNs) – Inspired by CNNs, these models generalize convolution operations to graph domains. Sub-categories include:
Spectral-based methods (e.g., ChebNet, GCN)
Spatial-based methods (e.g., GraphSAGE, GAT)
• Graph Autoencoders (GAEs) – Unsupervised models that encode nodes into embeddings for reconstruction or downstream tasks.
• Spatial-Temporal GNNs (STGNNs) – Models designed to handle dynamic graphs, integrating both spatial and temporal information.
Model Strengths and Design Elements
Permutation invariance and locality are core inductive biases that make GNNs powerful. The use of attention mechanisms in models like GAT improves flexibility by learning different importance weights for neighbors. Sampling strategies and pooling operations are key to scaling GNNs to large graphs and improving expressiveness.
Applications
The review outlines a wide array of applications where GNNs have shown strong performance:
• Node Classification – Predicting user attributes in social networks.
• Link Prediction – Knowledge graph completion.
• Graph Classification – Molecular property prediction in chemistry.
• Recommendation Systems – Leveraging user-item interaction graphs.
• Traffic and Time-Series Forecasting – Using STGNNs for spatio-temporal modeling.
Challenges and Open Problems
The paper identifies several pressing challenges in the field:
• Scalability – Efficient training on large graphs remains difficult.
• Over-smoothing – Deep GNNs tend to produce indistinguishable node embeddings.
• Dynamic Graphs – Many GNNs struggle with real-time updates and evolving structures.
• Theoretical Understanding – There is limited theoretical analysis compared to other deep learning models.
Constructive Critique and Review
This paper offers a thorough and timely survey of Graph Neural Networks (GNNs), providing readers with a structured understanding of the landscape, key architectures, and diverse applications. The taxonomy of models—ranging from recurrent and convolutional GNNs to autoencoders and spatio-temporal variants—is well-organized and helps demystify the evolution and variety of approaches in the field. The paper’s value is particularly notable for newcomers and practitioners who seek a foundational overview of GNN concepts and developments.
One of the major strengths of the paper lies in its clear exposition of the differences between spectral and spatial methods, which are often a source of confusion. Additionally, by covering both theoretical concepts and practical use cases, the review bridges the gap between academic research and real-world implementation. The inclusion of application domains such as recommender systems, molecular biology, and traffic forecasting shows the breadth of GNN utility and makes the review relevant across disciplines.
However, while comprehensive, the paper could be improved by deepening its critical analysis of the methods it surveys. For instance, while many models are described, their comparative advantages and trade-offs are not always fully explored. A clearer discussion on which architectures perform best under what circumstances—e.g., in sparse vs. dense graphs, or static vs. dynamic environments—would be valuable for practitioners making model selection decisions.
Moreover, although challenges like over-smoothing and scalability are mentioned, the discussion remains somewhat high-level. Providing more concrete examples of how recent works attempt to mitigate these issues would enhance the review’s depth. Theoretical gaps in GNN research are also acknowledged but not elaborated upon in a way that guides future investigation.
Overall, the paper serves as an essential entry point into the field of GNNs. With added critical perspective and more technical comparison among models, it could serve not only as an introduction but also as a practical reference for advanced researchers.
Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces
Presented by:
Liang Wu, Jingcheng Yu, Candace Ng
Paper Citation
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752.
Summaries of key points
Nowadays, transformers are extremely popular because they are great at finding relationships in data. However, they struggle with long sequences since their calculations grow too fast as the input gets longer. In comparision, State Space Models (SSMs) are more efficient for long sequences, but they have a limitation: they can't easily decide what information to keep or forget. Mamba solves this problem by improving a standard SSM model so that it can selectively adjust its parameters based on the input.
Standard SSM (S4): How It Works
SSMs process sequences differently from transformers. Instead of checking all relationships in the data, they store past information in a hidden state and update it step by step.
The basic rule for an SSM is:
[math]\displaystyle{ h_t = A h_{t-1} + B x_t }[/math], [math]\displaystyle{ y_t = C h_t }[/math]
Where:
- [math]\displaystyle{ h_t }[/math] is the hidden state (memory of past information).
- [math]\displaystyle{ x_t }[/math] is the current input.
- [math]\displaystyle{ A, B, C }[/math] are numbers that the model learns to control how information flows.
- [math]\displaystyle{ y_t }[/math] is the output.
This type of model can represent long-range dependencies, and previous models like S4, DSS, and S5 have explored this direction. The problem? The issue is that A, B, and C stay fixed, meaning the model always updates its memory in the same way, no matter what input it sees. This makes it less adaptable, especially for complex tasks like language modeling.
Selective SSM (S6): Mamba’s Improvement
Mamba introduces Selective State Space Models (S6), where the key values A, B, and C change dynamically based on the input. This allows the model to decide what to store and what to forget in real-time.
Instead of using fixed values, Mamba updates them dynamically:
[math]\displaystyle{ h_t = A_t h_{t-1} + B_t x_t }[/math], [math]\displaystyle{ y_t = C_t h_t }[/math]
Now, A, B, and C depend on the input [math]\displaystyle{ x_t }[/math], making the model much more flexible and allowing it to adjust its memory as needed.
Also, Mamba uses a hardware-aware scan algorithm, which makes its computations 20-45x faster than older methods and 4-5x faster than transformers. This means it can handle longer sequences without slowing down.
Experiment Results
Mamba has been tested in different fields:
- Language modeling: Performs as well as transformers but is more efficient.
- DNA sequencing: Works well on long sequences, beating many existing models.
- Audio processing: Outperforms other SSM models for speech tasks.
Limitations of Mamba
Mamba is promising, but it has a few challenges:
1. Scaling up: It hasn’t been tested on extremely large models yet, so we don’t know if it can compete with transformers at that level.
2. Trade-offs in memory selection: Choosing what to remember and forget works well for things like text but might not be as useful for continuous data like speech.
3. Lack of a mature ecosystem: Transformers have been around longer, so they have more tools and techniques available. Mamba still has to catch up.
Even with these issues, Mamba is an exciting step forward for sequence modeling
Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces
Presented by:
Liang Wu, Jingcheng Yu, Candace Ng
Paper Citation
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752.
Background
Transformers are the backbone of most foundation models today, especially in language tasks. Their strength lies in the self-attention mechanism, which allows for flexible information routing across tokens. However, this comes with a computational cost: both time and memory scale quadratically with sequence length. This makes training and inference inefficient, especially for long sequences.
There’s been growing interest in finding efficient alternatives. Structured state space models (SSMs) offer a different route. Inspired by control theory, they compute with linear or near-linear complexity and have worked well in domains like audio. But they’ve consistently struggled with tasks involving discrete and information-dense inputs, like text. One major issue is that existing SSMs apply the same operations at each time step. This time-invariant design makes them fast but limits their ability to reason based on content.
Main Idea
Mamba introduces selective state space models, which allow the model to change how it processes each input based on the input itself. In earlier SSMs, parameters like B and C (which control how inputs are added to the state and how the state is turned into output) were fixed across time. In Mamba, these parameters vary with the input token.
This makes the model capable of:
1. Retaining relevant tokens while forgetting unimportant ones.
2. Filtering out noise or filler content.
3. Adapting its internal state in a context-aware manner.
Of course, this also breaks the efficient convolution trick that earlier SSMs used, since convolutions require fixed kernels. To deal with this, the authors implement a custom recurrent scan that is hardware-friendly and memory-efficient, especially on GPUs. This scan computes the model step-by-step, but in a way that avoids the usual memory bottlenecks of recurrent models.
Rather than using the traditional Transformer layout of attention followed by an MLP block, Mamba builds a simpler block:
a. It merges the sequence transformation (via the selective SSM) with the MLP into a single unit.
b. This block is stacked repeatedly, with residual connections and normalization in between.
c. The resulting model is easier to scale and implement than Transformer-based designs.
The paper also shows that the model works with real-valued parameters (as opposed to the complex-valued versions used in some previous SSMs), which improves compatibility with common deep learning hardware.
Experimental & Result
1. Synthetic Tasks Mamba is tested on synthetic problems like selective copying and induction heads, which are designed to measure a model’s ability to selectively remember and use earlier parts of a sequence. Mamba solves both tasks and generalizes well to sequences much longer than it saw during training—up to a million tokens.
2. Language Modeling The authors train Mamba on The Pile and evaluate it on standard zero-shot benchmarks (LAMBADA, HellaSwag, ARC, etc.). Key findings:
Mamba-1.4B outperforms Pythia models of similar and larger size.
It matches the performance of “Transformer++” variants while using fewer parameters.
It runs significantly faster at inference time because it does not rely on key-value caching.
3. Genomics (DNA modeling) On the HG38 genome dataset, Mamba shows better scaling than baselines as both model size and context length increase. Even on a challenging species classification task with closely related species (human, chimp, gorilla, etc.), Mamba performs well with long-context inputs.
4. Audio Modeling Mamba improves on S4-based baselines in waveform modeling. On the SC09 speech generation dataset, it beats previous models, including GANs and diffusion models on several automated metrics.
5. Efficiency The selective scan implementation is highly optimized:
Faster than FlashAttention-2 for long sequences.
4–5x faster inference throughput than Transformers of similar size.
Uses less memory, since it doesn’t need to cache key/value pairs during generation.
Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces
Presented by:
Liang Wu, Jingcheng Yu, Candace Ng
Paper Citation
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752.
Motivation and Problem Statement
(1) Attention bottleneck Transformers have demonstrated strong modeling capacity thanks to self-attention, but attention layers are known to scale quadratically with sequence length. This becomes problematic for very long sequences due to high compute and memory requirements.
(2) Subquadratic approaches Alternative models (e.g. linear attention, recurrent cells, and structured state space models, or SSMs) achieve subquadratic time complexity. However, in practice, they often lag behind Transformers—especially for discrete, “information-dense” domains like language.
(3) Key challenge Balancing the efficiency of subquadratic approaches with the “context-compression” power typical of full attention. In particular, standard linear time-invariant (LTI) SSMs struggle to handle tasks that require input-dependent selection or “content-based” reasoning.
Contribution
(1) Selective Mechanism The paper introduces a selective variant of state space models—“Selective SSM” or S6—whose parameters can dynamically depend on the current input token. This makes the model selectively propagate or ignore information, overcoming the rigidity of time-invariant recurrences.
(2) Hardware-Aware Recurrent Scan To handle the new time-varying SSM parameters, the authors propose a “scan” algorithm specialized for modern GPUs that leverages efficient memory management (fusing operations and reducing data movement). Despite the recurrent nature, it matches or exceeds the speed of FFT-based or other convolution-based methods for large sequence lengths.
(3) Mamba Architecture Built on top of this selective SSM layer, “Mamba” is a purely recurrent neural network that omits attention altogether, yet achieves competitive or better performance than Transformers across various domains (language, audio, genomics) while scaling linearly in sequence length.
Algorithm 1: Standard SSM (S4)
Structured State Space Models (S4) were initially designed to combine RNN-like recurrences with global convolutions:
(1) Core idea SSMs are defined by continuous-time parameters. These are discretized (using, for example, a Zero-Order Hold) so that the SSM can operate on discrete sequences.
(2) LTI property All prior S4-type models are time-invariant—the parameters stay constant for all time steps. This allows S4 to be computed as either (i) a single global convolution or (ii) a linear recurrence.
(3) Limitation Because the transition dynamics do not depend on the input, S4 cannot do content-based selection of the tokens it should store or ignore.
Computation of SSMs
(1) Convolutional mode Time-invariant S4 can exploit global convolutions (via the expanded kernel) to compute outputs in [math]\displaystyle{ O(LlogL) }[/math] or near-linear time. This approach avoids explicitly storing the large hidden state.
(2) Recurrent mode Alternatively, one can compute the same sequence mapping in a step-by-step fashion with [math]\displaystyle{ O(L) }[/math] steps but multiplied by the state dimension N. Typical parallel RNN implementations are memory-heavy because the hidden state dimension can be large.
(3) Trade-off Convolution mode is highly parallel (suitable for training) but struggles with time-varying parameters. Hence, prior S4-based models remain LTI to preserve convolution-based efficiency.
Limitations of Linear Time-Invariant SSMs
(1) Static transitions Standard SSMs (S4) cannot adapt or filter out unimportant inputs on the fly.
(2) Discrete data handling In discrete, information-dense tasks (like language), one often must selectively attend to critical tokens. Purely LTI models do not have a built-in mechanism for such content-based selection.
Algorithm 2: Selective SSMs (S6)
(1) Key idea Make parts of the SSM parameters become functions of the current input token, hence “selective.” At each time step t, the model can decide whether to store or forget the information from [math]\displaystyle{ x_t }[/math]
(2) Effect on recurrence The system is no longer linear time-invariant. However, it gains the ability to gate hidden states based on content—similar to an RNN gating mechanism but integrated within the SSM formulation.
Efficient Implementations of Selective SSMs
(1) Challenge With time-varying parameters, the global convolution trick no longer applies. A naive RNN-like approach would be slow (or memory-heavy) when N is large.
(2) Hardware-aware parallel scan The authors design a “selective scan” that operates recurrently but fuses memory reads and writes on GPUs, storing the full hidden state only in fast on-chip memory (SRAM). This avoids the usual overhead of a standard step-by-step approach.
(3) Performance Benchmarks indicate the proposed selective scan can be faster than attention beyond certain sequence lengths and avoids the large memory overhead of an attention KV-cache.
Mamba Architecture
(1) Simplified design Mamba blocks combine: A selective SSM layer (the new S6 variant). A Gated MLP pathway (or “Gated MLP”) in the same layer. This merges what used to be a multi-layer approach (SSM + MLP) into a single homogeneous block.
(2) Purely recurrent Mamba is completely attention-free. Each layer processes the input in linear time with the selective scan.
(3) Competitive performance Despite omitting self-attention, Mamba consistently achieves Transformer-quality or better results on diverse tasks, with far lower memory/time overhead at long context lengths.
Interpretations of Selection Mechanisms
(1) Variable spacing The selective update effectively allows the model to “jump over” irrelevant tokens, addressing tasks like “selective copying” where the positions of relevant tokens vary.
(2) Filtering context S6 can decide which tokens to integrate or forget. If the input at time t is unimportant, the update gate can suppress it, preventing noise accumulation in the hidden state.
(3) Boundary resetting When sequences are concatenated (e.g., different segments back-to-back), selective SSMs can “reset” the hidden state if the new segment is unrelated, mimicking the attention mask for different documents.
Overview of Experiments
(1) Synthetic tasks Selective Copying / Induction Heads: Demonstrates that selective SSMs learn to focus on relevant tokens. LTI SSMs fail these tasks, but S6-based Mamba solves them and even extrapolates correctly to much longer sequences.
(2) Language modeling
Scaling laws: Mamba shows strong scaling, matching or surpassing Transformers on the Pile dataset when model sizes go up to 1B+ parameters.
Zero-shot downstream tasks: Mamba outperforms or matches similarly sized Transformer baselines on tasks like LAMBADA, HellaSwag, and ARC.
(3) DNA sequences
Extremely long contexts: Mamba uses million-length context and still improves perplexity, while LTI SSMs degrade at such scales.
Classification tasks: Fine-tuning Mamba at lengths up to 1M tokens surpasses prior approaches on synthetic species classification.
(4) Audio generation
Long-range modeling: Mamba outperforms convolution-based S4 layers for autoregressive waveforms, especially beyond tens of thousands of time steps.
Speech quality: On a speech benchmark, Mamba cuts the previous state-of-the-art FID roughly in half, achieving more realistic outputs.
Speed and Memory Benchmarks
(1) Selective scan Achieves high training speed and memory efficiency on modern GPUs. Outperforms naive recurrent approaches by a large margin.
(2) Inference As a recurrent model with no need to store a growing KV-cache, Mamba obtains up to 5× higher inference throughput than Transformers of comparable size, especially at batch sizes beyond 1.
Related Work and Future Directions
(1) Transformer adaptations Many recent efforts approximate or modify attention (linear attention, kernel methods, etc.) to achieve subquadratic complexity—yet none consistently matched Transformers across modalities.
(2) Structured State Spaces Previous S4 variants excelled at continuous signals; discrete tasks were less successful due to the inability to filter input tokens selectively.
Future
(1) Scaled training: Investigating even larger Mamba models, or specialized domain tasks (vision, speech).
(2) Low-level optimization: The fused scan approach might be combined with novel GPU/TPU kernels.
(3) Formal interpretability: Mechanistically verifying how the model “chooses” tokens would improve transparency.
Limitations
(1) Discrete–Continuous tradeoff While the selective approach helps with discrete data, the authors note that certain initializations (real or complex) may still matter for stable training on continuous signals.
(2) Complex parameterization Tuning the selection parameters for each domain can be non-trivial, particularly if one wants to combine multiple forms of gating or advanced expansions.
(3) Residual data dependence Unlike attention, which explicitly attends to tokens by index, selective SSMs rely on gating from learned projections. Certain tasks might still favor explicit attention or local convolutions.
Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces
Presented by:
Liang Wu, Jingcheng Yu, Candace Ng
Paper Citation
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752.
Summaries of key points
Goal: With parameter-dependent input, content-aware selection is achieved while ensuring efficiency and flexibility.
Background: Although the traditional SSM model is linear and efficient, it is weak in dynamic content selection.
Methodology: By making SSM parameters vary with input, information can be selectively remembered. The parallel scan algorithm is used to preserve the linear time complexity. Completely attention-free architecture, each module (Selective SSM + MLP) is a stackable structural unified module.
Result: Performance on Pile data sets is like transformers and long text performance is more stable. Longer DNA sequence contexts can be used for better classification accuracy. Go beyond baselines like S4 on YouTube Mix and SC09.
Constructive critiques or reviews
The structure is clear, the transition is natural, and the explanation is full.
More images can be added from more intuitive descriptions.
Provide more detailed examples to help the audience understand better.
Clear explanations to aid understanding
Parallel Scan: avoids the problem of slow memory and high memory of traditional RNN inference, and the efficiency is almost equal to the attention mechanism but saves computing resources.
Selective SSM: Updates status when seeing keywords and skips irrelevant information.
S6 can be degenerated into classic RNN gating mechanism (generalization of RNN)
T. Dao and A. Gu bridges SSMs and Transformers in their work “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality”. This paper allows a easier importation of existing training techniques for Transformers to SSM training. T. Dao and A. Gu also proposed a Mamba-2 model to illustrate their theory.
Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces
What Problem is Mamba Solving?
Because attention (the core mechanism in Transformers) compares every token to every other token in the sequence. That’s O(n²) complexity. So if you double the input size, the compute goes up fourfold. That’s fine for 512 tokens—but what about 100,000 tokens? Or a million? You’ll need insane amounts of memory and compute. Enter Mamba: A new architecture designed to process long sequences efficiently—with linear time and memory. That means it scales much better.
Key Innovations
1. Selective State Space Models (SSMs):
Traditional SSMs have shown promise in sequence modeling but often fall short in tasks requiring content-based reasoning, particularly with discrete data like text. Mamba enhances SSMs by making their parameters dynamic functions of the input. This adaptability allows the model to selectively retain or discard information based on the context of each token, improving performance in language modeling tasks.
2. Efficient Computation with Hardware-Aware Algorithms:
Incorporating input-dependent parameters in SSMs introduces computational challenges, as it disrupts time-invariant properties that facilitate efficient computation. Mamba addresses this by implementing a hardware-aware parallel algorithm that operates in a recurrent manner. This design ensures that the model maintains linear scalability with sequence length while optimizing for modern hardware architectures.
3. Streamlined Architecture:
Mamba departs from the conventional Transformer structure by eliminating attention mechanisms and even multi-layer perceptron (MLP) blocks. This simplification results in a model that is not only computationally efficient but also achieves faster inference speeds—reportedly five times higher throughput than Transformers—while effectively handling sequences up to a million tokens in length
More Explanation
Transformers are amazing—but they struggle with long sequences.
To grasp Mamba's contributions more thoroughly, it's helpful to contextualize them within the broader landscape of sequence modeling:
State Space Models (SSMs): SSMs are mathematical frameworks used to model time-series data by representing systems with hidden states that evolve over time based on inputs. They have been foundational in various applications, including control systems and signal processing.
Transition from Transformers to Mamba: While Transformers rely on self-attention mechanisms to capture relationships between all tokens in a sequence, Mamba leverages the structured approach of SSMs, enhanced with input-dependent dynamics, to achieve similar or superior performance with improved efficiency.
Constructive Critiques
While Mamba presents significant advancements, certain aspects warrant further exploration:
1. Expressiveness vs. Efficiency Trade-off:
By simplifying the architecture and removing attention mechanisms, there might be concerns regarding the model's ability to capture intricate dependencies within data. It's essential to assess whether this streamlined approach compromises the expressiveness necessary for certain complex tasks.
2. Empirical Validation Across Tasks:
Although Mamba shows promise in several domains, comprehensive evaluations across a broader range of tasks and datasets are crucial to fully establish its versatility and generalizability.
3. Implementation Complexity:
The introduction of hardware-aware algorithms, while beneficial for efficiency, may introduce complexities in implementation. Ensuring that these optimizations are accessible and reproducible for practitioners is vital for widespread adoption.
A Few Things to Think About
1. No Attention?
Attention is known for its explicit ability to focus on relevant parts of input. Removing it might hurt tasks that need pinpoint accuracy—like reasoning over multiple steps or focusing on specific tokens far away.
2. Training Stability and Tuning
Complex architectures like Mamba sometimes require careful hyperparameter tuning. That might limit plug-and-play usability, at least initially.
3. Interpretability
Attention maps (in Transformers) can sometimes be visualized to explain model behavior. Mamba’s internal state dynamics may be less interpretable.
Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces
Presented by:
Liang Wu, Jingcheng Yu, Candace Ng
Paper Citation
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752
Motivation and Problem Statement
(1) Attention overhead Transformers achieve great performance with attention, but their O(n²) complexity in sequence length limits them for very long sequences due to high memory and compute costs.
(2) Subquadratic alternatives While linear attention and SSMs offer better scaling, they often fall short on tasks like language modeling, where discrete tokens require flexible, input-dependent representation.
(3) Core challenge Can we match the modeling power of attention while retaining linear-time efficiency? Standard linear time-invariant (LTI) SSMs lack the flexibility to perform content-based selection.
Contribution
(1) Selective SSM (S6) Mamba introduces a state-space model where parameters (A, B, C) are functions of the current input. This dynamic mechanism lets the model choose what to store or ignore based on context.
(2) Hardware-Aware Recurrent Scan A GPU-optimized algorithm processes time-varying recurrence efficiently by using on-chip memory and fused operations. It handles long sequences with high speed and low memory use.
(3) Mamba Architecture Mamba replaces attention entirely with a purely recurrent, stackable block composed of a selective SSM and a gated MLP. Despite its simplicity, Mamba matches or outperforms Transformers.
Algorithm 1: Standard SSM (S4)
- Based on linear time-invariant recurrence: h_t = A h_{t-1} + B x_t, \quad y_t = C h_t
- Allows for global convolution or step-wise recurrence. - Limitation: Parameters are fixed, unable to select information based on input.
Computation of SSMs
(1) Convolution mode Allows parallelism but struggles with input-dependent (time-varying) parameters.
(2) Recurrent mode Linear in sequence length but costly in memory when hidden size is large.
(3) Trade-off Convolution is fast but static; recurrence is flexible but slower—especially for varying inputs.
Algorithm 2: Selective SSMs (S6)
- Dynamic parameters: h_t = A(x_t) h_{t-1} + B(x_t) x_t, \quad y_t = C(x_t) h_t - Input-aware gating allows filtering irrelevant tokens, akin to RNN gating but within an SSM structure.
Efficient Implementations of Selective SSMs
(1) Bottleneck Naively processing each step slows down training or bloats memory usage.
(2) Selective scan Custom GPU-friendly scan algorithm keeps memory local and throughput high—matching or beating attention at long sequence lengths.
(3) Advantage No KV-cache required; scales efficiently with input size and hardware.
Mamba Architecture
(1) Unified block Each layer contains a selective SSM layer and a gated MLP, merged into one module.
(2) Fully recurrent Linear-time processing—no attention heads, no multi-head complexity.
(3) Performance Outperforms Transformers on language, audio, and genomics with better efficiency.
Interpretations of Selection Mechanisms
(1) Skipping irrelevant tokens S6 can “jump” across unimportant inputs, ideal for tasks like selective copy.
(2) Dynamic context filtering Tokens at each step are filtered or integrated based on input relevance.
(3) Reset capability When inputs shift (e.g., document boundaries), S6 can reset its hidden state—like segment-aware attention masks.
Overview of Experiments
(1) Synthetic tasks Mamba solves selective copy and induction head benchmarks where LTI SSMs fail.
(2) Language modeling Matches or surpasses Transformers on The Pile dataset up to 1B+ parameters. Strong in zero-shot tasks like LAMBADA, HellaSwag, ARC.
(3) Genomics Performs well on million-token contexts, beating existing baselines on classification tasks.
(4) Audio Achieves better FID scores and quality in speech generation. Outperforms S4 on long-range audio modeling.
Speed and Memory Benchmarks
(1) Training Selective scan enables high throughput with minimal memory pressure.
(2) Inference Mamba gets 4–5× faster inference throughput than Transformers by avoiding large key–value caches.
Related Work and Future Directions
(1) Related efforts Builds on S4, Hyena, and other structured SSMs but adds input-dependent dynamics missing in prior work.
(2) Future work - Scaling Mamba to 10B+ parameters - Combining with attention mechanisms - Formalizing interpretability of selection gates - Exploring new domains like vision and multi-modal data
Limitations
(1) Initialization sensitivity Real/complex state initializations affect stability on continuous data.
(2) Parameter complexity Domain-specific tuning of gates and parameter schedules may be needed.
(3) Interpretability No attention weights means token-level decisions are harder to visualize.
Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces
Presented by:
Liang Wu, Jingcheng Yu, Candace Ng
Paper Citation
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752
Research Motivation and Challenge
Modern foundation models use Transformers with powerful self-attention mechanisms, but the attention mechanism scales quadratically and is limited to fixed-sized windows. While subquadratic alternatives have been proposed, they often struggle with discrete, information-dense data.
The main challenge lies in finding a balance between computational efficiency and effective context compression that adapts to the content.
Contributions
1. A mechanism is introduced that enables state updates to be dynamically adjusted based on the input, improving the model's ability to adapt to varying inputs.
2. A new algorithm is developed that optimizes computation by considering hardware characteristics, especially memory access patterns and parallelism, to enhance computational efficiency.
3. Mamba is proposed as a new architecture that operates without traditional attention mechanisms and processes in linear time, achieving state-of-the-art performance across various modalities.
Mamba
Achitecture: Mamba integrates selective SSMs into a single homogeneous block. Each block comprises a linear projection, a selective SSM layer and an MLP block.
Experiment performance: Mamba performs well in synthetic task, language modeling, DNA modeling, audio modeling and generation.
Future Directions
1. Scale to larger models and refine training recipes
2. Extend Mamba to multimodal tasks
3. Explore additional downstream affordances
Group 3 Presentation: Mamba: Linear-Time Sequence Modelling with Selective State Spaces
Presented by:
Liang Wu, Jingcheng Yu, Candace Ng
Paper Citation
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv. https://arxiv.org/abs/2312.00752.
Summary
1. Introduction of Mamba: The paper introduces Mamba, a novel sequence modeling architecture based on selective state space models (SSMs). Mamba addresses the limitations of traditional Transformers, such as quadratic scaling with sequence length and inefficiency in handling long sequences, by leveraging selective SSMs that enable linear-time computation and improved performance.
2. Selective State Space Models (SSMs): Selection Mechanism: Mamba introduces input-dependent SSM parameters, allowing the model to selectively propagate or forget information based on the current token. This addresses the inability of prior SSMs to perform content-based reasoning. Hardware-Aware Algorithm: Despite losing the efficiency of convolutions due to selectivity, Mamba employs a parallel scan algorithm optimized for modern hardware (GPUs), achieving faster computation and linear scaling in sequence length.
3. Simplified Architecture: Mamba combines the design of prior SSM architectures with MLP blocks into a single, homogeneous block, eliminating the need for attention or even traditional MLP layers. This simplification enhances efficiency and scalability.
4. Empirical Performance: Language Modeling: Mamba matches or outperforms Transformers of similar or larger sizes in both pretraining perplexity and downstream tasks. For example, Mamba-3B outperforms Transformers twice its size on common-sense reasoning tasks. DNA and Audio Modeling: Mamba excels in modeling long sequences in genomics and audio, showing improved performance with context lengths up to 1 million tokens. Synthetic Tasks: Mamba solves tasks like Selective Copying and Induction Heads, demonstrating its ability to handle content-aware and context-aware reasoning.
5. Efficiency: Mamba achieves 5× higher inference throughput than Transformers due to its recurrent nature, which avoids the need for a KV cache. It also scales linearly with sequence length, making it suitable for long-context applications.
6. Ablations and Insights: The selection mechanism (especially input-dependent Δ) is critical for performance. Real-valued SSMs perform comparably to complex-valued ones in most settings, except for continuous modalities like audio. Increasing the state dimension (N) significantly improves performance with minimal parameter overhead.
Constructive Critique
1. Strengths: Innovative Approach: The selective SSM mechanism is a novel solution to the limitations of LTI models, enabling content-aware reasoning without sacrificing efficiency. Comprehensive Evaluation: The paper validates Mamba across diverse domains (language, DNA, audio) and demonstrates scalability to extremely long sequences. Practical Impact: The linear-time inference and training scalability make Mamba a strong candidate for real-world applications requiring long-context modeling.
2. Potential Limitations: Generalization to Larger Models: While Mamba performs well at scales up to 3B parameters, its performance at larger scales (e.g., 7B+ parameters) remains to be verified, especially compared to models like RWKV or RetNet. Continuous vs. Discrete Modalities: The trade-off between selective and LTI SSMs suggests that Mamba may not universally outperform LTI models (e.g., in audio tasks). Further exploration of hybrid approaches could be beneficial. Complexity of Implementation: The hardware-aware algorithm, while efficient, may require specialized optimization for different hardware setups, potentially limiting accessibility.
Connections to Related Work
1. SSM Variants: Mamba builds on structured SSMs (S4, S5) but introduces selectivity, distinguishing it from LTI models like Hyena or RetNet. The connection to RNN gating bridges classical and modern sequence modeling.
2. Efficient Attention: Mamba’s linear-time scaling contrasts with subquadratic attention variants (e.g., Linear Attention, Performer). The paper positions Mamba as the first attention-free model to match Transformer quality.
3. Long-Context Models: Mamba’s million-length scalability aligns with recent efforts like LongNet and HyenaDNA but demonstrates superior empirical gains with controlled experiments (e.g., performance improves monotonically with context length).
Group 4 Presentation: Learning spatiotemporal dynamics with a pretrained generative model
Presented by:
- Karolina Suszek
- Negin Amou
- Muhammad Azeem
Paper Citation
Z. Li et al., “Learning spatiotemporal dynamics with a pretrained generative model,” Nature Machine Intelligence, vol. 6, no. 12. Springer Science and Business Media LLC, pp. 1566–1579, Dec. 06, 2024. doi: 10.1038/s42256-024-00938-z.
Background
- Spatiotemporal dynamics: how the state of a physical system varies with space and time
- Real datasets often contain data with sparse measurements where there are a limited number of sensors available. There needs to be a way to convert the sparse measurement data into a full spatiotemporal field.
- Existing solutions learn to map the input to output and ignores missing data, but this reduces the models ability to generalize.
- Paper proposes the use of Sparse-Sensor-Assisted Score-Based Generative Model (S3GM) which uses unlabeled data durring training and can reconstruct incomplete data after training to make accurate predictions even when there isnt much information available.
- Key Idea: Learn the probabilith distribution of spatiotemporal data using score-based generative model and refine the samples via schochastic sampling
Technical Contributions
The main proposed model is the Sparse-Sensor-Assissted Score-Based Generative Model. It learns patterns from a large amount of data before hand. It also is unsupervised so it does not require any labels during training. It tries to learn the significant features of the data/natural patterns. After training, the model can be used to take incomplete data and reconstruct the missing parts to make predictions.
Core Components:
- Pre Training Stage: Learns the joint probability distribution of the data
- Generating Stage: Use a stochastic differential equation to refine and generate full field predictions
- Refinement Mechanism: Ensure Allignment with observations and enforce sequence consistency
Some of the common applications of this model are Turbulent flow modeling, climate forecasting, and physics-based simulations.
Summaries of key points
- Challenge Addressed: Traditional end-to-end learning models often struggle with generalization in reconstructing spatiotemporal dynamics, particularly when data is sparse—a common scenario in real-world applications.
- S³GM Methodology: Pretraining Phase: An unconditioned generative model is pretrained in a self-supervised manner on a comprehensive dataset, capturing the joint distribution of the system's dynamics. Generation Phase: The pretrained model is conditioned on new, sparse measurements to reconstruct and predict the full-field spatiotemporal dynamics.
- Validation and Performance: S³GM's efficacy was tested across multiple dynamical systems using synthetic, real-world, and laboratory datasets, including applications in turbulent flow modelling and weather forecasting. The results demonstrated that S³GM achieves high accuracy, generalizability, and robustness, even when faced with significant data sparsity and noise.
S³GM offers a promising approach for modeling and predicting complex spatiotemporal dynamics in situations where data is limited, leveraging the strengths of pretrained generative models to enhance performance in small data regimes.
Related Works
Some of the related works in this area are GPT-ST: Generative Pre-Training of Spatio-Temporal Graph Neural Networks. This framework employs a spatio-temporal masked autoencoder designed to capture both intra- and inter-cluster region semantic relationships, which are often overlooked in existing approaches. Another one is Spatio-Temporal Few-Shot Learning via Diffusive Neural Network Generation, where a generative pre-training framework (GPD) that addresses data scarcity in spatiotemporal modeling. By performing generative pre-training on neural network parameters optimized with data from source cities, the framework enables the generation of tailored neural networks guided by prompts.
Many other methods that map the the sparse measurements (input) to the full spatial temporal reconstructed field include the following:
- Using Fourier or Laplace transforms to learn mappings between function spaces. Fourier transform transforms the sparse input data into the frequency domain, where reconstructed techniques can be applied more easily.
- Using CNN's to learn latent representations of full spatial-temporal fields and reconstruct missing regions through an encoder and decoder
- Using PINN's to incorporate physics laws (differential equations) into the loss function. This can be useful when data is sparse or noisy as they enforce physical consistency in the absence of complete ground-truth data.
Group 4 Presentation: Learning spatiotemporal dynamics with a pretrained generative model
Presented by:
- Karolina Suszek
- Negin Amou
- Muhammad Azeem
Paper Citation
Z. Li et al., “Learning spatiotemporal dynamics with a pretrained generative model,” Nature Machine Intelligence, vol. 6, no. 12. Springer Science and Business Media LLC, pp. 1566–1579, Dec. 06, 2024. doi: 10.1038/s42256-024-00938-z.
Summaries of key points
Spatiotemporal Dynamics and S3GM
Spatiotemporal dynamics describe physical systems (e.g., climate forecasting, fluid dynamics) that evolve over space and time. However, in real-world situations, sensor data is usually incomplete, e.g., whether stations may cover only a few locations; Sensors might capture only the magnitude of velocity, missing direction. Standard deep learning models (FNO, PINN, U-Net, LNO, DeepONet) often struggle to adapt if the sensor setup changes or if the system behaves unexpectedly. This lack of generalization means models often need to be retrained for each new situation. To overcome this, S3GM is proposed.
How Does S3GM Work?
Instead of learning the full probability distribution [math]\displaystyle{ p(x) }[/math] (which is complex), S3GM learns the gradient of the data distribution, called the score function: [math]\displaystyle{ s(x) = \nabla_x \log p(x) }[/math]. This tells the model which direction the data is most likely to change, making learning more efficient.
It should also be noted that real-world data is often messy—noisy, incomplete, or missing. S3GM handles this using DSM:
- Adds a small amount of noise to the data and trains the model to remove it.
- Forces the model to focus on true underlying patterns rather than memorizing raw data.
By repeatedly removing noise, the model deeply understands the true data structure—even when parts are missing.
Once the model learns the data’s structure, it reconstructs missing information using stochastic differential equations (SDEs), which has 3 terms:
- Drift Term: guides reconstruction toward likely states.
- Diffusion Term: adds controlled randomness to explore multiple solutions.
- Correction Term: uses real sensor data to ensure consistency.
How Well Does S3GM Work?
S3GM is tested on four different systems to see how well it reconstructs and predicts missing data:
Experiment 1: Predicting Chaotic Behavior (Kuramoto-Sivashinsky Equation)
- Challenges Tested:
- Sparse Spatial Data (few sensor readings)
- Fourier Transform Domain (frequency-based measurements)
- Limited Initial Data (predict future states with few frames)
- Results:
- S3GM outperformed U-Net, FNO, and DeepONet, achieving lower errors.
- Stable even with limited input data.
Experiment 2: Reconstructing Turbulent Flow (Kolmogorov Flow)
- Challenges Tested:
- Trained on low-turbulence data.
- Tested on high-turbulence data.
- Results:
- Accurately reconstructed velocity fields and vorticity patterns.
Experiment 3: Climate Data Reconstruction (ERA5 Dataset)
- Challenges Tested:
- Extreme Data Sparsity (only 1% wind speed measurements available).
- Hidden Variables (missing temperature and pressure).
- Noisy Measurements (Gaussian noise added).
- Results:
- Successfully reconstructed missing climate variables.
- Performance improved with more sensor data.
Experiment 4: Flow Around a Cylinder
- Challenges Tested:
- Spatiotemporal Gaps (only specific cross-sections measured).
- Time-Averaged Data (some measurements were only available as averages).
- Results:
- Accurately reconstructed instantaneous and time-averaged flow fields.
- Outperformed physics-informed neural networks (PINNs).
Limitations of S3GM
While powerful, S3GM has limitations:
- Computational Cost: Pre-training is resource-intensive.
- Data Quality Dependence: Best performance with diverse, high-quality data.
- Generalization Issues: May struggle with entirely new dynamics.
- Processing Speed: Iterative reconstruction can be slower than traditional methods.
Despite these challenges, S3GM is a promising tool, and if these are improved, it could be even more powerful.
Group 4 Presentation: Learning spatiotemporal dynamics with a pretrained generative model
Presented by:
Karolina Suszekm, Negin Amou, and Muhammad Azeem
Paper Citation
Z. Li et al., “Learning spatiotemporal dynamics with a pretrained generative model,” Nature Machine Intelligence, vol. 6, no. 12. Springer Science and Business Media LLC, pp. 1566–1579, Dec. 06, 2024. doi: 10.1038/s42256-024-00938-z.
Background
This paper proposes Mamba, a new type of sequence model designed to match the modeling quality of Transformers while improving computational efficiency. The key innovation is a selective state space model (SSM) that can reason based on content and scale linearly with sequence length. While providing 4–5x quicker inference than Transformers of comparable size, Mamba shows strong performance across several domains—language, music, and genomics—positioning itself as a general-purpose backbone for foundation models.
Most large models today rely on Transformers, which are powerful but inefficient, especially for long sequences. Both training and inference are bottlenecked by the quadratic scaling with the sequence length of the self-attention mechanism. Efficient substitutes have been state space models (SSMs), which have drawn growing interest. Though current versions have fallen short on jobs like language modelling, these models are recurrent and scale linearly. A major drawback the writers point out is that conventional SSMs are time-invariant, applying the same dynamics at every time step regardless of the input. This limits their capacity to complete tasks needing content-based thinking.
Main Idea
The central idea of this paper is to improve state space models by making them selective. Traditional structured state space models (SSMs) apply the same linear operations at every time step, which works well for smooth or continuous data like audio, but not for discrete tasks like language modeling. The authors argue that this is because these models cannot adapt their behavior based on the content of the input.
Mamba addresses this by allowing some of the internal dynamics of the SSM to depend on the current input token. Specifically, the model modifies the SSM parameters (like Δ, B, and C) so that they are no longer fixed, but vary depending on what the model sees at each step. This makes it possible for the model to filter, retain, or discard information in a context-aware way.
This design sacrifices the ability to use fast convolutional implementations, but the authors introduce an alternative they call a selective scan—a custom, hardware-friendly way of computing the state updates efficiently on GPU. This allows the model to maintain linear computational complexity while being much more flexible than previous SSMs.
Mamba’s architecture is also deliberately kept simple. It does not rely on attention, nor does it use the usual Transformer-style MLP blocks. Instead, it stacks blocks based on this new selective SSM design, each combining sequence modeling and nonlinear transformation in one place.
Experimental & Result
The authors test Mamba on a wide range of tasks to show both its performance and its scalability.
On synthetic tasks like selective copying and induction heads, Mamba succeeds in learning long-range dependencies that other models fail to capture. It generalizes well even when the test sequences are far longer than the ones it was trained on, reaching up to a million tokens.
In language modeling, they train Mamba on The Pile and compare it to Transformer baselines like Pythia and RWKV. Despite being smaller in size, Mamba-1.4B performs better than Pythia-2.8B on several zero-shot benchmarks. It also matches the performance of more carefully tuned Transformer setups. One major advantage is that Mamba runs faster at inference time—achieving 4 to 5 times the throughput of Transformer models—because it avoids key-value caching.
For genomics, Mamba is trained on the human genome (HG38). Its perplexity improves as the sequence length increases, which is unusual—most models perform worse on longer contexts. On a classification task involving DNA from closely related species (humans, chimps, gorillas, etc.), Mamba significantly outperforms other models, especially at longer input lengths.
In audio modeling, Mamba is plugged into the SaShiMi framework and outperforms it on waveform prediction and speech generation. On the SC09 dataset, it scores better than WaveNet and DiffWave, despite having fewer parameters.
Finally, in terms of efficiency, the new scan implementation is fast. It’s faster than both a naive PyTorch loop and FlashAttention-2 for long sequences. Mamba’s speed and memory use scale linearly with sequence length, making it practical for real-world applications with long inputs or limited compute.
Group 4 Presentation: Learning spatiotemporal dynamics with a pretrained generative model
Presented by:
Karolina Suszek, Negin Amou and Muhammad Azeem
Paper Citation
Z. Li et al., “Learning spatiotemporal dynamics with a pretrained generative model,” Nature Machine Intelligence, vol. 6, no. 12. Springer Science and Business Media LLC, pp. 1566–1579, Dec. 06, 2024. doi: 10.1038/s42256-024-00938-z.
Summary
The article "Learning spatiotemporal dynamics with a pretrained generative model" introduces a novel approach to reconstructing and predicting full-field spatiotemporal dynamics from sparse sensor measurements using a sparse-sensor-assisted score-based generative model (S³GM). The key points of the paper are as follows:
1. Problem Addressed: Reconstructing spatiotemporal dynamics (e.g., velocity, temperature, pressure fields) from sparse and heterogeneous sensor data is a significant challenge in fields like fluid dynamics, geophysics, and atmospheric physics. Traditional end-to-end learning models struggle with generalization, especially under sparse data conditions common in real-world scenarios.
2. Proposed Solution - S³GM: The authors propose S³GM, which leverages a pretrained generative model to capture the joint distribution of pretraining data in a self-supervised manner. This model is then conditioned on sparse measurements to reconstruct and predict full-field dynamics. Unlike conventional methods that directly map inputs to outputs, S³GM uses a two-step process: pretraining on vast datasets followed by conditional sampling.
3. Performance and Validation: The efficacy of S³GM is demonstrated across multiple dynamical systems, including turbulent flow modeling (e.g., Kolmogorov flow), weather/climate forecasting (using ERA5 data), and laboratory cylinder flow experiments (via PIV measurements). The model excels in zero-shot reconstruction and future-state forecasting, even with high data sparsity (e.g., 8× downsampling) and noise, outperforming baseline methods like U-Net, FNO, DeepONets, and DMD.
Key Features
Accuracy and Robustness: S³GM accurately reconstructs fields and maintains statistical fidelity (e.g., kinetic energy spectra) under varying sparsity levels. Generalizability: It performs well on unseen data, a significant improvement over traditional models. Stability: The model shows numerical stability in long-term forecasting, as evidenced by low error accumulation compared to baselines.
Applications and Datasets: The approach is tested on synthetic (e.g., Kuramoto-Sivashinsky equation, Kolmogorov flow), real-world (ERA5 reanalysis), and experimental (cylinder flow at Reynolds numbers 100–250) datasets, with all data and code made publicly available.
Comparison with Baselines: S³GM is benchmarked against seven methods, including neural network-based (U-Net, FNO, PINN) and linear (DMD, piDMD) approaches. It consistently delivers superior performance, particularly in handling complex, sparse, and noisy data.
Implications: The method offers a transformative tool for scientific and engineering applications where sensor data is limited, enhancing our ability to understand and control complex dynamical systems.
Group 4: Learning spatiotemporal dynamics with a pretrained generative model.
As presented by:
- Karolina Suszek
- Negin Amou
- Muhammad Azeem
Overview
Reconstruction of the spatiotemporal dynamics of dynamic systems is a canonically challenging task in engineering and science (may be more broadly referred to as "inverse problems"). An interesting heuristic for time dependent state dynamics problems is that they can be conceptualized as an "image to image" problem, or in some cases an "image" reconstruction problem. With this heuristic in mind, of course it makes sense to consider using generative models - which otherwise excel at image problems - for this use case. The authors introduce a specific framework which they title a "sparse sensor assisted score based generative model" (Abbreviated, [math]\displaystyle{ S^3GM }[/math]), leveraging a pretrained generative model and demonstrate its efficacy in recreating spatiotemporal system dynamics given only access to sparse sensor measurements.
Operating Principles
The model framework requires a few key components. It involves the use of an embedding network, consisting of spatial and temporal convolutions, which creates a prior that is used to inform, or steer the generator toward a physically plausible solution. At the time of generation, the generator works by denoising a tensor of random gaussians, whos trajectory is governed by the embedding provided by the prior network and penalized for discrepancy between the known sensor measurements. The result is a generated sample which is both physically plausible, and in agreement with the measurements at the known locations.
Implementation Details
The authors employ a "Video-U-Net" as a prior network, effectively characterizing the joint distribution of the spatiotemporal data. The generator is a "score-based-generative model" (SGM) which employs a denoising process governed by Stochastic differential equations (SDE). The generator is guided through the two aforementioned mechanisms, namely the output of the prior network guides the generation toward physically plausible solutions, and an observation consistency term, which penalizes proposed solutions which may very well come from the space of physically plausible trajectories, but which differ from the experimental (sensor) observations.
Discussion / Conclusion
Overall the authors demonstrate the efficacy of the S3GM model on a series of canonical dynamical systems problems. They show its ability to reconstruct dynamics for Kuramoto-Sivashinksy dynamics, Kolmogorov turblent flow, climate observations and cylinder flow, achieving very low error rates.
Related Works
An interesting paper which employs a similar methodology, is"Learning to Solve PDE-constrained Inverse Problems with Graph Networks" by Zhao, Lindell, Wetzstein (published in ICML 2023). Specifically, they employ a learned prior network to reconstruct the initial condition only, and use this in combination with a GNN to predict the forward dynamics. This coupled model could potentially take better advantage of the Generative models by using them only to create physically plausible initial conditions (when measured against real sensor locations) while using a more suitable architecture for the forward propagation.
Group 5 Presentation: Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models
Presented by:
Guoqian Li, Xinlei Xu, Wenning Xu
Paper Citation
De, S., Smith, S., Fernando, A., Botev, A., Cristian-Muraru, G., Gu, A., Haroun, R., Berrada, L., Chen, Y., Srinivasan, S., Desjardins, G., Doucet, A., Budden, D., Teh, Y. W., Pascanu, R., De Freitas, N., Gulcehre, C. (2024). Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models. arXiv. arxiv.org/pdf/2402.19427
Background
RNNs and transformers
Recurrent neural networks (RNNs) are a basic architecture for handling sequential data; they are good for short sequences but can struggle with long ones.
Transformers generally perform better than RNNs and have become dominant in recent years. However, for large sequences, they become computationally expensive for a couple of reasons. Their global attention mechanism has a quadratic complexity. Furthermore, the key-value cache and the multi-query attention cache grows linearly with sequence length.
The paper proposes several innovations to improve upon existing attention mechanisms in neural networks:
- RG-LRU layer: a novel gated linear recurrent layer designed to replace Multi-Query Attention (MQA)
- Hawk: integrate multi-layer perceptrons (MLPs) with recurrent blocks
- Griffin: combine MLPs with a mix of recurrent blocks and local attention mechanisms
Technical Contributions
Model architecture
Their proposed models -- Hawk and Griffin -- use these following structures:
- Residual block: This is the main structure in the architecture. It starts with an RMSNorm layer being applied to the hidden state, followed by a temporal mixing block. There's a residual connection. Next, there is another RMSNorm layer followed by an MLP (multi-layer perceptron) block.
- Gated MLP block: This is a gated feedforward block. There are 2 branches: one is linear and one uses a GeLU activation. This part of the structure is used for feature selection.
- Temporal mixing block: They used 3 different kinds of temporal mixing blocks in their models: global multi-query attention, local sliding window attention, and recurrent blocks (what they proposed). The recurrent block contains 2 parallel branches. One has a GeLU activation, while the other has a temporal 1D convolutional layer followed by a RG-LRU layer (a proposal in this paper).
Performance and Efficiancy
- Hawk surpasses the reported performance of Mamba on downstream tasks.
- Griffin matches the performance of Llama-2, despite being trained on over six times fewer tokens.
- Both models demonstrate the ability to extrapolate to sequences significantly longer than those encountered during training.
Real-Gated Linear Recurrent Unit (RG-LRU)
RG-LRUs are inspired by regular Linear Recurrent Units (LRUs), but with a gating mechanism. They are more computational efficient, as they avoid the quadratic complexity that transformers have. Mathematically, the layer is defined as follows.
The recurrence gate:
[math]\displaystyle{ r_t = \sigma (W_a x_t + b_a) }[/math]
The input gate:
[math]\displaystyle{ i_t = \sigma (W_x x_t + b_x) }[/math]
[math]\displaystyle{ a_t = a^{cr_t} }[/math]
The output:
[math]\displaystyle{ h_t = a_t \odot h_{t-1} + \sqrt{1-a_t^2} \odot (i_t \odot x_t) }[/math]
Summaries of Key Points
Context and Motivation
Transformers have dominated large language modeling, but their attention mechanism can be expensive for very long sequences, particularly because of the quadratic complexity of attention and a Key-Value (KV) cache that grows linearly in sequence length. Recurrent neural networks (RNNs) compress entire contexts into a fixed-size hidden state and avoid storing a cache that grows with length, which can reduce memory and speed up inference. Historically, though, RNNs have been difficult to train at scale and often underperform Transformers on complex tasks. This paper proposes new recurrent-based architectures that can match or exceed Transformer performance while retaining the training efficiency of standard deep-learning pipelines and gaining a significant advantage during inference.
Proposed Architectures
Two models are introduced: Hawk, a purely recurrent architecture, and Griffin, which mixes gated linear recurrences with local attention. Both rely on a new Real-Gated Linear Recurrent Unit (RG-LRU) layer, which extends a previously studied linear recurrency (the LRU) by adding two gating mechanisms. One gate modulates how much of the new input to incorporate each step, and another governs whether the layer updates in a standard LRU style or simply retains the previous hidden state.
RG-LRU and Recurrent Blocks
RG-LRU’s diagonal recurrence matrix dramatically lowers the computational cost for large sequence lengths. The authors also include a lightweight convolution for near-token interactions. For Griffin, they interleave RG-LRU-based blocks with a local sliding-window attention, allowing short-range local attention plus a global recurrent state. This combination helps the model handle both local details and long-distance dependencies in a memory-efficient way.
Empirical Results
The authors scale Hawk and Griffin to billions of parameters and compare them with:
1. A baseline Transformer using Multi-Query Attention (MQA),
2. The Mamba recurrent model (from prior work),
3. Llama-2.
Their main findings include:
- Both Hawk and Griffin follow standard power-law scaling in validation loss versus training compute, matching or beating Transformers.
- Hawk matches or exceeds previous recurrent models like Mamba on language modeling benchmarks, despite training on fewer tokens.
- Griffin matches Llama-2’s accuracy on multiple downstream NLP tasks while training on only about one-seventh the token count.
- During training, these models achieve speeds comparable to Transformers (thanks to parallelized or fused kernels), and in some configurations are faster at long sequence lengths (2048 tokens or more).
- At inference time, recurrent and local-attention blocks avoid the large KV caches that hamper Transformers. Hence, Hawk and Griffin show lower latency and can decode tokens at significantly higher throughput, especially for long sequences.
- Hawk and Griffin extrapolate to far longer sequences than they were trained on. They also efficiently learn synthetic copying/retrieval tasks if explicitly trained on them. However, their out-of-the-box retrieval or copying for tasks not seen at training time is still below that of Transformers.
Significance
By mixing local attention with a novel gated recurrence, Griffin retains the best of both recurrent and attention-based approaches: it can model local context with a sliding window while maintaining a small, fixed-size global state for the entire sequence. This achieves strong performance on large-scale language modeling benchmarks while offering major advantages in memory footprint and decoding speed. The results position Hawk and Griffin as compelling alternatives to purely Transformer-based architectures for scalable language models.
Related work and Future direction
Griffin builds on several key areas in efficient language modelling, particularly in recurrent architectures and hybrid approaches combining recurrence and attention: State Space Models (SSMs), Hybrid Recurrence-Attention Models, Efficient Transformer Alternatives.
Scaling Griffin to Larger Models
The paper discusses Griffin models up to 14B parameters but suggests further exploration into scaling beyond this size to compete with GPT-4 or Gemini models. Investigating how Griffin handles long-context tasks in ultra-large models would be an interesting future study.
Memory and Efficiency Optimization
Griffin already improves efficiency compared to transformers, but further research into sparsification, quantization, and hardware-specific optimizations could enhance its real-world applicability. Exploring mixed-precision training and inference optimizations for mobile and edge devices.
Extending Griffin to Multimodal Learning
While Griffin is designed for language modeling, incorporating it into multimodal tasks (e.g., vision-language models, audio processing) could expand its impact. Combining Griffin’s recurrence mechanism with diffusion models or video understanding models might be a promising direction.
Group 5 Presentation: Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models
Presented by:
Guoqian Li, Xinlei Xu, Wenning Xu
Paper Citation
De, S., Smith, S., Fernando, A., Botev, A., Cristian-Muraru, G., Gu, A., Haroun, R., Berrada, L., Chen, Y., Srinivasan, S., Desjardins, G., Doucet, A., Budden, D., Teh, Y. W., Pascanu, R., De Freitas, N., Gulcehre, C. (2024). Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models. arXiv. arxiv.org/pdf/2402.19427
Summaries of Key Points
Introduction and Motivation
This paper introduces the concept of Neural Collapse, a surprising and consistent geometric phenomenon observed in the terminal phase of training deep neural networks for classification. The authors highlight how neural networks, regardless of architecture or dataset, tend to collapse into a highly symmetrical state in which class features and classifiers align with remarkable regularity. The term terminal phase refers to the late stage of training after zero training error has already been achieved.
Description of Neural Collapse
Neural Collapse (NC) involves four key empirical properties:
• NC1 – Variability Collapse**: Within-class variability in feature representations collapses to zero.
• NC2 – Convergence to Class Means**: The feature vectors of each class converge to their class mean.
• NC3 – Equiangular Tight Frame (ETF) Structure**: The class means become equidistant and symmetrically arranged, forming a simplex ETF.
• NC4 – Alignment of Classifier and Features**: The last-layer classifier weights align with the class means.
These properties are observed across various settings, models, and datasets (e.g., MNIST, CIFAR-10, CIFAR-100, ImageNet).
Methodological Approach
The authors provide both empirical and theoretical support for neural collapse. They measure within-class variability, angles between class means, and alignment between classifiers and features to verify the four NC properties. They also propose a simplified analytical model (a deep linear network trained under certain assumptions) to theoretically demonstrate why neural collapse can emerge.
Theoretical Explanation
A key insight is the identification of minimization of cross-entropy loss with weight decay as an implicit regularizer driving networks toward this highly symmetric configuration. The authors prove that under simplified conditions, the ETF structure is an optimizer of the regularized loss. This aligns theoretical predictions with observed empirical behavior.
Experiments and Findings
Across multiple experiments (shown through figures and plots in the paper), the authors demonstrate that:
• Neural collapse becomes prominent after training accuracy hits 100%. • Even with non-linear architectures and real-world data, the ETF configuration emerges. • This behavior is observed even when networks are over-parameterized, suggesting it’s not due to constraints but rather a preference encoded by gradient descent.
Implications and Broader Impact
Neural collapse reveals that trained neural networks inherently develop geometric simplicity and symmetry in their representations. This insight could lead to better theoretical understanding of deep learning and inspire new architectures or training methods that explicitly promote these properties. It also connects to classical ideas in signal processing and geometry, such as tight frames and simplex structures.
Constructive Critique and Review
This paper offers a compelling contribution to the theoretical understanding of deep learning by identifying and rigorously analyzing the phenomenon of Neural Collapse. The authors present clear empirical evidence supported by a strong theoretical foundation, illustrating how trained deep networks tend to converge toward highly structured geometric configurations in their final training phase. This finding bridges a gap between practical neural network behavior and abstract geometric regularities, making it highly relevant to both practitioners and theorists in the field.
One of the most commendable aspects of the paper is its clear articulation of four key characteristics of neural collapse, each of which is supported by intuitive visualizations and consistent experimental evidence. The authors do an excellent job demonstrating the robustness of these phenomena across a wide range of datasets and architectures. The simplified theoretical model and analytical derivations further strengthen the paper’s foundation, offering an elegant and accessible explanation of a previously undocumented training behavior.
Despite these strengths, the work does leave room for further exploration. While the paper presents strong empirical results, most of the theoretical analysis is limited to simplified, idealized conditions such as deep linear networks or networks trained with weight decay. It is not yet clear how well these theoretical insights extend to more complex or non-convex training dynamics commonly used in real-world applications. Additionally, the paper focuses exclusively on classification tasks. It would be valuable to explore whether similar collapse behaviors occur in regression settings or in models trained on multi-label or sequence-based tasks.
Moreover, the practical implications of neural collapse remain largely speculative. While the geometric symmetry is intellectually intriguing, the paper does not provide concrete evidence that exploiting this phenomenon leads to better performance or generalization. Future work could explore whether enforcing or encouraging neural collapse during training could yield benefits in model robustness or efficiency.
Overall, the paper is well-executed and offers a fresh theoretical lens on deep learning. With additional investigation into practical applications and broader model types, this line of research could offer foundational insights into why deep networks generalize so effectively.
Group 5 Presentation: Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models
Presented by:
Guoqian Li, Xinlei Xu, Wenning Xu
Paper Citation
De, S., Smith, S., Fernando, A., Botev, A., Cristian-Muraru, G., Gu, A., Haroun, R., Berrada, L., Chen, Y., Srinivasan, S., Desjardins, G., Doucet, A., Budden, D., Teh, Y. W., Pascanu, R., De Freitas, N., Gulcehre, C. (2024). Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models. arXiv. arxiv.org/pdf/2402.19427
Summaries of Key Points
Motivation and background
Dominance of Transformers and their drawbacks: Transformers (with global attention) have driven most breakthroughs in large-scale language modeling. Despite their strengths, transformers struggle with quadratic complexity on long sequences; the large Key-Value (KV) cache grows linearly in the sequence length during inference, slowing down decoding for long inputs.
Potential of Recurrent Models: Recurrent neural networks (RNNs) offer an appealing alternative because they summarize the sequence with a fixed-size hidden state that is updated token by token (iteratively), thus avoiding a large KV cache. However, classical RNNs can be slow to train in practice and have struggled to match the performance of Transformers at large scales.
Goal: To develop recurrent architectures (and hybrids that combine local attention and recurrences) which (1) match or exceed Transformers’ performance on language modeling. (2) train efficiently at scale (both in FLOPs and hardware utilization). (3) offer improved inference efficiency (lower latency, higher throughput).
Architecture
Hawk (MLPs with recurrent blocks)
(1) Recurrent Block: Replaces self-attention entirely with a diagonal recurrent design. It processes tokens sequentially with fast iteration at inference time. (2) RG-LRU Layer: The crucial piece inside the Hawk model is the Real-Gated Linear Recurrent Unit (RG-LRU), a novel RNN layer that is stable, uses simple diagonal recurrences, and includes input and recurrence gates: Recurrence Gate [math]\displaystyle{ r_t }[/math] and input gate [math]\displaystyle{ i_t }[/math] each depend on the input [math]\displaystyle{ x_t }[/math] but not the recurrent state [math]\displaystyle{ h_(t-1) }[/math]. This yields stable, memory-friendly computations. The update equation mixes the previous hidden state [math]\displaystyle{ h_(t-1) }[/math] and a transformed input [math]\displaystyle{ x_t }[/math] using a diagonal recurrent weight [math]\displaystyle{ a }[/math]. This ensures the update is purely elementwise, minimizing overhead during inference. (3) Achieves strong performance on language modeling while avoiding the large KV cache burden of Transformers. Provides super-exponential memory (depending on the gating mechanism) and can, in practice, capture long-range dependencies.
Griffin (MLPs with a mixture of recurrent blocks and local attention)
(1) Motivation: Although recurrence alone scales well to long sequences, local patterns might still be easier to capture via (local) attention. Mixing the two yields a model that can leverage local attention for shorter-range patterns while relying on a fixed-size recurrent state for longer context. (2) Local Attention: Replaces global attention with a sliding-window approach (no quadratic cost in sequence length). The local window ensures that each token attends only to a fixed number of past tokens. (3) Interleaving Blocks. Griffin’s architecture interleaves the same RG-LRU-based recurrent block and a local attention block. This combination preserves efficiency (small recurrence state, bounded cache size), and matches or outperforms Transformers at the same scale.
Empirical Results and Main Findings
Power-Law Scaling
Hawk, Griffin, and a baseline Transformer (using multi-query attention, MQA) all exhibit a power-law relationship between validation loss and training FLOPs, resembling established Transformer scaling laws. Griffin consistently has slightly better loss–compute tradeoffs than the baseline Transformer at all scales tested.
Performance vs. Mamba and Llama-2
Hawk surpasses Mamba (a prior recurrent state-space model) on downstream tasks at 3B parameters despite training on half as many tokens (300B vs. 600B). Griffin (7B, 14B) matches or slightly outperforms Llama-2–7B/13B despite being trained on about seven times fewer tokens. Conclusion: The new architectures remain competitive on well-known benchmarks such as MMLU, HellaSwag, ARC, PIQA, and WinoGrande, even against popular open-source Transformers.
Implementation Challenges
RG-LRU’s diagonal recurrence is memory-bound on TPUs (and GPUs), as it performs elementwise updates rather than large matrix multiplications.The paper introduces a custom kernel (using Pallas on TPU-v3) that reads and writes hidden states in a linear scan and minimizes memory transfers, greatly speeding up training.
Speed Gains for Long Sequences
For short sequence lengths (e.g. 2K tokens), Griffin trains at a speed on par with the MQA baseline.As sequence lengths grow (4K, 8K), the MQA baseline slows down due to attention complexity, whereas the recurrent models maintain similar runtime. Hence, Hawk/Griffin gain a training-speed advantage at longer contexts.
Throughput and Latency Gains
Transformers (even MQA) rely on a KV cache growing linearly in sequence length. The overhead can become larger than the model parameters at long decode lengths. Hawk and Griffin keep a small, fixed-size recurrence state. This yields: (1) Lower latency: less reading/writing of large caches for extended contexts. (2) Higher throughput: bigger batch sizes fit into memory because the small recurrence state frees device memory otherwise used by attention KV caches.
Additional Observations (Long Context and Copy/Retrieve Tasks)
Extrapolation to Long Sequences
Hawk and Griffin can successfully exploit contexts longer than those used during training, showing better extrapolation than Transformers with typical positional embeddings (e.g. RoPE). Training them explicitly on 8K sequences (rather than 2K) further improves performance at longer contexts.
Copying/Retrieval Skills
Hawk alone may learn copy tasks but can sometimes lag behind Transformers in speed of learning. This matches prior observations about slow adaptation of purely recurrent state-space layers on synthetic tasks. Griffin (mixing local attention + recurrence) can learn copy tasks more quickly, similar to Transformers, and can extrapolate better to sequences beyond the training length.
Significance and Contributions
Challenging the Status Quo
This work demonstrates that RNN-based architectures can be trained at large scale as efficiently as Transformers, with the added benefit of small fixed-size hidden states.
High Performance at Fewer Tokens
Griffin achieves or exceeds Llama-2 performance while using substantially fewer training tokens, highlighting that it is not necessary to rely purely on full-sequence attention.
Practical Inference Advantage
Significantly reduced inference latency and improved throughput on long sequences is a major practical benefit in real-world applications when generating long responses and streaming data.
Blueprint for Hybrid Models
By combining local attention (for short-range patterns) with a gated linear recurrence (for stable, long-range capacity), Griffin strikes a balance that may inform future research in large language models beyond classical Transformers.
Group 5: Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models
Presenters
Guoqian Li, Xinlei Xu, Wenning Xu
Paper Citation
De, S., Smith, S., Fernando, A., Botev, A., Cristian-Muraru, G., Gu, A., Haroun, R., Berrada, L., Chen, Y., Srinivasan, S., Desjardins, G., Doucet, A., Budden, D., Teh, Y. W., Pascanu, R., De Freitas, N., Gulcehre, C. (2024). Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models. arXiv. arxiv.org/pdf/2402.19427
Introduction
Researchers at Google DeepMind have introduced a new approach to language modelling that combines the strengths of recurrent neural networks (RNNs) and Transformers. Their work, titled "Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models," presents a novel architecture that aims to overcome the limitations of traditional methods.
Background
RNNs were foundational in the early days of deep learning and natural language processing, demonstrating success in various applications, including machine translation. However, the Transformer architecture has since become dominant, achieving superior performance and hardware efficiency. Despite their success, Transformers face challenges in scaling to long sequences due to the computational demands of global attention and the increasing memory required for the Key-Value (KV) cache during inference.
Main Idea
The authors propose a hybrid model called "Griffin" that mixes gated linear recurrences with local attention mechanisms. This design aims to achieve the efficiency of RNNs in handling long sequences while maintaining the performance of Transformers. The core component of their recurrent architecture is a novel gated linear recurrent layer called the Real-Gated Linear Recurrent Unit (RG-LRU).
Experiments
The researchers conducted several experiments to evaluate their models:
• They compared the scaling behaviour of their models (Hawk and Griffin) against a Multi-Query Attention (MQA) Transformer baseline, examining the relationship between held-out loss and training FLOPs.
• They assessed the models' performance on downstream tasks, comparing them to Mamba-3B and Llama-2.
• They measured training speeds on TPU-v3 devices.
• They evaluated inference speed, considering latency and throughput.
• They tested the models' ability to handle long contexts and perform copying and retrieval tasks.
Results
The key findings of the paper include:
• Griffin demonstrates comparable scaling performance to Transformers.
• Griffin matches the performance of Llama-2 while being trained on significantly less data.
• Griffin and Hawk exhibit comparable training efficiency to Transformers on TPU-v3s.
• Griffin achieves higher throughput and lower latency during inference, especially with longer sequences.
• Griffin demonstrates strong extrapolation capabilities on long sequences and performs well on copying and retrieval tasks.
In conclusion, the Griffin architecture presents a promising direction for language models, offering a balance between performance, efficiency, and the ability to handle long-range dependencies in sequences.
Group 5: Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models
Presenters
Guoqian Li, Xinlei Xu, Wenning Xu
Paper Citation
De, S., Smith, S., Fernando, A., Botev, A., Cristian-Muraru, G., Gu, A., Haroun, R., Berrada, L., Chen, Y., Srinivasan, S., Desjardins, G., Doucet, A., Budden, D., Teh, Y. W., Pascanu, R., De Freitas, N., Gulcehre, C. (2024). Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models. arXiv. arxiv.org/pdf/2402.19427
Research Motivation
Recurrent Neural Networks (RNNs) compress the entire sequence into a fixed-size hidden state, which is updated through iterations.
Transformers outperform RNNs by employing multi-layer perceptrons (MLPs) and multi-head attention (MHA). The complexity of global attention in Transformers is quadratic, while the growth of the Key-Value (KV) cache increases linearly. With Multi-Query Attention (MQA), the cache continues to grow linearly with the sequence length.
Contribution
RG-LRU layer: a novel gated linear recurrent layer to replace MQA
Hawk: MLPs with recurrent blocks
Griffin: MLPs with a mixture of recurrent blocks and local attention
Key Findings
1. The held-out loss decreases as more training FLOPs are used. Griffin achieves slightly lower held-out loss across all model sizes.
2. Improved performance. Hawk-3B outperforms Mamba-3B on downstream tasks. Griffin-7B and Griffin-14B perform similarly to Llama-2 but were trained on approximately seven times fewer tokens.
3. Comparable training efficiency to Transformers on TPU-v3.
4. Griffin achieves significantly higher throughput than MQA transformers.
5. Griffin performs better when evaluated on sequences longer than those seen during training.
6. Griffin performs less effectively than Transformers on copy and exact retrieval tasks without fine-tuning.
Group 5: Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models
Presenters
Guoqian Li, Xinlei Xu, Wenning Xu
Paper Citation
De, S., Smith, S., Fernando, A., Botev, A., Cristian-Muraru, G., Gu, A., Haroun, R., Berrada, L., Chen, Y., Srinivasan, S., Desjardins, G., Doucet, A., Budden, D., Teh, Y. W., Pascanu, R., De Freitas, N., Gulcehre, C. (2024). Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models. arXiv. arxiv.org/pdf/2402.19427
Summary
1.Introduction of RG-LRU Layer
The paper proposes the Real-Gated Linear Recurrent Unit (RG-LRU), a novel gated linear recurrent layer. This layer combines the stability of linear recurrences with input-dependent gating mechanisms inspired by LSTMs and GRUs, enhancing the model's ability to handle long sequences efficiently.
2. Hybrid Architecture (Griffin)
The authors introduce Griffin, a hybrid model that integrates RG-LRU layers with local attention mechanisms. This combination leverages the strengths of both recurrent neural networks (efficient long-sequence handling) and attention mechanisms (local context modeling), achieving superior performance compared to pure RNNs or Transformers.
3. Scalability and Efficiency
The paper demonstrates that Griffin and Hawk (a pure RNN variant) scale efficiently with model size, following power-law scaling similar to Transformers. Griffin matches or exceeds the performance of larger models like Llama-2 while being trained on significantly fewer tokens (6 times fewer).
4. Hardware Optimization
The authors address the challenge of efficiently training diagonal RNNs on TPUs by developing a custom Pallas kernel for the RG-LRU layer. This optimization minimizes memory transfers, achieving near 3x speedup over naive implementations and ensuring competitive training speeds with Transformers.
5. Inference Advantages
Griffin and Hawk exhibit lower latency and higher throughput during inference, especially for long sequences, due to their fixed-size state (unlike the linearly growing KV cache in Transformers). This makes them practical for real-time applications and large-scale deployments.
6. Long-Context Extrapolation
The models show remarkable ability to extrapolate to sequences much longer than those seen during training. Griffin, in particular, maintains strong performance on long-context tasks, outperforming Transformers in such scenarios.
7. Copying and Retrieval Capabilities
The paper explores the models' performance on synthetic tasks like selective copying and induction heads. Griffin, with its hybrid design, matches Transformer performance on these tasks, while pure RNNs (Hawk) lag behind. This highlights the importance of combining recurrence and attention for such capabilities.
8. Empirical Validation
Extensive experiments validate the models' performance across various benchmarks (MMLU, HellaSwag, etc.), demonstrating that Griffin achieves competitive or better results than state-of-the-art baselines, even with reduced computational budgets.
Impact
The work advances the field by offering a viable alternative to Transformers, particularly for long-sequence tasks, with significant improvements in training and inference efficiency. The hybrid design of Griffin sets a new benchmark for balancing performance and computational cost in large language models.
Group 6 Presentation: Learning to (Learning at Test Time): RNNs with Expressive Hidden States
Presented by:
Zhiyang Cheng and Pingchu Zhang
Paper Citation
Sun, Y., Li, X., Dalal, K., Xu, J., Vikram, A., Zhang, G., Dubois, Y., Chen, X., Wang, X., Koyejo, S., Hashimoto, T., & Guestrin, C. (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv. https://doi.org/10.48550/arXiv.2407.04620.
Summaries of key points
This paper revisits the traditional role of RNN hidden states, proposing that they can do more than just store information—they can enable learning at test time. The authors introduce a method where a hypernetwork dynamically generates the RNN’s weights as a function of its hidden state, allowing the model to adapt its behavior on the fly. This gives rise to what they call expressive hidden states, which encode both memory and the capacity to steer the model’s future updates. The approach effectively blurs the line between training and inference, treating test-time prediction as a form of continual adaptation. This results in stronger performance in settings like few-shot learning and online learning, where flexibility and rapid adaptation are crucial. Rather than relying on explicit optimization at test time (as in typical meta-learning setups), the RNN itself becomes the learner, continuously reshaping its internal dynamics based on the sequence it's processing.
While innovative, the method introduces nontrivial architectural and computational overhead. The use of a hypernetwork to produce weights at every time step means the model must manage a more complex parameter space and could become less scalable for long sequences or larger models. There's also the risk of instability, since small changes in the hidden state can lead to large changes in the generated weights. Regularization and careful design are needed to prevent the model from diverging. Another limitation is that while the paper shows strong performance on synthetic and controlled few-shot learning tasks, it doesn’t extensively benchmark on more complex natural language or real-world sequential data, leaving questions about generalization and practicality open.
Clear explanations to aid understanding
In a standard RNN, the weights are fixed during inference—you feed in tokens or sequence elements, and the hidden state updates based on those fixed rules. What this paper suggests is: what if the hidden state itself could influence the rules? So instead of always using the same weights, the RNN can generate new ones depending on what it's seen so far. This is done using a hypernetwork—a small neural network that outputs the weights for the main RNN. So as the RNN processes a sequence, it effectively reshapes its own behavior to fit the task or data distribution it's encountering. It’s like the RNN is learning while it’s making predictions, adapting in real-time to maximize performance without needing gradient descent at test time.
Group 6 Presentation: Learning to (Learning at Test Time): RNNs with Expressive Hidden States
Presented by:
Zhiyang Cheng and Pingchu Zhang
Paper Citation
Sun, Y., Li, X., Dalal, K., Xu, J., Vikram, A., Zhang, G., Dubois, Y., Chen, X., Wang, X., Koyejo, S., Hashimoto, T., & Guestrin, C. (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv. https://doi.org/10.48550/arXiv.2407.04620.
Summaries of Key Points
Introduction of a New Research Question
This paper investigates whether incorporating sentence-level context—information from surrounding sentences—can improve the quality of translations produced by Statistical Machine Translation (SMT) systems. Traditional SMT models typically translate each sentence independently, ignoring the potential benefits of broader contextual information. The authors aim to quantify how much context matters and to what extent it can enhance translation fluency, coherence, and accuracy.
Design of a Context-Aware Re-Ranking Model
To evaluate the usefulness of sentence-level context, the authors develop a re-ranking model that operates on the n-best translation outputs of a baseline SMT system. This model uses a discriminative classifier trained to distinguish between better and worse hypotheses using features that include lexical overlap, syntactic consistency, and contextual similarity with adjacent sentences. By integrating these features, the system is able to promote translations that are more coherent within the broader discourse.
Use of Diverse and Realistic Datasets
The authors test their model on two datasets with naturally occurring multi-sentence structures: Europarl (parliamentary proceedings) and OpenSubtitles (dialogue-driven subtitles). These corpora represent both formal and conversational genres, providing a comprehensive testbed for evaluating the effectiveness of context. The subtitle data, in particular, presents challenges such as short, ambiguous sentences that strongly benefit from contextual cues.
Evaluation Through Both Automated and Human Measures
The proposed system shows consistent, though modest, improvements in BLEU scores compared to the baseline. However, human evaluation reveals clearer gains in fluency, referential consistency, and discourse-level cohesion. These results suggest that standard metrics may underestimate the value of context and highlight the importance of human judgment in translation assessment.
Contributions to Future Directions in MT
While the overall performance boost is not dramatic, this paper plays an important role in shifting attention toward discourse-aware translation. It lays the groundwork for future research in context modeling, which later becomes central to neural machine translation approaches. The authors also advocate for more nuanced evaluation techniques that capture translation quality beyond sentence-level accuracy.
Constructive Critique and Review
This paper provides an insightful early investigation into the role of sentence-level context in improving the output of statistical machine translation systems. The authors tackle an important problem: the lack of inter-sentential coherence in SMT, which treats each sentence independently. Their proposed method—context-aware reranking of translation hypotheses using a discriminative classifier—is both innovative and practical, as it builds on existing SMT outputs rather than requiring system retraining.
A major strength of the paper lies in its thoughtful experimental design. By selecting two distinct corpora, Europarl and OpenSubtitles, the authors ensure that the method is evaluated in both formal and conversational settings. This choice highlights the stronger impact of context in domains with high ambiguity and short utterances, such as subtitles. The integration of automatic and human evaluation adds further depth to the analysis, revealing that improvements in fluency and coherence may be underrepresented by standard metrics like BLEU.
However, there are limitations that reduce the generalizability and interpretability of the findings. The classifier’s performance is difficult to isolate due to limited ablation or feature-wise analysis. While the feature set is described, the contribution of individual context-related features remains unclear. A clearer breakdown of which contextual signals are most influential would have strengthened the practical implications of the work.
Furthermore, the improvements reported are modest in terms of BLEU scores, raising questions about the tradeoff between additional model complexity and measurable gains. The paper also predates neural machine translation, and while it was forward-thinking at the time, some of its techniques may appear limited by today’s standards. Nonetheless, the core insight—that context contributes meaningfully to translation quality—is validated and influential.
Overall, this is a well-motivated and carefully executed study that helped shift attention in the MT community toward discourse-aware modeling. Its methodological clarity and early focus on contextual coherence paved the way for future advances in both evaluation and translation architecture.
Group 6 Presentation: Learning to (Learning at Test Time): RNNs with Expressive Hidden States
Presented by:
Zhiyang Cheng and Pingchu Zhang
Paper Citation
Sun, Y., Li, X., Dalal, K., Xu, J., Vikram, A., Zhang, G., Dubois, Y., Chen, X., Wang, X., Koyejo, S., Hashimoto, T., & Guestrin, C. (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv. https://doi.org/10.48550/arXiv.2407.04620.
Summaries of key points
Goal: In Test-Time Training, make the hidden state into a small model that can be updated to improve the sequence modeling ability.
Background: The hidden state of RNNS is usually a fixed dimension that limits their expressiveness.
Methodology: Each step updates W gradients through a self-supervised task. Dual Form turns multi-step updates into a single matrix operation.
Result: In short sequences, the TTT model behaves similar to existing methods. In long sequences, TTT-Linear and TTT-MLP are significantly superior to Transformer and Mamba. TTT-Linear inference speed is closer to Mamba and faster than Transformer.
Constructive critiques or reviews
The presentation is clearly structured, and the slides include pictures and diagrams to help listeners understand better.
Turned on the camera to make it easier on the listener.
It can increase fluency appropriately.
Clear explanations to aid understanding
TTTN layer: Learn directly on the test sequence, and the update process is implemented through self-supervised learning.
Efficiency optimization: Improve computing efficiency with mini-batch and dual-form
Benefits of using dual form:
- Reduces memory consumption by not storing intermediate gradient matrices explicitly.
- Maximizes GPU/TPU hardware utilization by using matrix multiplications instead of sequential outer products.
Mamba: Mamba uses a state-space model for remote dependency capture.
Group 6 Presentation: Learning to (Learn at Test Time): RNNs with Expressive Hidden States
Presented by:
- Pingchu Zhang
- Zhiyang Cheng
Paper Citation
Sun, Y., Li, X., Dalal, K., Xu, J., Vikram, A., Zhang, G., Dubois, Y., Chen, X., Wang, X., Koyejo, S., Hashimoto, T., & Guestrin, C. (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv. https://doi.org/10.48550/arXiv.2407.04620
Background
For modern RNNs, performance in long context is limited by the expressive power of their hidden state of fixed size. Hence the authors introduced test-time training (TTT)
Technical Contributions
- Introduce TTT layers, where the hidden state is a model and the update rule is self-supervised learning, offering a new research direction.
- TTT-Linear, a simple implementation of TTT layers, outperforms Transformers and Mamba in evaluations.
- Improve the hardware efficiency of TTT layers through mini-batch TTT and the dual form, making TTT-Linear already a practical building block for LLMs.
Methodology
The key idea is to make the hidden state itself a model with weights, and the update rule a gradient step on the self-supervised loss. Then updating the hidden state on a test sequence is equivalent to training the model at test time.
Training a network with TTT layers
- Training the larger network as the outer loop and training weights within each TTT layer as the inner loop is preferred.
- TTT layers can replace RNN or self-attention layers in any network architecture. Training a network with TTT layers also works the same way as training any other language model.
Learning a self-supervised task for TTT
Add some outer-loo parameters to make this task learnable.
The input [math]\displaystyle{ x_t }[/math] is transformed using a learnable matrix [math]\displaystyle{ \theta_K }[/math] to create a projection [math]\displaystyle{ \tilde x_t = \theta_k x_t }[/math]
The reconstruction label is another low-rank projection [math]\displaystyle{ \theta_V x_t }[/math] which can differ from the input. Then we can create a test view [math]\displaystyle{ \theta_Q x_t }[/math]
Now the new self-supervised loss is: [math]\displaystyle{ l(W,;x_t) = \|f(\theta_k x_t; W)-\theta_V x_t\|^2 }[/math] and the output rule is modified to [math]\displaystyle{ z_t = f(\theta_q x_t;W_t) }[/math]
Summaries of Key Points
Motivation: Compressing Long Context Efficiently
Traditional Transformers handle large contexts by storing every token in a Key-Value cache, which grows linearly with sequence length and makes inference complexity scale quadratically. Modern recurrent neural networks (RNNs) like Mamba sidestep storing the entire context by having a fixed-size hidden state, which leads to linear time complexity. However, RNNs often struggle to exploit very long contexts because the fixed-size hidden state must compress a large amount of information. The authors propose a new design, Test-Time Training (TTT), that treats the hidden state as a small learnable model trained via a self-supervised loss on each incoming token—even at test time.
TTT Layers: Hidden State as a Learner
The paper reframes any sequence-modeling layer as “a hidden state plus an update rule.” For TTT layers, the hidden state is itself a small parametric or nonparametric model f, and the update rule is a step of gradient descent (or other training procedure) on each new token. Thus, at every token step, the hidden state is updated by “training” f on a self-supervised objective. Concretely, one might define a corruption or partial view of the token and train the parametric model to reconstruct the hidden or relevant aspects of the token.
Two Main Instantiations: TTT-Linear and TTT-MLP
The authors propose TTT-Linear, where the learner f is a simple linear mapping plus optional layer normalization and a residual connection. They also propose TTT-MLP, which uses a two-layer MLP as its learner, offering a more expressive hidden state. Both can be integrated into existing RNN-based or Transformer-based architectures in place of the usual self-attention or simple RNN blocks. Like other RNN layers, TTT layers compress all historical tokens into a fixed-size hidden state—but the learner can be updated more flexibly via gradient steps each time a new token arrives.
Efficiency Enhancements
Naively computing a gradient step per token would be too slow. Two key ideas improve hardware utilization: 1. **Mini-batch TTT** processes a batch of b tokens at once to parallelize the internal gradient steps. Smaller b yields more gradient steps (and better expressiveness) but can slow performance. 2. **A “dual form”** for TTT-Linear and TTT-MLP reworks the update and output computations into larger matrix multiplications, ensuring that modern accelerators (GPUs, TPUs) can exploit efficient batched operations.
Empirical Results
On language-modeling benchmarks (the Pile and Books), TTT-Linear and TTT-MLP match or exceed strong baselines (Transformer and the modern RNN Mamba) across model scales (125M to 1.3B parameters). TTT-Linear typically does as well as Mamba in short context (2k tokens) but outperforms Mamba substantially in longer contexts (8k or 32k), demonstrating that the extra expressiveness helps exploit more tokens. TTT-MLP can be even more expressive at very long contexts but can be more memory intensive. The authors also show that TTT-Linear can train and infer efficiently in wall-clock time using a specialized GPU kernel, yielding near-constant inference latency as context grows (unlike the linear growth in Transformers).
Significance and Future Work
TTT recasts the hidden-state update in an RNN-like layer as explicitly training a miniature model at test time—essentially “learning to learn” from each incoming token. With further improvements in tasks (beyond simple reconstruction), hardware kernels, and more expressive hidden states, TTT-based architectures may offer a new path toward efficient, high-performing sequence models for extremely long contexts.
Group 6 Presentation: Learning to (Learn at Test Time): RNNs with Expressive Hidden States
Presented by:
Pingchu Zhang, Zhiyang Cheng
Paper Citation
Sun, Y., Li, X., Dalal, K., Xu, J., Vikram, A., Zhang, G., Dubois, Y., Chen, X., Wang, X., Koyejo, S., Hashimoto, T., & Guestrin, C. (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv. https://doi.org/10.48550/arXiv.2407.04620
Background
Recurrent Neural Networks (RNNs) are attractive for their linear time complexity, which makes them efficient, especially for long-context inputs. However, they’ve historically struggled to match the performance of Transformers on tasks like language modeling. One key limitation is the fixed-size hidden state of RNNs, which forces them to compress all past context into a compact representation. This compression becomes increasingly difficult as the context grows longer.
Recent RNN variants like Mamba have closed the gap in scaling performance, but they still hit a ceiling: their ability to improve predictions plateaus at long context lengths (e.g., beyond 16k tokens). Transformers, in contrast, continue to benefit from more context, although at a higher computational cost due to their quadratic scaling.
The authors suggest that this limitation is tied to the expressive capacity of the hidden state. Inspired by how large language models compress vast datasets into their weights through training, they explore whether a hidden state can itself be a learnable model, updated online, even during inference.
Main Idea
The core proposal is the Test-Time Training (TTT) layer, a new kind of sequence modeling layer where the hidden state is a model, and the update rule is a self-supervised learning step. Instead of simply storing a vector or matrix, the hidden state consists of the weights of a small model (like a linear function or a 2-layer MLP). These weights are updated at each time step using gradient descent based on a self-supervised loss.
Key points:
The update happens at test time, not just during training—hence “test-time training.”
The layer sees each input token as a new self-supervised learning opportunity, updating its internal model to better predict the next token.
This approach allows the hidden state to grow in complexity without growing in size—it gains depth by learning, not by storing.
Two instantiations are tested:
TTT-Linear, where the hidden state is a linear model.
TTT-MLP, where the hidden state is a 2-layer MLP.
This method can be used in place of RNN or attention layers, and is compatible with existing architectures. Despite its novel structure, it can be trained end-to-end like other language models.
To make this practical on hardware, the authors also design efficient mini-batch updates and a dual form of the forward pass that enables good GPU utilization. These tricks allow them to run TTT layers efficiently, even faster than Transformers in some regimes.
Experimental & Result
The authors evaluate TTT-Linear and TTT-MLP against two baselines: a strong Transformer and Mamba (a recent high-performing RNN). They focus on both performance and efficiency, testing across different model sizes and context lengths.
1. Short Context (2k and 8k tokens)
At 2k tokens, TTT-Linear, Mamba, and Transformer perform similarly.
At 8k tokens, TTT models outperform Mamba. This shows that as context grows, the test-time learning approach starts to shine.
TTT-MLP generally has better perplexity than TTT-Linear, but is slower due to its more complex hidden state.
2. Long Context (up to 32k tokens)
Experiments on the Books3 subset of The Pile show that Mamba's performance plateaus after 16k tokens.
In contrast, TTT models (especially TTT-MLP) continue to improve, similar to how Transformers behave.
TTT-MLP performs best at long context, consistent with its higher expressivity.
3. Latency and Efficiency
In terms of wall-clock time, TTT-Linear is already faster than Transformers at 8k tokens and matches Mamba.
For token generation (decode time), TTT-Linear and Mamba have much lower latency than Transformers.
These efficiency gains are achieved thanks to GPU-aware design, including the use of mini-batch updates and matrix-optimized dual formulations.
4. Scaling and FLOPs
TTT-Linear uses fewer FLOPs than both baselines at equivalent perplexity.
TTT models perform well under the same training compute budgets (following the Chinchilla recipe).
They also maintain quality under increasing model sizes—from 125M to 1.3B parameters.
Group 6 Presentation: Learning to (Learn at Test Time): RNNs with Expressive Hidden States
Presented by:
Zhiyang Cheng and Pingchu Zhang
Paper Citation
Sun, Y., Li, X., Dalal, K., Xu, J., Vikram, A., Zhang, G., Dubois, Y., Chen, X., Wang, X., Koyejo, S., Hashimoto, T., & Guestrin, C. (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv: https://doi.org/10.48550/arXiv.2407.04620
What’s the Big Idea?
Most RNNs have fixed-size hidden states—they crunch all previous context into a small box and just hope for the best. But what if the hidden state could *learn* and *adapt* as it reads? That’s what this paper is about: making the hidden state itself into a **tiny learnable model** that updates itself on the fly. Think of it like an RNN that’s learning while it's predicting. Cool, right?
They call this approach **Test-Time Training (TTT)**—the model gets smarter with every token it sees, using a self-supervised loss. It’s not just training before deployment; it's adapting in real-time.
Key Concepts
- **TTT Layer**: The hidden state isn’t just a vector anymore—it’s a model (like a mini linear function or MLP) that gets updated at each step. - **TTT-Linear** and **TTT-MLP**: Two variations—one simple, one more expressive. - **Dual-form + Mini-batch TTT**: Optimized versions that make training/inference more efficient on GPUs.
Why This Matters
RNNs are fast and efficient—but they've always been limited by their fixed memory. TTT gives them a much-needed upgrade: they can now change behavior on the fly without relying on massive attention layers or key–value caches.
This method: - Helps models generalize better in few-shot and online learning scenarios. - Beats Mamba and Transformers in long-context tasks (up to 32k tokens). - Runs faster at inference time with lower memory use.
How It Works
1. Each token triggers a **self-supervised update** to the hidden state model. 2. The update is like a mini training step: it tweaks the model to better predict the next token. 3. This makes the RNN act like it’s doing meta-learning in real time.
Experimental Highlights
- **Short context (2k–8k tokens)**: TTT-Linear/MLP performs as well as or better than Transformers and Mamba. - **Long context (up to 32k tokens)**: TTT significantly outperforms Mamba, especially in perplexity. - **Latency & FLOPs**: TTT models are lean—faster inference, lower computational cost. - **Scales well**: Works across model sizes from 125M to 1.3B parameters.
Strengths of the Paper
- Fresh perspective—blurs the line between training and inference. - Strong empirical results—especially in long-context tasks. - Hardware-friendly—designed to run efficiently on modern GPUs. - Compatible with existing RNN/Transformer architectures.
Some Critiques & Considerations
- Could get unstable—small changes in input = big changes in model weights. - Needs more testing on real-world, noisy datasets (beyond synthetic ones). - Might need careful tuning for best performance (learning rate, loss scaling, etc.). - The idea is deep—might take a few reads to fully grasp.
How It Connects to Other Work
- TTT builds on meta-learning, but goes a step further: it adapts during inference, not just training. - It shares some goals with Mamba—efficient, long-context modeling—but does it differently. - Could inspire future models that combine learning and inference in more seamless ways.
Group 6 Presentation: Learning to (Learn at Test Time): RNNs with Expressive Hidden States
Presented by:
Pingchu Zhang and Zhiyang Cheng
Paper Citation:
Sun, Y., Li, X., Dalal, K., Xu, J., Vikram, A., Zhang, G., Dubois, Y., Chen, X., Wang, X., Koyejo, S., Hashimoto, T., & Guestrin, C. (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv. https://doi.org/10.48550/arXiv.2407.04620
Summaries
This paper proposes a novel approach to enhance the expressiveness and adaptability of Recurrent Neural Networks (RNNs) by allowing the hidden state itself to learn at test time. The authors introduce the concept of Test-Time Training (TTT), where the hidden state is no longer a fixed vector but the parameters of a small model (e.g., a linear layer or MLP) that gets updated dynamically through self-supervised learning.
Traditional RNNs update hidden states using fixed rules during inference, but TTT enables each time step to adapt based on what the model sees, making the RNN itself the learner. This results in better performance for long sequences and tasks requiring rapid, online adaptation—blurring the line between training and inference.
Two versions of the TTT layer are implemented:
TTT-Linear: hidden state as a linear model
TTT-MLP: hidden state as a 2-layer MLP
The method is end-to-end trainable, computationally efficient due to optimizations like dual-form updates, and shows competitive or superior performance compared to Mamba and Transformers, especially at long context lengths.
Key Contributions
Expressive Hidden States: Redefines RNN hidden states as small learnable models updated at test time.
Self-Supervised Test-Time Learning: Treats each input token as a new learning opportunity to improve next-token prediction.
Dual-Form Optimization: Reformulates weight updates into a matrix-based approach, improving computational efficiency and reducing memory usage.
Efficiency with Performance: TTT-Linear runs faster than Transformers and matches Mamba in speed while outperforming both at longer context lengths (up to 32k tokens).
Scalability: TTT models scale well across different sizes (from 125M to 1.3B parameters) and training budgets, following the Chinchilla efficiency principles.
Constructive Critiques or Reviews
While the method introduces exciting adaptability, architectural complexity increases due to the use of hypernetworks and per-step updates.
Stability concerns arise, as small changes in the hidden state could lead to unpredictable behaviors. Careful regularization is needed.
Most benchmarks are conducted on controlled or synthetic tasks. Further validation on real-world NLP datasets would enhance the paper’s practical impact.
Although efficient, TTT-MLP introduces latency overhead compared to TTT-Linear, limiting its practicality for latency-sensitive applications.
Related Works
Mamba: A high-performance RNN that uses state-space models for long-range dependency, but lacks test-time adaptability.
Meta-Learning Approaches: TTT shares the spirit of meta-learning but avoids explicit test-time optimization.
HyperNetworks: The idea of dynamically generated weights draws on prior work in hypernetworks, but TTT applies it in an online, token-wise setting.
Gradient-Based Test-Time Adaptation (TTT++): Prior methods apply gradient steps at inference but often require task-specific objectives—TTT generalizes this by embedding learning into the RNN dynamics.
Efficient Transformer Variants: While methods like FlashAttention and Longformer improve Transformer scalability, they do not adapt during inference the way TTT does.
Group 6 Presentation: Learning to (Learn at Test Time): RNNs with Expressive Hidden States
Presented by:
Pingchu Zhang and Zhiyang Cheng
Paper Citation:
Sun, Y., Li, X., Dalal, K., Xu, J., Vikram, A., Zhang, G., Dubois, Y., Chen, X., Wang, X., Koyejo, S., Hashimoto, T., & Guestrin, C. (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv. https://doi.org/10.48550/arXiv.2407.04620
Summaries
To address modern RNNs performance limit in long context due to expressive power of their hidden state of fixed size, this paper proposed a new class of sequence modelling layers with linear complexity and an expressive hidden state.
Key Contributions
This paper RNNs with Expressive Hidden States introduce TTT layers, where the hidden state is a model and the update rule is self-supervised learning. Its implementation of TTT layers called TTT-Linear outperforms Transformers and Mamba in evaluation. Moreover, it also improved the hardware efficiency of TTT layers through mini-batch TTT and the dual form making it practical building block for LLMs
Constructive Critiques or Reviews
Memory I/O bottlenecks for TTT-MLP and scaling to billion-parameter models and million-token contexts.
While this approach allows the models to adapt spontanously using only hidden state dynamics, this limits learning capacity compared to models that update weights or use attention mechanisms. In comparrison, Ttransformers model long-range dependencies more effectively through self-attention. This leads to TTT-RNN that often underperform on complex tasks requiring long memory or a global context. Nonetheless, the efficiency gains using this approach may be a step towards efficient sequence predictions.
Group 6 : Learning to (learn at test time)
Presented by :
Pingchu Zhang and Zhiyang Cheng
Overview
One of the main issues with the transformer architecture is the quadratic nature of attention mechanism. Specifically, when computing [math]\displaystyle{ \text{Softmax}(QK^T) }[/math]. The size of QK grows quadratically with the length of the text. As such we may look to other architectures with the hope that they may alleviate some of these concerns. For example, RNNs do not suffer from this issue (scaling with O(n) complexity), but they do suffer from a unique issue of their own - still related to sequence length. Specifically, RNNs struggle to recall the contributions of earlier tokens as the sequence length grows.This is due to the nature of RNN layers, which require that context be compressed into a hidden layer of a fixed size. This fact is explored in and attributed to OpenAI's scaling law paper (Kaplan, et al, 2020) where they showed that LSTMs (RNN subtype) do not scale like transformers or make use of exceptionally long context. The authors introduce a solution to this compression issue in RNNs, by introducing "TTT" layers, which they demonstrate as outperforming both Transformers and Mamba models for long context problem settings.
Governing Principles
The authors design TTT layers (Test-time training), which update the hidden state at test time, which they assert as being equivalent with test time training, hence the nomenclature. They introduce two such layers, the TTT-Linear and TTT-MLP (multilayer perceptron). Where in the former the hidden state consists of a linear model and the latter a two layer MLP. The key difference is that in the naive RNN implementation, the system state is a vector, whereas in the TTT, it could be the weights (W) of a linear layer or MLP layer, allowing for far more information to be stored and retained. Further, by representing the hidden state with a parameterized model, it can be refined at test time with a self supervised update.
Implementation Details
For the proposed TTT model, the output of a given layer is:
[math]\displaystyle{ z_t = f(x_t;,W_t) }[/math] Where [math]\displaystyle{ z_t }[/math] is the output, x the input token and [math]\displaystyle{ W_t }[/math] the learned model parameters.The key innovation is the addition of self-supervised / test time learning. That is the model is able to assess its own ability to retrive information about the sequence given its internal state, and update these internal states to better retain more information. This allows the model to take full advantage of the learnable parameters it replaced the standard embedding with, using them as an efficient means of compressing information.
Discussion / Conclusions
With TTT layers the authors prove that RNNs can compete with SOTA transformer models (at the time of publishing), Specifically, they demonstrate through means of a perplexity metric measured against sequence length that TTT layers are able to stave off the usual pitfalls of RNN/LSTMs, effectively capturing longer-range token dependencies. They further show that TTT-MLP performs slightly better than its linear counterpart, but this difference is less relevant outside of exceptionally large contexts.
Group 6 Presentation: Learning to (Learn at Test Time): RNNs with Expressive Hidden States
Key Innovations and Highlights:
1. Test-Time Training Layers:
The hidden state isn't a static vector anymore but rather a mini-model itself—like a linear layer or a small neural network.
Updating this hidden state involves a gradient descent step on a self-supervised learning objective, effectively training the model incrementally on every new token it encounters—even during inference.
2. Two Variants of TTT:
TTT-Linear: Uses a linear model as the hidden state.
TTT-MLP: Employs a two-layer MLP, offering more expressive power.
3. Linear Complexity with Transformer-level Expressivity:
Both TTT models maintain linear complexity, crucial for efficiency in handling long sequences.
Unlike traditional RNNs, which plateau in performance beyond certain sequence lengths, TTT layers continue to reduce perplexity effectively as the context grows (demonstrated clearly at 32k tokens).
4. Efficiency and Performance:
TTT-Linear outperforms Transformers and matches/exceeds Mamba, especially at long context (8k and beyond), while significantly reducing the number of computational operations (FLOPs).
With optimizations like mini-batch TTT and a dual-form computation, TTT-Linear matches the computational efficiency of the most advanced RNNs like Mamba.
Enhancing Understanding: What Makes TTT Different?
Think of traditional RNNs as short-term memory devices that quickly become overwhelmed with too much information. Transformers, on the other hand, carry around an ever-growing notebook (Key-Value caches) that's comprehensive but computationally expensive to flip through. TTT layers are like having a dynamic notebook whose pages continually rewrite themselves, efficiently storing key insights learned from recent information and discarding less relevant details.
By treating inference as continuous "micro-training," the model consistently refines its internal understanding, maintaining richer context representations without the typical constraints of fixed-size hidden states.
Constructive Critiques and Suggestions:
1. Memory and I/O Bottlenecks in TTT-MLP: Although highly promising in terms of performance, TTT-MLP suffers from increased memory overhead. Addressing this could further unlock its potential. Possible future work could include more optimized implementations or exploring more compact yet expressive intermediate states.
2. Hyperparameter Sensitivity: Mini-batch size in TTT training greatly impacts performance and computational efficiency. Further research might systematically explore adaptive strategies for selecting this hyperparameter dynamically based on context length or sequence complexit
3. Backbone Compatibility: The authors show better performance using the Mamba backbone architecture, which involves temporal convolutions. It raises a question: Would TTT layers achieve similar gains when integrated with alternative backbones or hybrid approaches?
Connections to Related Work:
Fast Weight Programmers and Hebbian Networks: The concept of updating internal model parameters dynamically has been explored before (e.g., Fast Weight Networks, linear attention models). However, explicitly integrating gradient descent steps as part of inference significantly expands the practical and theoretical possibilities.
Modern RNN Architectures (e.g., Mamba, RWKV): TTT can be viewed as the next evolution of modern RNNs, building upon recent innovations in structured state-space models but overcoming their inherent limitations in handling long contexts.
Group 7 Presentation: Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains
Presented by:
Jonathan Gallagher and Mariya Anashkina
Paper Citation
A. V. Makkuva et al., “Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains,” 2024, arXiv. doi: 10.48550/ARXIV.2402.04161.
Summaries of key points
This paper offers a theoretical framework for analyzing transformers using Markov chains, providing a new way to understand how information flows and dependencies are formed across layers. The core idea is to view the self-attention mechanism as inducing a Markov process over the sequence positions, where each attention head defines a transition probability matrix. By interpreting attention through this lens, the authors derive principled explanations for phenomena like context aggregation, token mixing, and the effects of multiple layers and heads. They show that stacking layers corresponds to composing Markov transitions, meaning that deeper transformers can be understood as performing longer-range probabilistic walks over the input. This allows them to make quantitative predictions about attention spread and mixing time, and helps formalize intuitions about how depth and attention head structure influence model behavior. Importantly, the framework applies without modifying the architecture—it’s purely analytical, making it a useful tool for understanding model internals in a mathematically grounded way.
The abstraction into Markov chains may oversimplify certain aspects of how attention operates—particularly when attention scores are data-dependent and influenced by non-linear operations (e.g., softmax and masking). The analysis tends to assume relatively clean or idealized cases, and may not fully capture the nuances of attention in real-world settings like language modeling with positional encoding or context-specific weighting. While the framework is insightful, it’s primarily useful for interpretability and analysis, not for improving model performance directly. Finally, the notion of interpretability here is mathematical, but not necessarily human-friendly—it gives you equations, not explanations in natural language.
Clear explanations to aid understanding
Think of each attention head as deciding where to “move” next in the sequence—like a random walker choosing the next word to look at. If you treat those attention weights as transition probabilities, then attention turns into a kind of Markov process, where information flows step-by-step based on those probabilities. Stacking layers is like walking further in the sequence graph, and having multiple heads is like having multiple walkers with different preferences. This view lets you talk about things like mixing time—how fast information spreads—or how different tokens influence each other over multiple layers. It’s a powerful way to bridge probabilistic modeling and deep learning, especially when trying to reason about what attention is doing beyond just visualizing heatmaps.
Group 7 Presentation: Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains
Presented by:
Jonathan Gallagher and Mariya Anashkina
Paper Citation
A. V. Makkuva et al., “Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains,” 2024, arXiv. doi: 10.48550/ARXIV.2402.04161.
Summary of Key Points
This paper introduces a theoretical framework that interprets transformers through the lens of Markov chains. The central idea is to model the self-attention mechanism as inducing a Markov process over token positions, where each attention head is seen as defining a transition probability matrix.
This interpretation allows researchers to:
- Quantify how information propagates across layers (context aggregation).
- Understand how attention heads mix token information.
- Relate model depth to longer-range probabilistic walks over inputs.
Layer stacking is interpreted as composing Markov transitions, and deeper networks thus perform longer walks over input sequences. This allows for formal predictions about mixing times and how tokens influence one another across layers. The framework is analytical and does not alter the transformer architecture, making it useful for interpretability.
However, it assumes idealized settings (e.g., fixed softmax structure, no masking or complex context-dependence), which may limit its real-world applicability. It’s mainly an interpretability tool and doesn't improve transformer performance directly. Moreover, while it provides a mathematically grounded view, its interpretability is theoretical rather than intuitive or human-readable.
Clear Explanation for Better Understanding
Imagine each attention head as a random walker choosing which token to move to next. The attention scores then define transition probabilities, turning attention into a kind of Markov process. As you stack layers, it’s like letting the walker take multiple steps.
Multiple attention heads mean different walkers, each with unique preferences. This model helps quantify how quickly information spreads (mixing time) and how tokens influence each other, going beyond mere attention heatmaps.
This probabilistic view bridges deep learning with classical stochastic processes, giving insights into how and why transformers work the way they do.
Group 7 Presentation: Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains
Presented by:
- Jonathan Gallagher
- Mariya Anashkina
Paper Citation
A. V. Makkuva et al., “Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains,” 2024, arXiv. doi: 10.48550/ARXIV.2402.04161.
Background and overview
Markov chains
Markov chains are probability models that predict the future state based on the current and prior states of a system. Language can be modeled as a high order Markov process, an idea that was popularized by Claude Shannon in a 1948 paper, and later became a building block for natural language processing tasks in the modern day. In transformers, predictions follow closely to Markov chains.
In this paper, the authors use a simplified model: a two-state (binary) Markov chain that is in its stationary distribution, making the assumption that the Markov chain has been running for an extremely long time. The term stationary distribution refers to a distribution of states that a chain converges towards as the number of steps increases indefinitely; at that point, each subsequent distribution is the same as the previous one.
Objectives in this paper
The paper introduces a framework to systematically analyze (through the lens of a Markov chain) how transformers learn to model sequential data. The objective is to explore how a transformer succeeds or struggles on first order Markov chains, a distribution that only needs one step of memory. They also do a theoretical analysis of network architecture choices for transformers, and look at how this affects the loss landscape and learning dynamics.
Weight tying
This is one of the main architectural choices in transformers that are studied in this paper. This is the practice of using the same weights in the input and output layers. Firstly, this means that whether the model is trying to read or generate the text, the tokens' representations are the same -- this is done for consistency. Secondly, this can act as a form of regularization as it gives the model fewer parameters to learn.
Theoretical Results
Significance of [math]\displaystyle{ p+q }[/math]
Recall that we are working with binary, first-order Markov processes here. Let [math]\displaystyle{ p }[/math] be the probability that a state 0 will turn to 1, and let [math]\displaystyle{ q }[/math] be the probability that a state 1 will turn to 0. Consequently, probabilities [math]\displaystyle{ 1-p }[/math] and [math]\displaystyle{ 1-q }[/math] are the probabilities that states 0 and 1 will remain unchanged, respectively.
The quantity [math]\displaystyle{ p+q }[/math] is referred to as a switching factor, and is used to characterize the overall tendency of the chain to change state. When [math]\displaystyle{ p+q \lt 1 }[/math], the system is likely to stay in its current state. When [math]\displaystyle{ p+q \gt 1 }[/math], the system is likely to change its state, and therefore will exhibit oscillatory behaviour.
The authors of the paper provide proof of the existence of the global minimum as well as bad local minima in the loss landscape. Furthermore, they look at the existence of global minima and saddle points in the cases with vs. without weight tying.
Theorem 1 (Global minimum). Let the input sequence be [math]\displaystyle{ \{x_n\}_{n=1}^N \sim \bigl(\pi(p,q), P(p,q)\bigr) }[/math] for some fixed [math]\displaystyle{ (p,q)\in(0,1)^2. }[/math] Then for all [math]\displaystyle{ (p,q), }[/math] there exists a [math]\displaystyle{ \theta_{\star}\in\mathbb{R}^{D-d} }[/math] with an explicit construction such that it is a global minimum for the population loss [math]\displaystyle{ L(\cdot) }[/math].
In other words, the authors are able to prove that for the transformer, there exists an optimal parameter configuration.
Theorem 2 (Bad local minimum). Let the input sequence be [math]\displaystyle{ \{x_n\}_{n=1}^N \sim \bigl(\pi(p,q), P(p,q)\bigr) }[/math] for some fixed [math]\displaystyle{ (p,q)\in(0,1)^2. }[/math] If [math]\displaystyle{ p+q\gt 1, }[/math] there exists an explicit [math]\displaystyle{ \theta_{\pi}\in\mathbb{R}^{D-d} }[/math] such that it is a bad local minimum for the loss [math]\displaystyle{ L(\cdot) }[/math]
Theorem 3 (Global minimum). Consider the same setting as in Thm.~1. Then for all [math]\displaystyle{ (p, q), }[/math] if [math]\displaystyle{ \theta_{\star} = (e_{\star} = a_{\star}, \dots, b_{\star}) \in \mathbb{R}^{D-d} }[/math] is a global minimum for the loss [math]\displaystyle{ L(\cdot) }[/math] in the weight-tied scenario, then its extension [math]\displaystyle{ \bar{\theta}_{\star} = (\bar{e}_{\star}, \bar{a}_{\star}) \in \mathbb{R}^D }[/math] is also a global minimum for [math]\displaystyle{ L(\cdot) }[/math] in [math]\displaystyle{ \mathbb{R}^D }[/math] in the non-weight-tied case. Further, [math]\displaystyle{ \bar{\theta}_{\star} }[/math] satisfies the same properties (ii)--(iv) as in Thm.~1.
Theorem 4 (Saddle point). Consider the same setting as in Thm.~3. For [math]\displaystyle{ p + q \gt 1, }[/math] let [math]\displaystyle{ \theta_{\pi} = (e_{\pi} = a_{\pi}, \dots, b_{\pi}) \in \mathbb{R}^{D-d} }[/math] be the corresponding bad local minimum for the loss [math]\displaystyle{ L(\cdot) }[/math] in the weight-tied scenario. Then its extension [math]\displaystyle{ \bar{\theta}_{\pi} = (\bar{e}_{\pi}, \bar{a}_{\pi}) \in \mathbb{R}^D }[/math] is a saddle point for [math]\displaystyle{ L(\cdot) }[/math] in [math]\displaystyle{ \mathbb{R}^D }[/math] in the non-weight-tied case. Further, [math]\displaystyle{ \bar{\theta}_{\pi} }[/math] satisfies the same properties (ii)--(iv) as in Thm.~2.
Main Contributions
Setup for the theoretical framework in this paper
- The dataset is first-order and binary.
- The model is a single-layer transformer, with a single attention head and no layer norm.
- The loss function is cross-entropy population loss. Note: the transformer model doesn't know that the data is Markovian; in other words, it is allowed to use the entire sequence's history, as it doesn't know that each step is dependent only on the previous one.
Empirical results
Summary of this paper
1. The authors introduce a theoretical framework that models the data source as a Markov process. This allows them to study how transformers learn sequential structure, contrasting it with other approaches that treat training data simply as i.i.d.
2. Focusing on single-layer transformers trained for next-token prediction on first-order Markov data, the paper gives a detailed analysis of the cross-entropy loss surface: There is always a set of parameters that perfectly recover the true Markov transition probabilities (thus achieving the global minimum).When the sum of the Markov chain's flipping probabilities exceeds one, the model can converge to parameters that simply predict the stationary distribution rather than the true transition probabilities. This phenomenon does not arise (or becomes only a saddle) when weights are untied or when the transformer has multiple layers.
3. Through experiments, they show that the theoretical findings match empirical behaviors. In particular: When weights are tied, the model may learn a constant (stationary) probability and fail to leverage sequential context if the transition probabilities are above a certain threshold. Removing weight tying—or increasing the transformer's depth—helps avoid such bad local minima.
4. The authors extend the analysis to higher-order processes. Surprisingly, simply increasing the transformer depth does not guarantee learning of higher-order transitions. They find, however, that restricting the attention window (rather than letting the network attend to all past tokens) dramatically improves learning of higher-order Markov patterns.
Related work
1. Analyzing Transformer Architectures:
Vaswani et al. (2017) introduced the Transformer architecture, which has since been extensively studied to understand its theoretical foundations and practical applications.
Rogers et al. (2020) provided a comprehensive analysis of BERT, a model based on the Transformer architecture, discussing its linguistic capabilities and limitations.
2. Combining Attention Mechanisms with Probabilistic Models:
Bahdanau et al. (2015) introduced the attention mechanism in neural machine translation, allowing models to focus on relevant parts of the input sequence dynamically.
Graves et al. (2014) proposed the Neural Turing Machine, combining neural networks with external memory, enabling models to learn algorithmic tasks.
Future direction
1.Enhanced Interpretability of Transformer Models:
Applying the Markov chain framework to dissect and visualize the internal workings of Transformers could lead to more interpretable models, aiding in identifying and mitigating biases.
2.Development of Hybrid Models:
Integrating Markovian structures with attention mechanisms may result in hybrid models that leverage the strengths of both approaches, potentially improving performance in tasks requiring sequential dependencies.
3.Theoretical Analysis of Attention Dynamics:
Further exploration into the mathematical properties of attention modelled as Markov processes could provide deeper insights into the stability and convergence of Transformer-based models.
Group 7 Review: An Analysis of Transformers via Markov Chains
Presented by:
Jonathan Gallagher and Mariya Anashkina
Paper Citation
A. V. Makkuva et al., “Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains,” 2024, arXiv. doi: 10.48550/ARXIV.2402.04161.
Critiques Towards Weight Tying
Potential drawbacks of weight tying includes:
- Reduced Flexibility
- Challenges in optimization: it sometimes introduce gradient conflicts or hinder training stability, especially if tied layers have significantly different roles.
- Task-specific limitations: performance depends on nature of tasks.
Review
I thought this presentation made it really easy to see the connection between Markov Chains, and Transformers when tasked to predict future tokens. The lunch menu analogy made it really easy to see how a first order markov chain behaves, and was a great general idea to discuss before diving deeper into how Markov chains behave in the language world. Theoretical results were also well explained before linking them to empirical results and discussing why the empirical results behave the way they do.
The empirical results were also well explained, showing how removing weight tying or increasing model depth can help escape bad local minima and ensure the transformers parameters are in the global minima for first order markov chains. And how smaller context windows are required to escape bad local minima for higher-order markov chains, even for deeper models (regardless of depth or weight tying).
I also thought this presentation was really well structured, and found that the slides were really easy to follow with great visual aids.
Group 7 Review: An Analysis of Transformers via Markov Chains
Presented by:
Jonathan Gallagher and Mariya Anashkina
Paper Citation
A. V. Makkuva et al., “Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains,” 2024, arXiv. doi: 10.48550/ARXIV.2402.04161.
Key Contributions
- A new Markov chains framework for analyzing transformers: model the sequential input as Markov process.
- Proposed theorems for global minimums and bad local minimum and proofs.
- For single-layer transformers, characterized the loss landscape.
- Applied these findings to higher order Markov chains.
Key findings and observations:
- For first order Markov chains, weight tying and transition probabilities significantly affect the loss landscape. Weight tying is when the embedding and linear layers are tied.
- For single-layer transformers, weight tying may introduce bad local minima. To avoid this, they increase the depth of the transformer.
- For higher-order Markov chains, a masking is necessary to correctly predict the probabilities. Increasing the depth for higher-order Markov chains does not significantly affect the performance.
Explanations for details
Single-layer Transformer: The single-layer transformer has a single-head attention with input being binary.
Property of Markov chains: Future steps in Markov chains only depend on the most recent m steps. The probabilities of transition is independent of position. The Markov chain has a steady state when it reaches a stationary distribution [math]\displaystyle{ \pi }[/math] and continues to have the same distribution in the future.
First-order Markov Chains: The next step only depend on 1 step in the past, and is independent of all other past steps.
Masking: This limits the scope of the attention layer. The attention is changed from [math]\displaystyle{ y_n=x_n+W_O \sum_{i \in [n]}{att_{n,i}} \cdot W_V x_i \in \mathbb{R}^d }[/math] to [math]\displaystyle{ y_n=x_n+W_O \sum_{i=n-W+1}{att_{n,i}} \cdot W_V x_i }[/math] where W is the number of symbols the model has attended to. Reducing W has been found to improve the performance.
Related Works
There are many existing works that tried to understand more about transformer models:
- There are works that tried to understand the transformer components (Nanda et al., Progress measures for grokking via mechanistic interpretability. In [math]\displaystyle{ \textit{The Eleventh International Conference on Learning Representations,} }[/math] 2023) but lack theoretical guarantees.
- There is work that focused on how transformer models learn the semantic structures (Li et al., How do transformers learn topic structure: towards a mechanistic understanding, In [math]\displaystyle{ \textit{Proceedings of the 40th International Conference on Machine Learning} }[/math], 2023.).
- There is work used optimization to understand the training dynamics and implicit biases for transformers trained with gradient descent (Tarzanagh et al., Transformers as support vector machines. In [math]\displaystyle{ \textit{NeurIPS 2023 Workshop on Mathematics of Modern Machine Learning} }[/math], 2023 and Tarzanagh et al., Max-margin token selection in attention mechanism. In [math]\displaystyle{ \textit{Thirty-seventh Conference on Neural Information Processing Systems} }[/math], 2023).
- There is work that proposed that weight-tying does not positively affect encoder-only transformer models (Chung et al., Rethinking embedding coupling in pre-trained language models. In [math]\displaystyle{ \textit{International Conference on Learning Representations,} }[/math] 2021).
Group 7 Review: An Analysis of Transformers via Markov Chains
Presented by:
Jonathan Gallagher and Mariya Anashkina
Paper Citation
A. V. Makkuva et al., “Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains,” 2024, arXiv. doi: 10.48550/ARXIV.2402.04161.
Theoretical foundation
Claude Shannon demonstrated that human communication was able to be reasonably approximated as a higher order Markov process in 1948.
Research Objective
The authors propose a framework to systematically analyze how transformers learn to model sequential data using the perspective of Markov chains.
By generating synthetic data from controlled Markov processes, they can precisely define the relationship between data characteristics, model architecture, and learning performance.
They also aim to theoretically describe the loss landscape of transformers and determine how specific architectural decisions impact learning.
Additionally, they explore how model complexity and the order of Markov chains influence the model’s ability to capture sequential dependencies.
Key Methodology
1. Weight tying
2. Attention masking
Conclusions
For first-order Markov chains:
1.Transformers can readily learn the transition dynamics
2.When p+q > 1 with weight tying, models may get stuck prediciting the stationary distribution
3.Removing weight tying or increasing model depth helps escape bad local minima
For higher-order Markov chains:
1.Standard transformers struggle regardless of depth or weight tying
2.Limiting the context window (masking) dramatically improves learning
3.Surprisingly, deeper models require even smaller context windows
Group 7 Review: An Analysis of Transformers via Markov Chains
Presented by:
Jonathan Gallagher and Mariya Anashkina
Paper Citation
A. V. Makkuva et al., “Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains,” 2024, arXiv. doi: 10.48550/ARXIV.2402.04161.
Research Objective
The author characterize the relationship between data properties, model architecture, and learning performance by generating synthetic data from controlled Markov processes. They characterized the loss landscape of transformers and identify how specific architectural choices affect learning.
Key Methodology
- Weight tying: use the same weights for both input embeddings and the final output layers. This ensures that a token's representation stays the same across the model, leading to more coherent and efficient learning
- Study the loss landscape of transformers when deadling with first order Markov Chain (one step of memory)
- Then study a special class of second order chain, where Xn+1 is influenced only by Xn-1
Conclusions
Architectural choices such as weight tying and attention masking significantly impact transformers' ability to learn Markovian patterns. For first-order Markov chain, transformers can readily learn the transition dynamics and removing weight tying or increasing model depth can help escape bad local minima.
For higher-order Markov chains, standard transformers struggle regardless of depth or weigh tying but by limiting the masking it can improve learning. Moreover, deeper models require smaller context windows. This suggest that unlimited context isn't always beneficial.
Group 8 Presentation: MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads
Presented by:
Nana Ye and Xingjian Zhou
Summaries of key points
This paper introduces MEDUSA, a lightweight yet effective method for accelerating inference in large language models by attaching multiple decoding heads at intermediate layers. The key idea is that, during inference, you can use these heads to predict several future tokens in parallel, reducing the number of sequential steps needed. Unlike speculative decoding (which relies on a separate draft model), MEDUSA uses the same model and architecture, just augmented with extra linear heads that predict future tokens based on intermediate hidden states. These predictions are verified by the base model in a final pass, similar in spirit to Draft & Verify, but with much lower overhead and implementation complexity. Despite its simplicity, MEDUSA achieves competitive speedups—up to 2×—on models like LLaMA, and integrates easily into existing transformer pipelines. It also preserves generation quality well, maintaining high accuracy across benchmarks without requiring retraining.
One potential limitation of MEDUSA is that its performance gains depend on the quality of intermediate predictions—if the early layers aren't predictive enough, the method may yield minimal speedup or introduce verification bottlenecks. Another concern is scalability: adding too many decoding heads could increase memory consumption or introduce architectural clutter. While the paper shows good results on standard benchmarks, it's less clear how MEDUSA performs in more complex decoding scenarios like beam search, sampling with temperature, or instruction-tuned models. Finally, although it's simple to implement, any modification to production LLM inference stacks still carries deployment costs, which MEDUSA somewhat underplays.
Clear explanations to aid understanding
Think of a transformer generating text one word at a time—slow, because each step waits on the previous. MEDUSA says: what if we could guess ahead a few tokens using partial information? It adds small prediction heads at different layers in the model, each trying to guess future tokens before the final layer finishes computing. Once these guesses are made, the base model verifies them. If correct, we skip ahead; if not, we fall back. It’s like speculative decoding, but self-contained—no second model, no complicated setup. You get the parallelism of speculative methods with the simplicity of just tweaking the model's architecture slightly.
Group 8 Presentation: MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads
Presented by:
Nana Ye and Xingjian Zhou
Summaries of key points
Goal: By adding multiple "Medusa headers" to the master model, you don't rely on an external model to predict multiple tokens at once.
Background: The core problem of slow inference in large models is not the memory bandwidth bottleneck. Autoregressive decoding generates tokens one by one and has low GPU utilization. Draft model acceleration has the problem of additional model overhead and inconsistent distribution.
Methodology: Use multiple Medusa heads to predict future tokens in parallel. Candidate tokens are organized into Tree Attention to validate multiple sequences simultaneously. Accept the longest prefixes with reasonable probability using the Typicality-based Acceptance strategy.
Result: Qwen7B vs. Zephyr7B model, on the ChatGPT dataset, Medusa 1 accelerates about 2.2x, Medusa 2 accelerates about 2.8x, and some tasks accelerate up to 3.6x, faster and with almost lossless quality.
Constructive critiques or reviews
With in-depth detailed explanation, let the audience understand more deeply.
You can try turning on the camera to increase affinity.
Clear explanations to aid understanding
Medusa 1: Train Medusa head only to save resources.
Medusa 2: Train Main model and Medusa head together. The performance degradation of the main model is avoided by a two-stage strategy
Mamba: Mamba uses a state-space model for remote dependency capture.
Group 8 Presentation: MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads
Presented by:
- Nana Ye
- Xingjian Zhou
Paper Citation
T. Cai et al., “Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads,” 2024, arXiv. doi: 10.48550/ARXIV.2401.10774.
https://arxiv.org/abs/2401.10774
Background
- As the size of LLMs grow, the speed at which they can generate tokens decreses. The bottleneck is primairly the transfer of data to/from the GPU
- Speculative Sampling is an existing solution that predicts multiple tokens in the future at once using smaller "draft" models
- Medusa instead solves this problem by adding multiple decoding heads and a tree based attention mechanism to existing LLMS
- Paper discusses the implementations of Medusa1 and Medusa2
Main Idea
The idea is that by replacing the draft model and use the heads within our own model since the best representation of a model is itself. The multiple heads predicts multiple tokens at once to leverage parallelism. This allows it to be more efficient and provide more tokens for the tree-based attention to choose. The tree-based attention is used is simulate the idea as if the tokens are being generated sequentially, by traversing through a tree from top to bottom, where top is the initial word.
Methodology
During training, each Medusa head is optimized to predict a future token given the same input context. For example, head 3 is trained to predict token xt+3 using only the context up to xt. The objective function is a standard cross-entropy loss over the predicted token distributions.
To avoid error accumulation, the training data for deeper heads (e.g., t+4, t+5) is generated using gold sequences rather than model outputs.
The decoding tree is constructed with a fixed depth and branching factor corresponding to the number of Medusa heads. Unlike beam search, this tree is evaluated in parallel, and scoring is done using a lightweight attention mechanism applied to the hidden states of the candidate paths.
For candidate selection, a method called typical acceptance is introduced as a fast alternative to rejection sampling. It accepts candidates based on whether their token-level probabilities fall within a "typical" range, reducing the number of evaluations needed during decoding.
Technical Contributions
Medusa 1:
- Uses a frozed pre-trained LLM and trains extra decoding heads on top
- Each additional decoding head predicts a token K time steps in the future
- Uses a probability loss function that scales based on the number of steps into the future
- Reduces memory usage because the backbone model is only used for hidden state extraction
- In simple terms, Medusa adds additional linear layers on top of the last hidden layer from the transformer output which are training to predict the tokens in future positions, rather than just the next token like a conventional transformer mechanism does in a typical auto-regressive manner.
Medusa 2:
- Fine tunes the LLM and trains the decoding heads at the same time.
- Encountered problems with high losses, switched to a two-stage training process:
- Stage1: train only the Medusa heads (simillar to Medusa1)
- Stage2: Train both the backbone model and the medusa heads together
Tree Attention
- Tree attention is used to enable the heads predicting later tokens to include the additional context which may have been created by the earlier medusa heads in the pipeline
- This tree structure does not occur autoregressively, however
- The top predictions from each head are fed into the tree structure as candidate tokens
- An attention mask is used to ensure that the future token prediction from the tree is based on prior tokens, not future ones past the one being dedicated
- Multiple future candidate tokens can be predicted with context-aware attention simultaneously
Self Distillation
- A dataset with prompts relevant to the desired model are created
- The full large language model predicts outputs to these prompts in a typical auto regressive manner. These prompts are used to form a training dataset for the self-distillation step
- Medusa Heads are trained on the generated training dataset
Tree Construction
- Prune Less Promising Branches: Branches with low probability of containing the next token are pruned from the tree of candidate tokens in tree attention, this reduces the computational expensiveness of MEDUSA 2
- Select the Best Candidate: From the remaining typical candidates, the longest accepted prefix is chosen for the next decoding step
Empirical Evaluation
Experiments on various LLMs show consistent 2–3 times speedups in practice without harming output quality (assessed by GPT-4 and other metrics). The authors also include ablation studies on key design choices (number of heads, attention structure, sampling thresholds), confirming the effectiveness and generality of the proposed framework.
Constructive Critique
While Medusa demonstrates promising improvements in decoding speed and flexibility, a few limitations remain:
- Training Stability: Especially in Medusa-2, jointly fine-tuning heads and the base model requires a two-stage schedule with learning rate separation and warmup—indicating some instability.
- Model Complexity: Adding multiple Medusa heads and tree attention introduces architectural complexity, which may hinder adoption or reproducibility without careful engineering.
- No open-source code: As of the paper's release, there is no official implementation, which limits replication and community engagement.
Group 8 Presentation: MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads
Presented by:
- Nana Ye
- Xingjian Zhou
Paper Citation
T. Cai et al., “Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads,” 2024, arXiv. doi: 10.48550/ARXIV.2401.10774.
https://arxiv.org/abs/2401.10774
Summaries of key points
One of the main challenges with LLMs is their slow inference process. It’s not because GPU can’t do the math faster, but because of memory bottleneck (the need to constantly move data between memory and the GPU). LLMs also generate text one token at a time using autoregressive decoding, where each token depends on the previous one, which leads to underutilization of the GPU.
One attempt to address this is Speculative Decoding, where a smaller "draft model" predicts multiple tokens at once, and the main model verifies them. While this speeds up generation, it requires maintaining a separate draft model, and it can lead to mismatches between the draft and main models, making integration into a system difficult.
So, Medusa is proposed to solve these issues.
How does Medusa work?
Core Idea of Medusa
Unlike speculative decoding, Medusa doesn’t need a separate draft model. Instead, it adds extra decoding heads to the same model, allowing it to predict multiple tokens at once (generating candidates). This reduces the dependency on previous tokens, making inference faster. Once multiple tokens are predicted, Medusa organizes these predictions into a tree structure, where each node represents a token. This structure helps Medusa evaluate several possible token sequences at once (processing candidates). This is called tree-based attention. After evaluating, Medusa picks the most likely token sequence and outputs the best one (accepting candidates).
Training Strategies of Medusa
Medusa has two training strategies to optimize its performance:
- Medusa 1: In this method, the original model (called the backbone) is frozen, meaning its parameters don’t change. Only the Medusa heads are trained. This saves computation and avoids the model forgetting what it learned originally, while improving inference speed by predicting multiple tokens at once.
- Medusa 2: In this approach, both the backbone and the Medusa heads are trained together. This boosts prediction accuracy, especially in larger models. The training starts with a two-stage process to prevent issues like high loss and large gradients from the new Medusa heads that could disrupt the backbone model. In Stage 1, only the Medusa heads are trained to specialize without affecting the main model. In Stage 2, both the Medusa heads and the backbone are trained together, with a warm-up strategy to gradually increase the learning rate for smoother adaptation. The tree attention mechanism also helps organize token continuations in a tree, allowing the model to evaluate multiple possibilities at once, speeding up inference.
Further enhancements of Medusa
The author further enhances Medusa's practical utility with three significant extensions:
1. Typical Acceptance Scheme: Instead of rejecting candidates based on strict probability thresholds, Medusa evaluates them based on how typical they are compared to the original model’s distribution. This speeds up the process without sacrificing quality.
2. Self-Distillation: Medusa can generate training data from its own output. It creates a seed dataset and then uses that data to train the Medusa heads, which helps improve the model’s efficiency in making predictions.
3. Optimized Trade Structure: Medusa improves how the model evaluates candidates by focusing on the most promising tokens, making the inference process faster and more efficient.
Benefits and Limitations
Medusa has shown great results in experiments with models like Riccuna-7B and Riccuna-13B, achieving up to 2.8 times faster performance than traditional methods, with even higher speedups for tasks like text extraction (3.62x) and coding(3.29x). It consistently outperformed speculative decoding, reaching 2.83x acceleration with Riccuna-7B, compared to 1.47x with speculative decoding. Despite the speed improvements, Medusa maintains high text quality, making it efficient without compromising accuracy. The tree-based attention mechanism further boosted speed by 1.9x. However, Medusa has some limitations, such as higher memory usage due to additional heads and the tree attention mechanism, and it can be challenging to scale for larger models and complex tasks.
Group 8 Presentation: MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads
Presented by:
Nana Ye and Xingjian Zhou
Motivation
Large language models have become increasingly prolific, but anyone who has used one, or paid to use one knows, that the inference cost is not insignificant. If one is familiar with how next token prediction works, they may ask, what if i predict multiple tokens in a single forward pass instead of a single token? That is the question that the authors of the Medusa paper asked, and what they tried to do.
Operating principle
In simple terms, the medusa model manages to predict multiple probability distributions over possible next tokens by learning multiple decoding heads. That is, provided a string of input tokens and the corresponding embeddings, multiple decoding heads are used on the same token, with each head predicting a distinctly different distribution, corresponding to the following distributions over the dictionary: [math]\displaystyle{ p_i, p_{i+1},...,p_{i+3} }[/math]. In the paper they use a total of 4 heads. The result is that each forward pass will predict four tokens instead of one.
Training Strategies
The authors introduce two training strategies, with the goal that one will be able to train Medusa without needing access to an HPC, that is they want to demonstrate that you can download an open source model Like Meta's LLama, and modify it to use medusa with limited computing resources.
The training strategies are:
Medusa 1: Frozen Backbone
This is the simplest possible training sequence. The user is going to essentially freeze the weights of their pretrained model, and train only the weights of our multiple MEDUSA heads, by computing the cross-entropy loss between our "n" token predictions from our MEDUSA heads and the next "n" tokens in the ground truth. Specifically, if we have the ground truth token for a specific token: [math]\displaystyle{ y_{t+k+1} }[/math] the associated loss with the [math]\displaystyle{ k^{th} }[/math] head is simply : [math]\displaystyle{ L_k = - logp_t^{(k)}(y_{t+k+1}) }[/math] where [math]\displaystyle{ p_t^{(k)}(y) }[/math] is the probability of token y being predicted by the [math]\displaystyle{ k^th }[/math] head. The authors note that the loss for each k grows with k, which makes intuitive sense as the further away you are trying to predict from the last token, the more uncertainty in your forecasting. This is one of the limitations of Medusa, but not one anybody should be surprised by. The total medusa loss is just the sum over each of these idnividual head losses weighed by some [math]\displaystyle{ \lambda_k }[/math] which is typically set as something like the k power of a constant < 1.0. This is so that the model treats the tokens which are closest to the last known token as more important.
Medusa 2: Joint Training
Joint training is a slightly more sophisticated approach than frozen backbone. When using joint training we consider two sources of training loss, that is the standard LLM loss, coupled with the medusa loss. This method can be beneficial as it provides some degree of fine tuning to the foundation model so it can learn embeddings that are more complementary to the multi head approach.
[math]\displaystyle{ L_{total} = L_{LM} + \alpha \sum_{k=1}^{K} \lambda_k L_k }[/math]
We balance the losses to account for the fact that the backbone loss is likely to be very small, and we do not want the medusa loss to dominate early on. Tuning these parameters is important so as to maintain learnability in the medusa heads but not lose the foundation models functionality.
Results / Conclusion
Overall, medusa models prove themselves an interesting option for those wishing to improve the efficiency of foundation models, offering flexibility for users based on available computational resources. Medusa's predictive accuracy naturally falls for tokens further along in the sequence, but for instances where one wants a model to perform while respecting hardware limitations, such as a locally hosted LLM or a distilled LLM on a mobile device, medusa models could still prove a very efficient option.
Group 8 Presentation: MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads
Presented by:
- Nana Ye
- Xingjian Zhou
Paper Citation
T. Cai et al., “Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads,” 2024, arXiv. doi: 10.48550/ARXIV.2401.10774.
https://arxiv.org/abs/2401.10774
Motivation
Slow inference with LLM: memory-bound issue & latency bottleneck, which causes the underutilizing of GPU computational potential.
Existing methods accelerate inference but introduce complexity and integration challenges.
Main idea
MEDUSA replace the complexity of speculative decoding by adding multiple lightweight decoding heads to existing Large Language Models.
Key components: No draft model; Decoding Heads; Tree-based attention
Extension
1. Typical Acceptance
Problem: Rejection sampling is inefficient at high temperatures.
Solution:
A candidate is accepted if its probability is above a certain threshold, which is adjusted based on the entropy, since higher enthropy indicates more uncertainty in the model's predictions, allowing for a broader range of candidates to be considered typical.
2. Self-Distillation
Select seed dataset - Generate Responses - Create Training dataset - Training with self-distillation
3. Optimized Tree Construction
Practical Advantages
1. Integration-friendly
2. Scalable and efficient
3. Resource-efficient
Group 9 Presentation: Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality
Presented by:
- Kaiyue Ma
- Wenzhe Wang
Paper Citation
T. Dao and A. Gu, “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality,” 2024, arXiv. doi: 10.48550/ARXIV.2405.21060.
Background
- Transformers are effective, but computationally expensive and suffer from quadratic complexity
- Structured state space models (SSMs) are an alternative that scales linearly instead and works for long range modeling
- SSMs have not recieved the same main stream improvements as transformers, and lack support for parallelization and hardware acceleration
- Structured state space duality (SSD) bridges the gaps between transformers and SSMs
Additional Background
SSM(State Space Models) are traditionally used in control theory to model a dynamic system via variables. But then from this paper https://compneuro.uwaterloo.ca/files/publications/voelker.2018.pdf, they discovered that SSM is great for describing the time cells in the brain. A useful diagram of a SSM can be found here https://cdn-uploads.huggingface.co/production/uploads/613b0a62a14099d5afed7830/G7icfkYoxIqHZcJGHM7UD.png, where, n state variables, u, m state inputs, and y, p outputs.
Technical Contributions
- Represents SSMs as semiseparable matrices and uses semiseparable matrices for efficient matrix operations.
- Uses generalized linear attention mechanism with structured masks.
- Refines the original Mamba model to yield Mamba-2, which incorporates the new structured-state-space-duality algorithms. Mamba-2 is easily parallelizable and scales better to large state sizes. Empirical results demonstrate strong performance on language modelling benchmarks, surpassing older SSM models and matching or outperforming Transformers at various model scales.
Summaries of Key Points
The paper explores the theoretical connections between Transformer architectures and Structured State Space Models (SSMs). The authors introduce the State Space Duality (SSD) framework, which bridges these two model families through the concept of structured semiseparable matrices. This framework reveals that certain attention mechanisms in Transformers can be interpreted as SSMs, providing a unified perspective on sequence modelling techniques.
Leveraging the SSD framework, the authors propose a new architecture called Mamba-2. This model refines the selective SSM approach used in the original Mamba model, resulting in a design that is 2-8 times faster while maintaining competitive performance in language modelling tasks. Mamba-2 achieves this efficiency by simplifying the SSM layer, enabling better scalability and computational speed.
The paper also introduces efficient algorithms based on block decompositions of semiseparable matrices, which enhance the computational efficiency of SSMs. These algorithms allow for larger recurrent state sizes and improve the practicality of SSMs in handling long-range dependencies within sequences.
Empirical evaluations demonstrate that Mamba-2 outperforms both its predecessor and Transformer models in terms of training efficiency and performance on language modelling benchmarks. The architecture also shows superior capabilities in associative recall tasks, highlighting its effectiveness in capturing and utilizing long-range dependencies.
In summary, the paper provides a theoretical foundation connecting Transformers and SSMs, introduces the Mamba-2 architecture as a practical application of this theory, and presents algorithms that enhance the efficiency and scalability of sequence modelling techniques.
Constructive Critique
While the paper introduces a powerful theoretical framework through State Space Duality (SSD) and demonstrates strong empirical performance with Mamba-2, several areas could be further clarified or improved:
- Theoretical accessibility: The concept of semiseparable matrices and duality between attention and SSMs is mathematically rich, but not easily accessible to a broader audience. Including more visual or intuitive explanations would improve its pedagogical impact.
- Benchmark diversity: Most experiments focus on language modeling tasks. It remains unclear how Mamba-2 performs on other domains such as vision, speech, or reinforcement learning. Since SSD is a general framework, cross-domain validation would help showcase its broader applicability.
- Scalability limitations: While Mamba-2 is more efficient than its predecessor, the paper doesn’t fully discuss how performance scales with increasing model depth or state size, especially under training constraints on real-world hardware.
- Lack of interpretability analysis: The paper does not explore how the SSD framework or Mamba-2 influences model interpretability (e.g., how information is propagated or stored over long sequences), which could be important for downstream applications.
Despite these limitations, the paper makes a substantial theoretical and practical contribution by unifying two dominant modeling paradigms and offering a concrete architecture (Mamba-2) that is both efficient and performant.
Group 9 Presentation: Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality
Presented by:
- Kaiyue Ma
- Wenzhe Wang
Paper Citation
T. Dao and A. Gu, “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality,” 2024, arXiv. doi: 10.48550/ARXIV.2405.21060.
Background & Motivation
The paper aims to unify state-space models (SSMs) and attention mechanisms through structured state-space duality (SSD), demonstrating that SSMs can be interpreted as a form of masked attention with semiseparable matrices. This approach enables the utilization of attention's hardware efficiency (e.g., optimized matrix multiplications) while preserving the linear scaling property of SSMs. Although Mamba's selective SSM is powerful, it is slower than optimized attention due to its reliance on sequential scans rather than direct matrix operations. The authors propose methods to accelerate SSMs by 2–8× without compromising performance or even enhancing it. By reformulating SSMs as matrix transformations, the paper offers novel theoretical insights, such as their equivalence to semiseparable matrices, along with practical algorithms like block decomposition for efficient computation. These contributions pave the way for hybrid architectures (e.g., Mamba-2 augmented with attention layers) and improved system-level support (e.g., tensor and sequence parallelism).
Key Points
The paper titled "Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality" investigates the theoretical and practical connections between Transformers and State Space Models (SSMs), with a particular emphasis on structured state-space duality (SSD). The key contributions of the paper include:
1. Duality Framework: The authors introduce Structured State-Space Duality (SSD), a framework that establishes a connection between State Space Models (SSMs) and attention mechanisms via structured matrices, particularly semiseparable matrices. This duality enables SSMs to be interpreted as matrix transformations, thereby uncovering novel algorithmic and architectural insights.
2. Efficiency Improvements: The paper introduces the Mamba-2 architecture, which enhances the selective SSM of Mamba to achieve 2–8× faster computation while preserving competitive performance in language modeling tasks. This improvement is realized through the utilization of hardware-efficient matrix multiplications and block decompositions of semiseparable matrices.
3.Structured Masked Attention (SMA): A generalization of linear attention is introduced, in which the attention mask is replaced by a structured matrix (e.g., semiseparable matrices). This substitution enables subquadratic computational complexity and facilitates efficient autoregressive inference.
4. Hybrid Models: The paper demonstrates that combining SSMs with attention layers (e.g., 10% attention layers in Mamba-2) can improve performance, suggesting complementary strengths between the two paradigms.
Contirbutions
- Theoretical Connections: It establishes a rigorous equivalence between SSMs and semiseparable matrices, unifying recurrent, convolutional, and attention-based sequence models under a single framework.
- Algorithmic Innovations: The SSD algorithm optimizes SSM computation by blending linear (recurrent) and quadratic (attention-like) forms, achieving linear complexity while leveraging modern hardware.
- Mamba-2 Architecture: This new architecture improves upon Mamba by simplifying projections, enabling tensor parallelism, and incorporating larger state dimensions, resulting in better scalability and efficiency.
- Empirical Validation: The authors validate Mamba-2 on synthetic tasks (e.g., associative recall) and language modeling, showing it outperforms Mamba and matches or exceeds Transformer++ in perplexity and downstream tasks.
Constructive Critiques
- Expressivity Trade-offs: The adoption of scalar-identity structure for A matrices in Structured State-Space Duality (SSD) may constrain model expressivity relative to general diagonal State Space Models (SSMs). The paper could provide a more in-depth analysis of the trade-offs between hardware efficiency and model flexibility.
- Attention Approximation: The negative results observed for kernel approximations (e.g., Performer, cosFormer) in Mamba-2 indicate that the advantages of Structured State-Space Duality (SSD) may not fully translate from linear attention mechanisms. A more in-depth investigation into the reasons for the underperformance of these methods could further enhance the study.
- Broader Applicability: The focus is heavily on language modeling. Evaluating SSD on other domains (e.g., vision, reinforcement learning) could demonstrate its generalizability.
- Implementation Complexity: Although the SSD algorithm is simpler than Mamba's selective scan, its block decomposition may still present implementation challenges for practical adoption. Conducting additional ablation studies on parameters such as chunk size and parallelism levels could provide valuable guidance for practitioners.
Relationships to Other Works
The paper extends research in efficient sequence modeling, connecting various approaches through a unified framework. It builds on recent progress in State Space Models (SSMs), particularly from S4 to Mamba. S4 and S4D introduced diagonal-plus-low-rank matrices for long-range modeling, while Mamba's selective SSMs improved performance on dense data like language. The SSD framework generalizes these models using semiseparable matrices and introduces hardware-aware optimizations, making Mamba-2 significantly faster.
Connections to linear attention methods form another key thread. The paper generalizes Katharopoulos et al.'s linear attention with structured masked attention via semiseparable matrices. This links SSD to models like RetNet (fixed exponential decay) and GateLoop (input-dependent gating). GLA's chunkwise computation resembles SSD's block decomposition but lacks SSD's theoretical unification.
The work also intersects with efforts in efficient recurrent architectures. RWKV's attention-like gating shares similarities with SSD's matrix-based approach, though SSD offers a more rigorous mathematical foundation. Griffin's combination of SSMs with local attention and xLSTM's expanded state dimensions align with SSD's themes, suggesting SSD provides a unifying perspective.
On the systems side, the paper complements hardware-efficient Transformer implementations. While FlashAttention optimizes attention kernels, SSD advances SSM-based models. Monarch Mixer's structured matrices share some ideas with SSD but apply them differently. These connections highlight SSD's contribution to efficient deep learning architectures.
Theoretically, SSD bridges modern sequence modeling with classical numerical linear algebra. Semiseparable matrices connect to structured matrix computations, offering new insights into model representation. This grounding may inspire future algorithmic improvements.
Group 9 Presentation: Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality
Presented by:
- Kaiyue Ma
- Wenzhe Wang
Paper Citation
T. Dao and A. Gu, “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality,” 2024, arXiv. doi: 10.48550/ARXIV.2405.21060.
Introduction
The paper "Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality", by Sri Dao and Albert Gu establishes a theoretical framework connecting Transformers and State Space Models - SSMs.
This framework, termed Structured State Space Duality - SSD - brings these two prominent sequence architectures leading to the development of Mamba-2. This model enhances efficiency while painting competitive performance in LLMs.
Key Contributions
(1) Establishing Structured State Space Duality
The authors demonstrate that structured SSMS and attention mechanisms are closely related through structured matrices, specifically semiseperable matrices. This insight reveals that various sequence models can be interpreted as different parametrization of these matrices, providing a unified understanding.
Two perspectives are introduced to implement this duality:
Matrix representation: viewing sequence models as matrix transformations highlighting how SSmS can be represented using semiseparable matrices. This has sub-quadratic parameters.
Tensor Contraction Representation: This illustrates how the computations in Attention mechanisms can be reformulated in terms of tensor contractions; aligning them with SSM operations
By framing SSMs as well as attention mechanisms within the SSD framework, the paper enables the transfer of algorithmic and system optimizations between these models, fostering advancements in efficiency and stability.
(2) Development of Mamba-2 Architecture
Leveraging the SSD framework, Mamba-2 is an improved version of the Mamba-1 architecture. Mamba-2 refines in the selective SSM layer, and results in a significantly faster enhancement.
Mamba-2 therefore achieves 2-8 times faster performance compared to its predecessor while maintaining competitiveness with transformers in language modelling tasks. This demonstrates the practical benefits of applying the SSD design.
(3) Efficient Algorithms Through SSD
The paper presents efficient algorithms derived from the SSD framework that optimize the computation. These algorithms reduce the complexity often associated with traditional sequential models.
SMA, a novel attention variant, is introduced which benefits the structured properties of the SSD. This leads to a more efficient attention computation.
Applications and Impact
The SSD framework offers a new paradigm for designing sequence models, allowing practitioners to harness the strength of both SSMs and transformers. This leads to models that are both computationally efficient and effective in capturing long-range dependencies.
By reducing computational complexity, the insights from this paper facilitate the development of models that can handle longer sequences and larger datasets; addressing a common limitation in sequence modelling, thereby allowing scalability.
Finally, the theoretical connections established within the paper table the application of optimization techniques across different model architectures. This lays the foundations for more unified and efficient approaches to sequence modelling through unisons.
This work not only enhances theoretical understanding but also leads to practical advancements, exemplified by the development of the Mamba-2 Architecture.
Group 9 Presentation: Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality
Presented by:
- Kaiyue Ma
- Wenzhe Wang
Paper Citation
T. Dao and A. Gu, “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality,” 2024, arXiv. doi: 10.48550/ARXIV.2405.21060.
Summaries of key points
Goal: SFB connects SSM and Transformer, combining the best of both.
Background: Transformer can handle long distance dependencies but is computationally complex. SSM model is linear complexity, but it is difficult to fully parallel and accelerate.
Methodology: The SFB framework, with a semi-separable matrix structure, relates the SSM to the attention mechanism. The SSD algorithm is designed, the matrix is divided into diagonal blocks and non-diagonal blocks, and multi-head sequence transformation, parallel projection and kernel methods are added.
Result: Mamba2 has the same performance as Transformer but 2-8 times faster training.
Constructive critiques or reviews
There are clear illustrations, so that the audience can be more concise and intuitively understanding.
It's well structured, you can get a little more detail into place, and you can include more detail in your slides.
Clear explanations to aid understanding
Semi-separable matrix: Compress a large matrix into structured blocks that are easier to compute.
Block decomposition of SSDS: a matrix diagonal for attention work, non-diagonal for fast recursion.
Group 9 Presentation: Transformers are SSMs – Generalized Models and Efficient Algorithms Through Structured State Space Duality
Presented by:
Kaiyue Ma and Wenzhe Wang
Paper Citation
T. Dao and A. Gu, “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality,” arXiv, 2024. doi: 10.48550/ARXIV.2405.21060
Summaries
This paper proposes Structured State Space Duality (SSD), a novel framework that unifies Transformers and State Space Models (SSMs). By leveraging semiseparable matrices, SSD connects the two models both theoretically and computationally.
The SSD framework enables SSMs to be interpreted as a form of structured attention, allowing them to adopt hardware-efficient matrix operations typically used in attention mechanisms. Based on this insight, the paper develops the Mamba-2 architecture—an SSM-based model that is 2–8× faster than previous designs, while maintaining or even improving performance on language modeling tasks.
The paper offers both theoretical contributions (matrix/tensor formulations of sequence models) and practical improvements (block decomposition, structured masked attention), presenting a new paradigm for designing fast and expressive sequence models.
Key Contributions
Structured State Space Duality (SSD)
Establishes a bridge between SSMs and attention via semiseparable matrices.
Two perspectives:
Matrix Representation: Shows how SSMs can be implemented as structured matrix transformations with sub-quadratic properties.
Tensor Contraction Representation: Reformulates attention as tensor operations, aligning attention and SSMs under a shared computational lens.
Mamba-2 Architecture
Improves on Mamba by simplifying selective SSM layers.
Supports tensor and sequence parallelism, enabling scalability and hardware optimization.
Maintains high accuracy while achieving significant speed-ups.
Efficient Algorithms via SSD
Proposes block decomposition for semiseparable matrices to allow fast, parallel computation.
Introduces Structured Masked Attention (SMA), which generalizes linear attention using structured matrices, reducing inference complexity while preserving expressiveness.
Hybrid Sequence Models
Demonstrates that models like Mamba-2 can benefit from a small percentage of attention layers (e.g., 10%), leading to performance boosts without major efficiency trade-offs.
Constructive Critiques or Reviews
Expressivity Trade-off: SSD uses scalar-identity matrix A for simplicity, but this may limit modeling flexibility compared to more general SSMs.
Domain Specificity: Experiments focus heavily on NLP tasks. Broader validation across vision or RL tasks would strengthen the case for SSD.
Implementation Overhead: Although simpler than Mamba-1, SSD’s block decomposition still adds engineering complexity. Ablation studies on chunk size and parallelism could help guide practical usage.
Kernel Approximation Limitations: Some approximations (like Performer or cosFormer) underperform when used within SSD, suggesting limitations in directly porting techniques from linear attention to SSMs.
Related Works
SSM Lineage: Builds on S4, S4D, and Mamba, continuing the trend of enhancing long-range modeling with structured matrices.
Linear Attention Connections: Extends ideas from Katharopoulos et al., RetNet (exponential decay), and GateLoop (input-dependent gating) through the use of Structured Masked Attention.
Recurrent Model Efficiency: Shares similarities with RWKV’s attention-like gating and Griffin’s hybrid SSM-attention strategy. SSD offers a theoretically grounded unification of these approaches.
System-Level Optimization: Complements tools like FlashAttention by advancing SSM-based models. Monarch Mixer shares matrix structuring ideas but differs in architecture.
Mathematical Foundations: Draws on classical numerical linear algebra via semiseparable matrices, opening new directions in sequence model efficiency and design.
Group 9 Presentation: Transformers are SSMs – Generalized Models and Efficient Algorithms Through Structured State Space Duality
Presented by:
Kaiyue Ma Wenzhe Wang
Paper Citation
T. Dao and A. Gu, *Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality*, 2024, arXiv: [10.48550/ARXIV.2405.21060]
So, What’s This All About?
This paper tackles a big idea: Transformers and Structured State Space Models (SSMs) aren’t as different as they seem. What if we told you they’re actually two sides of the same mathematical coin?
Using something called **Structured State Space Duality (SSD)**, the authors show that attention mechanisms in Transformers can be seen as a special case of SSMs—specifically, semiseparable matrix operations. This realization lets them create models that combine the efficiency of SSMs with the expressiveness of attention.
And yes, they take it a step further by introducing **Mamba-2**, a faster, more parallel version of the original Mamba.
What’s New and Cool?
Here are the main contributions from the paper:
**(1) Structured State Space Duality**
- SSD shows how to map attention to SSMs using semiseparable matrices. - You can view both models through matrix and tensor representations, revealing a deep mathematical connection.
**(2) Mamba-2 Architecture**
- Mamba-2 is a refined version of Mamba, now optimized using SSD principles. - It’s **2–8× faster** than the original Mamba, with the same or better performance on language modeling benchmarks. - Plus, it scales to large state sizes and runs smoothly on modern hardware.
**(3) Efficient Algorithms for Long Sequences**
- By leveraging SSD, they optimize matrix multiplications and block decompositions. - The result? Efficient, parallel-friendly inference—even for very long input sequences.
**(4) SMA (Structured Masked Attention)**
- A new take on linear attention, replacing it with a structured matrix formulation. - Maintains accuracy, reduces complexity, and plays well with modern accelerators.
Real-World Impact & Experiments
- **Language Modeling**: Mamba-2 beats or matches Transformers on perplexity across standard datasets. - **Speed**: Achieves faster inference with lower computational costs. - **Scalability**: Easily extends to larger models and longer contexts thanks to block-structured matrices. - **Hybrid Potential**: Shows that combining attention + SSMs (hybrid models) may give the best of both worlds.
Let’s Talk Strengths
- The paper builds a solid theoretical bridge between Transformers and SSMs - Offers practical improvements—Mamba-2 isn’t just math; it works! - Hardware-friendly: optimized for real-world use, not just academic theory - Great generalization to other architectures (like FlashAttention, Monarch Mixer, xLSTMs)
Any Caveats or Challenges?
- **Math-heavy**: SSD is powerful, but not easy for everyone to digest. More visuals/examples would help. - **Narrow evaluation**: Most experiments focus on language modeling. Other areas (like vision, speech) need testing. - **Interpretability**: With all these matrix tricks, it's harder to see how info flows through the model. - **Implementation complexity**: Setting up SSD-based models isn’t as plug-and-play as Transformers yet.
Helpful Analogies & Clarifications
- **Semiseparable matrices**: Think of compressing a huge attention matrix into smarter, structured blocks. - **Block decomposition**: Imagine breaking attention into diagonal + off-diagonal parts—each optimized for a different type of sequence interaction. - **SSD = a unifying lens**: It’s like saying Transformers and SSMs are just different “views” of the same underlying engine.
Final Thoughts
This paper doesn’t just tweak existing models—it proposes a unified theory that explains and improves both Transformers and SSMs. With SSD, we can create hybrid models that are faster, more efficient, and still powerful. And with Mamba-2, we get real performance gains without sacrificing quality.
It’s an exciting step forward that’s both theoretical and practical—bridging two major model families and opening up new paths for scalable, efficient sequence modeling.
Group 9 Presentation: Transformers are SSMs – Generalized Models and Efficient Algorithms Through Structured State Space Duality
Presented by:
Kaiyue Ma and Wenzhe Wang
Paper Citation
T. Dao and A. Gu, “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality,” arXiv, 2024. doi: 10.48550/ARXIV.2405.21060
Objective
Develop connections between Transformers & SSMs
Transformers: Effective but computationally expensive. SSMs:Linear complexity, efficient long-range modeling.
Contributions
1. Introduce SSD framework
SSD bridges attention and SSMs by structured matrices.
2. Present Mamba-2 architecture with improved performance and speed.
Experimental Results
Mamba-2 trains 2-8 times faster than original Mamba.
Improved long-sequence modeling.
Surpasses standard Transformer benchmarks
Future Directions
structured matrices, efficiency enhancements, broader applications
Group 9 Presentation: Transformers are SSMs – Generalized Models and Efficient Algorithms Through Structured State Space Duality
Presented by:
Kaiyue Ma and Wenzhe Wang
Paper Citation
T. Dao and A. Gu, “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality,” arXiv, 2024. doi:10.48550/ARXIV.2405.21060
Objective
Explore the theoretical and practical connection between Transformers and State Space Models (SSMs).
Transformers: Powerful sequence models, but suffer from quadratic computational costs.
SSMs: Linear time complexity, well-suited for long-range dependencies.
Key Contributions
Structured State Space Duality (SSD) Framework
Establishes a formal link between attention mechanisms and SSMs via structured matrices.
Unifies seemingly distinct modeling approaches under a common mathematical foundation.
Introduction of Mamba-2 Architecture
An enhanced SSM-based model demonstrating significant improvements in training speed and sequence modeling capabilities.
Efficient and scalable, designed to outperform traditional Transformers.
Experimental Highlights
Training Speed: Mamba-2 trains 2× to 8× faster than the original Mamba.
Performance: Achieves state-of-the-art results on several long-sequence benchmarks.
Efficiency: Maintains low computational overhead while improving expressivity.
Future Directions
Further exploration of structured matrix representations.
Optimization techniques to boost training and inference efficiency.
Broader integration into diverse domains such as vision, language, and time-series analysis.
Group 10 Presentation: Accelerating Large Language Model Decoding with Speculative Sampling
Paper Citation
C. Chen, S. Borgeaud, G. Irving, J.-B. Lespiau, L. Sifre, and J. Jumper, ‘Accelerating Large Language Model Decoding with Speculative Sampling’, Feb. 02, 2023, arXiv: arXiv:2302.01318. doi: 10.48550/arXiv.2302.01318.
https://arxiv.org/abs/2302.01318
Background
- Traditional autoregressive decoding for large language model is computationally expensive, as the entire model has to run for each additional token which is generated
- Transformer neural networks are typically memory bandwidth limted, and using quantization or distillation to make smaller models are solutions which have been used in the past to improve the performance of LLMs
Technical Contributions
- Speculative sampling was developed to increase the speed of LLM decoding without meaningfully reducing it's performance on predicting future tokens compared to the base model
- Generates a draft of k future tokens using a smaller model
- Score the proposed tokens using the base model
- A modified rejection sampling scheme was developed by the authors in the paper
- The acceptance of a draft token is based on the minimum of 1 and the ratio of the target model's probability to the draft model's probability for that token.
Explanations to Aid Understanding
- Transformer Decoding: This is the process by which a large language model generates text. Given an input prompt, the model sequentially selects the most probable next token, then uses that token to predict the subsequent one, and so on, and this process is computationally intensive.
- Speculative Sampling: Unlike traditional decoding where one token is generated at a time by the target model, speculative sampling aims to generate multiple tokens in parallel by using a faster, potentially smaller "draft model" to propose a short sequence (the draft). The target model evaluates these drafted tokens, and a rejection sampling mechanism decides which ones to accept, ensuring the output remains consistent with the target model's capabilities.
- Parallel Scoring: Instead of computing the logits for each drafted token sequentially, the method computes the logits for all (K+1) tokens in the draft at the same time. The presentation notes that the computing time for this parallel process is similar to sampling a single token with the target model, which is a key factor in the potential speedup. The key insight is that the model inference pass is dominated by memory bandwidth and inter-device communication rather than purely by the token count. By handling several drafted tokens per pass, overall decoding time is greatly reduced.
Summaries of Key Points
- Decoding Challenges in LLMs: Traditional autoregressive sampling methods generate one token at a time, leading to inefficiencies, especially as model
- Speculative Sampling Overview: SpS utilizes a draft model, which is a smaller and faster version of the target LLM, to propose a sequence of tokens. The target model then evaluates this proposed sequence, accepting or rejecting tokens based on a modified rejection sampling scheme that ensures the final output aligns with the target model's distribution.
- Algorithm Efficiency: The draft model generates a short sequence (e.g., K tokens) in parallel. The target model scores this sequence, and tokens are accepted or resampled as needed, allowing for the potential acceptance of up to K+1 tokens per iteration. This parallel approach contrasts with traditional methods that generate tokens sequentially, thereby reducing decoding latency.
- Empirical Results: Implementing SpS with Chinchilla, a 70-billion-parameter language model, resulted in a 2 to 2.5 times speedup in decoding across benchmark tasks such as XSum and HumanEval. These speed improvements were achieved without degrading sample quality or requiring changes to the target model's parameters or architecture.
- Advantages of Speculative Sampling: Maintains the integrity of the target model's output distribution. Requires no alterations to existing model architectures, facilitating easy integration into current systems. Demonstrates versatility across various tasks and decoding methods, making it broadly applicable in LLM deployments.
Group 10 Presentation: Accelerating Large Language Model Decoding with Speculative Sampling
Presented by:
Danyang Zhao
Speculative Sampling: An In-Depth Explanation
Speculative Sampling is a technique introduced to speed up text generation for large language models (LLMs) without significantly affecting output quality. It works by leveraging a smaller, faster “draft” model to propose multiple candidate tokens, which are then verified and accepted or rejected by the larger model. This method reduces latency during decoding while maintaining output quality close to the original model.
But, what specifically is Latency? Latency latency refers to the time delay between a user requesting text generation and the model producing the output. More formally, it is the time required to generate each token in an autoregressive model.
But why do we need Speculative Sampling? Because there re problems with the Standard Autoregressive Sampling. Mainly:
- LLMs generate text token-by-token. Each step depends on the previous tokens, making sequential decoding slow.
- The computational bottleneck comes from generating the next token, which requires a full forward pass of the model.
Therefore the goal of speculative sampling is to reduce the number of forward passes of the large language model per generated token.
How Speculative Sampling Works:
1. There are two models, a Draft model and a Large Model. The smaller draft model, which is cheap and fast to compute, generates a set of k speculative candidate tokens. Then, the large model, which is expensive yet accurate, verifies these tokens and either accepts or rejects them.
2. The Draft model proposes multiples tokens, that is, instead of sampling just on token a each step, the draft model generates a sequence of k candidate tokes: [math]\displaystyle{ $x_{t+1}, x_{t+2}, ..., x_{t+k} }[/math] using the standard autoregressive decoding.
3. Now, the large language model verifies the proposal. This is done by calculating the probability of each proposed tokens:
[math]\displaystyle{ p(x_{t+1} | x_{\le t}), p(x_{t+2} | x_{\le t+1}), ..., p(x_{t+k} | x_{\le t+k}) }[/math]
For each Candide, there are only two possibilities. 1). The token is accepted and is added to the output. 2). The token is rejected and the large model takes over and directly generates a new token.
4). Since the drat model is smaller, it generates multiples speculative tokens quickly. Then, the forward model only computes a few forward passes, making the process more efficient. This allows for faster decoding by efficient verification.
Now, let us go over a bit more of the mathematics.
For each token [math]\displaystyle{ x_i }[/math], we check:
[math]\displaystyle{ p(x_i | x_{\lt i}) \ge q(x_i | x_{\lt i}) }[/math]
Where p is the probability distribution of the large model and q that of the draft model. If true, it is accepted, else, it is not.
This is paired with a Metropolis-style acceptance probability, which ensures the final sampled tokens remain statistically valid while speeding up computation.
The acceptance probability can then be calculated as follows:
[math]\displaystyle{ \alpha_i = \min \Big( 1, \frac{p(x_i | x_{\lt i})} {q(x_i | x_{\lt i})} \Big) }[/math]
Group 10 Presentation: Accelerating Large Language Model Decoding with Speculative Sampling
Presented by:
Danyang Zhao
Summaries of key points
Goal: A speculative sampling algorithm is proposed to accelerate the decoding process of the large prediction model.
Background: Traditional Transformer is slow and costly, and existing methods cannot effectively improve the generation speed.
Methodology: A small draft model is used to generate a token sequence of length k, and the logits of k+1 tokens are computed in parallel with the target large model. The modified rejection sampling method is used to decide whether to accept the draft token.
Result: On Chinchilla, the output quality is almost unaffected, and the generation speed is significantly improved.
Constructive critiques or reviews
The presentation can be more detailed and provide more examples to help you understand.
Increase fluency and reduce pause time.
While the presenter explained the concept verbaly, graphical models and diagrams would help the viewers have a better understanding of the model. For instance, in addition to providing the algorithm psudocode, a diagram of the transformer model with latency and bottleneck highlighted would showcase why their proposed parralel sampling approach improved speed.
Clear explanations to aid understanding
Compare the probability of generating the target model and the draft model to decide whether to accept a token.
The output should not deviate from the target model distribution.
Compared with the distillation model, the speculative sampling model is not changed, and the acceleration is direct.
Group 10 Presentation: Accelerating Large Language Model Decoding with Speculative Sampling
Presented by
Danyang Zhao
Paper Citation
C. Chen, S. Borgeaud, G. Irving, J.-B. Lespiau, L. Sifre, and J. Jumper, ‘Accelerating Large Language Model Decoding with Speculative Sampling’, Feb. 02, 2023, arXiv: arXiv:2302.01318. doi: 10.48550/arXiv.2302.01318. https://arxiv.org/abs/2302.01318
Background
- In Transformer-based language models, decoding is typically done autoregressively, and one token is selected at a time based on previously generated tokens. This is computationally expensive and memory bandwidth bound, meaning the process is limited by the speed at which the data can be transferred rather than compute itself.
- The goal of the paper is to reduce the latency of decoding without modifying the original language model, and they introduce speculative sampling.
- Speculative sampling is a method that is used to generate multiple tokens in parallel, rather than one at a time, to speed up inference.
Summaries of Key Points
- The speculative sampling process involves 3 steps. The first step is draft generation. This is when a lightweight draft model generates a sequence of k tokens in parallel. Then, the target model, which is typically larger than the draft model, is used to score the draft tokens. After these tokens are scored, the method accepts or rejects each draft token using a modified rejection sampling algorithm, ensuring the final output matches the distribution of the target model. This sampling scheme is provably correct since the output distribution does match that of the target model and is not just an approximation.
- Speculative sampling achieves up to 2-3x speedups in decoding times without loss in generation quality, especially when using efficient draft models like small LMs or knowledge distillation.
- Chinchilla was used for target decoding and compared speculative sampling with standard autoregressive sampling, showing significant latency reduction at little to no cost in output quality.
Group 10 Presentation: Accelerating Large Language Model Decoding with Speculative Sampling
Presented by
Danyang Zhao
Paper Citation
C. Chen, S. Borgeaud, G. Irving, J.-B. Lespiau, L. Sifre, and J. Jumper, ‘Accelerating Large Language Model Decoding with Speculative Sampling’, Feb. 02, 2023, arXiv: arXiv:2302.01318. doi: 10.48550/arXiv.2302.01318. https://arxiv.org/abs/2302.01318
Objective
Transformer sampling is typical memory bandwidth bound, time needed is proportional to the number of parameters and size of the memory. Current studies have proposed quantisation, distillation, and smaller models to address this issue. Also cache of keys and values is maintained for every attention layer. Recent studies did not focus on increasing the speed for decoding. This paper focus on reducing the latency of transformer decoding without altering the capabilities of the original model.
Summaries of Key Points
Speculative sampling by generating a short draft of length K then score the draft using the target model which is the model that we wish to sample from. Finally, it uses modified rejection sampling scheme use to recover the distribution of the target model. Comparing the speculative sampling method with Chinchilla, the time has been greatly reduced.
Group 10 Presentation: Accelerating Large Language Model Decoding with Speculative Sampling
How Does Speculative Sampling Work?
1. The Basic Idea: "Drafting" Tokens
Imagine you're completing someone's sentence. Usually, you might correctly guess a few words ahead because they're predictable. The idea behind speculative sampling is similar. A smaller, quicker "draft" model predicts multiple tokens (words) rapidly, guessing what the large model would likely say next.
2. Confirming with the "Big Model"
After the smaller model drafts several tokens, the bigger, smarter (but slower) model checks if these guesses align with its own predictions. If the draft guesses match closely enough, the large model accepts them immediately.
3. Efficiently Handling Mistakes
If the draft is wrong—maybe it guesses something improbable or off-topic—the large model rejects those tokens. But rather than starting from scratch, it quickly generates a replacement token that accurately reflects its original distribution. This clever mechanism ensures accuracy stays high, and no incorrect information slips through.
Real-World Results: Surprisingly Fast, Equally Accurate
The DeepMind team tested speculative sampling with Chinchilla, a popular 70-billion-parameter language model. They used two tasks:
Text summarization (XSum): Condensing lengthy articles into short summaries.
Code generation (HumanEval): Writing accurate Python code from natural-language descriptions.
The results? Speculative sampling doubled the speed of Chinchilla’s token generation, and in some cases, it was up to 2.5 times faster—all without losing quality.
The secret sauce here isn't magic; it's that the smaller model can swiftly handle simple predictions while the larger model just verifies and fills in any blanks.
Why Does This Matter?
Real-Time AI Experiences: If AI assistants could respond faster, interactions would feel more natural, allowing seamless conversations without frustrating delays.
Cost Efficiency: Faster token generation means saving computation time, directly reducing the costs associated with running large models in commercial and research contexts.
No Need for Retraining: Speculative sampling doesn’t require adjusting or retraining the big, expensive model, making it practical to implement quickly with existing setups.
Critique and Considerations
While speculative sampling seems like a powerful approach, there are a few things worth noting:
1. Domain Dependence: The method shines brightest when token predictions are straightforward or repetitive—like structured code. However, for less predictable or more creative text, the speedup might be smaller, as the smaller draft model might guess incorrectly more frequently, increasing overhead.
2. Choosing the Right Draft Model: Selecting a suitable draft model is critical. The draft needs to be good enough to ensure high accuracy but small enough to be quick. Picking the optimal size and architecture for this secondary model can be nuanced.
3. Variance and Latency: As more speculative tokens are generated at once, the total latency per step increases, potentially adding variability in response times. This may be problematic for applications sensitive to latency variations.
Connections to Other Techniques
Speculative sampling complements existing optimization strategies:
Quantization and Distillation: These techniques compress large models into smaller, faster versions. Unlike speculative sampling, they require retraining or altering the model itself. Combining speculative sampling with quantization could further amplify performance improvements.
Parallel Decoding Techniques: Previous approaches like "blockwise parallel decoding" similarly attempt to generate tokens in groups rather than one-by-one. However, they often require substantial architectural changes. Speculative sampling’s elegance lies in its simplicity—no changes to existing models required.
Group 11 Presentation: Simple Linear Attention Language Models Balance the Recall-Throughput Tradeoff
Presented by:
Yiyuan Yang, Anar Kuatzhan, Chuan Zhang
Paper Citation
Arora, S., Eyuboglu, S., Zhang, M., Timalsina, A., Alberti, S., Zinsley, D., ... & Ré, C. (2024). Simple linear attention language models balance the recall-throughput tradeoff. arXiv preprint arXiv:2402.18668. https://arxiv.org/pdf/2402.18668
Background
Nowadays large language models still struggle with efficiency. In attention-based models, attention requires a huge number of calculations as the input gets longer. Attention also stores every previous word in memory, which makes it memory-intensive. New language models developed in recent years can generate text faster while maintaining low perplexity. Low perplexity doesn't necessarily mean good recall. Gated convolutional models also struggle with recall. Attention-based models excels in recall tasks. To address these problems, the authors introduced the Based model.
Technical Contributions
The authors introduce a new architecture called BASED (Balanced Attention through Sliding + Exponential Decay), which is designed to address the recall–throughput tradeoff in language models. It combines two key components:
- Linear Attention (Global Context):
- Uses a second-order Taylor approximation of softmax attention.
- Enables constant-size memory state during autoregressive decoding.
- Sliding Window Attention (Local Context):
- Performs exact attention within a small local window (e.g., 64–128 tokens).
- Captures short-range dependencies with high precision.
Memory-Recall Tradeoff: Observed both within and across architecture classes.
Performance with Fixed Recurrent State: Not all architectures have the same recall capacity. Mamba optimally utilizes limited memory budgets while convolutional architectures underperform with memory constraints.
The Based model combines local fine-grained attention + long-range linear attention via Taylor approximation of softmax exponential function that are sub-quadratic complexity during training and permit an efficient recurrent inference view. Based outperforms prior sub-quadratic architectures in recall quality by up to 6.2 accuracy points.
Hybrid Layer Composition for Enhanced Recall and Efficiency BASED is constructed as a hybrid architecture composed of approximately 20% linear attention layers, 20% sliding window attention layers, and 60% gated convolution layers. This combination leverages the precision of local token comparison (via sliding window and short convolutions) with the global context capabilities of linear attention. The inclusion of short gated convolution layers (e.g., filter size 3) helps model local dependencies that small sliding windows might miss, improving the architecture’s overall recall ability. This hybrid design, detailed in Appendix E.1 of the paper, enables BASED to outperform other sub-quadratic models like Mamba in both recall-intensive tasks and generation throughput benchmarks.
Architecture
Softmax-approximating linear attention (applied globally) + exact softmax attention with sliding windows (applied locally)
This combination achieves 90.8% of full softmax attention's recall accuracy while reducing latency by a factor of 100,000.
Accomplishments
- Improved Recall:
- BASED outperforms Mamba by up to 10.36 accuracy points on recall-intensive tasks.
- Recovers over 90% of softmax attention’s recall performance while using significantly less memory.
- High Throughput:
- Achieves up to 24× higher generation throughput compared to FlashAttention-2.
- Competitive wall-clock time due to efficient CUDA kernel design.
- Strong Language Modeling:
- Matches or surpasses models like Mamba and Hyena in perplexity and downstream task accuracy.
- Theoretical Contributions:
- Demonstrates that recurrent models require [math]\displaystyle{ \Omega(N) }[/math]-bit memory to perform associative recall.
- Proves that BaseConv, a gated convolution model, cannot solve recall tasks in constant layers.
- Shows that the recall-throughput tradeoff is theoretically fundamental.
Group 11 Presentation: Simple Linear Attention Language Models Balance the Recall-Throughput Tradeoff
Presented by:
Yiyuan Yang, Anar Kuatzhan, Chuan Zhang
Paper Citation
Arora, S., Eyuboglu, S., Zhang, M., Timalsina, A., Alberti, S., Zinsley, D., ... & Ré, C. (2024). Simple linear attention language models balance the recall-throughput tradeoff. arXiv preprint arXiv:2402.18668. https://arxiv.org/pdf/2402.18668
Introduction
"Simple Linear Attention Language Models Balance the Recall-Throughput Tradeoff" is a peer which introduces BASED, an architecture designed to enhance the efficiency of language models by balancing memory consumption and recall abilities. This approach combines Linear Attention with Sliding Window Attention to navigate the tradeoff between state size and recall.
Methodology
The researchers analyzed various architectures to understand the tradeoff between a model's state size and its recall ability. They observed that efficient alternatives to attention, such as H3, Mamba, and RWKV, maintain a fixed-size recurrent state but exhibit limitations in recall performance. To address this, they proposed the BASED architecture, which combines linear attention with sliding window attention. By adjusting the window size and the feature dimension of the linear attention, BASED can navigate the Pareto frontier of the recall-memory tradeoff, effectively balancing recall quality and state size.
Empirical results
The study trained language models with up to 1.3 billion parameters and found that BASED matches the perplexity of leading sub-quadratic models like Mamba. Furthermore, BASED outperformed these models on real-world recall-intensive tasks by 6.22 accuracy points. Additionally, the implementation of input/output-aware algorithms enabled BASED to achieve 24 times higher throughput in language generation compared to FlashAttention-2 when generating 1,024 tokens using 1.3 billion parameter models.
Mathematical Explanation of BASED Architecture =
(1) Sliding Window Attention -- SWA
SWA computes the attention over a fixed-size window of precious tokens, capturing local dependencies. For a window size, say of [math]\displaystyle{ w }[/math], the attention for a token [math]\displaystyle{ t }[/math] considers only the tokens [math]\displaystyle{ [t-w, t-q] }[/math].
Given queries, keys and values [math]\displaystyle{ Q \in \mathbb^{n \times d} \ \ K \mathbb^{n \times d} \ \ V \mathbb^{n \times d} \lt \math\gt for a sequence of length \lt math\gt n }[/math] and of hidden dimension [math]\displaystyle{ d }[/math], the attention output [math]\displaystyle{ A }[/math] is computed as follow:
[math]\displaystyle{ A_t = \text{softmax} \Bigg( \frac{Q_t K_{t-w : t-1}}{\sqrt{d}} \Bigg) V_{t-w : t-1} }[/math]
Where [math]\displaystyle{ A_t }[/math] is the attention output at position [math]\displaystyle{ t }[/math].
(2) Linear Attention
Linear attention approximates standard attention mechanism to capture global dependencies with reduced computational complexity. It redefines the attention operation to be linear in the sequence length using feature maps [math]\displaystyle{ \phi }[/math] to project queries and keys.
The linear attention output is computed as:
[math]\displaystyle{ A = \phi(Q) \big( \phi(K)^T V \big) }[/math]
Where [math]\displaystyle{ \phi }[/math] is a feature map function applied to the queries and keys. This formulation allows the attention computation to be rearranged and optimized, reducing the complexity to [math]\displaystyle{ O(n) }[/math].
(3) Combining Sliding Window Attention and Linear Attention
BASED integrates SWA and linear attention to leverage the strength of both methods. SWA captures fine-grained local dependencies while Linear Attention models long-range dependencies.
By simply adjusting the sliding window size [math]\displaystyle{ w }[/math] and the feature dimension [math]\displaystyle{ d }[/math] in linear attention, BASED can navigate the trade off of memory consumption and the ability to recall information. A larger [math]\displaystyle{ w }[/math] enhances local context capture by increases memory usage, whereas a higher [math]\displaystyle{ d }[/math] improves global context understanding with minimal memory overhead.
Application and Performance (brief overview)
In this paper, BASED models were trained on up to 1.3 billion parameters and evaluated on tasks requiring high recall. This included tasks in information extraction and reading comprehension. The architecture demonstrated performance matching or in some instances surpassing other sub-quadratic models such as MAMBA. Notably, it excelled in recall-intensive situations.
Implementations of linear attention often lag behind optimized standard attention in efficiency. To address this, the authors developed I/O aware algorithms, enabling BASED to achieve 24x higher throughput in language generation compared to other methods such as Flash-Attention-2.
Conclusion
The BASED architecture offers a pragmatic solution to the recall and throughput tradeoff in language models by combining sliding window and linear attention mechanisms.
This integration has allowed for efficient handling of both local and global dependencies. Subsequently this has resulted in models that are both memory efficient and able to perform high recall tasks, thereby advancing the development of more efficient LLM techniques.
BASED is the first linear attention model shown to match or beat Mamba on:
- Perplexity (language modeling quality)
- Real-world recall benchmarks (copying, in-context learning)
This challenges the growing belief that attention-free models like Mamba are the most scalable path forward.
Group 11 Presentation: Simple Linear Attention Language Models Balance the Recall-Throughput Tradeoff
Presented by:
Yiyuan Yang, Anar Kuatzhan, Chuan Zhang
Paper Citation
Arora, S., Eyuboglu, S., Zhang, M., Timalsina, A., Alberti, S., Zinsley, D., ... & Ré, C. (2024). Simple linear attention language models balance the recall-throughput tradeoff. arXiv preprint arXiv:2402.18668. https://arxiv.org/pdf/2402.18668
Background & Motivation
Transformer-based language models rely on attention mechanisms that require storing increasing amounts of key-value pairs (KV-cache) during inference. This makes them memory-intensive and less suitable for real-time or resource-constrained applications. The paper investigates whether it's possible to reduce memory usage while maintaining strong contextual recall capabilities—hence the "recall-throughput tradeoff."
Methodology
1. Based Architecture:
- Linear Attention: Uses a second-order Taylor approximation of softmax to maintain global token interactions with a fixed-size recurrent state.
- Sliding Window Attention (SWA): Applies exact softmax attention locally in small windows (64-128 tokens) to handle precise local shifts. This combination allows Based to navigate the recall-throughput tradeoff effectively.
- IO-Aware Optimizations: Custom CUDA kernels reduce memory movement and improve hardware efficiency, enabling 24× higher throughput than FlashAttention-2 during generation.
- The BASED Model: BASED (Bidirectional Attention with Stable Expansion and Delay) is proposed as a simple and efficient linear attention architecture. Its defining traits include:
- Linear complexity with respect to sequence length.
- No KV-cache required during inference, unlike transformers.
- Introduces a memory state updated recurrently across tokens.
- Achieves bidirectional context modeling using a fixed-size memory block.
This makes BASED models more efficient for both training and inference, especially in streaming or real-time settings.
2. Theoretical and Empirical Analysis:
- Lower Bounds: The paper proves that any recurrent model requires Ω(N)-bits in state size to solve associative recall, highlighting the fundamental tradeoff.
- Achieves up to 24× higher generation throughput compared to FlashAttention-2.
- Empirical Results: Experiments on synthetic and real-world tasks (e.g., MQAR, Pile perplexity, information extraction) show Based outperforms Mamba by 10.36 accuracy points on recall-intensive tasks while matching its perplexity.
Experimental Results
- BASED models match or outperform standard transformer models on various language modeling benchmarks such as WikiText-103 and PG-19.
- They show strong performance on long-context tasks, including copy and retrieval tasks, indicating good memory recall.
- BASED demonstrates superior throughput, especially in inference without KV caching.
Key Findings
- Efficient memory usage and fast inference are achievable without sacrificing much performance.
- Linear attention models like BASED can serve as a viable alternative to transformers in memory-constrained or latency-sensitive applications.
- There exists a tradeoff surface between recall and throughput, and BASED models lie on an efficient frontier of that tradeoff.
Conclusion
Based expands the Pareto frontier of the recall-throughput tradeoff by combining simple, well-known techniques (linear and sliding window attention) with hardware-aware optimizations. The results suggest that efficient models can achieve high recall without sacrificing throughput, offering a promising direction for future language model architectures.
Group 12 Presentation: EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees
Presenters
Mutong Zhang, Hanqing Bi
Paper Citation
Y. Li, F. Wei, C. Zhang, and H. Zhang, ‘EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees’, Jun. 30, 2024, arXiv: arXiv:2406.16858. doi: 10.48550/arXiv.2406.16858.
https://arxiv.org/abs/2406.16858
Background
- LLMs to date have shown great performance, but they are slow and computationally expensive
- Speculative sampling - small model generates candidate tokens, whereas large model then evaluates those tokens, reducing the number of times the expensive computations of the large model have to occur
- However, speculative decoding methods suffered from low acceptance rates of draft tokens as they lacked efficient ways to predict which draft tokens would be accepted (non-efficient computation).
- EAGLE 1 used a draft tree to improve performance of speculative execution, and this builds on the author's prior work
- The authors observe that acceptance rates are also context-dependent, suggesting the need for adaptive drafting.
Paper Contributions
- EAGLE 1 doesn't directly predict tokens, and rather maps tokens to features, predicts features, and then predicts tokens from those features
- This work was shown to improve upon previous speculative execution methods
- Eagles uses a tree structure to propose alternative token when speculative sampling draft token is rejected by the full model
- EAGLE 2 noted that token acceptance is dependant on context and position. The first token has a high acceptance rate, and later tokens have lower acceptance rates
- EAGLE 2 uses a dynamic draft trees which incorporate "tree attention" to incorporate context information into selecting the next candidate token, increasing the acceptance rate of the token, as it depends on context as well as not only position
Summaries of Key Points
Addressing the Challenge of Slow LLM Inference
Large language models (LLMs) have revolutionized natural language processing, but their inference remains computationally expensive and slow. This bottleneck arises due to the vast number of model parameters and the sequential nature of token generation. Improving inference efficiency without sacrificing accuracy is a critical research direction in the field.
The Foundation: Speculative Sampling and Draft Models
EAGLE-2 builds on speculative sampling, a technique that leverages a smaller, lightweight draft model to propose candidate tokens, which are then verified by the full LLM. This approach speeds up inference by reducing the number of computations performed by the large model.
Evolution from EAGLE-1 to EAGLE-2
The original EAGLE-1 introduced a static draft tree, assuming that token acceptance rates depend solely on their position within the tree structure. However, this assumption overlooks the context-dependent nature of token acceptance. EAGLE-2 addresses this limitation by introducing a dynamic draft tree adjustment mechanism, which refines speculative sampling based on real-time confidence scores.
Key Mechanisms of EAGLE-2
1. Instead of directly predicting tokens, EAGLE-2’s draft model generates feature representations, which are then processed by the head of the LLM to produce token predictions. This method enhances accuracy while maintaining efficiency.
2. EAGLE-2 employs a tree-based verification mechanism, which is computationally more efficient than the standard chain-structured verification used in traditional speculative sampling. By verifying multiple candidate tokens in parallel, it accelerates the overall process.
3. A core innovation in EAGLE-2 is its ability to dynamically modify the draft tree based on confidence scores from the draft model. This ensures that speculative sampling adapts to varying contexts, improving token acceptance rates and overall inference speed.
Dynamic Expansion and Re-Ranking
1. Expansion Phase: EAGLE-2 introduces a novel expansion phase, where a tree-attention mechanism processes all tokens in a layer simultaneously. This significantly enhances efficiency compared to sequential processing. Additionally, selective expansion prioritizes only the top-k tokens with the highest estimated global acceptance probabilities, preventing unnecessary computational overhead.
2. Re-Ranking Phase: Following expansion, EAGLE-2 reranks candidate tokens by selecting the top-m tokens with the highest acceptance probabilities. In cases where multiple tokens have similar scores, shallower nodes in the draft tree are prioritized, further optimizing verification speed.
During the reranking phase, EAGLE-2 ensures that the most promising tokens are selected not only based on their depth in the draft tree but also on their overall likelihood of being accepted. Since deeper nodes in the draft tree tend to have lower values due to the multiplication of multiple acceptance probabilities, some shallow nodes—though not expanded in the previous phase—may have higher values. To optimize performance, EAGLE-2 reranks all candidate tokens (including both shallow and deep nodes) and selects the top m tokens with the highest values. Importantly, when multiple tokens share the same value, shallower nodes are prioritized to preserve tree connectivity and maximize verification efficiency. This strategy improves both the average acceptance length and the speedup ratio, as confirmed by the ablation studies.
Experimental Results: Significant Performance Gains
EAGLE-2 achieved acceleration ratios of 3.05x - 4.26x across various tasks and large language model series such as Kuna, Llama 2 and Llama 3, making it 20% - 40% faster than EAGLE-1. It is also approximately 2 times faster than Medusa and 2.3 times faster than Lookahead. On token throughput, EAGLE-2 processes 4-5.5 tokens per verification cycle, about twice as many as traditional speculative sampling.
Key Advantages of EAGLE-2
1. Plug-and-Play Efficiency – EAGLE-2 requires no additional model training, as it seamlessly integrates the pre-trained draft model from EAGLE-1. It does not alter the original LLM, and maintains the exact same output distribution as greedy decoding (i.e., it is lossless).
2. Robust and Reliable – Unlike some acceleration techniques, EAGLE-2 does not modify the original model parameters or relax acceptance conditions, ensuring stable and consistent outputs.
3. Broad Generalization – The framework generalizes well across different tasks and architectures, demonstrating strong adaptability in diverse applications.
EAGLE-2 represents a significant advancement in accelerating LLM inference. By introducing a dynamic draft tree, efficient expansion strategies, and intelligent token re-ranking, it substantially reduces computational costs while maintaining accuracy. As large-scale models continue to grow, techniques like EAGLE-2 will be instrumental in making LLMs more practical and accessible for real-world applications.
Accomplishments
- State-of-the-art inference speedup:
- EAGLE-2 achieves up to 5× speedup in language model inference.
- Outperforms prior speculative decoding methods including EAGLE, Medusa, Lookahead, and others.
- Longest average acceptance length:
- EAGLE-2 generates longer sequences per accepted draft, reducing the number of calls to the target model.
- Wide applicability:
- Tested across 6 diverse tasks, including:
- Conversation
- Code generation
- Mathematical reasoning
- Question answering
- Summarization
- Instruction following
- Tested across 6 diverse tasks, including:
- Model-agnostic design:
- Compatible with popular LLMs such as:
- Vicuna
- LLaMA2-Chat
- LLaMA3-Instruct
- Compatible with popular LLMs such as:
Group 12 Presentation: EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees
Presenters
Mutong Zhang, Hanqing Bi
Paper Citation
Y. Li, F. Wei, C. Zhang, and H. Zhang, ‘EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees’, Jun. 30, 2024, arXiv: arXiv:2406.16858. doi: 10.48550/arXiv.2406.16858.
https://arxiv.org/abs/2406.16858
Constructive Critique and Review
The paper “EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees” introduces an innovative approach to enhancing the efficiency of Large Language Model (LLM) inference through the implementation of a context-aware dynamic draft tree.
This paper makes two main contributions to the field:
1). Dynamic Draft Tree Structure
Building upon the original EAGLE framework, EAGLE-2 replaces the static draft tree with a dynamic architecture that adapts on context. This adjustment acknowledges that the acceptance rate of draft tokens is influences not only by their position, but also by the surrounding context which leads to a more efficient token generation.
2). Utilization of Well Calibrated Draft Models
The paper also reveals that the draft model's confidence scores closely approximate the acceptance rates of draft tokens. As such, by leveraging this calibration, EaGLE-2 effectively predicts which tokens are more likely to be accepted, optimizing the entire drafting process and token generation.
Performance Outcomes
Extensive evaluations have been conducted access three series of LLMs, those of Vicuna, LLaMA2-Char and LLaMA3-Instruct, as well as on six diverse tasks, including multi-turn conversations, code generation and mathematical reasoning.
The results reveal that EAGLE-2 achieves speedup ratios of 3.05x to 4.26x, nearly a 20% to 40% improvement over EAGLE-1. Notably, this acceleration is achieved without altering the distribution of the generated text, ensuring the fidelity of the model's outputs.
Advancements within the field
EAGLE-2 makes significant advancements in the realm of LLM inference and especially optimization. By introducing a content-aware dynamic draft tree, the paper addresses the limitations of the previous speculative sampling architecture of EAGLE-1, which is a static in nature.
This innovation enhances the acceptance rate of draft tokens, thereby reducing inference latency and computational costs. Additionally, the approach maintains the integrity of the generated text distributions which distinguishes itself from other acceleration models that have compromised outputs.
Conclusion
The methodologies and findings presented in this paper offer a substantial contribution to the field of ML and most notably in optimizing the efficiency of the generative LLMs. The introduction of dynamic, context-aware drafting mechanisms sets a new benchmark for speculative sampling techniques paving the way for more responsive, fast, and cost-effective LLM applications.
Group 12 Presentation: EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees
Presenters
Mutong Zhang, Hanqing Bi
Paper Citation
Y. Li, F. Wei, C. Zhang, and H. Zhang, ‘EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees’, Jun. 30, 2024, arXiv: arXiv:2406.16858. doi: 10.48550/arXiv.2406.16858.
https://arxiv.org/abs/2406.16858
Summaries
EAGLE-2 makes improvements to the speculative sampling by improving EAGLE. EAGLE-2 constructs dynamic tree as the draft model based on context (acceptance rates). To reduce the costs of computing acceptance rates, the confidence score of the draft model that can be generated by EAGLE is used to approximate the acceptance rate. EAGLE-2 does not require extra training and there is no loss in the distribution of generated texts compare with the original LLM as it does not modify the original LLM.
Key Contributions
EAGLE-2 follows the same procedure as EAGLE. EAGLE-2 constructs a tree draft model based on context (acceptance rate), where the acceptance rate is approximated by the confidence score from EAGLE.
Expansion Phase: The best tokens (top-k tokens) from the last layer are used to predict tokens in the next layer.
Re-ranking: Tokens with the highest acceptance rates (top-m tokens) are selected and used in the verifying phase.
Metrics: Speedup ratio and average acceptance length are used to evaluate the model.
Explanations of details
Why a dynamic tree is used
The acceptance rates of tokens at different positions were tested. It was noted that the acceptance rates of tokens were higher in the upper left side of the tree, and were lower in the lower right side. Also, the acceptance rates at the same position varied significantly. Therefore, the authors concluded that it is worthy to build a dynamic tree that is based on context.
Why use an approximation of acceptance rates
To calculate the real acceptance rates of tokens, we need to do a forward pass in the original LLM, which is costly. The authors found that there was a positive relationship between the confidence score and the acceptance rate.
Related Works
For LLMs, the existing works that try to accelerate LLM inference include: low-bit quantization (Hubara et al., 2018; Shen et al., 2020; Kim et al., 2021; Zadeh et al., 2020; Zafrir et al., 2019), pruning (Gale et al., 2019; Sanh et al., 2020), distillation (Hinton et al., 2015). These methods often decrease the quality of the output.
Speculative sampling (Leviathan et al., 2023; Chen et al., 2023a) tries to reduce the costs in the decoding process while preserving the output quality. It involves two phases: the generating phase that generates multiple tokens and the verifying phase. It uses a chain-structured draft model.
EAGLE has improved the speculative sampling by autoregressively predicting features and using both features and tokens to reduce uncertainties. EAGLE uses a static tree structure as the draft model. Medusa (Cai et al., 2024) also uses a tree structure. They all pick candidates at specific step.
Group 12 Presentation: EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees
Presenters
Mutong Zhang, Hanqing Bi
Paper Citation
Y. Li, F. Wei, C. Zhang, and H. Zhang, ‘EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees’, Jun. 30, 2024, arXiv: arXiv:2406.16858. doi: 10.48550/arXiv.2406.16858.
Background
Big language models (LLMs) are changing the game in AI, but they come with a huge downside: they take a lot of time and computing power to run. Making them faster and cheaper without losing quality is a major challenge. One popular way to speed things up is speculative decoding, where a smaller, faster model makes guesses to help the bigger model generate text more efficiently. While this approach works, it still struggles with getting the right balance between speed, accuracy, and ease of use.
Main Idea
EAGLE 2 is a new and improved speculative decoding method that makes LLMs way faster without sacrificing quality. The trick is using a small, efficient model to suggest multiple possible next words at the same time, then having the big model quickly check and accept as many as possible. Unlike older speculative decoding techniques that focus on just one guess at a time, EAGLE 2 smartly boosts the number of accepted tokens using a better verification process. It also fine-tunes how guesses are made and rejected, cutting down on wasted computations and speeding things up even more.
Experiments
The researchers ran a bunch of tests to see how well EAGLE 2 performs across different LLMs and datasets. They looked at:
• How it compares to standard autoregressive decoding and older speculative decoding methods.
• How well it works on general NLP tasks like text generation and summarization, as well as more specific datasets.
• Key metrics like speed improvements, how often the model accepts the suggested words, and how good the final output is.
Results
EAGLE 2 delivered big improvements in speed without lowering output quality. Key takeaways include:
• Higher Acceptance Rates: More of the small model’s suggested words were accepted, meaning fewer wasted computations.
• Faster Performance: Compared to the usual autoregressive decoding method, EAGLE 2 sped things up by 2 to 4 times.
• Same Great Quality: Even with the increased efficiency, the text output was just as good as before, based on standard evaluation methods.
Conclusion
EAGLE 2 makes large language models run much faster while keeping their responses accurate and high-quality. This makes it a strong candidate for real-world applications where speed and cost matter. With further improvements, it could become even more efficient and be applied to even larger-scale AI systems.
Group 12 Presentation: EAGLE-2: Faster Inference of Language Models With Dynamic Draft Trees
Presented by
Mutong Zhang and Hanqing Bi
Paper Citation
Y. Li, F. Wei, C. Zhang, and H. Zhang, ‘EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees’, Jun. 30, 2024, arXiv: arXiv:2406.16858. doi: 10.48550/arXiv.2406.16858. https://arxiv.org/abs/2406.16858
Background
- Large Language Models (LLMs) are computationally intensive during inference due to their substantial parameter sizes, leading to high latency and resource consumption.
- Speculative sampling is a technique that was aimed at accelerating LLM inference by generating multiple tokens in parallel. This technique was employed in EAGLE-1, a model that uses a static draft tree, where token acceptance is predetermined by position. However, this approach over looks the context-dependent nature of token acceptance rates.
- EAGLE-2 addresses these limitations by introducing a dynamic adjustment mechanism for the draft tree, and it enhances efficiency by adapting to contextual variations
Summaries of Key Points
- EAGLE-2 recognizes that token acceptance rates are influenced by context, not just position, and it performs context-aware adaptations in order to dynamically adjust the draft tree. By leveraging the draft model’s confidence scores, which closely approximate actual acceptance probabilities, the draft tree dynamically adjusts, focusing computational resources where there are most effective.
- The dynamic draft tree is constructed through expansion and re-ranking phases that selectively generate and prioritize token branches based on their likelihood of acceptance, improving efficiency and reducing redundant computation.
- In terms of performance, EAGLE-2 achieves significant speedups without compromising the integrity of the generated text, maintaining the original output distribution. Additionally, it resulted in an increase in throughput, processing almost twice as many tokens per cycle compared to traditional methods. Ablation studies demonstrate that EAGLE-2 offers higher speedups and better acceptance rates than EAGLE-1, showing improvement by use of the dynamic nature of the draft tree.
- The advantages of EAGLE-2 are that no additional training is required and there is consistency in text generation even when using accelerated inference. EAGLE-2 also exhibits consistent performance improvements across a diverse range of tasks.
Explanation of Expansion and Re-ranking phases
- To construct the dynamic draft tree efficiently, the EAGLE-2 model employs an expansion phase and a re-ranking phase.
- In the expansion phase, a tree attention mechanism guides the selective growth of the tree by identifying and expanding the most promising nodes. It uses confidence scores to prioritize branches with higher token acceptance probabilities. This stage ensures computational resources are focused on likely candidates, reducing unnecessary expansion.
- In the re-ranking phase, draft tokens are re-evaluated using the target model, and the top-m tokens with the highest probabilities are selected to ensure optimal token generation sequences. This preserves generation quality while maintaining consistency in token ordering and ensuring child nodes follow their parent nodes in the final output sequence.
Group 13 Presentation: Linear Attention Mechanism: An Efficient Attention for Semantic Segmentation
Presented By
Yuke Liu, Mei Si
Paper Citation
R. Li, J. Su, C. Duan, and S. Zheng, ‘Linear Attention Mechanism: An Efficient Attention for Semantic Segmentation’, Aug. 20, 2020, arXiv: arXiv:2007.14902. doi: 10.48550/arXiv.2007.14902.
https://arxiv.org/abs/2007.14902
Background
- Existing transformer models have [math]\displaystyle{ \mathcal{O} (n^2) }[/math] complexity, which is problematic as model size grows
- This limits the growth of model sizes due to computational resources constraints
- This paper focused on an alternative method to conventional dot product attention that is more computationally efficient
- Standard attention required the computation of [math]\displaystyle{ Q K^\top }[/math] which requires [math]\displaystyle{ \mathcal{O} (n^2) }[/math] complexity
- The paper proposes a linear attention mechanism that solves the problem while keeping the performance.
Technical Contributions
Rather than doing the full computation for the softmax in the transformer architecture, the authors instead compute
[math]\displaystyle{ D(\mathbf{Q}, \mathbf{K}, \mathbf{V})_i = \frac{\sum_{j=1}^{N} e^{\mathbf{q}_i^{T} \mathbf{k}_j} \mathbf{v}_j}{\sum_{j=1}^{N} e^{\mathbf{q}_i^{T} \mathbf{k}_j}} = \frac{\sum_{j=1}^{N} \text{sim}(\mathbf{q}_i, \mathbf{k}_j) \mathbf{v}_j}{\sum_{j=1}^{N} \text{sim}(\mathbf{q}_i, \mathbf{k}_j)} }[/math]
and define the transformation function as
[math]\displaystyle{ \text{sim}(\mathbf{q}_i, \mathbf{k}_j) = \phi(\mathbf{q}_i)^{T} \phi(\mathbf{k}_j) }[/math]
The authors apply a first-order Taylors series expansion, and after some rearranging and substiution arrive to their final formula (full derivation not shown)
[math]\displaystyle{ D(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \frac{ \sum_j \mathbf{v}_{i,j} + \left( \frac{\mathbf{Q}}{\lVert \mathbf{Q} \rVert_2} \right) \left( \left( \frac{\mathbf{K}}{\lVert \mathbf{K} \rVert_2} \right)^{T} \mathbf{V} \right) }{ N + \left( \frac{\mathbf{Q}}{\lVert \mathbf{Q} \rVert_2} \right) \sum_j \left( \frac{\mathbf{K}}{\lVert \mathbf{K} \rVert_2} \right)_{i,j}^{T} } }[/math]
The authors form of the attention mechanism can be solved in [math]\displaystyle{ \mathcal{O} (n) }[/math] complexity, reducing the computational scaling that existing with conventional transformers and enabling the creation of larger models using the underlying attention mechanism.
- Hybrid Attention Mechanisms for Selective Efficiency
Idea: Combine linear attention with local or sparse attention in a hybrid structure where the system dynamically chooses the best mechanism depending on context (e.g., object density, resolution, or modality). The rationale is that the dense areas of an image (e.g., urban environments) may benefit from sparse or local attention for spatial precision. Less complex regions (e.g., sky, ocean) can rely on linear attention for efficiency. It could be implemented with a gating module that selects or blends attention types.
Model Performance Evaluation
The model performance evaluation primarily focused on assessing how well the proposed linear attention mechanism enhances semantic segmentation performance while maintaining computational efficiency.
- Dataset: The experiments were conducted using the Fine Grained Image Data set, which consists of high-resolution satellite images. This dataset was selected due to its complex landscapes and varied environmental conditions, presenting significant challenges.
- Data Preprocessing: Due to the large size of the images in the dataset, each image was divided into smaller patches of 256x256 pixels, resulting in a total of 7,280 patches.
- Data Split: These batches were methodically partitioned into 60% for training, 20% for validation, and 20% for testing. This split ensured a rigorous evaluation of the model's performance across different scenarios.
- The standard dot-product attention is replaced by the linear attention mechanism in the baseline segmentation models such as U-Net, Res101, DeepLab, PSPNet, etc
- Implementation Details: The experiments were implemented using PyTorch framework and trained on an NVIDIA RTX 2080Ti GPU.
- Evaluation Metrics: OA, AA, K, mloU, F1.
Comparative Analysis
The proposed linear attention mechanism achieves a similar level of accuracy in semantic segmentation tasks compared to traditional dot product attention.
Linear attention maintains comparable performance metrics while drastically reducing both memory footprint and processing time required.
The efficiency gains of linear attention become more apparent in large-scale segmentation tasks, making it suitable for high-resolution images and long text sequences.
Future work includes extending the linear attention mechanism to more complex network designs, integrating with state-of-the-art deep learning models, and optimizing for real-time processing in demanding applications.
Method | OA | AA | K | mIoU | F1 |
---|---|---|---|---|---|
U-Net | 86.378 | 74.532 | 83.357 | 64.516 | 75.532 |
U-Net LAM | 87.692 | 77.297 | 84.935 | 68.038 | 78.593 |
Res101 | 89.251 | 80.451 | 86.846 | 72.433 | 82.510 |
Res101 LAM | 90.178 | 82.757 | 88.041 | 74.085 | 83.105 |
RefineNet | 89.857 | 81.169 | 87.597 | 73.167 | 83.113 |
RefineNet LAM | 90.214 | 83.544 | 88.083 | 74.973 | 84.311 |
DeepLab | 89.388 | 80.905 | 87.079 | 71.809 | 81.077 |
DeepLab LAM | 89.576 | 81.692 | 87.287 | 72.827 | 82.702 |
DeepLabV3+ | 90.125 | 81.483 | 87.959 | 72.668 | 81.492 |
DeepLabV3+ LAM | 90.315 | 81.695 | 88.182 | 73.727 | 82.736 |
PSPNet | 90.573 | 82.211 | 88.485 | 74.797 | 83.761 |
PSPNet LAM | 90.725 | 83.088 | 88.677 | 75.695 | 84.480 |
FastFCN | 90.336 | 83.625 | 88.221 | 74.364 | 83.704 |
FastFCN LAM | 90.835 | 83.075 | 88.769 | 75.174 | 84.023 |
Constructive Critiques and Reviews
The presenters throughly explained the precursor to the proposed Linear Attention Mechanism and how it has potential in computer vision applications. However, it would helpful to show individual samples as well, showcasing any difference in visual accuracy (if any) and where this approach may struggle in other computer vision tasks. While this proposed solution does clearly bring advantages, it most often is the case that certain sub-tasks in computer vision benefit whereas others actually show inferior results.
Group 13 Presentation: Linear Attention Mechanism: An Efficient Attention for Semantic Segmentation
Presented By
Yuke Liu, Mei Si
Paper Citation
R. Li, J. Su, C. Duan, and S. Zheng, ‘Linear Attention Mechanism: An Efficient Attention for Semantic Segmentation’, Aug. 20, 2020, arXiv: arXiv:2007.14902. doi: 10.48550/arXiv.2007.14902.
https://arxiv.org/abs/2007.14902
Introduction
(1) Attention Mechanism
The attention mechanism gained prominence for its capacity to refine feature maps and to capture long-range dependencies in both computer vision and natural language processing. By focusing computational resources on the most relevant portions of the input data, attention mechanisms have proven especially valuable in tasks that demand context from large input sizes—such as high-resolution images or lengthy text sequences.
(2) Dot Product Attention
A widely used form of attention is the dot product attention. It forms the core of many state-of-the-art models, including Transformers in NLP and non-local modules in computer vision. Despite its strengths, dot product attention scales poorly (quadratically) with input size in both memory and computational cost.
(3) Problem Statements
Because the dot product attention mechanism requires [math]\displaystyle{ O(N^2) }[/math] operations for inputs of length [math]\displaystyle{ N }[/math] , deploying it on large inputs—e.g., high-resolution images or long sequences—becomes infeasible. This computational bottleneck has motivated research into more efficient variants of attention, particularly approaches that reduce the quadratic cost.
Overview of Attention Mechanismn
(1) Scaling Attention
In contrast to dot product attention, “scaling” attention mechanisms (sometimes also called “channel” or “spatial” attention in the literature) aim to emphasize informative features while reducing redundant information. Examples include: Squeeze-and-Excitation (SE) modules, which learn channel-wise scaling factors. Convolutional Block Attention Module (CBAM), which extends SE by considering both channel and spatial attention. These approaches do not necessarily learn long-range dependencies across positions in the same way as dot product attention; instead, they focus on enhancing or diminishing particular channels or spatial locations.
(2) Linear Attention: Taylor Expansion → Linear Time
Recent research efforts address the [math]\displaystyle{ O(N^2) }[/math] complexity of dot product attention by re-formulating the similarity function (the key step in attention). One family of methods uses kernel-based approximations, while another employs mathematical expansions. By approximating the exponential term of the softmax function (commonly used in dot product attention) with a first-order Taylor expansion, it is possible to make attention computations linear with respect to [math]\displaystyle{ N }[/math]. This insight forms the basis for linear attention mechanisms.
(3) Semantic Segmentation
Semantic segmentation requires dense predictions over every pixel of an image, and capturing global context is crucial. Traditional convolution-based networks, whether they follow a DilatedFCN (e.g., DeepLab, PSPNet) or an Encoder-Decoder (e.g., U-Net) architecture, benefit significantly from attention modules that can encode relationships across spatial locations. However, if the resolution is high, dot product attention quickly becomes prohibitively expensive, motivating more efficient variants like linear attention.
Dot Product Attention Details
(1) Query, Key, and Value
Dot product attention transforms each input feature [math]\displaystyle{ x_i }[/math] into three different representations:
Query([math]\displaystyle{ q_i }[/math]) Key([math]\displaystyle{ k_i }[/math]) Value([math]\displaystyle{ v_i }[/math])
The core operation then measures similarity between every query–key pair to weight the contribution of all values in forming an output.
(2) Kernel-Based Approximation
A known strategy to reduce complexity is to view the exponential term in dot product attention as operating in a kernel space, thereby factorizing the softmax attention to achieve lower complexity. Various works replace the explicit softmax with kernel-based transformations, enabling a more efficient computation.
Linear Attention Mechanism
Linear attention leverages a Taylor expansion to approximate the exponential term. By applying additional constraints (such as [math]\displaystyle{ L_2 }[/math] normalization to keep the term non-negative), the resulting formulation scales as [math]\displaystyle{ O(N) }[/math] rather than [math]\displaystyle{ O(N^2) }[/math]. This makes the mechanism practical for high-resolution inputs or very long sequences, significantly broadening the usability of attention in semantic segmentation and beyond.
Experimental Settings
(1) Dataset
Many experiments on linear attention for segmentation are conducted using large, high-resolution datasets. One common benchmark highlighted is a satellite imagery dataset (e.g., Fine Gaofen Image Dataset, GID). These datasets typically comprise large aerial images that are partitioned into patches, with splits for training, validation, and testing.
(2) Model Implementations Baseline segmentation networks (e.g., PSPNet, DeepLab series, U-Net variants, FastFCN, and RefineNet) integrate the proposed linear attention modules in place of, or alongside, standard attention mechanisms. The training setup often employs standard optimizers (such as Adam), cross-entropy loss, and hardware accelerators (like NVIDIA GPUs).
(3) Evaluation Metrics Common segmentation metrics include:
Overall Accuracy (OA) Average Accuracy (AA) Kappa Coefficient (K) mean Intersection over Union (mIoU) F1 Score
Evaluation metrics
Mean Intersection over Union (mIoU)
[math]\displaystyle{ IOU = \frac{TP}{TP+FP+FN} }[/math]
[math]\displaystyle{ mIoU = 1/N \sum_{i=1}^{N} IoU_i }[/math]
- Measures the average overlap between predicted segmentation and ground truth across all classes.
- Commonly used in semantic segmentation benchmarks.
Why It’s Important in This Paper:
- Serves as a primary benchmark for assessing spatial accuracy of LAM.
- Especially meaningful in pixel-level classification, where precise boundaries matter.
- Used to compare LAM-enhanced networks against traditional attention-based and baseline CNN models.
- LAM achieves competitive or improved mIoU across multiple medical image segmentation datasets (DRIVE, STARE, CHASE_DB1), validating its contextual understanding with reduced computational cost.
Kappa coefficient [math]\displaystyle{ \kappa = \frac{p_0 - p_e}{1-p_e} }[/math]
- Assesses statistical agreement between predicted segmentation and ground truth, adjusting for agreement by chance
Why It’s Important in This Paper:
- Medical segmentation tasks often suffer from class imbalance (e.g., small vessels vs large background).
- Kappa offers a robust metric that accounts for imbalanced distributions, unlike plain accuracy.
- The paper reports high Kappa values for LAM-based models, showing that they make meaningful predictions beyond chance, even when foreground (vessel) pixels are rare.
- This supports that LAM is not just "overfitting" to majority classes but learns class-relevant structure effectively.
Results and Experimental Improvement
In benchmark experiments, linear attention typically boosts the performance of various baseline segmentation models while reducing memory and computational overhead. For example, when embedded into U-Net, PSPNet, or DeepLab, linear attention can achieve a higher mIoU and Kappa coefficient compared to the original dot product attention. These gains confirm that the approximation introduced by the Taylor expansion still captures sufficient global context for accurate segmentation.
Comparative Analysis
Limitations:
(1) Approximation Error: The first-order Taylor expansion introduces approximation errors relative to exact dot product attention. Although experiments show minimal performance degradation, in certain tasks or extreme input scales, further refinements or higher-order terms might be necessary.
(2) Architecture Constraints: Integrating linear attention can require modifications to the existing network design, including normalization steps and careful parameter initialization.
Conclusion and Future Work
Linear attention mechanisms substantially reduce both memory usage and computational cost in high-resolution vision tasks, making them a promising alternative to dot product attention. Future research may involve:
Extending linear attention to multi-modal domains where extremely large inputs (e.g., video or multi-spectral data) are common.
Investigating higher-order approximations that may yield even more accurate results while retaining near-linear scalability.
Combining linear attention with other efficient modules (e.g., lightweight convolutions or quantization techniques) to further push the boundaries of real-time segmentation on resource-constrained devices.
Group 13 Presentation: Linear Attention Mechanism – An Efficient Attention for Semantic Segmentation
Presented By
Yuke Liu, Mei Si
Paper Citation
R. Li, J. Su, C. Duan, and S. Zheng, *Linear Attention Mechanism: An Efficient Attention for Semantic Segmentation*, arXiv: [2007.14902](https://arxiv.org/abs/2007.14902)
---
What's the Problem?
Transformers are powerful but **computationally expensive**—especially when working with large inputs like high-resolution images or long sequences. The root issue is their **quadratic complexity** in sequence length due to dot product attention. This makes scaling hard and slows things down.
This paper tackles that bottleneck by proposing a more efficient method: **linear attention**, which keeps performance strong but dramatically cuts the cost.
---
Key Idea: Linear Attention Instead of Dot Product
Dot product attention works, but it’s expensive. This paper introduces a smarter way by:
- Approximating attention using a **first-order Taylor expansion** of the softmax function. - Rewriting the attention calculation to **scale linearly** with input size. - Enabling **O(N)** performance instead of **O(N²)**.
The authors also suggest a **hybrid attention system** that can flexibly switch between local and global attention depending on context (like object density or resolution). For example, city scenes may need precise local attention; ocean images might not.
---
Technical Details (In a Nutshell)
- Instead of computing the full softmax in standard attention, the authors reformulate the attention matrix using linear operations. - They apply kernel-based tricks to restructure the attention, reducing compute without sacrificing performance. - The result? Faster attention that still captures essential relationships.
Final formula (simplified): \[ D(Q, K, V) = \frac{\sum_j v_j + (\left( \frac{K_i}{\sum_l K_l} \right)^T V)}{N + \left( \left( \frac{K_i}{\sum_l K_l} \right)^T K \right)} \]
Don’t worry if it looks complicated—just know it’s faster and scales better.
---
Dataset & Experimental Setup
- Dataset: Fine Grained Image Dataset (GID) with large aerial satellite images. - The dataset was split 60/20/20 into train/val/test sets. - Patch size: 256x256, total of 7,280 patches. - Used standard segmentation models like U-Net, Res101, PSPNet, DeepLab, etc. - Training on NVIDIA RTX 2080Ti using PyTorch.
---
Metrics Used
- **OA (Overall Accuracy)** - **AA (Average Accuracy)** - **Kappa Coefficient** (accounts for chance agreement) - **mIoU (mean Intersection over Union)** – most important for segmentation - **F1 Score**
---
How Well Does It Perform?
Here’s what the authors found:
- **Competitive Accuracy**: Linear attention performs on par with dot product attention in almost all models. - **Better Efficiency**: Requires less time and memory, especially useful for high-res inputs. - **Scalability**: Works better in larger models or longer sequences. - **Flexible Integration**: Easily replaces attention in many architectures with minimal tweaks.
---
Pros of Linear Attention
- Efficient: Great for long inputs or high-res images. - Accurate: Keeps precision high in segmentation tasks. - Flexible: Works across many architectures. - Practical: Easy to implement in PyTorch.
---
Limitations & Trade-offs
- **Approximation error**: Taylor expansions are not exact—there’s some performance loss in rare cases. - **Architecture tweaks**: Some models may require re-tuning (normalization, initialization, etc.) to use linear attention smoothly.
---
Final Thoughts & What’s Next
This paper shows that linear attention can be a drop-in, compute-efficient alternative to traditional attention for semantic segmentation. It’s especially promising for real-time systems or edge devices where every GPU cycle counts.
- Future directions include:**
- Extending to multi-modal inputs like video or multi-spectral images. - Testing higher-order Taylor approximations for even better accuracy. - Combining linear attention with other modules like lightweight convolutions or quantization.
Group 14 Presentation: Scalable Watermarking for Identifying Large Language Model Outputs
Presented by:
Ryan Tymkow and Benjamin Schnapp
Paper Citation
Dathathri, S., See, A., Ghaisas, S. et al. Scalable watermarking for identifying large language model outputs. Nature 634, 818–823 (2024). https://doi.org/10.1038/s41586-024-08025-4.
Summaries of key points
This paper tackles the problem of watermarking—embedding a detectable signal into the outputs of large language models (LLMs) to distinguish generated text from human-written content. The challenge is doing this in a way that is robust, scalable, and minimally intrusive to the model's output quality. The authors propose a method based on statistical biasing of token selection during generation. Specifically, they partition the vocabulary into “greenlist” and “redlist” tokens at each step based on a seeded hash function, and then bias sampling toward the greenlist using a tunable parameter. The watermark is invisible in individual outputs but detectable over longer texts using hypothesis testing. Importantly, the approach is model-agnostic, doesn’t require retraining, and adds very little computational overhead. It also scales well to large models and can be applied in high-throughput or real-time settings. Overall, it’s a lightweight yet effective strategy for watermarking that balances detectability, scalability, and usability.
A key limitation of this approach is that it may still be vulnerable to paraphrasing or text transformation—simple rewriting could break the statistical signature. Another concern is adversarial robustness: since the watermarking method is relatively transparent (based on vocabulary partitioning), a knowledgeable attacker could design strategies to erase or spoof the signal. Additionally, while the method maintains fluency and quality in most cases, biasing token selection could subtly affect stylistic or semantic nuances, especially in creative writing or long-form tasks. The paper doesn’t deeply explore how this might influence user-facing applications like chatbots or summarizers. Lastly, while the watermark is statistically detectable, it’s not embedded in a cryptographic sense, so it may not offer strong guarantees in high-stakes verification contexts.
Clear explanations to aid understanding
Imagine if every time a language model generated text, it subtly preferred certain words over others—but in a way that's invisible to readers. That’s what this watermark does. At each generation step, the model hashes its current context to choose a “greenlist” of preferred tokens and then slightly boosts their probabilities. Over many words, these choices form a statistically detectable pattern. It's like nudging a roulette wheel so certain numbers come up just a bit more often—not enough to be obvious, but enough to spot if you know where to look. The method is efficient and easy to integrate, since it works at the sampling level and doesn't require modifying the model architecture.
Review
I thought that the visual aids in this presentation greatly helped explain the process of how synthid alters the sampling process and how it embeds watermarks during the process, helping distinguish outputs created by the LLM vs outputs not created by the LLM. The graphic showing the three main parts of the process, the random seed generator, the sampling algorithm, and the scoring function, made it simpler to understand the whole process. The examples with the fruits, first generating random watermarking functions, then going through the tournament to sample the output token, also made it really easy to follow along on what exactly is going on.
Regarding some suggestions, perhaps including multiple examples of two outputs contrasting the encoded with the normal sampling output would help viewers better picture the elegance of the process. With regards to the limitations of this apporach, perhaps encoding over a longer horizon (multiple words to sentences) would make it more robust, although in the end, it could never be failproof by the fundamental essence of written language.
Group 14 Presentation: Scalable Watermarking for Identifying Large Language Model Outputs
Presented by:
Ryan Tymkow and Benjamin Schnapp
Paper Citation
Dathathri, S., See, A., Ghaisas, S., et al. Scalable watermarking for identifying large language model outputs. Nature, 634, 818–823 (2024). https://doi.org/10.1038/s41586-024-08025-4
Key Summary
This paper addresses the challenge of watermarking outputs from large language models (LLMs)—embedding subtle, detectable signals to distinguish machine-generated text from human-written content. The goal is to design a method that is robust, scalable, and minimally intrusive to output quality.
The authors propose a token biasing technique that works as follows:
At each generation step, a seeded hash function is used to partition the vocabulary into a “greenlist” (preferred tokens) and a “redlist.”
The model then slightly biases sampling towards the greenlist using a tunable parameter.
The watermark is invisible in short outputs, but detectable across longer texts using statistical hypothesis testing.
Key strengths:
Model-agnostic: No retraining required.
Low overhead: Minimal impact on inference speed or quality.
Scalable: Suitable for large models and real-time systems.
Practical: Balances watermark strength with natural text generation.
Limitations and Concerns
Vulnerability to paraphrasing: Rewriting could disrupt the statistical signature.
Adversarial robustness: Knowledgeable attackers could potentially remove or spoof the watermark.
Stylistic influence: May subtly affect semantic or creative output, especially in long-form or artistic applications.
Not cryptographically secure: Lacks strong guarantees for high-stakes verification or forensic scenarios.
Clear Explanation for Intuition
Think of the watermark as a gentle push in the model’s word choice—like favoring certain dice rolls in a way that’s imperceptible up close but noticeable when you observe many outcomes.
Each time the model generates a word, it computes a hash of the current context to decide which words get a slight boost.
Over time, this creates a hidden statistical pattern, detectable if you know what to test for.
It’s efficient, subtle, and works entirely during token sampling—no changes to the model itself.
Presentation Review
The presentation effectively demystified the watermarking process. Visual aids were especially helpful:
Diagrams showing the random seed generator, sampling mechanism, and scoring function clarified the full pipeline.
The fruit-based examples for token selection and tournament-style sampling made the technical process intuitive and engaging.
Suggestions for Improvement:
Include side-by-side examples of outputs with and without watermarking to better illustrate the subtlety and strength of the method.
Explore watermarking across longer horizons (e.g., multi-token or sentence-level biasing) to improve robustness—though this tradeoff remains an open challenge due to the fluid nature of language.
Group 14 Presentation: Scalable Watermarking for Identifying Large Language Model Outputs
Presented by:
Ryan Tymkow, Benjamin Schnapp
Paper Citation
Dathathri, S., See, A., Ghaisas, S. et al. Scalable watermarking for identifying large language model outputs. Nature 634, 818–823 (2024). https://doi.org/10.1038/s41586-024-08025-4
Background
With the rise of LLMs and generative AI, there is a risk of spreading AF generated misinformation as if it were from a human. It is important to distinguish AI-generated text from human writing. There exists solutions to address this problem, but they all have limitations involving privacy and computational costs. For example, the traditional watermarking can result in unwanted artifacts in the text.
Previous Approaches
Some of the existing approaches are Retrieval Based Tracking, Post-Hoc Detection, and Traditional Watermarking.
Retrieval Based Tracking stores generated LLM responses in a reference database to determine whether a newly generated response is from the tracked LLM. The issue with this is scalability and privacy since you are storing generated responses.
Post-Hoc Detection uses statistical features of LLM generated text. The issue with this is the high computational cost and as LLM continues to improve to be more human like with the output, the more difficult this approach becomes.
Traditional watermarking would include hidden features in the text, for example replace synonyms or unicode characters. But tinkering with the output will degrade the quality of the original LLM.
Technical Contributions
This paper introducted SynthID-Text, a watermarking method for large language models (LLMs). The method uses a Tournament sampling approach, which ensures that generated text contains a detectable watermark with minimal computational overhead. SynthID-Text incorporates a random seed generator and scoring functions to embed a watermark into the model's output. This technique enhances the ability to identify if text originates from a specific LLM, while preserving the text's quality and minimizing distortion. SynthID-Text does not affect LLM training. It can be configured as distortionary or non-distortionary.
Central to SynthID-Text is the novel Tournament sampling procedure. Rather than sampling each token directly from the LLM's distribution, multiple candidate tokens compete in a multi-layer "tournament" based on pseudorandom watermarking scores, embedding a statistical “signature” that can later be detected.
Results
Synth ID achieved a 95% true positive detection rate with 1% false positives in a 20 million interaction test on Google's Gemini chatbot. It offers high detection accuracy with minimal computational cost and can be configured for non-distortionary or distortionary watermarking.
The benefits are as follows:
Minimal impact on large language model training:
-Synth ID text can be applied to any large language model with minimal modifications, as it only alters the sampling step.
High detection accuracy with low computational cost:
-Outperforms retrieval-based tracking, post hoc detection, and traditional watermarking methods.
-Offers the best balance between computational cost, scalability, and accuracy.
-Can be integrated into production environments using speculative sampling, where smaller models suggest tokens and the main model verifies their validity.
Configurable distortion levels:
-Allows for distortionary or non-distortionary configurations, enabling better control over the quality of generated text versus detection accuracy.
-In non-distortionary watermarking, the average token distribution of the generated text matches the original model's token distribution.
Combining Watermarking with Speculative Sampling:
-SynthID-Text supports integration with speculative sampling to accelerate LLM inference.
-It introduces two configurations: high-detectability and fast watermarked speculative sampling.
-High-detectability configuration preserves watermark strength but may reduce generation speed.
-Fast configuration maintains speed and requires non-distortionary watermarking.
-The fast version uses a learned Bayesian scoring function to improve detectability.
-This enables efficient deployment of watermarking in real-world production systems.
-The approach balances performance, speed, and watermark traceability.
Group 14 Presentation: Scalable watermarking for identifying large language model outputs
Presenters
Ben Schnapp, Ryan Tymkow
Paper Citation
Dathathri, S., See, A., Ghaisas, S. et al. Scalable watermarking for identifying large language model outputs. Nature 634, 818–823 (2024). https://doi.org/10.1038/s41586-024-08025-4
Introduction
Researchers at Google DeepMind have introduced a new method for watermarking the text generated by large language models (LLMs). Their work, published in Nature, details "SynthID-Text," a watermarking scheme designed to be practical for use in production systems.
Background
LLMs are now capable of generating high-quality text that can be difficult to distinguish from human-written content. This raises concerns about the potential misuse of LLMs and the need for methods to identify AI-generated text. Existing methods for detecting AI-generated text have limitations, including computational cost, privacy concerns, and inconsistent performance. Text watermarking offers a potential solution by embedding a signal within the generated text that can be used for identification.
Main Idea
The authors developed SynthID-Text, a generative watermarking scheme that modifies the LLM's sampling procedure to embed a watermark in the generated text. This approach allows for efficient watermark detection without requiring access to the underlying LLM. SynthID-Text introduces a novel "Tournament sampling" algorithm to achieve this.
Experiments
The researchers evaluated SynthID-Text across multiple LLMs and conducted a live experiment with nearly 20 million Gemini responses to assess its performance in a real-world setting. They compared SynthID-Text with existing watermarking methods, focusing on text quality, detectability, and computational efficiency.
Results
The key findings of the paper are: - SynthID-Text preserves text quality, as confirmed by human evaluations and automated metrics.
- SynthID-Text provides improved watermark detectability compared to other methods.
- SynthID-Text has minimal computational overhead, making it suitable for large-scale production systems.
- The authors also developed an algorithm to integrate watermarking with speculative sampling, a technique used to speed up LLM text generation.
In conclusion, the paper presents SynthID-Text as a practical and effective solution for watermarking LLM-generated text, addressing key challenges in the responsible use of this technology.
Group 14 Presentation: Scalable watermarking for identifying LLM outputs:
Presented by:
Ben Schnapp,& Ryan Tymkow
Overview
With AI, specifically language models, becoming increasingly prolific, there are strong incentives driving the development of technologies capable of discerning when text has been either fully or in part generated by a machine. At present, there are two methodologies which dominate this space, post-hoc detection, and watermarking. Post-hoc detection describes an attempt to train a classifier with the intent of detecting text outputted by an LLM. While sometimes successful, it is difficult to determine with absolute certainty whether the accusations made by post-hoc detectors are accurate, and further, they often suffer when they are being used for either a language which is outside of their distribution, or they are being employed on text from a unique LLM that was not part of the training set. These limitations are some of the reasons that engineers, developers and scientists are increasingly considering methods to watermark text at the time of generation. That is, include some kind of text artifact which readily identifies the text as LLM generated without negatively impacting the quality of the outputs. The authors of this paper propose a unique method for this task, specifically by means of what they dub their "Tournament sampling algorithm"
Governing Principles
LLMs typically take in a series of tokens as input, producing a probability distribution over the space of possible next tokens, which is sampled and appended to the input sequence before calling the forward pass again. The goal in watermarking, is to imbue this sampling process with some signature unique to the generating model, such that it can be identified in a straightforward manner without the need for a second adversary model. Typically this is accomplished by combining a random seed generator, a sampling algorithm, and a scoring function.This functions by combining the random seed with a known proprietary key, before using this seed/key combo to influence sampling. The result is not deterministic, but rather pseudo-random, so that the behaviour of the LLM remains useful. The scoring function, is able to determine at the time of evaluation the probability that these tokens were drawn from our intentionally biased distribution. The key development in this paper is the introduction of many scoring functions, so that a "tournament" among scoring functions ultimately decides the which token the model selects. This is because designing a good scoring function is inherently difficult, so the authors want to bias the model toward selecting the token which is "most" identifiable of the viable tokens from the pseudorandom set.
Implementation Details
The key implementation detail worth mentioning is how the scoring function works after tournament sampling. Specifically, tournament sampling ensures that the token which is most likely to score higher under the random watermaking functions is chosen. To detect watermarks, we measure how highly a piece of text scores with respect to all watermaking functions. Specifically we compute:
[math]\displaystyle{ \text{Score}(x) = \frac{1}{mT} \sum_{t=1}^T \sum_{\ell=1}^m g_\ell (x_t,r_t) }[/math]
Where [math]\displaystyle{ x_t }[/math] is the token sequence and [math]\displaystyle{ r_t }[/math] the seed + key
Discussion / Conclusions
The key insight from this paper is that tournament sampling doesn't negatively impact performance from the user perspective. Because this was a deepmind paper, they had access to production models, and were able to complete user evaluations of output quality and A/B test with and without watermarking. Ultimately, they determined that users did not have a preference for non-watermarked text, that is that there watermarking method is able to preserve the quality of model outputs.
Group 14 Presentation: Scalable Watermarking For Identifying Large Language Model Outputs
Presented by
Ryan Tymkow and Benjamin Schnapp
Paper Citation
Dathathri, S., See, A., Ghaisas, S. et al. Scalable watermarking for identifying large language model outputs. Nature 634, 818–823 (2024). https://doi.org/10.1038/s41586-024-08025-4
Background
- Motivation: With the increasing integration of AI in content creation, distinguishing AI-generated text from human-authored content has become crucial to prevent misinformation and ensure content authenticity.
- Existing Solutions: Retrieval-based tracking involves storing all LLM outputs in a database for comparison. However, this method raises privacy concerns and faces scalability challenges. Post-hoc detection utilizes classifiers to detect AI-generated text after production but often suffers from false positives and high computational costs. Finally, traditional watermarking embeds hidden markers by altering the text, but this can degrade output quality.
- SynthID-Text modifies the token selection process during text generation to embed watermarks while preserving the original text’s quality. Unlike traditional methods, it integrates watermarking seamlessly into the generation process without necessitating changes to the LLM’s training.
- The standard LLM generation process starts off with tokenization, where the input text is divided into tokens. Then, the LLM predicts probabilities for the next token based on the input. Last, a token is chosen based on the predicted probabilities, and the process repeats. SynthID-Text embeds watermarks by altering the sampling step, ensuring the watermark is integrated without affecting the text’s coherence.
Summaries of Key Points
- SynthID-Text has 3 key components: A random seed generator, a sampling algorithm, and a scoring function. The random seed generator generates a deterministic seed based on preceding tokens using a hash function, ensuring consistency in the watermarking process without altering the probability distribution. The sampling algorithm then modifies the token selection process to embed the watermark and incorporates a scoring function to prioritize tokens that align with the watermarking scheme. The scoring function evaluates how likely a token is to be a part of the watermarked sequence. It helps to facilitate the detection mechanism by assigning higher scores to watermarked tokens.
- The tournament sampling step involves generating multiple candidate tokens and selecting the most suitable one based not he scoring function, which ensures the watermark is embedded without compromising the naturalness of the text.
- The detection mechanism is used to determine if a given text is AI generated by checking if the given text contains the embedded watermark by analyzing token sequences and their associated scores. Factors such as text length and the entropy of the LLM distribution can influence the detection accuracy.
- A benefit of SynthID-Text is that it does not require modifications to the LLM’s training process. Additionally, it has high detection accuracy, so it can effectively identify watermarked text with minimal false positives. It is also easily configured and can adjust to be distortionary or non-distortionary based on requirements. Another benefit is that it integrates seamlessly into the text generation process without a significant amount of overhead.
Explanation of Tournament Sampling Step
- Watermarking functions guide the selection of tokens during text generation.
- Although the process of selecting tokens is called “random”, these functions are deterministic, ensuring consistency in watermark embedding.
- The process involves evaluating multiple token candidates and selecting the one that best aligns with the watermarking criteria.
- During the watermarking process with SynthID, the LLM generates probabilities for the next token based on preceding tokens. Then, a set of candidate tokens is sampled and evaluated through the tournament sampling process. A watermarking seed is generated using a hash function that’s defined by a random key and the preceding tokens. This ensures the watermark is deterministically embedded into the generated text while preserving its fluency and coherence.
- The tournament sampling step is crucial because it balances the trade-off between preserving the original probability distribution of the model and embedding a detectable watermark, allowing high-quality text generation without compromising detection reliability.
Group 15 Presentation: DiGress: Discrete denoising diffusion for graph generation
Presented by:
Sean Tang, Buji Wong
Paper Citation
Vignac, C., Krawczuk, I., Siraudin, A., Wang, B., Cevher, V., & Frossard, P. (2022). Digress: Discrete denoising diffusion for graph generation. arXiv preprint arXiv:2209.14734.
Background
Graph generation
The goal of this project is to generate graphs, which are represented by node matrices and edge matrices. Edges and nodes can also have their own categories. One application of this is molecule generation: atoms would be represented by nodes and the chemical bonds would be represented by edges.
The challenge of graph generation is a complex task due to the unordered nature and sparsity of graphs. While denoising diffusion models have been successful in other domains like images, they struggle with graphs due to their structural properties. Existing approaches that use continuous noise models disrupt the sparsity and connectivity crucial for graphs.
Discrete diffusion
Regular diffusion methods involve a forward process (where noise is gradually added to the data) and a neural network that is trained to predict the backwards process (where random noise is gradually turned back into something that would plausibly fit in with the dataset). Most of the time, diffusion models use Gaussian noise.
This graph generation problem involves discrete data, therefore, a discrete noise model is preferred for the diffusion process. (Note: The authors actually previously tried to use Gaussian noise, but the continuous nature of the Gaussian function meant that the discreteness of the graphs was not respected.) In this discrete case, however, the diffusion is defined in terms of transition matrices, [math]\displaystyle{ Q^t }[/math]:
[math]\displaystyle{ q(x_t | x_{t-1}) = x_{t-1} Q^t }[/math]
Each transition matrix is applied to each node and edge; this allows the nodes and edges to remain in discrete space.
Technical Contributions
Overview of DiGress
The authors introduce DiGress, a discrete denoising diffusion model designed specifically for graph generation with categorical node and edge attributes. DiGress improves graph generation by using a discrete noise model that preserves graph sparsity and structural properties. The model involves progressively editing graphs through edge addition/removal and attribute changes. A graph transformer network is used to reverse this noisy process using cross-entropy loss, sampling from the trained model by iteratively updating the noise level and computing structural features.
Key enhancements include a noise model that maintains node and edge type distributions, a guidance procedure for conditioning on graph-level features, and the use of auxiliary graph-theoretic features. DiGress achieves state-of-the-art performance on both molecular and non-molecular graph datasets.
Denoising network architecture
- The input is a noisy graph, expressed as [math]\displaystyle{ X }[/math] (a tensor of one-hot encoded nodes) and [math]\displaystyle{ E }[/math] (a tensor of one-hot encoded edges).
- The structural and spectral features of these graphs are calculated.
- The noisy graph along with the structural and spectral features are taken as input into a multi-layer perceptron (MLP) structure.
- Next is a sequence of graph transformer layers.
- After another MLP layer, the output is a distribution over nodes and edges.
Summary of results
Results showed Digress outperformed continuous diffusion methods on various metrics, including degree distribution, clustering, and novelty, and was more scalable for larger graphs. Moreover, in the creation of novel molecules, discrete diffusion aids scalability for larger graphs and molecules, making it more efficient compared to continuous diffusion. DiGress is the first one-shot graph-based model that feasibly trains on over a million molecules without fragment-based heuristics. Its performance on drug-like molecule benchmarks reaches or exceeds that of specialized autoregressive or SMILES-based baselines.
Accomplishments
- State-of-the-art one-shot graph generation:
- DiGress outperforms existing non-autoregressive models such as GDSS, GraphNVP, and SPECTRE.
- Achieves strong performance on benchmarks including Stochastic Block Models (SBM) and planar graphs.
- Scalable molecular graph generation:
- DiGress is the first diffusion-based model to scale to large molecular datasets such as:
- MOSES (small, drug-like molecules)
- GuacaMol (1.3M large molecules)
- Matches or exceeds autoregressive and fragment-based baselines on metrics such as validity, uniqueness, novelty, and scaffold diversity.
- DiGress is the first diffusion-based model to scale to large molecular datasets such as:
- Fast and efficient performance on QM9:
- Achieves near-perfect:
- Validity: 99%
- Uniqueness: ~96%
- Trains significantly faster than prior diffusion models (1 hour vs. 7.2 hours for ConGress).
- Achieves near-perfect:
- Effective conditional generation:
- Uses regressor-guided sampling to generate molecules with desired properties (e.g., HOMO energy, dipole moment).
- Outperforms unconditional models in property control and accuracy.
- Theoretical soundness:
- Proven permutation equivariance and exchangeability for correct graph generation.
- Provides an ELBO-based training objective for likelihood estimation and model comparison.
Group 14 Presentation: Scalable Watermarking for Identifying Large Language Model Outputs
Presented by:
Ryan Tymkow and Benjamin Schnapp
Paper Citation
Dathathri, S., See, A., Ghaisas, S. et al. (2024). Scalable watermarking for identifying large language model outputs. Nature, 634, 818–823. https://doi.org/10.1038/s41586-024-08025-4
Summaries
This paper addresses the need to reliably detect AI-generated text, particularly from large language models (LLMs), in the context of growing concerns over misinformation, authorship verification, and content authenticity. Existing approaches—such as retrieval-based tracking, post-hoc statistical detection, or traditional watermarking—face challenges related to scalability, privacy, and output degradation.
The authors propose a lightweight and scalable watermarking method that subtly biases the LLM’s sampling process without modifying its architecture or requiring retraining. The method partitions the vocabulary at each generation step into two lists using a seeded hash function:
Greenlist: tokens favored for selection
Redlist: tokens slightly penalized
By biasing the LLM to favor greenlist tokens, a statistically detectable pattern is formed over longer texts. This watermark is invisible in short sequences but detectable through hypothesis testing in larger samples. The approach is model-agnostic and adds minimal overhead, making it suitable for real-time or high-throughput deployment.
Key Contributions
Token-Level Watermarking via Sampling Bias: Rather than embedding hidden characters or syntactic artifacts, the method adjusts token selection probabilities via a seeded hash function.
Statistical Detection Over Long Outputs: The watermark is not intended to be seen in any one sentence, but accumulates over time and is verified through statistical analysis.
Model-Agnostic and Training-Free: This method requires no retraining and can be applied to any autoregressive LLM.
Low Computational Cost: Because it operates only at sampling time, the watermarking process adds negligible runtime cost.
Configurable Distortion: The bias strength is adjustable, enabling trade-offs between output quality and detection strength. In non-distortionary mode, the output token distribution closely resembles the unwatermarked model.
Constructive Critiques or Reviews
The watermark can be vulnerable to paraphrasing or text manipulation, which may erase or weaken the statistical signal.
Since the watermarking logic (vocabulary partitioning via hashing) is relatively transparent, adversarial users might reverse-engineer or spoof the watermark.
There is minor concern about stylistic or semantic drift, especially in long-form generation or creative tasks.
The method lacks cryptographic strength, limiting its application in high-assurance or legal verification scenarios.
Related Works
Retrieval-Based Tracking: Stores model outputs in a reference database. Issues: privacy concerns, poor scalability.
Post-Hoc Detection: Classifies AI vs. human text using statistical features. Becomes less effective as LLMs improve and mimick human writing.
Traditional Watermarking: Inserts visible or invisible tokens or character-level perturbations (e.g., synonyms, Unicode tweaks). Problematic due to quality degradation and easy circumvention.
In contrast, this paper’s statistical watermarking strikes a better balance among robustness, usability, and deployment feasibility—particularly suitable for integration into production systems like Google’s Gemini, which reported 95% true positive rate with only 1% false positives over 20 million interactions using this method.
Group 14 Presentation: Scalable Watermarking for Identifying Large Language Model Outputs
Presented by:
Ryan Tymkow and Benjamin Schnapp
Paper Citation
Dathathri, S., See, A., Ghaisas, S. et al. (2024). Scalable watermarking for identifying large language model outputs. Nature, 634, 818–823. https://doi.org/10.1038/s41586-024-08025-4
Summaries
The paper presents SynthID-Text, a watermarking scheme designed to detect synthetic text generated by large language models (LLMs). The authors aim to address the challenge of distinguishing between human-written and AI-generated content, which has become increasingly difficult with the advancement of LLMs.
Key Contributions
The authors introduce SynthID-Text, a watermarking method that integrates with the sampling process of LLMs to embed identifiable markers into generated text. This approach allows for the detection of AI-generated content without compromising text quality. The paper reports high detection accuracy of the watermarked text across multiple LLMs. Standard benchmarks and human evaluations indicate that SynthID-Text maintains the original capabilities and quality of the LLM outputs.
Constructive Critiques or Reviews
The study primarily focuses on general text generation. Evaluating SynthID-Text's performance across various text genres and domains, such as technical writing or creative literature, would offer insights into its versatility and potential limitations.
Group 15 Presentation: DiGress: Discrete denoising diffusion for graph generation
Presented by:
Sean Tang, Buji Wong
Paper Citation
Vignac, C., Krawczuk, I., Siraudin, A., Wang, B., Cevher, V., & Frossard, P. (2022). Digress: Discrete denoising diffusion for graph generation. arXiv preprint arXiv:2209.14734.
Summaries of key points
Motivation
Recall that in class we already learned diffusion models, they work by adding noise to data gradually (forward process) and training a model to reverse the process, restoring the clean data (reverse process). Traditional diffusion models are designed for continuous data, but graphs are genereally discrete. This discrete structure makes it difficult to directly apply traditional diffusion models to graphs.
DiGress tackles this issue by introducing a method that applies diffusion to discrete graphs through the concept of discrete diffusion.
Forward Diffusion Process: How to Add Discrete Noise?
As mentioned above, DiGress uses discrete noise. In DiGress, noise is added by changing the categories of nodes and edges (such as atom types or bond types). This means the model randomly selects new categories for nodes and edges from a predefined set of valid categories. To guide this noise addition, DiGress uses a transition matrix. The transition matrix defines the probability of each node and edge transitioning between categories at each step. This ensures that while the graph becomes noisier over time, it still maintains its structure and validity.
To improve the quality of the noisy graphs during diffusion, DiGress introduces a Markovian noise model that preserves the marginal distributions of node and edge types observed in the training data. Rather than using uniform transitions, which can create unrealistic dense graphs, this model defines transition matrices where the probability of changing to a new category is proportional to its frequency in the dataset. This approach helps maintain realistic sparsity and structure throughout the diffusion steps, making the reverse denoising process easier and more effective.
Reverse Diffusion Process: How to Denoise a Graph?
In the reverse process, DiGress gradually removes the noise from the graph. Instead of randomly undoing the noise, the model uses a graph transformer network, designed for graph-structured data. This network helps the model recognize the structure of the graph and the relationships between nodes and edges. During each step, the model focuses on the most relevant parts of the graph, predicting the correct categories for nodes and edges. And the model’s predictions are guided by cross-entropy loss (applied to both nodes and edges), which measures how accurately the model predicts the node and edge categories after denoising. By minimizing this loss, the model becomes better at removing the noise, step by step, until it recovers a valid and meaningful graph.
Conditioning Graph Generation on Desired Properties
One of the most powerful features of DiGress is its ability to condition the graph generation process on specific properties. For example, if you want the generated graph to have a certain number of oxygen atoms or satisfy some chemical property, DiGress can adjust the generation process to meet those requirements. This is done by checking the graph at each step of the sampling process and modifying it as needed to match the desired properties. This capability is particularly useful in areas like drug discovery, where the generated molecules must meet certain chemical and structural criteria.
Experimental Results
DiGress was tested on several datasets to evaluate its performance:
1. Graph Generation: On datasets like the Stochastic Block Model (SBM) and planar graphs, DiGress outperformed other methods, particularly in generating novel graphs that were not seen during training.
2. Molecule Generation: When applied to the MOSES dataset, DiGress produced more valid and novel molecules, even though it did not always surpass methods that check graph validity at every step.
3. Scalability: On larger graphs, such as those from the Guacamole dataset, DiGress demonstrated strong scalability, making it a suitable option for generating larger and more complex graphs.
Comparison with Existing Approaches
- Versus Autoregressive Models: These models (like GraphAF, GraphRNN) generate graphs node-by-node or edge-by-edge and often rely on ordering, making them slower and harder to parallelize.
- Versus Continuous Diffusion Models for Graphs: E.g., GDSS uses Gaussian noise and struggles with categorical data. DiGress handles discrete data directly, making it more suitable for molecular and symbolic domains.
- DiGress Advantage: Fully parallel sampling, no need to learn generation order, works better for discrete structured data.
Group 15 Presentation: DIGRESS: DISCRETE DENOISING DIFFUSION FOR GRAPH GENERATION
Motivation
If one is familiar with graph theory, you will not need convincing that they are some of the most useful mathematical objects. Consisting of vertices (nodes) and edges, graphs can be used to represent a wide variety of relational data types. Examples include:
Molecules, where atoms are nodes and bonds are edges
Social networks, where people are nodes and connections are edges
Traffic systems, where intersections are nodes and roads are edge
etc etc.
As such, one could imagine why being able to generate such data types would be useful for research purposes. One could imagine training a generative model to generate all kinds of plausible graphs, for purposes such as material science or drug discovery. If one is familiar with graph theory, you will not need convincing that graphs are some of the most powerful and flexible mathematical structures. Consisting of vertices (nodes) and edges, graphs are capable of representing a wide range of relational data types. Examples include: Traditional generative models struggle with graph data because of its non-Euclidean structure, as such there is a use case for a model specifically tuned to this use case.
Operating principle
The key thought process behind DiGress is to adapt denoising diffusion probabilistic models (DDPMS) to a graph domain. For the forward process,the DiGress authors define a markov chain of noisy graphs, with each step gradually adding noise to the node features and the graph structure. At each step [math]\displaystyle{ t }[/math] the graph [math]\displaystyle{ G_t }[/math] will be a noised version of the graph [math]\displaystyle{ G_{t-1} }[/math]. At time T, the graph [math]\displaystyle{ G_T }[/math] can be thought of as a real graph having been turned into pure noise. The goal of DiGress is to learn how to reverse this noising process, or to generate a feasible graph (given a certain training data) provided a noisy input signal. This process begins by sampling noise from a given distribution, after which the denoising network is used to predict the noise that was added at that step. This process is repeated, sometimes with the addition of a small amount of fresh noise, until the sample is no longer noisy and we are left with a believable sample.
Technical Details
Similar to image diffusion models, which apply noise to each pixel independently, DiGress applies noise on each node and edge. As a result, the state space is not that of all possible graphs, which would be too enormous because of the explosion in the size of the transition matrix. The denoising network is traine dto minimize a cross-entropy loss between predicted and true node/edge dsitributions according to:
[math]\displaystyle{ l(\hat{p}^G, G) = \sum_{1 \leq i \leq n} \text{CE}(x_i, \hat{p}_X^i) + \lambda \sum_{1 \leq i,j \leq n} \text{CE}(e_{ij}, \hat{p}_E^{ij}) }[/math] The model learns a reverse diffusion distribution [math]\displaystyle{ p_\theta(G^{t-1} | G^t) }[/math] for sampling clean graphs
Results / Conclusion
DiGress has thus far proven to be SOTA for both molecular and non momlecular datasets. Specifically, it achieved SOTA performance on the GuacaMol dataset, containing >> 1.3 million drug like compounds.
Group 15 Presentation: DiGress: Discrete denoising diffusion for graph generation
Presented by:
Sean Tang, Buji Wong
Paper Citation
Vignac, C., Krawczuk, I., Siraudin, A., Wang, B., Cevher, V., & Frossard, P. (2022). Digress: Discrete denoising diffusion for graph generation. arXiv preprint arXiv:2209.14734.
Summaries
The paper proposed a Discrete Graph Denoising Diffusion model (DiGress) to generate graphs. This model can overcome the limitations that existing graph generation models have that are in a continuous space but destroy the graph's sparsity and connectivity. The model can handle graphs with categorical nodes and edge attributes. The model has two steps: the diffusion process that adds noise to the original graph, then running a transformer network to invert the noise. This model satisfies the properties of efficiency. The model specifically tackles the challenge that graphs are insensitive to the order of nodes. The paper also discusses the choice of the noise model which should should have a distribution close to the true data distribution. The model also has the ability to generate graphs conditional on graph-level which would be beneficial for real-life applications.
Key Contributions
Diffusion process: The diffusion process has two steps: adding noise to the original graph by the noise model, then invert the noises by the denoising neural network. The diffusion is done separately for each node and edge attribute.
Noise model: The noise model creates a sequence of noisy data points that has a Markovian structure with the transition probability [math]\displaystyle{ q(z^1,...z^T | x)=q(z^1 | x) \prod_{t=2}^T {q(z^t | z^{t-1})}. }[/math]
Denoising neural network: This process predicts noise [math]\displaystyle{ z^{t-1} }[/math] from [math]\displaystyle{ z^t }[/math]. This is done by minimizing the cross-entropy loss for each node and edge attribute: [math]\displaystyle{ l(\hat{p}^G, G) = \sum_{i \le i \le n} {cross-entropy (x_i, \hat{p}_i^X)} + \lambda \sum_{1 \le i, j \le n} {cross-entropy (e_{ij}, \hat{p}_{ij}^E)} }[/math], where X is the matrix of all one-hot encodings, E is the tensor, [math]\displaystyle{ \hat{p} }[/math] is the predicted probability, i, j represent rows and columns in the matrix for nodes. This formula considers the relative importance of nodes and edges by including a parameter [math]\displaystyle{ \lambda }[/math]. Then the distribution of the next noisy graph is calculated. The next noisy graph is sampled from this distribution and used as the input for the next step.
Properties of Efficiency: This is satisfied by the existing continuous graph generation methods, which should also be met for the discrete graph generation methods.
- The distribution of the noisy data has a closed-form formula. This achieves efficiency because the noises are not added recursively.
- The posterior distribution [math]\displaystyle{ q(z_{t-1} | z_t, x) }[/math] has a closed-form formula. With this condition, the original data points can be used as the target of the denoising network.
- The limit distribution of the noise model [math]\displaystyle{ q_{\infty} = \lim_{T \rightarrow \infty} {q(z^T | x)} }[/math] does not depend on the original data points. With this condition, this distribution can be used as the prior distribution to be more efficient.
Related Works
Existing works have utilized Gaussian noise in a continuous setting to add noise to the node features and graph adjacency matrix (Nui et al., Jo et al., 2022). However this method doesn't keep the most significant characteristics for graphs that they are sparse, insensitive to order of nodes and other structural properties.
Other existing works that explored discrete diffusion model are applied to images, texts and audios and haven't the unique challenges of generating graphs.
Group 15 Presentation: DiGress: Discrete Denoising Diffusion for Graph Generation
Presented by:
Sean Tang, Buji Wong
Paper Citation
Vignac, C., Krawczuk, I., Siraudin, A., Wang, B., Cevher, V., & Frossard, P. (2022). Digress: Discrete denoising diffusion for graph generation. arXiv preprint arXiv:2209.14734.
Introduction
Researchers have introduced a new method for generating graphs with categorical node and edge attributes called DiGress. This work, published in the proceedings of the ICLR 2023 conference, presents a discrete denoising diffusion model designed for graph generation.
Background
Diffusion models have shown remarkable success in various domains, particularly in image and video generation. The ability of diffusion models to outperform other generative methods in these areas has motivated researchers to explore their potential for graph generation. However, generating graphs presents unique challenges due to their unordered nature and sparsity. Previous approaches that applied continuous diffusion models with Gaussian noise to graphs have struggled to preserve the inherent structural properties of graphs, such as sparsity and connectivity.
Main Idea
DiGress addresses these challenges by employing a discrete denoising diffusion model. The model progressively edits graphs through a Markov process that involves adding or removing edges and changing node or edge categories. A graph transformer network is then trained to reverse this process, effectively learning to denoise the graph. This approach simplifies the complex task of graph distribution learning into a series of node and edge classification tasks.
Experiments
The authors evaluated DiGress on both molecular and non-molecular datasets, comparing its performance against state-of-the-art graph generation methods. Their experiments included:
- Unconditional generation on stochastic block model (SBM) and planar graph datasets.
- Conditional generation on the QM9 dataset to assess the model's ability to generate graphs with specific properties.
- Large-scale molecule generation on the MOSES and GuacaMol datasets.
Results
The results of the experiments demonstrate that DiGress achieves state-of-the-art performance in graph generation. Key highlights include:
- DiGress exhibits strong performance on both molecular and non-molecular datasets.
- The model shows significant improvements in validity, particularly on the planar graph dataset.
- DiGress is the first model to successfully scale to the large GuacaMol dataset, achieving performance comparable to autoregressive models.
In summary, DiGress introduces a novel and effective approach to graph generation by leveraging discrete denoising diffusion models. The model's ability to handle discrete graph structures and scale to large datasets represents a significant advancement in the field of graph generation.
Group 16 Presentation: Machine Learning and Hamilton-Jacobi-Bellman Equation for Optimal Decumulation: a Comparison Study
Presented by:
Zeyu Zhang
Paper Citation
Chen M, Shirazi M, Forsyth PA, Li Y. Machine Learning and Hamilton-Jacobi-Bellman Equation for Optimal Decumulation: a Comparison Study. Published online 2023. doi:10.48550/arxiv.2306.10582
Background
The paper is based on computational finance, focusing on the optimization problem related to "defined benefit" and "defined contribution plans". The main focus is on the challenge of ensuring retirees have enough funds for their retirement. Two key plans were discussed:
"Defined benefit plans" guarantee fixed monthly payments based on factors like tenure and salary but are cost-prohibitive and risky.
"Contribution plans" shift the investment and withdrawal strategy burden to individual investors, but they struggle to balance maximizing withdrawals and minimizing risk.
This problem, often called the "Nazi's hardest problem in finance," highlights the complexity of balancing risk and reward in financial planning for retirement.
The 4% rule is a traditional method recommending a constant 4% withdrawal each year, adjusted for inflation, and investing in stocks and bonds.
Despite its popularity, the 4% rule is suboptimal and not globally optimal
Peter Fauci proposed the HJB PDE method to maximize expected withdrawal and minimize the risk of running out of savings.
The HJB PDE method uses scalarization techniques to achieve Pareto optimal points, but it has limitations.
Technical Contributions
1. Hamilton-Jacobi-Bellman (HJB):
- The problem formulation involves complex mathematical equations related to computational finance.
- The problem uses dynamic programming to break down the optimal control problem, leading to the HJB function that represents the value function.
- The paper assumes stock and bond prices follow a jump diffusion model.
- The investors' total wealth at time [math]\displaystyle{ t }[/math] is defined as the sum of stock price and bond price at that time.
- The capital [math]\displaystyle{ T }[/math] is set to 30 years, and rebalancing times are defined with discrete withdrawal amounts and allocation for stocks and bonds.
2. Neural Network (NN): Control and Objective Function:
- The control at time [math]\displaystyle{ T_i }[/math] includes the withdrawal amount [math]\displaystyle{ Q_i }[/math] and allocation for the wealth at time [math]\displaystyle{ T_i^- }[/math].
- The admissible control set is defined, and the expected shortfall is introduced as a measure of risk.
- The expected total withdrawal is used as a measure of reward, aiming to maximize the expected total withdrawal while minimizing the expected shortfall.
- The pre-commitment in the expected shortfall problem is defined, focusing on maximizing the expected total withdrawal and minimizing the expected shortfall.
Neural Network (NN) Formulation
As an alternative to the HJB framework, the authors propose a Neural Network (NN) approach to solve the stochastic control problem. This framework has several advantages:
1. The NN approach is data-driven, meaning it avoids explicitly defining parametric models for stochastic processes. This provides flexibility and allows integration of auxiliary variables if needed.
2. It circumvents the computation of high-dimensional conditional expectations by solving a single, unconstrained optimization problem for control decisions. This avoids the curse of dimensionality often associated with dynamic programming.
3. If the optimal control is continuous in time and state, the NN reflects this property. If the control is discontinuous, the NN yields a smooth approximation, which is beneficial in practice for implementing investment policies.
4. The method is scalable, making it suitable for long horizons and high-frequency rebalancing without significantly increasing computational complexity.
The NN framework’s effectiveness lies in its ability to learn control policies directly from simulated state paths without requiring explicit knowledge of the underlying stochastic differential equations.
Instead of solving high-dimensional HJB PDEs, the neural network uses:
- Forward simulation to sample wealth trajectories under candidate policies.
- Backward evaluation to update network parameters based on performance (e.g., maximizing expected withdrawals, minimizing expected shortfall).
This model-free, data-driven method avoids dynamic programming and is especially useful in high dimensions, where solving PDEs becomes computationally infeasible.
Moreover, by designing appropriate activation functions (e.g., softmax for portfolio weights and sigmoid for withdrawal rates), the NN ensures that stochastic constraints are naturally respected throughout training and inference.
NN Approximation Setup
- The control policy [math]\displaystyle{ \mathcal{P} }[/math] is approximated using two feed-forward neural networks, with parameters [math]\displaystyle{ \boldsymbol{\theta}_q }[/math] and [math]\displaystyle{ \boldsymbol{\theta}_p }[/math], representing withdrawal and allocation strategies respectively.
- These networks take as inputs the Brownian motion path [math]\displaystyle{ W(t_i) }[/math] and time [math]\displaystyle{ t_i }[/math] to approximate control decisions:
[math]\displaystyle{ \hat{q}(W_i^-, t_i^-, \boldsymbol{\theta}_q) \approx q_i(W_i^-), \quad \hat{p}(W_i^+, t_i^+, \boldsymbol{\theta}_p) \approx p_i(W_i^+) }[/math]
- The final control policy is given by: [math]\displaystyle{ \hat{\mathcal{P}} = \{ (\hat{q}(\cdot), \hat{p}(\cdot)) \} \approx \mathcal{P} }[/math]
- The functions [math]\displaystyle{ \hat{p} }[/math] and [math]\displaystyle{ \hat{q} }[/math] use time as one of the inputs, allowing a single NN to handle decisions across all rebalancing points, rather than training separate models for each time step.
- The paper also discusses how the architecture includes activation functions that enforce stochastic constraints naturally.
Summaries of Key Notes
- Neural Network Framework for Pension Decumulation: A novel framework using neural networks (NNs) is proposed to optimize asset allocation and cash withdrawal strategies for defined contribution (DC) pension plans. Unlike traditional methods, it solves constraints efficiently via unconstrained optimization.
- Comparison with HJB Method: The NN approach achieves comparable accuracy to the Hamilton-Jacobi-Bellman (HJB) PDE method while being scalable to higher-dimensional problems and avoiding dynamic programming errors.
- Efficient Withdrawals: The NN framework closely approximates optimal "bang-bang" controls, effectively alternating between minimum and maximum withdrawals based on wealth, ensuring reliable pension decumulation.
- Robustness: Tested extensively on synthetic and historical market data, the NN solution adapts well and demonstrates strong out-of-sample and out-of-distribution performance. Advantages:
Constructive Critique
While the NN method replicates HJB-derived control policies with high accuracy, a few limitations and caveats exist:
- The training relies on synthetic data simulated from assumed models (e.g., geometric Brownian motion, jump diffusion), which may limit generalizability under real-world dynamics.
- The scalarization of the reward-risk tradeoff assumes a linear weighting of Expected Withdrawals and Expected Shortfall. In practice, retiree preferences might reflect more complex, utility-based behaviors.
- The interpretability of the learned policy is less transparent compared to explicit, closed-form control strategies derived via HJB.
- No formal convergence analysis or approximation bounds are provided for the NN solution.
Despite these challenges, the method is empirically robust and scalable, making it an appealing alternative for large-scale or real-time applications.
Related Works
The work is also related to "Spending Retirement on Planet Vulcan: The Impact of Longevity Risk Aversion on Optimal Withdrawal Rates", which introduces utility-based models for withdrawal strategies under longevity risk aversion. The models focus on retirees' behavioral preferences — particularly longevity risk aversion and the need for smooth consumption throughout retirement.
While not PDE-based, it contextualizes the trade-offs between consumption smoothing and risk of depletion, complementing the (EW, ES) approach by addressing behavioral and utility-driven objectives.
On the other hand, the benchmarkNNpaper takes a risk-sensitive stochastic control approach, optimizing a scalarized objective that balances:
Expected Withdrawals (EW) = proxy for consumption
Expected Shortfall (ES) = proxy for downside risk / depletion probability
Group 16 Presentation: Machine Learning and Hamilton-Jacobi-Bellman Equation for Optimal Decumulation: a Comparison Study
Presented by:
Zeyu Zhang
Paper Citation
Chen M, Shirazi M, Forsyth PA, Li Y. Machine Learning and Hamilton-Jacobi-Bellman Equation for Optimal Decumulation: a Comparison Study. Published online 2023. doi:10.48550/arxiv.2306.10582
Background & Motivation
The paper focuses on addressing a stochastic optimal control problem in retirement decumulation and asset allocation, which is a critical issue in financial planning. Specifically, it investigates how retirees can optimally withdraw funds from their savings (decumulation) while simultaneously managing asset allocation under uncertain market conditions.
Traditionally, rules of thumb such as the Bengen 4% Rule have been widely adopted in the financial industry to guide withdrawal strategies. However, these approaches are increasingly viewed as suboptimal, particularly in an environment characterized by volatile markets and evolving mortality patterns. Recent academic studies, such as Forsyth (2022), propose partial differential equation (PDE)-based methods that are provably convergent and optimal under specific assumptions. Nevertheless, PDE methods face significant limitations in scalability due to the curse of dimensionality, often performing well only in low-dimensional settings.
The motivation for this paper is to overcome the limitations of PDE-based approaches by leveraging neural networks (NNs) to solve the decumulation and asset allocation control problem. The authors aim to evaluate whether deep learning can accurately and robustly approximate the solution to this high-dimensional stochastic control problem, as well as whether it provides computational advantages.
Key Points
1. Problem Formulation: The paper formulates the decumulation problem as a stochastic optimal control problem, aiming to optimize a weighted sum of expected withdrawals (EW) and expected shortfall (ES) to effectively manage tail risk. Key constraints include minimum and maximum withdrawal limits, as well as no-shorting and no-leverage rules.
2. HJB Framework: The Hamilton-Jacobi-Bellman (HJB) approach employs dynamic programming to solve the problem numerically, providing a ground-truth benchmark for comparison. However, this method is computationally limited to low-dimensional problems and relies on parametric models for asset returns, which may not capture real-world complexities.
3. NN Framework: The proposed neural network (NN) framework directly approximates the control functions (withdrawal and allocation) using feed-forward networks with customized activation functions designed to enforce the specified constraints. This data-driven approach bypasses the need for dynamic programming and demonstrates scalability to higher-dimensional problems.
4. Comparative Results: On synthetic data, the NN solution achieves performance nearly identical to that of the HJB method, showcasing its high accuracy in approximating the optimal control policy, including complex "bang-bang" withdrawal strategies.
5. Robustness: The NN framework exhibits strong performance in out-of-sample and out-of-distribution tests, such as bootstrap-resampled historical data, thereby demonstrating its generalizability beyond the training distribution.
Contributions
- Demonstration of neural networks as reliable solvers for constrained stochastic control problems, which were previously addressable only through partial differential equations (PDEs).
- Quantitative benchmark comparisons between NN-based and PDE-based methods reveal near-equivalent accuracy, particularly in replicating key features such as the efficient frontier and optimal withdrawal behavior.
- The proposed approach is scalable to higher dimensions, unlike PDE-based methods, making it potentially transformative for real-world retirement planning problems that involve multiple assets or stochastic factors.
- The authors demonstrate that regularization within the NN framework helps mitigate instability in regions of the state space where the control problem becomes ill-posed (e.g., high wealth levels or near terminal time).
- The method provides continuous-time control outputs by explicitly incorporating time as an input to the network, ensuring smooth solutions when required.
A key innovation in the neural network (NN) formulation is the use of customized activation functions to enforce feasibility of withdrawal and allocation controls. Instead of training under constrained optimization, the authors design activation functions that inherently respect control bounds (e.g., no shorting, no leverage, and withdrawal limits). The withdrawal control uses a modified sigmoid function that maps outputs to the valid withdrawal range, which depends on wealth. The allocation control uses a softmax activation to ensure portfolio weights are non-negative and sum to one. This allows training via standard unconstrained optimization, significantly simplifying the optimization process while ensuring all control outputs remain feasible throughout the training and inference phases.
Constructive Critiques
- Ill-Posed Regions: The NN and HJB solutions diverge in high-wealth regions near the terminal time due to the problem's ill-posedness. While the authors argue this has negligible impact on objectives, further analysis of how this affects real-world implementation would strengthen the paper.
- Training Complexity: The NN requires transfer learning for high κ values (weighting ES more heavily), suggesting potential instability in risk-averse scenarios. A deeper exploration of training challenges and solutions would be valuable.
- Historical Data Limitations: The bootstrap resampling tests rely on U.S. market data (1926–2019). Including non-U.S. data or stress-testing during extreme market conditions (e.g., hyperinflation) could enhance robustness claims.
- Computational Costs: While the NN avoids dynamic programming, the computational expense of training large networks is not quantified. A comparison of runtime between HJB and NN methods would clarify trade-offs.
Relationships to Other Works
This work builds on the stochastic control literature, particularly the decumulation problem studied in Forsyth (2022), which employs PDE-based methods. The current paper extends this research by providing a data-driven and high-dimensional alternative.It conceptually aligns with deep FBSDE methods, Deep Galerkin methods used for solving HJB equations, as well as reinforcement learning (RL)-based approaches to optimal control, such as the Deep Deterministic Policy Gradient.Compared to prior studies, such as Han and E (2016), Buehler et al. (2019), and Laurière et al. (2021), the current paper places emphasis on benchmarking against a well-established numerical method (PDEs), an aspect often overlooked in other NN-based control studies. The proposed method falls within the Policy Function Approximation (PFA) framework outlined in Powell (2021), providing a robust example of utilizing fixed neural networks to approximate control functions across time and state dimensions.
Group 16 Presentation: Machine Learning and HJB Equation for Optimal Decumulation – A Comparison Study
Presented by:
Zeyu Zhang
Paper Citation
Chen M., Shirazi M., Forsyth P.A., Li Y. *Machine Learning and Hamilton-Jacobi-Bellman Equation for Optimal Decumulation: A Comparison Study*. arXiv, 2023. doi: [10.48550/arXiv.2306.10582](https://doi.org/10.48550/arXiv.2306.10582)
What’s the Study About?
Retirement planning isn’t just about saving—it’s also about *withdrawing* wisely. This paper looks at how retirees can manage their savings smartly over time while investing in both stocks and bonds. The focus? Finding the **best way to withdraw** without running out of money or taking on too much risk.
It compares two methods: 1. **Traditional HJB PDE-based optimization** (fancy math from control theory). 2. **Modern neural network (NN)-based learning** (using machine learning to approximate good decisions).
The challenge is often called the *“Nazi’s hardest problem in finance”*—how to balance withdrawal needs, inflation, investment returns, and the fear of running out of money.
Core Contributions
**1. HJB Approach – Classic but Computationally Heavy**
- Uses dynamic programming and mathematical modeling to find the value function. - Models include risky asset returns, expected withdrawal needs, and allocation strategies. - Precise but often limited to low-dimensional cases due to complexity.
**2. Neural Network (NN) Control – New, Scalable, and Flexible**
- Treats the control strategy (how much to withdraw and how to allocate funds) as a learning problem. - Learns from simulated trajectories without solving complex equations. - Naturally handles constraints like no-short-selling or withdrawal limits. - More flexible in high dimensions and generalizes well to different market conditions.
Big Idea
They set up a reward function balancing two goals: - Maximize total withdrawals (you want money to spend). - Minimize shortfall (you don’t want to run out).
The NN learns how to make these decisions across time, adapting its strategy as the market evolves.
Experimental Setup
- Simulates investment scenarios using real historical market data (e.g., bootstrap sampling from 1926–2019). - Uses feedforward networks and backward training for better accuracy and constraint satisfaction. - Approximates complex decisions (like "bang-bang" control where you go all-in or all-out) very effectively.
Key Results & Takeaways
- **Accuracy**: The NN performs nearly as well as the HJB method—even for tricky withdrawal rules. - **Efficiency**: No need to solve high-dimensional PDEs; the NN handles complexity better. - **Scalability**: Can handle more assets, longer time horizons, and stochastic market patterns. - **Robustness**: Generalizes well to unseen data, even under distribution shifts.
Pros of the NN Approach
- Avoids the "curse of dimensionality" that haunts HJB methods. - Works well with noisy and historical data. - Automatically learns smooth solutions—even when traditional solutions are discontinuous. - Better suited for real-world policies where exact math assumptions don’t always hold.
Critiques & Limitations
- **Ill-posed regions**: Control can become unstable near wealth boundaries or end-of-life planning. - **Training complexity**: Transfer learning may be needed for extreme cases (like very risk-averse clients). - **Historical data bias**: Based on U.S. market data—might miss global scenarios or crises. - **Compute cost**: Large NNs still require serious GPU time, even without PDEs.
Related Work
- Builds on Forsyth (2022), which used PDEs for optimal retirement strategies. - Closely linked to reinforcement learning (RL), especially policy gradient methods. - Part of the broader field of stochastic control and financial engineering.
Group 18 Presentation: HIGEN: HIERARCHICAL GRAPH GENERATIVE NETWORKS
Presented by:
- Shiyu Zhu
- Jesse Xue
Paper Citation
M. Karami, “HiGen: Hierarchical Graph Generative Networks,” 2023, arXiv. doi: 10.48550/ARXIV.2305.19337.
Background
- Standard softmax function uses quadratic computational complexity during prediction, and linear attention mechanism leads to an underperforming model.
- Hierarchical or Multi-Scale Structure: Captures high level interactions or relationships between objects or groups, while also representing the lower level structures. One example is a company org chart.
- Existing graph generating models include: Variational Autoencoders, Generative Adversarial Networks, Autoregressive Models (GNN, Graph RNN, GRAN) and Diffusion Models
- Paper introduces HIGEN, Hierarchical Graph Generative Networks to address problems with existing models.
- Experiments were conducted on 5 datasets, each with increasing size and scale. The GraphRNN, GRAN, DiGress, GDSS, and SPEC
Technical Contributions
- Related Hierarchical Methods: The presentation discusses several recent hierarchical methods in specific domains like chemistry, highlighting HiGen's broader applicability, multi-level approach, and parallel generation as advantages over these more specialized techniques. These include a multi-based generation for molecular graphs (2020) relying on domain knowledge, a hierarchical normalizing flow model (2021) based on local neighborhoods, and a tree decomposition framework (2022) limited to a single abstraction level and medium-sized graphs.
Definition: Hierarchical Graph
A Hierarchical Graph is a multi-level representation of a graph [math]\displaystyle{ \mathcal{G} = (\mathcal{V}, \mathcal{E}) }[/math] where:
- [math]\displaystyle{ \mathcal{V} }[/math] is the set of nodes (vertices), and [math]\displaystyle{ \mathcal{E} }[/math] is the set of edges, with sizes [math]\displaystyle{ n = |\mathcal{V}| }[/math] and [math]\displaystyle{ m = |\mathcal{E}| }[/math].
- A node partition function [math]\displaystyle{ \mathcal{F}: \mathcal{V} \rightarrow \{1, ..., c\} }[/math] groups nodes into [math]\displaystyle{ c }[/math] communities or clusters.
- Each cluster forms a subgraph [math]\displaystyle{ \mathcal{C}_i = (\mathcal{V}(\mathcal{C}_i), \mathcal{E}(\mathcal{C}_i)) }[/math] with adjacency matrix [math]\displaystyle{ A_i }[/math].
- Cross-links between communities form bipartite graphs [math]\displaystyle{ \mathcal{B}_{ij} = (\mathcal{V}(\mathcal{C}_i), \mathcal{V}(\mathcal{C}_j), \mathcal{E}(\mathcal{B}_{ij})) }[/math].
- Each cluster is aggregated into a super-node and each bipartite into a super-edge at the next higher level. This forms a coarser graph at the parent level.
Methodology
HiGen (Hierarchical Graph Generative Networks) adopts a coarse-to-fine approach for generating graphs with complex hierarchical structures. The method consists of the following components:
1. Hierarchical Graph Construction
The input graph is recursively partitioned into communities using the Louvain clustering algorithm, forming a multi-level hierarchy. Each cluster (community) is abstracted as a "super-node" at the next level, and cross-community connections become "super-edges."
2. Community and Bipartite Generation
* Community Generation: Each community is generated independently using a GNN-based autoregressive model that factorizes the edge distribution into multinomial components. A mixture of multinomials is used to model intra-community edge structures. * Bipartite Generation: Once communities are generated, bipartite graphs (cross-community edges) are created using a separate neural network. These inter-cluster edges are predicted with a parallel GNN model using a similar factorization strategy.
3. Autoregressive Probabilistic Modeling
HiGen decomposes the joint probability of the graph into a product of conditional multinomial distributions. Both community and bipartite graphs are generated step-by-step using a recursive generation process guided by node and cluster embeddings.
4. Parallelism and Invariance
The hierarchical structure allows parallel generation of communities and bipartite edges at the same level, improving efficiency. The model is also invariant to node ordering within clusters, which improves robustness and scalability.
This design enables HiGen to generate large and complex graphs while maintaining global structure and fine-grained local details. It supports realistic synthesis across diverse graph types including molecules, biological networks, and social systems.
Summaries of Key Notes
- HiGen: Hierarchical Graph Generative Networks (HiGen) is a novel graph generation model that addresses the limitations of existing methods by leveraging hierarchical structures in graphs. It generates substructures in a coarse-to-fine manner, modeling the interactions between communities and cross-edges at multiple levels of abstraction.
- Community and Bipartite Generation: HiGen utilizes modular generation, creating communities in parallel followed by predictions of cross-edges with separate neural networks, ensuring scalability and efficiency. It employs multinomial distributions with recursive factorization, enabling autoregressive generation of integer-valued edge weights within communities.
- State-of-the-Art Performance: HiGen demonstrates superior graph quality across benchmark datasets compared to competing models like GRAN, GraphRNN, GDSS, and diffusion-based methods. It captures both local and global graph statistics effectively, achieving state-of-the-art metrics for graph degree, clustering coefficients, and Laplacian spectra.
- Scalability: The hierarchical structure of HiGen facilitates parallel graph generation, reduces sensitivity to node ordering, and enables block-wise processing of adjacency matrices, making it adaptable to large and complex graphs. Applications: HiGen excels in generating realistic graphs for applications in molecular modeling, protein structure analysis, data network synthesis, and more, addressing diverse domains where hierarchical graph structures are pivotal.
Results
- Experiments are conducted on 5 datasets (e.g., SBM, Protein, Enzyme, Ego) with increasing size and complexity.
- Evaluation metrics cover local (degree distribution, clustering coefficient distribution, orbit counts) and global (graph spectra) graph statistics.
- Observation: Compared to baselines like GraphRNN, GRAN, and SPECTRE, HiGen consistently achieves state-of-the-art performance, especially in capturing graph motifs and global structure.
Related works
Classical graph models are typically heuristic, rule-based, and statistical, relying on predefined structures rather than learning distributions from data. They are typically computationally efficient, but can't capture complex graphs/structures beyond predefined rules. They are also not data-driven, so they are unable to learn patterns from observed graphs.
Additional deep learning methods for graph generation include the following below, along with some of the things they struggle with
- Variational auto encoders: Limited to smaller graphs due to scalability limitations
- GANs: Hard to train, cannot capture complex dependencies
- Autoregressive models such as graph neural networks and graph RNNs: Have high complexity
- Diffusion models: Struggles with sampling speed
Group 18 Presentation: HIGEN: HIERARCHICAL GRAPH GENERATIVE NETWORKS
Presented by:
- Shiyu Zhu
- Jesse Xue
Paper Citation
M. Karami, “HiGen: Hierarchical Graph Generative Networks,” 2023, arXiv. doi: 10.48550/ARXIV.2305.19337.
Introduction
"HiGen: Hierarchical Graph Generative Networks" by Mahdi Karma introduces a novel approach to graph generation that places emphasis on the hierarchal structures in many inherent real-world graphs.
Key Contributions
1). Hierarchical Graph Generation
HiGen employs a coarse-to-fine strategy to generate graphs, effectively capturing their hierarchal nature. This method involves generating sub-structures at multiple levels, which enhances the model's ability to reflect the inherent organization of complex graphs.
2). Parallel Community Generation
At every hierarchical level, the HiGen method generates communities in parallel. Then, this is followed by the prediction of cross-edge between these communities using separate neural networks. Such a modular approach enables scalable graph generation for large and complicated graphs.
3). Multinomial Edge Distribution Modelling
The model utilizes a multinomial distribution to represent the output distribution od edges within the graph. A recursive factorization of this distribution is employed, and HiGen facilitates the autoregressive generation of community graphs with integer valued edge weights. This improves the realism and accuracy of the generated graphs.
Performance Outcomes
Studies demonstrate the effectiveness and scalability of the HiGen, allowing it to achieve state of the art performance in terms of graph quality across a variety of benchmark datasets. The modular design and hierarchal generation process contributes to its abilities to ternate large and complex graphs with efficiency.
Advancements in the Field
HiGen has advanced the abilities of machine learning in graph generative models by explicitly incorporating hierarchal structures into the generation process. This approach has addressed limitations in existing methods that often overlook the multi-level organization of real world graphs. This thereby enhances the fidelity and applicability of generated graphs within various domains and applications.
Conclusion
The methodologies and finding presented in the HIGEN: HIERARCHICAL GRAPH GENERATIVE NETWORKS paper offer contributions to the field of graph generation. By introducing a hierarchal and modular approach, HIGen sets a new benchmark for generating complex graphs that accurately reflect real observed structures.
Group 18 Presentation: HiGen: Hierarchical Graph Generative Networks
Presented by
- Shiyu Zhu
- Jesse Xue
Paper Citation
M. Karami, “HiGen: Hierarchical Graph Generative Networks,” 2023, arXiv. doi: 10.48550/ARXIV.2305.19337.
Background
- Problem Graphs are widely used to represent complex relationships in various domains (e.g., social networks, molecular structures, and knowledge graphs). However, generating realistic graphs is challenging due to the need to capture both local and global structures.
- Hierarchical structures provide a natural way to model these interactions. Lower levels capture dense local structures, and higher levels capture global properties.
- Existing graph generative models had limitations. Variational Autoencoders struggle with scalability, limiting them to small graphs. Autoregressive models captured hierarchical dependencies well, but suffered from high computational complexity due to sequential node/edge modeling. Diffusion models generated high-quality graphs but were computationally expensive due to long denoising steps.
- Motivation: HiGen introduces hierarchical graph generation with a structured decomposition approach to improve scalability, efficiency, and parallelization.
Summaries of Key Points
- Goal: HiGen aims to learn a generative model from training graphs and efficiently generating new graphs while capturing both local and global structures.
- Uses generative probability and decomposed probability. Generative probability directly models the joint distribution of the entire graph and ensures a holistic view of the graph structure. Decomposed probability breaks the graph into smaller, independent components, and it allows for parallelization, which makes training and generation more scalable and efficient.
- Hierarchical graph generative networks decompose graphs into communities and their bipartite interactions, leveraging conditional independence for efficiency. Essentially, community structures are conditionally independent of bipartite interactions, and this decomposition allows parallel generation of communities and their bipartite edges, reducing training time.
- Inspired by Graph Recurrent Attention Network (GRAN). GRAN generates groups of nodes in parallel to reduce complexity, and HiGen extends this by introducing a k-mixture model. The k-mixture model learns more complex structural patterns and adapts to diverse hierarchical structures that are found in real-world graphs. It is implemented with Multi-Layer Perceptrons (MLPs), each with different inputs and activation functions to capture hierarchical relationships.
- HiGen reduces the number of required adjacency matrices, which improves computational efficiency. It also encourages parallelizable training by organizing graphs into structured blocks. Those blocks represent strongly connected communities, effectively capturing intercommunity and global relationships.
Related works
- Unlike previous models, HiGen explicitly incorporates hierarchical structure to improve scalability, making it more efficient than other methods in handling large graphs.
- Potential Applications: Application in chemistry for molecular graph generation, where the hierarchical structures correspond to molecular backbones and functional groups. It is also helpful for simulating hierarchical community structures in large-scale networks, in terms of social networks. HiGen is applicable to any field that requires scalable graph generation with groups that have hierarchical dependencies.
Group 18 Presentation: HiGen: Hierarchical Graph Generative Networks
Presented by
- Shiyu Zhu and Jesse Xue
Paper Citation
M. Karami, “HiGen: Hierarchical Graph Generative Networks,” 2023, arXiv. doi: 10.48550/ARXIV.2305.19337.
Summaries
This paper proposed a new graph generation network , Hierarchical Graph Generative Network (HiGen), to improve the use of graph generation in real life applications without the expertise knowledge on a specific domain. This paper addressed the issue with existing works that they do not capture the hierarchical structure of graphs, especially graphs and the unordered nature of graphs. The new network generates based on partitions of the graph: communities and cross-links (bipartite graph) between neighboring communities. These components are based on its parent node and are independent from each other so can be generated in parallel which reduce the computational costs. Partitioning the graph also solves the issue with the autoregressive models have that it requires inputs to have orders. With this more efficient network, large graphs can be handled so that better be used for real-life applications.
Key Contributions
- The generative process is run iteratively from the root node to the leaf nodes. Distribution is calculated for the partitioned graph at each level conditional on its parent level. The paper identified that the joint distribution of the set of weights of all edges follows a multinomial distribution. For each community and bipartite sub-graph, the conditional distributions are independent from each other, each follows a multinomial distribution. This finding allows for a parallel generation which reduces computational costs.
- The generation of communities can be broken down to a autoregressive process to predict edge-by-edge. This is because the joint multibinomial distribution of the communities in a specific level given the partition graph at its parent level is a sequence of binomial distributions of each edge.
- The edges of bipartite graphs (cross-links between neighboring communities) can also be generated in parallel.
Explanations of details
Communities: Communities are clusters of nodes or modules that can be grouped together.
Partition process: The graph is partitioned using a graph partitioning function. The partition process is done from the leaf nodes to the root node.
Related Works
Some existing works are specific to an expertise domains which have the limitations to only work for one type of application.
The autoregressive process used for existing works to predict requires an appropriate order of inputs which ignores an unique property of graphs which is that they are insensitive to ordering of nodes.
Existing works also require significant computational costs thus cannot handle large graphs.
Also there has not been a work that explores the hierarchical structure of graphs which limits the performance of existing models.
For diffusion models, continuous diffusion models destroy the sparsity and structural properties of graphs, while discrete diffusion models requires an extra denoising step which may increase the computational costs.
Group 18 Presentation: HiGen: Hierarchical Graph Generative Networks
Presented by
- Shiyu Zhu and Jesse Xue
Paper Citation
M. Karami, “HiGen: Hierarchical Graph Generative Networks,” 2023, arXiv. doi: 10.48550/ARXIV.2305.19337.
Summaries
The paper introduces HiGen, a novel graph generative model that explicitly captures the hierarchical structure of real-world graphs. HiGen generates graphs in a coarse-to-fine manner, successively building communities and their interactions across multiple levels. At each level, the model generates subgraphs (communities) in parallel and then predicts cross-community (bipartite) edges using dedicated neural networks. This decomposition allows HiGen to scale efficiently and capture both local and global graph properties.
Key Contributions
Proposes a multi-resolution generative framework where graphs are generated level-by-level, capturing hierarchical community structures naturally present in many real-world graphs. Decomposes each graph into communities that can be generated in parallel, increasing efficiency and scalability. Models edge weights using a multinomial distribution and introduces a recursive factorization (via stick-breaking) to enable efficient autoregressive generation.
Group 20 Gated linear attention transformers with hardware efficient training
Reference
arXiv:2312.06635
Background
The paper discusses Gated Linear Attention (GLA) Transformers, addressing the computational inefficiency of traditional transformers with softmax attention. Regular transformers have a quadratic computational complexity with sequence length, which becomes extremely expensive for long sequences, and linear attention mechanism leads to an underperforming model.
Technical Contributions
It proposes using a linear kernel as an alternative to the softmax function, which allows attention to be formulated as a linear RNN with 2D hidden states. The key innovations include:
1. Introducing a data-dependent gating mechanism to improve model performance, which allows the model to forget past information adaptively
2. Developing a linear attention approach that reduces computational complexity
3. Creating a hardware-efficient training method (FLASHLINEARATTENTION Algorithm) that can handle long sequences more effectively
The main goal was to create a more efficient transformer model that can:
- Reduce computational expenses
- Maintain competitive performance across different tasks
- Handle long sequences more effectively
- Leverage modern GPU architectures for improved training and inference
The approach addresses the fundamental challenge of making transformer models more scalable and computationally efficient, particularly for tasks involving long sequences like processing books, dialogues, or complex scientific texts.
Results
The results and conclusions of the paper showed:
Performance Results:
- For the 340 million parameter model:
- Achieved competitive performance - Close to transformer performance - Slightly better or comparable to Rednet - Slightly below Mamba on some tasks
- For the 1.3 billion parameter model:
- Beat most benchmarks in average accuracy - Slightly behind transformer++ in perplexity - Showed impressive accuracy across tasks
Key Findings:
1. Gating mechanism is crucial for model performance
- Removing it significantly increased perplexity - Data-dependent scalar decay improved results
2. Recall-intensive tasks:
- Smaller model: Transformer still led - Larger model: GLA closed performance gap considerably - Competitive with Mamba and Rednet
3. Computational Efficiency:
- Higher training throughput for larger batch sizes - Slight increase in GPU memory usage - More efficient for bigger batches
Conclusions:
- GLA is highly effective for handling long sequences - Hardware-efficient design reduces computational costs - Gating mechanism significantly enhances model performance - Promising approach for making transformers more accessible and efficient - A efficient replacement for softmax attention in Transformers
The paper suggests future research should focus on optimizing the balance between performance and efficiency.
Linear Attention
Transformers traditionally use softmax attention, which scales poorly with sequence length due to quadratic complexity. Linear attention approximates softmax with kernel-based attention mechanisms, reducing this cost.
Parallel and Recurrent Forms
-
Parallel Form: Computes full attention using:
[math]\displaystyle{ \mathbf{O} = \text{softmax}((\mathbf{QK}^\top) \odot \mathbf{M}) \mathbf{V} }[/math]
Enables efficient training with full-sequence inputs. -
Recurrent Form: Used during inference, processes token-by-token with:
[math]\displaystyle{ \mathbf{o}_t = \frac{\sum_{i=1}^{t} \phi(\mathbf{q}_t) \phi(\mathbf{k}_i)^\top \mathbf{v}_i}{\sum_{i=1}^{t} \phi(\mathbf{q}_t) \phi(\mathbf{k}_i)^\top} }[/math] -
Using [math]\displaystyle{ \phi(x) = x }[/math] and removing normalization yields the simplified linear attention update:
[math]\displaystyle{ \mathbf{S}_t = \mathbf{S}_{t-1} + \mathbf{k}_t^\top \mathbf{v}_t }[/math], [math]\displaystyle{ \quad \mathbf{o}_t = \mathbf{q}_t \mathbf{S}_t }[/math]
Chunkwise Parallel Linear Attention
The chunkwise parallel form balances between full parallelism and full recurrence, enabling faster training on long sequences.
- Splits input [math]\displaystyle{ \mathbf{X} }[/math] into chunks of length [math]\displaystyle{ C }[/math].
-
Inter-chunk state update:
[math]\displaystyle{ \mathbf{S}_{[i+1]} = \mathbf{S}_{[i]} + \sum_{j=iC+1}^{(i+1)C} \mathbf{k}_j^\top \mathbf{v}_j }[/math] -
Intra-chunk output:
[math]\displaystyle{ \mathbf{O}_{[i+1]} = \mathbf{Q}_{[i+1]} \mathbf{S}_{[i]} + \left((\mathbf{Q}_{[i+1]} \mathbf{K}_{[i+1]}^\top) \odot \mathbf{M}\right) \mathbf{V}_{[i+1]} }[/math]
Swish Activation function (SwiGLU)
One notable component of the GLA model’s design is its use of the Swish activation function (and the derived SwiGLU gating unit) in key parts of the network. Swish, defined as [math]\displaystyle{ \text{Swish}(x) = x \cdot \sigma(x) }[/math] (where [math]\displaystyle{ \sigma }[/math] is the sigmoid), is a smooth, non-monotonic activation known to often outperform ReLU/GELU in deep networks. In this paper, Swish is employed in two main places: (1) the feed-forward network (FFN) layers, where the authors adopt the SwiGLU formulation, and (2) the computation of the data-dependent gates in the attention mechanism.
1) FFN
The function's smooth gradient and ability to yield non-zero outputs even for negative inputs help with optimization and expressiveness. In summary, using SwiGLU in each Transformer block’s FFN is an architectural choice that boosts performance per parameter, making the overall model more competitive with standard Transformers.
2)Gating mechanism
Swish is self-gating. when [math]\displaystyle{ x_t W_r }[/math] is large positive, Swish outputs a large value (roughly linear in x for large positive inputs), but when [math]\displaystyle{ x_t W_r }[/math] is around zero or negative, Swish outputs a small value (tending toward zero for large negative inputs). This means [math]\displaystyle{ r_t }[/math] will tend to selectively suppress tokens that the model deems less important (yielding near-zero for those features) while allowing strong signals to pass through (near-linear for large activations). A standard sigmoid gate could also suppress features (outputting 0-1), but it would saturate at 1 for any sufficiently large input, effectively capping the influence of very important features. Swish, by contrast, does not saturate to a hard limit – for important inputs it can output values greater than 1 (since [math]\displaystyle{ r_t \approx x_t }[/math] if [math]\displaystyle{ x_t W_r }[/math] is large), thereby allowing an amplification effect. This gives the model more flexibility than a sigmoid-gated GLU: small signals are squashed (multiply by a small fraction), while strong signals can be propagated in full or even amplified (almost identity for large positive x). This property can be crucial for modeling, for example, rare key tokens that should strongly influence the attention – Swish gating will let those contributions through rather unattenuated, whereas a sigmoid gate might bottleneck at 1.
Benefits
- Time complexity: [math]\displaystyle{ \mathcal{O}(LCd + Ld^2) }[/math], which is sub-quadratic.
- [math]\displaystyle{ C = 1 }[/math] recovers the recurrent form; [math]\displaystyle{ C = L }[/math] recovers the parallel form.
- Efficient and scalable to long sequences with minimal performance loss.
Future Work
1. Future hardware-aware optimization: balance between efficiency and performance.
2. Application to other data: the potential of applying GLA to image, video, or scientific data.
3. Test how GLA perform on larger model: due to computational limitations, the experiment is on moderate scale model.
Summaries of Key Points
- Gated Linear Attention (GLA) Transformer is a novel architecture that combines the efficiency of linear attention with data-dependent gating mechanisms to improve performance in sequence modeling tasks.
- FLASHLINEARATTENTION is introduced as a hardware-efficient implementation of linear attention, outperforming FLASHATTENTION-2 in speed, even for short sequences (e.g., 1K tokens).
- GLA Transformer enhances length generalization, allowing models trained on 2K sequences to generalize to sequences longer than 20K without significant performance degradation.
- The model is competitive with state-of-the-art architectures, including LLaMA Transformers and linear-time inference models like RetNet and Mamba, particularly in moderate-scale language modeling tasks.
- GLA Transformer achieves higher training throughput compared to similarly sized Mamba models while maintaining efficient long-context processing.
Group 20 Gated linear attention transformers with hardware efficient training
Presented by:
- Felix Jean
- Maxime Bouthilier
- Thomas Hudon
Paper Citation
S. Yang, B. Wang, Y. Shen, R. Panda & Y. Kim, “Gated linear attention transformers with hardware efficient training,” 2024, arXiv:2312.06635
Background & Motivation
The paper tackles the limitations of traditional softmax attention in Transformers, which, despite enabling efficient parallel training, exhibits quadratic complexity with respect to sequence length, rendering it impractical for long sequences. Linear attention has emerged as a promising alternative, providing linear-time inference by reformulating attention as a recurrent neural network (RNN) with 2D hidden states. However, in practice, linear attention often underperforms compared to softmax attention, and existing implementations lack I/O-awareness, leading to slower speeds relative to optimized softmax attention implementations such as FlashAttention-2. The authors identify two critical gaps: (1) the absence of hardware-efficient algorithms for linear attention that effectively balance memory movement and parallelizability, and (2) the lack of data-dependent gating mechanisms in linear attention, which are essential for achieving high performance in RNNs. These gaps motivate the development of FlashLinearAttention and the gated linear attention (GLA) Transformer.
Key Points
The paper introduces FlashLinearAttention, an I/O-aware and hardware-efficient algorithm for linear attention that optimizes memory movement and parallelizability. It achieves faster speeds than FlashAttention-2, even on short sequences (e.g., 1K tokens). The authors further extend this algorithm to Gated Linear Attention (GLA), which incorporates data-dependent gates to enhance model expressiveness. GLA preserves the linear-time inference property while improving performance across a range of tasks. Additionally, the paper proposes a chunkwise parallel formulation for GLA, enabling efficient training by dividing sequences into chunks and balancing inter-chunk and intra-chunk computations. Experimental results demonstrate that the GLA Transformer performs competitively against LLaMA-architecture Transformers and recent linear-time models such as RetNet and Mamba, particularly excelling in length generalization and recall-intensive tasks.
Contributions
- FlashLinearAttention: A hardware-efficient algorithm for linear attention that outperforms FlashAttention-2 in speed and memory efficiency.
- Gated Linear Attention (GLA): A novel linear attention variant with data-dependent gates, offering better performance and stability.
- Chunkwise Parallel Form: A training-friendly formulation of GLA that enables efficient parallelization and scalability.
- Empirical Validation: Demonstrates competitive performance against strong baselines, including LLaMA, RetNet, and Mamba, with notable strengths in length generalization and recall tasks.
- Open-source Implementation: The release of FlashLinearAttention as a practical tool for the community.
Constructive Critiques
- Scalability: Although the experiments are conducted at moderate scales (up to 1.3B parameters), it remains unclear how GLA would perform at larger scales (e.g., 7B+ parameters). The authors hypothesize that GLA’s efficiency would further improve at such scales, but this claim requires empirical validation.
- Generalization to Other Modalities: The current focus is on language modeling; however, extending GLA to other domains (e.g., vision or audio) could potentially broaden its applicability and impact.
- Complexity of Implementation: The secondary-level chunking and materialization strategies introduce additional complexity. Providing a more streamlined implementation or conducting ablation studies could help users better understand the associated trade-offs.
- Comparison to Hybrid Models: While the paper compares GLA to pure linear-time models (e.g., Mamba) and softmax attention, hybrid approaches that combine linear and sparse attention are not explored. Such comparisons could provide deeper insights into GLA's relative strengths and limitations.
Relationships to Other Works
Linear Attention extends prior work by Katharopoulos et al. (2020) and Sun et al. (2023a) by introducing data-dependent gates and hardware optimizations. Hardware-Efficient Attention follows the spirit of FlashAttention (Dao et al., 2022b) but adapts it for linear attention, addressing unique challenges such as chunkwise parallelism. Gated RNNs draws inspiration from gated RNNs (e.g., LSTMs, Mamba) but adapts the gating mechanism for linear attention’s 2D hidden states. Length Generalization complements recent efforts like RetNet and Mamba-2, offering a new solution for extrapolating beyond training lengths.
Group 23 Presentation: Discrete Diffusion Modelling By Estimating the Ratios of the Data Distribution
Presented By
Chenxin Lyu, Yixuan Zeng
Paper Citation
A. Lou, C. Meng, and S. Ermon, ‘Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution’, Jun. 06, 2024, arXiv: arXiv:2310.16834. doi: 10.48550/arXiv.2310.16834.
https://arxiv.org/abs/2310.16834
Background
- Diffusion models have shown great performance for generative artifical intelligence when applied to domains with continuous data
- Diffusion models are more difficult to implement for data in the discrete domain, such as tokenized texts
- Prior attempts at applying diffusion to text generations have performed worse than autoregressive models
Paper Contributions
- Developed a method called Score Entropy Discrete Diffusion (SEDD)
- Parameterizes the diffusion process for discrete data using data distribution ratios, rather than dealing with the tokenized data directly
- SEDD
SEDD is a framework for discrete diffusion modeling that learns to generate data by estimating probability ratios between neighboring discrete states.. In the paper, SEDD forms the core modeling strategy that enables diffusion-based generation for discrete data, achieving competitive perplexities with autoregressive baselines.
- Implicit Score Entropy Loss
[math]\displaystyle{ L_{ISE} = \mathbb{E}_{x \sim p} \sum_{y!=x}(w_{xy}s_\theta(x)_y - w_{yx}log s_\theta (y)_x) }[/math]
Implicit Score Entropy Loss is a novel training objective designed to learn the ratio function [math]\displaystyle{ s_\theta(x, t) = p_t(y)/p_t(x) }[/math] without requiring access to the true data distribution. In the paper, it makes the score entropy computationally more efficient, as it avoids the explicit dependence on [math]\displaystyle{ p(y)/p(x) }[/math] ratio.
It allows one to not compute partition functions or normalize over large vocabularies. The implicit score entropy loss makes the method scalable and practical for high-dimensional or categorical data (like text). It also connects naturally to energy-based modeling where exact densities are intractable but ratios or unnormalized scores can be learned.
It’s worth noting that the score entropy loss also connects to maximum likelihood. The authors show it can be used to derive an evidence lower bound (ELBO) for likelihood training
Discrete Diffusion Processes
- Models probability distributions over a finite discrete space [math]\displaystyle{ \mathcal{X} = \{1, \ldots, N\} }[/math], using probability mass vectors [math]\displaystyle{ p_t \in \mathbb{R}^N }[/math].
-
Evolution of [math]\displaystyle{ p_t }[/math] follows a linear ODE:
[math]\displaystyle{ \frac{dp_t}{dt} = Q_t p_t,\quad p_0 \approx p_{\text{data}} }[/math] - [math]\displaystyle{ Q_t }[/math] is a diffusion matrix with non-negative off-diagonal entries and column sums equal to 0 (mass is preserved).
- Often simplified as [math]\displaystyle{ Q_t = \sigma(t) Q }[/math], driving [math]\displaystyle{ p_t }[/math] toward a base distribution as [math]\displaystyle{ t \to \infty }[/math].
-
Simulated using Euler steps with small [math]\displaystyle{ \Delta t }[/math]. Transition probability:
[math]\displaystyle{ p(x_{t+\Delta t} = y \mid x_t = x) = \delta_{xy} + Q_t(y, x) \Delta t + O(\Delta t^2) }[/math] -
Time Reversal: Reverse process uses another matrix [math]\displaystyle{ \overline{Q}_t }[/math] with:
[math]\displaystyle{ \overline{Q}_t(y, x) = \frac{p_t(y)}{p_t(x)} Q_t(x, y) }[/math]
Reverse ODE: [math]\displaystyle{ \frac{dp_{T-t}}{dt} = \overline{Q}_{T-t} p_{T-t} }[/math] - This connects to the concrete score, generalizing the score function [math]\displaystyle{ \nabla_x \log p_t }[/math].
Summaries of Key Points
- SEDD (Score Entropy Discrete Diffusion models) is a novel approach to discrete diffusion modeling that bridges the gap between diffusion models and autoregressive language models. It introduces score entropy, a new loss function that extends score matching to discrete spaces, improving performance in discrete generative modeling tasks.
- SEDD significantly improves performance in language modeling, reducing perplexity by 25-75% compared to previous discrete diffusion models and outperforming GPT-2 in certain tasks.
- Key advantages of SEDD over traditional autoregressive models:
- Higher generative quality without requiring distribution annealing techniques such as temperature scaling.
- Computational efficiency, enabling similar output quality with up to 32× fewer network evaluations.
- Enhanced controllability, allowing for flexible text infilling beyond left-to-right prompting, while maintaining quality comparable to nucleus sampling.
- SEDD models discrete data by parameterizing a reverse discrete diffusion process using ratios of the data distribution, making them more effective in capturing language structure.
- The method challenges the dominance of autoregressive transformers by offering an alternative with better trade-offs between efficiency, control, and generation quality in discrete text modeling.
Results
- Perplexity Evaluation: SEDD outperforms previous diffusion models in terms of perplexity across multiple language modeling benchmarks. This suggests that SEDD models the underlying data distribution more accurately, improving likelihood estimation.
- Unconditional Generation: SEDD generates high-quality samples without any input conditioning, achieving performance comparable to GPT-2. It does so with 32× fewer network evaluations, indicating higher sampling efficiency and reduced computational cost.
- Conditional Generation (Infill Tasks): SEDD is tested on tasks where the model must generate missing parts of text conditioned on surrounding context. Despite lacking the autoregressive biases that normally boost performance in such tasks, SEDD remains competitive with strong baselines like GPT-2 and SSD-LM. This highlights SEDD’s ability to generalize well to complex discrete generation tasks without relying on token-by-token prediction.
Group 23 Presentation: Discrete Diffusion Modelling By Estimating the Ratios of the Data Distribution
Presented By
Chenxin Lyu, Yixuan Zeng
Paper Citation
A. Lou, C. Meng, and S. Ermon, ‘Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution’, Jun. 06, 2024, arXiv: arXiv:2310.16834. doi: 10.48550/arXiv.2310.16834.
https://arxiv.org/abs/2310.16834
Background & Motivation
The paper tackles a critical gap in generative modeling for discrete data, particularly within the domain of natural language processing (NLP). While diffusion models have achieved remarkable success in continuous domains such as image generation, their performance on discrete data (e.g., text) has fallen short compared to autoregressive models, which currently dominate the field. The authors pinpoint the root cause as the absence of a principled and scalable framework for discrete score matching—the foundational theory underlying continuous diffusion models. Existing approaches, such as mean prediction and ratio matching, exhibit theoretical and empirical limitations, including instability, inefficiency, and suboptimal performance. Motivated by these challenges, the paper introduces Score Entropy Discrete Diffusion (SEDD), a novel method that extends score matching to discrete spaces by estimating probability ratios of the data distribution. This approach seeks to close the performance gap between autoregressive and diffusion-based language models while addressing key challenges associated with slow sampling, limited controllability, and stringent annealing requirements in autoregressive models.
Key Points
1. Score Entropy Loss: The key innovation lies in a novel loss function, score entropy, which extends score matching to discrete spaces by modeling the ratios of the data distribution. This ensures positivity, scalability, and theoretical consistency, thereby addressing the limitations of prior methods such as concrete score matching (which inadequately penalizes negative values).
2. Discrete Diffusion Framework: SEDD parameterizes the reverse diffusion process using learned probability ratios, enabling efficient sampling and likelihood-based training. The framework supports token-level transitions via structured matrices (e.g., uniform or absorbing transitions), facilitating the handling of high-dimensional sequences.
3. Empirical Superiority: SEDD outperforms existing discrete and continuous diffusion models on language tasks, reducing perplexity by 25–75% and matching or surpassing GPT-2 in zero-shot perplexity. It also achieves significantly higher-quality unconditional generation (6–8× better generative perplexity than un-annealed GPT-2) and flexible conditional generation (e.g., infilling).
4. Practical Benefits: The model provides a favorable compute-quality trade-off (e.g., achieving GPT-2 quality with 32× fewer steps), eliminates the need for annealing techniques like temperature scaling, and enables controllable infilling without specialized training.
Contributions
- Theoretical: Introduces score entropy, a loss function that generalizes score matching to discrete spaces while ensuring positivity and scalability, with rigorous proofs of consistency and tractability (e.g., the denoising score entropy variant). .
- Methodological: Develops SEDD, a discrete diffusion framework that integrates score entropy with token-level transitions via structured matrices, enabling efficient training and sampling. The Tweedie-inspired τ-leaping sampling strategy further enhances performance in practical scenarios.
- Empirical: Demonstrates state-of-the-art results on language modeling benchmarks (e.g., text8, One Billion Words) and generation tasks (both unconditional and conditional), outperforming autoregressive baselines in key metrics such as perplexity. The model’s flexibility in infilling and its favorable compute-quality trade-offs represent significant advancements in the field.
Constructive Critiques
- Complexity: The reliance on matrix exponentials (e.g., for token transitions) may limit scalability to larger vocabularies or more complex structures (e.g., graphs).
- Generalization: While SEDD excels in language, its applicability to other discrete domains (e.g., molecules, code) remains untested.
- Training Cost: The paper notes SEDD’s parameter count is slightly higher than GPT-2, but the computational overhead of diffusion training versus autoregressive training is not thoroughly compared.
Relationships to Other Works
The SEDD model advances key areas of generative modeling by improving upon prior discrete diffusion approaches such as D3PM (Austin et al., 2021) and the continuous-time framework of Campbell et al. (2022). It replaces their mean prediction objectives with ratio estimation, addressing limitations in stability and continuous-time approximation. Compared to continuous diffusion models like Diffusion-LM (Li et al., 2022) and PLAID (Gulrajani & Hashimoto, 2023), SEDD achieves better performance in likelihood estimation and generation quality without requiring heuristic annealing techniques. The work also generalizes score matching methods, extending Hyvarinen's original score matching (2005) and concrete score matching (Meng et al., 2022) to discrete domains through its score entropy formulation. While not yet reaching the scale of modern autoregressive models, SEDD competes effectively with autoregressive baselines like GPT-2 in flexible generation tasks (e.g., infilling) and computational efficiency. Its success highlights the potential of combining SEDD with recent advances such as self-conditioning (Strudel et al., 2022) to further close the gap with autoregressive models.
Group 23 Presentation: Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution
Presented by:
Chenxin Lyu and Yixuan Zeng
Paper Citation
Lou, A., Meng, C., & Ermon, S. (2024). Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution. arXiv. https://doi.org/10.48550/arXiv.2310.16834
Summaries
Diffusion models have achieved remarkable performance in continuous domains like image generation, but extending them to discrete domains—such as natural language—has proven challenging. Previous efforts to adapt diffusion to text suffer from instability, inefficiency, and inferior performance when compared to autoregressive models, which currently dominate NLP.
This paper proposes Score Entropy Discrete Diffusion (SEDD), a novel generative modeling framework for discrete data. Rather than modeling tokens directly, SEDD estimates probability ratios between discrete states, bridging the gap between diffusion models and autoregressive transformers.
SEDD introduces a new training loss, score entropy, which generalizes score matching to discrete spaces. This loss function ensures theoretical consistency, avoids the need to compute partition functions, and maintains scalability for high-dimensional data. The reverse diffusion process is parameterized using these ratios, enabling efficient sampling, improved generation quality, and competitive perplexity scores—sometimes even outperforming GPT-2.
Key Contributions
Score Entropy Loss A new loss function for discrete score matching. Unlike previous approaches, it avoids negative values, ensures scalability, and naturally connects to maximum likelihood estimation through an evidence lower bound (ELBO).
Discrete Diffusion Framework SEDD models a reverse diffusion process using structured transition matrices and estimated data distribution ratios. This supports token-level transitions with practical benefits for text generation tasks.
Efficient Sampling & High-Quality Output The model significantly reduces perplexity (by 25–75%) and achieves GPT-2–level generation quality while requiring 32× fewer sampling steps. It supports unconditional generation, conditional infilling, and eliminates the need for annealing strategies (e.g., temperature scaling).
Theoretical Rigor & Practical Flexibility SEDD is grounded in a discrete formulation of diffusion ODEs and uses a Tweedie-inspired τ-leaping strategy for improved sample efficiency.
Constructive Critiques or Reviews
Scalability Concerns: The use of matrix exponentials in token transitions might limit performance when applied to large vocabularies or graph-structured data.
Domain Generalization: While SEDD shows strong results in language tasks, its application to other discrete domains like molecules or source code remains untested.
Training Cost: The model has slightly more parameters than GPT-2, but the full computational trade-off between SEDD and autoregressive training isn’t thoroughly explored.
Related Works
Discrete Diffusion Models: Builds upon earlier works like D3PM (Austin et al., 2021) and continuous diffusion adaptations such as Diffusion-LM and PLAID.
Score Matching Foundations: Extends Hyvärinen’s (2005) original score matching and Meng et al.’s (2022) concrete score matching to discrete spaces using the score entropy formulation.
Comparisons with Autoregressive Models: SEDD matches or outperforms GPT-2 in zero-shot and infilling tasks, offering better trade-offs between generation quality, efficiency, and controllability.
Alternative Sampling Techniques: The τ-leaping sampler resembles approaches in energy-based modeling and may be enhanced further with self-conditioning (Strudel et al., 2022).
Group 23 Presentation: Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution
Presented by
Chenin Lyu and Yixuan Zeng
Paper Citation
A. Lou, C. Meng, and S. Ermon, ‘Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution’, Jun. 06, 2024, arXiv: arXiv:2310.16834. doi: 10.48550/arXiv.2310.16834. https://arxiv.org/abs/2310.16834
Background
- Generative modeling aims to create models capable of learning and reproducing real-world data distributions.
- Existing autoregressive models generate sequences token-by-token, making them computationally slow, and they rely heavily on heuristic measures and hyper parameter tuning for optimal performance.
- Diffusion models have emerged as an alternative, but existing approaches face challenges in discrete data modeling.
- This paper introduces Score Entropy Discrete Diffusion (SEDD), which learns the ratio between different states in discrete data distributions rather than estimating explicit probability densities. It parameterizes the reverse diffusion process, making it computationally efficient and scalable for high-dimensional discrete data.
Summaries of Key Points
- A diffusion process models the transformation of data from a structured state to pure noise over time and then reverses this process to generate new samples. The discrete diffusion process’s purpose is to create a probabilistic framework that can efficiently model discrete data. The time-reversed formulation of the diffusion process ensures that the generative process follows a learned reverse trajectory back to the original data distribution.
- An ideal diffusion model is consistent, meaning the learned score function aligns with the true probability distribution. Score entropy loss is a loss function that is used to improve stability in discrete diffusion models. It ensures consistency by minimizing errors in the estimated probability ratios, leading to a more reliable generative process.
- Implicit score entropy loss only depends on observed examples and learned scores, which is useful for practical optimization since it allows training without requiring knowledge of all true probabilities. It also encourages scalability since it works for high-dimensional discrete tasks, where exact probability distributions are infeasible to compute.
- These discrete diffusion models address the issue of computational efficiency as well. Because computing all transition probabilities directly requires large matrix-matrix multiplications, which is impractical in memory-intensive settings, it uses two structured matrices to compute transition ratios effectively.
- The model also enables unconditional generation by leveraging the learned diffusion process without requiring additional conditioning variables.
Explanation of Time-Reversal Strategies
- Discrete Tweedie’s theorem says that if [math]\displaystyle{ p_t }[/math] satisfies the diffusion ODE [math]\displaystyle{ dp_t = Qp_t }[/math], then the exact denoiser can be expressed as:
[math]\displaystyle{ p_{0|t}(x_0|x_t) = \left(exp(-tQ)\left[\frac{p_t(i)}{p_t(x_t)}\right]_{i=1}^{N}\right)_{x_0} exp(tQ)(x_t,x_0) }[/math]
- This theorem shows that given full knowledge of probability ratios, the optimal denoiser can be expressed in closed form.
- This guides practical denoiser construction because the model learns data distribution ratios rather than explicit probabilities. It also allows the design of an efficient denoising process that reconstructs original states from noise.
Group 24 Presentation: Mitigating the Missing Fragmentation Problem in De Novo Peptide Sequencing With A Two-Stage Graph-Based Deep Learning Model
Presenters
Zi Hua Xu, Zehao Zhang
Paper Citation
Mao, Z., Zhang, R., Xin, L. et al. Mitigating the missing-fragmentation problem in de novo peptide sequencing with a two-stage graph-based deep learning model. Nat Mach Intell 5, 1250–1260 (2023). https://doi.org/10.1038/s42256-023-00738-x
https://www.nature.com/articles/s42256-023-00738-x#citeas
Background
- Proteins are crucial for biological functions
- Proteins are formed from peptides which are sequences of amino acids
- Mass spectrometry is used to analyze peptide sequences
- De Novo sequencing is used to piece together peptide sequences when the sequences are missing from existing established protein databases
- Deep learning has become commonly implimented to solve the problem of de-novo peptide sequencing
- When a peptide fails to fragment in the expected manner, it can make protein reconstruction difficult due to missing data
- One error in the protein can propogate to errors throughout the entire sequence
Paper Contributions
- Graph Novo was developed to handle incomplete segments
- GraphNovo-PathSearcher instead of directly predicting, does a path search method to predict the next peptide in a sequence
- A graph neural network is used to find the best path from the graph generated from the mass spectrometry input
- GraphNovo-SeqFiller instead of directly predicting, does a path search method to predict the next peptide in a sequence.
- It's expected that some peptides/ amino acids may have been missed, SeqFiller uses a transformer to add in amino acids which have been missed from PathSearcher
- Input is mass spectrum from mass spectrometry
- Graph construction is done where nodes represent possible fragments, and edges represent possible peptides (PathSearcher module)
- PathSearcher uses machine learning to find the optimal path on the generated graph
- SeqFiller fills in missing amino acids that may have not been included in the PathSearcher module due to lacking data from the mass spectrometry inputs
Peptide Sequencing in AlphaPeptDeep
- Peptide sequencing determines amino acid sequences of peptides, crucial for understanding proteins.
- Mass spectrometry (MS) is used to fragment peptides and analyze the resulting spectra.
Methods referenced in the presentation:
- Database Search & Spectral Library Search: AlphaPeptDeep improves prediction of MS spectra and retention time, boosting accuracy of both methods.
- de novo Sequencing: Enhanced spectral prediction from AlphaPeptDeep supports building peptide sequences without prior knowledge.
- AlphaPeptDeep predicts peptide properties (e.g., fragmentation patterns) to improve spectrum matching and sequence inference.
Contributions
- GraphNovo outperforms prior models such as DeepNovo, PointNovo, and Casanovo. Achieveing:
- 9.3–11.5% higher peptide recall,
- 6.1–8.9% higher amino acid recall,
- 5.8–8.7% higher amino acid precision.
- Substantially improves predictions in and beyond missing-fragmentation regions.
- Improves amino acid recall by up to 24.5% after missing fragmentation when guiding DeepNovo with GraphNovo’s predicted path.
- Maintains high accuracy across varying:
- Sequence lengths,
- Noise signal ratios,
- Degrees of missing fragmentation.
- Open-source availability:
- Code: GitHub – GraphNovo
- Data: Zenodo Repository
Constructive Critiques
- Struggles with long missing-fragmentation regions: The model has difficulty accurately predicting peptide sequences when large segments of fragmentation data are absent. These missing regions create gaps in the spectrum, which may impair the model's ability to reconstruct the full peptide chain.
- Requires predefined post-translational modifications: The system depends on a predefined list of possible post-translational modifications. This constraint limits the model’s ability to generalize to peptides with unexpected or novel PTMs, reducing its adaptability in complex biological samples.
- Computationally expensive: Due to its two-stage graph-based deep learning architecture, the model demands significant computational resources.
Reviews
I thought that this presentation clearly introduced the concepts of peptide sequences and how mass spectrometry is used to analyze/reconstruct peptide sequences. Also it was nice that you mentioned the different ways how sequences are matched with observed spectra besides deep learning methods (e.g, matching observed spectra with known sequences and using pre-existing spectral databases). The problem of how mistakes in the fragmentation can result in missing data in the spectrometry observed was also clearly explained, ultimately making sequence reconstruction difficult. As someone who plans to do project 3 for the final project, I found this presentation particularly helpful in understanding the type of data used in de novo peptide sequence generation. One comment that I have is that I found it a little unclear on how the mass spectra input is turned into a graph with fragments as nodes, and what exactly the optimal "path" is in the graph (aka how the edges are defined), although it may also because I am not too familiar in this area.
Group 24 Presentation: Mitigating the Missing Fragmentation Problem in De Novo Peptide Sequencing With A Two-Stage Graph-Based Deep Learning Model
Presenters
Zi Hua Xu, Zehao Zhang
Paper Citation
Mao, Z., Zhang, R., Xin, L. et al. Mitigating the missing-fragmentation problem in de novo peptide sequencing with a two-stage graph-based deep learning model. Nat Mach Intell 5, 1250–1260 (2023). https://doi.org/10.1038/s42256-023-00738-x
https://www.nature.com/articles/s42256-023-00738-x#citeas
Background
Peptide sequencing is a critical step in proteomics for determining the amino acid composition of proteins using tandem mass spectrometry (MS). Traditional approaches fall into three main categories:
(1) Database Search: Compares observed tandem mass spectra with peptides from a known database.
(2) Spectral Library Search: Uses curated spectral libraries containing experimentally acquired spectra to identify peptide–spectrum matches.
(3) De Novo Sequencing: Infers the peptide sequence directly from the MS data without relying on existing sequence databases, making it essential for novel protein discovery and cases where databases are incomplete or impractical.
Limitations in de novo peptide sequencing
Two major challenges persist in de novo peptide sequencing:
(1) Missing Fragmentation: Some peptide bonds do not break during MS fragmentation, leaving “gaps” in the spectrum that make certain regions difficult to reconstruct.
(2) Error Accumulation: A mistake in one poorly fragmented region often propagates, causing further downstream errors in the predicted peptide sequence.
GraphNovo
GraphNovo is a two-stage de novo sequencing algorithm designed to mitigate these issues:
(1) GraphNovo-PathSearcher: Constructs a directed acyclic graph (the “spectrum graph”) where each node represents a possible partial mass, and edges represent allowable amino acid mass differences. Predicts the “optimal path” from the start node to the end node, effectively capturing the correct arrangement of fragment ions and bypassing regions of missing fragmentation by labeling them as “mass tags.”
(2) GraphNovo-SeqFiller: Fills in the “mass tags” from the optimal path with their specific amino acid composition. Uses a transformer-based decoder guided by the path constraints, reducing errors that could otherwise accumulate if one low-confidence region led to incorrect subsequent predictions.
How GraphNovo Functions
(1) Graph Construction: Translates raw MS peaks into a spectrum graph. Each node corresponds to a potential prefix mass; edges form between nodes if the mass difference matches one or more amino acid masses (within tolerance).
(2) PathSearcher: Trained on known spectra to find the correct route through this graph, placing constraints on fragment ions (including missing-fragmentation regions represented as nodes without direct evidence).
(3) SeqFiller: Given the path—and hence each mass gap—SeqFiller “zooms in” on each mass tag to determine the exact amino acids. This two-stage strategy tackles missing fragmentation more directly than single-stage approaches.
Data and Preprocessing
Training Data: Includes high-confidence peptide identifications for Homo sapiens and Mus musculus from public proteomics datasets (e.g., plasma, HeLa, brain tissues). Only peptides with reliable annotations (1% false-discovery rate) and precursor masses below a certain threshold are included.
Test Data: Drawn from species not in the training set (e.g., Arabidopsis thaliana, C. elegans, E. coli), ensuring that evaluation measures generalizability.
Preprocessing Steps:
(1) Convert raw MS peaks into feature vectors (e.g., normalized m/z, relative intensity).
(2) Generate “node spectrum” (b, y, a, y2+ ions, among others) while discarding infeasible peaks.
(3) Build the graph by connecting nodes if their mass difference matches valid amino acid combinations.
Model Architecture
Graph Construction: Creates a directed graph with edges corresponding to possible amino acid masses.
Loss Functions:
(1) GraphNovo-PathSearcher: Uses Kullback–Leibler divergence to guide node (path) predictions.
(2) GraphNovo-SeqFiller: Uses cross-entropy to predict the exact amino acid sequence that fills each mass tag.
Hyperparameter Tuning:
Optimizer: AdamW (a variant of Adam) with a fixed learning rate in the reported experiments.
Both stages employ a transformer-based architecture, incorporating a specialized graph encoder (relation attention) to capture node and edge features.
Performance Comparison
Peptide Recall: GraphNovo shows a 9–12% improvement over the next-best approach (e.g., PointNovo, Casanovo) in correctly reconstructing entire peptide sequences.
Amino Acid Recall and Precision: Yields a 5–9% improvement across different test species, indicating more accurate individual residue identifications.
Robust to Missing Fragmentation and Noise: Maintains relatively high recall/precision for longer peptides, higher noise levels, and spectra with multiple missing-fragmentation sites, thereby mitigating error accumulation.
Constructive Critiques
Long Missing-Fragmentation Regions: While GraphNovo substantially reduces error propagation, very long or continuous gaps remain challenging.
Predefined Modifications: Must specify possible post-translational modifications (PTMs) in the graph construction step, which becomes computationally costly if many PTMs are considered at once.
Computational Overhead: Two-stage approach and large-scale graph construction require significant memory and processing time.
Future improvements
Enhanced Sequence Prediction: Integrate more MS features (e.g., retention time) to improve accuracy within large missing-fragmentation regions.
Expanded Applicability: Adapting the two-stage approach for top-down or middle-down proteomics and more extensive sets of PTMs.
Computational Efficiency: Explore faster graph-building algorithms, reduce the number of nodes through refined filtering, and potentially incorporate few-shot learning for user-specified PTMs.
Comment
Overall, GraphNovo demonstrates that a carefully designed, graph-based deep learning model can significantly mitigate the missing-fragmentation problem in de novo peptide sequencing, outperforming traditional and newer transformer-based methods by providing a stable framework for both path and sequence prediction.
Paper Citation
Group 47 Presentation: Jamba: A Hybrid Transformer - Mamba Language Model
Presenter
Chentao Jin
Paper Citation
Lieber, O., Lenz, B., Bata, H., Cohen, G., Osin, J., Dalmedigos, I., Safahi, E., Meirom, S., Belinkov, Y., Shalev-Shwartz, S., Abend, O., Alon, R., Asida, T., Bergman, A., Glozman, R., Gokhman, M., Manevich, A., Ratner, N., Rozen, N., Shwartz, E., Zusman, M., Shoham, Y. (2024). Jamba: A Hybrid Transformer-Mamba Language Model. arXiv. https://arxiv.org/abs/2403.19887
https://doi.org/10.48550/arXiv.2403.19887
Background
- Large language models (LLMs) have become essential for various natural language processing tasks.
- Transformers have been the dominant architecture for LLMs due to their effectiveness in handling sequential data.
- However, Transformers suffer from high memory and compute costs, limiting their efficiency in long-context processing.
- Mamba, a recent state-space model, has emerged as an alternative to Transformers, offering improved efficiency in handling long sequences.
- A hybrid approach combining Transformers and Mamba layers can leverage the strengths of both architectures.
- Mixture-of-Experts (MoE) techniques can further enhance model capacity while managing active parameter usage.
- Efficient model architectures are crucial for balancing performance, computational efficiency, and memory footprint in large-scale AI applications.
Summaries of Key Points
Bridging Transformer Expressiveness and Mamba Efficiency
Large language models (LLMs) have made remarkable advances, but scaling them to handle long-context processing remains a significant challenge. Jamba introduces a novel hybrid architecture that combines the strengths of Transformers and Mamba, enhanced with a Mixture of Experts (MoE) module. This integration enables efficient memory usage, high throughput, and scalability for sequences up to 256,000 tokens on a single GPU.
Key Architectural Innovations
1. Transformer Layers – Self-attention mechanisms allow the model to capture complex token relationships, crucial for tasks involving deep contextual reasoning. However, they come with high memory and compute costs when processing long sequences.
2. Mamba Layers – Derived from state-space models (SSMs), Mamba efficiently processes long sequences without storing extensive key-value caches. Instead, it maintains a hidden state to summarize prior information, significantly reducing memory overhead.
3. Mixture of Experts (MoE) – Jamba integrates sparse expert selection, where each token activates only a small subset of experts. This technique increases capacity while controlling computational costs. Specifically, Jamba uses 16 experts, with only two active per token, optimizing efficiency.
The Jamba Block: A Structured Hybrid Design
The architecture of a Jamba block follows a structured sequence of Transformer layers, Mamba layers, and MoE layers: Transformer-to-Mamba Ratio (1:7). The model incorporates one Transformer layer for every seven Mamba layers, balancing expressiveness and efficiency. MoE Placement – Instead of applying MoE to every layer, Jamba replaces every second multi-layer perceptron (MLP) layer with an MoE module. This approach increases model capacity without significantly raising parameter count.
By blending self-attention, state-space models, and sparsity techniques, Jamba pushes the boundaries of long-context processing in language models. Its ability to handle extremely long sequences while maintaining efficiency and scalability makes it a compelling innovation in the next generation of LLM architectures.
Performance and Benefits
Jamba achieves state-of-the-art efficiency and performance across academic benchmarks, long-context tasks, and throughput.
- Matches or outperforms larger models such as Mixtral-8x7B and LLaMA-2 70B on:
- Reasoning: HellaSwag, ARC-Challenge, WinoGrande, PIQA, TruthfulQA
- Comprehension: BoolQ, QuAC
- Math and Code: GSM8K, HumanEval
- Aggregated tasks: MMLU, BBH
- Outperforms Mixtral on:
- Needle-in-a-Haystack retrieval
- Few-shot classification: Banking77, TREC-Fine
- Long-context QA: NarrativeQA, CUAD, NaturalQuestions
- Up to 3× faster than Mixtral at [math]\displaystyle{ 128K }[/math] context length.
- Efficient inference on a single [math]\displaystyle{ 80\,\text{GB} }[/math] GPU (int8 quantization).
- KV-cache memory usage is 8× smaller than Transformers (e.g., [math]\displaystyle{ 4\,\text{GB} }[/math] vs. [math]\displaystyle{ 32\,\text{GB} }[/math] at [math]\displaystyle{ 256K }[/math] tokens).
Constructive Critique
While Jamba achieves impressive performance and efficiency through its hybrid architecture, several limitations and open questions remain:
- Lack of ablation analysis: The paper adopts a 1:7 ratio of Transformer to Mamba layers and places MoE modules every other layer, but provides little justification for these hyperparameters. A more thorough ablation would help clarify the contribution of each component to final performance.
- Interpretability trade-off: Combining multiple architectural modules (SSM, attention, MoE) increases complexity. While effective, this may make it harder to interpret model behavior or debug errors, especially compared to simpler Transformer-only baselines.
- ICL limitations: The authors mention that Mamba layers alone struggle with in-context learning (ICL). Although interleaving attention layers helps, this still suggests limitations in how well SSMs handle token-by-token reasoning or structure-sensitive tasks.
- Lack of fine-tuning or alignment: The released model is a base pretrained model without instruction tuning or safety alignment. This limits its immediate use in downstream applications without additional supervised or RLHF-based training.
Despite these challenges, Jamba represents a promising direction for scaling efficient, long-context language models and offers a practical blueprint for hybrid architectures in the LLM space.
Group 47 Presentation: Jamba: A Hybrid Transformer - Mamba Language Model
Presenter
Chentao Jin
Paper Citation
Lieber, O., Lenz, B., Bata, H., Cohen, G., Osin, J., Dalmedigos, I., Safahi, E., Meirom, S., Belinkov, Y., Shalev-Shwartz, S., Abend, O., Alon, R., Asida, T., Bergman, A., Glozman, R., Gokhman, M., Manevich, A., Ratner, N., Rozen, N., Shwartz, E., Zusman, M., Shoham, Y. (2024). Jamba: A Hybrid Transformer-Mamba Language Model. arXiv. https://arxiv.org/abs/2403.19887
https://doi.org/10.48550/arXiv.2403.19887
Jamba Architecture
Main Features
Jamba is a hybrid large language model that interleaves:
Transformer layers with self-attention (standard decoder blocks).
Mamba layers (a state-space model, SSM) introduced by Gu & Dao (2023).
Mixture-of-experts (MoE) modules integrated into some of the MLP layers.
By combining these three components, Jamba can balance efficiency, long-context capabilities, and model capacity without incurring a prohibitive computational or memory cost.
(1) Transformer Layers
Jamba uses standard decoder-only Transformer blocks, but crucially, they appear in a reduced proportion (e.g., 1 attention layer for every 7 Mamba layers).
The attention mechanism is still important for in-context learning and tasks that benefit from explicit token-to-token interactions.
(2) Mamba Layers
Mamba layers replace much of the attention with an SSM-based mechanism, scaling linearly with sequence length.
They significantly reduce key–value cache size for long contexts because each Mamba layer does not require storing extensive attention activations.
Unlike prior SSMs, Mamba is stabilized at large scale with carefully chosen RMSNorm inside the state-space modules.
The authors find no explicit positional encoding is required in the Mamba blocks—Mamba inherently captures positional information.
(3) Mixture-of-Experts (MoE)
Jamba integrates MoE in some MLP layers to increase total capacity without increasing the active parameters used per token.
MoE involves having multiple “expert” sub-MLPs, with only the top K experts selected for each token.
This leads to a sparse model design: total parameters can be large (e.g., 50B+), but only ~12B parameters are “active” at any forward pass.
Performance and Benefits of Jamba
(1) High throughput Compared to a pure-Transformer of similar size, Jamba achieves up to 3× higher inference throughput at very long context lengths. This is because Mamba’s linear-time scan avoids the quadratic cost and large key–value cache of attention.
(2) Memory efficiency Jamba’s key–value cache can be 8× smaller than a similarly sized Transformer, which makes it possible to handle up to 256K tokens of context (or even more) on a single 80GB GPU.
(3) Competitive quality On standard LM benchmarks (ARC, HellaSwag, WinoGrande, etc.), Jamba performs on par with or better than similarly sized Transformer or MoE-Transformer models. It also demonstrates strong capabilities in long-context tasks (e.g., “needle in a haystack” retrieval).
Key Design and Insights
(1) Hybrid Architecture
The mixed ratio of attention layers to Mamba layers (often 1:7) is crucial. Even a small fraction of attention layers confers strong in-context learning (format adherence, induction-like patterns), while the Mamba layers bring speed and memory savings.
Pure Mamba, though fast, sometimes struggles with emergent in-context learning behaviors (e.g., properly following few-shot prompts). The hybrid design preserves these Transformer-like capabilities.
(2) MoE Effectiveness
Using MoE on top of the hybrid model further improves perplexity and downstream performance, allowing the total parameter count to go up to 50B+ while keeping active parameter usage around ~12B.
Balancing the number of experts, top-K selection, and how frequently MoE is used (e.g., every other MLP layer) is key for controlling compute costs and memory.
(3) Training Stability and Design Choices
RMSNorm: Large-scale Mamba layers exhibit occasional large activation spikes. RMSNorm on internal activations stabilizes the training, preventing loss spikes.
No explicit positional encoding needed: Unlike typical Transformers (which use rotary, ALiBi, or other embeddings), the authors found that Mamba captures positional cues inherently. Adding RoPE gave no notable improvement.
Conclusion
(1) Uniqueness of Jamba High efficiency and strong design
Jamba’s combination of attention, Mamba, and MoE layers yields excellent throughput and long-sequence modeling.
(2) Handling long context better
Jamba’s memory footprint for KV caching is drastically smaller. It can handle contexts of up to 256K tokens on a single 80GB GPU—significantly exceeding typical Transformer-based LLMs of similar size.
(3) Open-source release
The model is released under an Apache 2.0 license, encouraging research on this hybrid approach. Pretrained checkpoints and ablation runs will also be provided.
Future Directions
(1) Optimize MoE Further
Investigating more sophisticated MoE routing strategies, expert balance, or hierarchical gating to push quality and efficiency further.
(2) Hybrid Scaling in Even Larger Models
Extending beyond ~7B–12B active parameters to tens of billions or more, exploring how the attention–Mamba ratio and MoE design scale at even larger training runs.
Group 47 Presentation: Jamba: A Hybrid Transformer - Mamba Language Model
Presenter
Chentao Jin
Paper Citation
Lieber, O., Lenz, B., Bata, H., Cohen, G., Osin, J., Dalmedigos, I., Safahi, E., Meirom, S., Belinkov, Y., Shalev-Shwartz, S., Abend, O., Alon, R., Asida, T., Bergman, A., Glozman, R., Gokhman, M., Manevich, A., Ratner, N., Rozen, N., Shwartz, E., Zusman, M., Shoham, Y. (2024). Jamba: A Hybrid Transformer-Mamba Language Model. arXiv. https://arxiv.org/abs/2403.19887
https://doi.org/10.48550/arXiv.2403.19887
Clear explanations to aid understanding
Jamba uses a unique combination of transformer layers, memory layers, and a Mixture of Experts (MoE) layer to improve its efficiency, scalability, and ability to process long sequences of text. This hybrid design optimizes both memory management and computational power while minimizing resource use. Here’s an overview of the components and their roles in the architecture:
Jamba's Architecture
Jamba’s architecture is built on a structured sequence of layers: after each memory layer, a Mixture of Experts (MoE) layer is placed, and every transformer layer is followed by seven memory layers. This design creates an efficient, hybrid system that maximizes both memory handling and computational power while minimizing resource usage.
Each of these layers:
1. Transformer Layers: The transformer layers in Jamba are responsible for capturing the relationships between different tokens in the sequence, no matter how far apart they are. This is done using self-attention, which helps the model understand complex relationships and context within the text. However, traditional transformers can struggle with very long sequences because the self-attention mechanism becomes very memory and computation-heavy as the sequence length grows.
2. Memory Layers: To tackle this, Jamba introduces memory layers based on a different model called the state space model (SSM). These memory layers don’t rely on self-attention. Instead, they maintain a hidden state that keeps track of important information from earlier in the sequence. This makes memory layers far more efficient when it comes to handling long sequences, as they don’t need to store as much data in memory, unlike the transformer layers.
3. Mixture of Experts (MoE): The MoE component is where Jamba gets its flexibility. Instead of using all model parameters for each token, MoE selectively activates a small subset of "experts" for each token. An "expert" is a specialized set of parameters that focuses on solving a specific part of the problem. The model dynamically selects the experts that are most relevant to each token's context, allowing it to handle different parts of the problem more efficiently. Since only a few experts are activated per token, the model can scale its capacity to handle more complex tasks or longer sequences without significantly increasing computational costs.
Additionally, Jamba stabilizes its training process with RMSNorm, which normalizes activations and ensures that training remains stable, even when the model is scaled up to very large sizes.
Performance and Benefits of Jamba
Jamba's hybrid approach provides several advantages:
- Efficient Handling of Long Contexts: Jamba is able to process long sequences of text effectively, overcoming the limitations of traditional transformer models.
- Balance Between Performance and Efficiency: Jamba achieves strong performance while reducing memory usage and computational costs thanks to its combination of transformer, memory, and MoE layers.
- Scalability: By using fewer active parameters than other models, Jamba is scalable and can handle tasks that require understanding large amounts of text without compromising efficiency.
Limitations of Jamba
But this hybrid approach has limitations as well: like training instability that needs RMSNorm to keep things stable. It also requires a lot of GPU memory, like 80 GB, to handle longer contexts. Also, the mixture of experts (MOE) still needs more optimization to improve performance.
Further Directions
- Optimize MoE further for even better efficiency
- Investigate how hybrid architectures scale in even larger contexts and models
Group 47 Presentation: Jamba: A Hybrid Transformer - Mamba Language Model
A New Approach to Long-Context Language Modeling
Jamba is a new language model that seeks to overcome some of the classic limitations of Transformers—namely, the high memory overhead of the key-value cache and the computational inefficiency when processing long sequences. The paper introduces a hybrid architecture that blends traditional Transformer layers with a newer type of state-space layer called Mamba. This fusion is further enhanced by incorporating Mixture-of-Experts (MoE) modules, creating a model that is both memory efficient and highly performant on long-context tasks.
Key Architectural Innovations
1. Hybrid Transformer-Mamba Design:
Transformers vs. Mamba: Traditional Transformers excel at in-context learning and have become the de facto architecture for language models. However, their self-attention mechanism results in quadratic memory usage relative to the context length, making them less efficient for very long texts.
Mamba Layers: These layers, based on state-space models, offer a more efficient alternative by reducing memory requirements and improving throughput, especially when the input sequence becomes very long.
Interleaving Strategy: Jamba interleaves a small number of attention layers with a larger number of Mamba layers. In one specific configuration, a ratio of 1 attention layer to 7 Mamba layers was found to be both effective and compute-efficient. This interleaving allows the model to benefit from the strengths of both components.
2. Mixture-of-Experts (MoE):
Expanding Capacity Without Extra Cost: MoE modules are used to boost the total parameter count (and thus the model’s capacity) without a proportional increase in compute costs. This is achieved by activating only a small subset of experts (e.g., top-2 out of 16) for any given token.
Flexibility: The MoE integration is flexible and can be adjusted (by varying the frequency of MoE layers or the number of experts) to trade off between memory usage, throughput, and performance.
3. Resource Efficiency:
KV Cache Savings: One of the standout features of Jamba is its dramatic reduction in key-value cache memory requirements. For instance, while some comparable models might require tens of gigabytes to store the cache for long contexts, Jamba can process up to 256K tokens with only a few gigabytes.
Single-GPU Feasibility: Despite having a total available parameter count of 52B (with only 12B active at any time), the model is engineered to fit on a single 80GB GPU, which is an impressive engineering feat.
Constructive Critiques and Discussion
Ablation Insights:
The paper includes a thorough ablation study showing that neither pure Transformer nor pure Mamba models perform as robustly across all tasks as the hybrid does. However, it raises questions about the precise roles of each component—especially how much reliance there is on the Transformer layers for in-context learning. Future work could explore whether even fewer attention layers might suffice or if alternative mechanisms might further enhance the balance.
Scaling and Adaptability:
While the current design is optimized to fit on a single high-end GPU, it remains to be seen how the architecture scales when further pushed in size or applied to more diverse downstream tasks. Additionally, the robustness of the MoE routing (i.e., ensuring the right experts are chosen consistently) could benefit from further investigation and refinement.
Positional Encoding:
An interesting observation is that Jamba performs similarly with and without explicit positional embeddings. The paper suggests that the Mamba layers might be implicitly capturing positional information. This finding challenges the conventional wisdom in Transformer-based architectures and could inspire further research into whether explicit positional mechanisms are always necessary.
Enhancing Understanding: Why Does This Matter?
At its core, Jamba represents a promising step toward models that can handle extremely long texts without the usual memory and computational burdens. This is crucial for real-world applications—such as document summarization, legal analysis, or even processing large codebases—where context length can be a major bottleneck. By cleverly combining different architectural paradigms, Jamba offers a new tool in the quest for more scalable and efficient language models.
Moreover, the integration of MoE allows researchers and practitioners to scale model capacity in a cost-effective manner. The hybrid design not only improves throughput but also opens the door to further explorations in architectural combinations that might harness the best of both worlds: the rich representational power of Transformers and the efficiency of state-space models.