Nucleus sampling, temperature scaling
Medium InferenceImplement sampling with top-k and top-p filtering โ the standard LLM decoding strategy.
1. Scale by temperature: logits /= temperature
2. Top-k: keep only top-k logits, set rest to -inf
3. Top-p: sort by prob, mask tokens where cumulative prob exceeds p
4. Sample from filtered distribution
Implement the function below. Use only basic PyTorch operations.
Use this code to debug before submitting.
Try solving it yourself first! Click below to reveal the solution.
For interactive practice with auto-grading, run TorchCode locally:pip install torch-judge then use check("topk_sampling")
Nucleus sampling, temperature scaling