RL中GAE的计算过程详解

GAE,Generalized Advantage Estimator,是在强化学习中用于估计优势函数的重要技术。

GAE结合TD误差的多步估计,通过调节λ参数可以在偏差和方差之间取得平衡。

在强化学习算法如PPO、TRPO、A3C等,GAE是稳定训练的关键技术之一。

这里所用示例、代码参考和修改自网络资料。

1. 基础概念

1.1 优势函数

优势函数定义为:

其中:

Q(s,a) 是动作价值函数,在状态为s时,执行动作a时的价值。

V(s) 是状态价值函数,在状态为s时的平均价值。

优势函数,通过Q和S的差,衡量了在状态s下采取动作a相对于平均情况的好坏程度。

1.2 TD误差

时序差分(TD),结合了动态规划(DP)和蒙特卡洛(MC)方法。

引入MC方法,直接从经验中学习,不需要环境的动态模型。

引入DP思路,使用当前估计的值函数来更新值函数,而不需要等到一个完整的回合结束。

单步TD误差定义为:

其中:

  • 是时间步t的奖励

  • 是折扣因子

  • 是状态价值函数的估计

TD误差表示当前值函数的估计与更准确的估计之间的差异。

这个更准确的估计称为TD目标。

在TD学习中,使用TD误差来更新值函数:

其中α是学习率。

1.3 TD误差解释

时间步t时的价值会对时间步t+1的价值有影响,然而时间步t的价值不好估计。

对于T0, T1, ..., Tn,只有到达最后一步tn,才可以根据Tn结果估计价值。

所以RL中的价值采用时序相反方向估计,也就是说,的估计依赖于

越接近开始的状态,由于后续步骤的action不同,所以其价值也是不固定的。后续步骤t+1对当前步骤t的价值的影响,取决于后续一系列action所产生的影响。

比如,如果后续步骤选择得当,则可能导致任务成功,则当前状态t的价值就大。

否则如果后续步骤action选择不当导致任务失败,则当天状态t的价值就相对较小。

在迭代过程中,为简化分析,采用马尔可夫式的链式估计,即t步的价值只受t+1步的价值和t步时采用action获得的奖励有关系。

结合上述公式,t步时V估计的平均价值为,t+1步平均价值为

t步在状态时,执行action导致产生奖励,此时t步在状态为s时的实际价值为

,此时步价值估计的误差为

该误差的物理意义理解为由于action导致t步价值的变化量。

1.4 TD与优势函数

在策略梯度方法中,通常需要估计优势函数A(s,a) = Q(s,a) - V(s)。

然而,通常只学习V(s)函数,而Q(s,a)可以通过TD误差来估计。

实际上,TD误差可以被看作是对优势函数的一个估计。

所以有

但这个估计只使用了一步的奖励,因此是带有偏差的。

为了减少偏差,可以使用多步的奖励,这就是GAE的基本思想。

从另外一个角度看,对于时间步t,怎么用这些要素表示t的优势函数呢。

合理设想是将t开始后续时间步的action导致的价值变化量进行累加,这就是GAE的雏形。

2 GAE

GAE利用TD误差的多步加权和来估计优势函数,从而在策略梯度方法中提供了一个在偏差和方差之间平衡的估计。通过调整λ,我们可以控制这个平衡点。

2.1 GAE核心

GAE是将优势函数表示为TD误差的加权和,通过引入参数λ来调节偏差和方差的权衡:

其中:

  • γ 是折扣因子(通常0<γ≤1)

  • λ 是GAE参数(通常0≤λ≤1)

  • 是第(t+l)步的TD误差

这个式子可以这样理解,我们考虑从时刻t开始,往后每一步的TD误差。

但是给这些TD误差加上一个衰减权重

当λ=0时,GAE就退化为单步TD误差;

当λ=1时,GAE就等价于蒙特卡洛方法,即使用整个轨迹的累积奖励减去基线

GAE通过引入λ,在偏差和方差之间做了一个权衡:

当λ较小时,估计主要依赖于近期的TD误差,因此偏差较大,因为单步TD误差是有偏的,但方差较小。

当λ较大时,估计依赖于更多的步数,偏差减小(因为多步回报更接近真实的回报),但方差增大。

通常,λ是一个超参数,需要根据具体问题进行调整。

2.2 前置条件

假设我们有以下数据:

一条轨迹:

状态价值估计:

参数:折扣因子,GAE参数

2.3 TD误差

首先计算每个时间步的TD误差:

(对于t=0到T-1)

对于终止状态T,如果轨迹结束,通常设V(s_{T+1})=0

2.4 递归形式

参考之前内容,GAE优势函数一般采用马尔可夫链相反方向,从后向前递归估计。

# 终止状态的GAE为0

, 对于t=T-1到0

2.5 计算示例

以下时GAE计算的代码示例。

复制代码
def compute_gae(rewards, values, gamma=0.99, lambda_=0.95):
    """
    计算GAE
    
    参数:
    rewards: 奖励序列 [r_0, r_1, ..., r_{T-1}]
    values: 状态价值估计 [V(s_0), V(s_1), ..., V(s_T)]
    gamma: 折扣因子
    lambda_: GAE参数
    
    返回:
    advantages: GAE优势估计 [Â_0, Â_1, ..., Â_{T-1}]
    returns: 回报估计
    """
    T = len(rewards)
    advantages = np.zeros(T)
    returns = np.zeros(T)
    
    # 从后向前计算
    gae = 0
    for t in reversed(range(T)):
        if t == T - 1:
            # 最后一个时间步
            next_value = 0  # 假设终止后价值为0
        else:
            next_value = values[t + 1]
        
        # 计算TD误差
        delta = rewards[t] + gamma * next_value - values[t]
        
        # 更新GAE
        gae = delta + gamma * lambda_ * gae
        advantages[t] = gae
        
        # 计算回报(可选)
        returns[t] = advantages[t] + values[t]
    
    return advantages, returns

2.6 递归形式

在实际计算中,通常使用递归形式从后往前计算GAE,可以高效地计算整个轨迹的GAE。

递归公式为:

来源如下,可以从定义中推导出来。

因此,可以从轨迹的最后一个时间步开始,逐步向前计算。

对于GPT模型,处理对象可以看作一个token序列。

对于么个token步,在选择token时需要进一步考虑后续计划输出什么。

对于后续不同的输出token序列,当前位置token的最优选择均步一样。

所以,GPT可能需要一个类似于GAE的计算去估计当前决策。

这可能也是当前LLM大量使用RL进行调优训练的原因之一。

reference


partial advantage estimator for proximal policy optimization

https://ar5iv.labs.arxiv.org/html/2301.10920

论文解读之优势函数GAE(GENERALIZED ADVANTAGE ESTIMATION)

https://blog.csdn.net/m0_72806612/article/details/146227542

Partial Advantage Estimator for Proximal Policy Optimization

https://qmro.qmul.ac.uk/xmlui/handle/123456789/98015?show=full

通用优势估计函数(GAE,Generalized Advantage Estimation)详解

https://blog.csdn.net/qq_38769809/article/details/148383621

相关推荐
Hgfdsaqwr2 小时前
内存泄漏检测与防范
开发语言·c++·算法
yhyvc2 小时前
人形具身机器人国产/进口快速选型优先级清单
人工智能·机器人
C雨后彩虹2 小时前
优雅子数组
java·数据结构·算法·华为·面试
wangmengxxw2 小时前
SpringAI-mysql
java·数据库·人工智能·mysql·springai
漫随流水2 小时前
leetcode回溯算法(46.全排列)
数据结构·算法·leetcode·回溯算法
考證寶題庫網2 小时前
Designing and Implementing a Microsoft Azure AI Solution 微軟Azure AI-102 認證全攻略
人工智能·microsoft·azure
We་ct2 小时前
LeetCode 68. 文本左右对齐:贪心算法的两种实现与深度解析
前端·算法·leetcode·typescript
努力学算法的蒟蒻2 小时前
day67(1.26)——leetcode面试经典150
算法·leetcode·面试
iAkuya2 小时前
(leetcode) 力扣100 52腐烂的橘子(BFS)
算法·leetcode·宽度优先