All Problems Description Template Solution

Kaiming Init

std = sqrt(2/fan_in), variance scaling

Easy Fundamentals

Problem Description

Implement Kaiming (He) normal initialization for weight tensors.

$$W \sim \mathcal{N}(0, \text{std}^2) \quad \text{where} \quad \text{std} = \sqrt{\frac{2}{\text{fan\_in}}}$$

Signature

def kaiming_init(weight: Tensor) -> Tensor: # Initialize weight in-place with Kaiming normal # fan_in = weight.shape[1] # Returns the weight tensor

Template

Implement the function below. Use only basic PyTorch operations.

# ✏️ YOUR IMPLEMENTATION HERE def kaiming_init(weight): pass # fill with normal(0, sqrt(2/fan_in))

Test Your Implementation

Use this code to debug before submitting.

# 🧪 Debug import math w = torch.empty(256, 512) kaiming_init(w) print(f'Mean: {w.mean():.4f} (expect ~0)') print(f'Std: {w.std():.4f} (expect {math.sqrt(2/512):.4f})')

Reference Solution

Try solving it yourself first! Click below to reveal the solution.

# ✅ SOLUTION def kaiming_init(weight): fan_in = weight.shape[1] if weight.dim() >= 2 else weight.shape[0] std = math.sqrt(2.0 / fan_in) with torch.no_grad(): weight.normal_(0, std) return weight

Tips

Run Locally

For interactive practice with auto-grading, run TorchCode locally:
pip install torch-judge then use check("weight_init")

Key Concepts

std = sqrt(2/fan_in), variance scaling

Kaiming Init

Description Template Test Solution Tips