学习算法精读:RolloutStorage、PPO 更新与 TorchScript 导出
这一篇我们把视角从环境切到训练系统本身,主读这三份代码:
rollout_storage.pyppo.py__init__.py
如果说环境解决的是"动作进去以后物理世界发生什么",那么这套训练代码解决的就是:
这些轨迹怎么存、怎么算 advantage、怎么做 PPO 更新,以及 adaptation module 的监督信号怎么插进同一个训练闭环。
从整体结构上看,这个 ppo_cse 版本可以概括成四层:
Runner负责 rollout 收集、日志、保存与训练调度RolloutStorage负责把一整段轨迹缓存成 PPO 可消费的 tensorPPO负责 GAE、PPO loss、梯度更新和 adaptation supervisionActorCritic提供 student policy、critic 和 adaptation module
一、训练主循环:Runner.learn() 先收轨迹,再做一次更新
先看 ppo_cse/__init__.py 里的 Runner。
初始化时,Runner 会先创建 ActorCritic,再创建 PPO,最后初始化 storage:
python
actor_critic = ActorCritic(
self.env.num_obs,
self.env.num_privileged_obs,
self.env.num_obs_history,
self.env.num_actions,
).to(self.device)
self.alg = PPO(actor_critic, device=self.device)
self.alg.init_storage(
self.env.num_train_envs,
self.num_steps_per_env,
[self.env.num_obs],
[self.env.num_privileged_obs],
[self.env.num_obs_history],
[self.env.num_actions]
)
# 这里只给 train envs 分配 rollout storage
# 说明并行环境里虽然可能同时跑 train 和 eval,但真正进入 PPO 更新的只有训练环境部分
这一点很关键:
eval envs 会并行跑,但不进入优化器;只有 train envs 的 rollout 会被存下来训练。
在 learn() 里,整体训练节奏非常清楚:
python
for it in range(...):
# 1. rollout collection
for i in range(self.num_steps_per_env):
actions_train = self.alg.act(...)
actions_eval = self.alg.actor_critic.act_student(...) or act_teacher(...)
ret = self.env.step(torch.cat((actions_train, actions_eval), dim=0))
self.alg.process_env_step(rewards[:num_train_envs], dones[:num_train_envs], infos)
# 2. compute returns
self.alg.compute_returns(...)
# 3. PPO update
self.alg.update()
也就是说,一个 iteration 的基本单位是:
先在环境里收 num_steps_per_env 步轨迹,再做一次 PPO 学习。
二、RolloutStorage:轨迹是怎么存下来的
RolloutStorage 的设计很标准,但这份代码有一个很值得注意的点:
它不仅存常规 PPO 所需的 (obs, action, reward, done, value, log_prob),还额外存了:
privileged_observationsobservation_historiesenv_bins
看初始化:
python
self.observations = torch.zeros(T, N, *obs_shape, device=self.device)
self.privileged_observations = torch.zeros(T, N, *privileged_obs_shape, device=self.device)
self.observation_histories = torch.zeros(T, N, *obs_history_shape, device=self.device)
self.rewards = torch.zeros(T, N, 1, device=self.device)
self.actions = torch.zeros(T, N, *actions_shape, device=self.device)
self.dones = torch.zeros(T, N, 1, device=self.device).byte()
# T = num_steps_per_env
# N = num_train_envs
# 每一项都是一个 [时间, 环境, 特征] 的张量,所以每个张量都可以理解成:第 0 维:时间步 第 1 维:并行环境编号 后面维度:特征本身
这里的物理/算法含义是:
observations是当前单帧观测observation_histories是 policy 真正依赖的历史窗口输入privileged_observations是 critic 和 adaptation supervision 的真值env_bins是命令课程相关的环境分桶信息,主要用于统计和分析
真正写入轨迹发生在 add_transitions():
python
self.observations[self.step].copy_(transition.observations)
self.privileged_observations[self.step].copy_(transition.privileged_observations)
self.observation_histories[self.step].copy_(transition.observation_histories)
self.actions[self.step].copy_(transition.actions)
self.rewards[self.step].copy_(transition.rewards.view(-1, 1))
self.dones[self.step].copy_(transition.dones.view(-1, 1))
self.values[self.step].copy_(transition.values)
self.actions_log_prob[self.step].copy_(transition.actions_log_prob.view(-1, 1))
# 每一步 rollout 收到的新 transition 都按时间顺序写到 storage 里
所以 RolloutStorage 本质上就是:
把"在线环境交互产生的一步一步 transition",重新组织成一个定长的 [T, N, ...] 批量轨迹张量。
三、GAE 是怎么计算的
GAE 的实现位于 compute_returns(),写法非常标准:
python
advantage = 0
for step in reversed(range(self.num_transitions_per_env)):
if step == self.num_transitions_per_env - 1:
next_values = last_values
else:
next_values = self.values[step + 1]
next_is_not_terminal = 1.0 - self.dones[step].float()
delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step]
advantage = delta + next_is_not_terminal * gamma * lam * advantage
self.returns[step] = advantage + self.values[step]
这段代码可以拆成三层理解。
1. TD residual
python
delta = r_t + gamma * V(s_{t+1}) - V(s_t)
# 这是一步 temporal-difference 误差
# 表示"当前 value 估计和 Bellman 目标之间差了多少"
2. GAE 递推
python
advantage = delta + gamma * lam * advantage
# 不是只看一步 TD,而是把未来一串 TD 残差按 gamma*lambda 递减累加
# lambda 越大,估计越接近 Monte Carlo;lambda 越小,越接近 TD(0)
3. return 重建
python
self.returns[step] = advantage + self.values[step]
# PPO 的 value target 不是直接 reward-to-go,而是 A + V
最后它会做 advantage 标准化:
python
self.advantages = self.returns - self.values
self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)
# 标准化 advantage 有助于稳定 PPO 更新,避免某些 batch 的尺度特别大
所以这份实现里的 GAE 逻辑可以总结成一句话:
先反向递推 generalized advantage,再把它转成 value learning 所需的 return target。
四、mini-batch 是怎么切出来的
PPO 更新不直接在 [T, N, ...] 的三维轨迹上做,而是先 flatten:
python
observations = self.observations.flatten(0, 1)
privileged_obs = self.privileged_observations.flatten(0, 1)
obs_history = self.observation_histories.flatten(0, 1)
actions = self.actions.flatten(0, 1)
returns = self.returns.flatten(0, 1)
advantages = self.advantages.flatten(0, 1)
# 把 [T, N, ...] 展平成 [T*N, ...]
# 这样就能按普通监督学习的方式随机抽 mini-batch
然后 mini_batch_generator() 里用 torch.randperm 打乱索引,切出 num_mini_batches 个子 batch。
因此这里的 PPO 不是 recurrent PPO,而是一个基于展平时间-环境维度的普通 feed-forward PPO。
这和当前网络结构也一致:
虽然输入里有 obs_history,但它已经被环境 wrapper 预先拼成固定长度向量了,所以对网络来说它仍然是普通 MLP 输入,而不是显式 RNN 序列。
五、PPO loss 是怎么组成的
ppo.py 的 update() 分成两条优化线:
- PPO 主损失
- adaptation supervision 损失
先看主损失。
1. 策略分布重算
python
self.actor_critic.act(obs_history_batch, masks=masks_batch)
actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch)
value_batch = self.actor_critic.evaluate(obs_history_batch, privileged_obs_batch, masks=masks_batch)
mu_batch = self.actor_critic.action_mean
sigma_batch = self.actor_critic.action_std
entropy_batch = self.actor_critic.entropy
# 用当前网络重新计算该 batch 上的 action log prob、value、均值、方差和熵
这里 actor 只吃 obs_history_batch,critic 吃 obs_history_batch + privileged_obs_batch,和前一篇分析一致。
2. 自适应 KL 学习率
python
kl = ...
if kl_mean > PPO_Args.desired_kl * 2.0:
self.learning_rate = max(1e-5, self.learning_rate / 1.5)
elif kl_mean < PPO_Args.desired_kl / 2.0 and kl_mean > 0.0:
self.learning_rate = min(1e-2, self.learning_rate * 1.5)
# 如果新旧策略差太远,就减小学习率
# 如果更新太保守,就增大学习率
这属于一个很常见但很实用的 PPO 工程增强:
不用固定 learning rate,而是根据 KL 漂移自动调。
3. clipped surrogate loss
python
ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch))
surrogate = -torch.squeeze(advantages_batch) * ratio
surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param)
surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()
# PPO 的核心思想就是限制 policy ratio 不能偏离太远
# 防止一次更新把策略推崩
4. clipped value loss
python
value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(-clip_param, clip_param)
value_losses = (value_batch - returns_batch).pow(2)
value_losses_clipped = (value_clipped - returns_batch).pow(2)
value_loss = torch.max(value_losses, value_losses_clipped).mean()
# value network 也做 clipping,避免 critic 更新过猛
5. 总 PPO loss
python
loss = surrogate_loss + value_loss_coef * value_loss - entropy_coef * entropy_batch.mean()
# 策略优化项 + 价值回归项 - 熵正则
# 熵项越大,策略越鼓励保持探索
然后就是常规梯度更新:
python
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.actor_critic.parameters(), max_grad_norm)
self.optimizer.step()
# 做梯度裁剪,避免偶发大梯度把训练炸掉
所以,主 PPO 部分的损失结构非常经典,就是:
clipped policy loss + clipped value loss + entropy regularization
六、adaptation supervision 是怎么加进去的
这一版 ppo_cse 最有特色的地方,就在于它不是只做 PPO,还会额外训练 adaptation module。
看 update() 后半段:
python
adaptation_pred = self.actor_critic.adaptation_module(obs_history_batch)
with torch.no_grad():
adaptation_target = privileged_obs_batch
# adaptation module 只看 obs_history
# 监督目标是 privileged_obs 真值
然后做 MSE:
python
adaptation_loss = F.mse_loss(
adaptation_pred[:num_train, selection_indices],
adaptation_target[:num_train, selection_indices]
)
adaptation_test_loss = F.mse_loss(
adaptation_pred[num_train:, selection_indices],
adaptation_target[num_train:, selection_indices]
)
# 当前 batch 会再切成 4/5 train 和 1/5 test
# 训练时顺手统计一个 adaptation test loss,方便观察是否过拟合
这段代码有两个很值得点出来的细节。
1. adaptation supervision 是单独的优化步骤
python
self.adaptation_module_optimizer.zero_grad()
adaptation_loss.backward()
self.adaptation_module_optimizer.step()
# adaptation 不是并进 PPO 的总 loss 一起反传
# 而是在 PPO 主更新之后,再做一个额外的监督学习步骤
也就是说,算法结构不是:
PPO loss + alpha * adaptation loss
而是:
先做 PPO 更新,再做 adaptation module 的监督更新。
2. 这个 optimizer 实际上仍然挂在 self.actor_critic.parameters()
python
self.adaptation_module_optimizer = optim.Adam(self.actor_critic.parameters(), lr=...)
# 从代码上看,optimizer 的参数集并没有只筛 adaptation_module
# 但因为 loss 只从 adaptation_module 前向图产生,所以主要梯度会流向它相关的参数
从实现意图上,作者显然想表达的是"给 adaptation module 一条独立学习率的监督分支"。
博客里可以如实说明这一点,但不必展开成代码审查。
所以这一部分的整体逻辑可以总结成:
PPO 负责学控制,adaptation supervision 负责学环境辨识。
七、rollout 里 student / teacher 是怎么协作的
在 Runner.learn() 中,train env 和 eval env 的动作来源是不同的:
python
actions_train = self.alg.act(obs[:num_train_envs], privileged_obs[:num_train_envs], obs_history[:num_train_envs])
# 训练环境里,动作来自 student policy 路径
if eval_expert:
actions_eval = self.alg.actor_critic.act_teacher(obs_history[num_train_envs:], privileged_obs[num_train_envs:])
else:
actions_eval = self.alg.actor_critic.act_student(obs_history[num_train_envs:])
# 评估环境可以切换成 teacher 或 student,用于对比效果
这说明几个关键点:
- rollout 训练时,真正进入 PPO 优化的是 student 行为
- teacher 主要用在 eval 或对比分析
- critic 仍然使用 privileged_obs 来估值
- adaptation module 则通过额外监督学习 privileged_obs
所以,这个训练系统不是"先 teacher 再 distill",而是:
student 主导 rollout,critic 提供特权 value,adaptation module 单独学环境因子估计。
八、checkpoint 和 .jit 是在什么时机导出的
保存逻辑在 Runner.learn() 的后半段。
周期性保存
python
if it % RunnerArgs.save_interval == 0:
with logger.Sync():
logger.torch_save(self.alg.actor_critic.state_dict(), f"checkpoints/ac_weights_{it:06d}.pt")
logger.duplicate(..., "checkpoints/ac_weights_last.pt")
# 每隔 save_interval 个 iteration 保存一次完整权重
这保存的是完整 ActorCritic 参数,也就是训练恢复用的 checkpoint。
紧接着还会导出两个 TorchScript:
python
adaptation_module = copy.deepcopy(self.alg.actor_critic.adaptation_module).to('cpu')
traced_script_adaptation_module = torch.jit.script(adaptation_module)
traced_script_adaptation_module.save(adaptation_module_path)
body_model = copy.deepcopy(self.alg.actor_critic.actor_body).to('cpu')
traced_script_body_module = torch.jit.script(body_model)
traced_script_body_module.save(body_path)
# 同步导出 adaptation_module_latest.jit 和 body_latest.jit
# 这两个文件对应部署时真正需要的 student 推理路径
训练结束时再保存一次
在 learn() 末尾,又会无条件再导出一遍:
python
with logger.Sync():
logger.torch_save(...)
...
traced_script_adaptation_module.save(adaptation_module_path)
traced_script_body_module.save(body_path)
# 即使最后一次 iteration 不是 save_interval 的整数倍,训练结束也会补一份最终模型
所以导出时机很清楚:
- 周期性 checkpoint:用于训练过程中的恢复和回滚
- 周期性
.jit:用于随时拿最新 student policy 去部署或测试 - 最终再导出一次:保证训练结束时一定有最新成品
九、为什么导出的不是完整模型,而是两个 .jit
这和前一篇的策略结构完全对应。
训练完整体后,真正部署只需要:
adaptation_module_latest.jitbody_latest.jit
原因是推理链路本来就是:
obs_history -> adaptation_module -> privileged_hat -> actor_body -> action
critic 不参与部署,PPO optimizer 和 storage 更不参与部署。
所以 .jit 导出的目的不是"保存所有训练结构",而是抽取出实际机器人运行时必须保留的最小推理图。
这也是为什么训练系统会同时保存两类文件:
.pt:给训练恢复用.jit:给部署执行用
十、把整个训练链路串起来看
如果把这三份代码连起来,整套学习算法的节奏可以概括成下面这条主线:
Runner从环境取到obs / privileged_obs / obs_history- student policy 基于
obs_history产生动作 - critic 基于
obs_history + privileged_obs产生 value Transition被塞进RolloutStorage- 收满
num_steps_per_env步之后:RolloutStorage.compute_returns()反向算 GAE 和 returnmini_batch_generator()打乱并切出 mini-batch
PPO.update()对每个 mini-batch:- 计算 PPO surrogate loss
- 计算 value loss
- 加 entropy regularization
- 做一次 PPO 梯度更新
- 再做一次 adaptation supervision 的 MSE 更新
Runner记录日志、保存 checkpoint、导出.jit
所以,这个系统并不是一个"纯 PPO",而是一个:
PPO 主优化 + privileged critic + adaptation supervision + deployment-oriented TorchScript export
的完整训练流水线。
结语
这套 ppo_cse 实现最值得记住的,不是某个具体超参数,而是它把三件事揉进了一条一致的训练链路里:
- PPO 负责学动作策略
- GAE 和 value loss 负责稳定优化
- adaptation supervision 负责把历史观测映射成环境因子估计
.pt和.jit双轨保存则把"训练恢复"和"部署执行"明确分离开来
所以如果用一句话概括这篇:
它训练的不是一个单独的 policy,而是一套"可训练、可恢复、可部署"的 student 控制系统。