强化学习笔记之【TD3算法】

强化学习笔记之【TD3算法】


前言:

本文为强化学习笔记第三篇,第一篇讲的是Q-learning和DQN,第二篇讲的是DDPG

TD3就是比DDPG多了两个网络用来防止过估计,然后引入了延迟更新机制,就没了,还挺简单的

本文初编辑于2024.10.6

CSDN主页:https://blog.csdn.net/rvdgdsva

博客园主页:https://www.cnblogs.com/hassle

博客园本文链接:

三渲二(Celshading/toon shading) = Two-dimensional 3D = TD3,这很合理


首先,我们需要明确,Q-learning 算法发展成DQN 算法,DQN 算法发展成为DDPG 算法,而DDPG 算法发展成TD3(Twin Delayed DDPG)算法

这里有两个问题

一、DQN算法中存在过估计

DQN的目标是优化Q值函数。DQN采用离线学习的方式,通过计算Q值的目标来更新其网络参数。DQN的目标公式可以表示为:
y = r + γ × m a x ( Q ( s ′ , a ′ ; θ ′ ) ) y = r + γ \times max(Q(s', a'; θ')) y=r+γ×max(Q(s′,a′;θ′))

在这里,r是当前动作的价值,**max(Q(s', a'; θ'))**是后续动作中价值最大的动作的价值。

而往往在实际运行中没有办法达成最大价值这种理想情况

  • r 是当前时刻的奖励。
  • γ 是折扣因子,用于平衡未来奖励的重要性。
  • s' 是下一个状态,a' 是下一个动作。
  • Q(s', a'; θ') 是目标网络的 Q 值。
  • θ' 是目标网络的参数。

于是TD3算法用了一个讨巧但是有效的方式,搞两个网络,分别运行,取相对小价值的作为输出
y = r + γ × m i n ( Q 1 ( s ′ , π ( s ′ ) ) , Q 2 ( s ′ , π ( s ′ ) ) ) y = r + γ \times min(Q₁(s', π(s')), Q₂(s', π(s'))) y=r+γ×min(Q1(s′,π(s′)),Q2(s′,π(s′)))

  • r 是奖励。
  • γ 是折扣因子。
  • Q₁(s', π(s'))Q₂(s', π(s')) 是两个目标 Q 网络的 Q 值。
  • π(s') 是目标策略网络生成的下一个动作。
  • θ₁'θ₂' 是两个目标 Q 网络的参数,π 是目标策略网络的参数。

TD3 的核心改进在于 使用两个Q网络取最小值 来计算目标Q值,以减少过高估计问题(overestimation bias)。


二、DDPG算法存在局部最优

TD3 中的Delay 机制,主要体现在两个方面:延迟更新策略网络延迟更新目标Q网络 。这种延迟机制是对经典 DDPG(Deep Deterministic Policy Gradient)算法的一项重要改进

在 DDPG 中,策略网络(Actor)和Q网络(Critic)是交替更新的,这意味着策略网络在每次迭代时都能快速学习新的动作。然而,频繁更新策略网络会使它容易陷入局部最优,难以找到全局最优策略,尤其是在Q值估计不准确的情况下。

2.1 策略网络的优化原理

在强化学习中,策略网络(Actor)决定智能体在每个状态下应该采取的动作,目的是最大化未来的累积奖励。策略网络通过不断调整,学会选择能带来更高回报的动作。

当策略网络更新频繁时,它会在每一轮训练中快速调整自己的权重,期望基于当前的 Q 值(即 Critic 网络评估的动作价值)尽可能找到最优动作。

2.2 局部最优

局部最优是指在当前的策略空间中,智能体找到了一种动作选择,这种选择看起来已经是最好的(最大化了当前的 Q 值),但从更大的全局视角来看,这其实不是最优解。也就是说,策略可能在某个小范围内找到了一个"局部最佳解",但离真正的全局最优解还有差距。

2.3 频繁更新为什么导致局部最优

当策略网络更新频繁时,可能会过于迅速地朝着当前 Q 网络评估出的"最佳方向"移动,然而,Q 网络本身的估值在早期阶段可能还不够准确或稳定。

  • Q值的不稳定性:在强化学习的过程中,Q值估计在训练早期或数据不足时,常常会有误差。如果策略网络过于依赖这些不稳定的 Q 值进行快速更新,它会基于这些错误的估计来选择看似"最优"的动作,这可能让策略陷入一个局部最优,而没有足够时间探索更好的全局解。
  • 策略更新速度过快:频繁更新策略网络,会使其迅速调整到一个"看似最优"的策略上。但由于更新速度太快,智能体可能还没有足够的时间探索整个策略空间,因此很容易错过更优的动作选择。
  • 缺乏探索:策略网络在频繁更新过程中,可能对当前评估为"好的"动作过度偏好,而忽略了其他动作的探索,这样可能导致陷入局部最优,而未能发现更优的全局策略。
2.4 TD3中的改进

为了避免这种现象,TD3引入了延迟更新策略网络的机制。也就是说,在 Q 网络(Critic)经过多次更新后,策略网络(Actor)才会更新。通过这种延迟更新,策略网络能够基于更准确、更稳定的 Q 值进行更新,从而减少策略过快收敛到局部最优的风险。


三、DDPG算法和TD3算法代码对比

下面对比 DDPG 和 TD3 的代码或伪代码,特别是延迟策略更新双Q网络的改进部分

另外,TD3target_actor的输出后面加了个噪声

3.1 DDPG算法
python 复制代码
# DDPG
for each iteration:
    # 采样经验数据 (state, action, reward, next_state) from replay buffer
    batch = replay_buffer.sample()
    
    # Critic 更新 (Q网络)
    next_action = target_actor(next_state)
    target_q_value = reward + gamma * target_critic(next_state, next_action)
    critic_loss = mse(critic(state, action), target_q_value)
    critic_optimizer.zero_grad()
    critic_loss.backward()
    critic_optimizer.step()
    
    # Actor 更新 (策略网络)
    actor_loss = -critic(state, actor(state)).mean()  
    actor_optimizer.zero_grad()
    actor_loss.backward()
    actor_optimizer.step()

    # 软更新目标网络参数
    for param, target_param in zip(critic.parameters(), target_critic.parameters()):
        target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
    for param, target_param in zip(actor.parameters(), target_actor.parameters()):
        target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
3.2 TD3算法
python 复制代码
# TD3
for each iteration:
    # 采样经验数据 (state, action, reward, next_state) from replay buffer
    batch = replay_buffer.sample()

    # Critic 更新 (双 Q 网络)
    next_action = target_actor(next_state) + clip(noise, -noise_clip, noise_clip)  # 引入噪声
    target_q_value1 = target_critic1(next_state, next_action)
    target_q_value2 = target_critic2(next_state, next_action)
    target_q_value = reward + gamma * min(target_q_value1, target_q_value2)  # 取最小值

    critic1_loss = mse(critic1(state, action), target_q_value)
    critic_optimizer1.zero_grad()
    critic1_loss.backward()
    critic_optimizer1.step()

    critic2_loss = mse(critic2(state, action), target_q_value)
    critic_optimizer2.zero_grad()
    critic2_loss.backward()
    critic_optimizer2.step()

    # 延迟更新 Actor 网络
    if iteration % policy_delay == 0:  # 延迟更新策略网络
        actor_loss = -critic1(state, actor(state)).mean()  
        actor_optimizer.zero_grad()
        actor_loss.backward()
        actor_optimizer.step()

        # 软更新目标网络参数
        for param, target_param in zip(critic1.parameters(), target_critic1.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
        for param, target_param in zip(critic2.parameters(), target_critic2.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
        for param, target_param in zip(actor.parameters(), target_actor.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

四、总结:DDPG 与 TD3 的关键区别

4.1 双 Q 网络 (Critic)

在 DDPG 中只有一个 Q目标网络,而 TD3 中使用两个 Q 目标网络,主要的改进是通过计算两个 Q 网络的最小值来缓解 Q 值的过估计问题。

4.2 延迟更新策略网络 (Actor)

TD3 的策略网络并不在每次 Q 网络更新后都立即更新,而是隔几次才更新一次。policy_delay 决定了每几次更新 Critic 后更新 Actor 的频率。这样做的目的是为了让 Q 网络先稳定下来,防止 Actor 网络基于不准确的 Q 值进行优化。

4.3 目标策略平滑 (Target Policy Smoothing)

TD3 在生成目标动作时加入噪声,并且通过裁剪噪声来避免过大的扰动。这样可以增加策略的探索性,减少由于策略确定性导致的高估问题。


五、TD3和DDPG网络对比

5.1 DDPG 中的四个网络
  • Actor 网络(策略网络)
    • 作用:决定给定状态 ss 时,应该采取的动作 a=π(s)a=π(s),目标是找到最大化未来回报的策略。
    • 更新:基于 Critic 网络提供的 Q 值更新,以最大化 Critic 估计的 Q 值。
  • Target Actor 网络(目标策略网络)
    • 作用:为 Critic 网络提供更新目标,目的是让目标 Q 值的更新更为稳定。
    • 更新:使用软更新,缓慢向 Actor 网络靠近。
  • Critic 网络(Q 网络)
    • 作用:估计当前状态 ss 和动作 aa 的 Q 值,即 Q(s,a)Q(s,a),为 Actor 提供优化目标。
    • 更新:通过最小化与目标 Q 值的均方误差进行更新。
  • Target Critic 网络(目标 Q 网络)
    • 作用:生成 Q 值更新的目标,使得 Q 值更新更为稳定,减少振荡。
    • 更新:使用软更新,缓慢向 Critic 网络靠近。

DDPG 中的四个网络总结:

  • Actor 网络
  • Target Actor 网络
  • Critic 网络
  • Target Critic 网络
5.2 TD3 中的六个网络

TD3 相比 DDPG 增加了两个网络,使得总共有六个网络。多出的网络用于改进 Q 值估计的准确性。

  • Actor 网络(策略网络)
    • 与 DDPG 相同,决定状态 ss 时采取的动作 a=π(s)a=π(s)。
  • Target Actor 网络(目标策略网络)
    • 与 DDPG 相同,作为 Actor 网络的目标,更新更为平滑和稳定。
  • Critic 1 网络(第一个 Q 网络)
    • 估计给定状态 ss 和动作 aa 的 Q 值 Q1(s,a)Q1(s,a)。
  • Critic 2 网络(第二个 Q 网络)
    • 另一个 Q 网络,估计给定状态 ss 和动作 aa 的 Q 值 Q2(s,a)Q2(s,a)。目标是在 Q 值估计中避免过度高估。
  • Target Critic 1 网络(目标 Q 网络 1)
    • 作为 Critic 1 网络的目标,类似于 DDPG 中的 Target Critic 网络。
  • Target Critic 2 网络(目标 Q 网络 2)
    • 作为 Critic 2 网络的目标,用于为 Critic 2 网络生成更稳定的目标值。

TD3 中的六个网络总结:

  • Actor 网络
  • Target Actor 网络
  • Critic 1 网络
  • Critic 2 网络
  • Target Critic 1 网络
  • Target Critic 2 网络
5.3 对比总结
  • DDPG 里有四个网络:Actor、Target Actor、Critic、Target Critic。
  • TD3 里有六个网络:多了一个 Critic 2 和 Target Critic 2,用于减小 Q 值估计的偏差。
相关推荐
麻雀无能为力3 分钟前
python自学笔记14 NumPy 线性代数
笔记·python·numpy
Y|29 分钟前
GBDT(Gradient Boosting Decision Tree,梯度提升决策树)总结梳理
决策树·机器学习·集成学习·推荐算法·boosting
一枝小雨1 小时前
【数据结构】排序算法全解析
数据结构·算法·排序算法
略知java的景初1 小时前
深入解析十大经典排序算法原理与实现
数据结构·算法·排序算法
哈基鑫1 小时前
支持向量机(SVM)学习笔记
人工智能·机器学习·支持向量机
小白银子1 小时前
零基础从头教学Linux(Day 20)
linux·运维·服务器·php·国安工程师
岁忧1 小时前
(LeetCode 每日一题) 498. 对角线遍历 (矩阵、模拟)
java·c++·算法·leetcode·矩阵·go
kyle~2 小时前
C/C++---前缀和(Prefix Sum)
c语言·c++·算法
liweiweili1262 小时前
main栈帧和func栈帧的关系
数据结构·算法
古月-一个C++方向的小白2 小时前
Linux初始——基础指令篇
linux·运维·服务器