文章目录
-
- [0. 先说个小故事:让小车自己找路](#0. 先说个小故事:让小车自己找路)
- [1. 问题建模:机器人导航 = MDP](#1. 问题建模:机器人导航 = MDP)
- [2. Q-Learning 与 Q 表的天花板](#2. Q-Learning 与 Q 表的天花板)
- [3. DQN 是什么?两个关键技术](#3. DQN 是什么?两个关键技术)
-
- [3.1 Experience Replay(经验回放)](#3.1 Experience Replay(经验回放))
- [3.2 Target Network(目标网络)](#3.2 Target Network(目标网络))
- [4. PyTorch 最小实现:CartPole 对比实验](#4. PyTorch 最小实现:CartPole 对比实验)
- [5. 机器人导航实战:用 Gymnasium GridWorld](#5. 机器人导航实战:用 Gymnasium GridWorld)
- [6. 超参数调优实战经验](#6. 超参数调优实战经验)
- [7. 总结](#7. 总结)
- 参考资料
0. 先说个小故事:让小车自己找路
想象你要教一个机器人从 A 点走到 B 点。传统做法是写一堆 if-else 规则------"前方有墙左转""前方没墙直走"。但现实环境复杂得多,你根本写不完所有规则。
强化学习(RL)的思路完全不同:不告诉机器人怎么做,只告诉它做得好不好。
- 每走一步,给它一个奖励或惩罚
- 机器人自己摸索出一套策略(policy)
- 最终学会:从 A 到 B 的最优路径
DQN(Deep Q-Network)就是 RL 中经典又效果拔群的算法之一。下面我们一步步把它用 PyTorch 跑起来,让一个小车学会导航。
1. 问题建模:机器人导航 = MDP
强化学习的核心模型是马尔可夫决策过程(MDP):((S, A, P, R, \gamma))。
在机器人导航场景中:
| 符号 | 含义 | 机器人导航例子 |
|---|---|---|
| S(状态空间) | 智能体当前的全部信息 | 小车周围 5×5 格子感知到的障碍/目标位置 |
| A(动作空间) | 智能体能做的所有动作 | 上、下、左、右 |
| P(转移概率) | (P(s' | s,a)) |
| R(奖励函数) | 做得好不好 | +10 到达目标,-1 每走一步,-100 撞墙 |
| (\gamma)(折扣因子) | 未来奖励的重要性 | 通常取 0.9~0.99 |
我们最终要学到一个策略(\pi(a|s)),最大化期望累积奖励:
G_t = R_t + \\gamma R_{t+1} + \\gamma\^2 R_{t+2} + \\cdots
💡 Q-Learning 的核心思想:不直接学策略,而是学一个"在状态 s 下做动作 a 有多好"的价值函数 (Q(s,a))。
2. Q-Learning 与 Q 表的天花板
经典 Q-Learning 用一张Q 表存储 (Q(s,a))。在离散状态有限时这很有效:
python
import numpy as np
# 简化的 Q 表:状态数 × 动作数
n_states = 25 # 5×5 格子
n_actions = 4 # 上、下、左、右
Q = np.zeros((n_states, n_actions))
# Q-Learning 更新公式(TD 学习)
# Q(s,a) <- Q(s,a) + alpha * (r + gamma * max_a' Q(s',a') - Q(s,a))
def q_learning_update(Q, s, a, r, s_next, alpha=0.1, gamma=0.9):
best_next_q = np.max(Q[s_next])
Q[s, a] = Q[s, a] + alpha * (r + gamma * best_next_q - Q[s, a])
return Q
但状态空间连续时 Q 表就爆炸了------真实��器人的传感器读数是实数,状态数无限,Q 表根本存不下。
解决方案:用神经网络近似 Q 函数:这就是 DQN 的核心动机。
3. DQN 是什么?两个关键技术
DQN 用深度神经网络 (Q(s,a;\theta)) 近似 Q 函数,核心靠两个机制稳定训练:
3.1 Experience Replay(经验回放)
训练 RL 时相邻样本高度相关,直接顺序训练会震荡。DQN 把数据存入回放缓冲区,随机小批量采样打相关性:
python
import numpy as np
import random
class ReplayBuffer:
def __init__(self, capacity=10000):
self.buffer = []
self.capacity = capacity
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
if len(self.buffer) > self.capacity:
self.buffer.pop(0)
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
state, action, reward, next_state, done = zip(*batch)
return np.array(state), action, reward, np.array(next_state), done
3.2 Target Network(目标网络)
直接用 TD target (r + \gamma \max_{a'} Q(s', a')) 做回归目标会导致训练不稳定------目标本身在被训练的网络参数影响下不断变化。
解法:单独维护一个"慢速"的目标网络(\theta^-),每隔若干步才同步一次:
y_i = r_i + \\gamma \\max_{a'} Q(s_i', a'; \\theta\^-)
4. PyTorch 最小实现:CartPole 对比实验
为了最快看到效果,先用经典 CartPole 环境(控制杆保持平衡)验证 DQN,再迁移到自定义导航环境:
python
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gymnasium as gym
from collections import deque
import random
# ---------- 1. Q-Network ----------
class DQN(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=128):
super(DQN, self).__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim)
)
def forward(self, x):
return self.net(x)
# ---------- 2. DQN Agent ----------
class DQNAgent:
def __init__(self, state_dim, action_dim):
self.action_dim = action_dim
self.gamma = 0.99
self.epsilon = 1.0
self.epsilon_min = 0.01
self.epsilon_decay = 0.995
self.lr = 1e-3
self.batch_size = 64
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.policy_net = DQN(state_dim, action_dim).to(self.device)
self.target_net = DQN(state_dim, action_dim).to(self.device)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval()
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.lr)
self.replay_buffer = ReplayBuffer(capacity=10000)
self.update_target_every = 10
self.step_count = 0
def select_action(self, state):
if random.random() < self.epsilon:
return random.randint(0, self.action_dim - 1)
with torch.no_grad():
state_t = torch.FloatTensor(state).to(self.device)
return int(self.policy_net(state_t).argmax().item())
def train_step(self):
if len(self.replay_buffer.buffer) < self.batch_size:
return
states, actions, rewards, next_states, dones = \
self.replay_buffer.sample(self.batch_size)
states_t = torch.FloatTensor(states).to(self.device)
next_states_t = torch.FloatTensor(next_states).to(self.device)
actions_t = torch.LongTensor(actions).to(self.device)
rewards_t = torch.FloatTensor(rewards).to(self.device)
dones_t = torch.FloatTensor(dones).to(self.device)
# current Q
q_vals = self.policy_net(states_t).gather(1, actions_t.unsqueeze(1)).squeeze(1)
# target Q(用目标网络)
with torch.no_grad():
next_q_vals = self.target_net(next_states_t).max(1)[0]
targets = rewards_t + self.gamma * (1 - dones_t) * next_q_vals
loss = nn.MSELoss()(q_vals, targets)
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
self.optimizer.step()
# 每隔 update_target_every 步同步目标网络
self.step_count += 1
if self.step_count % self.update_target_every == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
# epsilon 衰减
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
return loss.item()
# ---------- 3. 训练循环 ----------
def train_cartpole():
env = gym.make("CartPole-v1", render_mode="rgb_array")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQNAgent(state_dim, action_dim)
episode_rewards = deque(maxlen=20)
for episode in range(300):
state, _ = env.reset()
total_reward = 0
done = False
while not done:
action = agent.select_action(state)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
agent.replay_buffer.push(state, action, reward, next_state, done)
state = next_state
total_reward += reward
agent.train_step()
episode_rewards.append(total_reward)
avg_reward = np.mean(episode_rewards)
if episode % 20 == 0:
print(f"Episode {episode:03d} | Reward: {total_reward:5.1f} | "
f"Avg(20ep): {avg_reward:5.2f} | Epsilon: {agent.epsilon:.3f}")
if avg_reward >= 475:
print(f"\n[SUCCESS] 在第 {episode} 个 episode 达到目标平均奖励 {avg_reward:.1f}!")
break
env.close()
return episode_rewards
if __name__ == "__main__":
rewards = train_cartpole()
你会看到 reward 从初期几帧迅速攀升到接近 500(CartPole 满分),说明 DQN 学会了平衡杆。
5. 机器人导航实战:用 Gymnasium GridWorld
现在把同样的 DQN 框架迁移到自定义的网格世界导航环境。目标:从任意起点学会绕开障碍物走到目标格子。
python
import numpy as np
import gymnasium as gym
class GridWorldEnv(gym.Env):
metadata = {"render_modes": ["human"]}
def __init__(self, size=5):
super().__init__()
self.size = size
self.observation_space = gym.spaces.Box(0, size - 1, shape=(2,), dtype=np.float32)
self.action_space = gym.spaces.Discrete(4) # 0=up, 1=down, 2=left, 3=right
self._agent_pos = None
self._target_pos = None
self.grid = self._generate_grid()
self.max_steps = size * size * 2
self.step_count = 0
def _generate_grid(self):
# 0=通路, 1=障碍, 2=目标
grid = np.zeros((self.size, self.size), dtype=np.int32)
grid[1, 2] = 1
grid[2, 2] = 1
grid[3, 1] = 1
return grid
def reset(self, seed=None):
super().reset(seed=seed)
while True:
self._agent_pos = np.array([
self.np_random.integers(0, self.size),
self.np_random.integers(0, self.size)
])
if self.grid[tuple(self._agent_pos)] == 0:
break
while True:
self._target_pos = np.array([
self.np_random.integers(0, self.size),
self.np_random.integers(0, self.size)
])
if (self.grid[tuple(self._target_pos)] == 0 and
not np.array_equal(self._target_pos, self._agent_pos)):
break
self.step_count = 0
return self._get_obs(), {}
def step(self, action):
moves = {0: [-1, 0], 1: [1, 0], 2: [0, -1], 3: [0, 1]}
delta = np.array(moves[action])
new_pos = np.clip(self._agent_pos + delta, 0, self.size - 1)
if self.grid[tuple(new_pos)] == 1: # 撞墙
reward = -2
new_pos = self._agent_pos.copy()
elif np.array_equal(new_pos, self._target_pos): # 到达目标
reward = 10
else: # 正常移动
reward = -0.1
self._agent_pos = new_pos
self.step_count += 1
done = np.array_equal(self._agent_pos, self._target_pos)
return self._get_obs(), reward, done, False, {}
def _get_obs(self):
return np.concatenate([self._agent_pos / self.size, self._target_pos / self.size]).astype(np.float32)
def render(self):
grid_vis = self.grid.copy().astype(str)
grid_vis[tuple(self._agent_pos)] = "A"
grid_vis[tuple(self._target_pos)] = "T"
grid_vis[grid_vis == "0"] = "."
grid_vis[grid_vis == "1"] = "#"
print("\n" + "\n".join(" ".join(row) for row in grid_vis))
# 训练
from dqn_agent import DQNAgent # 复用上面的 Agent
def train_gridworld():
env = GridWorldEnv(size=5)
state_dim = 4 # [agent_x, agent_y, target_x, target_y]
action_dim = 4
agent = DQNAgent(state_dim, action_dim)
agent.gamma = 0.95
agent.epsilon = 1.0
recent = deque(maxlen=50)
for ep in range(500):
s, _ = env.reset()
total_r = 0
done = False
while not done:
a = agent.select_action(s)
s2, r, done, _, _ = env.step(a)
agent.replay_buffer.push(s, a, r, s2, done)
s = s2
total_r += r
agent.train_step()
recent.append(total_r)
if ep % 50 == 0:
print(f"Ep {ep:03d} | Avg reward: {np.mean(recent):5.2f} | eps: {agent.epsilon:.3f}")
if np.mean(recent) > 8.5:
print(f"[DONE] 学会了!平均奖励 {np.mean(recent):.2f}")
break
if __name__ == "__main__":
train_gridworld()
💡 为什么 DQN 能泛化? 即使输入是连续坐标,网络也能学到"障碍附近要绕开"的模式,因为网络权重对相似状态输出相似 Q 值。
6. 超参数调优实战经验
| 参数 | 默认值 | 调优建议 |
|---|---|---|
| 学习率 | 1e-3 | 从 1e-3 开始,DQN 通常用更小 |
| Batch Size | 64 | 太小梯度不稳定,太大收敛慢 |
| Epsilon 衰减 | 0.995 | 太快学不到东西,太慢探索效率低 |
| Target Network 更新频率 | 每 10 步 | 更新越频繁收敛越快但越不稳定 |
| Replay Buffer 大小 | 10000 | 任务越复杂 Buffer 越大越好 |
| Discount Factor γ | 0.99 | γ 越大越重视长期奖励 |
7. 总结
DQN 的精髓在三个组件:
- Q-Network:用神经网络逼近 Q(s,a),突破 Q 表的天花板
- Experience Replay:打乱样本相关性,稳定训练
- Target Network:让 TD target 固定住,不再"自己追自己"
有了这三件套 + PyTorch,你可以在任何 Gymnasium 环境下快速验证 DQN------从 CartPole 到机器人导航,框架完全一样。
参考资料
- Mnih et al. (2015). Human-level control through deep reinforcement learning. Nature.
- Sutton & Barto. Reinforcement Learning: An Introduction 第 6、7 章
- PyTorch DQN 官方教程:https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html