Kernel trick, O(n*d^2)
Hard AttentionImplement Linear Attention — O(S·D²) instead of O(S²·D), enabling efficient long-sequence processing.
Replace softmax with a kernel feature map \phi:
Use \phi(x) = \text{elu}(x) + 1 (ensures non-negative features).
Instead of computing the S \times S attention matrix, compute \phi(K)^T V first (a D_k \times D_v matrix), then multiply by \phi(Q).
• Must use a feature map (NOT softmax)
• Must be O(S·D²) — should run fast on long sequences
• You may use F.elu
Implement the function below. Use only basic PyTorch operations.
Use this code to debug before submitting.
Try solving it yourself first! Click below to reveal the solution.
For interactive practice with auto-grading, run TorchCode locally:pip install torch-judge then use check("linear_attention")
Kernel trick, O(n*d^2)