All Problems Description Template Solution

Top-k / Top-p Sampling

Nucleus sampling, temperature scaling

Medium Inference

Problem Description

Implement sampling with top-k and top-p filtering โ€” the standard LLM decoding strategy.

Signature

def sample_top_k_top_p(logits, top_k=0, top_p=1.0, temperature=1.0) -> int: # logits: (V,) unnormalized log-probabilities # Returns: sampled token index

Algorithm

1. Scale by temperature: logits /= temperature

2. Top-k: keep only top-k logits, set rest to -inf

3. Top-p: sort by prob, mask tokens where cumulative prob exceeds p

4. Sample from filtered distribution

Template

Implement the function below. Use only basic PyTorch operations.

# โœ๏ธ YOUR IMPLEMENTATION HERE def sample_top_k_top_p(logits, top_k=0, top_p=1.0, temperature=1.0): pass # temperature, top-k filter, top-p filter, sample

Test Your Implementation

Use this code to debug before submitting.

# ๐Ÿงช Debug logits = torch.tensor([1.0, 5.0, 2.0, 0.5]) print('top_k=1:', sample_top_k_top_p(logits.clone(), top_k=1)) print('top_p=0.5:', sample_top_k_top_p(logits.clone(), top_p=0.5)) print('temp=0.01:', sample_top_k_top_p(logits.clone(), temperature=0.01))

Reference Solution

Try solving it yourself first! Click below to reveal the solution.

# โœ… SOLUTION def sample_top_k_top_p(logits, top_k=0, top_p=1.0, temperature=1.0): logits = logits / max(temperature, 1e-8) if top_k > 0: top_k_val = logits.topk(top_k).values[-1] logits[logits < top_k_val] = float('-inf') if top_p < 1.0: sorted_logits, sorted_idx = torch.sort(logits, descending=True) probs = torch.softmax(sorted_logits, dim=-1) cumsum = torch.cumsum(probs, dim=-1) mask = (cumsum - probs) > top_p sorted_logits[mask] = float('-inf') logits = torch.empty_like(logits).scatter_(0, sorted_idx, sorted_logits) probs = torch.softmax(logits, dim=-1) return torch.multinomial(probs, 1).item()

Tips

Run Locally

For interactive practice with auto-grading, run TorchCode locally:
pip install torch-judge then use check("topk_sampling")

Key Concepts

Nucleus sampling, temperature scaling

Top-k / Top-p Sampling

Description Template Test Solution Tips