All Problems Description Template Solution

Speculative Decoding

Accept/reject, draft model acceleration

Hard Inference

Problem Description

Implement the acceptance/rejection step of speculative decoding โ€” a technique for accelerating LLM inference.

Signature

def speculative_decode(target_probs, draft_probs, draft_tokens) -> list[int]: # target_probs: (K, V) from target (large) model # draft_probs: (K, V) from draft (small) model # draft_tokens: (K,) tokens sampled by draft model # Returns: list of accepted tokens (1 to K)

Algorithm

For each position i = 0, ..., K-1:

1. ratio = target_probs[i, token_i] / draft_probs[i, token_i]

2. Accept with probability min(1, ratio)

3. If rejected: sample from normalize(max(0, target - draft)), append, and stop

Template

Implement the function below. Use only basic PyTorch operations.

# โœ๏ธ YOUR IMPLEMENTATION HERE def speculative_decode(target_probs, draft_probs, draft_tokens): pass # accept/reject loop

Test Your Implementation

Use this code to debug before submitting.

# ๐Ÿงช Debug torch.manual_seed(0) probs = torch.softmax(torch.randn(4, 10), dim=-1) tokens = torch.tensor([2, 5, 1, 8]) print('Perfect draft:', speculative_decode(probs, probs, tokens)) target = torch.softmax(torch.randn(4, 10), dim=-1) draft = torch.softmax(torch.randn(4, 10), dim=-1) print('Random draft:', speculative_decode(target, draft, tokens))

Reference Solution

Try solving it yourself first! Click below to reveal the solution.

# โœ… SOLUTION def speculative_decode(target_probs, draft_probs, draft_tokens): K = len(draft_tokens) accepted = [] for i in range(K): t = draft_tokens[i].item() ratio = target_probs[i, t] / max(draft_probs[i, t].item(), 1e-10) if torch.rand(1).item() < min(1.0, ratio.item()): accepted.append(t) else: adjusted = torch.clamp(target_probs[i] - draft_probs[i], min=0) s = adjusted.sum() if s > 0: adjusted = adjusted / s else: adjusted = torch.ones_like(adjusted) / adjusted.shape[0] accepted.append(torch.multinomial(adjusted, 1).item()) return accepted return accepted

Tips

Run Locally

For interactive practice with auto-grading, run TorchCode locally:
pip install torch-judge then use check("speculative_decoding")

Key Concepts

Accept/reject, draft model acceleration

Speculative Decoding

Description Template Test Solution Tips