MLP Repurposing

Summary: A technique that treats existing MLP blocks in neural networks as adaptable memory components rather than adding new architectural elements. This allows the final projection matrices of MLPs to serve as "fast weights" that can be dynamically updated during inference to store contextual information.

Overview

MLP Repurposing leverages the existing multi-layer perceptron components already present in transformer architectures as adaptable memory storage. Instead of requiring additional parameters or architectural modifications, this approach treats the final projection matrix of MLP blocks as Fast Weights that can be updated in-place during inference.

The technique enables Test-Time Training by allowing models to dynamically adapt their parameters based on new information encountered during inference. This is particularly valuable for Long-Context Modeling, where traditional approaches struggle with sequences beyond typical context window limits. The repurposed MLP blocks act as external memory that can be written to and read from as the model processes sequential chunks of input.

Unlike memory-augmented approaches that require adding new components, MLP Repurposing works as a drop-in enhancement to pre-trained models. The existing MLP architecture remains unchanged, but its final projection layer gains the ability to store and retrieve contextual information through parameter updates aligned with Next-Token Prediction objectives.

Key Details

  • Architecture Agnostic: Works with existing transformer models from 500M to 14B parameters without requiring retraining
  • Computational Efficiency: Uses Chunk-wise Updates instead of per-token updates, maintaining compatibility with Context Parallelism
  • Memory Mechanism: The final projection matrix W_o in MLP blocks serves as the adaptable memory component
  • Update Objective: Uses LM-aligned targets that provably increase correct token logits while keeping incorrect ones unchanged
  • Performance Gains: Achieves superior performance on contexts up to 128k tokens, with 4B parameter models showing consistent improvements
  • Minimal Overhead: Adds negligible computational cost compared to standard inference
  • Scalability: Tested across multiple model families with consistent gains at various scales

The approach processes sequences in chunks, applying the current fast weights to generate predictions, then updating those weights based on the observed tokens. This apply-then-update cycle allows the model to accumulate contextual knowledge that persists across the entire sequence.

Relationships

Sources