前言
-
Unitree RL GYM 是一个开源的 基于 Unitree 机器人强化学习(Reinforcement Learning, RL)控制示例项目 ,用于训练、测试和部署四足机器人控制策略。该仓库支持多种 Unitree 机器人型号,包括 Go2、H1、H1_2 和 G1 。仓库地址

-
本系列将着手解析整个仓库的核心代码与算法实现和训练教程。此系列默认读者拥有一定的强化学习基础和代码基础,故在部分原理和基础代码逻辑不做解释,对强化学习基础感兴趣的读者可以阅读我的入门系列:
- 第一期: 【浅显易懂理解强化学习】(一)Q-Learning原来是查表法-CSDN博客
- 第二期: 【浅显易懂理解强化学习】(二):Sarsa,保守派的胜利-CSDN博客
- 第三期:【浅显易懂理解强化学习】(三):DQN:当查表法装上大脑-CSDN博客
- 第四期:【浅显易懂理解强化学习】(四):Policy Gradients玩转策略采样-CSDN博客
- 第五期:【浅显易懂理解强化学习】(五):Actor-Critic与A3C,多线程的完全胜利-CSDN博客
- 第六期:【浅显易懂理解强化学习】(六):DDPG与TD3集百家之长-CSDN博客
- 第七期:【浅显易懂理解强化学习】(七):PPO,策略更新的安全阀-CSDN博客
-
阅读本系列的前置知识:
python语法,明白面向对象的封装pytorch基础使用- 神经网络基础知识
- 强化学习基础知识,至少了解
Policy Gradient、Actor-Critic和PPO
-
本系列:
-
本期作为
rsl_rl仓库源码解读的最后一期,为了很好的过渡到后面的UNITREE_RL_GYM,将一口气把剩下的OnPolicyRunner和VecEnv以及RolloutStorage一口气讲完。
0 前置知识:观测与特权观测
0-1 介绍
- 在强化学习中,观测(Observation) 是 Actor 网络的输入,决定了策略可以看到环境的哪些信息。
- 而 特权观测(Privileged Observation) 则通常用于 Critic 网络,它可以访问更多、更精确的状态信息,以便更好地估计价值函数(Value Function),从而提升训练稳定性和收敛速度。
- 那这个项目举例子的话:
0-2 观测(Observation)
- 对于四足机器人环境,例如 Go2、H1 等,观测一般包含以下几类信息:
| 类型 | 说明 | 示例维度 |
|---|---|---|
| 关节位置 | 当前各关节角度 | 12 |
| 关节速度 | 各关节角速度 | 12 |
| 基座线速度 | 机器人底盘在 x, y, z 方向的速度 | 3 |
| 基座角速度 | 机器人底盘绕 x, y, z 轴的角速度 | 3 |
| IMU 加速度 | 加速度传感器读数 | 3 |
| IMU 角速度 | 陀螺仪读数 | 3 |
| 末端足接触 | 各足是否接触地面(触觉传感器) | 4 |
| 目标信息 | 目标位置或运动指令 | 3 |
0-3 特权观测(Privileged Observation)
- 特权观测是 Critic 独占的,它可以看到更多信息,包括一些 Actor 无法获取的"全局状态"。常见内容:
| 类型 | 说明 | 示例维度 |
|---|---|---|
| 完整机器人状态 | 包含底盘位姿(位置 + 姿态) | 6 |
| 关节力矩 | 各关节实际力矩 | 12 |
| 地形高度图 | 机器人下方或周围地形信息 | 12~20 |
| 物理全局状态 | 机器人每个关节和连杆的全局位置/速度 | N/A |
| 环境动态信息 | 例如障碍物位置、速度 | 可选 |
0-3 关系
-
这样就非常明白了吧:
Environment ──> observation ──> Actor ──> action ──> Environment
└─> privileged_observation ──> Critic ──> value -
Actor 看到的有限信息,保证策略可以在真实环境中执行
-
Critic 看到更多信息,帮助训练更准确的价值函数
0-4 Asymmetric Actor-Critic
- Asymmetric 指的是 Actor 和 Critic 的输入观测不同
- Actor:只能看到普通观测
obs(例如机器人传感器数据) - Critic:可以看到特权观测
privileged_obs(例如完整的环境状态、地形信息等)
- Actor:只能看到普通观测
- 目的 :
- Actor 训练出的策略可以只依赖有限信息部署
- Critic 利用更多信息计算更准确的价值函数,从而加速训练
- 本质 :它仍然是 Actor-Critic,只是 Critic 的输入比 Actor 多,适合 模拟到真实世界迁移(sim-to-real
1 RolloutStorage
1-1 介绍
- 我们在第一期讲过,
storage/rollout_storage.py用于存储采样轨迹(rollouts) - 在强化学习训练过程中,策略网络需要先 与环境交互采样一段时间的轨迹(rollout),然后再利用这些数据进行策略更新。
- 因此需要一个专门的 数据缓冲区(buffer) 来保存这些信息。
- 在
rsl_rl中,这个角色由 RolloutStorage 完成。
- 在
- 它主要负责三个功能:
- 存储采样轨迹
- 计算 PPO 所需的 Return 和 Advantage
- 生成 mini-batch
- 重点字段:
python
self.observations # Actor 输入 obs
self.privileged_observations # Critic 输入 priv obs
self.actions # 动作
self.rewards # 奖励
self.values # Critic 估计值
self.returns # 折扣回报
self.advantages # Advantage (GAE)
self.mu, self.sigma # 高斯策略参数
self.saved_hidden_states_a/c # RNN 隐藏状态
- 核心方法
python
compute_returns() # 利用 GAE 递推计算 Advantage
mini_batch_generator() # 把 rollouts 变成训练 mini-batch
1-2 完整代码一栏
python
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
import torch
import numpy as np
from rsl_rl.utils import split_and_pad_trajectories
class RolloutStorage:
class Transition:
def __init__(self):
self.observations = None
self.critic_observations = None
self.actions = None
self.rewards = None
self.dones = None
self.values = None
self.actions_log_prob = None
self.action_mean = None
self.action_sigma = None
self.hidden_states = None
def clear(self):
self.__init__()
def __init__(self, num_envs, num_transitions_per_env, obs_shape, privileged_obs_shape, actions_shape, device='cpu'):
self.device = device
self.obs_shape = obs_shape
self.privileged_obs_shape = privileged_obs_shape
self.actions_shape = actions_shape
# Core
self.observations = torch.zeros(num_transitions_per_env, num_envs, *obs_shape, device=self.device)
if privileged_obs_shape[0] is not None:
self.privileged_observations = torch.zeros(num_transitions_per_env, num_envs, *privileged_obs_shape, device=self.device)
else:
self.privileged_observations = None
self.rewards = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device).byte()
# For PPO
self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
self.returns = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
self.mu = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
self.sigma = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
self.num_transitions_per_env = num_transitions_per_env
self.num_envs = num_envs
# rnn
self.saved_hidden_states_a = None
self.saved_hidden_states_c = None
self.step = 0
def add_transitions(self, transition: Transition):
if self.step >= self.num_transitions_per_env:
raise AssertionError("Rollout buffer overflow")
self.observations[self.step].copy_(transition.observations)
if self.privileged_observations is not None: self.privileged_observations[self.step].copy_(transition.critic_observations)
self.actions[self.step].copy_(transition.actions)
self.rewards[self.step].copy_(transition.rewards.view(-1, 1))
self.dones[self.step].copy_(transition.dones.view(-1, 1))
self.values[self.step].copy_(transition.values)
self.actions_log_prob[self.step].copy_(transition.actions_log_prob.view(-1, 1))
self.mu[self.step].copy_(transition.action_mean)
self.sigma[self.step].copy_(transition.action_sigma)
self._save_hidden_states(transition.hidden_states)
self.step += 1
def _save_hidden_states(self, hidden_states):
if hidden_states is None or hidden_states==(None, None):
return
# make a tuple out of GRU hidden state sto match the LSTM format
hid_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],)
hid_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],)
# initialize if needed
if self.saved_hidden_states_a is None:
self.saved_hidden_states_a = [torch.zeros(self.observations.shape[0], *hid_a[i].shape, device=self.device) for i in range(len(hid_a))]
self.saved_hidden_states_c = [torch.zeros(self.observations.shape[0], *hid_c[i].shape, device=self.device) for i in range(len(hid_c))]
# copy the states
for i in range(len(hid_a)):
self.saved_hidden_states_a[i][self.step].copy_(hid_a[i])
self.saved_hidden_states_c[i][self.step].copy_(hid_c[i])
def clear(self):
self.step = 0
def compute_returns(self, last_values, gamma, lam):
advantage = 0
for step in reversed(range(self.num_transitions_per_env)):
if step == self.num_transitions_per_env - 1:
next_values = last_values
else:
next_values = self.values[step + 1]
next_is_not_terminal = 1.0 - self.dones[step].float()
delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step]
advantage = delta + next_is_not_terminal * gamma * lam * advantage
self.returns[step] = advantage + self.values[step]
# Compute and normalize the advantages
self.advantages = self.returns - self.values
self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)
def get_statistics(self):
done = self.dones
done[-1] = 1
flat_dones = done.permute(1, 0, 2).reshape(-1, 1)
done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero(as_tuple=False)[:, 0]))
trajectory_lengths = (done_indices[1:] - done_indices[:-1])
return trajectory_lengths.float().mean(), self.rewards.mean()
def mini_batch_generator(self, num_mini_batches, num_epochs=8):
batch_size = self.num_envs * self.num_transitions_per_env
mini_batch_size = batch_size // num_mini_batches
indices = torch.randperm(num_mini_batches*mini_batch_size, requires_grad=False, device=self.device)
observations = self.observations.flatten(0, 1)
if self.privileged_observations is not None:
critic_observations = self.privileged_observations.flatten(0, 1)
else:
critic_observations = observations
actions = self.actions.flatten(0, 1)
values = self.values.flatten(0, 1)
returns = self.returns.flatten(0, 1)
old_actions_log_prob = self.actions_log_prob.flatten(0, 1)
advantages = self.advantages.flatten(0, 1)
old_mu = self.mu.flatten(0, 1)
old_sigma = self.sigma.flatten(0, 1)
for epoch in range(num_epochs):
for i in range(num_mini_batches):
start = i*mini_batch_size
end = (i+1)*mini_batch_size
batch_idx = indices[start:end]
obs_batch = observations[batch_idx]
critic_observations_batch = critic_observations[batch_idx]
actions_batch = actions[batch_idx]
target_values_batch = values[batch_idx]
returns_batch = returns[batch_idx]
old_actions_log_prob_batch = old_actions_log_prob[batch_idx]
advantages_batch = advantages[batch_idx]
old_mu_batch = old_mu[batch_idx]
old_sigma_batch = old_sigma[batch_idx]
yield obs_batch, critic_observations_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, \
old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, (None, None), None
# for RNNs only
def reccurent_mini_batch_generator(self, num_mini_batches, num_epochs=8):
padded_obs_trajectories, trajectory_masks = split_and_pad_trajectories(self.observations, self.dones)
if self.privileged_observations is not None:
padded_critic_obs_trajectories, _ = split_and_pad_trajectories(self.privileged_observations, self.dones)
else:
padded_critic_obs_trajectories = padded_obs_trajectories
mini_batch_size = self.num_envs // num_mini_batches
for ep in range(num_epochs):
first_traj = 0
for i in range(num_mini_batches):
start = i*mini_batch_size
stop = (i+1)*mini_batch_size
dones = self.dones.squeeze(-1)
last_was_done = torch.zeros_like(dones, dtype=torch.bool)
last_was_done[1:] = dones[:-1]
last_was_done[0] = True
trajectories_batch_size = torch.sum(last_was_done[:, start:stop])
last_traj = first_traj + trajectories_batch_size
masks_batch = trajectory_masks[:, first_traj:last_traj]
obs_batch = padded_obs_trajectories[:, first_traj:last_traj]
critic_obs_batch = padded_critic_obs_trajectories[:, first_traj:last_traj]
actions_batch = self.actions[:, start:stop]
old_mu_batch = self.mu[:, start:stop]
old_sigma_batch = self.sigma[:, start:stop]
returns_batch = self.returns[:, start:stop]
advantages_batch = self.advantages[:, start:stop]
values_batch = self.values[:, start:stop]
old_actions_log_prob_batch = self.actions_log_prob[:, start:stop]
# reshape to [num_envs, time, num layers, hidden dim] (original shape: [time, num_layers, num_envs, hidden_dim])
# then take only time steps after dones (flattens num envs and time dimensions),
# take a batch of trajectories and finally reshape back to [num_layers, batch, hidden_dim]
last_was_done = last_was_done.permute(1, 0)
hid_a_batch = [ saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj].transpose(1, 0).contiguous()
for saved_hidden_states in self.saved_hidden_states_a ]
hid_c_batch = [ saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj].transpose(1, 0).contiguous()
for saved_hidden_states in self.saved_hidden_states_c ]
# remove the tuple for GRU
hid_a_batch = hid_a_batch[0] if len(hid_a_batch)==1 else hid_a_batch
hid_c_batch = hid_c_batch[0] if len(hid_c_batch)==1 else hid_a_batch
yield obs_batch, critic_obs_batch, actions_batch, values_batch, advantages_batch, returns_batch, \
old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, (hid_a_batch, hid_c_batch), masks_batch
first_traj = last_traj
1-3 初始化函数
python
def __init__(self, num_envs, num_transitions_per_env, obs_shape, privileged_obs_shape, actions_shape, device='cpu'):
self.device = device
self.obs_shape = obs_shape
self.privileged_obs_shape = privileged_obs_shape
self.actions_shape = actions_shape
# Core
self.observations = torch.zeros(num_transitions_per_env, num_envs, *obs_shape, device=self.device)
if privileged_obs_shape[0] is not None:
self.privileged_observations = torch.zeros(num_transitions_per_env, num_envs, *privileged_obs_shape, device=self.device)
else:
self.privileged_observations = None
self.rewards = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device).byte()
# For PPO
self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
self.returns = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
self.mu = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
self.sigma = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
self.num_transitions_per_env = num_transitions_per_env
self.num_envs = num_envs
# rnn
self.saved_hidden_states_a = None
self.saved_hidden_states_c = None
self.step = 0
- 可以看到初始化包含了很多采样轨迹需要记录的参数,包括:
- 核心参数
self.obs_shape:观测空间维度(Actor输入的状态维度)self.privileged_obs_shape:特权观测维度(Critic可能使用的额外观测信息)self.actions_shape:动作空间维度self.observations:存储每一步环境观测 s t s_t stself.privileged_observations:存储 Critic 使用的特权观测self.rewards:存储环境返回的奖励 r t r_t rtself.actions:存储策略网络输出的动作 a t a_t atself.dones:记录每一步是否结束 episode(环境是否 reset)
- PPO参数
self.actions_log_prob:旧策略产生动作的对数概率 l o g π o l d ( a ∣ s ) log π_old(a|s) logπold(a∣s)self.values:Critic 预测的状态价值 V ( s ) V(s) V(s)self.returns:折扣回报 R t R_t Rt,由compute_returns()计算得到self.advantages:优势函数 A t A_t At,用于 PPO 的策略梯度更新self.mu:高斯策略的动作均值 μ μ μself.sigma:高斯策略的动作标准差 σ σ σself.num_transitions_per_env:每个环境采样的时间步长度(rollout长度)self.num_envs:并行环境数量
- RNN参数
self.saved_hidden_states_a:Actor 网络的隐藏状态self.saved_hidden_states_c:Critic 网络的隐藏状态self.step:当前 rollout 中已经采样的时间步,用于控制数据写入位置
- 其中:
self.dones的结构是: [ T , N , 1 ] [T, N, 1] [T,N,1]T:时间步(rollout长度)N:并行环境数量
1-4 add_transitions()
- 函数的作用:将一次环境交互产生的数据(Transition)写入 RolloutStorage。
python
def add_transitions(self, transition: Transition):
if self.step >= self.num_transitions_per_env:
raise AssertionError("Rollout buffer overflow")
self.observations[self.step].copy_(transition.observations)
if self.privileged_observations is not None: self.privileged_observations[self.step].copy_(transition.critic_observations)
self.actions[self.step].copy_(transition.actions)
self.rewards[self.step].copy_(transition.rewards.view(-1, 1))
self.dones[self.step].copy_(transition.dones.view(-1, 1))
self.values[self.step].copy_(transition.values)
self.actions_log_prob[self.step].copy_(transition.actions_log_prob.view(-1, 1))
self.mu[self.step].copy_(transition.action_mean)
self.sigma[self.step].copy_(transition.action_sigma)
self._save_hidden_states(transition.hidden_states)
self.step += 1
- 包括:
python
obs([T , N , obs_dim])
action
reward
value
log_prob
done
hidden_state
- 这些都是 PPO更新必须用到的数据。
1-5 _save_hidden_states()
- 函数的作用:将当前时间步的 RNN hidden state 存入 RolloutStorage
- 这个是配合上一期我们提到的
ActorCriticRecurrent启用时候需要保存RNN的隐藏层
python
def _save_hidden_states(self, hidden_states):
if hidden_states is None or hidden_states==(None, None):
return
# make a tuple out of GRU hidden state sto match the LSTM format
hid_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],)
hid_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],)
# initialize if needed
if self.saved_hidden_states_a is None:
self.saved_hidden_states_a = [torch.zeros(self.observations.shape[0], *hid_a[i].shape, device=self.device) for i in range(len(hid_a))]
self.saved_hidden_states_c = [torch.zeros(self.observations.shape[0], *hid_c[i].shape, device=self.device) for i in range(len(hid_c))]
# copy the states
for i in range(len(hid_a)):
self.saved_hidden_states_a[i][self.step].copy_(hid_a[i])
self.saved_hidden_states_c[i][self.step].copy_(hid_c[i])
1-6 核心函数compute_returns()
- 这个函数是
PPO算法的核心之处,他做了两件事:- 计算
Return - 使用
GAE计算Advantage
- 计算
python
def compute_returns(self, last_values, gamma, lam):
advantage = 0
for step in reversed(range(self.num_transitions_per_env)):
if step == self.num_transitions_per_env - 1:
next_values = last_values
else:
next_values = self.values[step + 1]
next_is_not_terminal = 1.0 - self.dones[step].float()
delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step]
advantage = delta + next_is_not_terminal * gamma * lam * advantage
self.returns[step] = advantage + self.values[step]
# Compute and normalize the advantages
self.advantages = self.returns - self.values
self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)
- 这里我们可以分开看每一步
1-6-1 反向遍历 Rollout
python
for step in reversed(range(self.num_transitions_per_env)):
- 这里我们使用
reversed反向遍历 Rollout,原因是 GAE需要递推 : A t = δ t + γ λ A t + 1 A_t = \delta_t + \gamma\lambda A_{t+1} At=δt+γλAt+1所以必须 从最后一步开始计算
1-6-2 获取下一步价值
python
if step == self.num_transitions_per_env - 1:
next_values = last_values
else:
next_values = self.values[step + 1]
- 这里计算 V ( s t + 1 ) V(s_{t+1}) V(st+1)
- 如果是最后一步的话需要设置为 V ( s t ) V(s_t) V(st)
1-6-3 判断 episode 是否结束
python
next_is_not_terminal = 1.0 - self.dones[step].float()
- 如果
done==1,说明说明 episode 结束。next_is_not_terminal就会变成0
1-6-4 核心:判断计算 TD Error
python
delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step]
- 对应公式里头的 δ t = r t + γ V ( s t + 1 ) − V ( s t ) \delta_t = r_t + \gamma V(s_{t+1})- V(s_t) δt=rt+γV(st+1)−V(st)
- 这里考虑到考虑 episode 终止,故补上上一步计算的
next_is_not_terminal: δ t = r t + ( 1 − d t ) γ V ( s t + 1 ) − V ( s t ) \delta_t = r_t + (1-d_t)\gamma V(s_{t+1}) - V(s_t) δt=rt+(1−dt)γV(st+1)−V(st)
1-6-5 核心:计算 GAE Advantage
python
advantage = delta + next_is_not_terminal * gamma * lam * advantage
- 对应公式: A t = δ t + ( 1 − d t ) γ λ A t + 1 A_t = \delta_t + (1-d_t)\gamma\lambda A_{t+1} At=δt+(1−dt)γλAt+1其中:
- γ \gamma γ:控制长期奖励
- λ \lambda λ: 控制 bias-variance tradeoff
1-6-6 核心:计算 Return以及 Advantage,以及Advantage标准化
python
self.returns[step] = advantage + self.values[step]
# Compute and normalize the advantages
self.advantages = self.returns - self.values
self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)
- 对应公式: A t = R t − V ( s t ) A_t=R_t−V(s_t) At=Rt−V(st)
- Advantage标准化对应公式: A t = A t − μ A σ A + ϵ A_t = \frac{A_t - \mu_A}{\sigma_A + \epsilon} At=σA+ϵAt−μA
- 作用是为了
- 稳定训练
- 避免梯度过大
- 减少 variance
1-7 统计函数函数
- 这个函数用于计算平均 episode 长度和平均 reward
python
def get_statistics(self):
done = self.dones
done[-1] = 1
flat_dones = done.permute(1, 0, 2).reshape(-1, 1)
done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero(as_tuple=False)[:, 0]))
trajectory_lengths = (done_indices[1:] - done_indices[:-1])
return trajectory_lengths.float().mean(), self.rewards.mean()
- 再次回顾
done的结构是[T,N,1] done[-1] = 1,确保 每条轨迹一定会结束,也就是手动rollout末尾当作episode结束flat_dones翻转维度变为[N×T,1],现在所有 done 被展平为一条序列done_indices在所有 done 前面加一个-1,这样做是为了方便计算 轨迹长度trajectory_lengths:计算轨迹长度 L i = d i − d i − 1 L_i=d_i-d_{i-1} Li=di−di−1trajectory_lengths.float().mean():计算平均轨迹长度 L ˉ = 1 K ∑ i = 1 K L i \bar{L} = \frac{1}{K}\sum_{i=1}^{K} L_i Lˉ=K1∑i=1KLi
1-8 核心函数`mini_batch_generator()
- 作用就是把 Rollout 收集到的大 batch 数据 → 打乱 → 切成 mini-batch → 多轮训练。
num_mini_batches:一个 batch 切成多少小批num_epochs:同一批数据训练多少轮
python
def mini_batch_generator(self, num_mini_batches, num_epochs=8):
batch_size = self.num_envs * self.num_transitions_per_env
mini_batch_size = batch_size // num_mini_batches
indices = torch.randperm(num_mini_batches*mini_batch_size, requires_grad=False, device=self.device)
observations = self.observations.flatten(0, 1)
if self.privileged_observations is not None:
critic_observations = self.privileged_observations.flatten(0, 1)
else:
critic_observations = observations
actions = self.actions.flatten(0, 1)
values = self.values.flatten(0, 1)
returns = self.returns.flatten(0, 1)
old_actions_log_prob = self.actions_log_prob.flatten(0, 1)
advantages = self.advantages.flatten(0, 1)
old_mu = self.mu.flatten(0, 1)
old_sigma = self.sigma.flatten(0, 1)
for epoch in range(num_epochs):
for i in range(num_mini_batches):
start = i*mini_batch_size
end = (i+1)*mini_batch_size
batch_idx = indices[start:end]
obs_batch = observations[batch_idx]
critic_observations_batch = critic_observations[batch_idx]
actions_batch = actions[batch_idx]
target_values_batch = values[batch_idx]
returns_batch = returns[batch_idx]
old_actions_log_prob_batch = old_actions_log_prob[batch_idx]
advantages_batch = advantages[batch_idx]
old_mu_batch = old_mu[batch_idx]
old_sigma_batch = old_sigma[batch_idx]
yield obs_batch, critic_observations_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, \
old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, (None, None), None
- 返回的数据类型:
python
obs_batch
critic_obs_batch
actions_batch
target_values_batch
advantages_batch
returns_batch
old_log_prob
old_mu
old_sigma
- 这些就是 PPO loss 计算的全部输入。
1-9 小结
- 我们可以在
ppo.py看到
python
def compute_returns(self, last_critic_obs):
last_values= self.actor_critic.evaluate(last_critic_obs).detach()
self.storage.compute_returns(last_values, self.gamma, self.lam)
或者
python
generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
-
整体的流程大概是:
环境交互
↓
RolloutStorage
↓
compute_returns()
↓
mini_batch_generator()
↓
PPO update
↓
loss计算
2 VecEnv
2-1 介绍
VecEnv(Vectorized Environment) 是多环境并行接口- 功能:
- 一次步进多个环境 → 高 GPU 利用率
- 缓存
obs、privileged_obs、奖励、done等信息
RolloutStorage会把VecEnv输出的数据存下来,形成完整的rollout
2-2 完整代码实现
python
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
from abc import ABC, abstractmethod
import torch
from typing import Tuple, Union
# minimal interface of the environment
class VecEnv(ABC):
num_envs: int
num_obs: int
num_privileged_obs: int
num_actions: int
max_episode_length: int
privileged_obs_buf: torch.Tensor
obs_buf: torch.Tensor
rew_buf: torch.Tensor
reset_buf: torch.Tensor
episode_length_buf: torch.Tensor # current episode duration
extras: dict
device: torch.device
@abstractmethod
def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, dict]:
pass
@abstractmethod
def reset(self, env_ids: Union[list, torch.Tensor]):
pass
@abstractmethod
def get_observations(self) -> torch.Tensor:
pass
@abstractmethod
def get_privileged_observations(self) -> Union[torch.Tensor, None]:
pass
- 这个类很短我们直接看参数定义:
num_envs:并行环境数量num_obs:观测空间维度num_privileged_obs:特权观测维度num_actions:动作维度max_episode_length:最大 episode 长度,一个 episode 最多运行多少 step- 下面这些参数都是GPU tensor buffer
obs_buf:观测缓存privileged_obs_buf:特权观测缓存rew_buf:奖励缓存reset 标志:reset 标志缓存episode_length_buf: episode 长度缓存extras:存储一些额外信息device:运行设备
- 然后是几个虚函数
step():环境推进每一步,对应公式 E n v ( s t , a t ) s t + 1 , r t = E n v ( s t , a t ) Env(st,at)s_{t+1}, r_t = Env(s_t, a_t) Env(st,at)st+1,rt=Env(st,at)reset():重置环境
3 OnPolicyRunner
3-1 介绍
- 它是一个 通用强化学习训练器 ,用于在 Vectorized 环境(
VecEnv)中训练 PPO 或其他 Actor-Critic 算法。
3-2 完整代码实现
python
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
import time
import os
from collections import deque
import statistics
from torch.utils.tensorboard import SummaryWriter
import torch
from rsl_rl.algorithms import PPO
from rsl_rl.modules import ActorCritic, ActorCriticRecurrent
from rsl_rl.env import VecEnv
class OnPolicyRunner:
def __init__(self,
env: VecEnv,
train_cfg,
log_dir=None,
device='cpu'):
self.cfg=train_cfg["runner"]
self.alg_cfg = train_cfg["algorithm"]
self.policy_cfg = train_cfg["policy"]
self.device = device
self.env = env
if self.env.num_privileged_obs is not None:
num_critic_obs = self.env.num_privileged_obs
else:
num_critic_obs = self.env.num_obs
actor_critic_class = eval(self.cfg["policy_class_name"]) # ActorCritic
actor_critic: ActorCritic = actor_critic_class( self.env.num_obs,
num_critic_obs,
self.env.num_actions,
**self.policy_cfg).to(self.device)
alg_class = eval(self.cfg["algorithm_class_name"]) # PPO
self.alg: PPO = alg_class(actor_critic, device=self.device, **self.alg_cfg)
self.num_steps_per_env = self.cfg["num_steps_per_env"]
self.save_interval = self.cfg["save_interval"]
# init storage and model
self.alg.init_storage(self.env.num_envs, self.num_steps_per_env, [self.env.num_obs], [self.env.num_privileged_obs], [self.env.num_actions])
# Log
self.log_dir = log_dir
self.writer = None
self.tot_timesteps = 0
self.tot_time = 0
self.current_learning_iteration = 0
_, _ = self.env.reset()
def learn(self, num_learning_iterations, init_at_random_ep_len=False):
# initialize writer
if self.log_dir is not None and self.writer is None:
self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10)
if init_at_random_ep_len:
self.env.episode_length_buf = torch.randint_like(self.env.episode_length_buf, high=int(self.env.max_episode_length))
obs = self.env.get_observations()
privileged_obs = self.env.get_privileged_observations()
critic_obs = privileged_obs if privileged_obs is not None else obs
obs, critic_obs = obs.to(self.device), critic_obs.to(self.device)
self.alg.actor_critic.train() # switch to train mode (for dropout for example)
ep_infos = []
rewbuffer = deque(maxlen=100)
lenbuffer = deque(maxlen=100)
cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
tot_iter = self.current_learning_iteration + num_learning_iterations
for it in range(self.current_learning_iteration, tot_iter):
start = time.time()
# Rollout
with torch.inference_mode():
for i in range(self.num_steps_per_env):
actions = self.alg.act(obs, critic_obs)
obs, privileged_obs, rewards, dones, infos = self.env.step(actions)
critic_obs = privileged_obs if privileged_obs is not None else obs
obs, critic_obs, rewards, dones = obs.to(self.device), critic_obs.to(self.device), rewards.to(self.device), dones.to(self.device)
self.alg.process_env_step(rewards, dones, infos)
if self.log_dir is not None:
# Book keeping
if 'episode' in infos:
ep_infos.append(infos['episode'])
cur_reward_sum += rewards
cur_episode_length += 1
new_ids = (dones > 0).nonzero(as_tuple=False)
rewbuffer.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist())
lenbuffer.extend(cur_episode_length[new_ids][:, 0].cpu().numpy().tolist())
cur_reward_sum[new_ids] = 0
cur_episode_length[new_ids] = 0
stop = time.time()
collection_time = stop - start
# Learning step
start = stop
self.alg.compute_returns(critic_obs)
mean_value_loss, mean_surrogate_loss = self.alg.update()
stop = time.time()
learn_time = stop - start
if self.log_dir is not None:
self.log(locals())
if it % self.save_interval == 0:
self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(it)))
ep_infos.clear()
self.current_learning_iteration += num_learning_iterations
self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(self.current_learning_iteration)))
def log(self, locs, width=80, pad=35):
self.tot_timesteps += self.num_steps_per_env * self.env.num_envs
self.tot_time += locs['collection_time'] + locs['learn_time']
iteration_time = locs['collection_time'] + locs['learn_time']
ep_string = f''
if locs['ep_infos']:
for key in locs['ep_infos'][0]:
infotensor = torch.tensor([], device=self.device)
for ep_info in locs['ep_infos']:
# handle scalar and zero dimensional tensor infos
if not isinstance(ep_info[key], torch.Tensor):
ep_info[key] = torch.Tensor([ep_info[key]])
if len(ep_info[key].shape) == 0:
ep_info[key] = ep_info[key].unsqueeze(0)
infotensor = torch.cat((infotensor, ep_info[key].to(self.device)))
value = torch.mean(infotensor)
self.writer.add_scalar('Episode/' + key, value, locs['it'])
ep_string += f"""{f'Mean episode {key}:':>{pad}} {value:.4f}\n"""
mean_std = self.alg.actor_critic.std.mean()
fps = int(self.num_steps_per_env * self.env.num_envs / (locs['collection_time'] + locs['learn_time']))
self.writer.add_scalar('Loss/value_function', locs['mean_value_loss'], locs['it'])
self.writer.add_scalar('Loss/surrogate', locs['mean_surrogate_loss'], locs['it'])
self.writer.add_scalar('Loss/learning_rate', self.alg.learning_rate, locs['it'])
self.writer.add_scalar('Policy/mean_noise_std', mean_std.item(), locs['it'])
self.writer.add_scalar('Perf/total_fps', fps, locs['it'])
self.writer.add_scalar('Perf/collection time', locs['collection_time'], locs['it'])
self.writer.add_scalar('Perf/learning_time', locs['learn_time'], locs['it'])
if len(locs['rewbuffer']) > 0:
self.writer.add_scalar('Train/mean_reward', statistics.mean(locs['rewbuffer']), locs['it'])
self.writer.add_scalar('Train/mean_episode_length', statistics.mean(locs['lenbuffer']), locs['it'])
self.writer.add_scalar('Train/mean_reward/time', statistics.mean(locs['rewbuffer']), self.tot_time)
self.writer.add_scalar('Train/mean_episode_length/time', statistics.mean(locs['lenbuffer']), self.tot_time)
str = f" \033[1m Learning iteration {locs['it']}/{self.current_learning_iteration + locs['num_learning_iterations']} \033[0m "
if len(locs['rewbuffer']) > 0:
log_string = (f"""{'#' * width}\n"""
f"""{str.center(width, ' ')}\n\n"""
f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n"""
f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n"""
f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n"""
f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n""")
# f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
# f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""")
else:
log_string = (f"""{'#' * width}\n"""
f"""{str.center(width, ' ')}\n\n"""
f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n"""
f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n"""
f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""")
# f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
# f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""")
log_string += ep_string
log_string += (f"""{'-' * width}\n"""
f"""{'Total timesteps:':>{pad}} {self.tot_timesteps}\n"""
f"""{'Iteration time:':>{pad}} {iteration_time:.2f}s\n"""
f"""{'Total time:':>{pad}} {self.tot_time:.2f}s\n"""
f"""{'ETA:':>{pad}} {self.tot_time / (locs['it'] + 1) * (
locs['num_learning_iterations'] - locs['it']):.1f}s\n""")
print(log_string)
def save(self, path, infos=None):
torch.save({
'model_state_dict': self.alg.actor_critic.state_dict(),
'optimizer_state_dict': self.alg.optimizer.state_dict(),
'iter': self.current_learning_iteration,
'infos': infos,
}, path)
def load(self, path, load_optimizer=True):
loaded_dict = torch.load(path)
self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'])
if load_optimizer:
self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'])
self.current_learning_iteration = loaded_dict['iter']
return loaded_dict['infos']
def get_inference_policy(self, device=None):
self.alg.actor_critic.eval() # switch to evaluation mode (dropout for example)
if device is not None:
self.alg.actor_critic.to(device)
return self.alg.actor_critic.act_inference
3-3 初始化函数
python
def __init__(self,
env: VecEnv,
train_cfg,
log_dir=None,
device='cpu'):
self.cfg=train_cfg["runner"]
self.alg_cfg = train_cfg["algorithm"]
self.policy_cfg = train_cfg["policy"]
self.device = device
self.env = env
if self.env.num_privileged_obs is not None:
num_critic_obs = self.env.num_privileged_obs
else:
num_critic_obs = self.env.num_obs
actor_critic_class = eval(self.cfg["policy_class_name"]) # ActorCritic
actor_critic: ActorCritic = actor_critic_class( self.env.num_obs,
num_critic_obs,
self.env.num_actions,
**self.policy_cfg).to(self.device)
alg_class = eval(self.cfg["algorithm_class_name"]) # PPO
self.alg: PPO = alg_class(actor_critic, device=self.device, **self.alg_cfg)
self.num_steps_per_env = self.cfg["num_steps_per_env"]
self.save_interval = self.cfg["save_interval"]
# init storage and model
self.alg.init_storage(self.env.num_envs, self.num_steps_per_env, [self.env.num_obs], [self.env.num_privileged_obs], [self.env.num_actions])
# Log
self.log_dir = log_dir
self.writer = None
self.tot_timesteps = 0
self.tot_time = 0
self.current_learning_iteration = 0
_, _ = self.env.reset()
- 我们来看输入的超参:
env: 训练环境(支持多环境并行的 VecEnv)train_cfg: 配置字典,包括 runner 配置、算法配置、策略配置log_dir: 日志目录,用于 TensorBoard 可视化device: PyTorch 设备(CPU/GPU)
- 初始化函数进行了一些基础参数的配置,
-
读取 Runner、算法、策略配置
-
判断 Critic 的输入维度
- Actor 只看普通观测
num_obs
- Critic 可使用特权观测
num_privileged_obs,否则用普通观测
- Actor 只看普通观测
-
初始化 Actor-Critic 网络
-
初始化 PPO 算法
-
初始化存储(Trajectory Storage)
-
初始化日志和计数器
3-4 核心训练函数learn()
python
def learn(self, num_learning_iterations, init_at_random_ep_len=False):
# initialize writer
if self.log_dir is not None and self.writer is None:
self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10)
if init_at_random_ep_len:
self.env.episode_length_buf = torch.randint_like(self.env.episode_length_buf, high=int(self.env.max_episode_length))
obs = self.env.get_observations()
privileged_obs = self.env.get_privileged_observations()
critic_obs = privileged_obs if privileged_obs is not None else obs
obs, critic_obs = obs.to(self.device), critic_obs.to(self.device)
self.alg.actor_critic.train() # switch to train mode (for dropout for example)
ep_infos = []
rewbuffer = deque(maxlen=100)
lenbuffer = deque(maxlen=100)
cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
tot_iter = self.current_learning_iteration + num_learning_iterations
for it in range(self.current_learning_iteration, tot_iter):
start = time.time()
# Rollout
with torch.inference_mode():
for i in range(self.num_steps_per_env):
actions = self.alg.act(obs, critic_obs)
obs, privileged_obs, rewards, dones, infos = self.env.step(actions)
critic_obs = privileged_obs if privileged_obs is not None else obs
obs, critic_obs, rewards, dones = obs.to(self.device), critic_obs.to(self.device), rewards.to(self.device), dones.to(self.device)
self.alg.process_env_step(rewards, dones, infos)
if self.log_dir is not None:
# Book keeping
if 'episode' in infos:
ep_infos.append(infos['episode'])
cur_reward_sum += rewards
cur_episode_length += 1
new_ids = (dones > 0).nonzero(as_tuple=False)
rewbuffer.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist())
lenbuffer.extend(cur_episode_length[new_ids][:, 0].cpu().numpy().tolist())
cur_reward_sum[new_ids] = 0
cur_episode_length[new_ids] = 0
stop = time.time()
collection_time = stop - start
# Learning step
start = stop
self.alg.compute_returns(critic_obs)
mean_value_loss, mean_surrogate_loss = self.alg.update()
stop = time.time()
learn_time = stop - start
if self.log_dir is not None:
self.log(locals())
if it % self.save_interval == 0:
self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(it)))
ep_infos.clear()
self.current_learning_iteration += num_learning_iterations
self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(self.current_learning_iteration)))
3-4-1 函数输入
python
def learn(self, num_learning_iterations, init_at_random_ep_len=False):
num_learning_iterations: 总训练迭代次数init_at_random_ep_len: 是否随机初始化 episode 长度(用于训练多样化)
3-4-2 获取初始观测
python
obs = self.env.get_observations()
privileged_obs = self.env.get_privileged_observations()
critic_obs = privileged_obs if privileged_obs is not None else obs
obs, critic_obs = obs.to(self.device), critic_obs.to(self.device)
Actor使用obs
Critic使用critic_obs,优先使用特权观测- 这是 Asymmetric Actor-Critic 的典型用法(0-4提到的)
3-4-3 Rollout(采样动作、与环境交互)
python
for i in range(self.num_steps_per_env):
actions = self.alg.act(obs, critic_obs)
obs, privileged_obs, rewards, dones, infos = self.env.step(actions)
critic_obs = privileged_obs if privileged_obs is not None else obs
obs, critic_obs, rewards, dones = obs.to(self.device), critic_obs.to(self.device), rewards.to(self.device), dones.to(self.device)
self.alg.process_env_step(rewards, dones, infos)
Actor根据当前 obs 生成动作Env返回新的obs、reward、doneCritic更新critic_obsprocess_env_step:把 step 数据存入存储(Storage)以便后续 PPO 更新
3-4-4 PPO 学习
python
self.alg.compute_returns(critic_obs)
mean_value_loss, mean_surrogate_loss = self.alg.update()
- compute_returns:计算优势函数(GAE)和折扣回报
- update:使用 PPO 对 Actor-Critic 网络进行梯度更新
3-5 save(),load()以及推理函数
python
def save(self, path, infos=None):
torch.save({
'model_state_dict': self.alg.actor_critic.state_dict(),
'optimizer_state_dict': self.alg.optimizer.state_dict(),
'iter': self.current_learning_iteration,
'infos': infos,
}, path)
def load(self, path, load_optimizer=True):
loaded_dict = torch.load(path)
self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'])
if load_optimizer:
self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'])
self.current_learning_iteration = loaded_dict['iter']
return loaded_dict['infos']
def get_inference_policy(self, device=None):
self.alg.actor_critic.eval()
if device is not None:
self.alg.actor_critic.to(device)
return self.alg.actor_critic.act_inference
- 顾名思义保存加载模型。
3-6 整体流程总结
python
VecEnv(多个并行环境)
│
▼
采样数据 (obs / privileged_obs / action / reward / done)
│
▼
RolloutStorage.add_transitions()
│
▼
compute_returns() # GAE Advantage
│
▼
mini_batch_generator()
│
▼
PPO update (Actor-Critic)
│
▼
更新策略网络和 Critic 网络

小结
- 本期我们讲解了 Asymmetric Actor-Critic 的原理、RolloutStorage 的数据存储与 GAE Advantage 计算、以及 VecEnv 并行环境管理,完整梳理了 PPO 训练流程中的数据采集与 mini-batch 生成。
- 自此我们前三期把
rsl_rl的核心源码都解析完了,下一步就是开始对应使用到rsl_rl的unitree_rl_gym进行源码解析了,真正涉及到训练和奖励函数的编写。 - 如有错误,欢迎指出!感谢观看