Tiled attention, online softmax, memory-efficient
Hard AttentionImplement tiled attention with online softmax — the core idea behind Flash Attention.
Instead of materializing the full S×S attention matrix, process in blocks:
1. For each Q-block, iterate over K/V blocks
2. Use online softmax: track running max and sum
3. Rescale accumulator when max changes: acc *= exp(old_max - new_max)
4. Final: output = acc / row_sum
Must give identical results to standard softmax attention.
Implement the function below. Use only basic PyTorch operations.
Use this code to debug before submitting.
Try solving it yourself first! Click below to reveal the solution.
For interactive practice with auto-grading, run TorchCode locally:pip install torch-judge then use check("flash_attention")
Tiled attention, online softmax, memory-efficient