All Problems Description Template Solution

BatchNorm

Batch vs layer statistics, train/eval behavior

Medium Fundamentals

Problem Description

Implement Batch Normalization with both training and inference behavior.

In training mode, use batch statistics and update running estimates:

$$\text{BN}(x) = \gamma \cdot \frac{x - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}} + \beta$$

where \mu_B and \sigma_B^2 are the mean and variance computed across the batch (dim=0).

In inference mode, use the provided running mean/var instead of current batch stats.

Signature

def my_batch_norm( x: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor, running_mean: torch.Tensor, running_var: torch.Tensor, eps: float = 1e-5, momentum: float = 0.1, training: bool = True, ) -> torch.Tensor: # x: (N, D) — normalize each feature across all samples in the batch # running_mean, running_var: updated in-place during training; used as-is during inference

Rules

• Do NOT use F.batch_norm, nn.BatchNorm1d, etc.

• Compute batch mean and variance over dim=0 with unbiased=False

• Update running stats like PyTorch: running = (1 - momentum) * running + momentum * batch_stat

• Use running_mean / running_var for inference when training=False

• Must support autograd w.r.t. x, gamma, beta(running statistics 应视作 buffer,而不是需要梯度的参数)

Template

Implement the function below. Use only basic PyTorch operations.

# ✏️ YOUR IMPLEMENTATION HERE def my_batch_norm( x, gamma, beta, running_mean, running_var, eps=1e-5, momentum=0.1, training=True, ): pass # Replace this

Test Your Implementation

Use this code to debug before submitting.

# 🧪 Debug x = torch.randn(8, 4) gamma = torch.ones(4) beta = torch.zeros(4) # Running stats typically live on the same device and shape as features running_mean = torch.zeros(4) running_var = torch.ones(4) # Training mode: uses batch stats and updates running_mean / running_var out_train = my_batch_norm(x, gamma, beta, running_mean, running_var, training=True) print("[Train] Output shape:", out_train.shape) print("[Train] Column means:", out_train.mean(dim=0)) # should be ~0 print("[Train] Column stds: ", out_train.std(dim=0)) # should be ~1 print("Updated running_mean:", running_mean) print("Updated running_var:", running_var) # Inference mode: uses running_mean / running_var only out_eval = my_batch_norm(x, gamma, beta, running_mean, running_var, training=False) print("[Eval] Output shape:", out_eval.shape)

Reference Solution

Try solving it yourself first! Click below to reveal the solution.

# ✅ SOLUTION import torch def my_batch_norm( x, gamma, beta, running_mean, running_var, eps=1e-5, momentum=0.1, training=True, ): """BatchNorm with train/eval behavior and running stats. - Training: use batch stats, update running_mean / running_var in-place. - Inference: use running_mean / running_var as-is. """ if training: batch_mean = x.mean(dim=0) batch_var = x.var(dim=0, unbiased=False) # Update running statistics in-place. Detach to avoid tracking gradients. running_mean.mul_(1 - momentum).add_(momentum * batch_mean.detach()) running_var.mul_(1 - momentum).add_(momentum * batch_var.detach()) mean = batch_mean var = batch_var else: mean = running_mean var = running_var x_norm = (x - mean) / torch.sqrt(var + eps) return gamma * x_norm + beta

Tips

Run Locally

For interactive practice with auto-grading, run TorchCode locally:
pip install torch-judge then use check("batchnorm")

Key Concepts

Batch vs layer statistics, train/eval behavior

BatchNorm

Description Template Test Solution Tips