The key-value cache in transformer models is a memory bottleneck for mobile inference. A 7B parameter model with 4096 context length and 32 attention heads consumes roughly 2GB of KV cache alone at float16 precision. On devices with 4-6GB total RAM, this leaves little headroom for system processes, UI, or multiple concurrent tasks.

Traditional solutions—static quantization to int8 or reducing context windows—sacrifice either quality or utility. This article presents a hierarchical pruning strategy that exploits the natural attention patterns in autoregressive generation to selectively discard cached states while preserving semantic coherence.

Anatomy of the KV Cache

Each transformer layer maintains two tensors per attention head: keys (K) and values (V). For a model with L layers, H heads, dimension D, and sequence length S, total memory is:

Memory = 2 × L × H × S × (D/H) × bytes_per_element

For LLaMA-7B (32 layers, 32 heads, 4096 dimensions) at 500 tokens, that's 2 × 32 × 32 × 500 × 128 × 2 bytes ≈ 524MB. At full 4096 context, it balloons to 4.2GB.

The cache grows linearly with sequence length during generation. Each new token appends one KV pair per head per layer. Mobile implementations using llama.cpp or ONNX Runtime Mobile face hard limits: iOS terminates apps exceeding memory pressure thresholds, and Android's low-memory killer is aggressive on budget devices.

Attention Pattern Analysis

Not all cached tokens contribute equally to prediction quality. Empirical studies of attention weights during generation reveal three patterns:

  • Recency bias: Recent tokens (last 50-100) receive 60-70% of attention mass across most heads
  • Structural anchors: Special tokens (BOS, system prompts, section headers) maintain persistent high attention
  • Layer specialization: Early layers focus on local syntax; deeper layers attend to distant semantic context

This heterogeneity suggests that uniform eviction policies (FIFO, random) are suboptimal. A hierarchical approach can prune aggressively in upper layers while preserving more history in deeper layers.

Measuring Token Importance

We define a per-token importance score I(t) as the exponentially weighted moving average of attention received across the last N generation steps:

I(t) = α × attention_current(t) + (1 - α) × I_prev(t)

Where α = 0.3 works well empirically. Tokens with I(t) below a threshold θ become eviction candidates. Crucially, we compute separate thresholds per layer group: shallow (layers 0-10), mid (11-21), deep (22-31).

Three-Tier Eviction Policy

The pruning system operates in three tiers, triggered when cache memory exceeds 70% of the allocated budget:

Tier 1: Shallow Layer Aggressive Pruning

Layers 0-10 retain only the last 128 tokens plus structural anchors. These layers primarily handle syntactic dependencies and positional encoding. Testing on instruction-following tasks showed no measurable perplexity increase when pruning 75% of the shallow-layer cache beyond this window.

Tier 2: Mid-Layer Selective Eviction

Layers 11-21 use importance scoring with θ = 0.15. Tokens in the middle 50% of the context (excluding recent 128 and first 64) are evaluated every 32 generation steps. Low-importance tokens are evicted in batches of 16. This tier typically removes 40-50% of mid-layer cache while preserving semantic anchors.

Tier 3: Deep Layer Conservative Retention

Layers 22-31 maintain full cache for recent 256 tokens and use a higher threshold (θ = 0.25) for older content. Deep layers perform long-range semantic reasoning, so aggressive pruning causes coherence drift. Eviction here is last-resort and removes only the bottom 10% by importance score.

Implementation in ONNX Runtime Mobile

ONNX Runtime's ExecutionProvider API allows custom memory management. We implement pruning as a session callback invoked after each generation step:

class HierarchicalKVPruner : public IExecutionProvider {
  void OnPostExecute(Tensor* kv_cache, int layer_idx) {
    if (memory_pressure() > 0.7) {
      int tier = layer_idx < 11 ? 1 : (layer_idx < 22 ? 2 : 3);
      prune_by_tier(kv_cache, tier);
    }
  }
};

The pruner maintains a circular buffer of attention weights (last 16 steps) and computes importance scores in a background thread to avoid blocking generation. Eviction operates in-place by shifting tensor slices, avoiding reallocation overhead.

Metal Performance Shaders Optimization

On iOS, we use Metal compute shaders for importance scoring. The shader processes 64 tokens in parallel, computing weighted averages via SIMD dot products. Batch eviction uses a prefix-sum scan to compact the tensor without branching:

kernel void compact_kv_cache(
  device float* cache [[buffer(0)]],
  device float* importance [[buffer(1)]],
  constant float& threshold [[buffer(2)]]
) {
  // Parallel prefix sum of keep mask
  // followed by gather operation
}

This approach achieves 3-5ms pruning latency per layer on A15 Bionic, negligible compared to 80-120ms per-token inference time.

Production Results

Deployed in a mobile chat application using LLaMA-7B quantized to 4-bit, the hierarchical pruner delivers:

  • Memory reduction: 42% lower peak RAM usage at 2048 context length (1.8GB vs 3.1GB baseline)
  • Coherence preservation: 98.5% BLEU score vs full cache on 500-sample instruction-following benchmark
  • Latency impact: +4% median per-token generation time (pruning overhead amortized over 32-step intervals)
  • Extended context: Enables 3072 context on 4GB devices where full cache OOMs at 2048

Perplexity on multi-turn conversations remains within 2% of full-cache baseline. The primary failure mode is subtle coherence drift in conversations exceeding 1500 tokens, where deep-layer eviction begins removing semantic anchors.

Adaptive Threshold Tuning

Static thresholds (0.15, 0.25) work for general chat but fail for specialized tasks. Code generation benefits from higher retention (θ = 0.10 in mid-layers), while summarization tolerates aggressive pruning (θ = 0.30). We expose a task_profile parameter:

enum TaskProfile {
  CHAT,        // balanced
  CODE,        // high retention
  SUMMARIZE,   // aggressive pruning
  LONG_FORM    // deep-layer focus
};

Profile selection adjusts tier thresholds and eviction batch sizes. In practice, applications set this per session based on user intent or prompt classification.

Tradeoffs and Future Work

Hierarchical pruning is not free. Importance tracking adds 8-12MB overhead for score buffers and attention history. Background scoring consumes 3-5% CPU on average. For latency-critical applications (voice assistants, real-time translation), the 4% generation slowdown may be unacceptable.

The current implementation assumes causal attention. Bidirectional models (BERT-style encoders) require different policies, as attention patterns lack recency bias. Sparse attention mechanisms (like Longformer or BigBird) already reduce KV cache size but remain uncommon in mobile-deployed models.

Future directions include learned eviction policies—training a small MLP to predict token importance based on attention statistics and hidden states—and integration with speculative decoding, where draft model KV caches are pruned more aggressively than verification passes.

Practical Deployment

For teams shipping on-device LLMs, hierarchical pruning offers a pragmatic middle ground between quality and resource constraints. Start with conservative thresholds (0.20/0.30/0.35 for shallow/mid/deep) and measure perplexity on task-specific benchmarks. Monitor memory pressure via OS APIs (iOS: os_proc_available_memory, Android: MemoryInfo) and trigger pruning proactively rather than waiting for OOM crashes.

The technique pairs well with other optimizations: 4-bit quantization reduces base memory by 75%, and sliding-window attention caps absolute cache size. Together, these enable 7B models on mid-range devices—a threshold that unlocks genuinely useful applications beyond novelty demos.