PyTorch强化学习实战(12)——Double DQN(DDQN)

PyTorch强化学习实战(12)------Double DQN(DDQN)

    • [0. 前言](#0. 前言)
    • [1. Double DQN 详解](#1. Double DQN 详解)
    • [2. Double DQN 实现](#2. Double DQN 实现)
    • [3. 运行结果](#3. 运行结果)
    • [4. 超参数调优](#4. 超参数调优)
    • 小结
    • 系列链接

0. 前言

自从 DeepMind2015 年提出深度Q网络 (Deep Q-Network, DQN)模型以来,研究人员已经提出了诸多改进方案,通过对基础架构的调整显著提升了原始 DQN 的收敛性、稳定性和样本效率。

2017DeepMindHessel 等人发表了名为 Rainbow: Combining improvements in deep reinforcement learning 的论文,系统性地整合了 DQN 的六大核心改进。仅通过这六种方法的组合,便在 Atari 游戏测试集上达到了当时的最高水平。尽管 2017 年后又涌现出许多新研究不断刷新记录,但 Rainbow 论文中的方法至今仍具有实用价值。本节将深入探讨 Double DQN (DDQN),解决 DQN 对动作价值的高估问题。本节将解析 DDQN 方法的原理、实现方案,并与经典 DQN 进行性能对比。

1. Double DQN 详解

深度Q网络 (Deep Q-Network, DQN)通过结合深度学习与Q-Learning,让智能体能够直接从高维输入(如原始像素)中学习策略,实现了在Atari游戏上超越人类水平的表现。然而,随着研究的深入,发现经典 DQN 存在一个致命的缺陷------价值估计过高。由于 DQN 在更新时使用同一个网络同时进行"选择动作"和"评估价值",这导致了系统性的正向偏差,使得智能体过于乐观地估计某些动作的价值,从而影响学习稳定性和最终策略。

经典 DQN 倾向于高估Q值,这种偏差不仅影响训练效果,还可能导致策略陷入局部最优。其根本原因是贝尔曼方程中的 max 操作,为了解决这一问题,研究者提出对贝尔曼更新公式进行改进。

在经典 DQN 中,目标Q值的计算方式为:

Q ( s t , a t ) = r t + γ m a x a Q ′ ( s t + 1 , a ) Q(s_t,a_t)=r_t+\gamma \underset{a}{max}Q'(s_{t+1},a) Q(st,at)=rt+γamaxQ′(st+1,a)

Q ′ ( s t + 1 , a ) Q'(s_{t+1},a) Q′(st+1,a) 是使用目标网络计算的Q值,该网络的权重每隔 n 步从训练网络同步一次。Double DQN (DDQN) 改用训练网络选择下一状态的动作,而仅从目标网络提取对应动作的Q值。因此改良后的目标Q值公式变为:

Q ( s t , a t ) = r t + γ m a x a Q ′ ( s t + 1 , m a x a Q ( s t + 1 , a ) ) Q(s_t,a_t)=r_t+\gamma \underset{a}{max}Q'(s_{t+1},\underset{a}{max}Q(s_{t+1},a)) Q(st,at)=rt+γamaxQ′(st+1,amaxQ(st+1,a))

研究者通过严格论证表明,这一巧妙调整能彻底解决Q值高估问题,在保持计算效率的同时,显著提升了算法的稳定性和策略质量。

2. Double DQN 实现

Double DQN 的改进核心实现非常简洁,仅需对损失函数进行微调。但我们更进一步,通过实验对比经典 DQNDouble DQN 生成的动作价值。根据论文结论,在相同状态下,经典 DQN 应该对相同的状态预测的价值高于 Double DQN。为此,我们随机选取一组保留状态集,并定期计算评估集中每个状态最优动作的平均价值。

(1)dqn _double.py 中,首先实现损失函数:

python 复制代码
def calc_loss_double_dqn(
        batch: tt.List[lib.experience.ExperienceFirstLast],
        net: nn.Module, tgt_net: nn.Module, gamma: float, device: torch.device):
    states, actions, rewards, dones, next_states = common.unpack_batch(batch)

    states_v = torch.as_tensor(states).to(device)
    actions_v = torch.tensor(actions).to(device)
    rewards_v = torch.tensor(rewards).to(device)
    done_mask = torch.BoolTensor(dones).to(device)

(2) 我们将用此函数替代原有的 common.calc_loss_dqn,二者共享大部分代码,但关键区别在于下一状态Q值的估算逻辑:

python 复制代码
    actions_v = actions_v.unsqueeze(-1)
    state_action_vals = net(states_v).gather(1, actions_v)
    state_action_vals = state_action_vals.squeeze(-1)
    with torch.no_grad():
        next_states_v = torch.as_tensor(next_states).to(device)
        next_state_acts = net(next_states_v).max(1)[1]
        next_state_acts = next_state_acts.unsqueeze(-1)
        next_state_vals = tgt_net(next_states_v).gather(1, next_state_acts).squeeze(-1)
        next_state_vals[done_mask] = 0.0
        exp_sa_vals = next_state_vals.detach() * gamma + rewards_v
    return nn.MSELoss()(state_action_vals, exp_sa_vals)

上述代码片段以略有不同的方式计算损失值。在 Double DQN 中,使用主网络计算下一状态应采取的最佳动作,但该动作对应的Q值则取自目标网络。

这一部分可以通过将 next_states_vstates_v 合并,仅调用一次主网络来实现更快的运算,但会降低代码可读性。

函数其余部分保持不变:屏蔽已完成的回合,并计算网络预测Q值与近似Q值之间的均方误差 (Mean Squared Error, MSE) 损失。

(3) 定义 calc_values_of_states() 函数用于计算保留状态集的价值:

python 复制代码
@torch.no_grad()
def calc_values_of_states(states: np.ndarray, net: nn.Module, device: torch.device):
    mean_vals = []
    for batch in np.array_split(states, 64):
        states_v = torch.tensor(batch).to(device)
        action_values_v = net(states_v)
        best_action_values_v = action_values_v.max(1)[0]
        mean_vals.append(best_action_values_v.mean().item())
    return np.mean(mean_vals)

(4) 将预留的状态数组分割成等份的块,并将每个块输入网络以获取动作价值。然后从这些价值中,我们为每个状态选择价值最大的动作,并计算这些价值的平均值。由于我们在整个训练过程中使用的状态数组是固定的,且这个数组足够大(在本节代码中,存储了 1000 个状态),因此我们可以比较这两个 DQN 变体中该平均值的动态变化。dqn_double.py 文件的其余部分几乎相同,区别在于:一是使用调整后的损失函数,二是保留随机采样的 1000 个状态用于定期评估。这些操作都在 process_batch 函数中完成:

python 复制代码
        if engine.state.iteration % EVAL_EVERY_FRAME == 0:
            eval_states = getattr(engine.state, "eval_states", None)
            if eval_states is None:
                eval_states = buffer.sample(STATES_TO_EVALUATE)
                eval_states = [
                    np.asarray(transition.state)
                    for transition in eval_states
                ]
                eval_states = np.asarray(eval_states)
                engine.state.eval_states = eval_states
            engine.state.metrics["values"] = \
                common.calc_values_of_states(eval_states, net, device)

3. 运行结果

实验数据表明,在使用常规超参数时,Double DQN 对奖励增长动态产生了负面影响。虽然 Double DQN 有时能带来更好的初期学习效果------智能体更快掌握获胜策略,但达到最终奖励阈值所需时间反而延长。我们可以在其他游戏环境中进行验证,或尝试论文中的原始参数配置。

下图展示了 Double DQN 表现略优于经典 DQN 的奖励曲线:

除标准指标外,程序还会输出保留状态集的Q值均值变化,如下图所示:

经典 DQN 存在明显的价值高估现象,其Q值在达到一定水平后开始下降。相比之下,Double DQN 的增长曲线则更为稳定。在本节实验中,Double DQN 对训练时间的影响很小,但这并不意味着 Double DQN 缺乏价值,因为 Pong 是一个简单环境。在更复杂的游戏中,Double DQN 可能展现出更显著的优势。

4. 超参数调优

实验表明,DDQN 的超参数调优不是很成功。经过 30 次试验,最佳学习率和折扣因子组合仍需 412 局游戏才能解决 Pong 游戏(比经典 DQN 的表现更差)。

小结

本文系统介绍了 Double DQN (DDQN) 对经典 DQN 的改进。DDQN 旨在解决 DQN 高估动作价值的缺陷。其核心改进在于将动作选择与价值评估解耦:使用主网络选择下一状态的最优动作,再用目标网络评估该动作的Q值,从而消除正向偏差。实验结果显示,DDQN 的价值估计更稳定,未出现明显高估。

系列链接

PyTorch强化学习实战(1)------强化学习(Reinforcement Learning,RL)详解

PyTorch强化学习实战(2)------强化学习环境库Gymnasium

PyTorch强化学习实战(3)------Gymnasium API扩展功能

PyTorch强化学习实战(4)------PyTorch基础

PyTorch强化学习实战(5)------PyTorch Ignite 事件驱动机制与实践

PyTorch强化学习实战(6)------交叉熵方法详解与实现

PyTorch强化学习实战(7)------表格学习与贝尔曼方程

PyTorch强化学习实战(8)------Q学习详解与实现

PyTorch强化学习实战(9)------深度Q学习

PyTorch强化学习实战(10)------强化学习高级组件

PyTorch强化学习实战(11)------N步DQN(N-step DQN)

相关推荐
winlife_1 小时前
全程用 AI 做一款商业级手游 · EP7 表现层与手感:从“能跑“到“摸起来爽“
java·开发语言·人工智能·unity·ai编程·游戏开发·mcp
一条泥憨鱼1 小时前
Harness Engineering(驾驭工程)零基础入门
网络·人工智能·harness·驾驭工程
AIHR数智引擎1 小时前
AI组织进化论:拆解微软、英伟达、Anthropic与Open AI如何重写组织
人工智能·经验分享·microsoft·职场和发展·aihr
2601_955767421 小时前
2026年iPhone17护眼钢化膜推荐:悟赫德测评
网络·人工智能·iphone·#观复盾护景贴·scinique双护技术
weisian1511 小时前
基础篇--概念原理-27-基座模型是什么?怎么理解?——从原理到实战,一篇讲透
人工智能·深度学习·基座模型
科技侃谈1 小时前
从协议打通到RAG工程化:北泰智能全栈自研智慧档案系统架构深度拆解
人工智能
Geek_Vison1 小时前
政务一网通APP如何引入AI能力,通过一个AI助手就能够调用所有的功能,实现对话即办事
人工智能·ai·小程序·uni-app·小程序容器
LaughingZhu1 小时前
Product Hunt 每日热榜 | 2026-06-08
人工智能·经验分享·深度学习·神经网络·产品运营
ar01231 小时前
AR远程协助在机务维修中的应用
人工智能·ar
星川皆无恙1 小时前
Python豆瓣电影数据分析可视化系统:爬虫采集+数据清洗+可视化大屏完整项目
人工智能·爬虫·python·数据分析