前言
-
Unitree RL GYM 是一个开源的 基于 Unitree 机器人强化学习(Reinforcement Learning, RL)控制示例项目 ,用于训练、测试和部署四足机器人控制策略。该仓库支持多种 Unitree 机器人型号,包括 Go2、H1、H1_2 和 G1 。仓库地址

-
本系列将着手解析整个仓库的核心代码与算法实现和训练教程。此系列默认读者拥有一定的强化学习基础和代码基础,故在部分原理和基础代码逻辑不做解释,对强化学习基础感兴趣的读者可以阅读我的入门系列:
- 第一期: 【浅显易懂理解强化学习】(一)Q-Learning原来是查表法-CSDN博客
- 第二期: 【浅显易懂理解强化学习】(二):Sarsa,保守派的胜利-CSDN博客
- 第三期:【浅显易懂理解强化学习】(三):DQN:当查表法装上大脑-CSDN博客
- 第四期:【浅显易懂理解强化学习】(四):Policy Gradients玩转策略采样-CSDN博客
- 第五期:【浅显易懂理解强化学习】(五):Actor-Critic与A3C,多线程的完全胜利-CSDN博客
- 第六期:【浅显易懂理解强化学习】(六):DDPG与TD3集百家之长-CSDN博客
- 第七期:【浅显易懂理解强化学习】(七):PPO,策略更新的安全阀-CSDN博客
-
阅读本系列的前置知识:
python语法,明白面向对象的封装pytorch基础使用- 神经网络基础知识
- 强化学习基础知识,至少了解
Policy Gradient、Actor-Critic和PPO
-
本系列:
-
本期将讲解
rsl_rl仓库的ActorCritic网络和ActorCriticRecurrent网络的python实现。
1 ActorCritic类
1-1 ActorCritic 网络回顾
- 我们在浅显易懂强化学习入门的第五期提到过
ActorCritic网络,这里快速的回顾一下核心的内容: - 在强化学习中,Actor-Critic 是一种同时学习策略(Policy)和价值函数(Value Function)的框架。公式上可以表达为: π θ ( a ∣ s ) ( 策略网络 ) \pi_\theta(a|s)(策略网络) πθ(a∣s)(策略网络) V ϕ ( s ) ( 价值网络 ) V_\phi(s)(价值网络) Vϕ(s)(价值网络)其中:
- Actor(策略网络) :负责告诉你"在当前状态 s s s 下应该做什么动作 a a a" a t ∼ π θ ( a t ∣ s t ) a_t \sim \pi_\theta(a_t|s_t) at∼πθ(at∣st)
- 输出动作分布(连续动作通常是高斯分布 N ( μ , σ 2 ) \mathcal{N}(\mu, \sigma^2) N(μ,σ2)
- 在训练中采样动作增加探索,推理时取均值动作减少随机性
- Critic(价值网络) :负责告诉你"当前状态 s s s 的好坏",即状态价值 V ϕ ( s t ) ≈ E [ R t ∣ s t ] V_\phi(s_t) \approx \mathbb{E}[R_t|s_t] Vϕ(st)≈E[Rt∣st]
- 用于计算优势函数 A t = R t − V ϕ ( s t ) A_t = R_t - V_\phi(s_t) At=Rt−Vϕ(st)
- 优势函数衡量某个动作比平均水平好多少,是 Actor 更新策略的参考
- Actor(策略网络) :负责告诉你"在当前状态 s s s 下应该做什么动作 a a a" a t ∼ π θ ( a t ∣ s t ) a_t \sim \pi_\theta(a_t|s_t) at∼πθ(at∣st)
- 策略更新依赖价值
- Actor 的梯度来自策略梯度定理: ∇ θ J ( θ ) = E t [ A t ∇ θ log π θ ( a t ∣ s t ) ] \nabla_\theta J(\theta) = \mathbb{E}t \big[ A_t \nabla\theta \log \pi_\theta(a_t|s_t) \big] ∇θJ(θ)=Et[At∇θlogπθ(at∣st)]
- Critic 的目标是最小化 均方误差 : L critic = E t [ ( V ϕ ( s t ) − R t ) 2 ] L_\text{critic} = \mathbb{E}t \big[(V\phi(s_t) - R_t)^2\big] Lcritic=Et[(Vϕ(st)−Rt)2]
- 优势函数的作用
- 提高训练稳定性:仅更新比平均水平更好的动作
- 减少策略梯度的方差,使训练收敛更快
- 总结:
- Actor 决策,Critic 评估,优势函数桥接二者,使策略在训练中既有方向又有参照。
1-2 完整代码一览
- 我们打开
rsl_rl仓库
shell
git clone https://github.com/leggedrobotics/rsl_rl.git
cd rsl_rl
git checkout v1.0.2
- 在项目根目录下的
modules文件夹下可以找到actor_critic.py
python
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Normal
from torch.nn.modules import rnn
class ActorCritic(nn.Module):
is_recurrent = False
def __init__(self, num_actor_obs,
num_critic_obs,
num_actions,
actor_hidden_dims=[256, 256, 256],
critic_hidden_dims=[256, 256, 256],
activation='elu',
init_noise_std=1.0,
**kwargs):
if kwargs:
print("ActorCritic.__init__ got unexpected arguments, which will be ignored: " + str([key for key in kwargs.keys()]))
super(ActorCritic, self).__init__()
activation = get_activation(activation)
mlp_input_dim_a = num_actor_obs
mlp_input_dim_c = num_critic_obs
# Policy
actor_layers = []
actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0]))
actor_layers.append(activation)
for l in range(len(actor_hidden_dims)):
if l == len(actor_hidden_dims) - 1:
actor_layers.append(nn.Linear(actor_hidden_dims[l], num_actions))
else:
actor_layers.append(nn.Linear(actor_hidden_dims[l], actor_hidden_dims[l + 1]))
actor_layers.append(activation)
self.actor = nn.Sequential(*actor_layers)
# Value function
critic_layers = []
critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0]))
critic_layers.append(activation)
for l in range(len(critic_hidden_dims)):
if l == len(critic_hidden_dims) - 1:
critic_layers.append(nn.Linear(critic_hidden_dims[l], 1))
else:
critic_layers.append(nn.Linear(critic_hidden_dims[l], critic_hidden_dims[l + 1]))
critic_layers.append(activation)
self.critic = nn.Sequential(*critic_layers)
print(f"Actor MLP: {self.actor}")
print(f"Critic MLP: {self.critic}")
# Action noise
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
self.distribution = None
# disable args validation for speedup
Normal.set_default_validate_args = False
# seems that we get better performance without init
# self.init_memory_weights(self.memory_a, 0.001, 0.)
# self.init_memory_weights(self.memory_c, 0.001, 0.)
@staticmethod
# not used at the moment
def init_weights(sequential, scales):
[torch.nn.init.orthogonal_(module.weight, gain=scales[idx]) for idx, module in
enumerate(mod for mod in sequential if isinstance(mod, nn.Linear))]
def reset(self, dones=None):
pass
def forward(self):
raise NotImplementedError
@property
def action_mean(self):
return self.distribution.mean
@property
def action_std(self):
return self.distribution.stddev
@property
def entropy(self):
return self.distribution.entropy().sum(dim=-1)
def update_distribution(self, observations):
mean = self.actor(observations)
self.distribution = Normal(mean, mean*0. + self.std)
def act(self, observations, **kwargs):
self.update_distribution(observations)
return self.distribution.sample()
def get_actions_log_prob(self, actions):
return self.distribution.log_prob(actions).sum(dim=-1)
def act_inference(self, observations):
actions_mean = self.actor(observations)
return actions_mean
def evaluate(self, critic_observations, **kwargs):
value = self.critic(critic_observations)
return value
def get_activation(act_name):
if act_name == "elu":
return nn.ELU()
elif act_name == "selu":
return nn.SELU()
elif act_name == "relu":
return nn.ReLU()
elif act_name == "crelu":
return nn.ReLU()
elif act_name == "lrelu":
return nn.LeakyReLU()
elif act_name == "tanh":
return nn.Tanh()
elif act_name == "sigmoid":
return nn.Sigmoid()
else:
print("invalid activation function!")
return None
- 我们接下来看每一个函数分别实现了什么
1-3 初始化函数
python
class ActorCritic(nn.Module):
is_recurrent = False
def __init__(self, num_actor_obs,
num_critic_obs,
num_actions,
actor_hidden_dims=[256, 256, 256],
critic_hidden_dims=[256, 256, 256],
activation='elu',
init_noise_std=1.0,
**kwargs):
if kwargs:
print("ActorCritic.__init__ got unexpected arguments, which will be ignored: " + str([key for key in kwargs.keys()]))
super(ActorCritic, self).__init__()
activation = get_activation(activation)
mlp_input_dim_a = num_actor_obs
mlp_input_dim_c = num_critic_obs
# Policy
actor_layers = []
actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0]))
actor_layers.append(activation)
for l in range(len(actor_hidden_dims)):
if l == len(actor_hidden_dims) - 1:
actor_layers.append(nn.Linear(actor_hidden_dims[l], num_actions))
else:
actor_layers.append(nn.Linear(actor_hidden_dims[l], actor_hidden_dims[l + 1]))
actor_layers.append(activation)
self.actor = nn.Sequential(*actor_layers)
# Value function
critic_layers = []
critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0]))
critic_layers.append(activation)
for l in range(len(critic_hidden_dims)):
if l == len(critic_hidden_dims) - 1:
critic_layers.append(nn.Linear(critic_hidden_dims[l], 1))
else:
critic_layers.append(nn.Linear(critic_hidden_dims[l], critic_hidden_dims[l + 1]))
critic_layers.append(activation)
self.critic = nn.Sequential(*critic_layers)
print(f"Actor MLP: {self.actor}")
print(f"Critic MLP: {self.critic}")
# Action noise
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
self.distribution = None
# disable args validation for speedup
Normal.set_default_validate_args = False
# seems that we get better performance without init
# self.init_memory_weights(self.memory_a, 0.001, 0.)
# self.init_memory_weights(self.memory_c, 0.001, 0.)
1-3-1 超参数
python
def __init__(self,
num_actor_obs,
num_critic_obs,
num_actions,
actor_hidden_dims=[256, 256, 256],
critic_hidden_dims=[256, 256, 256],
activation='elu',
init_noise_std=1.0,
**kwargs):
- 我们来看看超参数
num_actor_obs:Actor 网络的输入维度,也就是策略网络可以看到的观测状态数量num_critic_obs:Critic 网络的输入维度,也就是价值网络看到的观测状态数量num_actions:动作空间维度actor_hidden_dims:Actor 网络每一隐藏层的神经元个数critic_hidden_dims:Critic 网络每一隐藏层神经元个数activation:隐藏层激活函数,对应get_activation()函数提供的几个激活函数,"elu","selu","relu","crelu","lrelu","tanh","sigmoid"init_noise_std:Actor 输出动作的初始噪声标准差 σ \sigma σ
1-3-2 构建 Actor MLP与 Critic MLP
python
# Policy
actor_layers = []
actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0]))
actor_layers.append(activation)
for l in range(len(actor_hidden_dims)):
if l == len(actor_hidden_dims) - 1:
actor_layers.append(nn.Linear(actor_hidden_dims[l], num_actions))
else:
actor_layers.append(nn.Linear(actor_hidden_dims[l], actor_hidden_dims[l + 1]))
actor_layers.append(activation)
self.actor = nn.Sequential(*actor_layers)
# Value function
critic_layers = []
critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0]))
critic_layers.append(activation)
for l in range(len(critic_hidden_dims)):
if l == len(critic_hidden_dims) - 1:
critic_layers.append(nn.Linear(critic_hidden_dims[l], 1))
else:
critic_layers.append(nn.Linear(critic_hidden_dims[l], critic_hidden_dims[l + 1]))
critic_layers.append(activation)
self.critic = nn.Sequential(*critic_layers)
print(f"Actor MLP: {self.actor}")
print(f"Critic MLP: {self.critic}")
- 我们直接用一个表格来描述这两个网络
| 特征 | Actor 网络 | Critic 网络 |
|---|---|---|
| 输入 | num_actor_obs |
num_critic_obs(可不同) |
| 输出 | 动作向量 μ θ ( s ) \mu_\theta(s) μθ(s) | 状态价值标量 V ϕ ( s ) V_\phi(s) Vϕ(s) |
| 输出层维度 | 动作空间维度 | 1 |
| 功能 | 选择动作(策略) | 评估状态(价值函数) |
| 使用 | 策略梯度更新 | 计算优势函数指导 Actor |
- 一个输出当前状态对应的动作向量 μ θ ( s ) \mu_\theta(s) μθ(s),一个输出评估当前状态的状态价值标量 V ϕ ( s ) V_\phi(s) Vϕ(s)
1-3-3 动作噪声初始化
python
# Action noise
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
self.distribution = None
Normal.set_default_validate_args = False
- 在连续动作空间中,动作采样服从正态分布: a ∼ N ( μ ( s ) , σ 2 ) a \sim \mathcal{N}(\mu(s), \sigma^2) a∼N(μ(s),σ2)
- 其中:
self.std:动作分布的标准差 σ \sigma σ,控制动作探索self.distribution:动作分布对象Normal(mean, std),用于采样、计算 log_prob 和熵
1-4 损失函数计算辅助工具函数
python
@property
def action_mean(self):
return self.distribution.mean
@property
def action_std(self):
return self.distribution.stddev
@property
def entropy(self):
return self.distribution.entropy().sum(dim=-1)
- 这三个都是一些工具函数:
action_mean():返回当前状态下动作分布的均值 μ θ ( s ) \mu_\theta(s) μθ(s)action_std():返回当前状态下动作分布标准差 σ \sigma σentropy():返回动作分布熵 H [ π θ ] H[\pi_\theta] H[πθ]- 上面这三个均在
algorithms/ppo.py计算损失函数的时候被调用
python
mu_batch = self.actor_critic.action_mean
sigma_batch = self.actor_critic.action_std
entropy_batch = self.actor_critic.entropy
1-5 核心函数act()
python
def update_distribution(self, observations):
mean = self.actor(observations)
self.distribution = Normal(mean, mean*0. + self.std)
def act(self, observations, **kwargs):
self.update_distribution(observations)
return self.distribution.sample()
- 获取Actor 网络输出动作均值: μ θ ( s t ) = Actor ( s t ) \mu_\theta(s_t) = \text{Actor}(s_t) μθ(st)=Actor(st)
- 并和标准差组合成高斯动作分布: a t ∼ N ( μ θ ( s t ) , σ 2 I ) a_t \sim \mathcal{N}\big(\mu_\theta(s_t), \sigma^2 \mathbf{I}\big) at∼N(μθ(st),σ2I)
- 注意:这里
mean*0.只是生成与mean形状相同的零张量,保证广播正确。 algorithms/ppo.py中每一个batch训练的第一步就是Actor前向计算,重新计算 当前策略的动作分布
python
self.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])
1-6 get_actions_log_prob()
python
def get_actions_log_prob(self, actions):
return self.distribution.log_prob(actions).sum(dim=-1)
- 作用 :计算动作的对数概率 log π θ ( a t ∣ s t ) \log \pi_\theta(a_t|s_t) logπθ(at∣st)
algorithms/ppo.py中直接调用:
python
actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch)
1-7 核心函数evaluate
python
def evaluate(self, critic_observations, **kwargs):
value = self.critic(critic_observations)
return value
- Critic 前向计算状态价值 V ϕ ( s ) V_\phi(s) Vϕ(s)
- 计算优势函数 A t = R t − V ϕ ( s t ) A_t = R_t - V_\phi(s_t) At=Rt−Vϕ(st)
algorithms/ppo.py中直接调用:
python
value_batch = self.actor_critic.evaluate(critic_obs_batch)
1-8 网络结构
输入状态 s
MLP 隐藏层
Actor 输出 μ(s)
Critic 输出 V(s)
2 ActorCriticRecurrent
2-1 前置知识补充:RNN
- RNN 全称 :Recurrent Neural Network(循环神经网络)
- 功能:处理序列数据或时间序列任务,例如自然语言、机器人控制、股价预测等
- 核心思想:当前时刻的输出不仅依赖当前输入,还依赖上一时刻的隐藏状态,形成网络内部的"记忆"。
- 数学公式 : h t = f ( W x h x t + W h h h t − 1 + b h ) h_t = f(W_{xh} x_t + W_{hh} h_{t-1} + b_h) ht=f(Wxhxt+Whhht−1+bh)其中:
- x t x_t xt:当前输入(比如观测向量 o t o_t ot)
- h t − 1 h_{t-1} ht−1:上一时刻隐藏状态(记忆历史信息),其实就是是对历史信息的总结,相当于网络的"记忆"。
- h t h_t ht:当前隐藏状态
- f f f:非线性激活函数(
tanh,ReLU等)
- 主要优势 :
- 记忆历史信息:通过隐藏状态保留序列中重要的过去信息
- 处理部分可观测问题(POMDP):可以推断当前观测中缺失的状态信息
- 增强时间依赖建模能力:动作和价值可以依赖过去多步信息,而不仅仅是当前观测
2-2 RNN和经验池的区别
- **经验池(Replay Buffer)**是强化学习中 存储过去经验(state, action, reward, next_state)的容器。
- 它的作用是 离线采样历史经验进行训练,打破时间相关性,提高训练稳定性。
- 例子:在 DQN 或 PPO 中,经验池可以存储上千条轨迹,然后随机采样 mini-batch 更新网络参数。
| 特征 | RNN 记忆 | 经验池(Replay Buffer) |
|---|---|---|
| 数据类型 | 隐藏状态(隐藏的网络向量) | 完整的状态、动作、奖励轨迹 |
| 功能 | 捕捉时间依赖,生成连续决策 | 存储历史经验,用于训练网络 |
| 更新方式 | 随每个时间步前向传播动态更新 | 离线采样随机更新 |
| 生命周期 | 每条轨迹或 episode 内有效 | 可跨多个 episode 持久存在 |
- RNN 是网络内部的"记忆",帮助做连续决策;经验池是训练中的"记忆库",帮助算法学习更稳健的策略。
2-3 ActorCriticRecurrent
ActorCriticRecurrent是 ActorCritic 的 RNN 版本 ,引入循环神经网络使得 Actor 和 Critic 能够记住历史信息 ,适合处理 部分可观测环境(POMDP) 或时间依赖的任务。- 主要特点:
- Actor 和 Critic 都引入了"记忆模块" (RNN/LSTM/GRU),各自的职责仍不变:
- Actor 的记忆输出 → 决定动作分布( μ θ ( s t ) \mu_\theta(s_t) μθ(st))
- Critic 的记忆输出 → 估算当前状态的价值( V ϕ ( s t ) V_\phi(s_t) Vϕ(st))
- 网络会随着时间不断更新记忆
- 每个时间步:网络把当前观测 + 之前的记忆 → 更新隐藏状态 h t h_t ht
- 说人话就是网络不断把新信息加入记忆,让它对整个动作序列有上下文感知。
- 每次新 episode 会清空记忆
- 用
reset(dones)把对应位置的隐藏状态清零 - 说人话就是每新的一轮会把记忆清空,不让上一局的经验干扰当前决策。
- 用
- 训练和推理用法不同
- 训练(批量更新):RNN 会处理整个序列,利用 masks 和保存的隐藏状态做前向传播
- 推理(收集经验):每步用上一时刻的隐藏状态生成动作
- Actor 和 Critic 都引入了"记忆模块" (RNN/LSTM/GRU),各自的职责仍不变:
- 说人话:RNN 给强化学习网络增加了"记忆",让策略和价值函数能利用过去信息做决策,从而在复杂机器人控制或时间序列任务中表现更稳健。
2-4 完整代码一览
python
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Normal
from torch.nn.modules import rnn
from .actor_critic import ActorCritic, get_activation
from rsl_rl.utils import unpad_trajectories
class ActorCriticRecurrent(ActorCritic):
is_recurrent = True
def __init__(self, num_actor_obs,
num_critic_obs,
num_actions,
actor_hidden_dims=[256, 256, 256],
critic_hidden_dims=[256, 256, 256],
activation='elu',
rnn_type='lstm',
rnn_hidden_size=256,
rnn_num_layers=1,
init_noise_std=1.0,
**kwargs):
if kwargs:
print("ActorCriticRecurrent.__init__ got unexpected arguments, which will be ignored: " + str(kwargs.keys()),)
super().__init__(num_actor_obs=rnn_hidden_size,
num_critic_obs=rnn_hidden_size,
num_actions=num_actions,
actor_hidden_dims=actor_hidden_dims,
critic_hidden_dims=critic_hidden_dims,
activation=activation,
init_noise_std=init_noise_std)
activation = get_activation(activation)
self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
print(f"Actor RNN: {self.memory_a}")
print(f"Critic RNN: {self.memory_c}")
def reset(self, dones=None):
self.memory_a.reset(dones)
self.memory_c.reset(dones)
def act(self, observations, masks=None, hidden_states=None):
input_a = self.memory_a(observations, masks, hidden_states)
return super().act(input_a.squeeze(0))
def act_inference(self, observations):
input_a = self.memory_a(observations)
return super().act_inference(input_a.squeeze(0))
def evaluate(self, critic_observations, masks=None, hidden_states=None):
input_c = self.memory_c(critic_observations, masks, hidden_states)
return super().evaluate(input_c.squeeze(0))
def get_hidden_states(self):
return self.memory_a.hidden_states, self.memory_c.hidden_states
class Memory(torch.nn.Module):
def __init__(self, input_size, type='lstm', num_layers=1, hidden_size=256):
super().__init__()
# RNN
rnn_cls = nn.GRU if type.lower() == 'gru' else nn.LSTM
self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
self.hidden_states = None
def forward(self, input, masks=None, hidden_states=None):
batch_mode = masks is not None
if batch_mode:
# batch mode (policy update): need saved hidden states
if hidden_states is None:
raise ValueError("Hidden states not passed to memory module during policy update")
out, _ = self.rnn(input, hidden_states)
out = unpad_trajectories(out, masks)
else:
# inference mode (collection): use hidden states of last step
out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states)
return out
def reset(self, dones=None):
# When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state
for hidden_state in self.hidden_states:
hidden_state[..., dones, :] = 0.0
- 我们主要看的是和
ActorCritic不同的地方
2-5 Memory类
python
class Memory(torch.nn.Module):
def __init__(self, input_size, type='lstm', num_layers=1, hidden_size=256):
super().__init__()
# RNN
rnn_cls = nn.GRU if type.lower() == 'gru' else nn.LSTM
self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
self.hidden_states = None
def forward(self, input, masks=None, hidden_states=None):
batch_mode = masks is not None
if batch_mode:
# batch mode (policy update): need saved hidden states
if hidden_states is None:
raise ValueError("Hidden states not passed to memory module during policy update")
out, _ = self.rnn(input, hidden_states)
out = unpad_trajectories(out, masks)
else:
# inference mode (collection): use hidden states of last step
out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states)
return out
def reset(self, dones=None):
# When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state
for hidden_state in self.hidden_states:
hidden_state[..., dones, :] = 0.0
Memory是 封装 RNN 的模块 ,主要目的是:- 在 ActorCriticRecurrent 中 给 Actor 和 Critic 提供隐藏状态记忆
- 区分 训练(batch mode) 和 推理(inference mode) 的前向计算
- 管理 隐藏状态的初始化和重置
2-5-1 构造函数-继承自torch.nn.Module
python
def __init__(self, input_size, type='lstm', num_layers=1, hidden_size=256):
super().__init__()
rnn_cls = nn.GRU if type.lower() == 'gru' else nn.LSTM
self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
self.hidden_states = None
input_size:RNN 每一步的输入维度(比如观测向量的维度)type:RNN 类型,支持'lstm'或'gru'num_layers:RNN 堆叠层数hidden_size:RNN 隐藏状态维度self.hidden_states:保存 RNN 的隐藏状态,用于推理模式下连续更新
2-5-2 前向函数 forward
python
def forward(self, input, masks=None, hidden_states=None):
batch_mode = masks is not None
if batch_mode:
# batch mode (policy update)
if hidden_states is None:
raise ValueError("Hidden states not passed to memory module during policy update")
out, _ = self.rnn(input, hidden_states)
out = unpad_trajectories(out, masks)
else:
# inference mode (collection)
out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states)
return out
- batch_mode :当
masks不为空时,表示训练阶段(一次性处理一个 batch 的多条轨迹)- 需要传入
hidden_states(前一 batch 的隐藏状态) - RNN 处理序列 → 输出隐藏状态序列
unpad_trajectories:去掉填充的时间步,保证序列长度正确
- 需要传入
- inference mode :当
masks为空时,表示推理或收集经验- 输入是当前时间步的观测
input.unsqueeze(0):增加时间维度- 更新
self.hidden_states→ 下次继续使用 - 输出当前时间步的 RNN 输出
-
说人话:
- 训练模式 → 一次处理整个序列,batch 更新
- 推理模式 → 一步步走,用上一步的隐藏状态记忆历史
2-5-3 重置函数 `reset()
python
def reset(self, dones=None):
for hidden_state in self.hidden_states:
hidden_state[..., dones, :] = 0.0
dones是布尔向量,表示哪些环境/样本 episode 已经结束
- 将对应样本的隐藏状态置零
2-6 隐藏状态在训练模式的流程(批量训练)
- 训练一次通常会把一整个 batch(好几条轨迹)一次性送进网络。
- RNN 会处理整个序列:
- 每个时间步的输入:当前状态 + 上一时间步隐藏状态。
- 输出隐藏状态序列 h t h_t ht。
- 去掉 padding(因为不同轨迹长度不一样,短的轨迹要填充)。
- 输出送到 MLP:
- Actor 得到动作分布。
- Critic 得到状态价值。
- 用这些输出计算策略梯度和价值函数损失,更新网络参数。
- 注意:隐藏状态在训练中是一次性处理整段序列,下一 batch 可以继续用上一次训练的隐藏状态。
2-7 隐藏状态在推理模式的流程(收集经验 / 实际控制)
- 每步只处理 当前状态。
- RNN 用 上一步的隐藏状态 h t − 1 h_{t-1} ht−1 来更新当前隐藏状态 h t h_t ht。
- 送入 MLP:
- Actor 输出动作 μ ( s t ) μ(s_t) μ(st)。
- Critic 可选地输出状态价值 V ( s t ) V(s_t) V(st)。
- 执行动作,环境给奖励和下一个状态。
- reset(dones) 的作用 :
- 某些环境已经结束了(done=True),对应的隐藏状态要清零。
- 避免上一轮记忆影响下一轮决策。
2-6 网络结构
输入状态 s
RNN / Memory
MLP 隐藏层
Actor 输出 μ(s)
Critic 输出 V(s)
隐藏状态 h_t

小结
- 本期主要解析了
rsl_rl仓库中 ActorCritic 与 ActorCriticRecurrent 的 Python 实现,回顾了 Actor-Critic 的核心原理,重点讲解了 ActorCriticRecurrent 引入 RNN/Memory 模块以增强网络对历史信息的记忆能力,区分了训练和推理模式下隐藏状态的处理,并对网络构建、动作采样、价值评估等函数实现进行了详细剖析,为理解复杂机器人控制任务中的策略与价值网络打下基础。 - 如有错误,欢迎指出!感谢观看