Escaping the OOM Trap with Liger-Kernel’s TiledMLP
3 mins read

Escaping the OOM Trap with Liger-Kernel’s TiledMLP

ARTICLE TITLE: Escaping the OOM Trap with Liger-Kernel’s TiledMLP ARTICLE CONTENT: I was staring at my terminal at 1:30 AM last Thursday, watching the exact same CUDA out of memory stack trace scroll by for the fifth time. DeepSpeed Zero-3 was already enabled, and gradient checkpointing was on. I even dropped the micro-batch size to 1. Still, it crashed. The attention mechanism usually gets all the blame for memory issues, but people ignore the intermediate activations in the MLP layers. When you push your sequence length that high, those hidden states get ridiculously large. Gigabytes of activations. Just sitting there. Waiting for the backward pass. Then I remembered seeing a patch for Liger-Kernel where Sangchun Ha ported over the TiledMLP implementation from the Arctic Long Sequence Training stack. I ran pip install liger-kernel==0.3.0, changed three lines of code, and the memory usage dropped enough to not only fit the 128k sequence but actually let me bump the batch size back up. The math behind standard training breaks down at massive sequence lengths. TiledMLP fixes this by chunking the computation. Instead of doing the entire sequence’s MLP forward pass at once, it tiles it. You pay a penalty—specifically, additional forward recompute for the MLP block—but you save a massive amount of VRAM.
from liger_kernel.transformers import apply_liger_kernel_to_llama

# Patching the model before passing it to DeepSpeed
apply_liger_kernel_to_llama(
    rope=True,
    cross_entropy=True,
    fused_linear_cross_entropy=True,
    mlp=True # This triggers the TiledMLP for memory savings
)

# Now wrap with DeepSpeed as usual
model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    model_parameters=model.parameters(),
    config=ds_config
)
For a 32k sequence length, the individual step time actually increased by about 12%. But because the activation memory footprint shrank by nearly 40%, I could quadruple my micro-batch size from 2 to 8. Net result? My overall training throughput went up. Wall-clock time to process 1,000 tokens dropped. While TiledMLP is the new fix for long sequences, I need to talk about FusedLinearCrossEntropy. If you aren’t using this yet, you’re burning memory for absolutely no reason. The fused version combines the linear layer and cross-entropy calculations, performing the backward pass for the linear layer *inside* the forward pass. No activations saved. No recompute cost. As described in the DeepSpeed documentation, this optimization can significantly reduce memory usage. One gotcha I ran into: if you pass custom class weights to the fused version, it behaves slightly differently than native PyTorch when mixed with DeepSpeed’s gradient accumulation. My loss scale went totally out of whack on the first run. Keep an eye on your loss curves for the first 100 steps if you’re migrating an existing training script mid-project. I’m done fighting the default PyTorch memory allocator. Patch the model, take the compute hit, and get your batch size back.