强化学习:DQN玩转CartPole游戏
1. CartPole环境与强化学习基础
1.1 环境介绍
CartPole是OpenAI Gym中的经典控制问题,目标是通过左右移动小车保持杆子竖直:
1.2 强化学习基本概念
- 状态(State) : <math xmlns="http://www.w3.org/1998/Math/MathML"> s t ∈ R 4 s_t \in \mathbb{R}^4 </math>st∈R4
- 动作(Action) : <math xmlns="http://www.w3.org/1998/Math/MathML"> a t ∈ { 0 , 1 } a_t \in \{0, 1\} </math>at∈{0,1}
- 奖励(Reward):每步存活获得+1
- 目标 :最大化累积奖励 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∑ t = 0 T γ t r t \sum_{t=0}^T \gamma^t r_t </math>∑t=0Tγtrt
2. DQN算法原理
2.1 Q-Learning更新公式
<math xmlns="http://www.w3.org/1998/Math/MathML"> Q ( s t , a t ) ← Q ( s t , a t ) + α [ r t + 1 + γ max a Q ( s t + 1 , a ) − Q ( s t , a t ) ] Q(s_t,a_t) \leftarrow Q(s_t,a_t) + \alpha[r_{t+1} + \gamma \max_a Q(s_{t+1},a) - Q(s_t,a_t)] </math>Q(st,at)←Q(st,at)+α[rt+1+γmaxaQ(st+1,a)−Q(st,at)]
2.2 深度Q网络改进
2.2.1 关键技术组件
- 经验回放(Experience Replay):打破数据相关性
- 目标网络(Target Network):稳定训练目标
- ε-贪婪策略:平衡探索与利用
3. PyTorch实现DQN
3.1 Q网络定义
python
import torch
import torch.nn as nn
import torch.optim as optim
class DQN(nn.Module):
def __init__(self, input_dim, output_dim):
super(DQN, self).__init__()
self.fc = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, output_dim)
def forward(self, x):
return self.fc(x)
3.2 经验回放缓冲区
python
from collections import deque
import random
class ReplayBuffer:
def __init__(self, capacity):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
return random.sample(self.buffer, batch_size)
def __len__(self):
return len(self.buffer)
3.3 DQN智能体实现
python
class DQNAgent:
def __init__(self, env, gamma=0.99, lr=1e-3):
self.env = env
self.gamma = gamma
self.epsilon = 1.0
self.epsilon_min = 0.01
self.epsilon_decay = 0.995
self.policy_net = DQN(env.observation_space.shape[0],
env.action_space.n)
self.target_net = DQN(env.observation_space.shape[0],
env.action_space.n)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
self.memory = ReplayBuffer(10000)
def select_action(self, state):
if random.random() < self.epsilon:
return self.env.action_space.sample()
with torch.no_grad():
return self.policy_net(state).argmax().item()
def update_model(self, batch_size):
if len(self.memory) < batch_size:
return
# 从缓冲区采样
transitions = self.memory.sample(batch_size)
batch = list(zip(*transitions))
# 转换为张量
state_batch = torch.FloatTensor(batch[0])
action_batch = torch.LongTensor(batch[1]).unsqueeze(1)
reward_batch = torch.FloatTensor(batch[2])
next_state_batch = torch.FloatTensor(batch[3])
done_batch = torch.FloatTensor(batch[4])
# 计算当前Q值
q_values = self.policy_net(state_batch).gather(1, action_batch)
# 计算目标Q值
next_q_values = self.target_net(next_state_batch).max(1)[0].detach()
expected_q = reward_batch + (1 - done_batch) * self.gamma * next_q_values
# 计算损失
loss = F.mse_loss(q_values, expected_q.unsqueeze(1))
# 优化模型
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# 更新ε
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
4. 训练流程与结果
4.1 训练循环
python
def train(env, agent, episodes=500, batch_size=64):
rewards = []
for ep in range(episodes):
state = env.reset()
total_reward = 0
while True:
state_tensor = torch.FloatTensor(state)
action = agent.select_action(state_tensor)
next_state, reward, done, _ = env.step(action)
agent.memory.push(state, action, reward, next_state, done)
agent.update_model(batch_size)
state = next_state
total_reward += reward
if done:
break
# 更新目标网络
if ep % 10 == 0:
agent.target_net.load_state_dict(agent.policy_net.state_dict())
rewards.append(total_reward)
print(f"Episode {ep+1}/{episodes}, Reward: {total_reward}, Epsilon: {agent.epsilon:.2f}")
return rewards
4.2 训练结果分析
python
import matplotlib.pyplot as plt
env = gym.make('CartPole-v1')
agent = DQNAgent(env)
rewards = train(env, agent)
# 绘制学习曲线
plt.plot(rewards)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('DQN Training Progress')
plt.show()
4.2.1 典型训练曲线
5. 高级改进技巧
5.1 Double DQN
修改目标Q值计算:
python
next_actions = self.policy_net(next_state_batch).max(1)[1]
next_q_values = self.target_net(next_state_batch).gather(1, next_actions.unsqueeze(1)).squeeze(1)
5.2 Dueling DQN
修改网络结构:
python
class DuelingDQN(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.feature = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU()
)
self.advantage = nn.Linear(128, output_dim)
self.value = nn.Linear(128, 1)
def forward(self, x):
x = self.feature(x)
advantage = self.advantage(x)
value = self.value(x)
return value + advantage - advantage.mean()
5.3 性能对比
方法 | 平均奖励 | 收敛速度 | 稳定性 |
---|---|---|---|
原始DQN | 195 | 300次 | 中 |
Double DQN | 200 | 250次 | 高 |
Dueling DQN | 200+ | 200次 | 很高 |
6. 常见问题解答
Q: 为什么需要目标网络?
- 防止Q值估计的快速变化导致训练不稳定
- 提供相对固定的目标值进行学习
Q: 如何选择ε衰减速度?
- 初始阶段保持较高探索(ε=1.0)
- 逐步衰减到最小值(ε_min=0.01)
- 典型衰减率:0.995~0.999
Q: 如何处理稀疏奖励问题?
- 使用优先经验回放(Prioritized Experience Replay)
- 引入内在好奇心模块
- 调整奖励函数设计
附录:核心数学推导
Bellman最优方程
<math xmlns="http://www.w3.org/1998/Math/MathML"> Q ∗ ( s , a ) = E [ r + γ max a ′ Q ∗ ( s ′ , a ′ ) ∣ s , a ] Q^*(s,a) = \mathbb{E}[r + \gamma \max_{a'} Q^*(s',a') | s,a] </math>Q∗(s,a)=E[r+γmaxa′Q∗(s′,a′)∣s,a]
损失函数推导
<math xmlns="http://www.w3.org/1998/Math/MathML"> L ( θ ) = E ( s , a , r , s ′ ) ∼ D [ ( r + γ max a ′ Q target ( s ′ , a ′ ) − Q ( s , a ; θ ) ) 2 ] \mathcal{L}(\theta) = \mathbb{E}{(s,a,r,s') \sim D}[(r + \gamma \max{a'} Q_{\text{target}}(s',a') - Q(s,a;\theta))^2] </math>L(θ)=E(s,a,r,s′)∼D[(r+γmaxa′Qtarget(s′,a′)−Q(s,a;θ))2]
梯度更新公式
<math xmlns="http://www.w3.org/1998/Math/MathML"> θ ← θ − α ∇ θ L ( θ ) \theta \leftarrow \theta - \alpha \nabla_\theta \mathcal{L}(\theta) </math>θ←θ−α∇θL(θ)
最佳实践建议:
- 使用Frame Stacking处理部分可观测问题
- 定期保存模型检查点
- 使用W&B或TensorBoard监控训练
- 尝试不同的网络架构(CNN、LSTM等)
完整代码示例可在GitHub仓库获取,包含可视化界面和进阶实现。通过调整超参数,可以轻松迁移到Atari等更复杂环境!