Introduction
DeepSeek-R1 shocked the AI world by achieving GPT-4 level reasoning capabilities through pure reinforcement learning. At the core of this breakthrough is GRPO (Group Relative Policy Optimization), an innovative reinforcement learning algorithm that eliminates the traditional critic network and instead optimizes policy through group-relative rewards.
GRPO solves the core problems of PPO (Proximal Policy Optimization): complexity, instability, and high memory consumption. Through clever group sampling design, GRPO achieves more efficient and stable training, ultimately enabling DeepSeek-R1’s reasoning breakthrough.
Problems with PPO
Traditional Actor-Critic Architecture
PPO belongs to the Actor-Critic family of reinforcement learning algorithms:
class PPOArchitecture:
"""
Traditional PPO requires multiple networks
"""
def __init__(self, state_dim, action_dim):
# Actor: learns the policy (what to do)
self.actor = ActorNetwork(state_dim, action_dim)
# Critic: estimates future rewards (value function)
self.critic = CriticNetwork(state_dim)
# Target networks for stability
self.target_actor = ActorNetwork(state_dim, action_dim)
self.target_critic = CriticNetwork(state_dim)
# Reference model for KL constraint
self.ref_model = ActorNetwork(state_dim, action_dim)
def ppo_loss(self, states, actions, old_log_probs, advantages):
"""
PPO Clip Objective:
L(θ) = E[min(r(θ) * A, clip(r(θ), 1-ε, 1+ε) * A)]
Where r(θ) = π_θ(a|s) / π_θ_old(a|s)
"""
# Get current policy probabilities
new_log_probs = self.actor.get_log_prob(states, actions)
# Compute probability ratio
ratio = torch.exp(new_log_probs - old_log_probs)
# Clipped surrogate objective
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - 0.2, 1 + 0.2) * advantages
# Take minimum (pessimistic bound)
policy_loss = -torch.min(surr1, surr2).mean()
# Value function loss
values = self.critic(states)
value_loss = F.mse_loss(values, advantages)
return policy_loss + 0.5 * value_loss
```text
### Four Major Challenges of PPO
```python
ppo_problems = {
'multiple_models': '4 models needed: actor, critic, reference, target',
'hyperparameters': 'Requires fine tuning: clip epsilon, GAE lambda, value loss weight',
'instability': 'Gradients may explode, needs gradient clipping and target network updates',
'memory': '40GB+ GPU memory for 7B model',
'complexity': 'GAE (Generalized Advantage Estimation) computation is complex',
# Code complexity comparison
'code_comparison': '''
PPO requires:
- advantage = compute_gae(rewards, values, gamma=0.99, lambda=0.95)
- ratio = (new_policy / old_policy).exp()
- clipped_ratio = ratio.clamp(1-eps, 1+eps)
- loss = -min(ratio * advantage, clipped_ratio * advantage)
- loss += 0.5 * value_loss + 0.01 * entropy_loss
'''
}
GRPO Core Principles
Key Insight
The core insight of GRPO is: for the same question, we can generate multiple responses and compare their relative quality, rather than learning an absolute value function.
def grpo_key_insight():
"""
GRPO key insight:
For each prompt q, we sample G responses {o_1, o_2, ..., o_G}
from the old policy π_ref
Then compute each response's reward r(o_i)
Use within-group statistics as baseline:
- mean: group average reward
- std: group reward standard deviation
Advantage function: A_i = (r(o_i) - mean) / std
This eliminates the need for a value network!
"""
pass
```yaml
### GRPO Loss Function
```python
import torch
import torch.nn.functional as F
def grpo_loss(
policy_logits, # policy model logits: [batch, group_size, seq_len, vocab]
ref_logits, # reference model logits
rewards, # reward values: [batch, group_size]
beta: float = 0.1,
epsilon: float = 0.2
):
"""
GRPO loss function
Args:
policy_logits: policy model output
ref_logits: reference (SFT) model output
rewards: reward for each response [batch, group_size]
beta: KL penalty coefficient
epsilon: clipping parameter
Returns:
loss: GRPO loss value
"""
batch_size, group_size, seq_len, vocab_size = policy_logits.shape
# Compute log probabilities
policy_logprobs = F.log_softmax(policy_logits, dim=-1)
ref_logprobs = F.log_softmax(ref_logits, dim=-1)
# Get total log probability per response (sum over tokens)
# Requires attention mask to ignore padding
log_probs = policy_logprobs.sum(dim=(2, 3)) # [batch, group_size]
ref_log_probs = ref_logprobs.sum(dim=(2, 3)) # [batch, group_size]
# Compute group-relative rewards (advantages)
# Computed for all responses of each prompt
mean_reward = rewards.mean(dim=1, keepdim=True) # [batch, 1]
std_reward = rewards.std(dim=1, keepdim=True) + 1e-8 # [batch, 1]
advantages = (rewards - mean_reward) / std_reward # [batch, group_size]
# Compute policy gradient term
# log π(a_i | q) - log π_ref(a_i | q)
policy_ref_diff = log_probs - ref_log_probs # [batch, group_size]
# Weighted advantage
weighted_diff = policy_ref_diff * advantages # [batch, group_size]
# Add KL penalty term
kl_penalty = (ref_log_probs - log_probs) # [batch, group_size]
# Final loss: maximize advantage + KL regularization
loss = -(weighted_diff - beta * kl_penalty).mean()
return loss
Complete GRPO Implementation
classGRPOTrainer:
"""
Complete GRPO training implementation
"""
def __init__(
self,
policy_model, # policy model to train
ref_model, # reference model (frozen SFT model)
reward_fn, # reward function
beta: float = 0.1,
group_size: int = 4,
max_length: int = 512
):
self.policy_model = policy_model
self.ref_model = ref_model
self.reward_fn = reward_fn
self.beta = beta
self.group_size = group_size
self.max_length = max_length
# Freeze reference model
for param in ref_model.parameters():
param.requires_grad = False
def sample_responses(self, prompts):
"""
Sample multiple responses for each prompt
"""
all_responses = []
for prompt in prompts:
# Sample multiple times to generate multiple responses
responses = []
for _ in range(self.group_size):
response = self.policy_model.generate(
prompt,
max_new_tokens=self.max_length,
do_sample=True,
temperature=0.7,
)
responses.append(response)
all_responses.append(responses)
return all_responses
def compute_rewards(self, prompts, responses):
"""
Compute reward for each response
"""
all_rewards = []
for prompt, response_group in zip(prompts, responses):
# Compute reward for each response in the group
group_rewards = []
for response in response_group:
reward = self.reward_fn(prompt, response)
group_rewards.append(reward)
all_rewards.append(group_rewards)
return torch.tensor(all_rewards, dtype=torch.float32)
def forward_batch(self, prompts, responses):
"""
Forward pass to compute loss
"""
batch_size = len(prompts)
# Prepare data
# [batch * group_size, seq_len]
flattened_responses = [r for group in responses for r in group]
# Tokenize
inputs = self.tokenizer(
flattened_responses,
return_tensors='pt',
padding=True,
truncation=True,
max_length=self.max_length
)
# Policy model forward
policy_outputs = self.policy_model(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask']
)
# Reference model forward (no gradient)
with torch.no_grad():
ref_outputs = self.ref_model(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask']
)
# Reshape to [batch, group_size, seq_len, vocab]
policy_logits = policy_outputs.logits.view(
batch_size, self.group_size, -1, self.policy_model.config.vocab_size
)
ref_logits = ref_outputs.logits.view(
batch_size, self.group_size, -1, self.ref_model.config.vocab_size
)
# Compute rewards
rewards = self.compute_rewards(prompts, responses)
# Compute GRPO loss
loss = grpo_loss(
policy_logits,
ref_logits,
rewards,
beta=self.beta
)
return loss
def train_step(self, prompts):
"""
Single training step
"""
# 1. Sample responses
responses = self.sample_responses(prompts)
# 2. Forward pass and loss computation
loss = self.forward_batch(prompts, responses)
# 3. Backward pass
self.optimizer.zero_grad()
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.policy_model.parameters(), 1.0)
self.optimizer.step()
return loss.item()
```python
## DeepSeek-R1 Application
### GRPO's Role in DeepSeek-R1
```python
class DeepSeekR1Training:
"""
DeepSeek-R1 uses GRPO for reasoning capability training
"""
def __init__(self):
self.base_model = None
self.reward_functions = []
def setup_rewards(self):
"""
R1 uses a combination of multiple reward functions
"""
# 1. Accuracy reward: check if answer is correct
self.reward_functions.append(AccuracyReward())
# 2. Format reward: require model to use specific format
self.reward_functions.append(FormatReward())
# 3. Reasoning step reward: check thought process
self.reward_functions.append(ReasoningReward())
def compute_composite_reward(self, prompt, response):
"""
Combine multiple rewards
"""
total_reward = 0.0
for reward_fn in self.reward_functions:
reward = reward_fn(prompt, response)
total_reward += reward
return total_reward
def train(self, prompts):
"""
Train using GRPO
"""
trainer = GRPOTrainer(
policy_model=self.base_model,
ref_model=self.sft_model,
reward_fn=self.compute_composite_reward,
group_size=16, # DeepSeek uses larger group
beta=0.04 # Smaller beta
)
for step in range(10000):
loss = trainer.train_step(prompts)
if step % 100 == 0:
print(f"Step {step}: Loss = {loss:.4f}")
Reward Function Design
class AccuracyReward:
"""
Accuracy reward: checks if final answer is correct
"""
def __init__(self):
self.weight = 1.0
def __call__(self, prompt, response):
# Extract answer and check correctness
extracted_answer = self.extract_answer(response)
ground_truth = self.get_ground_truth(prompt)
if extracted_answer == ground_truth:
return self.weight
else:
return 0.0
def extract_answer(self, response):
# Extract answer from model response
# May need regex or special markers
pass
def get_ground_truth(self, prompt):
# Get correct answer from the question
pass
class FormatReward:
"""
Format reward: requires model output to include thought process
"""
def __init__(self):
self.weight = 0.1
def __call__(self, prompt, response):
# Check if response contains <think> tags
has_think = '<think>' in response and '</think>' in response
has_answer = '<answer>' in response and '</answer>' in response
if has_think and has_answer:
return self.weight
elif has_think or has_answer:
return self.weight * 0.5
else:
return 0.0
class ReasoningReward:
"""
Reasoning reward: encourages long reasoning chains
"""
def __init__(self):
self.weight = 0.01
def __call__(self, prompt, response):
# Reward longer reasoning processes
# But only when format is correct
think_content = self.extract_think(response)
reasoning_length = len(think_content)
# Normalize: longer length gets higher reward (with upper bound)
normalized_reward = min(reasoning_length / 1000, 1.0)
return self.weight * normalized_reward
```text
## Performance Analysis
### GRPO vs PPO
```python
# Performance comparison
performance_comparison = {
'memory_usage': {
'PPO': '40GB+ for 7B model',
'GRPO': '20GB for 7B model', # 50% reduction
},
'training_speed': {
'PPO': '3 days on 8x A100',
'GRPO': '1.5 days on 8x A100', # 2x faster
},
'sample_efficiency': {
'PPO': 'Uses value estimation, can be biased',
'GRPO': 'Empirical baseline, more accurate',
},
'stability': {
'PPO': 'Requires clipping, value loss weighting',
'GRPO': 'Simple objective, more stable',
},
'hyperparameters': {
'PPO': '10+ hyperparameters',
'GRPO': '2-3 key hyperparameters (beta, group_size)',
}
}
Math Reasoning Results
# DeepSeek-R1 performance on math reasoning tasks
math_results = {
'GSM8K': {
'base_model': '15.6%',
'PPO_trained': '52.3%',
'GRPO_trained': '89.3%', # Significantly higher
},
'MATH': {
'base_model': '10.2%',
'PPO_trained': '28.5%',
'GRPO_trained': '47.1%',
}
}
```python
## Implementation Details
### Group Size Selection
```python
def optimal_group_size(task_complexity):
"""
Select group size based on task complexity
Args:
task_complexity: Task complexity score 1-10
Returns:
optimal_group_size
"""
if task_complexity <= 3:
# Simple tasks, smaller group works
return 4
elif task_complexity <= 6:
# Medium complexity
return 8
else:
# High complexity reasoning tasks
return 16 # DeepSeek-R1 uses 16
# General rules:
# - Larger groups provide more accurate baseline estimation
# - But increasing group size reduces gradient updates per epoch
# - In practice, 4-16 is the common range
Beta Scheduling
def cosine_beta_schedule(total_steps, start_beta=0.1, end_beta=0.04):
"""
Beta scheduling: gradually reduce KL penalty to allow larger policy changes
"""
def get_beta(step):
if step < total_steps * 0.1:
# Early stage: high beta, stay close to reference model
return start_beta
elif step > total_steps * 0.8:
# Late stage: low beta, allow more exploration
return end_beta
else:
# Middle stage: cosine annealing
progress = (step - total_steps * 0.1) / (total_steps * 0.7)
return start_beta - (start_beta - end_beta) * (1 + torch.cos(torch.tensor(progress * torch.pi))) / 2
return get_beta
```python
### Advanced Variant: GRPO with Self-Consistency
```python
def grpo_with_self_consistency(
policy_model,
ref_model,
prompts,
group_size=8,
num_final_samples=16
):
"""
GRPO combined with self-consistency
1. Generate multiple responses
2. Use majority voting to select the most consistent answer
3. Give higher reward to responses matching the majority
"""
all_responses = []
for _ in range(group_size):
responses = policy_model.generate(prompts)
all_responses.append(responses)
# Extract all answers
all_answers = [[extract_answer(r) for r in group] for group in all_responses]
# Majority voting
final_answers = []
for answer_group in all_answers:
# Count occurrences of each answer
from collections import Counter
counts = Counter(answer_group)
# Most common answer is the final answer
final_answer = counts.most_common(1)[0][0]
final_answers.append(final_answer)
# Reward: higher reward for responses matching the final consistent answer
rewards = []
for group_answers in all_answers:
group_rewards = []
majority_count = max(Counter(group_answers).values())
for answer in group_answers:
if answer == final_answers[0]:
reward = 1.0
else:
reward = -0.1
group_rewards.append(reward)
rewards.append(group_rewards)
# Use standard GRPO loss
return grpo_loss(policy_logits, ref_logits, rewards)
Comparison with DPO
# GRPO vs DPO comparison
comparison = {
'training_signal': {
'DPO': 'Pairwise preference: chosen vs rejected',
'GRPO': 'Group relative rewards: relative ranking of multiple responses',
},
'reference_model': {
'DPO': 'Required (computes KL)',
'GRPO': 'Required (computes KL)',
},
'sampling': {
'DPO': '2 responses per prompt',
'GRPO': 'G responses per prompt (G >= 4)',
},
'reward_type': {
'DPO': 'Binary preference',
'GRPO': 'Continuous reward',
},
'use_case': {
'DPO': 'General preference alignment',
'GRPO': 'Reasoning capability enhancement',
}
}
```text
## Practical Advice
### When to Use GRPO
```python
# GRPO ideal use cases
grpo_ideal_cases = {
'reasoning_tasks': 'Math, code, logical reasoning',
'self_verification': 'Model can verify its own output',
'rule_based_rewards': 'Clear correctness criteria available',
'limited_memory': 'Cannot afford PPO memory overhead',
'quick_iteration': 'Need fast experimentation and iteration',
# Not suitable:
'subjective_preferences': 'Subjective preferences have no clear standard',
'complex_environments': 'Need to interact with complex environments',
}
Common Pitfalls
# GRPO common issues and solutions
common_issues = {
'issue1': {
'problem': 'Reward variance too high',
'solution': 'Increase group size or use reward normalization'
},
'issue2': {
'problem': 'Model starts repeating responses',
'solution': 'Add repetition penalty reward term'
},
'issue3': {
'problem': 'KL divergence too large',
'solution': 'Increase beta value'
},
'issue4': {
'problem': 'Unstable training',
'solution': 'Use gradient clipping, reduce learning rate'
}
}
Conclusion
GRPO represents a major breakthrough in reinforcement learning optimization:
- Memory halved: 50% memory usage reduction
- Speed doubled: 2x faster training speed
- More stable: Fewer hyperparameters, more reliable convergence
- Reasoning breakthrough: Enabled DeepSeek-R1’s math reasoning capabilities
By replacing value networks with group-relative rewards, GRPO greatly simplifies the reinforcement learning pipeline while maintaining — and even improving — training effectiveness. This algorithm is becoming the new standard for reasoning model training.
Resources
- GRPO Paper: DeepSeek-R1 Technical Report
- DeepSeek-R1 GitHub
- GRPO Official Implementation
- HuggingFace GRPO Tutorial
Comments