从 DQN 到机器人导航:用深度 Q 网络让小车学会自己走路(含 PyTorch 代码)

文章目录

    • [0. 先说个小故事:让小车自己找路](#0. 先说个小故事:让小车自己找路)
    • [1. 问题建模:机器人导航 = MDP](#1. 问题建模:机器人导航 = MDP)
    • [2. Q-Learning 与 Q 表的天花板](#2. Q-Learning 与 Q 表的天花板)
    • [3. DQN 是什么?两个关键技术](#3. DQN 是什么?两个关键技术)
      • [3.1 Experience Replay(经验回放)](#3.1 Experience Replay(经验回放))
      • [3.2 Target Network(目标网络)](#3.2 Target Network(目标网络))
    • [4. PyTorch 最小实现:CartPole 对比实验](#4. PyTorch 最小实现:CartPole 对比实验)
    • [5. 机器人导航实战:用 Gymnasium GridWorld](#5. 机器人导航实战:用 Gymnasium GridWorld)
    • [6. 超参数调优实战经验](#6. 超参数调优实战经验)
    • [7. 总结](#7. 总结)
    • 参考资料

0. 先说个小故事:让小车自己找路

想象你要教一个机器人从 A 点走到 B 点。传统做法是写一堆 if-else 规则------"前方有墙左转""前方没墙直走"。但现实环境复杂得多,你根本写不完所有规则。

强化学习(RL)的思路完全不同:不告诉机器人怎么做,只告诉它做得好不好

  • 每走一步,给它一个奖励或惩罚
  • 机器人自己摸索出一套策略(policy)
  • 最终学会:从 A 到 B 的最优路径

DQN(Deep Q-Network)就是 RL 中经典又效果拔群的算法之一。下面我们一步步把它用 PyTorch 跑起来,让一个小车学会导航。


1. 问题建模:机器人导航 = MDP

强化学习的核心模型是马尔可夫决策过程(MDP):((S, A, P, R, \gamma))。

在机器人导航场景中:

符号 含义 机器人导航例子
S(状态空间) 智能体当前的全部信息 小车周围 5×5 格子感知到的障碍/目标位置
A(动作空间) 智能体能做的所有动作 上、下、左、右
P(转移概率) (P(s' s,a))
R(奖励函数) 做得好不好 +10 到达目标,-1 每走一步,-100 撞墙
(\gamma)(折扣因子) 未来奖励的重要性 通常取 0.9~0.99

我们最终要学到一个策略(\pi(a|s)),最大化期望累积奖励:

G_t = R_t + \\gamma R_{t+1} + \\gamma\^2 R_{t+2} + \\cdots

💡 Q-Learning 的核心思想:不直接学策略,而是学一个"在状态 s 下做动作 a 有多好"的价值函数 (Q(s,a))。


2. Q-Learning 与 Q 表的天花板

经典 Q-Learning 用一张Q 表存储 (Q(s,a))。在离散状态有限时这很有效:

python 复制代码
import numpy as np

# 简化的 Q 表:状态数 × 动作数
n_states = 25   # 5×5 格子
n_actions = 4   # 上、下、左、右

Q = np.zeros((n_states, n_actions))

# Q-Learning 更新公式(TD 学习)
# Q(s,a) <- Q(s,a) + alpha * (r + gamma * max_a' Q(s',a') - Q(s,a))
def q_learning_update(Q, s, a, r, s_next, alpha=0.1, gamma=0.9):
    best_next_q = np.max(Q[s_next])
    Q[s, a] = Q[s, a] + alpha * (r + gamma * best_next_q - Q[s, a])
    return Q

状态空间连续时 Q 表就爆炸了------真实��器人的传感器读数是实数,状态数无限,Q 表根本存不下。

解决方案:用神经网络近似 Q 函数:这就是 DQN 的核心动机。


3. DQN 是什么?两个关键技术

DQN 用深度神经网络 (Q(s,a;\theta)) 近似 Q 函数,核心靠两个机制稳定训练:

3.1 Experience Replay(经验回放)

训练 RL 时相邻样本高度相关,直接顺序训练会震荡。DQN 把数据存入回放缓冲区,随机小批量采样打相关性:

python 复制代码
import numpy as np
import random

class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.buffer = []
        self.capacity = capacity

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
        if len(self.buffer) > self.capacity:
            self.buffer.pop(0)

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*batch)
        return np.array(state), action, reward, np.array(next_state), done

3.2 Target Network(目标网络)

直接用 TD target (r + \gamma \max_{a'} Q(s', a')) 做回归目标会导致训练不稳定------目标本身在被训练的网络参数影响下不断变化。

解法:单独维护一个"慢速"的目标网络(\theta^-),每隔若干步才同步一次:

y_i = r_i + \\gamma \\max_{a'} Q(s_i', a'; \\theta\^-)


4. PyTorch 最小实现:CartPole 对比实验

为了最快看到效果,先用经典 CartPole 环境(控制杆保持平衡)验证 DQN,再迁移到自定义导航环境:

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gymnasium as gym
from collections import deque
import random

# ---------- 1. Q-Network ----------
class DQN(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )

    def forward(self, x):
        return self.net(x)


# ---------- 2. DQN Agent ----------
class DQNAgent:
    def __init__(self, state_dim, action_dim):
        self.action_dim = action_dim
        self.gamma = 0.99
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.lr = 1e-3
        self.batch_size = 64

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.policy_net = DQN(state_dim, action_dim).to(self.device)
        self.target_net = DQN(state_dim, action_dim).to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.lr)
        self.replay_buffer = ReplayBuffer(capacity=10000)
        self.update_target_every = 10
        self.step_count = 0

    def select_action(self, state):
        if random.random() < self.epsilon:
            return random.randint(0, self.action_dim - 1)
        with torch.no_grad():
            state_t = torch.FloatTensor(state).to(self.device)
            return int(self.policy_net(state_t).argmax().item())


    def train_step(self):
        if len(self.replay_buffer.buffer) < self.batch_size:
            return

        states, actions, rewards, next_states, dones = \
            self.replay_buffer.sample(self.batch_size)

        states_t = torch.FloatTensor(states).to(self.device)
        next_states_t = torch.FloatTensor(next_states).to(self.device)
        actions_t = torch.LongTensor(actions).to(self.device)
        rewards_t = torch.FloatTensor(rewards).to(self.device)
        dones_t = torch.FloatTensor(dones).to(self.device)

        # current Q
        q_vals = self.policy_net(states_t).gather(1, actions_t.unsqueeze(1)).squeeze(1)

        # target Q(用目标网络)
        with torch.no_grad():
            next_q_vals = self.target_net(next_states_t).max(1)[0]
            targets = rewards_t + self.gamma * (1 - dones_t) * next_q_vals

        loss = nn.MSELoss()(q_vals, targets)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
        self.optimizer.step()


        # 每隔 update_target_every 步同步目标网络
        self.step_count += 1
        if self.step_count % self.update_target_every == 0:
            self.target_net.load_state_dict(self.policy_net.state_dict())

        # epsilon 衰减
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
        return loss.item()


# ---------- 3. 训练循环 ----------
def train_cartpole():
    env = gym.make("CartPole-v1", render_mode="rgb_array")
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    agent = DQNAgent(state_dim, action_dim)
    episode_rewards = deque(maxlen=20)

    for episode in range(300):
        state, _ = env.reset()
        total_reward = 0
        done = False

        while not done:
            action = agent.select_action(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            agent.replay_buffer.push(state, action, reward, next_state, done)
            state = next_state
            total_reward += reward
            agent.train_step()

        episode_rewards.append(total_reward)
        avg_reward = np.mean(episode_rewards)

        if episode % 20 == 0:
            print(f"Episode {episode:03d} | Reward: {total_reward:5.1f} | "
                  f"Avg(20ep): {avg_reward:5.2f} | Epsilon: {agent.epsilon:.3f}")

        if avg_reward >= 475:
            print(f"\n[SUCCESS] 在第 {episode} 个 episode 达到目标平均奖励 {avg_reward:.1f}!")
            break

    env.close()
    return episode_rewards


if __name__ == "__main__":
    rewards = train_cartpole()

你会看到 reward 从初期几帧迅速攀升到接近 500(CartPole 满分),说明 DQN 学会了平衡杆。


5. 机器人导航实战:用 Gymnasium GridWorld

现在把同样的 DQN 框架迁移到自定义的网格世界导航环境。目标:从任意起点学会绕开障碍物走到目标格子。

python 复制代码
import numpy as np
import gymnasium as gym


class GridWorldEnv(gym.Env):
    metadata = {"render_modes": ["human"]}

    def __init__(self, size=5):
        super().__init__()
        self.size = size
        self.observation_space = gym.spaces.Box(0, size - 1, shape=(2,), dtype=np.float32)
        self.action_space = gym.spaces.Discrete(4)  # 0=up, 1=down, 2=left, 3=right
        self._agent_pos = None
        self._target_pos = None
        self.grid = self._generate_grid()
        self.max_steps = size * size * 2
        self.step_count = 0

    def _generate_grid(self):
        # 0=通路, 1=障碍, 2=目标
        grid = np.zeros((self.size, self.size), dtype=np.int32)
        grid[1, 2] = 1
        grid[2, 2] = 1
        grid[3, 1] = 1
        return grid

    def reset(self, seed=None):
        super().reset(seed=seed)
        while True:
            self._agent_pos = np.array([
                self.np_random.integers(0, self.size),
                self.np_random.integers(0, self.size)
            ])
            if self.grid[tuple(self._agent_pos)] == 0:
                break
        while True:
            self._target_pos = np.array([
                self.np_random.integers(0, self.size),
                self.np_random.integers(0, self.size)
            ])
            if (self.grid[tuple(self._target_pos)] == 0 and
                not np.array_equal(self._target_pos, self._agent_pos)):
                break
        self.step_count = 0
        return self._get_obs(), {}

    def step(self, action):
        moves = {0: [-1, 0], 1: [1, 0], 2: [0, -1], 3: [0, 1]}
        delta = np.array(moves[action])
        new_pos = np.clip(self._agent_pos + delta, 0, self.size - 1)

        if self.grid[tuple(new_pos)] == 1:  # 撞墙
            reward = -2
            new_pos = self._agent_pos.copy()
        elif np.array_equal(new_pos, self._target_pos):  # 到达目标
            reward = 10
        else:  # 正常移动
            reward = -0.1

        self._agent_pos = new_pos
        self.step_count += 1
        done = np.array_equal(self._agent_pos, self._target_pos)
        return self._get_obs(), reward, done, False, {}

    def _get_obs(self):
        return np.concatenate([self._agent_pos / self.size, self._target_pos / self.size]).astype(np.float32)

    def render(self):
        grid_vis = self.grid.copy().astype(str)
        grid_vis[tuple(self._agent_pos)] = "A"
        grid_vis[tuple(self._target_pos)] = "T"
        grid_vis[grid_vis == "0"] = "."
        grid_vis[grid_vis == "1"] = "#"
        print("\n" + "\n".join(" ".join(row) for row in grid_vis))



# 训练
from dqn_agent import DQNAgent  # 复用上面的 Agent

def train_gridworld():
    env = GridWorldEnv(size=5)
    state_dim = 4  # [agent_x, agent_y, target_x, target_y]
    action_dim = 4

    agent = DQNAgent(state_dim, action_dim)
    agent.gamma = 0.95
    agent.epsilon = 1.0

    recent = deque(maxlen=50)
    for ep in range(500):
        s, _ = env.reset()
        total_r = 0
        done = False
        while not done:
            a = agent.select_action(s)
            s2, r, done, _, _ = env.step(a)
            agent.replay_buffer.push(s, a, r, s2, done)
            s = s2
            total_r += r
            agent.train_step()
        recent.append(total_r)
        if ep % 50 == 0:
            print(f"Ep {ep:03d} | Avg reward: {np.mean(recent):5.2f} | eps: {agent.epsilon:.3f}")
        if np.mean(recent) > 8.5:
            print(f"[DONE] 学会了!平均奖励 {np.mean(recent):.2f}")
            break

if __name__ == "__main__":
    train_gridworld()

💡 为什么 DQN 能泛化? 即使输入是连续坐标,网络也能学到"障碍附近要绕开"的模式,因为网络权重对相似状态输出相似 Q 值。


6. 超参数调优实战经验

参数 默认值 调优建议
学习率 1e-3 从 1e-3 开始,DQN 通常用更小
Batch Size 64 太小梯度不稳定,太大收敛慢
Epsilon 衰减 0.995 太快学不到东西,太慢探索效率低
Target Network 更新频率 每 10 步 更新越频繁收敛越快但越不稳定
Replay Buffer 大小 10000 任务越复杂 Buffer 越大越好
Discount Factor γ 0.99 γ 越大越重视长期奖励

7. 总结

DQN 的精髓在三个组件:

  • Q-Network:用神经网络逼近 Q(s,a),突破 Q 表的天花板
  • Experience Replay:打乱样本相关性,稳定训练
  • Target Network:让 TD target 固定住,不再"自己追自己"

有了这三件套 + PyTorch,你可以在任何 Gymnasium 环境下快速验证 DQN------从 CartPole 到机器人导航,框架完全一样。


参考资料

相关推荐
龙文浩_2 小时前
AI深度学习核心机制解析
人工智能·pytorch·深度学习·神经网络
liliangcsdn2 小时前
LLM如何以ReAct Agent方式统计分析去重后数据
数据库·人工智能·全文检索
这张生成的图像能检测吗2 小时前
(论文速读)FD-LLM:将振动信号编码为文本表示来将振动信号与大型语言模型进行对齐
人工智能·深度学习·语言模型·智能制造·故障诊断
圣殿骑士-Khtangc2 小时前
Amazon CodeWhisperer 超详细使用教程:AWS 云原生 AI 编程助手上手指南
人工智能·ai编程·aws·编程助手·codewhisperer
花千树-0102 小时前
IndexTTS2 入门指南:从模型概念到 macOS 安装实战
人工智能·ai·chatgpt·aigc
阿钱真强道2 小时前
01 飞腾 S5000C 服务器环境搭建实战:PyTorch + CUDA + RTX 4090D 安装与验证
pytorch·cuda·aarch64·深度学习环境搭建·飞腾服务器·s5000c·rtx4090d
淡忘旧梦2 小时前
ChatGPT回答白屏
人工智能·chatgpt·代理模式
望百川归海2 小时前
FS-SAM2微调和推理加速
人工智能
lifallen2 小时前
Flink Agent:ActionTask 与可续跑状态机 (Coroutine/Continuation)
java·大数据·人工智能·语言模型·flink