【无标题】

这里写自定义目录标题

rsl_rl中PPO算法详解

on_policy_runner.py.

python 复制代码
import os
import time
import torch
import wandb
import statistics
from collections import deque
from datetime import datetime
from .ppo import PPO
from .actor_critic import ActorCritic
from humanoid.algo.vec_env import VecEnv
from torch.utils.tensorboard import SummaryWriter


class OnPolicyRunner:
    """
    在线策略(On-Policy)强化学习训练器,基于PPO算法实现
    核心功能:
    1. 环境交互收集轨迹数据
    2. PPO算法参数更新
    3. 训练过程日志记录与模型保存
    4. 推理模式下的策略获取
    """

    def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"):
        """
        初始化训练器
        Args:
            env: 向量化环境(VecEnv),支持多环境并行
            train_cfg: 训练配置字典,包含runner/algorithm/policy等子配置
            log_dir: 日志和模型保存目录
            device: 计算设备(cpu/cuda)
        """
        # 解析配置参数
        self.cfg = train_cfg["runner"]          # Runner配置(步数、保存间隔等)
        self.alg_cfg = train_cfg["algorithm"]   # PPO算法配置(学习率、clip范围等)
        self.policy_cfg = train_cfg["policy"]   # 策略网络配置(网络结构、隐藏层等)
        self.all_cfg = train_cfg                # 完整配置,用于wandb日志
        # 生成唯一的实验名称(时间戳+实验名+运行名)
        self.wandb_run_name = (
            datetime.now().strftime("%b%d_%H-%M-%S")
            + "_"
            + train_cfg["runner"]["experiment_name"]
            + "_"
            + train_cfg["runner"]["run_name"]
        )
        self.device = device
        self.env = env
        
        # 确定Critic网络的输入维度:优先使用特权观测(privileged obs),无则使用普通观测
        if self.env.num_privileged_obs is not None:
            num_critic_obs = self.env.num_privileged_obs
        else:
            num_critic_obs = self.env.num_obs
        
        # 动态加载策略网络类(默认为ActorCritic)
        actor_critic_class = eval(self.cfg["policy_class_name"])
        # 初始化Actor-Critic网络并移至指定设备
        actor_critic: ActorCritic = actor_critic_class(
            self.env.num_obs, num_critic_obs, self.env.num_actions, **self.policy_cfg
        ).to(self.device)
        
        # 动态加载PPO算法类
        alg_class = eval(self.cfg["algorithm_class_name"])
        self.alg: PPO = alg_class(actor_critic, device=self.device,** self.alg_cfg)
        
        # 每个环境收集的步数
        self.num_steps_per_env = self.cfg["num_steps_per_env"]
        # 模型保存间隔
        self.save_interval = self.cfg["save_interval"]

        # 初始化PPO的数据存储缓冲区
        self.alg.init_storage(
            self.env.num_envs,                # 环境数量
            self.num_steps_per_env,           # 每个环境收集步数
            [self.env.num_obs],               # 观测维度
            [self.env.num_privileged_obs],    # 特权观测维度
            [self.env.num_actions],           # 动作维度
        )

        # 日志相关初始化
        self.log_dir = log_dir
        self.writer = None                   # TensorBoard写入器
        self.tot_timesteps = 0               # 总交互步数
        self.tot_time = 0                    # 总训练时间
        self.current_learning_iteration = 0   # 当前学习迭代次数

        # 重置环境
        _, _ = self.env.reset()

    def learn(self, num_learning_iterations, init_at_random_ep_len=False):
        """
        核心训练循环
        Args:
            num_learning_iterations: 学习迭代次数
            init_at_random_ep_len: 是否随机初始化episode长度缓冲区
        """
        # 初始化日志工具(wandb + tensorboard)
        if self.log_dir is not None and self.writer is None:
            wandb.init(
                project="XBot",
                sync_tensorboard=True,
                name=self.wandb_run_name,
                config=self.all_cfg,
            )
            self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10)
        
        # 随机初始化episode长度(用于多样化训练)
        if init_at_random_ep_len:
            self.env.episode_length_buf = torch.randint_like(
                self.env.episode_length_buf, high=int(self.env.max_episode_length)
            )
        
        # 获取初始观测
        obs = self.env.get_observations()
        privileged_obs = self.env.get_privileged_observations()
        # Critic输入:优先使用特权观测
        critic_obs = privileged_obs if privileged_obs is not None else obs
        # 数据移至指定设备
        obs, critic_obs = obs.to(self.device), critic_obs.to(self.device)
        # 设置网络为训练模式(启用dropout/batchnorm等)
        self.alg.actor_critic.train()

        # 用于存储episode信息
        ep_infos = []
        # 滑动窗口存储最近100个episode的奖励和长度
        rewbuffer = deque(maxlen=100)
        lenbuffer = deque(maxlen=100)
        # 累计每个环境的当前episode奖励
        cur_reward_sum = torch.zeros(
            self.env.num_envs, dtype=torch.float, device=self.device
        )
        # 累计每个环境的当前episode长度
        cur_episode_length = torch.zeros(
            self.env.num_envs, dtype=torch.float, device=self.device
        )

        # 计算总迭代次数
        tot_iter = self.current_learning_iteration + num_learning_iterations
        
        # 主训练循环
        for it in range(self.current_learning_iteration, tot_iter):
            start = time.time()
            
            # ====================== 第一步:收集轨迹数据(Rollout) ======================
            # 推理模式:禁用梯度计算,提升速度
            with torch.inference_mode():
                # 每个迭代收集num_steps_per_env步数据
                for i in range(self.num_steps_per_env):
                    # 根据当前观测选择动作
                    # 数学原理:Actor网络输出动作分布,采样得到动作
                    # $a_t \sim \pi_\theta(a_t|s_t)$
                    actions = self.alg.act(obs, critic_obs)
                    
                    # 环境执行动作,获取反馈
                    obs, privileged_obs, rewards, dones, infos = self.env.step(actions)
                    critic_obs = privileged_obs if privileged_obs is not None else obs
                    
                    # 数据移至指定设备
                    obs, critic_obs, rewards, dones = (
                        obs.to(self.device),
                        critic_obs.to(self.device),
                        rewards.to(self.device),
                        dones.to(self.device),
                    )
                    
                    # 存储环境交互数据
                    self.alg.process_env_step(rewards, dones, infos)

                    # 日志记录:更新奖励和episode长度统计
                    if self.log_dir is not None:
                        if "episode" in infos:
                            ep_infos.append(infos["episode"])
                        # 累计奖励和步数
                        cur_reward_sum += rewards
                        cur_episode_length += 1
                        # 找出结束的episode
                        new_ids = (dones > 0).nonzero(as_tuple=False)
                        # 更新滑动窗口
                        rewbuffer.extend(
                            cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist()
                        )
                        lenbuffer.extend(
                            cur_episode_length[new_ids][:, 0].cpu().numpy().tolist()
                        )
                        # 重置结束episode的累计值
                        cur_reward_sum[new_ids] = 0
                        cur_episode_length[new_ids] = 0

                # 计算数据收集时间
                stop = time.time()
                collection_time = stop - start

                # ====================== 第二步:计算回报(Returns) ======================
                # 数学原理:广义优势估计(GAE)
                # $GAE(\gamma, \lambda) = \sum_{k=0}^\infty (\gamma\lambda)^k \delta_{t+k}$
                # 其中优势函数 $\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)$
                # 回报 $U_t = A_t + V(s_t)$
                start = stop
                self.alg.compute_returns(critic_obs)

            # ====================== 第三步:PPO策略更新 ======================
            # 数学原理:PPO-Clip目标函数
            # $\mathcal{L}^{CLIP}(\theta) = \hat{\mathbb{E}}_t\left[ \min(r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_t) \right]$
            # 其中比率 $r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}$
            mean_value_loss, mean_surrogate_loss = self.alg.update()
            
            # 计算学习时间
            stop = time.time()
            learn_time = stop - start
            
            # 记录日志
            if self.log_dir is not None:
                self.log(locals())
            
            # 定期保存模型
            if it % self.save_interval == 0:
                self.save(os.path.join(self.log_dir, "model_{}.pt".format(it)))
            
            # 清空episode信息
            ep_infos.clear()

        # 更新迭代次数
        self.current_learning_iteration += num_learning_iterations
        # 保存最终模型
        self.save(
            os.path.join(
                self.log_dir, "model_{}.pt".format(self.current_learning_iteration)
            )
        )

    def log(self, locs, width=80, pad=35):
        """
        训练日志记录函数
        Args:
            locs: 本地变量字典(包含当前迭代的所有信息)
            width: 日志打印宽度
            pad: 日志对齐填充长度
        """
        # 更新总步数和总时间
        self.tot_timesteps += self.num_steps_per_env * self.env.num_envs
        self.tot_time += locs["collection_time"] + locs["learn_time"]
        iteration_time = locs["collection_time"] + locs["learn_time"]

        # 构建episode信息字符串
        ep_string = f""
        if locs["ep_infos"]:
            for key in locs["ep_infos"][0]:
                infotensor = torch.tensor([], device=self.device)
                for ep_info in locs["ep_infos"]:
                    # 统一处理标量和0维张量
                    if not isinstance(ep_info[key], torch.Tensor):
                        ep_info[key] = torch.Tensor([ep_info[key]])
                    if len(ep_info[key].shape) == 0:
                        ep_info[key] = ep_info[key].unsqueeze(0)
                    infotensor = torch.cat((infotensor, ep_info[key].to(self.device)))
                # 计算平均值并记录到tensorboard
                value = torch.mean(infotensor)
                self.writer.add_scalar("Episode/" + key, value, locs["it"])
                ep_string += f"""{f'Mean episode {key}:':>{pad}} {value:.4f}\n"""
        
        # 计算策略输出的平均动作噪声标准差
        mean_std = self.alg.actor_critic.std.mean()
        
        # 计算每秒处理的步数(FPS)
        fps = int(
            self.num_steps_per_env
            * self.env.num_envs
            / (locs["collection_time"] + locs["learn_time"])
        )

        # 记录核心训练指标到tensorboard
        self.writer.add_scalar("Loss/value_function", locs["mean_value_loss"], locs["it"])
        self.writer.add_scalar("Loss/surrogate", locs["mean_surrogate_loss"], locs["it"])
        self.writer.add_scalar("Loss/learning_rate", self.alg.learning_rate, locs["it"])
        self.writer.add_scalar("Policy/mean_noise_std", mean_std.item(), locs["it"])
        self.writer.add_scalar("Perf/total_fps", fps, locs["it"])
        self.writer.add_scalar("Perf/collection time", locs["collection_time"], locs["it"])
        self.writer.add_scalar("Perf/learning_time", locs["learn_time"], locs["it"])
        
        # 记录奖励和episode长度统计
        if len(locs["rewbuffer"]) > 0:
            self.writer.add_scalar(
                "Train/mean_reward", statistics.mean(locs["rewbuffer"]), locs["it"]
            )
            self.writer.add_scalar(
                "Train/mean_episode_length",
                statistics.mean(locs["lenbuffer"]),
                locs["it"],
            )
            self.writer.add_scalar(
                "Train/mean_reward/time",
                statistics.mean(locs["rewbuffer"]),
                self.tot_time,
            )
            self.writer.add_scalar(
                "Train/mean_episode_length/time",
                statistics.mean(locs["lenbuffer"]),
                self.tot_time,
            )

        # 构建日志字符串(带格式)
        str = f" \033[1m Learning iteration {locs['it']}/{self.current_learning_iteration + locs['num_learning_iterations']} \033[0m "

        if len(locs["rewbuffer"]) > 0:
            log_string = (
                f"""{'#' * width}\n"""
                f"""{str.center(width, ' ')}\n\n"""
                f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
                            'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
                f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n"""
                f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n"""
                f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
                f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n"""
                f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n"""
            )
        else:
            log_string = (
                f"""{'#' * width}\n"""
                f"""{str.center(width, ' ')}\n\n"""
                f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
                            'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
                f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n"""
                f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n"""
                f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
            )

        # 补充episode信息和总计信息
        log_string += ep_string
        log_string += (
            f"""{'-' * width}\n"""
            f"""{'Total timesteps:':>{pad}} {self.tot_timesteps}\n"""
            f"""{'Iteration time:':>{pad}} {iteration_time:.2f}s\n"""
            f"""{'Total time:':>{pad}} {self.tot_time:.2f}s\n"""
            f"""{'ETA:':>{pad}} {self.tot_time / (locs['it'] + 1) * (
                               locs['num_learning_iterations'] - locs['it']):.1f}s\n"""
        )
        # 打印日志
        print(log_string)

    def save(self, path, infos=None):
        """
        保存模型和优化器状态
        Args:
            path: 保存路径
            infos: 额外保存的信息
        """
        torch.save(
            {
                "model_state_dict": self.alg.actor_critic.state_dict(),  # 策略网络参数
                "optimizer_state_dict": self.alg.optimizer.state_dict(), # 优化器状态
                "iter": self.current_learning_iteration,                 # 当前迭代次数
                "infos": infos,                                          # 额外信息
            },
            path,
        )

    def load(self, path, load_optimizer=True):
        """
        加载模型和优化器状态
        Args:
            path: 模型路径
            load_optimizer: 是否加载优化器状态
        Returns:
            infos: 保存的额外信息
        """
        loaded_dict = torch.load(path)
        # 加载网络参数
        self.alg.actor_critic.load_state_dict(loaded_dict["model_state_dict"])
        # 加载优化器状态(可选)
        if load_optimizer:
            self.alg.optimizer.load_state_dict(loaded_dict["optimizer_state_dict"])
        # 恢复迭代次数
        self.current_learning_iteration = loaded_dict["iter"]
        return loaded_dict["infos"]

    def get_inference_policy(self, device=None):
        """
        获取推理模式的策略函数(用于部署/测试)
        Args:
            device: 推理设备
        Returns:
            act_inference: 推理模式的动作选择函数
        """
        # 设置网络为评估模式(禁用dropout/batchnorm)
        self.alg.actor_critic.eval()
        if device is not None:
            self.alg.actor_critic.to(device)
        # 返回推理用的动作选择函数
        return self.alg.actor_critic.act_inference

    def get_inference_critic(self, device=None):
        """
        获取推理模式的Critic函数(用于价值评估)
        Args:
            device: 推理设备
        Returns:
            evaluate: 状态价值评估函数
        """
        self.alg.actor_critic.eval()
        if device is not None:
            self.alg.actor_critic.to(device)
        return self.alg.actor_critic.evaluate

ppo.py

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim

from .actor_critic import ActorCritic
from .rollout_storage import RolloutStorage

class PPO:
    """
    近端策略优化(PPO)算法实现类
    PPO是一种基于信任区域的策略梯度算法,通过裁剪目标函数来保证策略更新的幅度在可控范围内
    核心论文:https://arxiv.org/abs/1707.06347
    """
    actor_critic: ActorCritic  # 类型注解:Actor-Critic网络
    
    def __init__(self,
                 actor_critic,
                 num_learning_epochs=1,        # 每个更新周期的学习轮数
                 num_mini_batches=1,           # 每次更新的小批次数量
                 clip_param=0.2,               # PPO裁剪参数 ε
                 gamma=0.998,                  # 折扣因子 γ
                 lam=0.95,                     # GAE的lambda参数 λ
                 value_loss_coef=1.0,          # 价值损失系数
                 entropy_coef=0.0,             # 熵奖励系数(鼓励探索)
                 learning_rate=1e-3,           # 学习率
                 max_grad_norm=1.0,            # 梯度裁剪最大值
                 use_clipped_value_loss=True,  # 是否使用裁剪的价值损失
                 schedule="fixed",             # 学习率调度方式:fixed/adaptive
                 desired_kl=0.01,              # 期望的KL散度(自适应学习率用)
                 device='cpu',
                 ):

        self.device = device

        # 学习率自适应相关参数
        self.desired_kl = desired_kl
        self.schedule = schedule
        self.learning_rate = learning_rate

        # PPO核心组件初始化
        self.actor_critic = actor_critic
        self.actor_critic.to(self.device)
        self.storage = None  # 经验存储池(后续初始化)
        # 使用Adam优化器,默认betas=(0.9, 0.999)
        self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=learning_rate)
        self.transition = RolloutStorage.Transition()  # 单步转移数据容器

        # PPO超参数
        self.clip_param = clip_param
        self.num_learning_epochs = num_learning_epochs
        self.num_mini_batches = num_mini_batches
        self.value_loss_coef = value_loss_coef
        self.entropy_coef = entropy_coef
        self.gamma = gamma      # 折扣因子
        self.lam = lam          # GAE参数
        self.max_grad_norm = max_grad_norm  # 梯度最大范数
        self.use_clipped_value_loss = use_clipped_value_loss  # 价值损失是否裁剪

    def init_storage(self, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape):
        """
        初始化经验存储池
        Args:
            num_envs: 环境数量
            num_transitions_per_env: 每个环境存储的转移步数
            actor_obs_shape: 策略网络观测维度
            critic_obs_shape: 价值网络观测维度
            action_shape: 动作维度
        """
        self.storage = RolloutStorage(
            num_envs, num_transitions_per_env, 
            actor_obs_shape, critic_obs_shape, 
            action_shape, self.device
        )

    def test_mode(self):
        """设置模型为测试模式(关闭dropout/batchnorm等)"""
        self.actor_critic.test()
    
    def train_mode(self):
        """设置模型为训练模式"""
        self.actor_critic.train()

    def act(self, obs, critic_obs):
        """
        根据当前观测执行动作(收集经验阶段)
        Args:
            obs: 策略网络观测 [num_envs, obs_dim]
            critic_obs: 价值网络观测 [num_envs, critic_obs_dim]
        Returns:
            actions: 采样的动作 [num_envs, action_dim]
        """
        # 前向传播计算动作和价值
        self.transition.actions = self.actor_critic.act(obs).detach()  # 采样动作
        self.transition.values = self.actor_critic.evaluate(critic_obs).detach()  # 价值估计
        # 计算动作的对数概率 logπ(a|s)
        self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
        self.transition.action_mean = self.actor_critic.action_mean.detach()  # 动作分布均值
        self.transition.action_sigma = self.actor_critic.action_std.detach()  # 动作分布标准差
        
        # 记录当前观测(在环境step前)
        self.transition.observations = obs
        self.transition.critic_observations = critic_obs
        
        return self.transition.actions
    
    def process_env_step(self, rewards, dones, infos):
        """
        处理环境step返回的结果,更新转移数据并存储
        Args:
            rewards: 环境奖励 [num_envs]
            dones: 环境结束标志 [num_envs]
            infos: 环境附加信息
        """
        self.transition.rewards = rewards.clone()
        self.transition.dones = dones
        
        # 对超时终止的情况进行bootstrapping(时间限制而非任务失败)
        if 'time_outs' in infos:
            # 超时奖励修正:r += γ * V(s) * timeout_flag
            self.transition.rewards += self.gamma * torch.squeeze(
                self.transition.values * infos['time_outs'].unsqueeze(1).to(self.device), 1
            )

        # 将当前转移数据添加到存储池
        self.storage.add_transitions(self.transition)
        self.transition.clear()  # 清空临时转移数据
        self.actor_critic.reset(dones)  # 重置RNN等状态(如果有)
    
    def compute_returns(self, last_critic_obs):
        """
        计算折扣回报和优势函数(使用GAE算法)
        Args:
            last_critic_obs: 最后一步的价值网络观测
        """
        # 获取最后一步的价值估计 V(T)
        last_values = self.actor_critic.evaluate(last_critic_obs).detach()
        
        # 核心公式:GAE (Generalized Advantage Estimation)
        # A^GAE_t = δ_t + γλδ_{t+1} + (γλ)^2δ_{t+2} + ... + (γλ)^{T-t-1}δ_{T-1}
        # 其中 δ_t = r_t + γV(s_{t+1}) - V(s_t)  (时序差分误差)
        # 
        # 折扣回报: U_t = A^GAE_t + V(s_t)
        self.storage.compute_returns(last_values, self.gamma, self.lam)

    def update(self):
        """
        PPO核心更新过程(策略和价值网络)
        Returns:
            mean_value_loss: 平均价值损失
            mean_surrogate_loss: 平均策略损失
        """
        mean_value_loss = 0  # 价值损失累计
        mean_surrogate_loss = 0  # 策略损失累计

        # 生成小批次数据迭代器(打乱数据,分批次)
        generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
        
        # 遍历所有小批次数据
        for (obs_batch, critic_obs_batch, actions_batch, target_values_batch, 
             advantages_batch, returns_batch, old_actions_log_prob_batch, 
             old_mu_batch, old_sigma_batch, hid_states_batch, masks_batch) in generator:

            # 前向传播:计算当前策略的动作分布和价值
            self.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])
            # 当前策略的动作对数概率 logπ_new(a|s)
            actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch)
            # 当前价值估计 V_new(s)
            value_batch = self.actor_critic.evaluate(critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1])
            mu_batch = self.actor_critic.action_mean  # 当前动作分布均值
            sigma_batch = self.actor_critic.action_std  # 当前动作分布标准差
            entropy_batch = self.actor_critic.entropy  # 动作分布的熵(鼓励探索)

            # -------------------------- 自适应学习率(基于KL散度) --------------------------
            if self.desired_kl is not None and self.schedule == 'adaptive':
                with torch.inference_mode():  # 推理模式,不计算梯度
                    # 高斯分布的KL散度计算公式:
                    # KL(π_old || π_new) = 1/2 * [ (μ_old-μ_new)²/σ_new² + (σ_old²/σ_new²) - 1 + 2ln(σ_new/σ_old) ]
                    # 简化版(代码实现):
                    kl = torch.sum(
                        torch.log(sigma_batch / old_sigma_batch + 1.e-5) +  # 防止除零
                        (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch)) / (2.0 * torch.square(sigma_batch)) 
                        - 0.5, axis=-1
                    )
                    kl_mean = torch.mean(kl)  # 平均KL散度

                    # 根据KL散度调整学习率
                    if kl_mean > self.desired_kl * 2.0:  # KL过大,减小学习率
                        self.learning_rate = max(1e-5, self.learning_rate / 1.5)
                    elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:  # KL过小,增大学习率
                        self.learning_rate = min(1e-2, self.learning_rate * 1.5)
                    
                    # 更新优化器学习率
                    for param_group in self.optimizer.param_groups:
                        param_group['lr'] = self.learning_rate

            # -------------------------- PPO裁剪目标函数(核心) --------------------------
            # 重要性采样比率:r_t(θ) = π_θ(a_t|s_t) / π_θ_old(a_t|s_t)
            # 对数域计算:exp(logπ_new - logπ_old) = π_new / π_old
            ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch))
            
            # 原始策略目标:-A_t * r_t(θ) (负号因为要最小化损失)
            surrogate = -torch.squeeze(advantages_batch) * ratio
            
            # 裁剪后的策略目标:-A_t * clip(r_t(θ), 1-ε, 1+ε)
            surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(
                ratio, 1.0 - self.clip_param, 1.0 + self.clip_param
            )
            
            # PPO核心损失:取两者的最大值(保守更新)
            # L^CLIP(θ) = E[ min(r_t(θ)A_t, clip(r_t(θ), 1-ε, 1+ε)A_t) ]
            # 代码中取负号转为最小化问题:-min(...) → max(...)
            surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()

            # -------------------------- 价值函数损失 --------------------------
            if self.use_clipped_value_loss:
                # 裁剪的价值损失(PPO2做法)
                # V_clipped = V_old + clip(V_new - V_old, -ε, ε)
                value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(
                    -self.clip_param, self.clip_param
                )
                # 原始价值损失:(V_new - U_t)²
                value_losses = (value_batch - returns_batch).pow(2)
                # 裁剪后价值损失:(V_clipped - U_t)²
                value_losses_clipped = (value_clipped - returns_batch).pow(2)
                # 取最大值作为价值损失(保守更新)
                value_loss = torch.max(value_losses, value_losses_clipped).mean()
            else:
                # 普通MSE价值损失:L_VF(θ) = E[(V_θ(s_t) - U_t)²]
                value_loss = (returns_batch - value_batch).pow(2).mean()

            # -------------------------- 总损失计算 --------------------------
            # 总损失 = 策略损失 + 价值损失系数×价值损失 - 熵系数×熵(负号鼓励探索)
            # L_total = L^CLIP(θ) + c1*L_VF(θ) - c2*H(π_θ(s_t))
            loss = surrogate_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_batch.mean()

            # -------------------------- 梯度更新 --------------------------
            self.optimizer.zero_grad()  # 清空梯度
            loss.backward()  # 反向传播计算梯度
            # 梯度裁剪(防止梯度爆炸)
            nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm)
            self.optimizer.step()  # 参数更新

            # 累计损失值
            mean_value_loss += value_loss.item()
            mean_surrogate_loss += surrogate_loss.item()

        # 计算平均损失(除以总更新次数)
        num_updates = self.num_learning_epochs * self.num_mini_batches
        mean_value_loss /= num_updates
        mean_surrogate_loss /= num_updates
        
        # 清空存储池,准备下一轮收集经验
        self.storage.clear()

        return mean_value_loss, mean_surrogate_loss

actor_critic.py

python 复制代码
import torch
import torch.nn as nn
from torch.distributions import Normal  # 高斯分布(正态分布)

class ActorCritic(nn.Module):
    """
    Actor-Critic网络(PPO算法核心组件)
    - Actor:策略网络,输出动作分布的均值,结合可学习的标准差构建高斯分布,用于采样动作
    - Critic:价值网络,输出状态价值V(s),用于估计状态的期望回报
    核心原理:策略网络参数化动作分布π(a|s;θ),价值网络参数化状态价值V(s;φ)
    """
    def __init__(self,  
                 num_actor_obs,          # Actor网络输入维度(策略观测空间维度)
                 num_critic_obs,         # Critic网络输入维度(价值观测空间维度)
                 num_actions,            # 动作空间维度
                 actor_hidden_dims=[256, 256, 256],  # Actor网络隐藏层维度
                 critic_hidden_dims=[256, 256, 256], # Critic网络隐藏层维度
                 init_noise_std=1.0,     # 动作分布标准差的初始值
                 activation = nn.ELU(),  # 激活函数(ELU比ReLU更不易出现梯度消失)
                 **kwargs):
        # 处理多余的关键字参数
        if kwargs:
            print("ActorCritic.__init__ got unexpected arguments, which will be ignored: " + str([key for key in kwargs.keys()]))
        super(ActorCritic, self).__init__()

        # 输入维度定义
        mlp_input_dim_a = num_actor_obs  # Actor MLP输入维度
        mlp_input_dim_c = num_critic_obs # Critic MLP输入维度
        
        # ====================== 1. 构建Actor(策略)网络 ======================
        # Actor网络:输入状态s,输出动作分布的均值μ(s;θ)
        # 网络结构:全连接层 + 激活函数,最终输出维度=动作空间维度
        actor_layers = []
        # 第一层:输入层 → 第一个隐藏层
        actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0]))
        actor_layers.append(activation)
        # 构建隐藏层
        for l in range(len(actor_hidden_dims)):
            if l == len(actor_hidden_dims) - 1:
                # 最后一层:隐藏层 → 动作均值输出层(无激活函数,因为动作可以是任意实数)
                actor_layers.append(nn.Linear(actor_hidden_dims[l], num_actions))
            else:
                # 中间层:隐藏层 → 下一个隐藏层
                actor_layers.append(nn.Linear(actor_hidden_dims[l], actor_hidden_dims[l + 1]))
                actor_layers.append(activation)
        # 组合成Sequential网络
        self.actor = nn.Sequential(*actor_layers)

        # ====================== 2. 构建Critic(价值)网络 ======================
        # Critic网络:输入状态s,输出状态价值V(s;φ)
        # 网络结构:全连接层 + 激活函数,最终输出维度=1(标量价值)
        critic_layers = []
        # 第一层:输入层 → 第一个隐藏层
        critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0]))
        critic_layers.append(activation)
        # 构建隐藏层
        for l in range(len(critic_hidden_dims)):
            if l == len(critic_hidden_dims) - 1:
                # 最后一层:隐藏层 → 价值输出层(无激活函数,价值可以是任意实数)
                actor_layers.append(nn.Linear(critic_hidden_dims[l], 1))
            else:
                # 中间层:隐藏层 → 下一个隐藏层
                critic_layers.append(nn.Linear(critic_hidden_dims[l], critic_hidden_dims[l + 1]))
                critic_layers.append(activation)
        # 组合成Sequential网络
        self.critic = nn.Sequential(*critic_layers)

        # 打印网络结构
        print(f"Actor MLP: {self.actor}")
        print(f"Critic MLP: {self.critic}")

        # ====================== 3. 动作分布参数初始化 ======================
        # 动作分布的标准差σ:作为可学习的参数(所有动作维度共享/独立,取决于初始化)
        # 初始值为init_noise_std,形状为[num_actions](每个动作维度有独立的标准差)
        self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
        self.distribution = None  # 存储当前的动作分布(高斯分布)
        # 禁用Normal分布的参数验证(加速计算,因为我们确保输入合法)
        Normal.set_default_validate_args = False
        

    @staticmethod
    # 权重初始化函数(当前未使用)
    def init_weights(sequential, scales):
        """
        正交初始化网络权重(强化学习常用初始化方式,保持梯度稳定)
        Args:
            sequential: 待初始化的Sequential网络
            scales: 每层的增益系数
        """
        [torch.nn.init.orthogonal_(module.weight, gain=scales[idx]) for idx, module in
         enumerate(mod for mod in sequential if isinstance(mod, nn.Linear))]

    def reset(self, dones=None):
        """重置网络状态(用于RNN等有状态网络,当前MLP无状态,故空实现)"""
        pass

    def forward(self):
        """禁止直接调用forward,因为Actor-Critic有两个独立的前向路径"""
        raise NotImplementedError
    
    # ====================== 4. 分布属性获取(便捷属性) ======================
    @property
    def action_mean(self):
        """获取当前动作分布的均值μ"""
        return self.distribution.mean

    @property
    def action_std(self):
        """获取当前动作分布的标准差σ"""
        return self.distribution.stddev
    
    @property
    def entropy(self):
        """
        计算动作分布的熵H(π),用于鼓励探索
        高斯分布熵公式:H(N(μ,σ²)) = (1/2) * log(2πeσ²)  (单维度)
        多维度熵:sum(H_i) (假设动作维度独立)
        """
        return self.distribution.entropy().sum(dim=-1)

    def update_distribution(self, observations):
        """
        更新动作分布:根据当前观测计算动作均值,构建高斯分布
        核心公式:π(a|s) = N(μ(s;θ), σ²I)
        其中μ是Actor网络输出,σ是可学习参数,I是单位矩阵(动作维度独立)
        
        Args:
            observations: 观测张量 [batch_size, num_actor_obs]
        """
        # Actor网络前向传播:计算动作均值μ(s;θ)
        mean = self.actor(observations)
        # 构建高斯分布:均值=mean,标准差=self.std(广播到batch维度)
        # mean*0. 确保标准差和均值形状一致,不影响数值(乘以0加std)
        self.distribution = Normal(mean, mean*0. + self.std)

    def act(self, observations, **kwargs):
        """
        根据当前观测采样动作(训练阶段)
        Args:
            observations: 观测张量 [batch_size, num_actor_obs]
        Returns:
            actions: 采样的动作 [batch_size, num_actions]
        """
        # 更新动作分布
        self.update_distribution(observations)
        # 从高斯分布中采样动作:a ~ N(μ(s;θ), σ²)
        return self.distribution.sample()
    
    def get_actions_log_prob(self, actions):
        """
        计算给定动作的对数概率logπ(a|s)
        高斯分布对数概率公式:
        logπ(a|s) = -1/2 * [ (a-μ)²/σ² + log(2πσ²) ]  (单维度)
        多维度:sum(logπ_i) (动作维度独立,联合概率=各维度概率乘积,对数和)
        
        Args:
            actions: 动作张量 [batch_size, num_actions]
        Returns:
            log_prob: 对数概率 [batch_size]
        """
        # 计算每个动作维度的对数概率,然后求和(多维度联合概率)
        return self.distribution.log_prob(actions).sum(dim=-1)

    def act_inference(self, observations):
        """
        推理阶段获取动作(测试/部署阶段):直接返回动作均值(无探索噪声)
        Args:
            observations: 观测张量 [batch_size, num_actor_obs]
        Returns:
            actions_mean: 最优动作(均值) [batch_size, num_actions]
        """
        actions_mean = self.actor(observations)
        return actions_mean

    def evaluate(self, critic_observations, **kwargs):
        """
        评估状态价值V(s;φ)
        Args:
            critic_observations: 价值网络观测 [batch_size, num_critic_obs]
        Returns:
            value: 状态价值 [batch_size, 1]
        """
        # Critic网络前向传播:V(s;φ) = critic_mlp(s)
        value = self.critic(critic_observations)
        return value

rollout_storage.py

python 复制代码
import torch

class RolloutStorage:
    """
    经验存储池(Rollout Buffer)- PPO算法核心组件
    功能:
    1. 存储多环境并行收集的轨迹数据(观测、动作、奖励、价值等)
    2. 计算GAE(广义优势估计)和折扣回报
    3. 生成打乱的小批次数据用于PPO更新
    """
    class Transition:
        """单步转移数据容器:存储每一步的完整交互信息"""
        def __init__(self):
            self.observations = None          # 策略网络观测 s_t
            self.critic_observations = None   # 价值网络观测(可与策略观测不同)
            self.actions = None               # 执行的动作 a_t
            self.rewards = None               # 获得的奖励 r_t
            self.dones = None                 # 环境结束标志(是否到达终止状态)
            self.values = None                # 价值估计 V(s_t)
            self.actions_log_prob = None      # 动作对数概率 logπ(a_t|s_t)
            self.action_mean = None           # 动作分布均值 μ_t
            self.action_sigma = None          # 动作分布标准差 σ_t
            self.hidden_states = None         # RNN隐藏状态(如果使用循环网络)
        
        def clear(self):
            """重置所有属性为初始状态"""
            self.__init__()

    def __init__(self, 
                 num_envs,                    # 并行环境数量
                 num_transitions_per_env,     # 每个环境存储的转移步数
                 obs_shape,                   # 策略观测维度
                 privileged_obs_shape,        # 价值观测维度(特权观测)
                 actions_shape,               # 动作维度
                 device='cpu'):

        self.device = device  # 数据存储设备(CPU/GPU)

        # 维度记录
        self.obs_shape = obs_shape
        self.privileged_obs_shape = privileged_obs_shape
        self.actions_shape = actions_shape

        # ====================== 核心数据存储张量 ======================
        # 形状说明:[num_transitions_per_env, num_envs, *dim]
        # 第一维:时间步,第二维:环境编号,后续维度:数据本身维度
        self.observations = torch.zeros(
            num_transitions_per_env, num_envs, *obs_shape, device=self.device
        )
        # 特权观测(价值网络专用):可选,如果为None则和策略观测共用
        if privileged_obs_shape[0] is not None:
            self.privileged_observations = torch.zeros(
                num_transitions_per_env, num_envs, *privileged_obs_shape, device=self.device
            )
        else:
            self.privileged_observations = None
        
        self.rewards = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)  # 奖励 [T, N, 1]
        self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)  # 动作 [T, N, A]
        self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device).byte()  # 结束标志 [T, N, 1]

        # ====================== PPO专用数据 ======================
        self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)  # 动作对数概率 [T, N, 1]
        self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)  # 价值估计 [T, N, 1]
        self.returns = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)  # 折扣回报 U_t [T, N, 1]
        self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)  # 优势函数 A_t [T, N, 1]
        self.mu = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)  # 动作均值 [T, N, A]
        self.sigma = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)  # 动作标准差 [T, N, A]

        # 存储池配置参数
        self.num_transitions_per_env = num_transitions_per_env  # 每个环境的最大存储步数
        self.num_envs = num_envs  # 并行环境数

        # RNN相关(循环网络隐藏状态存储)
        self.saved_hidden_states_a = None  # Actor网络隐藏状态
        self.saved_hidden_states_c = None  # Critic网络隐藏状态

        self.step = 0  # 当前存储的步数指针

    def add_transitions(self, transition: Transition):
        """
        添加单步转移数据到存储池
        Args:
            transition: 单步转移数据对象
        """
        # 检查存储池是否溢出
        if self.step >= self.num_transitions_per_env:
            raise AssertionError("Rollout buffer overflow")
        
        # 复制数据到对应位置(使用copy_避免梯度传递)
        self.observations[self.step].copy_(transition.observations)
        if self.privileged_observations is not None: 
            self.privileged_observations[self.step].copy_(transition.critic_observations)
        self.actions[self.step].copy_(transition.actions)
        self.rewards[self.step].copy_(transition.rewards.view(-1, 1))  # 确保形状为[N, 1]
        self.dones[self.step].copy_(transition.dones.view(-1, 1))      # 确保形状为[N, 1]
        self.values[self.step].copy_(transition.values)
        self.actions_log_prob[self.step].copy_(transition.actions_log_prob.view(-1, 1))
        self.mu[self.step].copy_(transition.action_mean)
        self.sigma[self.step].copy_(transition.action_sigma)
        
        # 保存RNN隐藏状态
        self._save_hidden_states(transition.hidden_states)
        
        # 步数指针+1
        self.step += 1

    def _save_hidden_states(self, hidden_states):
        """
        保存RNN(GRU/LSTM)的隐藏状态(内部使用)
        Args:
            hidden_states: 包含Actor和Critic隐藏状态的元组 (hid_a, hid_c)
        """
        if hidden_states is None or hidden_states==(None, None):
            return
        
        # 统一GRU/LSTM格式:GRU只有hidden state,LSTM有(hidden, cell)
        hid_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],)
        hid_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],)

        # 初始化隐藏状态存储张量(首次调用时)
        if self.saved_hidden_states_a is None:
            self.saved_hidden_states_a = [
                torch.zeros(self.observations.shape[0], *hid_a[i].shape, device=self.device) 
                for i in range(len(hid_a))
            ]
            self.saved_hidden_states_c = [
                torch.zeros(self.observations.shape[0], *hid_c[i].shape, device=self.device) 
                for i in range(len(hid_c))
            ]
        
        # 复制当前步的隐藏状态
        for i in range(len(hid_a)):
            self.saved_hidden_states_a[i][self.step].copy_(hid_a[i])
            self.saved_hidden_states_c[i][self.step].copy_(hid_c[i])

    def clear(self):
        """重置存储池:仅重置步数指针,不释放内存(复用张量)"""
        self.step = 0

    def compute_returns(self, last_values, gamma, lam):
        """
        计算折扣回报(Returns)和GAE(广义优势估计)
        核心公式:
        1. 时序差分误差:δ_t = r_t + γ * (1-d_t) * V(s_{t+1}) - V(s_t)
        2. GAE优势函数:A_t^GAE = δ_t + γλ(1-d_t)δ_{t+1} + (γλ)^2(1-d_t)(1-d_{t+1})δ_{t+2} + ...
        3. 折扣回报:U_t = A_t^GAE + V(s_t) = r_t + γ(1-d_t)U_{t+1}
        
        Args:
            last_values: 最后一步的价值估计 V(s_T) [N, 1]
            gamma: 折扣因子 γ
            lam: GAE参数 λ
        """
        advantage = 0  # 反向计算时的优势累积值
        # 从最后一步反向遍历计算
        for step in reversed(range(self.num_transitions_per_env)):
            # 获取下一步的价值估计 V(s_{t+1})
            if step == self.num_transitions_per_env - 1:
                next_values = last_values  # 最后一步使用外部传入的V(s_T)
            else:
                next_values = self.values[step + 1]  # 非最后一步使用存储的V(s_{t+1})
            
            # 1-d_t:如果t步是终止状态,则下一步无价值(乘以0)
            next_is_not_terminal = 1.0 - self.dones[step].float()
            
            # 计算时序差分误差 δ_t
            delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step]
            
            # 递归计算GAE优势:A_t = δ_t + γλ(1-d_t)A_{t+1}
            advantage = delta + next_is_not_terminal * gamma * lam * advantage
            
            # 折扣回报 = 优势 + 价值估计
            self.returns[step] = advantage + self.values[step]

        # ====================== 优势函数标准化 ======================
        # 标准化优势可以提升PPO训练稳定性(减少不同批次间的尺度差异)
        self.advantages = self.returns - self.values  # 优势 = 回报 - 价值估计
        # 标准化:(A - μ_A) / (σ_A + ε),ε防止除零
        self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)

    def get_statistics(self):
        """
        获取轨迹统计信息:平均轨迹长度、平均奖励
        Returns:
            mean_trajectory_length: 平均轨迹长度
            mean_reward: 平均奖励
        """
        done = self.dones
        done[-1] = 1  # 强制最后一步为结束(确保所有轨迹都被计算)
        # 维度变换:[T, N, 1] → [N*T, 1]
        flat_dones = done.permute(1, 0, 2).reshape(-1, 1)
        # 找到所有结束标志的索引(+起始索引-1)
        done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero(as_tuple=False)[:, 0]))
        # 计算每个轨迹的长度
        trajectory_lengths = (done_indices[1:] - done_indices[:-1])
        
        return trajectory_lengths.float().mean(), self.rewards.mean()

    def mini_batch_generator(self, num_mini_batches, num_epochs=8):
        """
        生成小批次数据迭代器(用于PPO多轮更新)
        核心逻辑:
        1. 将所有数据展平为 [N*T, ...] 形状
        2. 生成随机打乱的索引
        3. 按批次切分索引,生成小批次数据
        
        Args:
            num_mini_batches: 每轮的小批次数量
            num_epochs: 更新轮数
        Yields:
            小批次数据元组:包含观测、动作、优势、回报等
        """
        # 总数据量 = 环境数 × 每个环境的步数
        batch_size = self.num_envs * self.num_transitions_per_env
        # 每个小批次的大小
        mini_batch_size = batch_size // num_mini_batches
        # 生成随机打乱的索引(确保每次epoch数据顺序不同)
        indices = torch.randperm(num_mini_batches*mini_batch_size, requires_grad=False, device=self.device)

        # ====================== 数据展平 ======================
        # 从 [T, N, ...] 展平为 [N*T, ...]
        observations = self.observations.flatten(0, 1)
        # 价值观测:使用特权观测或策略观测
        if self.privileged_observations is not None:
            critic_observations = self.privileged_observations.flatten(0, 1)
        else:
            critic_observations = observations

        actions = self.actions.flatten(0, 1)
        values = self.values.flatten(0, 1)
        returns = self.returns.flatten(0, 1)
        old_actions_log_prob = self.actions_log_prob.flatten(0, 1)
        advantages = self.advantages.flatten(0, 1)
        old_mu = self.mu.flatten(0, 1)
        old_sigma = self.sigma.flatten(0, 1)

        # 多轮更新
        for epoch in range(num_epochs):
            # 生成每一轮的小批次
            for i in range(num_mini_batches):
                # 计算当前批次的索引范围
                start = i * mini_batch_size
                end = (i + 1) * mini_batch_size
                batch_idx = indices[start:end]

                # 根据索引获取小批次数据
                obs_batch = observations[batch_idx]
                critic_observations_batch = critic_observations[batch_idx]
                actions_batch = actions[batch_idx]
                target_values_batch = values[batch_idx]
                returns_batch = returns[batch_idx]
                old_actions_log_prob_batch = old_actions_log_prob[batch_idx]
                advantages_batch = advantages[batch_idx]
                old_mu_batch = old_mu[batch_idx]
                old_sigma_batch = old_sigma[batch_idx]
                
                # 生成小批次数据(RNN隐藏状态暂时返回None)
                yield (obs_batch, critic_observations_batch, actions_batch, target_values_batch, 
                       advantages_batch, returns_batch, old_actions_log_prob_batch, 
                       old_mu_batch, old_sigma_batch, (None, None), None)
相关推荐
木木_王9 小时前
嵌入式Linux学习 | 数据结构 (Day05) 栈与队列详解(原理 + C 语言实现 + 实战实验 + 易错点剖析)
linux·c语言·开发语言·数据结构·笔记·学习
niucloud-admin9 小时前
PHP V6 单商户常见问题——配置了伪静态仍提示接口请求错误,请检查VIE_APP_BASE_URL参数配置或者伪静态配置
php
冷雨夜中漫步9 小时前
Claude Code源码分析——Claude Code Agent Loop 详细设计文档
java·开发语言·人工智能·ai
超龄编码人9 小时前
Qt Widgets Designer QTabWidget无法添加布局
开发语言·qt
Ether IC Verifier9 小时前
OSI网络七层协议详细介绍
服务器·网络·网络协议·计算机网络·php·dpu
直奔標竿9 小时前
Java开发者AI转型第二十六课!Spring AI 个人知识库实战(五)——联网搜索增强实战
java·开发语言·人工智能·spring boot·后端·spring
Python大数据分析@9 小时前
CLI一键采集,使用Python搭建TikTok电商爬虫Agent
开发语言·爬虫·python
@小码农9 小时前
2026年3月Scratch图形化编程等级考试一级真题试卷
开发语言·数据结构·c++·算法
这儿有一堆花9 小时前
住宅代理(Residential Proxy)技术指南
开发语言·数据库·php
一只大袋鼠10 小时前
Java进阶:CGLIB动态代理解析
java·开发语言