【宇树机器人强化学习】(二):ActorCritic网络和ActorCriticRecurrent网络的python实现与解析

前言


1 ActorCritic类

1-1 ActorCritic 网络回顾
  • 我们在浅显易懂强化学习入门的第五期提到过ActorCritic网络,这里快速的回顾一下核心的内容:
  • 在强化学习中,Actor-Critic 是一种同时学习策略(Policy)和价值函数(Value Function)的框架。公式上可以表达为: π θ ( a ∣ s ) ( 策略网络 ) \pi_\theta(a|s)(策略网络) πθ(a∣s)(策略网络) V ϕ ( s ) ( 价值网络 ) V_\phi(s)(价值网络) Vϕ(s)(价值网络)其中:
    • Actor(策略网络) :负责告诉你"在当前状态 s s s 下应该做什么动作 a a a" a t ∼ π θ ( a t ∣ s t ) a_t \sim \pi_\theta(a_t|s_t) at∼πθ(at∣st)
      • 输出动作分布(连续动作通常是高斯分布 N ( μ , σ 2 ) \mathcal{N}(\mu, \sigma^2) N(μ,σ2)
      • 在训练中采样动作增加探索,推理时取均值动作减少随机性
    • Critic(价值网络) :负责告诉你"当前状态 s s s 的好坏",即状态价值 V ϕ ( s t ) ≈ E [ R t ∣ s t ] V_\phi(s_t) \approx \mathbb{E}[R_t|s_t] Vϕ(st)≈E[Rt∣st]
      • 用于计算优势函数 A t = R t − V ϕ ( s t ) A_t = R_t - V_\phi(s_t) At=Rt−Vϕ(st)
      • 优势函数衡量某个动作比平均水平好多少,是 Actor 更新策略的参考
  • 策略更新依赖价值
    • Actor 的梯度来自策略梯度定理: ∇ θ J ( θ ) = E t [ A t ∇ θ log ⁡ π θ ( a t ∣ s t ) ] \nabla_\theta J(\theta) = \mathbb{E}t \big[ A_t \nabla\theta \log \pi_\theta(a_t|s_t) \big] ∇θJ(θ)=Et[At∇θlogπθ(at∣st)]
    • Critic 的目标是最小化 均方误差 : L critic = E t [ ( V ϕ ( s t ) − R t ) 2 ] L_\text{critic} = \mathbb{E}t \big[(V\phi(s_t) - R_t)^2\big] Lcritic=Et[(Vϕ(st)−Rt)2]
  • 优势函数的作用
    • 提高训练稳定性:仅更新比平均水平更好的动作
    • 减少策略梯度的方差,使训练收敛更快
  • 总结:
    • Actor 决策,Critic 评估,优势函数桥接二者,使策略在训练中既有方向又有参照。

1-2 完整代码一览
  • 我们打开rsl_rl仓库
shell 复制代码
git clone https://github.com/leggedrobotics/rsl_rl.git
cd rsl_rl
git checkout v1.0.2
  • 在项目根目录下的modules文件夹下可以找到actor_critic.py
python 复制代码
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# 
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin

import numpy as np

import torch
import torch.nn as nn
from torch.distributions import Normal
from torch.nn.modules import rnn

class ActorCritic(nn.Module):
    is_recurrent = False
    def __init__(self,  num_actor_obs,
                        num_critic_obs,
                        num_actions,
                        actor_hidden_dims=[256, 256, 256],
                        critic_hidden_dims=[256, 256, 256],
                        activation='elu',
                        init_noise_std=1.0,
                        **kwargs):
        if kwargs:
            print("ActorCritic.__init__ got unexpected arguments, which will be ignored: " + str([key for key in kwargs.keys()]))
        super(ActorCritic, self).__init__()

        activation = get_activation(activation)

        mlp_input_dim_a = num_actor_obs
        mlp_input_dim_c = num_critic_obs

        # Policy
        actor_layers = []
        actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0]))
        actor_layers.append(activation)
        for l in range(len(actor_hidden_dims)):
            if l == len(actor_hidden_dims) - 1:
                actor_layers.append(nn.Linear(actor_hidden_dims[l], num_actions))
            else:
                actor_layers.append(nn.Linear(actor_hidden_dims[l], actor_hidden_dims[l + 1]))
                actor_layers.append(activation)
        self.actor = nn.Sequential(*actor_layers)

        # Value function
        critic_layers = []
        critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0]))
        critic_layers.append(activation)
        for l in range(len(critic_hidden_dims)):
            if l == len(critic_hidden_dims) - 1:
                critic_layers.append(nn.Linear(critic_hidden_dims[l], 1))
            else:
                critic_layers.append(nn.Linear(critic_hidden_dims[l], critic_hidden_dims[l + 1]))
                critic_layers.append(activation)
        self.critic = nn.Sequential(*critic_layers)

        print(f"Actor MLP: {self.actor}")
        print(f"Critic MLP: {self.critic}")

        # Action noise
        self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
        self.distribution = None
        # disable args validation for speedup
        Normal.set_default_validate_args = False
        
        # seems that we get better performance without init
        # self.init_memory_weights(self.memory_a, 0.001, 0.)
        # self.init_memory_weights(self.memory_c, 0.001, 0.)

    @staticmethod
    # not used at the moment
    def init_weights(sequential, scales):
        [torch.nn.init.orthogonal_(module.weight, gain=scales[idx]) for idx, module in
         enumerate(mod for mod in sequential if isinstance(mod, nn.Linear))]


    def reset(self, dones=None):
        pass

    def forward(self):
        raise NotImplementedError
    
    @property
    def action_mean(self):
        return self.distribution.mean

    @property
    def action_std(self):
        return self.distribution.stddev
    
    @property
    def entropy(self):
        return self.distribution.entropy().sum(dim=-1)

    def update_distribution(self, observations):
        mean = self.actor(observations)
        self.distribution = Normal(mean, mean*0. + self.std)

    def act(self, observations, **kwargs):
        self.update_distribution(observations)
        return self.distribution.sample()
    
    def get_actions_log_prob(self, actions):
        return self.distribution.log_prob(actions).sum(dim=-1)

    def act_inference(self, observations):
        actions_mean = self.actor(observations)
        return actions_mean

    def evaluate(self, critic_observations, **kwargs):
        value = self.critic(critic_observations)
        return value

def get_activation(act_name):
    if act_name == "elu":
        return nn.ELU()
    elif act_name == "selu":
        return nn.SELU()
    elif act_name == "relu":
        return nn.ReLU()
    elif act_name == "crelu":
        return nn.ReLU()
    elif act_name == "lrelu":
        return nn.LeakyReLU()
    elif act_name == "tanh":
        return nn.Tanh()
    elif act_name == "sigmoid":
        return nn.Sigmoid()
    else:
        print("invalid activation function!")
        return None
  • 我们接下来看每一个函数分别实现了什么

1-3 初始化函数
python 复制代码
class ActorCritic(nn.Module):
    is_recurrent = False
    def __init__(self,  num_actor_obs,
                        num_critic_obs,
                        num_actions,
                        actor_hidden_dims=[256, 256, 256],
                        critic_hidden_dims=[256, 256, 256],
                        activation='elu',
                        init_noise_std=1.0,
                        **kwargs):
        if kwargs:
            print("ActorCritic.__init__ got unexpected arguments, which will be ignored: " + str([key for key in kwargs.keys()]))
        super(ActorCritic, self).__init__()

        activation = get_activation(activation)

        mlp_input_dim_a = num_actor_obs
        mlp_input_dim_c = num_critic_obs

        # Policy
        actor_layers = []
        actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0]))
        actor_layers.append(activation)
        for l in range(len(actor_hidden_dims)):
            if l == len(actor_hidden_dims) - 1:
                actor_layers.append(nn.Linear(actor_hidden_dims[l], num_actions))
            else:
                actor_layers.append(nn.Linear(actor_hidden_dims[l], actor_hidden_dims[l + 1]))
                actor_layers.append(activation)
        self.actor = nn.Sequential(*actor_layers)

        # Value function
        critic_layers = []
        critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0]))
        critic_layers.append(activation)
        for l in range(len(critic_hidden_dims)):
            if l == len(critic_hidden_dims) - 1:
                critic_layers.append(nn.Linear(critic_hidden_dims[l], 1))
            else:
                critic_layers.append(nn.Linear(critic_hidden_dims[l], critic_hidden_dims[l + 1]))
                critic_layers.append(activation)
        self.critic = nn.Sequential(*critic_layers)

        print(f"Actor MLP: {self.actor}")
        print(f"Critic MLP: {self.critic}")

        # Action noise
        self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
        self.distribution = None
        # disable args validation for speedup
        Normal.set_default_validate_args = False
        
        # seems that we get better performance without init
        # self.init_memory_weights(self.memory_a, 0.001, 0.)
        # self.init_memory_weights(self.memory_c, 0.001, 0.)

1-3-1 超参数
python 复制代码
def __init__(self,  
             num_actor_obs,
             num_critic_obs,
             num_actions,
             actor_hidden_dims=[256, 256, 256],
             critic_hidden_dims=[256, 256, 256],
             activation='elu',
             init_noise_std=1.0,
             **kwargs):
  • 我们来看看超参数
    • num_actor_obs:Actor 网络的输入维度,也就是策略网络可以看到的观测状态数量
    • num_critic_obs:Critic 网络的输入维度,也就是价值网络看到的观测状态数量
    • num_actions:动作空间维度
    • actor_hidden_dims:Actor 网络每一隐藏层的神经元个数
    • critic_hidden_dims:Critic 网络每一隐藏层神经元个数
    • activation:隐藏层激活函数,对应get_activation()函数提供的几个激活函数,"elu""selu""relu""crelu""lrelu""tanh""sigmoid"
    • init_noise_std:Actor 输出动作的初始噪声标准差 σ \sigma σ

1-3-2 构建 Actor MLP与 Critic MLP
python 复制代码
# Policy
actor_layers = []
actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0]))
actor_layers.append(activation)
for l in range(len(actor_hidden_dims)):
	if l == len(actor_hidden_dims) - 1:
		actor_layers.append(nn.Linear(actor_hidden_dims[l], num_actions))
	else:
		actor_layers.append(nn.Linear(actor_hidden_dims[l], actor_hidden_dims[l + 1]))
		actor_layers.append(activation)
self.actor = nn.Sequential(*actor_layers)

# Value function
critic_layers = []
critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0]))
critic_layers.append(activation)
for l in range(len(critic_hidden_dims)):
	if l == len(critic_hidden_dims) - 1:
		critic_layers.append(nn.Linear(critic_hidden_dims[l], 1))
	else:
		critic_layers.append(nn.Linear(critic_hidden_dims[l], critic_hidden_dims[l + 1]))
		critic_layers.append(activation)
self.critic = nn.Sequential(*critic_layers)

print(f"Actor MLP: {self.actor}")
print(f"Critic MLP: {self.critic}")
  • 我们直接用一个表格来描述这两个网络
特征 Actor 网络 Critic 网络
输入 num_actor_obs num_critic_obs(可不同)
输出 动作向量 μ θ ( s ) \mu_\theta(s) μθ(s) 状态价值标量 V ϕ ( s ) V_\phi(s) Vϕ(s)
输出层维度 动作空间维度 1
功能 选择动作(策略) 评估状态(价值函数)
使用 策略梯度更新 计算优势函数指导 Actor
  • 一个输出当前状态对应的动作向量 μ θ ( s ) \mu_\theta(s) μθ(s),一个输出评估当前状态的状态价值标量 V ϕ ( s ) V_\phi(s) Vϕ(s)

1-3-3 动作噪声初始化
python 复制代码
 # Action noise
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
self.distribution = None
Normal.set_default_validate_args = False
  • 在连续动作空间中,动作采样服从正态分布: a ∼ N ( μ ( s ) , σ 2 ) a \sim \mathcal{N}(\mu(s), \sigma^2) a∼N(μ(s),σ2)
  • 其中:
    • self.std :动作分布的标准差 σ \sigma σ,控制动作探索
    • self.distribution:动作分布对象 Normal(mean, std),用于采样、计算 log_prob 和熵

1-4 损失函数计算辅助工具函数
python 复制代码
@property
def action_mean(self):
	return self.distribution.mean

@property
def action_std(self):
	return self.distribution.stddev

@property
def entropy(self):
	return self.distribution.entropy().sum(dim=-1)
  • 这三个都是一些工具函数:
  • action_mean():返回当前状态下动作分布的均值 μ θ ( s ) \mu_\theta(s) μθ(s)
  • action_std():返回当前状态下动作分布标准差 σ \sigma σ
  • entropy():返回动作分布熵 H [ π θ ] H[\pi_\theta] H[πθ]
  • 上面这三个均在algorithms/ppo.py计算损失函数的时候被调用
python 复制代码
mu_batch = self.actor_critic.action_mean
sigma_batch = self.actor_critic.action_std
entropy_batch = self.actor_critic.entropy

1-5 核心函数act()
python 复制代码
def update_distribution(self, observations):
	mean = self.actor(observations)
	self.distribution = Normal(mean, mean*0. + self.std)
def act(self, observations, **kwargs):
        self.update_distribution(observations)
        return self.distribution.sample()
  • 获取Actor 网络输出动作均值: μ θ ( s t ) = Actor ( s t ) \mu_\theta(s_t) = \text{Actor}(s_t) μθ(st)=Actor(st)
  • 并和标准差组合成高斯动作分布: a t ∼ N ( μ θ ( s t ) , σ 2 I ) a_t \sim \mathcal{N}\big(\mu_\theta(s_t), \sigma^2 \mathbf{I}\big) at∼N(μθ(st),σ2I)
  • 注意:这里 mean*0. 只是生成与 mean 形状相同的零张量,保证广播正确。
  • algorithms/ppo.py中每一个batch训练的第一步就是Actor前向计算,重新计算 当前策略的动作分布
python 复制代码
self.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])

1-6 get_actions_log_prob()
python 复制代码
 def get_actions_log_prob(self, actions):
        return self.distribution.log_prob(actions).sum(dim=-1)
  • 作用 :计算动作的对数概率 log ⁡ π θ ( a t ∣ s t ) \log \pi_\theta(a_t|s_t) logπθ(at∣st)
  • algorithms/ppo.py中直接调用:
python 复制代码
actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch)

1-7 核心函数evaluate
python 复制代码
def evaluate(self, critic_observations, **kwargs):
    value = self.critic(critic_observations)
    return value
  • Critic 前向计算状态价值 V ϕ ( s ) V_\phi(s) Vϕ(s)
    • 计算优势函数 A t = R t − V ϕ ( s t ) A_t = R_t - V_\phi(s_t) At=Rt−Vϕ(st)
  • algorithms/ppo.py中直接调用:
python 复制代码
value_batch = self.actor_critic.evaluate(critic_obs_batch)

1-8 网络结构

输入状态 s
MLP 隐藏层
Actor 输出 μ(s)
Critic 输出 V(s)


2 ActorCriticRecurrent

2-1 前置知识补充:RNN
  • RNN 全称Recurrent Neural Network(循环神经网络)
  • 功能:处理序列数据或时间序列任务,例如自然语言、机器人控制、股价预测等
  • 核心思想:当前时刻的输出不仅依赖当前输入,还依赖上一时刻的隐藏状态,形成网络内部的"记忆"。
  • 数学公式 : h t = f ( W x h x t + W h h h t − 1 + b h ) h_t = f(W_{xh} x_t + W_{hh} h_{t-1} + b_h) ht=f(Wxhxt+Whhht−1+bh)其中:
    • x t x_t xt:当前输入(比如观测向量 o t o_t ot)
    • h t − 1 h_{t-1} ht−1:上一时刻隐藏状态(记忆历史信息),其实就是是对历史信息的总结,相当于网络的"记忆"。
    • h t h_t ht:当前隐藏状态
    • f f f:非线性激活函数(tanh, ReLU等)
  • 主要优势
    1. 记忆历史信息:通过隐藏状态保留序列中重要的过去信息
    2. 处理部分可观测问题(POMDP):可以推断当前观测中缺失的状态信息
    3. 增强时间依赖建模能力:动作和价值可以依赖过去多步信息,而不仅仅是当前观测

2-2 RNN和经验池的区别
  • **经验池(Replay Buffer)**是强化学习中 存储过去经验(state, action, reward, next_state)的容器
  • 它的作用是 离线采样历史经验进行训练,打破时间相关性,提高训练稳定性。
  • 例子:在 DQN 或 PPO 中,经验池可以存储上千条轨迹,然后随机采样 mini-batch 更新网络参数。
特征 RNN 记忆 经验池(Replay Buffer)
数据类型 隐藏状态(隐藏的网络向量) 完整的状态、动作、奖励轨迹
功能 捕捉时间依赖,生成连续决策 存储历史经验,用于训练网络
更新方式 随每个时间步前向传播动态更新 离线采样随机更新
生命周期 每条轨迹或 episode 内有效 可跨多个 episode 持久存在
  • RNN 是网络内部的"记忆",帮助做连续决策;经验池是训练中的"记忆库",帮助算法学习更稳健的策略

2-3 ActorCriticRecurrent
  • ActorCriticRecurrentActorCritic 的 RNN 版本 ,引入循环神经网络使得 Actor 和 Critic 能够记住历史信息 ,适合处理 部分可观测环境(POMDP) 或时间依赖的任务。
  • 主要特点:
    1. Actor 和 Critic 都引入了"记忆模块" (RNN/LSTM/GRU),各自的职责仍不变:
      • Actor 的记忆输出 → 决定动作分布( μ θ ( s t ) \mu_\theta(s_t) μθ(st))
      • Critic 的记忆输出 → 估算当前状态的价值( V ϕ ( s t ) V_\phi(s_t) Vϕ(st))
    2. 网络会随着时间不断更新记忆
      • 每个时间步:网络把当前观测 + 之前的记忆 → 更新隐藏状态 h t h_t ht
      • 说人话就是网络不断把新信息加入记忆,让它对整个动作序列有上下文感知。
    3. 每次新 episode 会清空记忆
      • reset(dones) 把对应位置的隐藏状态清零
      • 说人话就是每新的一轮会把记忆清空,不让上一局的经验干扰当前决策。
    4. 训练和推理用法不同
      • 训练(批量更新):RNN 会处理整个序列,利用 masks 和保存的隐藏状态做前向传播
      • 推理(收集经验):每步用上一时刻的隐藏状态生成动作
  • 说人话:RNN 给强化学习网络增加了"记忆",让策略和价值函数能利用过去信息做决策,从而在复杂机器人控制或时间序列任务中表现更稳健。

2-4 完整代码一览
python 复制代码
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# 
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin

import numpy as np

import torch
import torch.nn as nn
from torch.distributions import Normal
from torch.nn.modules import rnn
from .actor_critic import ActorCritic, get_activation
from rsl_rl.utils import unpad_trajectories

class ActorCriticRecurrent(ActorCritic):
    is_recurrent = True
    def __init__(self,  num_actor_obs,
                        num_critic_obs,
                        num_actions,
                        actor_hidden_dims=[256, 256, 256],
                        critic_hidden_dims=[256, 256, 256],
                        activation='elu',
                        rnn_type='lstm',
                        rnn_hidden_size=256,
                        rnn_num_layers=1,
                        init_noise_std=1.0,
                        **kwargs):
        if kwargs:
            print("ActorCriticRecurrent.__init__ got unexpected arguments, which will be ignored: " + str(kwargs.keys()),)

        super().__init__(num_actor_obs=rnn_hidden_size,
                         num_critic_obs=rnn_hidden_size,
                         num_actions=num_actions,
                         actor_hidden_dims=actor_hidden_dims,
                         critic_hidden_dims=critic_hidden_dims,
                         activation=activation,
                         init_noise_std=init_noise_std)

        activation = get_activation(activation)

        self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
        self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)

        print(f"Actor RNN: {self.memory_a}")
        print(f"Critic RNN: {self.memory_c}")

    def reset(self, dones=None):
        self.memory_a.reset(dones)
        self.memory_c.reset(dones)

    def act(self, observations, masks=None, hidden_states=None):
        input_a = self.memory_a(observations, masks, hidden_states)
        return super().act(input_a.squeeze(0))

    def act_inference(self, observations):
        input_a = self.memory_a(observations)
        return super().act_inference(input_a.squeeze(0))

    def evaluate(self, critic_observations, masks=None, hidden_states=None):
        input_c = self.memory_c(critic_observations, masks, hidden_states)
        return super().evaluate(input_c.squeeze(0))
    
    def get_hidden_states(self):
        return self.memory_a.hidden_states, self.memory_c.hidden_states


class Memory(torch.nn.Module):
    def __init__(self, input_size, type='lstm', num_layers=1, hidden_size=256):
        super().__init__()
        # RNN
        rnn_cls = nn.GRU if type.lower() == 'gru' else nn.LSTM
        self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
        self.hidden_states = None
    
    def forward(self, input, masks=None, hidden_states=None):
        batch_mode = masks is not None
        if batch_mode:
            # batch mode (policy update): need saved hidden states
            if hidden_states is None:
                raise ValueError("Hidden states not passed to memory module during policy update")
            out, _ = self.rnn(input, hidden_states)
            out = unpad_trajectories(out, masks)
        else:
            # inference mode (collection): use hidden states of last step
            out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states)
        return out

    def reset(self, dones=None):
        # When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state
        for hidden_state in self.hidden_states:
            hidden_state[..., dones, :] = 0.0
  • 我们主要看的是和ActorCritic不同的地方

2-5 Memory类
python 复制代码
class Memory(torch.nn.Module):
    def __init__(self, input_size, type='lstm', num_layers=1, hidden_size=256):
        super().__init__()
        # RNN
        rnn_cls = nn.GRU if type.lower() == 'gru' else nn.LSTM
        self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
        self.hidden_states = None
    
    def forward(self, input, masks=None, hidden_states=None):
        batch_mode = masks is not None
        if batch_mode:
            # batch mode (policy update): need saved hidden states
            if hidden_states is None:
                raise ValueError("Hidden states not passed to memory module during policy update")
            out, _ = self.rnn(input, hidden_states)
            out = unpad_trajectories(out, masks)
        else:
            # inference mode (collection): use hidden states of last step
            out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states)
        return out

    def reset(self, dones=None):
        # When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state
        for hidden_state in self.hidden_states:
            hidden_state[..., dones, :] = 0.0
  • Memory封装 RNN 的模块 ,主要目的是:
    1. 在 ActorCriticRecurrent 中 给 Actor 和 Critic 提供隐藏状态记忆
    2. 区分 训练(batch mode)推理(inference mode) 的前向计算
    3. 管理 隐藏状态的初始化和重置
2-5-1 构造函数-继承自torch.nn.Module
python 复制代码
def __init__(self, input_size, type='lstm', num_layers=1, hidden_size=256):
    super().__init__()
    rnn_cls = nn.GRU if type.lower() == 'gru' else nn.LSTM
    self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
    self.hidden_states = None
  • input_size:RNN 每一步的输入维度(比如观测向量的维度)
  • type:RNN 类型,支持 'lstm''gru'
  • num_layers:RNN 堆叠层数
  • hidden_size:RNN 隐藏状态维度
  • self.hidden_states:保存 RNN 的隐藏状态,用于推理模式下连续更新

2-5-2 前向函数 forward
python 复制代码
def forward(self, input, masks=None, hidden_states=None):
    batch_mode = masks is not None
    if batch_mode:
        # batch mode (policy update)
        if hidden_states is None:
            raise ValueError("Hidden states not passed to memory module during policy update")
        out, _ = self.rnn(input, hidden_states)
        out = unpad_trajectories(out, masks)
    else:
        # inference mode (collection)
        out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states)
    return out
  • batch_mode :当 masks 不为空时,表示训练阶段(一次性处理一个 batch 的多条轨迹)
    • 需要传入 hidden_states(前一 batch 的隐藏状态)
    • RNN 处理序列 → 输出隐藏状态序列
    • unpad_trajectories:去掉填充的时间步,保证序列长度正确
  • inference mode :当 masks 为空时,表示推理或收集经验
    • 输入是当前时间步的观测
    • input.unsqueeze(0):增加时间维度
    • 更新 self.hidden_states → 下次继续使用
    • 输出当前时间步的 RNN 输出
  • 说人话:

    • 训练模式 → 一次处理整个序列,batch 更新
    • 推理模式 → 一步步走,用上一步的隐藏状态记忆历史

2-5-3 重置函数 `reset()
python 复制代码
def reset(self, dones=None):
    for hidden_state in self.hidden_states:
        hidden_state[..., dones, :] = 0.0
  • dones 是布尔向量,表示哪些环境/样本 episode 已经结束
  • 将对应样本的隐藏状态置零

2-6 隐藏状态在训练模式的流程(批量训练)
  1. 训练一次通常会把一整个 batch(好几条轨迹)一次性送进网络。
  2. RNN 会处理整个序列:
    • 每个时间步的输入:当前状态 + 上一时间步隐藏状态。
    • 输出隐藏状态序列 h t h_t ht。
  3. 去掉 padding(因为不同轨迹长度不一样,短的轨迹要填充)。
  4. 输出送到 MLP:
    • Actor 得到动作分布。
    • Critic 得到状态价值。
  5. 用这些输出计算策略梯度和价值函数损失,更新网络参数。
  6. 注意:隐藏状态在训练中是一次性处理整段序列,下一 batch 可以继续用上一次训练的隐藏状态。

2-7 隐藏状态在推理模式的流程(收集经验 / 实际控制)
  1. 每步只处理 当前状态
  2. RNN 用 上一步的隐藏状态 h t − 1 h_{t-1} ht−1 来更新当前隐藏状态 h t h_t ht。
  3. 送入 MLP:
    • Actor 输出动作 μ ( s t ) μ(s_t) μ(st)。
    • Critic 可选地输出状态价值 V ( s t ) V(s_t) V(st)。
  4. 执行动作,环境给奖励和下一个状态。
  5. reset(dones) 的作用
    • 某些环境已经结束了(done=True),对应的隐藏状态要清零。
    • 避免上一轮记忆影响下一轮决策。

2-6 网络结构

输入状态 s
RNN / Memory
MLP 隐藏层
Actor 输出 μ(s)
Critic 输出 V(s)
隐藏状态 h_t



小结

  • 本期主要解析了 rsl_rl 仓库中 ActorCriticActorCriticRecurrent 的 Python 实现,回顾了 Actor-Critic 的核心原理,重点讲解了 ActorCriticRecurrent 引入 RNN/Memory 模块以增强网络对历史信息的记忆能力,区分了训练和推理模式下隐藏状态的处理,并对网络构建、动作采样、价值评估等函数实现进行了详细剖析,为理解复杂机器人控制任务中的策略与价值网络打下基础。
  • 如有错误,欢迎指出!感谢观看
相关推荐
郝学胜-神的一滴1 小时前
CMake:解锁C++跨平台工程构建的核心密钥
开发语言·c++·职场和发展
沐知全栈开发1 小时前
MVC 控制器
开发语言
`Jay1 小时前
高并发数据采集:隧道代理池架构设计与实现
爬虫·python·学习·golang·代理模式
Csvn1 小时前
Python 装饰器从入门到实战
python
wjs20241 小时前
ECharts 交互组件:深入解析与实战应用
开发语言
AI科技星1 小时前
基于v≡c空间光速螺旋量子几何归一化统一场论第一性原理的时间势差本源理论
人工智能·线性代数·算法·机器学习·平面
!chen1 小时前
C# + ViewFaceCore 快速实现高精度人脸识别
开发语言·c#
佑白雪乐1 小时前
C++标准总结+VSCode使用+MinGW
开发语言·c++·vscode
AC赳赳老秦2 小时前
智能协同新纪元:DeepSeek驱动的跨岗位、跨工具多智能体实操体系展望(2026)
大数据·运维·人工智能·深度学习·机器学习·ai-native·deepseek