【宇树机器人强化学习】(三):OnPolicyRunner和VecEnv以及RolloutStorage的python实现与解析

前言


0 前置知识:观测与特权观测

0-1 介绍
  • 在强化学习中,观测(Observation) 是 Actor 网络的输入,决定了策略可以看到环境的哪些信息。
  • 特权观测(Privileged Observation) 则通常用于 Critic 网络,它可以访问更多、更精确的状态信息,以便更好地估计价值函数(Value Function),从而提升训练稳定性和收敛速度。
  • 那这个项目举例子的话:

0-2 观测(Observation)
  • 对于四足机器人环境,例如 Go2、H1 等,观测一般包含以下几类信息:
类型 说明 示例维度
关节位置 当前各关节角度 12
关节速度 各关节角速度 12
基座线速度 机器人底盘在 x, y, z 方向的速度 3
基座角速度 机器人底盘绕 x, y, z 轴的角速度 3
IMU 加速度 加速度传感器读数 3
IMU 角速度 陀螺仪读数 3
末端足接触 各足是否接触地面(触觉传感器) 4
目标信息 目标位置或运动指令 3

0-3 特权观测(Privileged Observation)
  • 特权观测是 Critic 独占的,它可以看到更多信息,包括一些 Actor 无法获取的"全局状态"。常见内容:
类型 说明 示例维度
完整机器人状态 包含底盘位姿(位置 + 姿态) 6
关节力矩 各关节实际力矩 12
地形高度图 机器人下方或周围地形信息 12~20
物理全局状态 机器人每个关节和连杆的全局位置/速度 N/A
环境动态信息 例如障碍物位置、速度 可选

0-3 关系
  • 这样就非常明白了吧:

    Environment ──> observation ──> Actor ──> action ──> Environment
    └─> privileged_observation ──> Critic ──> value

  • Actor 看到的有限信息,保证策略可以在真实环境中执行

  • Critic 看到更多信息,帮助训练更准确的价值函数


0-4 Asymmetric Actor-Critic
  • Asymmetric 指的是 Actor 和 Critic 的输入观测不同
    • Actor:只能看到普通观测 obs(例如机器人传感器数据)
    • Critic:可以看到特权观测 privileged_obs(例如完整的环境状态、地形信息等)
  • 目的
    • Actor 训练出的策略可以只依赖有限信息部署
    • Critic 利用更多信息计算更准确的价值函数,从而加速训练
  • 本质 :它仍然是 Actor-Critic,只是 Critic 的输入比 Actor 多,适合 模拟到真实世界迁移(sim-to-real

1 RolloutStorage

1-1 介绍
  • 我们在第一期讲过,storage/rollout_storage.py用于存储采样轨迹(rollouts)
  • 在强化学习训练过程中,策略网络需要先 与环境交互采样一段时间的轨迹(rollout),然后再利用这些数据进行策略更新。
  • 因此需要一个专门的 数据缓冲区(buffer) 来保存这些信息。
    • rsl_rl 中,这个角色由 RolloutStorage 完成。
  • 它主要负责三个功能:
    1. 存储采样轨迹
    2. 计算 PPO 所需的 Return 和 Advantage
    3. 生成 mini-batch
  • 重点字段:
python 复制代码
self.observations            # Actor 输入 obs  
self.privileged_observations # Critic 输入 priv obs  
self.actions                 # 动作  
self.rewards                 # 奖励  
self.values                  # Critic 估计值  
self.returns                 # 折扣回报  
self.advantages              # Advantage (GAE)  
self.mu, self.sigma          # 高斯策略参数  
self.saved_hidden_states_a/c # RNN 隐藏状态
  • 核心方法
python 复制代码
compute_returns() # 利用 GAE 递推计算 Advantage
mini_batch_generator() # 把 rollouts 变成训练 mini-batch

1-2 完整代码一栏
python 复制代码
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# 
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin

import torch
import numpy as np

from rsl_rl.utils import split_and_pad_trajectories

class RolloutStorage:
    class Transition:
        def __init__(self):
            self.observations = None
            self.critic_observations = None
            self.actions = None
            self.rewards = None
            self.dones = None
            self.values = None
            self.actions_log_prob = None
            self.action_mean = None
            self.action_sigma = None
            self.hidden_states = None
        
        def clear(self):
            self.__init__()

    def __init__(self, num_envs, num_transitions_per_env, obs_shape, privileged_obs_shape, actions_shape, device='cpu'):

        self.device = device

        self.obs_shape = obs_shape
        self.privileged_obs_shape = privileged_obs_shape
        self.actions_shape = actions_shape

        # Core
        self.observations = torch.zeros(num_transitions_per_env, num_envs, *obs_shape, device=self.device)
        if privileged_obs_shape[0] is not None:
            self.privileged_observations = torch.zeros(num_transitions_per_env, num_envs, *privileged_obs_shape, device=self.device)
        else:
            self.privileged_observations = None
        self.rewards = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
        self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
        self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device).byte()

        # For PPO
        self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
        self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
        self.returns = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
        self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
        self.mu = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
        self.sigma = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)

        self.num_transitions_per_env = num_transitions_per_env
        self.num_envs = num_envs

        # rnn
        self.saved_hidden_states_a = None
        self.saved_hidden_states_c = None

        self.step = 0

    def add_transitions(self, transition: Transition):
        if self.step >= self.num_transitions_per_env:
            raise AssertionError("Rollout buffer overflow")
        self.observations[self.step].copy_(transition.observations)
        if self.privileged_observations is not None: self.privileged_observations[self.step].copy_(transition.critic_observations)
        self.actions[self.step].copy_(transition.actions)
        self.rewards[self.step].copy_(transition.rewards.view(-1, 1))
        self.dones[self.step].copy_(transition.dones.view(-1, 1))
        self.values[self.step].copy_(transition.values)
        self.actions_log_prob[self.step].copy_(transition.actions_log_prob.view(-1, 1))
        self.mu[self.step].copy_(transition.action_mean)
        self.sigma[self.step].copy_(transition.action_sigma)
        self._save_hidden_states(transition.hidden_states)
        self.step += 1

    def _save_hidden_states(self, hidden_states):
        if hidden_states is None or hidden_states==(None, None):
            return
        # make a tuple out of GRU hidden state sto match the LSTM format
        hid_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],)
        hid_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],)

        # initialize if needed 
        if self.saved_hidden_states_a is None:
            self.saved_hidden_states_a = [torch.zeros(self.observations.shape[0], *hid_a[i].shape, device=self.device) for i in range(len(hid_a))]
            self.saved_hidden_states_c = [torch.zeros(self.observations.shape[0], *hid_c[i].shape, device=self.device) for i in range(len(hid_c))]
        # copy the states
        for i in range(len(hid_a)):
            self.saved_hidden_states_a[i][self.step].copy_(hid_a[i])
            self.saved_hidden_states_c[i][self.step].copy_(hid_c[i])


    def clear(self):
        self.step = 0

    def compute_returns(self, last_values, gamma, lam):
        advantage = 0
        for step in reversed(range(self.num_transitions_per_env)):
            if step == self.num_transitions_per_env - 1:
                next_values = last_values
            else:
                next_values = self.values[step + 1]
            next_is_not_terminal = 1.0 - self.dones[step].float()
            delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step]
            advantage = delta + next_is_not_terminal * gamma * lam * advantage
            self.returns[step] = advantage + self.values[step]

        # Compute and normalize the advantages
        self.advantages = self.returns - self.values
        self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)

    def get_statistics(self):
        done = self.dones
        done[-1] = 1
        flat_dones = done.permute(1, 0, 2).reshape(-1, 1)
        done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero(as_tuple=False)[:, 0]))
        trajectory_lengths = (done_indices[1:] - done_indices[:-1])
        return trajectory_lengths.float().mean(), self.rewards.mean()

    def mini_batch_generator(self, num_mini_batches, num_epochs=8):
        batch_size = self.num_envs * self.num_transitions_per_env
        mini_batch_size = batch_size // num_mini_batches
        indices = torch.randperm(num_mini_batches*mini_batch_size, requires_grad=False, device=self.device)

        observations = self.observations.flatten(0, 1)
        if self.privileged_observations is not None:
            critic_observations = self.privileged_observations.flatten(0, 1)
        else:
            critic_observations = observations

        actions = self.actions.flatten(0, 1)
        values = self.values.flatten(0, 1)
        returns = self.returns.flatten(0, 1)
        old_actions_log_prob = self.actions_log_prob.flatten(0, 1)
        advantages = self.advantages.flatten(0, 1)
        old_mu = self.mu.flatten(0, 1)
        old_sigma = self.sigma.flatten(0, 1)

        for epoch in range(num_epochs):
            for i in range(num_mini_batches):

                start = i*mini_batch_size
                end = (i+1)*mini_batch_size
                batch_idx = indices[start:end]

                obs_batch = observations[batch_idx]
                critic_observations_batch = critic_observations[batch_idx]
                actions_batch = actions[batch_idx]
                target_values_batch = values[batch_idx]
                returns_batch = returns[batch_idx]
                old_actions_log_prob_batch = old_actions_log_prob[batch_idx]
                advantages_batch = advantages[batch_idx]
                old_mu_batch = old_mu[batch_idx]
                old_sigma_batch = old_sigma[batch_idx]
                yield obs_batch, critic_observations_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, \
                       old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, (None, None), None

    # for RNNs only
    def reccurent_mini_batch_generator(self, num_mini_batches, num_epochs=8):

        padded_obs_trajectories, trajectory_masks = split_and_pad_trajectories(self.observations, self.dones)
        if self.privileged_observations is not None: 
            padded_critic_obs_trajectories, _ = split_and_pad_trajectories(self.privileged_observations, self.dones)
        else: 
            padded_critic_obs_trajectories = padded_obs_trajectories

        mini_batch_size = self.num_envs // num_mini_batches
        for ep in range(num_epochs):
            first_traj = 0
            for i in range(num_mini_batches):
                start = i*mini_batch_size
                stop = (i+1)*mini_batch_size

                dones = self.dones.squeeze(-1)
                last_was_done = torch.zeros_like(dones, dtype=torch.bool)
                last_was_done[1:] = dones[:-1]
                last_was_done[0] = True
                trajectories_batch_size = torch.sum(last_was_done[:, start:stop])
                last_traj = first_traj + trajectories_batch_size
                
                masks_batch = trajectory_masks[:, first_traj:last_traj]
                obs_batch = padded_obs_trajectories[:, first_traj:last_traj]
                critic_obs_batch = padded_critic_obs_trajectories[:, first_traj:last_traj]

                actions_batch = self.actions[:, start:stop]
                old_mu_batch = self.mu[:, start:stop]
                old_sigma_batch = self.sigma[:, start:stop]
                returns_batch = self.returns[:, start:stop]
                advantages_batch = self.advantages[:, start:stop]
                values_batch = self.values[:, start:stop]
                old_actions_log_prob_batch = self.actions_log_prob[:, start:stop]

                # reshape to [num_envs, time, num layers, hidden dim] (original shape: [time, num_layers, num_envs, hidden_dim])
                # then take only time steps after dones (flattens num envs and time dimensions),
                # take a batch of trajectories and finally reshape back to [num_layers, batch, hidden_dim]
                last_was_done = last_was_done.permute(1, 0)
                hid_a_batch = [ saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj].transpose(1, 0).contiguous()
                                for saved_hidden_states in self.saved_hidden_states_a ] 
                hid_c_batch = [ saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj].transpose(1, 0).contiguous()
                                for saved_hidden_states in self.saved_hidden_states_c ]
                # remove the tuple for GRU
                hid_a_batch = hid_a_batch[0] if len(hid_a_batch)==1 else hid_a_batch
                hid_c_batch = hid_c_batch[0] if len(hid_c_batch)==1 else hid_a_batch

                yield obs_batch, critic_obs_batch, actions_batch, values_batch, advantages_batch, returns_batch, \
                       old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, (hid_a_batch, hid_c_batch), masks_batch
                
                first_traj = last_traj

1-3 初始化函数
python 复制代码
def __init__(self, num_envs, num_transitions_per_env, obs_shape, privileged_obs_shape, actions_shape, device='cpu'):

	self.device = device

	self.obs_shape = obs_shape
	self.privileged_obs_shape = privileged_obs_shape
	self.actions_shape = actions_shape

	# Core
	self.observations = torch.zeros(num_transitions_per_env, num_envs, *obs_shape, device=self.device)
	if privileged_obs_shape[0] is not None:
		self.privileged_observations = torch.zeros(num_transitions_per_env, num_envs, *privileged_obs_shape, device=self.device)
	else:
		self.privileged_observations = None
	self.rewards = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
	self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
	self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device).byte()

	# For PPO
	self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
	self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
	self.returns = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
	self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
	self.mu = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
	self.sigma = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)

	self.num_transitions_per_env = num_transitions_per_env
	self.num_envs = num_envs

	# rnn
	self.saved_hidden_states_a = None
	self.saved_hidden_states_c = None

	self.step = 0
  • 可以看到初始化包含了很多采样轨迹需要记录的参数,包括:
  • 核心参数
    • self.obs_shape:观测空间维度(Actor输入的状态维度)
    • self.privileged_obs_shape:特权观测维度(Critic可能使用的额外观测信息)
    • self.actions_shape:动作空间维度
    • self.observations:存储每一步环境观测 s t s_t st
    • self.privileged_observations:存储 Critic 使用的特权观测
    • self.rewards:存储环境返回的奖励 r t r_t rt
    • self.actions:存储策略网络输出的动作 a t a_t at
    • self.dones:记录每一步是否结束 episode(环境是否 reset)
  • PPO参数
    • self.actions_log_prob:旧策略产生动作的对数概率 l o g π o l d ( a ∣ s ) log π_old(a|s) logπold(a∣s)
    • self.values:Critic 预测的状态价值 V ( s ) V(s) V(s)
    • self.returns:折扣回报 R t R_t Rt,由 compute_returns() 计算得到
    • self.advantages :优势函数 A t A_t At,用于 PPO 的策略梯度更新
    • self.mu:高斯策略的动作均值 μ μ μ
    • self.sigma:高斯策略的动作标准差 σ σ σ
    • self.num_transitions_per_env:每个环境采样的时间步长度(rollout长度)
    • self.num_envs:并行环境数量
  • RNN参数
    • self.saved_hidden_states_a:Actor 网络的隐藏状态
    • self.saved_hidden_states_c:Critic 网络的隐藏状态
    • self.step:当前 rollout 中已经采样的时间步,用于控制数据写入位置
  • 其中:self.dones 的结构是: [ T , N , 1 ] [T, N, 1] [T,N,1]
    • T:时间步(rollout长度)
    • N:并行环境数量

1-4 add_transitions()
  • 函数的作用:将一次环境交互产生的数据(Transition)写入 RolloutStorage。
python 复制代码
def add_transitions(self, transition: Transition):
	if self.step >= self.num_transitions_per_env:
		raise AssertionError("Rollout buffer overflow")
	self.observations[self.step].copy_(transition.observations)
	if self.privileged_observations is not None: self.privileged_observations[self.step].copy_(transition.critic_observations)
	self.actions[self.step].copy_(transition.actions)
	self.rewards[self.step].copy_(transition.rewards.view(-1, 1))
	self.dones[self.step].copy_(transition.dones.view(-1, 1))
	self.values[self.step].copy_(transition.values)
	self.actions_log_prob[self.step].copy_(transition.actions_log_prob.view(-1, 1))
	self.mu[self.step].copy_(transition.action_mean)
	self.sigma[self.step].copy_(transition.action_sigma)
	self._save_hidden_states(transition.hidden_states)
	self.step += 1
  • 包括:
python 复制代码
obs([T , N , obs_dim])
action
reward
value
log_prob
done
hidden_state
  • 这些都是 PPO更新必须用到的数据

1-5 _save_hidden_states()
  • 函数的作用:将当前时间步的 RNN hidden state 存入 RolloutStorage
  • 这个是配合上一期我们提到的ActorCriticRecurrent启用时候需要保存RNN的隐藏层
python 复制代码
def _save_hidden_states(self, hidden_states):
        if hidden_states is None or hidden_states==(None, None):
            return
        # make a tuple out of GRU hidden state sto match the LSTM format
        hid_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],)
        hid_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],)

        # initialize if needed 
        if self.saved_hidden_states_a is None:
            self.saved_hidden_states_a = [torch.zeros(self.observations.shape[0], *hid_a[i].shape, device=self.device) for i in range(len(hid_a))]
            self.saved_hidden_states_c = [torch.zeros(self.observations.shape[0], *hid_c[i].shape, device=self.device) for i in range(len(hid_c))]
        # copy the states
        for i in range(len(hid_a)):
            self.saved_hidden_states_a[i][self.step].copy_(hid_a[i])
            self.saved_hidden_states_c[i][self.step].copy_(hid_c[i])

1-6 核心函数compute_returns()
  • 这个函数是PPO算法的核心之处,他做了两件事:
    • 计算Return
    • 使用 GAE计算 Advantage
python 复制代码
def compute_returns(self, last_values, gamma, lam):
	advantage = 0
	for step in reversed(range(self.num_transitions_per_env)):
		if step == self.num_transitions_per_env - 1:
			next_values = last_values
		else:
			next_values = self.values[step + 1]
		next_is_not_terminal = 1.0 - self.dones[step].float()
		delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step]
		advantage = delta + next_is_not_terminal * gamma * lam * advantage
		self.returns[step] = advantage + self.values[step]

	# Compute and normalize the advantages
	self.advantages = self.returns - self.values
	self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)
  • 这里我们可以分开看每一步

1-6-1 反向遍历 Rollout
python 复制代码
for step in reversed(range(self.num_transitions_per_env)):
  • 这里我们使用reversed反向遍历 Rollout,原因是 GAE需要递推 : A t = δ t + γ λ A t + 1 A_t = \delta_t + \gamma\lambda A_{t+1} At=δt+γλAt+1所以必须 从最后一步开始计算

1-6-2 获取下一步价值
python 复制代码
if step == self.num_transitions_per_env - 1:
    next_values = last_values
else:
    next_values = self.values[step + 1]
  • 这里计算 V ( s t + 1 ) V(s_{t+1}) V(st+1)
  • 如果是最后一步的话需要设置为 V ( s t ) V(s_t) V(st)

1-6-3 判断 episode 是否结束
python 复制代码
next_is_not_terminal = 1.0 - self.dones[step].float()
  • 如果done==1,说明说明 episode 结束。next_is_not_terminal就会变成0

1-6-4 核心:判断计算 TD Error
python 复制代码
delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step]
  • 对应公式里头的 δ t = r t + γ V ( s t + 1 ) − V ( s t ) \delta_t = r_t + \gamma V(s_{t+1})- V(s_t) δt=rt+γV(st+1)−V(st)
  • 这里考虑到考虑 episode 终止,故补上上一步计算的next_is_not_terminal: δ t = r t + ( 1 − d t ) γ V ( s t + 1 ) − V ( s t ) \delta_t = r_t + (1-d_t)\gamma V(s_{t+1}) - V(s_t) δt=rt+(1−dt)γV(st+1)−V(st)

1-6-5 核心:计算 GAE Advantage
python 复制代码
advantage = delta + next_is_not_terminal * gamma * lam * advantage
  • 对应公式: A t = δ t + ( 1 − d t ) γ λ A t + 1 A_t = \delta_t + (1-d_t)\gamma\lambda A_{t+1} At=δt+(1−dt)γλAt+1其中:
    • γ \gamma γ:控制长期奖励
    • λ \lambda λ: 控制 bias-variance tradeoff

1-6-6 核心:计算 Return以及 Advantage,以及Advantage标准化
python 复制代码
self.returns[step] = advantage + self.values[step]
# Compute and normalize the advantages
self.advantages = self.returns - self.values
self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)
  • 对应公式: A t = R t − V ( s t ) A_t=R_t−V(s_t) At=Rt−V(st)
  • Advantage标准化对应公式: A t = A t − μ A σ A + ϵ A_t = \frac{A_t - \mu_A}{\sigma_A + \epsilon} At=σA+ϵAt−μA
  • 作用是为了
    1. 稳定训练
    2. 避免梯度过大
    3. 减少 variance

1-7 统计函数函数
  • 这个函数用于计算平均 episode 长度和平均 reward
python 复制代码
def get_statistics(self):
	done = self.dones
	done[-1] = 1
	flat_dones = done.permute(1, 0, 2).reshape(-1, 1)
	done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero(as_tuple=False)[:, 0]))
	trajectory_lengths = (done_indices[1:] - done_indices[:-1])
	return trajectory_lengths.float().mean(), self.rewards.mean()
  • 再次回顾done的结构是[T,N,1]
  • done[-1] = 1,确保 每条轨迹一定会结束,也就是手动rollout末尾当作episode结束
  • flat_dones翻转维度变为[N×T,1],现在所有 done 被展平为一条序列
  • done_indices在所有 done 前面加一个-1,这样做是为了方便计算 轨迹长度
  • trajectory_lengths:计算轨迹长度 L i = d i − d i − 1 L_i=d_i-d_{i-1} Li=di−di−1
  • trajectory_lengths.float().mean():计算平均轨迹长度 L ˉ = 1 K ∑ i = 1 K L i \bar{L} = \frac{1}{K}\sum_{i=1}^{K} L_i Lˉ=K1∑i=1KLi

1-8 核心函数`mini_batch_generator()
  • 作用就是把 Rollout 收集到的大 batch 数据 → 打乱 → 切成 mini-batch → 多轮训练。
    • num_mini_batches:一个 batch 切成多少小批
    • num_epochs:同一批数据训练多少轮
python 复制代码
def mini_batch_generator(self, num_mini_batches, num_epochs=8):
	batch_size = self.num_envs * self.num_transitions_per_env
	mini_batch_size = batch_size // num_mini_batches
	indices = torch.randperm(num_mini_batches*mini_batch_size, requires_grad=False, device=self.device)

	observations = self.observations.flatten(0, 1)
	if self.privileged_observations is not None:
		critic_observations = self.privileged_observations.flatten(0, 1)
	else:
		critic_observations = observations

	actions = self.actions.flatten(0, 1)
	values = self.values.flatten(0, 1)
	returns = self.returns.flatten(0, 1)
	old_actions_log_prob = self.actions_log_prob.flatten(0, 1)
	advantages = self.advantages.flatten(0, 1)
	old_mu = self.mu.flatten(0, 1)
	old_sigma = self.sigma.flatten(0, 1)

	for epoch in range(num_epochs):
		for i in range(num_mini_batches):

			start = i*mini_batch_size
			end = (i+1)*mini_batch_size
			batch_idx = indices[start:end]

			obs_batch = observations[batch_idx]
			critic_observations_batch = critic_observations[batch_idx]
			actions_batch = actions[batch_idx]
			target_values_batch = values[batch_idx]
			returns_batch = returns[batch_idx]
			old_actions_log_prob_batch = old_actions_log_prob[batch_idx]
			advantages_batch = advantages[batch_idx]
			old_mu_batch = old_mu[batch_idx]
			old_sigma_batch = old_sigma[batch_idx]
			yield obs_batch, critic_observations_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, \
				   old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, (None, None), None
  • 返回的数据类型:
python 复制代码
obs_batch  
critic_obs_batch  
actions_batch  
target_values_batch  
advantages_batch  
returns_batch  
old_log_prob  
old_mu  
old_sigma
  • 这些就是 PPO loss 计算的全部输入

1-9 小结
  • 我们可以在ppo.py看到
python 复制代码
def compute_returns(self, last_critic_obs):
	last_values= self.actor_critic.evaluate(last_critic_obs).detach()
	self.storage.compute_returns(last_values, self.gamma, self.lam)

或者

python 复制代码
generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
  • 整体的流程大概是:

    环境交互

    RolloutStorage

    compute_returns()

    mini_batch_generator()

    PPO update

    loss计算


2 VecEnv

2-1 介绍
  • VecEnv (Vectorized Environment) 是多环境并行接口
  • 功能:
    1. 一次步进多个环境 → 高 GPU 利用率
    2. 缓存 obsprivileged_obs奖励done 等信息
  • RolloutStorage 会把 VecEnv 输出的数据存下来,形成完整的 rollout
2-2 完整代码实现
python 复制代码
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# 
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin

from abc import ABC, abstractmethod
import torch
from typing import Tuple, Union

# minimal interface of the environment
class VecEnv(ABC):
    num_envs: int
    num_obs: int
    num_privileged_obs: int
    num_actions: int
    max_episode_length: int
    privileged_obs_buf: torch.Tensor
    obs_buf: torch.Tensor 
    rew_buf: torch.Tensor
    reset_buf: torch.Tensor
    episode_length_buf: torch.Tensor # current episode duration
    extras: dict
    device: torch.device
    @abstractmethod
    def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, dict]:
        pass
    @abstractmethod
    def reset(self, env_ids: Union[list, torch.Tensor]):
        pass
    @abstractmethod
    def get_observations(self) -> torch.Tensor:
        pass
    @abstractmethod
    def get_privileged_observations(self) -> Union[torch.Tensor, None]:
        pass
  • 这个类很短我们直接看参数定义:
    • num_envs:并行环境数量
    • num_obs:观测空间维度
    • num_privileged_obs:特权观测维度
    • num_actions:动作维度
    • max_episode_length:最大 episode 长度,一个 episode 最多运行多少 step
    • 下面这些参数都是GPU tensor buffer
      • obs_buf:观测缓存
      • privileged_obs_buf:特权观测缓存
      • rew_buf:奖励缓存
      • reset 标志:reset 标志缓存
      • episode_length_buf: episode 长度缓存
      • extras:存储一些额外信息
      • device:运行设备
  • 然后是几个虚函数
    • step():环境推进每一步,对应公式 E n v ( s t , a t ) s t + 1 , r t = E n v ( s t , a t ) Env(st,at)s_{t+1}, r_t = Env(s_t, a_t) Env(st,at)st+1,rt=Env(st,at)
    • reset():重置环境

3 OnPolicyRunner
3-1 介绍
  • 它是一个 通用强化学习训练器 ,用于在 Vectorized 环境(VecEnv)中训练 PPO 或其他 Actor-Critic 算法。
3-2 完整代码实现
python 复制代码
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# 
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin

import time
import os
from collections import deque
import statistics

from torch.utils.tensorboard import SummaryWriter
import torch

from rsl_rl.algorithms import PPO
from rsl_rl.modules import ActorCritic, ActorCriticRecurrent
from rsl_rl.env import VecEnv


class OnPolicyRunner:

    def __init__(self,
                 env: VecEnv,
                 train_cfg,
                 log_dir=None,
                 device='cpu'):

        self.cfg=train_cfg["runner"]
        self.alg_cfg = train_cfg["algorithm"]
        self.policy_cfg = train_cfg["policy"]
        self.device = device
        self.env = env
        if self.env.num_privileged_obs is not None:
            num_critic_obs = self.env.num_privileged_obs 
        else:
            num_critic_obs = self.env.num_obs
        actor_critic_class = eval(self.cfg["policy_class_name"]) # ActorCritic
        actor_critic: ActorCritic = actor_critic_class( self.env.num_obs,
                                                        num_critic_obs,
                                                        self.env.num_actions,
                                                        **self.policy_cfg).to(self.device)
        alg_class = eval(self.cfg["algorithm_class_name"]) # PPO
        self.alg: PPO = alg_class(actor_critic, device=self.device, **self.alg_cfg)
        self.num_steps_per_env = self.cfg["num_steps_per_env"]
        self.save_interval = self.cfg["save_interval"]

        # init storage and model
        self.alg.init_storage(self.env.num_envs, self.num_steps_per_env, [self.env.num_obs], [self.env.num_privileged_obs], [self.env.num_actions])

        # Log
        self.log_dir = log_dir
        self.writer = None
        self.tot_timesteps = 0
        self.tot_time = 0
        self.current_learning_iteration = 0

        _, _ = self.env.reset()
    
    def learn(self, num_learning_iterations, init_at_random_ep_len=False):
        # initialize writer
        if self.log_dir is not None and self.writer is None:
            self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10)
        if init_at_random_ep_len:
            self.env.episode_length_buf = torch.randint_like(self.env.episode_length_buf, high=int(self.env.max_episode_length))
        obs = self.env.get_observations()
        privileged_obs = self.env.get_privileged_observations()
        critic_obs = privileged_obs if privileged_obs is not None else obs
        obs, critic_obs = obs.to(self.device), critic_obs.to(self.device)
        self.alg.actor_critic.train() # switch to train mode (for dropout for example)

        ep_infos = []
        rewbuffer = deque(maxlen=100)
        lenbuffer = deque(maxlen=100)
        cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
        cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)

        tot_iter = self.current_learning_iteration + num_learning_iterations
        for it in range(self.current_learning_iteration, tot_iter):
            start = time.time()
            # Rollout
            with torch.inference_mode():
                for i in range(self.num_steps_per_env):
                    actions = self.alg.act(obs, critic_obs)
                    obs, privileged_obs, rewards, dones, infos = self.env.step(actions)
                    critic_obs = privileged_obs if privileged_obs is not None else obs
                    obs, critic_obs, rewards, dones = obs.to(self.device), critic_obs.to(self.device), rewards.to(self.device), dones.to(self.device)
                    self.alg.process_env_step(rewards, dones, infos)
                    
                    if self.log_dir is not None:
                        # Book keeping
                        if 'episode' in infos:
                            ep_infos.append(infos['episode'])
                        cur_reward_sum += rewards
                        cur_episode_length += 1
                        new_ids = (dones > 0).nonzero(as_tuple=False)
                        rewbuffer.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist())
                        lenbuffer.extend(cur_episode_length[new_ids][:, 0].cpu().numpy().tolist())
                        cur_reward_sum[new_ids] = 0
                        cur_episode_length[new_ids] = 0

                stop = time.time()
                collection_time = stop - start

                # Learning step
                start = stop
                self.alg.compute_returns(critic_obs)
            
            mean_value_loss, mean_surrogate_loss = self.alg.update()
            stop = time.time()
            learn_time = stop - start
            if self.log_dir is not None:
                self.log(locals())
            if it % self.save_interval == 0:
                self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(it)))
            ep_infos.clear()
        
        self.current_learning_iteration += num_learning_iterations
        self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(self.current_learning_iteration)))

    def log(self, locs, width=80, pad=35):
        self.tot_timesteps += self.num_steps_per_env * self.env.num_envs
        self.tot_time += locs['collection_time'] + locs['learn_time']
        iteration_time = locs['collection_time'] + locs['learn_time']

        ep_string = f''
        if locs['ep_infos']:
            for key in locs['ep_infos'][0]:
                infotensor = torch.tensor([], device=self.device)
                for ep_info in locs['ep_infos']:
                    # handle scalar and zero dimensional tensor infos
                    if not isinstance(ep_info[key], torch.Tensor):
                        ep_info[key] = torch.Tensor([ep_info[key]])
                    if len(ep_info[key].shape) == 0:
                        ep_info[key] = ep_info[key].unsqueeze(0)
                    infotensor = torch.cat((infotensor, ep_info[key].to(self.device)))
                value = torch.mean(infotensor)
                self.writer.add_scalar('Episode/' + key, value, locs['it'])
                ep_string += f"""{f'Mean episode {key}:':>{pad}} {value:.4f}\n"""
        mean_std = self.alg.actor_critic.std.mean()
        fps = int(self.num_steps_per_env * self.env.num_envs / (locs['collection_time'] + locs['learn_time']))

        self.writer.add_scalar('Loss/value_function', locs['mean_value_loss'], locs['it'])
        self.writer.add_scalar('Loss/surrogate', locs['mean_surrogate_loss'], locs['it'])
        self.writer.add_scalar('Loss/learning_rate', self.alg.learning_rate, locs['it'])
        self.writer.add_scalar('Policy/mean_noise_std', mean_std.item(), locs['it'])
        self.writer.add_scalar('Perf/total_fps', fps, locs['it'])
        self.writer.add_scalar('Perf/collection time', locs['collection_time'], locs['it'])
        self.writer.add_scalar('Perf/learning_time', locs['learn_time'], locs['it'])
        if len(locs['rewbuffer']) > 0:
            self.writer.add_scalar('Train/mean_reward', statistics.mean(locs['rewbuffer']), locs['it'])
            self.writer.add_scalar('Train/mean_episode_length', statistics.mean(locs['lenbuffer']), locs['it'])
            self.writer.add_scalar('Train/mean_reward/time', statistics.mean(locs['rewbuffer']), self.tot_time)
            self.writer.add_scalar('Train/mean_episode_length/time', statistics.mean(locs['lenbuffer']), self.tot_time)

        str = f" \033[1m Learning iteration {locs['it']}/{self.current_learning_iteration + locs['num_learning_iterations']} \033[0m "

        if len(locs['rewbuffer']) > 0:
            log_string = (f"""{'#' * width}\n"""
                          f"""{str.center(width, ' ')}\n\n"""
                          f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
                            'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
                          f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n"""
                          f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n"""
                          f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
                          f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n"""
                          f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n""")
                        #   f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
                        #   f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""")
        else:
            log_string = (f"""{'#' * width}\n"""
                          f"""{str.center(width, ' ')}\n\n"""
                          f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
                            'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
                          f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n"""
                          f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n"""
                          f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""")
                        #   f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
                        #   f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""")

        log_string += ep_string
        log_string += (f"""{'-' * width}\n"""
                       f"""{'Total timesteps:':>{pad}} {self.tot_timesteps}\n"""
                       f"""{'Iteration time:':>{pad}} {iteration_time:.2f}s\n"""
                       f"""{'Total time:':>{pad}} {self.tot_time:.2f}s\n"""
                       f"""{'ETA:':>{pad}} {self.tot_time / (locs['it'] + 1) * (
                               locs['num_learning_iterations'] - locs['it']):.1f}s\n""")
        print(log_string)

    def save(self, path, infos=None):
        torch.save({
            'model_state_dict': self.alg.actor_critic.state_dict(),
            'optimizer_state_dict': self.alg.optimizer.state_dict(),
            'iter': self.current_learning_iteration,
            'infos': infos,
            }, path)

    def load(self, path, load_optimizer=True):
        loaded_dict = torch.load(path)
        self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'])
        if load_optimizer:
            self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'])
        self.current_learning_iteration = loaded_dict['iter']
        return loaded_dict['infos']

    def get_inference_policy(self, device=None):
        self.alg.actor_critic.eval() # switch to evaluation mode (dropout for example)
        if device is not None:
            self.alg.actor_critic.to(device)
        return self.alg.actor_critic.act_inference

3-3 初始化函数
python 复制代码
def __init__(self,
			 env: VecEnv,
			 train_cfg,
			 log_dir=None,
			 device='cpu'):

	self.cfg=train_cfg["runner"]
	self.alg_cfg = train_cfg["algorithm"]
	self.policy_cfg = train_cfg["policy"]
	self.device = device
	self.env = env
	if self.env.num_privileged_obs is not None:
		num_critic_obs = self.env.num_privileged_obs 
	else:
		num_critic_obs = self.env.num_obs
	actor_critic_class = eval(self.cfg["policy_class_name"]) # ActorCritic
	actor_critic: ActorCritic = actor_critic_class( self.env.num_obs,
													num_critic_obs,
													self.env.num_actions,
													**self.policy_cfg).to(self.device)
	alg_class = eval(self.cfg["algorithm_class_name"]) # PPO
	self.alg: PPO = alg_class(actor_critic, device=self.device, **self.alg_cfg)
	self.num_steps_per_env = self.cfg["num_steps_per_env"]
	self.save_interval = self.cfg["save_interval"]

	# init storage and model
	self.alg.init_storage(self.env.num_envs, self.num_steps_per_env, [self.env.num_obs], [self.env.num_privileged_obs], [self.env.num_actions])

	# Log
	self.log_dir = log_dir
	self.writer = None
	self.tot_timesteps = 0
	self.tot_time = 0
	self.current_learning_iteration = 0

	_, _ = self.env.reset()
  • 我们来看输入的超参:
    • env: 训练环境(支持多环境并行的 VecEnv)
    • train_cfg: 配置字典,包括 runner 配置、算法配置、策略配置
    • log_dir: 日志目录,用于 TensorBoard 可视化
    • device: PyTorch 设备(CPU/GPU)
  • 初始化函数进行了一些基础参数的配置,
  1. 读取 Runner、算法、策略配置

  2. 判断 Critic 的输入维度

    • Actor 只看普通观测 num_obs
    • Critic 可使用特权观测 num_privileged_obs,否则用普通观测
  3. 初始化 Actor-Critic 网络

  4. 初始化 PPO 算法

  5. 初始化存储(Trajectory Storage)

  6. 初始化日志和计数器


3-4 核心训练函数learn()
python 复制代码
def learn(self, num_learning_iterations, init_at_random_ep_len=False):
        # initialize writer
        if self.log_dir is not None and self.writer is None:
            self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10)
        if init_at_random_ep_len:
            self.env.episode_length_buf = torch.randint_like(self.env.episode_length_buf, high=int(self.env.max_episode_length))
        obs = self.env.get_observations()
        privileged_obs = self.env.get_privileged_observations()
        critic_obs = privileged_obs if privileged_obs is not None else obs
        obs, critic_obs = obs.to(self.device), critic_obs.to(self.device)
        self.alg.actor_critic.train() # switch to train mode (for dropout for example)

        ep_infos = []
        rewbuffer = deque(maxlen=100)
        lenbuffer = deque(maxlen=100)
        cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
        cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)

        tot_iter = self.current_learning_iteration + num_learning_iterations
        for it in range(self.current_learning_iteration, tot_iter):
            start = time.time()
            # Rollout
            with torch.inference_mode():
                for i in range(self.num_steps_per_env):
                    actions = self.alg.act(obs, critic_obs)
                    obs, privileged_obs, rewards, dones, infos = self.env.step(actions)
                    critic_obs = privileged_obs if privileged_obs is not None else obs
                    obs, critic_obs, rewards, dones = obs.to(self.device), critic_obs.to(self.device), rewards.to(self.device), dones.to(self.device)
                    self.alg.process_env_step(rewards, dones, infos)
                    
                    if self.log_dir is not None:
                        # Book keeping
                        if 'episode' in infos:
                            ep_infos.append(infos['episode'])
                        cur_reward_sum += rewards
                        cur_episode_length += 1
                        new_ids = (dones > 0).nonzero(as_tuple=False)
                        rewbuffer.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist())
                        lenbuffer.extend(cur_episode_length[new_ids][:, 0].cpu().numpy().tolist())
                        cur_reward_sum[new_ids] = 0
                        cur_episode_length[new_ids] = 0

                stop = time.time()
                collection_time = stop - start

                # Learning step
                start = stop
                self.alg.compute_returns(critic_obs)
            
            mean_value_loss, mean_surrogate_loss = self.alg.update()
            stop = time.time()
            learn_time = stop - start
            if self.log_dir is not None:
                self.log(locals())
            if it % self.save_interval == 0:
                self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(it)))
            ep_infos.clear()
        
        self.current_learning_iteration += num_learning_iterations
        self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(self.current_learning_iteration)))

3-4-1 函数输入
python 复制代码
def learn(self, num_learning_iterations, init_at_random_ep_len=False):
  • num_learning_iterations: 总训练迭代次数
  • init_at_random_ep_len: 是否随机初始化 episode 长度(用于训练多样化)

3-4-2 获取初始观测
python 复制代码
obs = self.env.get_observations()
privileged_obs = self.env.get_privileged_observations()
critic_obs = privileged_obs if privileged_obs is not None else obs
obs, critic_obs = obs.to(self.device), critic_obs.to(self.device)
  • Actor 使用 obs
  • Critic 使用 critic_obs,优先使用特权观测
  • 这是 Asymmetric Actor-Critic 的典型用法(0-4提到的)

3-4-3 Rollout(采样动作、与环境交互)
python 复制代码
for i in range(self.num_steps_per_env):
    actions = self.alg.act(obs, critic_obs)
    obs, privileged_obs, rewards, dones, infos = self.env.step(actions)
    critic_obs = privileged_obs if privileged_obs is not None else obs
    obs, critic_obs, rewards, dones = obs.to(self.device), critic_obs.to(self.device), rewards.to(self.device), dones.to(self.device)
    self.alg.process_env_step(rewards, dones, infos)
  • Actor 根据当前 obs 生成动作
  • Env 返回新的 obsrewarddone
  • Critic 更新 critic_obs
  • process_env_step:把 step 数据存入存储(Storage)以便后续 PPO 更新

3-4-4 PPO 学习
python 复制代码
self.alg.compute_returns(critic_obs)
mean_value_loss, mean_surrogate_loss = self.alg.update()
  • compute_returns:计算优势函数(GAE)和折扣回报
  • update:使用 PPO 对 Actor-Critic 网络进行梯度更新

3-5 save(),load()以及推理函数
python 复制代码
def save(self, path, infos=None):
    torch.save({
        'model_state_dict': self.alg.actor_critic.state_dict(),
        'optimizer_state_dict': self.alg.optimizer.state_dict(),
        'iter': self.current_learning_iteration,
        'infos': infos,
    }, path)
def load(self, path, load_optimizer=True):
	loaded_dict = torch.load(path)
	self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'])
	if load_optimizer:
		self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'])
	self.current_learning_iteration = loaded_dict['iter']
	return loaded_dict['infos'] 
def get_inference_policy(self, device=None):
    self.alg.actor_critic.eval()
    if device is not None:
        self.alg.actor_critic.to(device)
    return self.alg.actor_critic.act_inference
  • 顾名思义保存加载模型。

3-6 整体流程总结
python 复制代码
VecEnv(多个并行环境)
       │
       ▼
采样数据 (obs / privileged_obs / action / reward / done)
       │
       ▼
RolloutStorage.add_transitions()
       │
       ▼
compute_returns()  # GAE Advantage
       │
       ▼
mini_batch_generator()
       │
       ▼
PPO update (Actor-Critic)
       │
       ▼
更新策略网络和 Critic 网络


小结

  • 本期我们讲解了 Asymmetric Actor-Critic 的原理、RolloutStorage 的数据存储与 GAE Advantage 计算、以及 VecEnv 并行环境管理,完整梳理了 PPO 训练流程中的数据采集与 mini-batch 生成
  • 自此我们前三期把rsl_rl的核心源码都解析完了,下一步就是开始对应使用到rsl_rlunitree_rl_gym进行源码解析了,真正涉及到训练和奖励函数的编写。
  • 如有错误,欢迎指出!感谢观看
相关推荐
Balrog-v1 小时前
2026最新保姆级教程:Windows 下使用 uv 从零配置 Python (OpenCV) 环境指南
windows·python·uv
梯度下降中1 小时前
Transformer原理精讲
人工智能·深度学习·transformer
海滩游侠1 小时前
细读经典: ZeRO
深度学习
EZ_Python1 小时前
如何在 Windows 上将 Python 脚本打包为 macOS 原生应用
windows·python·macos
爱学习的小齐哥哥1 小时前
鸿蒙常见问题分析三:视频关键帧提取与智能体图像分析
人工智能·pytorch·深度学习·harmonyos·harmony pc·harmonyos app
sinat_255487811 小时前
FileReader/FileWriter
java·开发语言·jvm
XW01059991 小时前
5-8能被3,5和7整除的数的个数(用集合实现)
前端·javascript·数据结构·数据库·python·for循环
清空mega1 小时前
网络程序设计入门第一章:Web、JSP、Tomcat 到底是什么?
开发语言·网络·php
DeepModel1 小时前
【概率分布】泊松分布的原理、推导与实战应用
python·算法·概率论