Group relative policy optimization, RLAIF
Hard AdvancedImplement the Group Relative Policy Optimization (GRPO) loss โ a group-wise, baseline-subtracted REINFORCE objective commonly used in RLAIF (reinforcement learning from AI feedback).
Given a batch of log-probabilities, scalar rewards, and group ids (one group per prompt), define the within-group normalized advantages:
where \(\bar r_{g(i)}\) and \(\text{std}_{g(i)}\) are the mean and standard deviation of rewards in the group of example \(i\).
The GRPO loss is then the negative advantage-weighted log-probability:
Implement the function below. Use only basic PyTorch operations.
Use this code to debug before submitting.
Try solving it yourself first! Click below to reveal the solution.
For interactive practice with auto-grading, run TorchCode locally:pip install torch-judge then use check("grpo_loss")
Group relative policy optimization, RLAIF