All Problems Description Template Solution

RMSNorm

LLaMA-style norm, simpler than LayerNorm

Medium Fundamentals

Problem Description

Implement Root Mean Square Layer Normalization โ€” the normalization used in LLaMA, Gemma, etc.

$$\text{RMSNorm}(x) = \frac{x}{\text{RMS}(x)} \cdot w, \quad \text{RMS}(x) = \sqrt{\frac{1}{d}\sum x_i^2 + \epsilon}$$

Signature

def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: # Normalize over the last dimension. No mean subtraction (unlike LayerNorm).

Rules

• Do NOT use any built-in norm layers

• Normalize over dim=-1

• Must support autograd

Template

Implement the function below. Use only basic PyTorch operations.

# โœ๏ธ YOUR IMPLEMENTATION HERE def rms_norm(x, weight, eps=1e-6): pass # Replace this

Test Your Implementation

Use this code to debug before submitting.

# ๐Ÿงช Debug x = torch.randn(2, 8) w = torch.ones(8) out = rms_norm(x, w) print("Output shape:", out.shape) print("RMS of output:", out.pow(2).mean(dim=-1).sqrt()) # should be ~1

Reference Solution

Try solving it yourself first! Click below to reveal the solution.

# โœ… SOLUTION def rms_norm(x, weight, eps=1e-6): rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + eps) return x / rms * weight

Tips

Run Locally

For interactive practice with auto-grading, run TorchCode locally:
pip install torch-judge then use check("rmsnorm")

Key Concepts

LLaMA-style norm, simpler than LayerNorm

RMSNorm

Description Template Test Solution Tips