【宇树机器人强化学习】(一):PPO算法的python实现与解析

前言


0 仓库安装

  • 关于仓库的安装和环境配置官方的文档已经非常清楚了,这里就不在赘述。

  • 官方教程

  • 通过下述指令可以快速获取仓库代码。

shell 复制代码
git clone https://github.com/leggedrobotics/rsl_rl.git
cd rsl_rl
git checkout v1.0.2

0-1 PPO公式回顾
  • 姑且这里回顾一下PPO的核心公式PPO 的目标函数是: L c l i p ( θ ) = E [ min ⁡ ( r ( θ ) A , c l i p ( r ( θ ) , 1 − ϵ , 1 + ϵ ) A ) ] L^{clip}(\theta)=\mathbb{E}[\min(r(\theta)A,\mathrm{clip}(r(\theta),1-\epsilon,1+\epsilon)A)] Lclip(θ)=E[min(r(θ)A,clip(r(θ),1−ϵ,1+ϵ)A)]其中:
    • r ( θ ) r(\theta) r(θ):新旧策略概率比
    • A A A:Advantage(优势函数)
    • ϵ \epsilon ϵ:裁剪范围,一般取 0.1~0.2

  • 概率比率(Probability Ratio) r ( θ ) = π θ ( a ∣ s ) π θ o l d ( a ∣ s ) r(\theta) = \frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)} r(θ)=πθold(a∣s)πθ(a∣s)它表示:
    • 新策略和旧策略在某个动作上的概率比例。
    • 如果r ≈ 1,说明新旧策略 差不多
    • 如果r >> 1 或 r << 1,说明策略 变化太大
  • 通过上述公式,PPO 会限制 r ( θ ) r(\theta) r(θ)的取值范围 [ 1 − ϵ , 1 + ϵ ] [1-\epsilon, 1+\epsilon] [1−ϵ,1+ϵ]如果超过这个范围,梯度就会被裁剪,不再继续增大。

  • GAE(Generalized Advantage Estimation) :
    • 它通过引入一个参数 λ(lambda) ,将 多步 TD 误差进行加权平均,从而得到更加稳定的 Advantage 估计。
  • 公式: A t G A E = ∑ l = 0 ∞ ( γ λ ) l δ t + l A_t^{GAE} = \sum_{l=0}^{\infty} (\gamma \lambda)^l \delta_{t+l} AtGAE=l=0∑∞(γλ)lδt+l其中: δ t = r t + γ V ( s t + 1 ) − V ( s t ) \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t) δt=rt+γV(st+1)−V(st)
  • 表示 TD 误差(Temporal Difference Error)
  • 参数 λ \lambda λ 控制了 偏差和方差之间的平衡
    • λ = 0 只使用 一步 TD 误差,方差小,偏差较大
    • λ = 1 接近 Monte Carlo 回报.偏差小.方差较大

  • 策略熵的公式为: H ( π ) = − ∑ π ( a ∣ s ) log ⁡ π ( a ∣ s ) H(\pi) = -\sum \pi(a|s)\log \pi(a|s) H(π)=−∑π(a∣s)logπ(a∣s)
    • 策略越随机,熵越大

1 仓库一览

  • 拉取完仓库以后,我们可以简单的使用tree指令看一下整个项目的结构
  • rsl_rl目录结构
bash 复制代码
rsl_rl/
├── algorithms/
├── env/
├── modules/
├── runners/
├── storage/
└── utils/

1-1 algorithms/目录
bash 复制代码
algorithms/
├── __init__.py
├── ppo.py
  • 功能 :存放 RL 算法实现,例如 ppo.py 实现了 PPO(Proximal Policy Optimization) 算法。
  • 特点
    • 可以扩展更多算法(如 DDPG、TD3、DPPO)。
    • 提供训练所需的核心算法逻辑(策略更新、损失函数计算等)。

1-2 env/目录
bash 复制代码
env/
├── __init__.py
└── vec_env.py
  • 功能 :封装环境接口。
    • vec_env.py 实现 Vectorized Environment,支持多环境并行训练。
  • 作用
    • 对接仿真环境(如 PyBullet / Mujoco)。
    • 提供标准接口给算法训练(step、reset、render 等)。

1-3 modules/目录
bash 复制代码
modules/
├── actor_critic.py
├── actor_critic_recurrent.py
  • 功能 :定义策略网络结构。
    • actor_critic.py:普通 Actor-Critic 网络。
    • actor_critic_recurrent.py:RNN / LSTM 版本的 Actor-Critic 网络。
  • 作用
    • 提供策略和价值网络给 PPO 或其他算法调用。
    • 支持状态序列建模,适合处理时间相关的机器人动作控制。

1-4 runners/目录
bash 复制代码
runners/
└── on_policy_runner.py
  • 功能 :训练调度器。
    • on_policy_runner.py 负责 按策略采样数据并执行训练循环
  • 作用
    • 管理数据采样、训练步数、模型保存。
    • 将算法、环境、存储模块整合成完整的训练流程。

1-5 storage/目录
bash 复制代码
storage/
└── rollout_storage.py
  • 功能:存储采样轨迹(rollouts)。
  • 作用
    • PPO 需要保存每一步的状态、动作、奖励等。
    • 提供 mini-batch 更新、归一化等功能。

1-6 utils/目录
bash 复制代码
utils/
└── utils.py
  • 功能 :工具函数。
    • 例如日志记录、模型保存/加载、张量操作等。
  • 作用
    • 提供训练和部署所需的通用工具函数,减轻主逻辑负担。

2 PPO算法的python实现

2-1 代码一览
  • 代码的路径在
bash 复制代码
algorithms/
├── __init__.py
├── ppo.py
  • 代码整体在这,别急我们一部分一部分进行分析
python 复制代码
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# 
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin

import torch
import torch.nn as nn
import torch.optim as optim

from rsl_rl.modules import ActorCritic
from rsl_rl.storage import RolloutStorage

class PPO:
    actor_critic: ActorCritic
    def __init__(self,
                 actor_critic,
                 num_learning_epochs=1,
                 num_mini_batches=1,
                 clip_param=0.2,
                 gamma=0.998,
                 lam=0.95,
                 value_loss_coef=1.0,
                 entropy_coef=0.0,
                 learning_rate=1e-3,
                 max_grad_norm=1.0,
                 use_clipped_value_loss=True,
                 schedule="fixed",
                 desired_kl=0.01,
                 device='cpu',
                 ):

        self.device = device

        self.desired_kl = desired_kl
        self.schedule = schedule
        self.learning_rate = learning_rate

        # PPO components
        self.actor_critic = actor_critic
        self.actor_critic.to(self.device)
        self.storage = None # initialized later
        self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=learning_rate)
        self.transition = RolloutStorage.Transition()

        # PPO parameters
        self.clip_param = clip_param
        self.num_learning_epochs = num_learning_epochs
        self.num_mini_batches = num_mini_batches
        self.value_loss_coef = value_loss_coef
        self.entropy_coef = entropy_coef
        self.gamma = gamma
        self.lam = lam
        self.max_grad_norm = max_grad_norm
        self.use_clipped_value_loss = use_clipped_value_loss

    def init_storage(self, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape):
        self.storage = RolloutStorage(num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape, self.device)

    def test_mode(self):
        self.actor_critic.test()
    
    def train_mode(self):
        self.actor_critic.train()

    def act(self, obs, critic_obs):
        if self.actor_critic.is_recurrent:
            self.transition.hidden_states = self.actor_critic.get_hidden_states()
        # Compute the actions and values
        self.transition.actions = self.actor_critic.act(obs).detach()
        self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
        self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
        self.transition.action_mean = self.actor_critic.action_mean.detach()
        self.transition.action_sigma = self.actor_critic.action_std.detach()
        # need to record obs and critic_obs before env.step()
        self.transition.observations = obs
        self.transition.critic_observations = critic_obs
        return self.transition.actions
    
    def process_env_step(self, rewards, dones, infos):
        self.transition.rewards = rewards.clone()
        self.transition.dones = dones
        # Bootstrapping on time outs
        if 'time_outs' in infos:
            self.transition.rewards += self.gamma * torch.squeeze(self.transition.values * infos['time_outs'].unsqueeze(1).to(self.device), 1)

        # Record the transition
        self.storage.add_transitions(self.transition)
        self.transition.clear()
        self.actor_critic.reset(dones)
    
    def compute_returns(self, last_critic_obs):
        last_values= self.actor_critic.evaluate(last_critic_obs).detach()
        self.storage.compute_returns(last_values, self.gamma, self.lam)

    def update(self):
        mean_value_loss = 0
        mean_surrogate_loss = 0
        if self.actor_critic.is_recurrent:
            generator = self.storage.reccurent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
        else:
            generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
        for obs_batch, critic_obs_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, \
            old_mu_batch, old_sigma_batch, hid_states_batch, masks_batch in generator:


                self.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])
                actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch)
                value_batch = self.actor_critic.evaluate(critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1])
                mu_batch = self.actor_critic.action_mean
                sigma_batch = self.actor_critic.action_std
                entropy_batch = self.actor_critic.entropy

                # KL
                if self.desired_kl != None and self.schedule == 'adaptive':
                    with torch.inference_mode():
                        kl = torch.sum(
                            torch.log(sigma_batch / old_sigma_batch + 1.e-5) + (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch)) / (2.0 * torch.square(sigma_batch)) - 0.5, axis=-1)
                        kl_mean = torch.mean(kl)

                        if kl_mean > self.desired_kl * 2.0:
                            self.learning_rate = max(1e-5, self.learning_rate / 1.5)
                        elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
                            self.learning_rate = min(1e-2, self.learning_rate * 1.5)
                        
                        for param_group in self.optimizer.param_groups:
                            param_group['lr'] = self.learning_rate


                # Surrogate loss
                ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch))
                surrogate = -torch.squeeze(advantages_batch) * ratio
                surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(ratio, 1.0 - self.clip_param,
                                                                                1.0 + self.clip_param)
                surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()

                # Value function loss
                if self.use_clipped_value_loss:
                    value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(-self.clip_param,
                                                                                                    self.clip_param)
                    value_losses = (value_batch - returns_batch).pow(2)
                    value_losses_clipped = (value_clipped - returns_batch).pow(2)
                    value_loss = torch.max(value_losses, value_losses_clipped).mean()
                else:
                    value_loss = (returns_batch - value_batch).pow(2).mean()

                loss = surrogate_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_batch.mean()

                # Gradient step
                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm)
                self.optimizer.step()

                mean_value_loss += value_loss.item()
                mean_surrogate_loss += surrogate_loss.item()

        num_updates = self.num_learning_epochs * self.num_mini_batches
        mean_value_loss /= num_updates
        mean_surrogate_loss /= num_updates
        self.storage.clear()

        return mean_value_loss, mean_surrogate_loss

2-2 初始化函数
  • 我们来看看这个类初始化部分:
python 复制代码
class PPO:
    actor_critic: ActorCritic
    def __init__(self,
                 actor_critic,
                 num_learning_epochs=1,
                 num_mini_batches=1,
                 clip_param=0.2,
                 gamma=0.998,
                 lam=0.95,
                 value_loss_coef=1.0,
                 entropy_coef=0.0,
                 learning_rate=1e-3,
                 max_grad_norm=1.0,
                 use_clipped_value_loss=True,
                 schedule="fixed",
                 desired_kl=0.01,
                 device='cpu',
                 ):

        self.device = device

        self.desired_kl = desired_kl
        self.schedule = schedule
        self.learning_rate = learning_rate

        # PPO components
        self.actor_critic = actor_critic
        self.actor_critic.to(self.device)
        self.storage = None # initialized later
        self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=learning_rate)
        self.transition = RolloutStorage.Transition()

        # PPO parameters
        self.clip_param = clip_param
        self.num_learning_epochs = num_learning_epochs
        self.num_mini_batches = num_mini_batches
        self.value_loss_coef = value_loss_coef
        self.entropy_coef = entropy_coef
        self.gamma = gamma
        self.lam = lam
        self.max_grad_norm = max_grad_norm
        self.use_clipped_value_loss = use_clipped_value_loss
  • 初始化传入了大量PPO的超参数:
    • actor_critic:这里传入的是PPO算法必须的Actor-Critic 网络(这个网络的定义在modules/actor_critic.py,这个我们后面几期会进行解析)
    • num_learning_epochs=1:每一批 rollout 数据 重复训练多少轮
    • num_mini_batches=1:把 rollout 数据分成多少 mini-batch(以提高样本利用率)
    • ==clip_param=0.2:这个是PPO的 ϵ \epsilon ϵ核心参数,用于对策略进行裁切
    • gamma=0.998:奖励折扣因子,用于控制控制 长期奖励权重
    • lam=0.95GAE的 λ 参数,用于在计算优势函数的时候降低方差
    • value_loss_coef=1.0:损失函数权重,越高越关注value网络
    • entropy_coef=0.0:策略熵, 鼓励策略保持一定随机性,用于设置额外探索奖励
    • learning_rate=1e-3:神经网络学习率
    • max_grad_norm=1.0:梯度裁剪,大于此值的梯度值会被裁切,防止梯度爆炸
    • use_clipped_value_loss=True:是否使用Value Clipping,防止 Critic更新过大
    • schedule="fixed":表示 训练过程中学习率保持固定,不根据 KL 或训练情况动态调整
    • desired_kl=0.01:目标 KL 散度,表示 期望新旧策略之间的 KL 距离大约为 0.01,用于在自适应学习率策略中控制策略更新幅度
    • `device='cpu':运行设备
  • 同时还定义了一些变量
python 复制代码
self.storage = None # initialized later
self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=learning_rate)
self.transition = RolloutStorage.Transition()
  • self.storage:经验回放缓存(Rollout Buffer)占位符
  • self.optimizerAdam优化器 来更新 Actor-Critic 网络参数
  • self.transition临时数据结构(step buffer)

2-3 初始化经验回放缓存函数 init_storage()
python 复制代码
def init_storage(self, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape):
    self.storage = RolloutStorage(
        num_envs,
        num_transitions_per_env,
        actor_obs_shape,
        critic_obs_shape,
        action_shape,
        self.device
    )
  • 这个函数用于初始化经验回放缓存(Rollout Buffer)机制的数据缓存(Rollout Buffer)
  • 定义在storage/rollout_storage.py,我们之后也会解析

2-4 模式函数
python 复制代码
def test_mode(self):
    self.actor_critic.test()
def train_mode(self):
    self.actor_critic.train()
  • 这两都是启用actor_critic的模式
  • 这个网络的定义在modules/actor_critic.py,我们之后也会解析

2-5 行动函数act()
  • 这个函数的作用:根据当前状态 ,计算动作并返回,同时并记录训练数据
python 复制代码
def act(self, obs, critic_obs):
        if self.actor_critic.is_recurrent:
            self.transition.hidden_states = self.actor_critic.get_hidden_states()
        # Compute the actions and values
        self.transition.actions = self.actor_critic.act(obs).detach()
        self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
        self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
        self.transition.action_mean = self.actor_critic.action_mean.detach()
        self.transition.action_sigma = self.actor_critic.action_std.detach()
        # need to record obs and critic_obs before env.step()
        self.transition.observations = obs
        self.transition.critic_observations = critic_obs
        return self.transition.actions

  • 我们一步步看,首先我们来看函数的输入
python 复制代码
def act(self, obs, critic_obs):
  • obs策略网络的输入
  • critic_obs价值网络输入

python 复制代码
if self.actor_critic.is_recurrent:
    self.transition.hidden_states = self.actor_critic.get_hidden_states()
  • 这里判断是否需要使用 RNN / LSTM 网络 ,如果是,需要保存hidden_state,否则后面训练无法恢复序列状态。

python 复制代码
self.transition.actions = self.actor_critic.act(obs).detach()
self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
  • Actor网络根据策略网络的输入 来计算动作,.detach()表示不参与梯度计算,只是进行采样。
  • Critic网络 计算价值,.detach()表示不参与梯度计算,只是进行采样。

python 复制代码
self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
  • 这里计算动作概率 l o g π θ ( a ∣ s ) log π_\theta(a|s) logπθ(a∣s),用于后面计算概率比率(Probability Ratio)的时候使用

python 复制代码
self.transition.action_mean = self.actor_critic.action_mean.detach()
self.transition.action_sigma = self.actor_critic.action_std.detach()
  • 这里保存策略分布,用于 KL散度计算 。其中动作通常来自 高斯分布 : a N ( μ , σ ) a ~ N(μ , σ) a N(μ,σ)

python 复制代码
# need to record obs and critic_obs before env.step()
self.transition.observations = obs
self.transition.critic_observations = critic_obs
return self.transition.actions
  • 保存状态并返回动作,这里需要在env.step()之前保存,否则状态就改变了

2-6 处理环境反馈函数process_env_step()
  • 这个函数用于处理环境返回结果,并存储数据
python 复制代码
def process_env_step(self, rewards, dones, infos):
        self.transition.rewards = rewards.clone()
        self.transition.dones = dones
        # Bootstrapping on time outs
        if 'time_outs' in infos:
            self.transition.rewards += self.gamma * torch.squeeze(self.transition.values * infos['time_outs'].unsqueeze(1).to(self.device), 1)

        # Record the transition
        self.storage.add_transitions(self.transition)
        self.transition.clear()
        self.actor_critic.reset(dones)

python 复制代码
self.transition.rewards = rewards.clone()
self.transition.dones = dones
  • 保存奖励 r t r_t rt,同时保存终止信号

python 复制代码
# Bootstrapping on time outs
if 'time_outs' in infos:
	self.transition.rewards += self.gamma * torch.squeeze(self.transition.values * infos['time_outs'].unsqueeze(1).to(self.device), 1)
  • 有些 episode 结束不是因为失败,而是达到最大步数,那就不能把未来价值 V ( s ) V(s) V(s)当成 0。
  • 这时候修正value的计算 r = r + γ V ( s ) r = r + γV(s) r=r+γV(s)

python 复制代码
# Record the transition
self.storage.add_transitions(self.transition)
self.transition.clear()
self.actor_critic.reset(dones)
  • 在经验池里头储存数据,每一步的数据包含 ( s , a , r , V , l o g p r o b ) (s,a,r,V,log_prob) (s,a,r,V,logprob)
  • 清空 transition并重置RNN

2-7 计算收获函数compute_returns
  • 计算 PPO 训练需要的奖励returns和 优势函数advantage
python 复制代码
def compute_returns(self, last_critic_obs):
	last_values= self.actor_critic.evaluate(last_critic_obs).detach()
	self.storage.compute_returns(last_values, self.gamma, self.lam)
  • 第一行会计算最后状态价值 V ( s T ) V(s_T) V(sT)
  • 第二行就是GAE计算,将 多步 TD 误差进行加权平均,从而得到更加稳定的 Advantage 估计。
  • Return: R t = r t + γ R t + 1 R_t = r_t + γR_{t+1} Rt=rt+γRt+1
  • 优势函数Advantage(GAE): δ t = r t + γ V ( s t + 1 ) − V ( s t ) δ_t = r_t + γV(s_{t+1}) - V(s_t) δt=rt+γV(st+1)−V(st) A t = δ t + γ λ δ t + 1 + ( γ λ ) 2 δ t + 2 + . . . A_t = δ_t + γλδ_{t+1} + (γλ)^2δ_{t+2}+... At=δt+γλδt+1+(γλ)2δt+2+...

3 核心函数update

3-1 完整实现
  • 前面的代码主要完成 数据采样与优势计算 ,而 PPO 的核心训练逻辑全部在 update() 函数中完成。 这个函数负责:
    • 重新计算策略概率
    • 计算 PPO loss
    • 反向传播更新网络
python 复制代码
 def update(self):
        mean_value_loss = 0
        mean_surrogate_loss = 0
        if self.actor_critic.is_recurrent:
            generator = self.storage.reccurent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
        else:
            generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
        for obs_batch, critic_obs_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, \
            old_mu_batch, old_sigma_batch, hid_states_batch, masks_batch in generator:


                self.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])
                actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch)
                value_batch = self.actor_critic.evaluate(critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1])
                mu_batch = self.actor_critic.action_mean
                sigma_batch = self.actor_critic.action_std
                entropy_batch = self.actor_critic.entropy

                # KL
                if self.desired_kl != None and self.schedule == 'adaptive':
                    with torch.inference_mode():
                        kl = torch.sum(
                            torch.log(sigma_batch / old_sigma_batch + 1.e-5) + (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch)) / (2.0 * torch.square(sigma_batch)) - 0.5, axis=-1)
                        kl_mean = torch.mean(kl)

                        if kl_mean > self.desired_kl * 2.0:
                            self.learning_rate = max(1e-5, self.learning_rate / 1.5)
                        elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
                            self.learning_rate = min(1e-2, self.learning_rate * 1.5)
                        
                        for param_group in self.optimizer.param_groups:
                            param_group['lr'] = self.learning_rate


                # Surrogate loss
                ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch))
                surrogate = -torch.squeeze(advantages_batch) * ratio
                surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(ratio, 1.0 - self.clip_param,
                                                                                1.0 + self.clip_param)
                surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()

                # Value function loss
                if self.use_clipped_value_loss:
                    value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(-self.clip_param,
                                                                                                    self.clip_param)
                    value_losses = (value_batch - returns_batch).pow(2)
                    value_losses_clipped = (value_clipped - returns_batch).pow(2)
                    value_loss = torch.max(value_losses, value_losses_clipped).mean()
                else:
                    value_loss = (returns_batch - value_batch).pow(2).mean()

                loss = surrogate_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_batch.mean()

                # Gradient step
                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm)
                self.optimizer.step()

                mean_value_loss += value_loss.item()
                mean_surrogate_loss += surrogate_loss.item()

        num_updates = self.num_learning_epochs * self.num_mini_batches
        mean_value_loss /= num_updates
        mean_surrogate_loss /= num_updates
        self.storage.clear()

        return mean_value_loss, mean_surrogate_loss

3-2 参数定义
  • 我们一步步来看:
python 复制代码
mean_value_loss = 0
mean_surrogate_loss = 0
if self.actor_critic.is_recurrent:
	generator = self.storage.reccurent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
else:
	generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
  • mean_value_lossmean_surrogate_loss:统计整个训练过程中的 平均 loss,用于日志打印。
  • 然后我们根据是否使用 RNN / LSTM 网络 来构造 mini-batch 迭代器

3-3 每个batch
python 复制代码
for obs_batch, critic_obs_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, \

old_mu_batch, old_sigma_batch, hid_states_batch, masks_batch in generator:
  • 然后我们在每个 batch 取出这些变量:
    • obs_batch:Actor网络输入
    • critic_obs_batch:Critic网络输入
    • actions_batch:采样动作
    • target_values_batch:旧价值函数V(s)
    • advantages_batch:GAE优势函数
    • critic_obs_batch:Critic输入
    • returns_batch:目标价值
    • old_actions_log_prob_batch:旧策略概率 l o g π θ ( a ∣ s ) log π_\theta(a|s) logπθ(a∣s)
    • old_mu_batch:旧策略均值 μ μ μ
    • old_sigma_batch:旧策略方差 σ σ σ

python 复制代码
self.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])
  • 首先调用act函数进行Actor前向计算,重新计算 当前策略的动作分布 π θ ( a ∣ s ) = N ( μ θ ( s ) , σ θ ( s ) ) \pi_\theta(a|s) = \mathcal{N}(\mu_\theta(s), \sigma_\theta(s)) πθ(a∣s)=N(μθ(s),σθ(s))

python 复制代码
actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch)
  • 然后获取当前动作概率log prob log ⁡ π θ ( a t ∣ s t ) \log \pi_\theta(a_t|s_t) logπθ(at∣st),用于一会计算概率比率

python 复制代码
value_batch = self.actor_critic.evaluate(critic_obs_batch)
  • Critic计算价值函数 value

python 复制代码
mu_batch = self.actor_critic.action_mean
sigma_batch = self.actor_critic.action_std
entropy_batch = self.actor_critic.entropy
  • mu_batch: μ \mu μ策略均值
  • sigma_batch: σ \sigma σ策略方差
  • entropy:策略熵

3-4 KL散度控制
  • KL散度控制就干一件事:如果 KL 太大,降低学习率。
  • 这是一种简单的 Trust Region 近似实现, 用于防止策略更新过大导致训练不稳定。
python 复制代码
# KL
if self.desired_kl != None and self.schedule == 'adaptive':
	with torch.inference_mode():
		kl = torch.sum(
			torch.log(sigma_batch / old_sigma_batch + 1.e-5) + (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch)) / (2.0 * torch.square(sigma_batch)) - 0.5, axis=-1)
		kl_mean = torch.mean(kl)

		if kl_mean > self.desired_kl * 2.0:
			self.learning_rate = max(1e-5, self.learning_rate / 1.5)
		elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
			self.learning_rate = min(1e-2, self.learning_rate * 1.5)
		
		for param_group in self.optimizer.param_groups:
			param_group['lr'] = self.learning_rate
  • 其中的kl对应 高斯分布KL公式 K L ( π o l d ∣ ∣ π n e w ) = log ⁡ σ σ o l d + σ o l d 2 + ( μ o l d − μ ) 2 2 σ 2 − 1 2 KL(\pi_{old}||\pi_{new}) = \log\frac{\sigma}{\sigma_{old}} + \frac{\sigma_{old}^2 + (\mu_{old}-\mu)^2}{2\sigma^2} - \frac12 KL(πold∣∣πnew)=logσoldσ+2σ2σold2+(μold−μ)2−21
python 复制代码
if kl_mean > self.desired_kl * 2.0:
	self.learning_rate = max(1e-5, self.learning_rate / 1.5)
elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
	self.learning_rate = min(1e-2, self.learning_rate * 1.5)

for param_group in self.optimizer.param_groups:
	param_group['lr'] = self.learning_rate
  • 自适应学习率并更新
KL 说明
太大 更新太猛
太小 更新太慢

3-5 PPO核心:概率比率
python 复制代码
ratio = torch.exp(actions_log_prob_batch - old_actions_log_prob_batch)
  • 这就是PPO的核心公式: r t ( θ ) = π θ ( a t ∣ s t ) π θ o l d ( a t ∣ s t ) r_t(\theta) = \frac{\pi_\theta(a_t|s_t)} {\pi_{\theta_{old}}(a_t|s_t)} rt(θ)=πθold(at∣st)πθ(at∣st)
  • 它表示:
    • 新策略和旧策略在某个动作上的概率比例。
    • 如果r ≈ 1,说明新旧策略 差不多
    • 如果r >> 1 或 r << 1,说明策略 变化太大

3-6 PPO又一核心 Clip裁切
python 复制代码
surrogate = -torch.squeeze(advantages_batch) * ratio
surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(ratio, 1.0 - self.clip_param,
                                                                                1.0 + self.clip_param)
surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()
  • 也就是对应的 L C L I P = E [ min ⁡ ( r t ( θ ) A t , c l i p ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) A t ) ] L^{CLIP} = E[ \min( r_t(\theta)A_t, clip(r_t(\theta),1-\epsilon,1+\epsilon)A_t ) ] LCLIP=E[min(rt(θ)At,clip(rt(θ),1−ϵ,1+ϵ)At)]
  • 第一行是原始策略梯度,也就是公式中的 L P G = E [ r t ( θ ) A t ] L^{PG} = E[r_t(\theta)A_t] LPG=E[rt(θ)At]
  • 通过上述公式,PPO 会限制 r ( θ ) r(\theta) r(θ)的取值范围 [ 1 − ϵ , 1 + ϵ ] [1-\epsilon, 1+\epsilon] [1−ϵ,1+ϵ]如果超过这个范围,梯度就会被裁剪,不再继续增大。
  • 注意:这里加负号是因为 PyTorch默认最小化loss

3-7 损失函数
python 复制代码
 # Value function loss
if self.use_clipped_value_loss:
	value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(-self.clip_param,
																					self.clip_param)
	value_losses = (value_batch - returns_batch).pow(2)
	value_losses_clipped = (value_clipped - returns_batch).pow(2)
	value_loss = torch.max(value_losses, value_losses_clipped).mean()
else:
	value_loss = (returns_batch - value_batch).pow(2).mean()

loss = surrogate_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_batch.mean()
  • 这里计算完整的损失函数公式 L = L p o l i c y + c 1 L v a l u e − c 2 H ( π ) L = L_{policy} + c_1 L_{value} - c_2 H(\pi) L=Lpolicy+c1Lvalue−c2H(π)
  • 其中:
    • self.value_loss_coef * value_loss:价值网络损失
    • surrogate_loss:策略网络损失
    • self.entropy_coef * entropy_batch.mean():策略熵函数损失
  • 这里根据是否使用PPO value clip分为两种计算value_loss的方式
  1. 普通value loss L V = ( V θ ( s ) − R t ) 2 L_V = (V_\theta(s) - R_t)^2 LV=(Vθ(s)−Rt)2
  2. PPO value clip V c l i p = V o l d + c l i p ( V θ − V o l d ) V^{clip} = V_{old} + clip(V_\theta - V_{old}) Vclip=Vold+clip(Vθ−Vold) L V = m a x ( ( V − R ) 2 , ( V c l i p − R ) 2 ) L_V = max( (V - R)^2, (V^{clip}-R)^2 ) LV=max((V−R)2,(Vclip−R)2)
  • 如果使用PPO value clip可以防止 critic更新过大

3-8 梯度推进
python 复制代码
 # Gradient step
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm)
self.optimizer.step()

mean_value_loss += value_loss.item()
mean_surrogate_loss += surrogate_loss.item()
  • 剩下的就是
    • 清空梯度
    • 反向传播
    • 梯度裁剪 ∣ ∣ g ∣ ∣ < m a x _ g r a d _ n o r m ∣∣g∣∣<max\_grad\_norm ∣∣g∣∣<max_grad_norm
    • 更新参数
    • 记录loss

3-9 外层循环收尾工作
  • 计算平均loss,清空rollout buffer
python 复制代码
num_updates = self.num_learning_epochs * self.num_mini_batches
mean_value_loss /= num_updates
mean_surrogate_loss /= num_updates
self.storage.clear()

3-10 PPO训练循环
  • PPO训练循环:
  1. 收集 rollout
  2. 计算 advantage (GAE)
  3. 多轮 mini-batch 更新
  4. clip policy
  5. clip value
  6. entropy regularization
  • 数学目标: L = E [ min ⁡ ( r t A t , c l i p ( r t , 1 − ϵ , 1 + ϵ ) A t ) ] + c 1 ( V − R ) 2 − c 2 H ( π ) L = E[ \min( r_t A_t, clip(r_t,1-\epsilon,1+\epsilon)A_t ) ] + c_1(V-R)^2 - c_2H(\pi) L=E[min(rtAt,clip(rt,1−ϵ,1+ϵ)At)]+c1(V−R)2−c2H(π)

小结

  • 本期我们对 rsl_rl 仓库中 PPO 算法的 Python 实现进行了全面解析:从初始化超参数、经验回放缓存、动作采样、环境反馈处理,到优势函数计算与策略更新的完整流程。核心机制包括概率比率裁剪 (clip)、GAE 优势估计、价值函数裁剪、防止梯度爆炸、以及可选的自适应学习率和 KL 控制,最终通过组合策略损失、价值损失和策略熵形成完整优化目标,实现对四足机器人稳定且高效的强化学习训练。
  • 如有错误,欢迎指出!感谢观看
相关推荐
小钻风33661 小时前
Optional:告别NullPointerException的优雅方案
开发语言·python
随意起个昵称1 小时前
【贪心】选择尽量多的不相交区间
数据结构·算法
章小幽2 小时前
LeetCode-35.搜索插入位置
数据结构·算法·leetcode
科技块儿2 小时前
多语言技术栈如何共用IP离线库?Java、Python、Go 的加载实践
java·python·tcp/ip
fawubio_A2 小时前
毕业设计 深度学习卷积神经网络垃圾分类系统
python·cnn·毕业设计·毕设
放下华子我只抽RuiKe52 小时前
机器学习全景指南-探索篇——发现数据内在结构的聚类算法
人工智能·深度学习·算法·机器学习·语言模型·数据挖掘·聚类
Yupureki2 小时前
《C++实战项目-高并发内存池》3.ThreadCache构造
服务器·c语言·c++·算法·哈希算法
与虾牵手2 小时前
大模型流式输出 Streaming API 完整教程:从原理到踩坑,一篇搞定
python·aigc·ai编程
程序员JerrySUN2 小时前
别再把 HTTPS 和 OTA 看成两回事:一篇讲透 HTTPS 协议、安全通信机制与 Mender 升级加密链路的完整文章
android·java·开发语言·深度学习·流程图