All Problems Description Template Solution

Beam Search

Hypothesis expansion, pruning, eos handling

Medium Inference

Problem Description

Implement beam search โ€” the classic decoding algorithm for sequence generation.

Signature

def beam_search(log_prob_fn, start_token, max_len, beam_width, eos_token) -> list[int]: # log_prob_fn: takes token list, returns (V,) log-probabilities # Returns: best sequence (list of ints)

Algorithm

1. Start with [(0.0, [start_token])]

2. Each step: expand each beam with top-k next tokens

3. Keep top beam_width beams by total log-probability

4. Stop when best beam ends with eos_token or max_len reached

Template

Implement the function below. Use only basic PyTorch operations.

# โœ๏ธ YOUR IMPLEMENTATION HERE def beam_search(log_prob_fn, start_token, max_len, beam_width, eos_token): pass # maintain beams, expand, prune, return best

Test Your Implementation

Use this code to debug before submitting.

# ๐Ÿงช Debug def simple_fn(tokens): lp = torch.full((5,), -10.0) lp[min(len(tokens), 4)] = 0.0 return lp seq = beam_search(simple_fn, start_token=0, max_len=5, beam_width=2, eos_token=4) print('Sequence:', seq)

Reference Solution

Try solving it yourself first! Click below to reveal the solution.

# โœ… SOLUTION def beam_search(log_prob_fn, start_token, max_len, beam_width, eos_token): beams = [(0.0, [start_token])] completed = [] for _ in range(max_len): candidates = [] for score, seq in beams: if seq[-1] == eos_token: completed.append((score, seq)) continue log_probs = log_prob_fn(torch.tensor(seq)) topk_lp, topk_idx = log_probs.topk(beam_width) for j in range(beam_width): candidates.append((score + topk_lp[j].item(), seq + [topk_idx[j].item()])) if not candidates: break candidates.sort(key=lambda x: x[0], reverse=True) beams = candidates[:beam_width] all_seqs = completed + beams all_seqs.sort(key=lambda x: x[0], reverse=True) return all_seqs[0][1]

Tips

Run Locally

For interactive practice with auto-grading, run TorchCode locally:
pip install torch-judge then use check("beam_search")

Key Concepts

Hypothesis expansion, pruning, eos handling

Beam Search

Description Template Test Solution Tips