PPO GAE优势函数演化和推导

文章目录

📚 GAE公式推导:从蒙特卡洛到TD的优雅折中

让我用小步慢走的方式,带你理解GAE的推导过程。我们从最基础的概念开始,每一步都确保你能跟上。


第一步:什么是优势函数?

核心问题

在某个状态 s t s_t st 下,我采取动作 a t a_t at 到底好不好

优势函数定义

A ( s t , a t ) = Q ( s t , a t ) − V ( s t ) A(s_t, a_t) = Q(s_t, a_t) - V(s_t) A(st,at)=Q(st,at)−V(st)

符号 含义 通俗理解
Q ( s t , a t ) Q(s_t, a_t) Q(st,at) 动作价值函数 采取动作 a t a_t at 能拿到的期望回报
V ( s t ) V(s_t) V(st) 状态价值函数 状态 s t s_t st 的平均期望回报
A ( s t , a t ) A(s_t, a_t) A(st,at) 优势函数 这个动作比平均水平好多少

直观理解

复制代码
状态 s_t 下:
├── 所有动作的平均回报 = V(s_t) = 10分
├── 动作 a_t 的回报     = Q(s_t, a_t) = 15分
└── 优势 A(s_t, a_t)   = 15 - 10 = +5分  ✅ 好动作!

优势函数的意义

  • A > 0 A > 0 A>0:这个动作比平均水平好
  • A = 0 A = 0 A=0:这个动作就是平均水平
  • A < 0 A < 0 A<0:这个动作比平均水平差

第二步:如何估计优势函数?

问题:我们不知道真实的 Q Q Q 和 V V V

我们需要用采样数据来估计优势函数。有两种经典方法:


第三步:方法一 ------ 蒙特卡洛(MC)估计

核心思想

实际拿到的完整回报 来估计 Q ( s t , a t ) Q(s_t, a_t) Q(st,at)

MC的优势估计公式

A ^ t M C = ∑ k = 0 T − t γ k r t + k ⏟ 实际回报 G t − V ( s t ) \hat{A}^{MC}t = \underbrace{\sum{k=0}^{T-t} \gamma^k r_{t+k}}_{\text{实际回报 } G_t} - V(s_t) A^tMC=实际回报 Gt k=0∑T−tγkrt+k−V(st)

图解MC

复制代码
时间线:t ──→ t+1 ──→ t+2 ──→ ... ──→ T(终点)
         │       │       │            │
         r_t     r_{t+1} r_{t+2}      r_T
         ↓       ↓       ↓            ↓
        全部加起来 = G_t (实际回报)
        
优势估计 = G_t - V(s_t)

MC的优缺点

优点 ✅ 缺点 ❌
无偏估计:期望等于真实值 高方差:不同轨迹差异大
概念简单 必须等到 episode 结束才能计算
不需要价值函数准确 样本效率低

为什么方差高?

想象掷骰子游戏:

  • 轨迹1:🎲 6, 6, 6, 1 → 回报 19
  • 轨迹2:🎲 1, 1, 1, 6 → 回报 9
  • 轨迹3:🎲 3, 3, 3, 3 → 回报 12

真实期望可能是 13,但单次采样可能偏差很大!


第四步:方法二 ------ TD(时序差分)估计

核心思想

不用等到终点,用一步 lookahead + 价值函数估计

TD的优势估计公式(1步TD)

A ^ t T D ( 1 ) = r t + γ V ( s t + 1 ) ⏟ TD目标 − V ( s t ) \hat{A}^{TD(1)}t = \underbrace{r_t + \gamma V(s{t+1})}_{\text{TD目标}} - V(s_t) A^tTD(1)=TD目标 rt+γV(st+1)−V(st)

图解TD(1)

复制代码
时间线:t ──→ t+1
         │       │
         r_t     V(s_{t+1})  ← 用价值函数估计未来
         ↓       ↓
        r_t + γ·V(s_{t+1}) = TD目标
        
优势估计 = TD目标 - V(s_t)

引入TD误差(TD Error)

定义 TD误差 δ t \delta_t δt:

δ t = r t + γ V ( s t + 1 ) − V ( s t ) \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t) δt=rt+γV(st+1)−V(st)

重要发现:1步TD的优势估计就是TD误差!

A ^ t T D ( 1 ) = δ t \hat{A}^{TD(1)}_t = \delta_t A^tTD(1)=δt

TD的优缺点

优点 ✅ 缺点 ❌
低方差:不用等完整轨迹 有偏估计:依赖价值函数准确性
可以在线学习 如果 V ( s ) V(s) V(s) 不准,估计就偏了
样本效率高 偏差可能累积

第五步:MC vs TD 的核心矛盾

偏差-方差权衡

复制代码
                    偏差-方差光谱
                    
    蒙特卡洛(MC) ←───────────→ TD(1)
         │                        │
         ▼                        ▼
      偏差 = 0                  偏差 高
      方差 高                   方差 低

直观对比

场景 MC表现 TD表现
价值函数 V ( s ) V(s) V(s) 很准 浪费信息 ✅ 很好
价值函数 V ( s ) V(s) V(s) 不准 ✅ 不受影响 估计偏差大
轨迹很长 方差爆炸 ✅ 稳定
需要快速学习 太慢 ✅ 快速

核心问题

有没有一种方法,能在MC和TD之间灵活调节?


第六步:方法三 ------ n步TD估计

核心思想

n 步 实际回报,然后用价值函数估计剩余部分

n步TD的优势估计公式

A ^ t T D ( n ) = r t + γ r t + 1 + . . . + γ n − 1 r t + n − 1 ⏟ n步实际回报 + γ n V ( s t + n ) ⏟ n步后估计 − V ( s t ) \hat{A}^{TD(n)}t = \underbrace{r_t + \gamma r{t+1} + ... + \gamma^{n-1} r_{t+n-1}}{\text{n步实际回报}} + \underbrace{\gamma^n V(s{t+n})}_{\text{n步后估计}} - V(s_t) A^tTD(n)=n步实际回报 rt+γrt+1+...+γn−1rt+n−1+n步后估计 γnV(st+n)−V(st)

图解n步TD

复制代码
时间线:t ──→ t+1 ──→ ... ──→ t+n-1 ──→ t+n
         │       │                     │       │
         r_t     r_{t+1}               r_{t+n-1} V(s_{t+n})
         ↓       ↓                     ↓       ↓
        └─────────── 实际回报 ────────────┘   ↓
                                        γ^n·V(s_{t+n})
                                        
优势估计 = (n步回报 + γ^n·V) - V(s_t)

n步TD的TD误差形式

可以证明(后面推导):

A ^ t T D ( n ) = ∑ k = 0 n − 1 γ k δ t + k \hat{A}^{TD(n)}t = \sum{k=0}^{n-1} \gamma^k \delta_{t+k} A^tTD(n)=k=0∑n−1γkδt+k

其中 δ t + k = r t + k + γ V ( s t + k + 1 ) − V ( s t + k ) \delta_{t+k} = r_{t+k} + \gamma V(s_{t+k+1}) - V(s_{t+k}) δt+k=rt+k+γV(st+k+1)−V(st+k)

n的变化趋势

n值 方法 偏差 方差
n=1 TD(1)
n=5 TD(5)
n=∞ MC 0

第七步:关键洞察 ------ 多个n步TD的平均

核心问题

选哪个n最好?n=1?n=5?n=10?

聪明人的想法

为什么要选一个?为什么不全部用上?

加权平均思想

A ^ t 混合 = w 1 A ^ t T D ( 1 ) + w 2 A ^ t T D ( 2 ) + w 3 A ^ t T D ( 3 ) + . . . \hat{A}^{混合}_t = w_1 \hat{A}^{TD(1)}_t + w_2 \hat{A}^{TD(2)}_t + w_3 \hat{A}^{TD(3)}_t + ... A^t混合=w1A^tTD(1)+w2A^tTD(2)+w3A^tTD(3)+...

权重怎么设?

自然想到指数衰减:越远的n步,权重越小

w n ∝ ( 1 − λ ) λ n − 1 w_n \propto (1-\lambda) \lambda^{n-1} wn∝(1−λ)λn−1

其中 λ ∈ [ 0 , 1 ] \lambda \in [0, 1] λ∈[0,1] 是调节参数


第八步:GAE的正式推导

推导起点:n步TD的TD误差形式

首先记住这个关键等式(可以证明):

A ^ t T D ( n ) = ∑ k = 0 n − 1 γ k δ t + k \hat{A}^{TD(n)}t = \sum{k=0}^{n-1} \gamma^k \delta_{t+k} A^tTD(n)=k=0∑n−1γkδt+k

GAE定义:指数加权的n步TD平均

A ^ t G A E = ( 1 − λ ) ∑ n = 1 ∞ λ n − 1 A ^ t T D ( n ) \hat{A}^{GAE}t = (1-\lambda) \sum{n=1}^{\infty} \lambda^{n-1} \hat{A}^{TD(n)}_t A^tGAE=(1−λ)n=1∑∞λn−1A^tTD(n)

展开推导(关键步骤!)

让我一步一步展开:

第1步 :代入 A ^ t T D ( n ) \hat{A}^{TD(n)}_t A^tTD(n)

A ^ t G A E = ( 1 − λ ) ∑ n = 1 ∞ λ n − 1 ( ∑ k = 0 n − 1 γ k δ t + k ) \hat{A}^{GAE}t = (1-\lambda) \sum{n=1}^{\infty} \lambda^{n-1} \left( \sum_{k=0}^{n-1} \gamma^k \delta_{t+k} \right) A^tGAE=(1−λ)n=1∑∞λn−1(k=0∑n−1γkδt+k)

第2步 :交换求和顺序(把 δ \delta δ 提到外层)

A ^ t G A E = ( 1 − λ ) ∑ k = 0 ∞ γ k δ t + k ( ∑ n = k + 1 ∞ λ n − 1 ) \hat{A}^{GAE}t = (1-\lambda) \sum{k=0}^{\infty} \gamma^k \delta_{t+k} \left( \sum_{n=k+1}^{\infty} \lambda^{n-1} \right) A^tGAE=(1−λ)k=0∑∞γkδt+k(n=k+1∑∞λn−1)

第3步:计算内层几何级数

∑ n = k + 1 ∞ λ n − 1 = λ k + λ k + 1 + λ k + 2 + . . . = λ k 1 − λ \sum_{n=k+1}^{\infty} \lambda^{n-1} = \lambda^k + \lambda^{k+1} + \lambda^{k+2} + ... = \frac{\lambda^k}{1-\lambda} n=k+1∑∞λn−1=λk+λk+1+λk+2+...=1−λλk

第4步:代入并化简

A ^ t G A E = ( 1 − λ ) ∑ k = 0 ∞ γ k δ t + k ⋅ λ k 1 − λ \hat{A}^{GAE}t = (1-\lambda) \sum{k=0}^{\infty} \gamma^k \delta_{t+k} \cdot \frac{\lambda^k}{1-\lambda} A^tGAE=(1−λ)k=0∑∞γkδt+k⋅1−λλk

( 1 − λ ) 和 1 1 − λ 抵消! (1-\lambda) \text{ 和 } \frac{1}{1-\lambda} \text{ 抵消!} (1−λ) 和 1−λ1 抵消!

A ^ t G A E = ∑ k = 0 ∞ γ k λ k δ t + k \hat{A}^{GAE}t = \sum{k=0}^{\infty} \gamma^k \lambda^k \delta_{t+k} A^tGAE=k=0∑∞γkλkδt+k

第5步 :合并 γ \gamma γ 和 λ \lambda λ

A ^ t G A E = ∑ k = 0 ∞ ( γ λ ) k δ t + k \boxed{\hat{A}^{GAE}t = \sum{k=0}^{\infty} (\gamma \lambda)^k \delta_{t+k}} A^tGAE=k=0∑∞(γλ)kδt+k


第九步:GAE最终公式

完整形式

A ^ t G A E ( γ , λ ) = ∑ l = 0 T − t − 1 ( γ λ ) l δ t + l \hat{A}^{GAE(\gamma, \lambda)}t = \sum{l=0}^{T-t-1} (\gamma \lambda)^l \delta_{t+l} A^tGAE(γ,λ)=l=0∑T−t−1(γλ)lδt+l

其中:

  • δ t + l = r t + l + γ V ( s t + l + 1 ) − V ( s t + l ) \delta_{t+l} = r_{t+l} + \gamma V(s_{t+l+1}) - V(s_{t+l}) δt+l=rt+l+γV(st+l+1)−V(st+l) 是TD误差
  • γ \gamma γ 是折扣因子(通常0.99)
  • λ \lambda λ 是GAE参数(通常0.95)

展开写出来

A ^ t G A E = δ t + ( γ λ ) δ t + 1 + ( γ λ ) 2 δ t + 2 + ( γ λ ) 3 δ t + 3 + . . . \hat{A}^{GAE}t = \delta_t + (\gamma\lambda)\delta{t+1} + (\gamma\lambda)^2\delta_{t+2} + (\gamma\lambda)^3\delta_{t+3} + ... A^tGAE=δt+(γλ)δt+1+(γλ)2δt+2+(γλ)3δt+3+...

图解GAE

复制代码
时间线:t ──→ t+1 ──→ t+2 ──→ t+3 ──→ ...
         │       │       │       │
         δ_t     δ_{t+1} δ_{t+2} δ_{t+3}
         │       │       │       │
         ×1     ×γλ    ×(γλ)²  ×(γλ)³
         ↓       ↓       ↓       ↓
         └───────┴───────┴───────┴──→ 加权求和 = GAE

第十步:λ参数的神奇作用

λ的调节效果

λ值 公式简化 接近方法 偏差 方差
λ = 0 A ^ t G A E = δ t \hat{A}^{GAE}_t = \delta_t A^tGAE=δt TD(1)
λ = 0.5 中等权重 折中
λ = 1 A ^ t G A E = ∑ γ k δ t + k \hat{A}^{GAE}t = \sum \gamma^k \delta{t+k} A^tGAE=∑γkδt+k MC 0

可视化

复制代码
λ = 0          λ = 0.5        λ = 1
  │              │              │
  ▼              ▼              ▼
┌─────┐      ┌─────────┐    ┌─────────┐
│ TD  │ ←──→ │  GAE    │ ←→ │   MC    │
│(1步)│      │ (折中)  │    │ (完整)  │
└─────┘      └─────────┘    └─────────┘
  │              │              │
方差低          方差中         方差高
偏差高          偏差中         偏差低

实际推荐值

参数 推荐值 说明
γ \gamma γ 0.99 折扣因子
λ \lambda λ 0.95 GAE参数

这个组合在大多数任务中表现良好!


第十一步:为什么GAE是"最佳折中"?

数学直觉

GAE本质上是在做:

GAE = 指数加权平均 ( TD ( 1 ) , TD ( 2 ) , TD ( 3 ) , . . . ) \text{GAE} = \text{指数加权平均}(\text{TD}(1), \text{TD}(2), \text{TD}(3), ...) GAE=指数加权平均(TD(1),TD(2),TD(3),...)

三大优势

优势 解释
🎯 可调的偏差-方差权衡 通过λ灵活控制
🎯 计算高效 只需一次向后遍历
🎯 理论保证 基于TD(λ)的坚实理论基础

实际计算(向后遍历)

python 复制代码
# 伪代码:GAE的实际计算
gae = 0
advantages = []

for t in reversed(range(T)):  # 从后往前
    delta = r_t + gamma * V(s_{t+1}) - V(s_t)
    gae = delta + gamma * lambda * gae  # 递归计算
    advantages.insert(0, gae)

第十二步:完整对比总结

三种方法的终极对比

特性 MC TD(1) GAE
公式 ∑ γ k r − V \sum \gamma^k r - V ∑γkr−V δ t \delta_t δt ∑ ( γ λ ) k δ \sum (\gamma\lambda)^k \delta ∑(γλ)kδ
偏差 0 可调
方差 可调
需要episode结束
依赖V函数准确性 部分
实际效果 不稳定 可能偏差大 最佳平衡

演化路线图

复制代码
                    偏差-方差权衡的演化
                    
    1992              2000s              2016
      │                 │                  │
      ▼                 ▼                  ▼
   ┌──────┐         ┌──────┐          ┌──────┐
   │  MC  │  ────→  │ TD   │  ────→   │ GAE  │
   │高方差│         │高偏差│          │可调  │
   └──────┘         └──────┘          └──────┘
      │                 │                  │
      └─────────────────┴──────────────────┘
                        │
                        ▼
              🎯 GAE = MC + TD 的最佳折中

📝 核心要点回顾

  1. 优势函数 = 动作价值 - 状态价值(衡量动作好坏)

  2. MC估计 = 用完整回报,无偏但高方差

  3. TD估计 = 用1步 lookahead,低方差但有偏

  4. n步TD = 走n步实际 + 剩余估计,介于两者之间

  5. GAE = 所有n步TD的指数加权平均

  6. λ参数 = 调节MC和TD的权重(λ=0偏TD,λ=1偏MC)

  7. 最终公式 : A ^ t G A E = ∑ l = 0 ∞ ( γ λ ) l δ t + l \hat{A}^{GAE}t = \sum{l=0}^{\infty} (\gamma \lambda)^l \delta_{t+l} A^tGAE=∑l=0∞(γλ)lδt+l


希望这个小步慢走的讲解能帮你理解GAE的推导!有任何一步不清楚,欢迎继续提问 🎯

相关推荐
Jasmine_llq1 小时前
《P3572 [POI 2014] PTA-Little Bird》
算法·滑动窗口·单调队列·动态规划(dp)·多组查询处理·循环优化(宏定义 rep)
tankeven2 小时前
HJ101 排序
c++·算法
流云鹤2 小时前
动态规划02
算法·动态规划
小白菜又菜2 小时前
Leetcode 236. Lowest Common Ancestor of a Binary Tree
python·算法·leetcode
不想看见4042 小时前
01 Matrix 基本动态规划:二维--力扣101算法题解笔记
c++·算法·leetcode
多恩Stone2 小时前
【3D-AICG 系列-12】Trellis 2 的 Shape VAE 的设计细节 Sparse Residual Autoencoding Layer
人工智能·python·算法·3d·aigc
踢足球09292 小时前
寒假打卡:2026-2-23
数据结构·算法
田里的水稻3 小时前
FA_建图和定位(ML)-超宽带(UWB)定位
人工智能·算法·数学建模·机器人·自动驾驶
Navigator_Z3 小时前
LeetCode //C - 964. Least Operators to Express Number
c语言·算法·leetcode