TechLead
Intermediate
20 min
Full Guide

Reinforcement Learning: Deep RL & Modern Methods

Advanced reinforcement learning with Deep Q-Networks, policy gradients, PPO, actor-critic methods, and RLHF for LLM alignment

From Q-Learning to Deep RL

While basic Q-learning uses a table to store values for each state-action pair, Deep Q-Networks (DQN) use neural networks to approximate the Q-function, enabling RL to work in environments with large or continuous state spaces like video games and robotics.

The DQN Breakthrough (2015):

DeepMind's DQN learned to play Atari games at superhuman level using only raw pixels as input. Key innovations: experience replay buffer and target network stabilization.

Deep Q-Network (DQN) Implementation

import torch
import torch.nn as nn
import numpy as np
from collections import deque
import random

class DQN(nn.Module):
    """Deep Q-Network: approximates Q(s, a) with a neural net."""
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim)
        )

    def forward(self, state):
        return self.network(state)

class ReplayBuffer:
    """Store and sample past experiences for stable training."""
    def __init__(self, capacity=10000):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (np.array(states), np.array(actions), np.array(rewards),
                np.array(next_states), np.array(dones))

class DQNAgent:
    def __init__(self, state_dim, action_dim, lr=1e-3, gamma=0.99):
        self.action_dim = action_dim
        self.gamma = gamma
        self.epsilon = 1.0      # exploration rate
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995

        # Two networks for stable training
        self.q_network = DQN(state_dim, action_dim)
        self.target_network = DQN(state_dim, action_dim)
        self.target_network.load_state_dict(self.q_network.state_dict())

        self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=lr)
        self.buffer = ReplayBuffer()

    def select_action(self, state):
        """Epsilon-greedy: explore randomly or exploit learned knowledge."""
        if random.random() < self.epsilon:
            return random.randint(0, self.action_dim - 1)
        state_t = torch.FloatTensor(state).unsqueeze(0)
        with torch.no_grad():
            q_values = self.q_network(state_t)
        return q_values.argmax().item()

    def train_step(self, batch_size=64):
        if len(self.buffer.buffer) < batch_size:
            return

        states, actions, rewards, next_states, dones = self.buffer.sample(batch_size)

        states_t = torch.FloatTensor(states)
        actions_t = torch.LongTensor(actions)
        rewards_t = torch.FloatTensor(rewards)
        next_states_t = torch.FloatTensor(next_states)
        dones_t = torch.FloatTensor(dones)

        # Current Q values
        current_q = self.q_network(states_t).gather(1, actions_t.unsqueeze(1))

        # Target Q values (from target network for stability)
        with torch.no_grad():
            max_next_q = self.target_network(next_states_t).max(1)[0]
            target_q = rewards_t + (1 - dones_t) * self.gamma * max_next_q

        # Update Q-network
        loss = nn.MSELoss()(current_q.squeeze(), target_q)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Decay exploration
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

agent = DQNAgent(state_dim=4, action_dim=2)
print(f"DQN agent created with {sum(p.numel() for p in agent.q_network.parameters())} parameters")

Policy Gradient Methods

Instead of learning value functions (DQN), policy gradient methods directly learn the policy (which action to take). This works better for continuous action spaces and stochastic policies.

import torch
import torch.nn as nn
import torch.nn.functional as F

class PolicyNetwork(nn.Module):
    """REINFORCE: A simple policy gradient network."""
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, action_dim)
        self.saved_log_probs = []
        self.rewards = []

    def forward(self, x):
        x = F.relu(self.fc1(x))
        action_probs = F.softmax(self.fc2(x), dim=-1)
        return action_probs

    def select_action(self, state):
        state_t = torch.FloatTensor(state).unsqueeze(0)
        probs = self.forward(state_t)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()
        self.saved_log_probs.append(dist.log_prob(action))
        return action.item()

def reinforce_update(policy, optimizer, gamma=0.99):
    """REINFORCE update: policy gradient with returns."""
    R = 0
    returns = []

    # Calculate discounted returns (backwards)
    for r in reversed(policy.rewards):
        R = r + gamma * R
        returns.insert(0, R)

    returns = torch.tensor(returns)
    returns = (returns - returns.mean()) / (returns.std() + 1e-8)  # normalize

    # Policy gradient: increase probability of actions that led to high returns
    policy_loss = []
    for log_prob, G in zip(policy.saved_log_probs, returns):
        policy_loss.append(-log_prob * G)  # negative for gradient ascent

    optimizer.zero_grad()
    loss = torch.stack(policy_loss).sum()
    loss.backward()
    optimizer.step()

    policy.saved_log_probs = []
    policy.rewards = []
    return loss.item()

policy = PolicyNetwork(state_dim=4, action_dim=2)
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-3)
print("REINFORCE: directly optimize the policy using returns")

PPO: Proximal Policy Optimization

PPO is the most widely used RL algorithm in practice (it powers ChatGPT's RLHF). It improves on REINFORCE by preventing too-large policy updates using a clipped objective.

import torch
import torch.nn as nn

class ActorCritic(nn.Module):
    """Actor-Critic network for PPO."""
    def __init__(self, state_dim, action_dim):
        super().__init__()
        # Shared feature extractor
        self.shared = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
        )
        # Actor head (policy): outputs action probabilities
        self.actor = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim),
            nn.Softmax(dim=-1)
        )
        # Critic head (value): estimates state value
        self.critic = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, state):
        features = self.shared(state)
        action_probs = self.actor(features)
        state_value = self.critic(features)
        return action_probs, state_value

def ppo_loss(old_probs, new_probs, advantages, clip_epsilon=0.2):
    """
    PPO Clipped Objective:
    L = min(r * A, clip(r, 1-eps, 1+eps) * A)

    Prevents the policy from changing too much in a single update.
    """
    ratio = new_probs / (old_probs + 1e-8)

    # Unclipped objective
    surr1 = ratio * advantages

    # Clipped objective (prevents too-large updates)
    surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantages

    # Take the minimum (conservative update)
    return -torch.min(surr1, surr2).mean()

# PPO is used in:
# - ChatGPT/Claude RLHF training (align LLMs with human preferences)
# - OpenAI Five (Dota 2)
# - Robot locomotion and manipulation
# - Game AI agents
model = ActorCritic(state_dim=8, action_dim=4)
print(f"Actor-Critic model: {sum(p.numel() for p in model.parameters())} parameters")
print("PPO clips the update ratio to prevent instability")

RLHF: Connecting RL to LLMs

How RLHF Aligns LLMs

Step 1: Collect human comparisons

Show humans two model responses to the same prompt. They pick which is better. Collect thousands of these preference pairs.

Step 2: Train a reward model

Train a model to predict which response a human would prefer. This becomes our automated "human judge" that can score any response.

Step 3: Optimize with PPO

Use PPO to update the LLM policy to maximize the reward model's score, while staying close to the original model (KL penalty prevents reward hacking).

Alternative: DPO (Direct Preference Optimization)

Skip the reward model entirely. Directly optimize the LLM using preference pairs. Simpler, more stable, increasingly replacing RLHF. Used by Llama 3.

Exploration vs Exploitation

Exploration

Try new, uncertain actions to discover potentially better strategies.

  • Epsilon-greedy: Random action with probability epsilon
  • Boltzmann: Sample from softmax of Q-values
  • UCB: Bonus for less-visited states
  • Curiosity-driven: Reward for novelty

Exploitation

Use the best-known action to maximize immediate reward.

  • Greedy: Always pick highest Q-value action
  • Risk: Gets stuck in local optima
  • When: After sufficient exploration
  • Decay: Gradually shift from explore to exploit

Key Takeaways

  • DQN uses neural networks to approximate Q-values for large state spaces
  • Policy gradients directly optimize the policy; PPO clips updates for stability
  • Actor-Critic combines value estimation (critic) with policy learning (actor)
  • RLHF uses PPO to align LLMs with human preferences via a learned reward model
  • DPO is a simpler alternative to RLHF that skips the reward model training step

Continue Learning