PPO (Proximal Policy Optimization) 算法模块详细拆解

让我把PPO算法像拆解一台精密机器一样,从整体架构到每个螺丝钉都详细解释。

🏗️ 整体架构图

text

复制代码
┌─────────────────────────────────────────────────────────┐
│                      PPO算法整体架构                      │
├─────────────────────────────────────────────────────────┤
│                                                         │
│   ┌──────────────┐          ┌──────────────────────┐   │
│   │  环境交互模块  │ ──────→ │     经验收集模块      │   │
│   │ (take_action) │          │ (transition_dict)   │   │
│   └──────────────┘          └──────────┬───────────┘   │
│                                        ↓                 │
│   ┌──────────────┐          ┌──────────────────────┐   │
│   │   模型保存    │ ←────── │       更新模块        │   │
│   │ (save_model) │          │   (update) 核心算法   │   │
│   └──────────────┘          └──────────────────────┘   │
│                                        ↑                 │
│   ┌──────────────┐          ┌──────────┴───────────┐   │
│   │   价值网络    │          │      策略网络         │   │
│   │ (ValueNet)   │          │    (PolicyNet)       │   │
│   └──────────────┘          └──────────────────────┘   │
│                                                         │
└─────────────────────────────────────────────────────────┘

📦 模块1:神经网络模块

1.1 策略网络(PolicyNet)- 演员

python

复制代码
class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)  # 输入层 → 隐藏层
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim) # 隐藏层 → 输出层

    def forward(self, x):
        x = F.relu(self.fc1(x))  # 激活函数引入非线性
        return F.softmax(self.fc2(x), dim=1)  # 转换为概率分布

详细解释

  • 输入:状态(CartPole中是4个数值:位置、速度、角度、角速度)

  • 隐藏层:128个神经元,学习状态特征

  • 输出:每个动作的概率(和为1),比如[0.3, 0.7]表示30%概率向左,70%向右

  • 激活函数

    • ReLU:解决梯度消失问题,计算简单

    • Softmax:将输出转换为概率分布

类比:就像一个决策顾问,根据当前情况给出建议的概率分布。

1.2 价值网络(ValueNet)- 评论家

python

复制代码
class ValueNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super(ValueNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, 1)  # 输出一个数值

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)  # 直接输出数值,不需要激活函数

详细解释

  • 输入:同样4维状态

  • 输出:一个标量值,表示当前状态的价值(预期未来总奖励)

  • 没有Softmax:因为是回归问题,不是分类

类比:就像一个经验丰富的评估师,判断当前局势有多好。

🔧 模块2:动作选择模块

python

复制代码
def take_action(self, state):
    # 1. 状态转换:numpy数组 → tensor
    state = torch.tensor([state], dtype=torch.float).to(self.device)
    
    # 2. 获取概率分布
    probs = self.actor(state)  # 形状:[1, action_dim]
    
    # 3. 创建概率分布对象
    action_dist = torch.distributions.Categorical(probs)
    
    # 4. 采样动作(不是选最大的,而是按概率随机)
    action = action_dist.sample()
    
    return action.item()

为什么要采样而不是选最大的?

  • 探索与利用的平衡:即使某个动作概率低,也有机会被选中

  • 避免局部最优:保持探索性,发现更好的策略

💾 模块3:经验收集模块

python

复制代码
transition_dict = {
    'states': [],        # 当前状态
    'actions': [],       # 采取的动作
    'next_states': [],   # 下一个状态
    'rewards': [],       # 获得的奖励
    'dones': []          # 是否结束
}

数据流

text

复制代码
一个完整轨迹(episode):
s1 → a1 → (r1, s2) → a2 → (r2, s3) → ... → sT (done)

🧮 模块4:核心更新模块

这是PPO最复杂的部分,我把它分解成5个子步骤:

4.1 数据预处理

python

复制代码
def update(self, transition_dict):
    # 将列表转换为tensor,并移到指定设备
    states = torch.tensor(transition_dict['states']).to(self.device)
    actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)
    rewards = torch.tensor(transition_dict['rewards']).view(-1, 1).to(self.device)
    next_states = torch.tensor(transition_dict['next_states']).to(self.device)
    dones = torch.tensor(transition_dict['dones']).view(-1, 1).to(self.device)

4.2 TD目标和TD误差计算

python

复制代码
# TD目标:r + γ * V(s') * (1-done)
td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)

# TD误差:δ = TD目标 - V(s)
td_delta = td_target - self.critic(states)

数学解释

  • TD目标:当前奖励 + 未来状态的折扣价值

  • TD误差:实际得到的比预期好多少(正数表示比预期好)

4.3 GAE优势函数计算

python

复制代码
def compute_advantage(gamma, lmbda, td_delta):
    td_delta = td_delta.detach().numpy()
    advantage_list = []
    advantage = 0.0
    # 反向计算优势(从最后一个时间步往前)
    for delta in td_delta[::-1]:
        advantage = gamma * lmbda * advantage + delta  # 关键公式
        advantage_list.append(advantage)
    advantage_list.reverse()
    return torch.tensor(advantage_list)

GAE的数学原理

text

复制代码
GAE(γ,λ) = δ₁ + (γλ)δ₂ + (γλ)²δ₃ + ...

其中 δ_t = r_t + γV(s_{t+1}) - V(s_t)  # TD误差

参数含义

  • γ (gamma):折扣因子,0.98,关注长期回报

  • λ (lambda):0.95,平衡方差和偏差

    • λ=0:只看一步TD误差(高偏差)

    • λ=1:看完整轨迹(高方差)

4.4 新旧策略比率计算

python

复制代码
# 计算旧策略的对数概率(detach阻止梯度传播)
old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach()

for _ in range(self.epochs):  # 重复使用数据多次
    # 新策略的对数概率
    log_probs = torch.log(self.actor(states).gather(1, actions))
    
    # 比率 r(θ) = π_θ(a|s) / π_θ_old(a|s)
    ratio = torch.exp(log_probs - old_log_probs)

gather操作解释

python

复制代码
# 例子:probs = [[0.2, 0.5, 0.3]]  # 3个动作的概率
# actions = [[1]]  # 选择了动作1(索引从0开始)
# probs.gather(1, actions) = [[0.5]]  # 提取选中的动作的概率
# log(probs.gather()) = log(0.5)  # 对数概率

4.5 PPO裁剪目标

python

复制代码
# 未裁剪的目标
surr1 = ratio * advantage

# 裁剪后的目标(限制在[1-eps, 1+eps])
surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage

# PPO目标:取两者的最小值
actor_loss = torch.mean(-torch.min(surr1, surr2))

裁剪机制可视化

text

复制代码
优势为正(好动作)时:
    ratio > 1+eps → 使用裁剪值1+eps
    ratio < 1+eps → 使用实际ratio

优势为负(坏动作)时:
    ratio < 1-eps → 使用裁剪值1-eps
    ratio > 1-eps → 使用实际ratio

4.6 价值网络损失

python

复制代码
critic_loss = torch.mean(
    F.mse_loss(self.critic(states), td_target.detach())
)

为什么detach TD目标?

  • TD目标包含critic自己的输出,detach防止梯度回流

  • 避免"自举"造成的循环梯度

4.7 梯度更新

python

复制代码
# 清空梯度
self.actor_optimizer.zero_grad()
self.critic_optimizer.zero_grad()

# 反向传播
actor_loss.backward()
critic_loss.backward()

# 更新参数
self.actor_optimizer.step()
self.critic_optimizer.step()

💿 模块5:模型保存与加载

python

复制代码
def save_model(self, filepath):
    torch.save({
        'actor_state_dict': self.actor.state_dict(),      # 策略网络权重
        'critic_state_dict': self.critic.state_dict(),    # 价值网络权重
        'actor_optimizer_state_dict': self.actor_optimizer.state_dict(),  # 优化器状态
        'critic_optimizer_state_dict': self.critic_optimizer.state_dict(),
    }, filepath)

def load_model(self, filepath):
    checkpoint = torch.load(filepath, map_location=self.device)
    self.actor.load_state_dict(checkpoint['actor_state_dict'])
    self.critic.load_state_dict(checkpoint['critic_state_dict'])
    self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer_state_dict'])
    self.critic_optimizer.load_state_dict(checkpoint['critic_optimizer_state_dict'])

为什么要保存优化器状态?

  • 优化器(Adam)有动量等内部状态

  • 保存这些状态可以完全恢复训练过程

📊 模块6:训练流程控制

python

复制代码
def train_ppo(num_episodes=300):
    return_list = []
    best_reward = -float('inf')
    
    for i in range(10):  # 10个阶段
        for i_episode in range(30):  # 每个阶段30个episode
            # 1. 收集一个episode的数据
            # 2. 更新策略
            # 3. 记录奖励
            # 4. 保存最佳模型
            
            if episode_return > best_reward:
                best_reward = episode_return
                agent.save_model('best_model.pth')

🎯 关键超参数的作用

参数 作用 影响
actor_lr 1e-3 策略网络学习率 太大不稳定,太小学习慢
critic_lr 1e-2 价值网络学习率 价值网络可以学得快些
gamma 0.98 折扣因子 接近1考虑长远,接近0只看眼前
lmbda 0.95 GAE参数 平衡方差和偏差
eps 0.2 裁剪范围 控制策略更新幅度
epochs 10 数据复用次数 提高样本效率

🔄 完整数据流示例

假设一个episode有4步:

text

复制代码
Step 1: s1 → a1 → r1, s2
Step 2: s2 → a2 → r2, s3  
Step 3: s3 → a3 → r3, s4
Step 4: s4 → a4 → r4, done

收集的数据:
states:      [s1, s2, s3, s4]
actions:     [a1, a2, a3, a4]  
rewards:     [r1, r2, r3, r4]
next_states: [s2, s3, s4, s4]
dones:       [0, 0, 0, 1]

更新过程:
1. 计算V(s1), V(s2), V(s3), V(s4)
2. 计算TD目标:[r1+γV(s2), r2+γV(s3), r3+γV(s4), r4+γV(s4)*0]
3. 计算TD误差:δ1, δ2, δ3, δ4
4. 计算GAE优势:A1, A2, A3, A4(考虑整个序列)
5. 计算新旧策略比率
6. 计算裁剪损失
7. 更新网络

💡 算法创新点总结

  1. 重要性采样:用旧策略收集的数据更新新策略

  2. 裁剪目标:限制更新幅度,保证稳定性

  3. GAE优势估计:平衡偏差和方差

  4. 多epoch更新:提高样本效率

  5. Actor-Critic架构:同时学习策略和价值

相关推荐
仙女修炼史2 小时前
FCOS: Fully Convolutional One-Stage Object Detection
人工智能·目标检测·目标跟踪
阿Y加油吧2 小时前
力扣打卡day06——滑动窗口最大值、最小覆盖子串
数据结构·算法·leetcode
大傻^2 小时前
Spring AI Alibaba 多模态开发:集成视觉理解与视频分析能力
人工智能·spring·音视频·springai·springaialibaba·混合检索
沉鱼.442 小时前
日期题目集
数据结构·算法
前端摸鱼匠2 小时前
面试题3:自注意力机制(Self-Attention)的计算流程是什么?
人工智能·ai·面试·职场和发展
出门吃三碗饭2 小时前
CARLA: 如何在 CARLA 中回放自动驾驶场景
人工智能·机器学习·自动驾驶
Axis tech2 小时前
第二届人形机器人半程马拉松即将于4月开赛,对比去年技术进步有哪些?
人工智能·机器人
志栋智能2 小时前
超自动化巡检,如何成为业务稳定的“压舱石”?
大数据·运维·网络·人工智能·自动化
Book思议-2 小时前
【数据结构考研真题】链表题
c语言·数据结构·算法·链表·408·计算机考研