让我把PPO算法像拆解一台精密机器一样,从整体架构到每个螺丝钉都详细解释。
🏗️ 整体架构图
text
┌─────────────────────────────────────────────────────────┐
│ PPO算法整体架构 │
├─────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ ┌──────────────────────┐ │
│ │ 环境交互模块 │ ──────→ │ 经验收集模块 │ │
│ │ (take_action) │ │ (transition_dict) │ │
│ └──────────────┘ └──────────┬───────────┘ │
│ ↓ │
│ ┌──────────────┐ ┌──────────────────────┐ │
│ │ 模型保存 │ ←────── │ 更新模块 │ │
│ │ (save_model) │ │ (update) 核心算法 │ │
│ └──────────────┘ └──────────────────────┘ │
│ ↑ │
│ ┌──────────────┐ ┌──────────┴───────────┐ │
│ │ 价值网络 │ │ 策略网络 │ │
│ │ (ValueNet) │ │ (PolicyNet) │ │
│ └──────────────┘ └──────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────┘
📦 模块1:神经网络模块
1.1 策略网络(PolicyNet)- 演员
python
class PolicyNet(torch.nn.Module):
def __init__(self, state_dim, hidden_dim, action_dim):
super(PolicyNet, self).__init__()
self.fc1 = torch.nn.Linear(state_dim, hidden_dim) # 输入层 → 隐藏层
self.fc2 = torch.nn.Linear(hidden_dim, action_dim) # 隐藏层 → 输出层
def forward(self, x):
x = F.relu(self.fc1(x)) # 激活函数引入非线性
return F.softmax(self.fc2(x), dim=1) # 转换为概率分布
详细解释:
-
输入:状态(CartPole中是4个数值:位置、速度、角度、角速度)
-
隐藏层:128个神经元,学习状态特征
-
输出:每个动作的概率(和为1),比如[0.3, 0.7]表示30%概率向左,70%向右
-
激活函数:
-
ReLU:解决梯度消失问题,计算简单
-
Softmax:将输出转换为概率分布
-
类比:就像一个决策顾问,根据当前情况给出建议的概率分布。
1.2 价值网络(ValueNet)- 评论家
python
class ValueNet(torch.nn.Module):
def __init__(self, state_dim, hidden_dim):
super(ValueNet, self).__init__()
self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
self.fc2 = torch.nn.Linear(hidden_dim, 1) # 输出一个数值
def forward(self, x):
x = F.relu(self.fc1(x))
return self.fc2(x) # 直接输出数值,不需要激活函数
详细解释:
-
输入:同样4维状态
-
输出:一个标量值,表示当前状态的价值(预期未来总奖励)
-
没有Softmax:因为是回归问题,不是分类
类比:就像一个经验丰富的评估师,判断当前局势有多好。
🔧 模块2:动作选择模块
python
def take_action(self, state):
# 1. 状态转换:numpy数组 → tensor
state = torch.tensor([state], dtype=torch.float).to(self.device)
# 2. 获取概率分布
probs = self.actor(state) # 形状:[1, action_dim]
# 3. 创建概率分布对象
action_dist = torch.distributions.Categorical(probs)
# 4. 采样动作(不是选最大的,而是按概率随机)
action = action_dist.sample()
return action.item()
为什么要采样而不是选最大的?
-
探索与利用的平衡:即使某个动作概率低,也有机会被选中
-
避免局部最优:保持探索性,发现更好的策略
💾 模块3:经验收集模块
python
transition_dict = {
'states': [], # 当前状态
'actions': [], # 采取的动作
'next_states': [], # 下一个状态
'rewards': [], # 获得的奖励
'dones': [] # 是否结束
}
数据流:
text
一个完整轨迹(episode):
s1 → a1 → (r1, s2) → a2 → (r2, s3) → ... → sT (done)
🧮 模块4:核心更新模块
这是PPO最复杂的部分,我把它分解成5个子步骤:
4.1 数据预处理
python
def update(self, transition_dict):
# 将列表转换为tensor,并移到指定设备
states = torch.tensor(transition_dict['states']).to(self.device)
actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)
rewards = torch.tensor(transition_dict['rewards']).view(-1, 1).to(self.device)
next_states = torch.tensor(transition_dict['next_states']).to(self.device)
dones = torch.tensor(transition_dict['dones']).view(-1, 1).to(self.device)
4.2 TD目标和TD误差计算
python
# TD目标:r + γ * V(s') * (1-done)
td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)
# TD误差:δ = TD目标 - V(s)
td_delta = td_target - self.critic(states)
数学解释:
-
TD目标:当前奖励 + 未来状态的折扣价值
-
TD误差:实际得到的比预期好多少(正数表示比预期好)
4.3 GAE优势函数计算
python
def compute_advantage(gamma, lmbda, td_delta):
td_delta = td_delta.detach().numpy()
advantage_list = []
advantage = 0.0
# 反向计算优势(从最后一个时间步往前)
for delta in td_delta[::-1]:
advantage = gamma * lmbda * advantage + delta # 关键公式
advantage_list.append(advantage)
advantage_list.reverse()
return torch.tensor(advantage_list)
GAE的数学原理:
text
GAE(γ,λ) = δ₁ + (γλ)δ₂ + (γλ)²δ₃ + ...
其中 δ_t = r_t + γV(s_{t+1}) - V(s_t) # TD误差
参数含义:
-
γ (gamma):折扣因子,0.98,关注长期回报
-
λ (lambda):0.95,平衡方差和偏差
-
λ=0:只看一步TD误差(高偏差)
-
λ=1:看完整轨迹(高方差)
-
4.4 新旧策略比率计算
python
# 计算旧策略的对数概率(detach阻止梯度传播)
old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach()
for _ in range(self.epochs): # 重复使用数据多次
# 新策略的对数概率
log_probs = torch.log(self.actor(states).gather(1, actions))
# 比率 r(θ) = π_θ(a|s) / π_θ_old(a|s)
ratio = torch.exp(log_probs - old_log_probs)
gather操作解释:
python
# 例子:probs = [[0.2, 0.5, 0.3]] # 3个动作的概率
# actions = [[1]] # 选择了动作1(索引从0开始)
# probs.gather(1, actions) = [[0.5]] # 提取选中的动作的概率
# log(probs.gather()) = log(0.5) # 对数概率
4.5 PPO裁剪目标
python
# 未裁剪的目标
surr1 = ratio * advantage
# 裁剪后的目标(限制在[1-eps, 1+eps])
surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage
# PPO目标:取两者的最小值
actor_loss = torch.mean(-torch.min(surr1, surr2))
裁剪机制可视化:
text
优势为正(好动作)时:
ratio > 1+eps → 使用裁剪值1+eps
ratio < 1+eps → 使用实际ratio
优势为负(坏动作)时:
ratio < 1-eps → 使用裁剪值1-eps
ratio > 1-eps → 使用实际ratio
4.6 价值网络损失
python
critic_loss = torch.mean(
F.mse_loss(self.critic(states), td_target.detach())
)
为什么detach TD目标?
-
TD目标包含critic自己的输出,detach防止梯度回流
-
避免"自举"造成的循环梯度
4.7 梯度更新
python
# 清空梯度
self.actor_optimizer.zero_grad()
self.critic_optimizer.zero_grad()
# 反向传播
actor_loss.backward()
critic_loss.backward()
# 更新参数
self.actor_optimizer.step()
self.critic_optimizer.step()
💿 模块5:模型保存与加载
python
def save_model(self, filepath):
torch.save({
'actor_state_dict': self.actor.state_dict(), # 策略网络权重
'critic_state_dict': self.critic.state_dict(), # 价值网络权重
'actor_optimizer_state_dict': self.actor_optimizer.state_dict(), # 优化器状态
'critic_optimizer_state_dict': self.critic_optimizer.state_dict(),
}, filepath)
def load_model(self, filepath):
checkpoint = torch.load(filepath, map_location=self.device)
self.actor.load_state_dict(checkpoint['actor_state_dict'])
self.critic.load_state_dict(checkpoint['critic_state_dict'])
self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer_state_dict'])
self.critic_optimizer.load_state_dict(checkpoint['critic_optimizer_state_dict'])
为什么要保存优化器状态?
-
优化器(Adam)有动量等内部状态
-
保存这些状态可以完全恢复训练过程
📊 模块6:训练流程控制
python
def train_ppo(num_episodes=300):
return_list = []
best_reward = -float('inf')
for i in range(10): # 10个阶段
for i_episode in range(30): # 每个阶段30个episode
# 1. 收集一个episode的数据
# 2. 更新策略
# 3. 记录奖励
# 4. 保存最佳模型
if episode_return > best_reward:
best_reward = episode_return
agent.save_model('best_model.pth')
🎯 关键超参数的作用
| 参数 | 值 | 作用 | 影响 |
|---|---|---|---|
actor_lr |
1e-3 | 策略网络学习率 | 太大不稳定,太小学习慢 |
critic_lr |
1e-2 | 价值网络学习率 | 价值网络可以学得快些 |
gamma |
0.98 | 折扣因子 | 接近1考虑长远,接近0只看眼前 |
lmbda |
0.95 | GAE参数 | 平衡方差和偏差 |
eps |
0.2 | 裁剪范围 | 控制策略更新幅度 |
epochs |
10 | 数据复用次数 | 提高样本效率 |
🔄 完整数据流示例
假设一个episode有4步:
text
Step 1: s1 → a1 → r1, s2
Step 2: s2 → a2 → r2, s3
Step 3: s3 → a3 → r3, s4
Step 4: s4 → a4 → r4, done
收集的数据:
states: [s1, s2, s3, s4]
actions: [a1, a2, a3, a4]
rewards: [r1, r2, r3, r4]
next_states: [s2, s3, s4, s4]
dones: [0, 0, 0, 1]
更新过程:
1. 计算V(s1), V(s2), V(s3), V(s4)
2. 计算TD目标:[r1+γV(s2), r2+γV(s3), r3+γV(s4), r4+γV(s4)*0]
3. 计算TD误差:δ1, δ2, δ3, δ4
4. 计算GAE优势:A1, A2, A3, A4(考虑整个序列)
5. 计算新旧策略比率
6. 计算裁剪损失
7. 更新网络
💡 算法创新点总结
-
重要性采样:用旧策略收集的数据更新新策略
-
裁剪目标:限制更新幅度,保证稳定性
-
GAE优势估计:平衡偏差和方差
-
多epoch更新:提高样本效率
-
Actor-Critic架构:同时学习策略和价值