强化学习:DQN玩转CartPole游戏

强化学习:DQN玩转CartPole游戏

1. CartPole环境与强化学习基础

1.1 环境介绍

CartPole是OpenAI Gym中的经典控制问题,目标是通过左右移动小车保持杆子竖直:

graph TD A[状态空间] --> B[车位置 -2.4~2.4] A --> C[车速度 -∞~∞] A --> D[杆角度 -41.8°~41.8°] A --> E[杆角速度 -∞~∞] F[动作空间] --> G[向左推 0] F --> H[向右推 1] style A fill:#9f9,stroke:#333 style F fill:#f99,stroke:#333

1.2 强化学习基本概念

  • 状态(State) s t ∈ R 4 s_t \in \mathbb{R}^4 st∈R4
  • 动作(Action) a t ∈ { 0 , 1 } a_t \in \{0, 1\} at∈{0,1}
  • 奖励(Reward):每步存活获得+1
  • 目标 :最大化累积奖励 ∑ t = 0 T γ t r t \sum_{t=0}^T \gamma^t r_t ∑t=0Tγtrt

2. DQN算法原理

2.1 Q-Learning更新公式

Q ( s t , a t ) ← Q ( s t , a t ) + α r t + 1 + γ max ⁡ a Q ( s t + 1 , a ) − Q ( s t , a t ) Q(s_t,a_t) \leftarrow Q(s_t,a_t) + \alphar_{t+1} + \\gamma \\max_a Q(s_{t+1},a) - Q(s_t,a_t) Q(st,at)←Q(st,at)+αrt+1+γmaxaQ(st+1,a)−Q(st,at)

2.2 深度Q网络改进

graph LR A[传统Q-Learning] --> B[状态空间爆炸] B --> C[深度网络拟合] C --> D[经验回放] D --> E[目标网络] style C fill:#99f,stroke:#333 style E fill:#99f,stroke:#333
2.2.1 关键技术组件
  1. 经验回放(Experience Replay):打破数据相关性
  2. 目标网络(Target Network):稳定训练目标
  3. ε-贪婪策略:平衡探索与利用

3. PyTorch实现DQN

3.1 Q网络定义

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim

class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
    
    def forward(self, x):
        return self.fc(x)

3.2 经验回放缓冲区

python 复制代码
from collections import deque
import random

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)
    
    def __len__(self):
        return len(self.buffer)

3.3 DQN智能体实现

python 复制代码
class DQNAgent:
    def __init__(self, env, gamma=0.99, lr=1e-3):
        self.env = env
        self.gamma = gamma
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        
        self.policy_net = DQN(env.observation_space.shape[0], 
                             env.action_space.n)
        self.target_net = DQN(env.observation_space.shape[0],
                             env.action_space.n)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
        self.memory = ReplayBuffer(10000)
    
    def select_action(self, state):
        if random.random() < self.epsilon:
            return self.env.action_space.sample()
        with torch.no_grad():
            return self.policy_net(state).argmax().item()
    
    def update_model(self, batch_size):
        if len(self.memory) < batch_size:
            return
        
        # 从缓冲区采样
        transitions = self.memory.sample(batch_size)
        batch = list(zip(*transitions))
        
        # 转换为张量
        state_batch = torch.FloatTensor(batch[0])
        action_batch = torch.LongTensor(batch[1]).unsqueeze(1)
        reward_batch = torch.FloatTensor(batch[2])
        next_state_batch = torch.FloatTensor(batch[3])
        done_batch = torch.FloatTensor(batch[4])
        
        # 计算当前Q值
        q_values = self.policy_net(state_batch).gather(1, action_batch)
        
        # 计算目标Q值
        next_q_values = self.target_net(next_state_batch).max(1)[0].detach()
        expected_q = reward_batch + (1 - done_batch) * self.gamma * next_q_values
        
        # 计算损失
        loss = F.mse_loss(q_values, expected_q.unsqueeze(1))
        
        # 优化模型
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # 更新ε
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

4. 训练流程与结果

4.1 训练循环

python 复制代码
def train(env, agent, episodes=500, batch_size=64):
    rewards = []
    for ep in range(episodes):
        state = env.reset()
        total_reward = 0
        
        while True:
            state_tensor = torch.FloatTensor(state)
            action = agent.select_action(state_tensor)
            
            next_state, reward, done, _ = env.step(action)
            agent.memory.push(state, action, reward, next_state, done)
            
            agent.update_model(batch_size)
            
            state = next_state
            total_reward += reward
            
            if done:
                break
        
        # 更新目标网络
        if ep % 10 == 0:
            agent.target_net.load_state_dict(agent.policy_net.state_dict())
        
        rewards.append(total_reward)
        print(f"Episode {ep+1}/{episodes}, Reward: {total_reward}, Epsilon: {agent.epsilon:.2f}")
    
    return rewards

4.2 训练结果分析

python 复制代码
import matplotlib.pyplot as plt

env = gym.make('CartPole-v1')
agent = DQNAgent(env)
rewards = train(env, agent)

# 绘制学习曲线
plt.plot(rewards)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('DQN Training Progress')
plt.show()
4.2.1 典型训练曲线
graph LR A[初始随机探索] --> B[逐渐稳定] B --> C[达到最大奖励] style A fill:#f99,stroke:#333 style C fill:#9f9,stroke:#333

5. 高级改进技巧

5.1 Double DQN

修改目标Q值计算:

python 复制代码
next_actions = self.policy_net(next_state_batch).max(1)[1]
next_q_values = self.target_net(next_state_batch).gather(1, next_actions.unsqueeze(1)).squeeze(1)

5.2 Dueling DQN

修改网络结构:

python 复制代码
class DuelingDQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.feature = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU()
        )
        self.advantage = nn.Linear(128, output_dim)
        self.value = nn.Linear(128, 1)
    
    def forward(self, x):
        x = self.feature(x)
        advantage = self.advantage(x)
        value = self.value(x)
        return value + advantage - advantage.mean()

5.3 性能对比

方法 平均奖励 收敛速度 稳定性
原始DQN 195 300次
Double DQN 200 250次
Dueling DQN 200+ 200次 很高

6. 常见问题解答

Q: 为什么需要目标网络?

  • 防止Q值估计的快速变化导致训练不稳定
  • 提供相对固定的目标值进行学习

Q: 如何选择ε衰减速度?

  • 初始阶段保持较高探索(ε=1.0)
  • 逐步衰减到最小值(ε_min=0.01)
  • 典型衰减率:0.995~0.999

Q: 如何处理稀疏奖励问题?

  • 使用优先经验回放(Prioritized Experience Replay)
  • 引入内在好奇心模块
  • 调整奖励函数设计

附录:核心数学推导

Bellman最优方程

Q ∗ ( s , a ) = E r + γ max ⁡ a ′ Q ∗ ( s ′ , a ′ ) ∣ s , a Q^*(s,a) = \mathbb{E}r + \\gamma \\max_{a'} Q\^\*(s',a') \| s,a Q∗(s,a)=Er+γmaxa′Q∗(s′,a′)∣s,a

损失函数推导

L ( θ ) = E ( s , a , r , s ′ ) ∼ D ( r + γ max ⁡ a ′ Q target ( s ′ , a ′ ) − Q ( s , a ; θ ) ) 2 \mathcal{L}(\theta) = \mathbb{E}_{(s,a,r,s') \sim D}(r + \\gamma \\max_{a'} Q_{\\text{target}}(s',a') - Q(s,a;\\theta))\^2 L(θ)=E(s,a,r,s′)∼D(r+γmaxa′Qtarget(s′,a′)−Q(s,a;θ))2

梯度更新公式

θ ← θ − α ∇ θ L ( θ ) \theta \leftarrow \theta - \alpha \nabla_\theta \mathcal{L}(\theta) θ←θ−α∇θL(θ)


最佳实践建议

  1. 使用Frame Stacking处理部分可观测问题
  2. 定期保存模型检查点
  3. 使用W&B或TensorBoard监控训练
  4. 尝试不同的网络架构(CNN、LSTM等)

完整代码示例可在GitHub仓库获取,包含可视化界面和进阶实现。通过调整超参数,可以轻松迁移到Atari等更复杂环境!

复制代码
相关推荐
火山引擎开发者社区6 小时前
技术速递|使用 GitHub Copilot CLI 构建 Emoji 列表生成器
人工智能
codefan※7 小时前
干掉“幻觉“实战:如何构建企业级知识图谱增强 RAG
人工智能·知识图谱
wukangjupingbb7 小时前
传统基于药物 SMILES 序列和蛋白质氨基酸序列的 DTI(Drug-Target Interaction)预测方法的缺陷
人工智能
沪漂阿龙7 小时前
Codex 额度重置周期变化:AI 编程免费试玩时代正在结束
人工智能
TickDB7 小时前
美股行情 API 接入避坑:REST 快照、WebSocket 推送、盘前盘后数据的边界
人工智能·python·websocket·行情数据 api
装不满的克莱因瓶7 小时前
深入理解卷积神经网络(CNN)——从原理到代码实践
人工智能·神经网络·cnn
完成大叔8 小时前
模块二,Agent知识图谱的工具链思考
人工智能
lauo8 小时前
ibbot手机发布:搭载poplang技术 + token节点经济,革新AI手机体验
人工智能·智能手机
咖啡星人k8 小时前
云端开发环境技术架构深度解析:从容器隔离到AI Agent集成
人工智能·架构
袋鼠云数栈8 小时前
从前端到基础设施,ACOS 如何打通企业全链路可观测
运维·前端·人工智能·数据治理·数据智能