【强化学习】PPO(Proximal Policy Optimization,近端策略优化)算法

文章目录

  • [1. PPO算法简介](#1. PPO算法简介)
  • [2. 相关原理](#2. 相关原理)
    • [2.1. 策略梯度定理(Policy Gradient Theorem)](#2.1. 策略梯度定理(Policy Gradient Theorem))
      • [2.1.1. 核心思想](#2.1.1. 核心思想)
      • [2.1.2. 伪代码流程](#2.1.2. 伪代码流程)
    • [2.2. 重要性采样](#2.2. 重要性采样)
      • [2.2.1. 核心思想](#2.2.1. 核心思想)
      • [2.2.2. 在强化学习中的应用](#2.2.2. 在强化学习中的应用)
    • [2.3. KL散度](#2.3. KL散度)
      • [2.3.1. KL散度的定义](#2.3.1. KL散度的定义)
      • [2.3.2. 简单示例](#2.3.2. 简单示例)
    • [2.4. 优势函数 Advantage Function](#2.4. 优势函数 Advantage Function)
      • [2.4.1 优势函数的定义](#2.4.1 优势函数的定义)
      • [2.4.2. 优势函数的作用](#2.4.2. 优势函数的作用)
      • [2.4.3. 如何计算优势函数(实践中)](#2.4.3. 如何计算优势函数(实践中))
        • [方法一:1-step Advantage](#方法一:1-step Advantage)
        • [方法二:k-step Advantage](#方法二:k-step Advantage)
        • [方法三:GAE(广义优势估计,Generalized Advantage Estimation)](#方法三:GAE(广义优势估计,Generalized Advantage Estimation))
  • [3. PPO算法原理](#3. PPO算法原理)
    • [3.1. 核心思想](#3.1. 核心思想)
    • [3.2. 基本组成](#3.2. 基本组成)
      • [3.2.1. 使用策略梯度](#3.2.1. 使用策略梯度)
      • [3.2.2. GAE 提供优势估计](#3.2.2. GAE 提供优势估计)
      • [3.2.3. 策略目标函数](#3.2.3. 策略目标函数)
        • [(1)PPO-Clip 目标函数](#(1)PPO-Clip 目标函数)
        • [(2)KL Penalty目标函数(次要版本)](#(2)KL Penalty目标函数(次要版本))
    • [3.3. 算法流程](#3.3. 算法流程)
    • [3.4. PPO 的优点](#3.4. PPO 的优点)

1. PPO算法简介

PPO(Proximal Policy Optimization,近端策略优化)是强化学习中一种高效、稳定、易于实现的策略梯度方法,属于基于策略的方法。它由 OpenAI 在 2017 年提出,目的是在保持性能的同时简化实现复杂度。

2. 相关原理

2.1. 策略梯度定理(Policy Gradient Theorem)

2.1.1. 核心思想

策略梯度的核心是以下公式:

∇ θ J ( θ ) = E π θ [ ∇ θ log ⁡ π θ ( a ∣ s ) ⋅ Q π θ ( s , a ) ] \nabla_\theta J(\theta) = \mathbb{E}{\pi\theta} \left[ \nabla_\theta \log \pi_\theta(a|s) \cdot Q^{\pi_\theta}(s, a) \right] ∇θJ(θ)=Eπθ[∇θlogπθ(a∣s)⋅Qπθ(s,a)]

解释:

  • 用动作概率的对数梯度来表示策略的导数;
  • 乘上 Q Q Q 值(当前策略下执行该动作得到的预期回报)。
    基本算法流程(REINFORCE)

REINFORCE 是最早的策略梯度方法(Williams, 1992),它使用 Monte Carlo 方法估计 Q ( s , a ) Q(s, a) Q(s,a)。

2.1.2. 伪代码流程

  1. 初始化策略参数 θ \theta θ

  2. 重复直到收敛:

    • 用策略 π θ \pi_\theta πθ 与环境交互,收集一条或多条轨迹 τ \tau τ
    • 对于每个时间步 t t t,计算累计回报 G t = ∑ k = 0 ∞ γ k r t + k G_t = \sum_{k=0}^\infty \gamma^k r_{t+k} Gt=∑k=0∞γkrt+k
    • 用以下梯度上升方式更新参数:

θ ← θ + α ⋅ ∇ θ log ⁡ π θ ( a t ∣ s t ) ⋅ G t \theta \leftarrow \theta + \alpha \cdot \nabla_\theta \log \pi_\theta(a_t | s_t) \cdot G_t θ←θ+α⋅∇θlogπθ(at∣st)⋅Gt

2.2. 重要性采样

2.2.1. 核心思想

  • p ( x ) p(x) p(x):目标分布(我们想估计其期望)
  • q ( x ) q(x) q(x):实际采样的分布(行为策略)
  • w ( x ) = p ( x ) q ( x ) w(x) = \frac{p(x)}{q(x)} w(x)=q(x)p(x):重要性权重

我们用这些权重修正来自 q ( x ) q(x) q(x) 的样本,使得估计结果近似于 p ( x ) p(x) p(x) 下的期望。

2.2.2. 在强化学习中的应用

在强化学习中,重要性采样通常用于离策略学习(off-policy learning) :我们用一个行为策略 μ \mu μ 采集样本,但希望评估或优化一个目标策略 π \pi π。

例如,在策略梯度中:

E τ ∼ π [ ∇ θ log ⁡ π ( a ∣ s ) R ] ⇒ E τ ∼ μ [ π ( a ∣ s ) μ ( a ∣ s ) ∇ θ log ⁡ π ( a ∣ s ) R ] \mathbb{E}{\tau \sim \pi} [ \nabla\theta \log \pi(a|s) R ] \Rightarrow \mathbb{E}{\tau \sim \mu} \left[ \frac{\pi(a|s)}{\mu(a|s)} \nabla\theta \log \pi(a|s) R \right] Eτ∼π[∇θlogπ(a∣s)R]⇒Eτ∼μ[μ(a∣s)π(a∣s)∇θlogπ(a∣s)R]

这里:

  • μ ( a ∣ s ) \mu(a|s) μ(a∣s):行为策略(你生成样本用的策略)
  • π ( a ∣ s ) \pi(a|s) π(a∣s):目标策略(你要优化的策略)
  • π μ \frac{\pi}{\mu} μπ:重要性比率

2.3. KL散度

2.3.1. KL散度的定义

给定两个概率分布 P ( x ) P(x) P(x) 和 Q ( x ) Q(x) Q(x),KL 散度定义为:

D KL ( P ∥ Q ) = ∑ x P ( x ) log ⁡ P ( x ) Q ( x ) (离散) D_{\text{KL}}(P \| Q) = \sum_x P(x) \log \frac{P(x)}{Q(x)} \quad \text{(离散)} DKL(P∥Q)=x∑P(x)logQ(x)P(x)(离散)

D KL ( P ∥ Q ) = ∫ P ( x ) log ⁡ P ( x ) Q ( x ) d x (连续) D_{\text{KL}}(P \| Q) = \int P(x) \log \frac{P(x)}{Q(x)} dx \quad \text{(连续)} DKL(P∥Q)=∫P(x)logQ(x)P(x)dx(连续)

  • 它衡量:用 Q 来近似 P 时所造成的信息损失
  • 如果 P = Q P = Q P=Q,则 D K L ( P ∥ Q ) = 0 D_{KL}(P \| Q) = 0 DKL(P∥Q)=0;
  • D K L ( P ∥ Q ) ≥ 0 D_{KL}(P \| Q) \geq 0 DKL(P∥Q)≥0,但不对称:即 D K L ( P ∥ Q ) ≠ D K L ( Q ∥ P ) D_{KL}(P \| Q) \ne D_{KL}(Q \| P) DKL(P∥Q)=DKL(Q∥P)。

2.3.2. 简单示例

设:

  • P = [ 0.1 , 0.4 , 0.5 ] P = [0.1, 0.4, 0.5] P=[0.1,0.4,0.5]
  • Q = [ 0.3 , 0.4 , 0.3 ] Q = [0.3, 0.4, 0.3] Q=[0.3,0.4,0.3]

计算 KL 散度:

D K L ( P ∥ Q ) = 0.1 log ⁡ 0.1 0.3 + 0.4 log ⁡ 0.4 0.4 + 0.5 log ⁡ 0.5 0.3 D_{KL}(P \| Q) = 0.1 \log \frac{0.1}{0.3} + 0.4 \log \frac{0.4}{0.4} + 0.5 \log \frac{0.5}{0.3} DKL(P∥Q)=0.1log0.30.1+0.4log0.40.4+0.5log0.30.5

= 0.1 ⋅ ( − 1.585 ) + 0.4 ⋅ 0 + 0.5 ⋅ 0.737 = − 0.1585 + 0 + 0.3685 = 0.21 = 0.1 \cdot (-1.585) + 0.4 \cdot 0 + 0.5 \cdot 0.737 = -0.1585 + 0 + 0.3685 = 0.21 =0.1⋅(−1.585)+0.4⋅0+0.5⋅0.737=−0.1585+0+0.3685=0.21

2.4. 优势函数 Advantage Function

2.4.1 优势函数的定义

优势函数(Advantage Function)定义为:

A π ( s , a ) = Q π ( s , a ) − V π ( s ) A^{\pi}(s, a) = Q^{\pi}(s, a) - V^{\pi}(s) Aπ(s,a)=Qπ(s,a)−Vπ(s)

含义解释:

  • Q π ( s , a ) Q^{\pi}(s, a) Qπ(s,a):在状态 s s s 执行动作 a a a 后,后续回报的期望(行为价值);
  • V π ( s ) V^{\pi}(s) Vπ(s):在状态 s s s 下的期望回报(状态价值);
  • 所以 A ( s , a ) A(s,a) A(s,a):衡量在 s s s 下选择 a a a 是否比平均水平 V ( s ) V(s) V(s) 好。

2.4.2. 优势函数的作用

直接使用 Q ( s , a ) Q(s,a) Q(s,a) 更新策略时噪声大。使用 A ( s , a ) A(s,a) A(s,a) 可以:

  • 减少方差;
  • 聚焦于"优于平均"的动作;
  • 在 PPO 中可与旧策略值共享估计器。

2.4.3. 如何计算优势函数(实践中)

通常我们不能准确知道 Q Q Q 和 V V V,所以我们使用采样 + 蒙特卡洛或时序差分的方式估计优势函数。

常用方法如下:


方法一:1-step Advantage

A ^ t = r t + γ V ( s t + 1 ) − V ( s t ) \hat{A}t = r_t + \gamma V(s{t+1}) - V(s_t) A^t=rt+γV(st+1)−V(st)

适用于TD 近似,偏差小但方差大。


方法二:k-step Advantage

A ^ t = ∑ l = 0 k − 1 γ l r t + l + γ k V ( s t + k ) − V ( s t ) \hat{A}t = \sum{l=0}^{k-1} \gamma^l r_{t+l} + \gamma^k V(s_{t+k}) - V(s_t) A^t=l=0∑k−1γlrt+l+γkV(st+k)−V(st)

适用于中等长度回报估计


方法三:GAE(广义优势估计,Generalized Advantage Estimation)

广义优势估计(GAE ,Generalized Advantage Estimation)是一种在策略优化中高效、稳定地估计优势函数的方法

δ t = r t + γ V ( s t + 1 ) − V ( s t ) \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t) δt=rt+γV(st+1)−V(st)

A ^ t G A E ( γ , λ ) = ∑ l = 0 ∞ ( γ λ ) l δ t + l \hat{A}t^{GAE(\gamma, \lambda)} = \sum{l=0}^{\infty} (\gamma \lambda)^l \delta_{t+l} A^tGAE(γ,λ)=l=0∑∞(γλ)lδt+l

GAE 带来了两个重要超参数:

  • γ \gamma γ:折扣因子(通常为 0.99);
  • λ \lambda λ:控制 bias-variance 权衡(通常为 0.95)。

实现上可使用递归:

python 复制代码
def compute_gae(rewards, values, gamma=0.99, lam=0.95):
    advs = np.zeros_like(rewards)
    advantage = 0
    for t in reversed(range(len(rewards))):
        delta = rewards[t] + gamma * values[t + 1] - values[t]
        advantage = delta + gamma * lam * advantage
        advs[t] = advantage
    return advs

其中:

  • rewards[t]: 当前时刻获得的奖励;
  • values[t]: 当前状态的价值估计;
  • values[t+1]: 下一状态的价值估计。

3. PPO算法原理

3.1. 核心思想

PPO 是一种基于策略梯度的方法,其目标是改进旧策略时不过度更新(保持策略变化"适度"),以稳定训练。

PPO 有两种常见形式:

  1. PPO-Clip(常用):通过裁剪概率比来限制策略更新;
  2. PPO-KL:通过添加 KL 散度惩罚项。

3.2. 基本组成

3.2.1. 使用策略梯度

3.2.2. GAE 提供优势估计

3.2.3. 策略目标函数

(1)PPO-Clip 目标函数

策略更新目标函数如下:

L CLIP ( θ ) = E t [ min ⁡ ( r t ( θ ) A ^ t , clip ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) ⋅ A ^ t ) ] L^{\text{CLIP}}(\theta) = \mathbb{E}_t \left[ \min\left( r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \cdot \hat{A}_t \right) \right] LCLIP(θ)=Et[min(rt(θ)A^t,clip(rt(θ),1−ϵ,1+ϵ)⋅A^t)]

其中:

  • r t ( θ ) = π θ ( a t ∣ s t ) π θ old ( a t ∣ s t ) r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} rt(θ)=πθold(at∣st)πθ(at∣st):当前策略与旧策略的概率比;
  • A ^ t \hat{A}_t A^t:优势函数(通常由 GAE 计算);
  • ϵ \epsilon ϵ:限制策略变化的范围(如 0.2);
  • clip 限制 :防止 r t r_t rt 偏离 1 太多,从而稳定训练。
    • 如果策略变化小( r ≈ 1 r \approx 1 r≈1),使用正常的策略梯度;
    • 如果策略变化太大,clip 限制更新,防止崩溃。
(2)KL Penalty目标函数(次要版本)

另一种方法是将 KL 散度作为正则项加入目标:

L K L ( θ ) = E t [ r t ( θ ) A ^ t − β ⋅ D K L ( π θ old ∥ π θ ) ] L^{KL}(\theta) = \mathbb{E}t \left[ r_t(\theta) \hat{A}t - \beta \cdot D{KL}(\pi{\theta_{\text{old}}} \| \pi_\theta) \right] LKL(θ)=Et[rt(θ)A^t−β⋅DKL(πθold∥πθ)]

这种方法也可行,但不如 Clipped 简洁稳定,调参难度较大。

3.3. 算法流程

  1. 采样数据

    使用当前策略 π θ \pi_\theta πθ 与环境交互,得到轨迹 ( s t , a t , r t ) (s_t, a_t, r_t) (st,at,rt)。

  2. 计算优势值

    用 GAE(Generalized Advantage Estimation)或其他方法估算 A ^ t \hat{A}_t A^t。

  3. 构建目标函数

    使用 Clipped PPO 的目标函数。

  4. 多次优化

    使用 mini-batch 和多轮 epoch 更新策略网络和价值网络。

  5. 更新旧策略

    把当前策略复制为旧策略,进行下一轮训练。

3.4. PPO 的优点

  • 稳定性强:通过限制策略更新幅度,训练稳定。
  • 易于实现:无须精细调节 trust region(相较于 TRPO)。
  • 样本效率高:适合 batch 训练、多次利用旧样本。
  • 广泛应用:在游戏、机器人控制等场景中表现优异。
相关推荐
Baihai IDP3 分钟前
深度解析 Cursor(逐行解析系统提示词、分享高效制定 Cursor Rules 的技巧...)
人工智能·ai编程·cursor·genai·智能体·llms
神经星星7 分钟前
MIT 团队利用大模型筛选 25 类水泥熟料替代材料,相当于减排 12 亿吨温室气体
人工智能·深度学习·机器学习
lifallen11 分钟前
Java BitSet类解析:高效位向量实现
java·开发语言·后端·算法
学不好python的小猫22 分钟前
7-4 身份证号处理
开发语言·python·算法
Jamence23 分钟前
多模态大语言模型arxiv论文略读(125)
论文阅读·人工智能·语言模型·自然语言处理·论文笔记
AI浩27 分钟前
TradingAgents:基于多智能体的大型语言模型(LLM)金融交易框架
人工智能·语言模型·自然语言处理
澳鹏Appen30 分钟前
对抗性提示:进阶守护大语言模型
人工智能·语言模型·自然语言处理
源图客1 小时前
大语言模型指令集全解析
人工智能·语言模型·自然语言处理
wenzhangli71 小时前
筑牢安全防线:电子文件元数据驱动的 AI 知识库可控管理方案
大数据·人工智能
北京地铁1号线1 小时前
OCRBench:评估多模态大模型的OCR能力
人工智能