All Problems Description Template Solution

Mixture of Experts

Mixtral-style, top-k routing, expert MLPs

Hard Architecture

Problem Description

Implement a Mixture of Experts layer (Mixtral / Switch Transformer style).

Signature

class MixtureOfExperts(nn.Module): def __init__(self, d_model, d_ff, num_experts, top_k=2): ... def forward(self, x: Tensor) -> Tensor: # x: (B, S, D) -> (B, S, D)

Architecture

self.router: nn.Linear(d_model, num_experts) — gating network

self.experts: nn.ModuleList of MLPs (Linear→ReLU→Linear)

• For each token: select top-k experts, compute weighted sum of their outputs

Template

Implement the function below. Use only basic PyTorch operations.

# ✏️ YOUR IMPLEMENTATION HERE class MixtureOfExperts(nn.Module): def __init__(self, d_model, d_ff, num_experts, top_k=2): super().__init__() pass # router + experts def forward(self, x): pass # route tokens to top-k experts

Test Your Implementation

Use this code to debug before submitting.

# 🧪 Debug moe = MixtureOfExperts(32, 64, num_experts=4, top_k=2) x = torch.randn(2, 8, 32) print('Output:', moe(x).shape) print('Params:', sum(p.numel() for p in moe.parameters()))

Reference Solution

Try solving it yourself first! Click below to reveal the solution.

# ✅ SOLUTION class MixtureOfExperts(nn.Module): def __init__(self, d_model, d_ff, num_experts, top_k=2): super().__init__() self.top_k = top_k self.router = nn.Linear(d_model, num_experts) self.experts = nn.ModuleList([ nn.Sequential(nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model)) for _ in range(num_experts) ]) def forward(self, x): orig_shape = x.shape if x.dim() == 3: B, S, D = x.shape x_flat = x.reshape(-1, D) else: x_flat = x logits = self.router(x_flat) top_vals, top_idx = logits.topk(self.top_k, dim=-1) weights = torch.softmax(top_vals, dim=-1) output = torch.zeros_like(x_flat) for k in range(self.top_k): for e in range(len(self.experts)): mask = (top_idx[:, k] == e) if mask.any(): output[mask] += weights[mask, k:k+1] * self.experts[e](x_flat[mask]) return output.reshape(orig_shape)

Tips

Run Locally

For interactive practice with auto-grading, run TorchCode locally:
pip install torch-judge then use check("moe")

Key Concepts

Mixtral-style, top-k routing, expert MLPs

Mixture of Experts

Description Template Test Solution Tips