All Problems Description Template Solution

KV Cache Attention

Incremental decoding, cache K/V, prefill vs decode

Hard Attention

Problem Description

Implement multi-head attention with KV caching for efficient autoregressive generation.

During LLM inference, recomputing all key/value projections at every step is wasteful.

A KV cache stores previously computed K and V tensors so only the new token(s) need projection.

Signature

class KVCacheAttention(nn.Module): def __init__(self, d_model: int, num_heads: int): ... def forward(self, x: torch.Tensor, cache=None) -> tuple[torch.Tensor, tuple]: # x: (B, S_new, D) — new tokens # cache: None or (K_past, V_past) each (B, num_heads, S_past, d_k) # Returns: (output, (K_all, V_all))

Requirements

• Inherit from nn.Module

self.W_q, self.W_k, self.W_v, self.W_o: nn.Linear projections

• When cache=None (prefill): apply causal mask, return all K/V as cache

• When cache provided (decode): concat new K/V with cached, no causal mask needed for single-token decode

• Incremental decode must produce identical results to full forward pass

Key Idea

Prefill: [t0 t1 t2 t3] → full causal attention → cache = (K_{0:3}, V_{0:3}) Decode: [t4] → Q=t4, K/V=cache+t4 → cache = (K_{0:4}, V_{0:4}) Decode: [t5] → Q=t5, K/V=cache+t5 → cache = (K_{0:5}, V_{0:5})

Template

Implement the function below. Use only basic PyTorch operations.

# ✏️ YOUR IMPLEMENTATION HERE class KVCacheAttention(nn.Module): def __init__(self, d_model, num_heads): super().__init__() pass # Initialize W_q, W_k, W_v, W_o def forward(self, x, cache=None): # 1. Project Q, K, V from x # 2. Reshape to multi-head: (B, num_heads, S, d_k) # 3. If cache exists, concat new K/V with cached K/V # 4. Compute attention (causal mask needed during prefill) # 5. Return (output, (K_all, V_all)) pass

Test Your Implementation

Use this code to debug before submitting.

# 🧪 Debug torch.manual_seed(0) attn = KVCacheAttention(d_model=64, num_heads=4) x = torch.randn(1, 6, 64) # Full forward full_out, _ = attn(x) print("Full output shape:", full_out.shape) # (1, 6, 64) # Incremental: prefill 4, decode 1, decode 1 out1, cache = attn(x[:, :4]) print("Cache K shape:", cache[0].shape) # (1, 4, 4, 16) out2, cache = attn(x[:, 4:5], cache=cache) out3, cache = attn(x[:, 5:6], cache=cache) inc_out = torch.cat([out1, out2, out3], dim=1) print("Match:", torch.allclose(full_out, inc_out, atol=1e-5))

Reference Solution

Try solving it yourself first! Click below to reveal the solution.

# ✅ SOLUTION class KVCacheAttention(nn.Module): def __init__(self, d_model, num_heads): super().__init__() self.num_heads = num_heads self.d_k = d_model // num_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) def forward(self, x, cache=None): B, S_new, _ = x.shape q = self.W_q(x).view(B, S_new, self.num_heads, self.d_k).transpose(1, 2) k = self.W_k(x).view(B, S_new, self.num_heads, self.d_k).transpose(1, 2) v = self.W_v(x).view(B, S_new, self.num_heads, self.d_k).transpose(1, 2) if cache is not None: k = torch.cat([cache[0], k], dim=2) v = torch.cat([cache[1], v], dim=2) new_cache = (k, v) S_total = k.shape[2] scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) if S_new > 1: # Causal mask for prefill: each query position can only attend to # positions up to itself in the full sequence S_past = S_total - S_new mask = torch.triu( torch.ones(S_new, S_total, device=x.device, dtype=torch.bool), diagonal=S_past + 1, ) scores = scores.masked_fill(mask, float('-inf')) weights = torch.softmax(scores, dim=-1) attn = torch.matmul(weights, v) out = self.W_o(attn.transpose(1, 2).contiguous().view(B, S_new, -1)) return out, new_cache

Tips

Run Locally

For interactive practice with auto-grading, run TorchCode locally:
pip install torch-judge then use check("kv_cache")

Key Concepts

Incremental decoding, cache K/V, prefill vs decode

KV Cache Attention

Description Template Test Solution Tips