Problem Description
Implement the PPO (Proximal Policy Optimization) clipped surrogate loss.
Given:
• new_logps: current policy log-probs (B,)
• old_logps: old policy log-probs (B,)
• advantages: advantage estimates (B,)
Define the ratio
$$ r_i = \exp(\text{new\_logps}_i - \text{old\_logps}_i). $$
Then compute
• L^{\text{unclipped}}_i = r_i A_i
• L^{\text{clipped}}_i = \operatorname{clip}(r_i, 1-\epsilon, 1+\epsilon) A_i
The loss is the negative batch mean of the elementwise minimum:
$$\mathcal{L}_\text{PPO} = -\mathbb{E}_i\big[\min(L^{\text{unclipped}}_i, L^{\text{clipped}}_i)\big].$$
Implementation notes: detach old_logps and advantages so gradients only flow through new_logps.
Signature
from torch import Tensor
def ppo_loss(new_logps: Tensor, old_logps: Tensor, advantages: Tensor,
clip_ratio: float = 0.2) -> Tensor:
"""PPO clipped surrogate loss over a batch."""
Template
Implement the function below. Use only basic PyTorch operations.
# ✏️ YOUR IMPLEMENTATION HERE
def ppo_loss(new_logps: Tensor, old_logps: Tensor, advantages: Tensor,
clip_ratio: float = 0.2) -> Tensor:
pass # -mean(min(r * adv, clamp(r, 1-clip, 1+clip) * adv)) with gradients only through new_logps
Test Your Implementation
Use this code to debug before submitting.
# 🧪 Debug
new_logps = torch.tensor([0.0, -0.2, -0.4, -0.6])
old_logps = torch.tensor([0.0, -0.1, -0.5, -0.5])
advantages = torch.tensor([1.0, -1.0, 0.5, -0.5])
print('Loss:', ppo_loss(new_logps, old_logps, advantages, clip_ratio=0.2))
Reference Solution
Try solving it yourself first! Click below to reveal the solution.
# ✅ SOLUTION
def ppo_loss(new_logps: Tensor, old_logps: Tensor, advantages: Tensor,
clip_ratio: float = 0.2) -> Tensor:
"""PPO clipped surrogate loss.
new_logps: (B,) current policy log-probs
old_logps: (B,) old policy log-probs (treated as constant)
advantages: (B,) advantage estimates (treated as constant)
returns: scalar loss (Tensor)
"""
# Detach old_logps and advantages so gradients only flow through new_logps
old_logps_detached = old_logps.detach()
adv_detached = advantages.detach()
# Importance sampling ratio r = pi_new / pi_old in log-space
ratios = torch.exp(new_logps - old_logps_detached)
# Unclipped and clipped objectives
unclipped = ratios * adv_detached
clipped = torch.clamp(ratios, 1.0 - clip_ratio, 1.0 + clip_ratio) * adv_detached
# PPO objective: negative mean of the more conservative objective
return -torch.min(unclipped, clipped).mean()