强化学习:DQN玩转CartPole游戏

强化学习:DQN玩转CartPole游戏

1. CartPole环境与强化学习基础

1.1 环境介绍

CartPole是OpenAI Gym中的经典控制问题,目标是通过左右移动小车保持杆子竖直:

graph TD A[状态空间] --> B[车位置 -2.4~2.4] A --> C[车速度 -∞~∞] A --> D[杆角度 -41.8°~41.8°] A --> E[杆角速度 -∞~∞] F[动作空间] --> G[向左推 0] F --> H[向右推 1] style A fill:#9f9,stroke:#333 style F fill:#f99,stroke:#333

1.2 强化学习基本概念

  • 状态(State) : <math xmlns="http://www.w3.org/1998/Math/MathML"> s t ∈ R 4 s_t \in \mathbb{R}^4 </math>st∈R4
  • 动作(Action) : <math xmlns="http://www.w3.org/1998/Math/MathML"> a t ∈ { 0 , 1 } a_t \in \{0, 1\} </math>at∈{0,1}
  • 奖励(Reward):每步存活获得+1
  • 目标 :最大化累积奖励 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∑ t = 0 T γ t r t \sum_{t=0}^T \gamma^t r_t </math>∑t=0Tγtrt

2. DQN算法原理

2.1 Q-Learning更新公式

<math xmlns="http://www.w3.org/1998/Math/MathML"> Q ( s t , a t ) ← Q ( s t , a t ) + α [ r t + 1 + γ max ⁡ a Q ( s t + 1 , a ) − Q ( s t , a t ) ] Q(s_t,a_t) \leftarrow Q(s_t,a_t) + \alpha[r_{t+1} + \gamma \max_a Q(s_{t+1},a) - Q(s_t,a_t)] </math>Q(st,at)←Q(st,at)+α[rt+1+γmaxaQ(st+1,a)−Q(st,at)]

2.2 深度Q网络改进

graph LR A[传统Q-Learning] --> B[状态空间爆炸] B --> C[深度网络拟合] C --> D[经验回放] D --> E[目标网络] style C fill:#99f,stroke:#333 style E fill:#99f,stroke:#333
2.2.1 关键技术组件
  1. 经验回放(Experience Replay):打破数据相关性
  2. 目标网络(Target Network):稳定训练目标
  3. ε-贪婪策略:平衡探索与利用

3. PyTorch实现DQN

3.1 Q网络定义

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim

class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
    
    def forward(self, x):
        return self.fc(x)

3.2 经验回放缓冲区

python 复制代码
from collections import deque
import random

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)
    
    def __len__(self):
        return len(self.buffer)

3.3 DQN智能体实现

python 复制代码
class DQNAgent:
    def __init__(self, env, gamma=0.99, lr=1e-3):
        self.env = env
        self.gamma = gamma
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        
        self.policy_net = DQN(env.observation_space.shape[0], 
                             env.action_space.n)
        self.target_net = DQN(env.observation_space.shape[0],
                             env.action_space.n)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
        self.memory = ReplayBuffer(10000)
    
    def select_action(self, state):
        if random.random() < self.epsilon:
            return self.env.action_space.sample()
        with torch.no_grad():
            return self.policy_net(state).argmax().item()
    
    def update_model(self, batch_size):
        if len(self.memory) < batch_size:
            return
        
        # 从缓冲区采样
        transitions = self.memory.sample(batch_size)
        batch = list(zip(*transitions))
        
        # 转换为张量
        state_batch = torch.FloatTensor(batch[0])
        action_batch = torch.LongTensor(batch[1]).unsqueeze(1)
        reward_batch = torch.FloatTensor(batch[2])
        next_state_batch = torch.FloatTensor(batch[3])
        done_batch = torch.FloatTensor(batch[4])
        
        # 计算当前Q值
        q_values = self.policy_net(state_batch).gather(1, action_batch)
        
        # 计算目标Q值
        next_q_values = self.target_net(next_state_batch).max(1)[0].detach()
        expected_q = reward_batch + (1 - done_batch) * self.gamma * next_q_values
        
        # 计算损失
        loss = F.mse_loss(q_values, expected_q.unsqueeze(1))
        
        # 优化模型
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # 更新ε
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

4. 训练流程与结果

4.1 训练循环

python 复制代码
def train(env, agent, episodes=500, batch_size=64):
    rewards = []
    for ep in range(episodes):
        state = env.reset()
        total_reward = 0
        
        while True:
            state_tensor = torch.FloatTensor(state)
            action = agent.select_action(state_tensor)
            
            next_state, reward, done, _ = env.step(action)
            agent.memory.push(state, action, reward, next_state, done)
            
            agent.update_model(batch_size)
            
            state = next_state
            total_reward += reward
            
            if done:
                break
        
        # 更新目标网络
        if ep % 10 == 0:
            agent.target_net.load_state_dict(agent.policy_net.state_dict())
        
        rewards.append(total_reward)
        print(f"Episode {ep+1}/{episodes}, Reward: {total_reward}, Epsilon: {agent.epsilon:.2f}")
    
    return rewards

4.2 训练结果分析

python 复制代码
import matplotlib.pyplot as plt

env = gym.make('CartPole-v1')
agent = DQNAgent(env)
rewards = train(env, agent)

# 绘制学习曲线
plt.plot(rewards)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('DQN Training Progress')
plt.show()
4.2.1 典型训练曲线
graph LR A[初始随机探索] --> B[逐渐稳定] B --> C[达到最大奖励] style A fill:#f99,stroke:#333 style C fill:#9f9,stroke:#333

5. 高级改进技巧

5.1 Double DQN

修改目标Q值计算:

python 复制代码
next_actions = self.policy_net(next_state_batch).max(1)[1]
next_q_values = self.target_net(next_state_batch).gather(1, next_actions.unsqueeze(1)).squeeze(1)

5.2 Dueling DQN

修改网络结构:

python 复制代码
class DuelingDQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.feature = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU()
        )
        self.advantage = nn.Linear(128, output_dim)
        self.value = nn.Linear(128, 1)
    
    def forward(self, x):
        x = self.feature(x)
        advantage = self.advantage(x)
        value = self.value(x)
        return value + advantage - advantage.mean()

5.3 性能对比

方法 平均奖励 收敛速度 稳定性
原始DQN 195 300次
Double DQN 200 250次
Dueling DQN 200+ 200次 很高

6. 常见问题解答

Q: 为什么需要目标网络?

  • 防止Q值估计的快速变化导致训练不稳定
  • 提供相对固定的目标值进行学习

Q: 如何选择ε衰减速度?

  • 初始阶段保持较高探索(ε=1.0)
  • 逐步衰减到最小值(ε_min=0.01)
  • 典型衰减率:0.995~0.999

Q: 如何处理稀疏奖励问题?

  • 使用优先经验回放(Prioritized Experience Replay)
  • 引入内在好奇心模块
  • 调整奖励函数设计

附录:核心数学推导

Bellman最优方程

<math xmlns="http://www.w3.org/1998/Math/MathML"> Q ∗ ( s , a ) = E [ r + γ max ⁡ a ′ Q ∗ ( s ′ , a ′ ) ∣ s , a ] Q^*(s,a) = \mathbb{E}[r + \gamma \max_{a'} Q^*(s',a') | s,a] </math>Q∗(s,a)=E[r+γmaxa′Q∗(s′,a′)∣s,a]

损失函数推导

<math xmlns="http://www.w3.org/1998/Math/MathML"> L ( θ ) = E ( s , a , r , s ′ ) ∼ D [ ( r + γ max ⁡ a ′ Q target ( s ′ , a ′ ) − Q ( s , a ; θ ) ) 2 ] \mathcal{L}(\theta) = \mathbb{E}{(s,a,r,s') \sim D}[(r + \gamma \max{a'} Q_{\text{target}}(s',a') - Q(s,a;\theta))^2] </math>L(θ)=E(s,a,r,s′)∼D[(r+γmaxa′Qtarget(s′,a′)−Q(s,a;θ))2]

梯度更新公式

<math xmlns="http://www.w3.org/1998/Math/MathML"> θ ← θ − α ∇ θ L ( θ ) \theta \leftarrow \theta - \alpha \nabla_\theta \mathcal{L}(\theta) </math>θ←θ−α∇θL(θ)


最佳实践建议

  1. 使用Frame Stacking处理部分可观测问题
  2. 定期保存模型检查点
  3. 使用W&B或TensorBoard监控训练
  4. 尝试不同的网络架构(CNN、LSTM等)

完整代码示例可在GitHub仓库获取,包含可视化界面和进阶实现。通过调整超参数,可以轻松迁移到Atari等更复杂环境!

复制代码
相关推荐
LitchiCheng17 分钟前
DQN 玩 2048 实战|第二期!设计 ε 贪心策略神经网络,简单训练一下吧!
人工智能·深度学习·神经网络
tortorish26 分钟前
PyTorch中Batch Normalization1d的实现与手动验证
人工智能·pytorch·batch
wwwzhouhui33 分钟前
dify案例分享-儿童故事绘本语音播报视频工作流
人工智能·音视频·语音识别
南太湖小蚂蚁1 小时前
自然语言处理入门4——RNN
人工智能·rnn·深度学习·自然语言处理
Ronin-Lotus1 小时前
深度学习篇---分类任务图像预处理&模型训练
人工智能·python·深度学习·机器学习·分类·模型训练·分类任务
四口鲸鱼爱吃盐1 小时前
CVPR2025 | TAPT:用于视觉语言模型鲁棒推理的测试时对抗提示调整
网络·人工智能·深度学习·机器学习·语言模型·自然语言处理·对抗样本
沈二到不行1 小时前
多头注意力&位置编码:完型填空任务
人工智能·后端·deepseek
朱剑君2 小时前
机器学习概要
人工智能·机器学习
千亿的星空2 小时前
部队仓储信息化手段建设:基于RFID、IWMS、RCS三大技术的仓储物流全链路效能优化方案
大数据·人工智能·信息可视化·信息与通信·数据库开发·可信计算技术
猫先生Mr.Mao3 小时前
2025年2月AGI技术月评|重构创作边界:从视频生成革命到多模态生态的全面爆发
人工智能·大模型·aigc·agi·多模态·行业洞察