Problem Description
Implement Multi-Head Attention from scratch โ the core building block of the Transformer.
$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) W^O$$
$$\text{head}_i = \text{Attention}(Q W_i^Q,\; K W_i^K,\; V W_i^V)$$
Signature
class MultiHeadAttention:
def __init__(self, d_model: int, num_heads: int): ...
def forward(self, Q, K, V) -> torch.Tensor: ...
Requirements
• Use nn.Linear(d_model, d_model) for self.W_q, self.W_k, self.W_v, self.W_o
• d_k = d_model // num_heads per head
• forward(Q, K, V): Q is (B, seq_q, d_model), K/V are (B, seq_k, d_model)
• Must support cross-attention (seq_q != seq_k)
• Do NOT use torch.nn.MultiheadAttention
• You may use torch.softmax and torch.matmul
Steps
1. Project: q = self.W_q(Q), k = self.W_k(K), v = self.W_v(V)
2. Reshape to (B, num_heads, seq, d_k)
3. Scaled dot-product attention per head
4. Concat heads โ (B, seq_q, d_model)
5. Output projection: self.W_o(concat)
Template
Implement the function below. Use only basic PyTorch operations.
# โ๏ธ YOUR IMPLEMENTATION HERE
class MultiHeadAttention:
def __init__(self, d_model: int, num_heads: int):
pass # Initialize W_q, W_k, W_v, W_o
def forward(self, Q, K, V):
pass # Implement multi-head attention
Test Your Implementation
Use this code to debug before submitting.
# ๐งช Debug
torch.manual_seed(0)
mha = MultiHeadAttention(d_model=32, num_heads=4)
print("W_q type:", type(mha.W_q)) # should be nn.Linear
print("W_q.weight shape:", mha.W_q.weight.shape) # (32, 32)
x = torch.randn(2, 6, 32)
out = mha.forward(x, x, x)
print("Output shape:", out.shape) # (2, 6, 32)
# Cross-attention
Q = torch.randn(1, 3, 32)
K = torch.randn(1, 7, 32)
V = torch.randn(1, 7, 32)
out2 = mha.forward(Q, K, V)
print("Cross-attn shape:", out2.shape) # (1, 3, 32)
Reference Solution
Try solving it yourself first! Click below to reveal the solution.
# โ
SOLUTION
class MultiHeadAttention:
def __init__(self, d_model: int, num_heads: int):
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, Q, K, V):
B, S_q, _ = Q.shape
S_k = K.shape[1]
q = self.W_q(Q).view(B, S_q, self.num_heads, self.d_k).transpose(1, 2)
k = self.W_k(K).view(B, S_k, self.num_heads, self.d_k).transpose(1, 2)
v = self.W_v(V).view(B, S_k, self.num_heads, self.d_k).transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
weights = torch.softmax(scores, dim=-1)
attn = torch.matmul(weights, v)
out = attn.transpose(1, 2).contiguous().view(B, S_q, -1)
return self.W_o(out)