All Problems Description Template Solution

Linear Attention

Kernel trick, O(n*d^2)

Hard Attention

Problem Description

Implement Linear Attention — O(S·D²) instead of O(S²·D), enabling efficient long-sequence processing.

Replace softmax with a kernel feature map \phi:

$$\text{LinearAttn}(Q,K,V) = \frac{\phi(Q) \left(\phi(K)^T V\right)}{\phi(Q) \cdot \sum \phi(K)}$$

Feature map

Use \phi(x) = \text{elu}(x) + 1 (ensures non-negative features).

Signature

def linear_attention(Q, K, V): # Q: (B, S, D_k), K: (B, S, D_k), V: (B, S, D_v) # Returns: (B, S, D_v)

Key insight

Instead of computing the S \times S attention matrix, compute \phi(K)^T V first (a D_k \times D_v matrix), then multiply by \phi(Q).

Rules

• Must use a feature map (NOT softmax)

• Must be O(S·D²) — should run fast on long sequences

• You may use F.elu

Template

Implement the function below. Use only basic PyTorch operations.

# ✏️ YOUR IMPLEMENTATION HERE def linear_attention(Q, K, V): pass # Replace this

Test Your Implementation

Use this code to debug before submitting.

# 🧪 Debug Q = torch.randn(1, 8, 16) K = torch.randn(1, 8, 16) V = torch.randn(1, 8, 32) out = linear_attention(Q, K, V) print("Output shape:", out.shape) # (1, 8, 32) print("Has NaN?", torch.isnan(out).any().item())

Reference Solution

Try solving it yourself first! Click below to reveal the solution.

# ✅ SOLUTION def linear_attention(Q, K, V): Q_prime = F.elu(Q) + 1 K_prime = F.elu(K) + 1 KV = torch.bmm(K_prime.transpose(1, 2), V) # (B, D_k, D_v) Z = K_prime.sum(dim=1, keepdim=True) # (B, 1, D_k) num = torch.bmm(Q_prime, KV) # (B, S, D_v) den = torch.bmm(Q_prime, Z.transpose(1, 2)) # (B, S, 1) return num / (den + 1e-6)

Tips

Run Locally

For interactive practice with auto-grading, run TorchCode locally:
pip install torch-judge then use check("linear_attention")

Key Concepts

Kernel trick, O(n*d^2)

Linear Attention

Description Template Test Solution Tips