强化学习笔记之【TD3算法】

强化学习笔记之【TD3算法】


前言:

本文为强化学习笔记第三篇,第一篇讲的是Q-learning和DQN,第二篇讲的是DDPG

TD3就是比DDPG多了两个网络用来防止过估计,然后引入了延迟更新机制,就没了,还挺简单的

本文初编辑于2024.10.6

CSDN主页:https://blog.csdn.net/rvdgdsva

博客园主页:https://www.cnblogs.com/hassle

博客园本文链接:

三渲二(Celshading/toon shading) = Two-dimensional 3D = TD3,这很合理


首先,我们需要明确,Q-learning 算法发展成DQN 算法,DQN 算法发展成为DDPG 算法,而DDPG 算法发展成TD3(Twin Delayed DDPG)算法

这里有两个问题

一、DQN算法中存在过估计

DQN的目标是优化Q值函数。DQN采用离线学习的方式,通过计算Q值的目标来更新其网络参数。DQN的目标公式可以表示为:
y = r + γ × m a x ( Q ( s ′ , a ′ ; θ ′ ) ) y = r + γ \times max(Q(s', a'; θ')) y=r+γ×max(Q(s′,a′;θ′))

在这里,r是当前动作的价值,**max(Q(s', a'; θ'))**是后续动作中价值最大的动作的价值。

而往往在实际运行中没有办法达成最大价值这种理想情况

  • r 是当前时刻的奖励。
  • γ 是折扣因子,用于平衡未来奖励的重要性。
  • s' 是下一个状态,a' 是下一个动作。
  • Q(s', a'; θ') 是目标网络的 Q 值。
  • θ' 是目标网络的参数。

于是TD3算法用了一个讨巧但是有效的方式,搞两个网络,分别运行,取相对小价值的作为输出
y = r + γ × m i n ( Q 1 ( s ′ , π ( s ′ ) ) , Q 2 ( s ′ , π ( s ′ ) ) ) y = r + γ \times min(Q₁(s', π(s')), Q₂(s', π(s'))) y=r+γ×min(Q1(s′,π(s′)),Q2(s′,π(s′)))

  • r 是奖励。
  • γ 是折扣因子。
  • Q₁(s', π(s'))Q₂(s', π(s')) 是两个目标 Q 网络的 Q 值。
  • π(s') 是目标策略网络生成的下一个动作。
  • θ₁'θ₂' 是两个目标 Q 网络的参数,π 是目标策略网络的参数。

TD3 的核心改进在于 使用两个Q网络取最小值 来计算目标Q值,以减少过高估计问题(overestimation bias)。


二、DDPG算法存在局部最优

TD3 中的Delay 机制,主要体现在两个方面:延迟更新策略网络延迟更新目标Q网络 。这种延迟机制是对经典 DDPG(Deep Deterministic Policy Gradient)算法的一项重要改进

在 DDPG 中,策略网络(Actor)和Q网络(Critic)是交替更新的,这意味着策略网络在每次迭代时都能快速学习新的动作。然而,频繁更新策略网络会使它容易陷入局部最优,难以找到全局最优策略,尤其是在Q值估计不准确的情况下。

2.1 策略网络的优化原理

在强化学习中,策略网络(Actor)决定智能体在每个状态下应该采取的动作,目的是最大化未来的累积奖励。策略网络通过不断调整,学会选择能带来更高回报的动作。

当策略网络更新频繁时,它会在每一轮训练中快速调整自己的权重,期望基于当前的 Q 值(即 Critic 网络评估的动作价值)尽可能找到最优动作。

2.2 局部最优

局部最优是指在当前的策略空间中,智能体找到了一种动作选择,这种选择看起来已经是最好的(最大化了当前的 Q 值),但从更大的全局视角来看,这其实不是最优解。也就是说,策略可能在某个小范围内找到了一个"局部最佳解",但离真正的全局最优解还有差距。

2.3 频繁更新为什么导致局部最优

当策略网络更新频繁时,可能会过于迅速地朝着当前 Q 网络评估出的"最佳方向"移动,然而,Q 网络本身的估值在早期阶段可能还不够准确或稳定。

  • Q值的不稳定性:在强化学习的过程中,Q值估计在训练早期或数据不足时,常常会有误差。如果策略网络过于依赖这些不稳定的 Q 值进行快速更新,它会基于这些错误的估计来选择看似"最优"的动作,这可能让策略陷入一个局部最优,而没有足够时间探索更好的全局解。
  • 策略更新速度过快:频繁更新策略网络,会使其迅速调整到一个"看似最优"的策略上。但由于更新速度太快,智能体可能还没有足够的时间探索整个策略空间,因此很容易错过更优的动作选择。
  • 缺乏探索:策略网络在频繁更新过程中,可能对当前评估为"好的"动作过度偏好,而忽略了其他动作的探索,这样可能导致陷入局部最优,而未能发现更优的全局策略。
2.4 TD3中的改进

为了避免这种现象,TD3引入了延迟更新策略网络的机制。也就是说,在 Q 网络(Critic)经过多次更新后,策略网络(Actor)才会更新。通过这种延迟更新,策略网络能够基于更准确、更稳定的 Q 值进行更新,从而减少策略过快收敛到局部最优的风险。


三、DDPG算法和TD3算法代码对比

下面对比 DDPG 和 TD3 的代码或伪代码,特别是延迟策略更新双Q网络的改进部分

另外,TD3target_actor的输出后面加了个噪声

3.1 DDPG算法
python 复制代码
# DDPG
for each iteration:
    # 采样经验数据 (state, action, reward, next_state) from replay buffer
    batch = replay_buffer.sample()
    
    # Critic 更新 (Q网络)
    next_action = target_actor(next_state)
    target_q_value = reward + gamma * target_critic(next_state, next_action)
    critic_loss = mse(critic(state, action), target_q_value)
    critic_optimizer.zero_grad()
    critic_loss.backward()
    critic_optimizer.step()
    
    # Actor 更新 (策略网络)
    actor_loss = -critic(state, actor(state)).mean()  
    actor_optimizer.zero_grad()
    actor_loss.backward()
    actor_optimizer.step()

    # 软更新目标网络参数
    for param, target_param in zip(critic.parameters(), target_critic.parameters()):
        target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
    for param, target_param in zip(actor.parameters(), target_actor.parameters()):
        target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
3.2 TD3算法
python 复制代码
# TD3
for each iteration:
    # 采样经验数据 (state, action, reward, next_state) from replay buffer
    batch = replay_buffer.sample()

    # Critic 更新 (双 Q 网络)
    next_action = target_actor(next_state) + clip(noise, -noise_clip, noise_clip)  # 引入噪声
    target_q_value1 = target_critic1(next_state, next_action)
    target_q_value2 = target_critic2(next_state, next_action)
    target_q_value = reward + gamma * min(target_q_value1, target_q_value2)  # 取最小值

    critic1_loss = mse(critic1(state, action), target_q_value)
    critic_optimizer1.zero_grad()
    critic1_loss.backward()
    critic_optimizer1.step()

    critic2_loss = mse(critic2(state, action), target_q_value)
    critic_optimizer2.zero_grad()
    critic2_loss.backward()
    critic_optimizer2.step()

    # 延迟更新 Actor 网络
    if iteration % policy_delay == 0:  # 延迟更新策略网络
        actor_loss = -critic1(state, actor(state)).mean()  
        actor_optimizer.zero_grad()
        actor_loss.backward()
        actor_optimizer.step()

        # 软更新目标网络参数
        for param, target_param in zip(critic1.parameters(), target_critic1.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
        for param, target_param in zip(critic2.parameters(), target_critic2.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
        for param, target_param in zip(actor.parameters(), target_actor.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

四、总结:DDPG 与 TD3 的关键区别

4.1 双 Q 网络 (Critic)

在 DDPG 中只有一个 Q目标网络,而 TD3 中使用两个 Q 目标网络,主要的改进是通过计算两个 Q 网络的最小值来缓解 Q 值的过估计问题。

4.2 延迟更新策略网络 (Actor)

TD3 的策略网络并不在每次 Q 网络更新后都立即更新,而是隔几次才更新一次。policy_delay 决定了每几次更新 Critic 后更新 Actor 的频率。这样做的目的是为了让 Q 网络先稳定下来,防止 Actor 网络基于不准确的 Q 值进行优化。

4.3 目标策略平滑 (Target Policy Smoothing)

TD3 在生成目标动作时加入噪声,并且通过裁剪噪声来避免过大的扰动。这样可以增加策略的探索性,减少由于策略确定性导致的高估问题。


五、TD3和DDPG网络对比

5.1 DDPG 中的四个网络
  • Actor 网络(策略网络)
    • 作用:决定给定状态 ss 时,应该采取的动作 a=π(s)a=π(s),目标是找到最大化未来回报的策略。
    • 更新:基于 Critic 网络提供的 Q 值更新,以最大化 Critic 估计的 Q 值。
  • Target Actor 网络(目标策略网络)
    • 作用:为 Critic 网络提供更新目标,目的是让目标 Q 值的更新更为稳定。
    • 更新:使用软更新,缓慢向 Actor 网络靠近。
  • Critic 网络(Q 网络)
    • 作用:估计当前状态 ss 和动作 aa 的 Q 值,即 Q(s,a)Q(s,a),为 Actor 提供优化目标。
    • 更新:通过最小化与目标 Q 值的均方误差进行更新。
  • Target Critic 网络(目标 Q 网络)
    • 作用:生成 Q 值更新的目标,使得 Q 值更新更为稳定,减少振荡。
    • 更新:使用软更新,缓慢向 Critic 网络靠近。

DDPG 中的四个网络总结:

  • Actor 网络
  • Target Actor 网络
  • Critic 网络
  • Target Critic 网络
5.2 TD3 中的六个网络

TD3 相比 DDPG 增加了两个网络,使得总共有六个网络。多出的网络用于改进 Q 值估计的准确性。

  • Actor 网络(策略网络)
    • 与 DDPG 相同,决定状态 ss 时采取的动作 a=π(s)a=π(s)。
  • Target Actor 网络(目标策略网络)
    • 与 DDPG 相同,作为 Actor 网络的目标,更新更为平滑和稳定。
  • Critic 1 网络(第一个 Q 网络)
    • 估计给定状态 ss 和动作 aa 的 Q 值 Q1(s,a)Q1(s,a)。
  • Critic 2 网络(第二个 Q 网络)
    • 另一个 Q 网络,估计给定状态 ss 和动作 aa 的 Q 值 Q2(s,a)Q2(s,a)。目标是在 Q 值估计中避免过度高估。
  • Target Critic 1 网络(目标 Q 网络 1)
    • 作为 Critic 1 网络的目标,类似于 DDPG 中的 Target Critic 网络。
  • Target Critic 2 网络(目标 Q 网络 2)
    • 作为 Critic 2 网络的目标,用于为 Critic 2 网络生成更稳定的目标值。

TD3 中的六个网络总结:

  • Actor 网络
  • Target Actor 网络
  • Critic 1 网络
  • Critic 2 网络
  • Target Critic 1 网络
  • Target Critic 2 网络
5.3 对比总结
  • DDPG 里有四个网络:Actor、Target Actor、Critic、Target Critic。
  • TD3 里有六个网络:多了一个 Critic 2 和 Target Critic 2,用于减小 Q 值估计的偏差。
相关推荐
冷白白11 分钟前
【C++】C++对象初探及友元
c语言·开发语言·c++·算法
活跃的煤矿打工人12 分钟前
【星海saul随笔】Ubuntu基础知识
linux·运维·ubuntu
鹤上听雷19 分钟前
【AGC005D】~K Perm Counting(计数抽象成图)
算法
一叶祇秋31 分钟前
Leetcode - 周赛417
算法·leetcode·职场和发展
武昌库里写JAVA36 分钟前
【Java】Java面试题笔试
c语言·开发语言·数据结构·算法·二维数组
ya888g37 分钟前
GESP C++四级样题卷
java·c++·算法
fasewer1 小时前
第五章 linux实战-挖矿 二
linux·运维·服务器
Funny_AI_LAB1 小时前
MetaAI最新开源Llama3.2亮点及使用指南
算法·计算机视觉·语言模型·llama·facebook
NuyoahC1 小时前
算法笔记(十一)——优先级队列(堆)
c++·笔记·算法·优先级队列
jk_1011 小时前
MATLAB中decomposition函数用法
开发语言·算法·matlab