All Problems Description Template Solution

Flash Attention

Tiled attention, online softmax, memory-efficient

Hard Attention

Problem Description

Implement tiled attention with online softmax — the core idea behind Flash Attention.

Signature

def flash_attention(Q, K, V, block_size=32) -> Tensor: # Q, K, V: (B, S, D) # Returns: (B, S, D) — same as standard attention

Key Insight

Instead of materializing the full S×S attention matrix, process in blocks:

1. For each Q-block, iterate over K/V blocks

2. Use online softmax: track running max and sum

3. Rescale accumulator when max changes: acc *= exp(old_max - new_max)

4. Final: output = acc / row_sum

Must give identical results to standard softmax attention.

Template

Implement the function below. Use only basic PyTorch operations.

# ✏️ YOUR IMPLEMENTATION HERE def flash_attention(Q, K, V, block_size=32): # Process Q in blocks, iterate K/V blocks with online softmax pass

Test Your Implementation

Use this code to debug before submitting.

# 🧪 Debug import math Q, K, V = torch.randn(1, 8, 4), torch.randn(1, 8, 4), torch.randn(1, 8, 4) out = flash_attention(Q, K, V, block_size=4) scores = torch.bmm(Q, K.transpose(1,2)) / math.sqrt(4) ref = torch.bmm(torch.softmax(scores, dim=-1), V) print('Match:', torch.allclose(out, ref, atol=1e-4))

Reference Solution

Try solving it yourself first! Click below to reveal the solution.

# ✅ SOLUTION def flash_attention(Q, K, V, block_size=32): B, S, D = Q.shape output = torch.zeros_like(Q) for i in range(0, S, block_size): qi = Q[:, i:i+block_size] bs_q = qi.shape[1] row_max = torch.full((B, bs_q, 1), float('-inf'), device=Q.device) row_sum = torch.zeros(B, bs_q, 1, device=Q.device) acc = torch.zeros(B, bs_q, D, device=Q.device) for j in range(0, S, block_size): kj = K[:, j:j+block_size] vj = V[:, j:j+block_size] scores = torch.bmm(qi, kj.transpose(1, 2)) / math.sqrt(D) block_max = scores.max(dim=-1, keepdim=True).values new_max = torch.maximum(row_max, block_max) correction = torch.exp(row_max - new_max) exp_scores = torch.exp(scores - new_max) acc = acc * correction + torch.bmm(exp_scores, vj) row_sum = row_sum * correction + exp_scores.sum(dim=-1, keepdim=True) row_max = new_max output[:, i:i+block_size] = acc / row_sum return output

Tips

Run Locally

For interactive practice with auto-grading, run TorchCode locally:
pip install torch-judge then use check("flash_attention")

Key Concepts

Tiled attention, online softmax, memory-efficient

Flash Attention

Description Template Test Solution Tips