• LLMs Research
  • Posts
  • Research papers improving performance of LLMs [2/3]

Research papers improving performance of LLMs [2/3]

Research papers published from January 16th to February 15th, 2025 proposing context length and architectural changes in LLMs

In partnership with

What’s in it today?

  • ReLearn makes LLMs forget unwanted knowledge and remember how to speak good

  • Coupled Adam fixes Adam so language model embeddings aren't too "extra"

  • TransMLA converts GQA models to MLA ones for better LLM expression, because apparently size does matter

  • LASP-2 makes linear attention training zoom by decluttering communication

Don’t have much time to read entire newsletter? Well, listen to this fun and engaging podcast covering these research papers in detail.

ReLearn: Unlearning via Learning for Large Language Models

Why this research is important?

LLMs are getting trained on more and more data and its consumption is increasing drastically. LLM providers use large-scale AI training datasets that often contain unauthorized private and copyrighted information. Recent legal actions, like the New York Times lawsuit against OpenAI, underscore the urgency of addressing these issues. The core problem is that current unlearning methods often rely on reverse optimization, which reduces the probabilities of target tokens but degrades the model's fundamental language generation capabilities, resulting in repetitive or incoherent outputs. The paper argues that current evaluation metrics are also inadequate, as they primarily focus on contextual forgetting while failing to capture broader limitations in fluency and relevance. Therefore, this research is crucial for developing more effective and reliable unlearning techniques that comply with privacy and copyright regulations without compromising the model's linguistic coherence.

Approach:

Paper proposes unlearning pipeline based on data augmentation and positive optimization. Instead of suppressing token probabilities (reverse optimization), ReLearn overwrites sensitive information with new, authorized knowledge by training the model on augmented data. This process preserves the model's linguistic ability while forgetting target knowledge, akin to human memory updating. To better understand I divide ReLearn pipeline in following steps:

  1. Unlearning data synthesis: In this step ReLearn pipeline synthesize non-sensitive training data by augmenting the forget set (data to be unlearned) with diverse variations, ensuring comprehensive coverage of the knowledge to be forgotten. This process is entirely performed by an LLM using specific prompts. It performs question and answer augmentation. For question augment it iterates on each question-answer pair in the forget set, the method synthesizes four types of question variations: (1) Simple Variants to prevent overfitting to specific phrasings, (2) Contextual Variants to ensure forgetting across contexts, (3) Noise Variants to enhance robustness to noisy inputs, and (4) Logical Variants to adapt to different knowledge forms by altering the logic of the questions. and, for answer augmentation it iterates on each augmented question, the method synthesizes new pairs with relevant, deliberately vague answers. These answers must be (1) Unlearned, containing no original sensitive content; (2) Relevant, aligning with the question context; and (3) No-risk, avoiding the introduction of new sensitive content.

  2. Content verification: To ensure the safety of the augmented data, the method uses a Content Verification process for the synthesized answers. This process utilizes LLMs to conduct Chain-of-Thought (COT) analysis on each augmented answer, evaluating it against predefined safety criteria. If verification fails, indicating a potential risk in the augmented data, the process returns to the step of "Answer Augmentation."

  3. Data diversification: To prevent QA format overfitting and catastrophic forgetting, the method uses two main strategies sentence completion and generic dataset. In sentence completion, the augmented data is augmented with sentence completion pairs, split from each answer. For example, "Isabella Marquez can be reached through conventional electronic communication channels." is split into "Isabella Marquez can be reached through" and the label "conventional electronic communication channels." and in generic dataset, ReLearn applies generic data by randomly sampling questions from WikiQA and Chatbot Instruction datasets.

  4. Unlearning via learning: The unlearning objective is formulated using three datasets: the augmented forget set, the sentence completion set, and the generic dataset. The vanilla model is then fine-tuned on a combination of these datasets. The loss function combines generative loss, masked language model loss, and knowledge loss, which is then used to perform parameter update and fine-tuning of the vanilla model.

Results:

Paper identified limitations in existing unlearning metrics and proposes three metrics: Knowledge Forgetting Rate (KFR), Knowledge Retention Rate (KRR), and Linguistic Score (LS). KFR measures the extent of knowledge forgetting, and KRR measures the extent of knowledge retention, while LS evaluates the linguistic quality of the unlearned model, capturing linguistic degradation patterns such as reduced vocabulary diversity, simplified syntax, and diminished lexical richness. ReLearn outperforms SOTA in many benchmarks, one of them is shown above.

Sponsored: Special offer for you claim $50 free credits!

Accelerate your AI projects with Prolific. Claim $50 free credits and get quality human data in minutes from 200,000+ taskers. No setup cost, no subscription, no delay—get started, top up your account to claim your free credit, and test Prolific for yourself now. Use code: LLM-RESEARCH-50

Better Embeddings with Coupled Adam

We implemented this research paper. Please support our work by sharing our work and provide us feedback to make our hardworking truly helpful to you!

Team LLMs Research

This paper solves the problem of anisotropic embeddings. Buit what it is? Well, word representations learned by LLMs tend to cluster in a small subspace, away from the origin of the vector space. This clustering limits the semantic expressiveness of the embeddings and, consequently, the model's overall performance. The paper begins by noting that while various attempts have been made to explain and alleviate this issue, the role of the optimization algorithm itself has been largely overlooked. This research paper argue that the second moment estimate in Adam, which is used to adapt the learning rate for each parameter, causes embedding vectors to shift collectively away from the origin.

To understand this, let’s first understand how Adam works. Adam maintains an exponentially decaying average of past gradients (first moment) and squared gradients (second moment) for each parameter. The second moment is meant to scale the learning rate adaptively, giving larger updates to parameters with smaller gradients and smaller updates to parameters with larger gradients. This is particularly useful for sparse data, like word frequencies in LLM training, where some words occur far more often than others.

However, paper shows that this adaptive scaling, specifically the  i-dependency of the second moment, leads to the anisotropy problem. They demonstrate mathematically that while the sum of gradients over all embedding vectors vanishes with SGD, the weighted sum (weighted by the adaptive learning rate based on the second moment) does not vanish with Adam. This non-vanishing sum causes a collective shift of the embedding vectors away from the origin. They also provide experimental evidence to show that the expectation value of the second moment is proportional to the unigram probability of the corresponding word, confirming the link between word frequency and the anisotropic effect.

Proposed approach

Paper proposes a modified version of Adam called Coupled Adam. The main idea behind Coupled Adam is to enforce that the second moments are the same for all embedding vectors. To do this, team replaced the individual second moment estimates for each embedding vector with the average of the second moments over all embedding vectors.

Mathematically, the original Adam update rule uses an i-dependent effective learning rate (ηi) that depends on the second moment estimate of the ith embedding vector. Coupled Adam replaces this i-dependent rate with an i-independent rate based on the average second moment.

This coupling of the second moments ensures that the sum of embedding updates vanishes, similar to SGD, preventing the collective shift of embeddings away from the origin. At the same time, Coupled Adam retains the benefits of Adam, using a second moment to normalize the embedding update vectors (albeit a global one).

Implementation

The implementation of Coupled Adam is straightforward. Paper provides pseudocode in Algorithm 1 of the paper which is as shown above. The key modification lies in Algorithm 1 lines 8-12 in the inclusion of Coupled Adam where the second moments are averaged across all the vocabulary items before the update vector is calculated, as such:

  1. Calculate the standard Adam update vectors for all embedding vectors ei.

  2. Compute the average second moment (ν) over all embedding vectors.

  3. Replace the individual second moment estimates for each embedding vector with this average value (ν).

  4. Apply the standard Adam update rule using this shared second moment.

Paper emphasizes that Coupled Adam can be easily integrated into existing training pipelines with minimal code changes, specifically for the embedding parameters, while standard Adam can be used for all non-embedding parameters.

Setup and Results:

Paper presents multiple experiments to evaluate the effectiveness of Coupled Adam. These experiments were divided into small-scale and large-scale settings, involving different datasets, model architectures, and training frameworks.

  • Small-Scale Experiments: The small-scale experiments utilized the OpenWebText Corpus and the GPT-2 tokenizer. The model architecture also followed GPT-2, with hyperparameter settings derived from GPT-3. Model sizes ranged from 125M to 760M parameters, and dataset sizes ranged from 5B to 20B tokens. Each experiment was repeated three times with different random seeds to assess statistical significance.

  • Large-Scale Experiments: The large-scale experiments employed the SlimPajama dataset and the GPT-2 tokenizer. The model architecture was a state-of-the-art dense transformer similar to those used in recent LLMs, including RoPE embeddings and SwiGLU activation functions. Model sizes were 1.3B and 2.6B parameters.

Across all experiments, the authors trained two models: one using standard Adam and one using Coupled Adam for the embedding parameters. They then evaluated the models using various metrics to assess both general performance and embedding quality. Here are the findings:

  • The use of Coupled Adam resulted in lower perplexity on both small and large scale. For example, they reported 14.69 perplexity score using vanilla adam compared to 14.45 score using Coupled Adam.

  • Coupled Adam was seen to outperform vanilla adam in different training sizes. For example, it was seen that model trained on 10B tokens has perplexity of 15.51. And, on similar training setup using Coupled Adam, perplexity score comes out to be 15.37.

  • Coupled Adam had significantly less anisotropic score which proves that the method is working to improve the quality of embeddings. The lower the score, the better is the embedding. Example: The anistropic score for vanilla adam was 0.874 compared to coupled adam which was 0.817.

TransMLA: Multi-head Latent Attention Is All You Need

As models and sequence lengths grow, the KV cache, which stores information about previous tokens, becomes a significant bottleneck, consuming substantial memory and bandwidth. While methods like Group Query Attention (GQA) have been adopted to reduce KV cache size, the authors propose that Multi-head Latent Attention (MLA), used in Deepseek models, offers a more theoretically sound and practically effective solution. The key problem this paper tackles is how to efficiently transition existing GQA-based models, which are widely used, to MLA-based models, which are more expressive for the same KV cache overhead.

Core argument:

The main argument is that Multi-head Latent Attention (MLA) provides greater expressive power than Group Query Attention (GQA) for the same KV cache overhead. Paper provides a theoretical proof to support this claim. It then introduces TransMLA, a post-training method, that enables the conversion of widely used GQA-based pre-trained models (such as LLaMA, Qwen, and Mixtral) into equivalent MLA-based models. The goal is to allow existing models to benefit from the superior expressiveness of MLA with minimal changes to model architecture and without increasing KV cache size. 

Methodology

We divided proposed method in three main parts:

  1. Theoretical Proof of GQA to MLA Conversion: Paper provides a mathematical proof showing that any GQA configuration can be equivalently transformed into MLA with the same size of KV cache. This proof relies on showing that the key transformation in GQA can be represented as a low-rank factorization, which is the core idea behind MLA. Paper demonstrates how to replicate keys in GQA (making all heads identical within a group), move that replication to the parameter side (replicating weights instead of activations), and then show that this replicated weight matrix has a low-rank structure that can be factorized in the same way as MLA. They achieve this by factorizing WK using the Singular Value Decomposition (SVD) WK = UK SK VK , and only keeping the top-r singular values, thus achieving KV compression.

  2. MLA is Not Representable in GQA: Paper also proves that the reverse is not true – MLA cannot always be represented by GQA. They demonstrate this by considering a scenario where vectors in the MLA transformation matrix are orthogonal. This leads to diversity in the outputs that cannot be replicated by GQA’s grouped heads which are necessarily identical within each group. This asymmetry provides the theoretical justification for expecting performance improvements when converting GQA models to MLA.

  3. Practical Conversion and Post-Training: The core of the TransMLA method lies in converting the weights of a pre-trained GQA model to equivalent MLA weights. This involves several steps:

    • Weight Decomposition: The authors take the existing key projection matrix from the GQA model and decompose it into two smaller matrices, corresponding to the WaK and WbK matrices in the MLA formulation. This decomposition is achieved using SVD or other low-rank factorization techniques.

    • Weight Recombination: These newly formed WaK and WbK matrices are then integrated into the model as the new key projection weights. This effectively replaces the GQA attention mechanism with the MLA mechanism.

    • Post-Conversion Training: After the conversion, the model is fine-tuned on a dataset to allow it to adapt to the new MLA architecture and realize the benefits of its enhanced expressiveness. This fine-tuning step is crucial for recovering any potential performance loss from the weight conversion and for fully exploiting the potential of MLA. This training is done without increasing the KV cache size.

LASP-2: Rethinking Sequence Parallelism for Linear Attention and Its Hybrid

Attention mechanism’s computation increases at quadratic complexity with respect to sequence length. Paper enhances the efficiency of linear attention mechanisms through a combination of theoretical insights and practical implementations. Paper begins by analyzing the limitations of existing linear attention methods, which typically rely on fixed parallelism strategies that do not fully exploit the potential of modern hardware architectures and than it proposes a new approach that allows for dynamic sequence parallelism, enabling better utilization of computational resources.

Paper introduces a hybrid architecture that combines both local and global attention mechanisms. Local attention focuses on a limited context window, allowing for efficient computation within smaller segments of the input sequence. In contrast, global attention captures long-range dependencies across the entire sequence. By integrating these two forms of attention, paper achieves a balance between computational efficiency and expressive power.

The paper details how LASP-2 employs a two-stage process for attention computation. In the first stage, local attention is computed in parallel across segments of the input sequence. This allows for rapid processing of individual segments while maintaining low memory overhead. In the second stage, global attention is applied to aggregate information from these segments, ensuring that long-range dependencies are preserved. Moreover paper introduces a novel mechanism for adaptive segment sizing based on input characteristics. This allows LASP-2 to dynamically adjust the size of local segments depending on the complexity of the input data, further optimizing performance.

Result

Paper demonstrates up to a 30% reduction in training time while maintaining comparable accuracy levels on large datasets. LASP-2 maintained stable performance without significant degradation in processing speed or accuracy for up to 10,000 tokens.

Reply

or to participate.