All Problems Description Template Solution

Multi-Head Attention

Parallel heads, split/concat, projection matrices

Hard Attention

Problem Description

Implement Multi-Head Attention from scratch โ€” the core building block of the Transformer.

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) W^O$$
$$\text{head}_i = \text{Attention}(Q W_i^Q,\; K W_i^K,\; V W_i^V)$$

Signature

class MultiHeadAttention: def __init__(self, d_model: int, num_heads: int): ... def forward(self, Q, K, V) -> torch.Tensor: ...

Requirements

• Use nn.Linear(d_model, d_model) for self.W_q, self.W_k, self.W_v, self.W_o

d_k = d_model // num_heads per head

forward(Q, K, V): Q is (B, seq_q, d_model), K/V are (B, seq_k, d_model)

• Must support cross-attention (seq_q != seq_k)

• Do NOT use torch.nn.MultiheadAttention

• You may use torch.softmax and torch.matmul

Steps

1. Project: q = self.W_q(Q), k = self.W_k(K), v = self.W_v(V)

2. Reshape to (B, num_heads, seq, d_k)

3. Scaled dot-product attention per head

4. Concat heads โ†’ (B, seq_q, d_model)

5. Output projection: self.W_o(concat)

Template

Implement the function below. Use only basic PyTorch operations.

# โœ๏ธ YOUR IMPLEMENTATION HERE class MultiHeadAttention: def __init__(self, d_model: int, num_heads: int): pass # Initialize W_q, W_k, W_v, W_o def forward(self, Q, K, V): pass # Implement multi-head attention

Test Your Implementation

Use this code to debug before submitting.

# ๐Ÿงช Debug torch.manual_seed(0) mha = MultiHeadAttention(d_model=32, num_heads=4) print("W_q type:", type(mha.W_q)) # should be nn.Linear print("W_q.weight shape:", mha.W_q.weight.shape) # (32, 32) x = torch.randn(2, 6, 32) out = mha.forward(x, x, x) print("Output shape:", out.shape) # (2, 6, 32) # Cross-attention Q = torch.randn(1, 3, 32) K = torch.randn(1, 7, 32) V = torch.randn(1, 7, 32) out2 = mha.forward(Q, K, V) print("Cross-attn shape:", out2.shape) # (1, 3, 32)

Reference Solution

Try solving it yourself first! Click below to reveal the solution.

# โœ… SOLUTION class MultiHeadAttention: def __init__(self, d_model: int, num_heads: int): self.num_heads = num_heads self.d_k = d_model // num_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) def forward(self, Q, K, V): B, S_q, _ = Q.shape S_k = K.shape[1] q = self.W_q(Q).view(B, S_q, self.num_heads, self.d_k).transpose(1, 2) k = self.W_k(K).view(B, S_k, self.num_heads, self.d_k).transpose(1, 2) v = self.W_v(V).view(B, S_k, self.num_heads, self.d_k).transpose(1, 2) scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) weights = torch.softmax(scores, dim=-1) attn = torch.matmul(weights, v) out = attn.transpose(1, 2).contiguous().view(B, S_q, -1) return self.W_o(out)

Tips

Run Locally

For interactive practice with auto-grading, run TorchCode locally:
pip install torch-judge then use check("multihead_attention")

Key Concepts

Parallel heads, split/concat, projection matrices

Multi-Head Attention

Description Template Test Solution Tips