PyTorch强化学习实战:从零实现DQN玩转CartPole

前言

你有没有想过:让程序像人一样,从零开始学会玩游戏?

不需要告诉它规则,不需要写if-else逻辑,只需要告诉它"得分高了就是好,得分低了就是坏",它自己就能摸索出最优策略。

这就是强化学习。

今天,我们用PyTorch实现经典的DQN算法,让AI学会玩CartPole游戏(平衡小车上的杆子)。


一、强化学习是什么?

核心概念

```

观测 (observation)

┌─────────────────────┐

│ │

│ Agent │

│ (智能体) │

│ │

└─────────────────────┘

动作 (action)

┌─────────────────────┐

│ │

│ Environment │

│ (环境) │

│ │

└─────────────────────┘

奖励 (reward) + 新观测

```

· Agent:我们的AI程序

· Environment:游戏环境(CartPole)

· State:当前状态(杆子的角度、小车的速度等)

· Action:动作(向左推、向右推)

· Reward:奖励(杆子保持直立就给+1)

目标

让Agent学会:在什么状态下,采取什么动作,能让累计奖励最大。


二、CartPole 游戏介绍

CartPole是OpenAI Gym里的经典入门环境:

```

┌─────┐

│ │ ← 杆子(要让它保持直立)

└──┬──┘

┌──────┴──────┐

│ 小车 │ ← 我们控制左右移动

└─────────────┘

```

状态空间(4个值):

  1. 小车位置 (-2.4 到 2.4)

  2. 小车速度 (-∞ 到 ∞)

  3. 杆子角度 (-41.8° 到 41.8°)

  4. 杆子角速度 (-∞ 到 ∞)

动作空间(2个):

· 0:向左推

· 1:向右推

奖励:

· 每存活一帧 +1

· 杆子倒下或小车出界 → 游戏结束


三、DQN算法原理

  1. Q-Learning 核心思想

Q值:在状态s下,采取动作a,能获得的未来累计奖励。

```

Q(s, a) = 即时奖励 + 未来奖励的期望

```

最优策略:每次选择Q值最大的动作。

  1. 神经网络的作用

传统Q-Learning用表格存储Q值,但状态空间太大时不行。

我们用神经网络来近似Q函数:

```

输入:状态 (4个数字)

→ [隐藏层] → [隐藏层] →

输出:每个动作的Q值 (2个数字)

```

  1. DQN的两个关键技术

· 经验回放:把过去的经验存起来,随机抽样训练,打破数据相关性

· 目标网络:用另一个网络计算目标值,稳定训练

```

┌─────────────┐ ┌─────────────┐

│ Q网络 │ │ 目标网络 │

│ (实时更新) │ │ (定期同步) │

└──────┬──────┘ └──────┬──────┘

│ │

│ 预测Q值 │ 计算目标Q值

│ │

└───────────────────┘

比较

计算损失

反向传播

```


四、完整代码实现

  1. 安装依赖

```bash

pip install torch gymnasium matplotlib

```

  1. 导入库和超参数设置

```python

import torch

import torch.nn as nn

import torch.optim as optim

import torch.nn.functional as F

import numpy as np

import random

from collections import deque

import gymnasium as gym

import matplotlib.pyplot as plt

超参数

EPISODES = 500 # 训练回合数

BATCH_SIZE = 64 # 批量大小

GAMMA = 0.99 # 折扣因子

EPSILON_START = 1.0 # 初始探索率

EPSILON_END = 0.01 # 最终探索率

EPSILON_DECAY = 0.995 # 探索率衰减

LEARNING_RATE = 0.001 # 学习率

MEMORY_SIZE = 10000 # 经验池大小

TARGET_UPDATE = 10 # 目标网络更新频率

```

  1. 神经网络定义

```python

class DQN(nn.Module):

def init(self, state_size, action_size):

super(DQN, self).init()

self.fc1 = nn.Linear(state_size, 128)

self.fc2 = nn.Linear(128, 128)

self.fc3 = nn.Linear(128, action_size)

def forward(self, x):

x = F.relu(self.fc1(x))

x = F.relu(self.fc2(x))

return self.fc3(x)

```

  1. 经验回放缓冲区

```python

class ReplayBuffer:

def init(self, capacity):

self.buffer = deque(maxlen=capacity)

def push(self, state, action, reward, next_state, done):

self.buffer.append((state, action, reward, next_state, done))

def sample(self, batch_size):

batch = random.sample(self.buffer, batch_size)

state, action, reward, next_state, done = zip(*batch)

return (np.array(state), action, reward,

np.array(next_state), done)

def len(self):

return len(self.buffer)

```

  1. Agent 实现

```python

class DQNAgent:

def init(self, state_size, action_size):

self.state_size = state_size

self.action_size = action_size

self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Q网络和目标网络

self.q_network = DQN(state_size, action_size).to(self.device)

self.target_network = DQN(state_size, action_size).to(self.device)

self.optimizer = optim.Adam(self.q_network.parameters(), lr=LEARNING_RATE)

目标网络初始化为Q网络的副本

self.target_network.load_state_dict(self.q_network.state_dict())

经验回放

self.memory = ReplayBuffer(MEMORY_SIZE)

探索率

self.epsilon = EPSILON_START

def act(self, state):

"""根据状态选择动作(ε-贪心策略)"""

if random.random() < self.epsilon:

return random.randrange(self.action_size)

state = torch.FloatTensor(state).unsqueeze(0).to(self.device)

with torch.no_grad():

q_values = self.q_network(state)

return q_values.argmax().item()

def remember(self, state, action, reward, next_state, done):

"""存储经验"""

self.memory.push(state, action, reward, next_state, done)

def learn(self):

"""从经验池中学习"""

if len(self.memory) < BATCH_SIZE:

return

采样

states, actions, rewards, next_states, dones = self.memory.sample(BATCH_SIZE)

转换为tensor

states = torch.FloatTensor(states).to(self.device)

actions = torch.LongTensor(actions).unsqueeze(1).to(self.device)

rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)

next_states = torch.FloatTensor(next_states).to(self.device)

dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device)

当前Q值

current_q = self.q_network(states).gather(1, actions)

目标Q值

with torch.no_grad():

next_q = self.target_network(next_states).max(1, keepdim=True)[0]

target_q = rewards + (GAMMA * next_q * (1 - dones))

计算损失并更新

loss = F.mse_loss(current_q, target_q)

self.optimizer.zero_grad()

loss.backward()

self.optimizer.step()

def update_epsilon(self):

"""衰减探索率"""

self.epsilon = max(EPSILON_END, self.epsilon * EPSILON_DECAY)

def update_target_network(self):

"""同步目标网络"""

self.target_network.load_state_dict(self.q_network.state_dict())

```

  1. 训练主循环

```python

def train():

创建环境

env = gym.make('CartPole-v1')

state_size = env.observation_space.shape[0] # 4

action_size = env.action_space.n # 2

agent = DQNAgent(state_size, action_size)

scores = []

for episode in range(EPISODES):

state, _ = env.reset()

total_reward = 0

done = False

while not done:

选择动作

action = agent.act(state)

执行动作

next_state, reward, terminated, truncated, _ = env.step(action)

done = terminated or truncated

存储经验

agent.remember(state, action, reward, next_state, done)

学习

agent.learn()

state = next_state

total_reward += reward

更新探索率

agent.update_epsilon()

定期更新目标网络

if episode % TARGET_UPDATE == 0:

agent.update_target_network()

scores.append(total_reward)

打印进度

avg_score = np.mean(scores[-100:]) if len(scores) >= 100 else np.mean(scores)

print(f"Episode {episode}, Score: {total_reward}, "

f"Avg: {avg_score:.2f}, Epsilon: {agent.epsilon:.3f}")

早停:连续100回合平均得分>195就认为解决了

if avg_score >= 195:

print(f"🎉 解决了!第{episode}回合达成")

break

保存模型

torch.save(agent.q_network.state_dict(), "dqn_cartpole.pth")

print("模型已保存")

绘制训练曲线

plt.plot(scores)

plt.xlabel('Episode')

plt.ylabel('Score')

plt.title('DQN Training on CartPole')

plt.show()

env.close()

return agent, scores

```

  1. 测试训练好的模型

```python

def test(agent, episodes=10):

env = gym.make('CartPole-v1', render_mode='human')

agent.epsilon = 0 # 测试时不探索,只利用

for episode in range(episodes):

state, _ = env.reset()

total_reward = 0

done = False

while not done:

action = agent.act(state)

next_state, reward, terminated, truncated, _ = env.step(action)

done = terminated or truncated

state = next_state

total_reward += reward

print(f"Test Episode {episode}, Score: {total_reward}")

env.close()

运行

if name == "main":

agent, scores = train()

test(agent) # 取消注释以观看AI表演

```


五、训练效果分析

训练曲线示例

```

Episode 0, Score: 12, Avg: 12.00, Epsilon: 1.000

Episode 50, Score: 32, Avg: 18.50, Epsilon: 0.700

Episode 100, Score: 78, Avg: 45.20, Epsilon: 0.490

Episode 200, Score: 156, Avg: 120.30, Epsilon: 0.241

Episode 300, Score: 198, Avg: 189.50, Epsilon: 0.118

Episode 350, Score: 500, Avg: 198.20, Epsilon: 0.085

🎉 解决了!第352回合达成

```

常见问题及解决

问题 原因 解决方案

学不会 学习率太大/太小 调整 LEARNING_RATE

不稳定 批量大小太小 增大 BATCH_SIZE

收敛慢 探索率衰减太快 调小 EPSILON_DECAY

不收敛 网络太简单/复杂 调整网络结构


六、改良版本:Double DQN

解决DQN过估计问题:

```python

class DoubleDQNAgent(DQNAgent):

def learn(self):

if len(self.memory) < BATCH_SIZE:

return

states, actions, rewards, next_states, dones = self.memory.sample(BATCH_SIZE)

states = torch.FloatTensor(states).to(self.device)

actions = torch.LongTensor(actions).unsqueeze(1).to(self.device)

rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)

next_states = torch.FloatTensor(next_states).to(self.device)

dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device)

current_q = self.q_network(states).gather(1, actions)

Double DQN:用Q网络选动作,目标网络算Q值

with torch.no_grad():

next_actions = self.q_network(next_states).max(1, keepdim=True)[1]

next_q = self.target_network(next_states).gather(1, next_actions)

target_q = rewards + (GAMMA * next_q * (1 - dones))

loss = F.mse_loss(current_q, target_q)

self.optimizer.zero_grad()

loss.backward()

self.optimizer.step()

```


七、扩展到其他游戏

换了环境只需要改几个地方:

```python

LunarLander(登月着陆器)

env = gym.make('LunarLander-v2')

state_size = 8 # 8个观测值

action_size = 4 # 4个动作

Atari游戏(需要预处理)

env = gym.make('Breakout-v2', render_mode='human')

需要添加帧堆叠、灰度化、缩放等预处理

```


八、总结

通过这篇文章,你学会了:

· 强化学习的核心概念(Agent、Environment、State、Action、Reward)

· DQN算法的原理(神经网络近似Q值 + 经验回放 + 目标网络)

· 从头实现DQN训练CartPole

· 训练曲线分析和问题排查

· Double DQN的改良版本

强化学习是AI领域最有意思的方向之一。看着AI从完全随机乱动,到学会完美平衡杆子,那种感觉真的很奇妙。

下一篇预告:《策略梯度方法:让AI直接输出动作概率》


评论区分享一下你训练了多少回合才让AI学会~

相关推荐
三品吉他手会点灯2 小时前
C语言学习笔记 - 13.C语言简介 - 回顾本讲内容
c语言·笔记·学习
大大杰哥2 小时前
Spring AI 开发笔记:ChatClient 的创建、配置与工具函数注册
人工智能·笔记·spring
再玩一会儿看代码2 小时前
idea中快捷键详细总结整理
java·ide·经验分享·笔记·学习·intellij-idea
破阵子443282 小时前
Premiere(Pr) 下载安装教程(附安装包)
笔记
是上好佳佳佳呀2 小时前
【前端(九)】CSS Transform 2D/3D 变换笔记:分清两个原点,搞懂多重变换
前端·css·笔记
handler0112 小时前
从零实现自动化构建:Linux Makefile 完全指南
linux·c++·笔记·学习·自动化
Hello_Embed13 小时前
嵌入式上位机开发入门(二十六):将 MQTT 测试程序加入 APP 任务
网络·笔记·网络协议·tcp/ip·嵌入式
不会编程的懒洋洋14 小时前
C# Task async/await CancellationToken
笔记·c#·线程·面向对象·task·同步异步
zhangrelay17 小时前
蓝桥云课五分钟-通关自动控制-octave
笔记·学习