All Problems Description Template Solution

Gradient Clipping

Norm-based clipping, direction preservation

Easy Fundamentals

Problem Description

Implement gradient norm clipping โ€” a training stability technique.

Signature

def clip_grad_norm(parameters, max_norm: float) -> float: # Clip gradients in-place so total norm <= max_norm # Returns the original (unclipped) total norm

Algorithm

1. Compute total norm: sqrt(sum(p.grad.norm()^2 for p in parameters))

2. If total > max_norm: scale all grads by max_norm / total

3. Return original total norm

Template

Implement the function below. Use only basic PyTorch operations.

# โœ๏ธ YOUR IMPLEMENTATION HERE def clip_grad_norm(parameters, max_norm): pass # compute total norm, clip if needed, return original norm

Test Your Implementation

Use this code to debug before submitting.

# ๐Ÿงช Debug p = torch.randn(100, requires_grad=True) (p * 10).sum().backward() print('Before:', p.grad.norm().item()) orig = clip_grad_norm([p], max_norm=1.0) print('After: ', p.grad.norm().item()) print('Original norm:', orig)

Reference Solution

Try solving it yourself first! Click below to reveal the solution.

# โœ… SOLUTION def clip_grad_norm(parameters, max_norm): parameters = [p for p in parameters if p.grad is not None] total_norm = torch.sqrt(sum(p.grad.norm() ** 2 for p in parameters)) clip_coef = max_norm / (total_norm + 1e-6) if clip_coef < 1: for p in parameters: p.grad.mul_(clip_coef) return total_norm.item()

Tips

Run Locally

For interactive practice with auto-grading, run TorchCode locally:
pip install torch-judge then use check("gradient_clipping")

Key Concepts

Norm-based clipping, direction preservation

Gradient Clipping

Description Template Test Solution Tips