强化学习:DQN玩转CartPole游戏

强化学习:DQN玩转CartPole游戏

1. CartPole环境与强化学习基础

1.1 环境介绍

CartPole是OpenAI Gym中的经典控制问题,目标是通过左右移动小车保持杆子竖直:

graph TD A[状态空间] --> B[车位置 -2.4~2.4] A --> C[车速度 -∞~∞] A --> D[杆角度 -41.8°~41.8°] A --> E[杆角速度 -∞~∞] F[动作空间] --> G[向左推 0] F --> H[向右推 1] style A fill:#9f9,stroke:#333 style F fill:#f99,stroke:#333

1.2 强化学习基本概念

  • 状态(State) : <math xmlns="http://www.w3.org/1998/Math/MathML"> s t ∈ R 4 s_t \in \mathbb{R}^4 </math>st∈R4
  • 动作(Action) : <math xmlns="http://www.w3.org/1998/Math/MathML"> a t ∈ { 0 , 1 } a_t \in \{0, 1\} </math>at∈{0,1}
  • 奖励(Reward):每步存活获得+1
  • 目标 :最大化累积奖励 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∑ t = 0 T γ t r t \sum_{t=0}^T \gamma^t r_t </math>∑t=0Tγtrt

2. DQN算法原理

2.1 Q-Learning更新公式

<math xmlns="http://www.w3.org/1998/Math/MathML"> Q ( s t , a t ) ← Q ( s t , a t ) + α [ r t + 1 + γ max ⁡ a Q ( s t + 1 , a ) − Q ( s t , a t ) ] Q(s_t,a_t) \leftarrow Q(s_t,a_t) + \alpha[r_{t+1} + \gamma \max_a Q(s_{t+1},a) - Q(s_t,a_t)] </math>Q(st,at)←Q(st,at)+α[rt+1+γmaxaQ(st+1,a)−Q(st,at)]

2.2 深度Q网络改进

graph LR A[传统Q-Learning] --> B[状态空间爆炸] B --> C[深度网络拟合] C --> D[经验回放] D --> E[目标网络] style C fill:#99f,stroke:#333 style E fill:#99f,stroke:#333
2.2.1 关键技术组件
  1. 经验回放(Experience Replay):打破数据相关性
  2. 目标网络(Target Network):稳定训练目标
  3. ε-贪婪策略:平衡探索与利用

3. PyTorch实现DQN

3.1 Q网络定义

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim

class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
    
    def forward(self, x):
        return self.fc(x)

3.2 经验回放缓冲区

python 复制代码
from collections import deque
import random

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)
    
    def __len__(self):
        return len(self.buffer)

3.3 DQN智能体实现

python 复制代码
class DQNAgent:
    def __init__(self, env, gamma=0.99, lr=1e-3):
        self.env = env
        self.gamma = gamma
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        
        self.policy_net = DQN(env.observation_space.shape[0], 
                             env.action_space.n)
        self.target_net = DQN(env.observation_space.shape[0],
                             env.action_space.n)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
        self.memory = ReplayBuffer(10000)
    
    def select_action(self, state):
        if random.random() < self.epsilon:
            return self.env.action_space.sample()
        with torch.no_grad():
            return self.policy_net(state).argmax().item()
    
    def update_model(self, batch_size):
        if len(self.memory) < batch_size:
            return
        
        # 从缓冲区采样
        transitions = self.memory.sample(batch_size)
        batch = list(zip(*transitions))
        
        # 转换为张量
        state_batch = torch.FloatTensor(batch[0])
        action_batch = torch.LongTensor(batch[1]).unsqueeze(1)
        reward_batch = torch.FloatTensor(batch[2])
        next_state_batch = torch.FloatTensor(batch[3])
        done_batch = torch.FloatTensor(batch[4])
        
        # 计算当前Q值
        q_values = self.policy_net(state_batch).gather(1, action_batch)
        
        # 计算目标Q值
        next_q_values = self.target_net(next_state_batch).max(1)[0].detach()
        expected_q = reward_batch + (1 - done_batch) * self.gamma * next_q_values
        
        # 计算损失
        loss = F.mse_loss(q_values, expected_q.unsqueeze(1))
        
        # 优化模型
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # 更新ε
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

4. 训练流程与结果

4.1 训练循环

python 复制代码
def train(env, agent, episodes=500, batch_size=64):
    rewards = []
    for ep in range(episodes):
        state = env.reset()
        total_reward = 0
        
        while True:
            state_tensor = torch.FloatTensor(state)
            action = agent.select_action(state_tensor)
            
            next_state, reward, done, _ = env.step(action)
            agent.memory.push(state, action, reward, next_state, done)
            
            agent.update_model(batch_size)
            
            state = next_state
            total_reward += reward
            
            if done:
                break
        
        # 更新目标网络
        if ep % 10 == 0:
            agent.target_net.load_state_dict(agent.policy_net.state_dict())
        
        rewards.append(total_reward)
        print(f"Episode {ep+1}/{episodes}, Reward: {total_reward}, Epsilon: {agent.epsilon:.2f}")
    
    return rewards

4.2 训练结果分析

python 复制代码
import matplotlib.pyplot as plt

env = gym.make('CartPole-v1')
agent = DQNAgent(env)
rewards = train(env, agent)

# 绘制学习曲线
plt.plot(rewards)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('DQN Training Progress')
plt.show()
4.2.1 典型训练曲线
graph LR A[初始随机探索] --> B[逐渐稳定] B --> C[达到最大奖励] style A fill:#f99,stroke:#333 style C fill:#9f9,stroke:#333

5. 高级改进技巧

5.1 Double DQN

修改目标Q值计算:

python 复制代码
next_actions = self.policy_net(next_state_batch).max(1)[1]
next_q_values = self.target_net(next_state_batch).gather(1, next_actions.unsqueeze(1)).squeeze(1)

5.2 Dueling DQN

修改网络结构:

python 复制代码
class DuelingDQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.feature = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU()
        )
        self.advantage = nn.Linear(128, output_dim)
        self.value = nn.Linear(128, 1)
    
    def forward(self, x):
        x = self.feature(x)
        advantage = self.advantage(x)
        value = self.value(x)
        return value + advantage - advantage.mean()

5.3 性能对比

方法 平均奖励 收敛速度 稳定性
原始DQN 195 300次
Double DQN 200 250次
Dueling DQN 200+ 200次 很高

6. 常见问题解答

Q: 为什么需要目标网络?

  • 防止Q值估计的快速变化导致训练不稳定
  • 提供相对固定的目标值进行学习

Q: 如何选择ε衰减速度?

  • 初始阶段保持较高探索(ε=1.0)
  • 逐步衰减到最小值(ε_min=0.01)
  • 典型衰减率:0.995~0.999

Q: 如何处理稀疏奖励问题?

  • 使用优先经验回放(Prioritized Experience Replay)
  • 引入内在好奇心模块
  • 调整奖励函数设计

附录:核心数学推导

Bellman最优方程

<math xmlns="http://www.w3.org/1998/Math/MathML"> Q ∗ ( s , a ) = E [ r + γ max ⁡ a ′ Q ∗ ( s ′ , a ′ ) ∣ s , a ] Q^*(s,a) = \mathbb{E}[r + \gamma \max_{a'} Q^*(s',a') | s,a] </math>Q∗(s,a)=E[r+γmaxa′Q∗(s′,a′)∣s,a]

损失函数推导

<math xmlns="http://www.w3.org/1998/Math/MathML"> L ( θ ) = E ( s , a , r , s ′ ) ∼ D [ ( r + γ max ⁡ a ′ Q target ( s ′ , a ′ ) − Q ( s , a ; θ ) ) 2 ] \mathcal{L}(\theta) = \mathbb{E}{(s,a,r,s') \sim D}[(r + \gamma \max{a'} Q_{\text{target}}(s',a') - Q(s,a;\theta))^2] </math>L(θ)=E(s,a,r,s′)∼D[(r+γmaxa′Qtarget(s′,a′)−Q(s,a;θ))2]

梯度更新公式

<math xmlns="http://www.w3.org/1998/Math/MathML"> θ ← θ − α ∇ θ L ( θ ) \theta \leftarrow \theta - \alpha \nabla_\theta \mathcal{L}(\theta) </math>θ←θ−α∇θL(θ)


最佳实践建议

  1. 使用Frame Stacking处理部分可观测问题
  2. 定期保存模型检查点
  3. 使用W&B或TensorBoard监控训练
  4. 尝试不同的网络架构(CNN、LSTM等)

完整代码示例可在GitHub仓库获取,包含可视化界面和进阶实现。通过调整超参数,可以轻松迁移到Atari等更复杂环境!

复制代码
相关推荐
Darach1 分钟前
坐姿检测Python实现
人工智能·python
xiaok2 分钟前
LangBot 和消息平台均运行在 Docker 容器中
人工智能
queeny10 分钟前
Datawhale AI夏令营 科大讯飞AI大赛(大模型技术) Task3 心得
人工智能
ToTensor11 分钟前
Paraformer实时语音识别中的碎碎念
人工智能·语音识别·xcode
陈佬昔没带相机17 分钟前
Mac Mini 玩大模型避坑指南
人工智能·mac
重启的码农17 分钟前
llama.cpp 分布式推理介绍(4) RPC 服务器 (rpc_server)
c++·人工智能·神经网络
柠檬味拥抱18 分钟前
不确定环境下AI Agent的贝叶斯信念更新策略研究
人工智能
Nona996121 分钟前
从零开始学AI——13
人工智能
重启的码农21 分钟前
llama.cpp 分布式推理介绍(3) 远程过程调用后端 (RPC Backend)
c++·人工智能·神经网络
顾道长生'21 分钟前
(Arxiv-2025)SkyReels-A2:在视频扩散变换器中组合任意内容
人工智能·计算机视觉·音视频·多模态