All Problems Description Template Solution

ViT Patch Embedding

Image to patches to linear projection

Medium Architecture

Problem Description

Implement the patch embedding layer from Vision Transformer (ViT).

Signature

class PatchEmbedding(nn.Module): def __init__(self, img_size, patch_size, in_channels, embed_dim): ... def forward(self, x: Tensor) -> Tensor: # x: (B, C, H, W) # Returns: (B, num_patches, embed_dim)

Algorithm

1. Reshape image into non-overlapping patches: (B, C, H, W) โ†’ (B, N, C*P*P)

2. Project each patch: nn.Linear(C*P*P, embed_dim)

3. num_patches = (img_size // patch_size) ** 2

Template

Implement the function below. Use only basic PyTorch operations.

# โœ๏ธ YOUR IMPLEMENTATION HERE class PatchEmbedding(nn.Module): def __init__(self, img_size, patch_size, in_channels, embed_dim): super().__init__() pass # self.num_patches, self.proj def forward(self, x): pass # reshape to patches, project

Test Your Implementation

Use this code to debug before submitting.

# ๐Ÿงช Debug pe = PatchEmbedding(32, 8, 3, 64) x = torch.randn(2, 3, 32, 32) print('Output:', pe(x).shape) print('Patches:', pe.num_patches)

Reference Solution

Try solving it yourself first! Click below to reveal the solution.

# โœ… SOLUTION class PatchEmbedding(nn.Module): def __init__(self, img_size, patch_size, in_channels, embed_dim): super().__init__() self.patch_size = patch_size self.num_patches = (img_size // patch_size) ** 2 self.proj = nn.Linear(in_channels * patch_size * patch_size, embed_dim) def forward(self, x): B, C, H, W = x.shape p = self.patch_size n_h, n_w = H // p, W // p x = x.reshape(B, C, n_h, p, n_w, p) x = x.permute(0, 2, 4, 1, 3, 5).reshape(B, n_h * n_w, C * p * p) return self.proj(x)

Tips

Run Locally

For interactive practice with auto-grading, run TorchCode locally:
pip install torch-judge then use check("vit_patch")

Key Concepts

Image to patches to linear projection

ViT Patch Embedding

Description Template Test Solution Tips