ppo爬坡代码及解释

代码来源参考:PPO-小车爬坡-迭代500次收敛

https://blog.csdn.net/qq_42670001/article/details/135025811

PPO(Proximal Policy Optimization,近端策略优化)是强化学习中一种高效且稳定的算法,属于actor-critic框架。它通过限制新策略与旧策略的差异("近端"约束)解决了传统策略梯度方法中更新步长难以控制的问题。下面结合你提供的代码,从核心组件、算法流程和关键细节三个方面讲解PPO。

ppo.py代码

csharp 复制代码
import gym
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import rl_utils


class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=1)


class ValueNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super(ValueNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)



class PPO:
    ''' PPO算法,采用截断方式 '''
    def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,
                 lmbda, epochs, eps, gamma, device):
        self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        self.critic = ValueNet(state_dim, hidden_dim).to(device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=critic_lr)
        self.gamma = gamma
        self.lmbda = lmbda
        self.epochs = epochs  # 一条序列的数据用来训练轮数
        self.eps = eps  # PPO中截断范围的参数
        self.device = device

    def take_action(self, state):
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        probs = self.actor(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()

    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'],
                              dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(
            self.device)
        rewards = torch.tensor(transition_dict['rewards'],
                               dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'],
                                   dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'],
                             dtype=torch.float).view(-1, 1).to(self.device)
        td_target = rewards + self.gamma * self.critic(next_states) * (1 -
                                                                       dones)
        td_delta = td_target - self.critic(states)
        advantage = rl_utils.compute_advantage(self.gamma, self.lmbda,
                                               td_delta.cpu()).to(self.device)
        old_log_probs = torch.log(self.actor(states).gather(1,
                                                            actions)).detach()

        for _ in range(self.epochs):
            log_probs = torch.log(self.actor(states).gather(1, actions))
            ratio = torch.exp(log_probs - old_log_probs)
            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1 - self.eps,
                                1 + self.eps) * advantage  # 截断
            actor_loss = torch.mean(-torch.min(surr1, surr2))  # PPO损失函数
            critic_loss = torch.mean(
                F.mse_loss(self.critic(states), td_target.detach()))
            self.actor_optimizer.zero_grad()
            self.critic_optimizer.zero_grad()
            actor_loss.backward()
            critic_loss.backward()
            self.actor_optimizer.step()
            self.critic_optimizer.step()

if __name__ == "__main__":

    video_path = '/video'

    actor_lr = 1e-3
    critic_lr = 1e-4
    num_episodes = 1000
    hidden_dim = 128
    gamma = 0.98
    lmbda = 0.95
    epochs = 10
    eps = 0.2
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
        "cpu")

    env_name = 'MountainCar-v0'
    env = gym.make(env_name)
    # env.seed(0)
    torch.manual_seed(0)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda,
                epochs, eps, gamma, device)
    return_list = []
    stateRecord = []
    for i_episode in range(num_episodes):

        episode_return = 0
        transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}
        state = env.reset()

        done = False
        step = 0

        while not done:

            step += 1
            action = agent.take_action(state)
            next_state, reward, done, _ = env.step(action)
            x = next_state[0]
            v = next_state[1] / 0.07
            if action - 1 < 0 and next_state[1] < 0:
                reward = 1
            elif action - 1 > 0 and next_state[1] > 0:
                reward = 1
            else:
                reward = -1
            if done:
                if x >= 0.5:
                    # env.render()
                    stateRecord.append(1)
                    reward = 200 - step
                else:
                    stateRecord.append(0)
                    # reward = -100

            if i_episode % 1 == 0:
                env.render()

            transition_dict['states'].append(state)
            transition_dict['actions'].append(action)
            transition_dict['next_states'].append(next_state)

            transition_dict['rewards'].append(reward)
            transition_dict['dones'].append(done)
            state = next_state
            episode_return += reward

        return_list.append(episode_return)
        agent.update(transition_dict)

        print({'episode': '%d' % (i_episode + 1), 'return': '%.3f' %sum(return_list[-1:])}, step, "winRate: " , sum(stateRecord[-50:]) / 50)


    episodes_list = list(range(len(return_list)))
    plt.plot(episodes_list, return_list)
    plt.xlabel('Episodes')
    plt.ylabel('Returns')
    plt.title('PPO on {}'.format(env_name))
    plt.show()

    mv_return = rl_utils.moving_average(return_list, 9)
    plt.plot(episodes_list, mv_return)
    plt.xlabel('Episodes')
    plt.ylabel('Returns')
    plt.title('PPO on {}'.format(env_name))
    plt.show()

一、核心组件解析

代码中主要包含3个核心部分:策略网络(Actor)、价值网络(Critic)和PPO算法主体。

1. 策略网络(PolicyNet)------ Actor

策略网络的作用是根据当前状态输出动作的概率分布,用于指导智能体选择动作。

python 复制代码
class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)  # 输入层到隐藏层
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim) # 隐藏层到输出层

    def forward(self, x):
        x = F.relu(self.fc1(x))  # 激活函数引入非线性
        return F.softmax(self.fc2(x), dim=1)  # 输出动作的概率分布(和为1)
  • 输入:状态(state_dim维度,如MountainCar的位置和速度,维度为2)。
  • 输出:每个动作的概率(action_dim维度,如MountainCar的3个动作:左移、不动、右移)。
  • 关键:通过softmax确保输出是合法的概率分布,后续可通过采样(如Categorical)选择动作。
2. 价值网络(ValueNet)------ Critic

价值网络的作用是估计当前状态的价值(即从当前状态出发的期望累积奖励),用于计算"优势函数"(衡量动作的好坏)。

python 复制代码
class ValueNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super(ValueNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, 1)  # 输出标量(状态价值)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)  # 输出状态s的价值V(s)
  • 输入:状态(同策略网络)。
  • 输出:标量V(s),表示状态s的价值估计。
3. PPO算法主体

PPO类封装了算法的核心逻辑:动作选择、数据收集和参数更新。

二、PPO核心流程(结合代码)

PPO的核心流程可分为数据收集参数更新两大步,下面结合代码细节说明。

1. 数据收集(与环境交互)

在每一轮episode中,智能体通过策略网络选择动作,与环境交互并记录轨迹数据(状态、动作、奖励等)。

python 复制代码
# 主循环中收集数据
transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}
state = env.reset()
done = False
while not done:
    action = agent.take_action(state)  # 用策略网络选动作
    next_state, reward, done, _ = env.step(action)
    # 记录轨迹数据
    transition_dict['states'].append(state)
    transition_dict['actions'].append(action)
    ...  # 记录next_state、reward、done
    state = next_state
  • take_action方法:通过策略网络输出的概率分布采样动作(确保探索性)。

    python 复制代码
    def take_action(self, state):
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        probs = self.actor(state)  # 得到动作概率分布
        action_dist = torch.distributions.Categorical(probs)  # 构建分类分布
        action = action_dist.sample()  # 从分布中采样动作
        return action.item()
2. 参数更新(核心!PPO的"近端"约束体现)

收集完一条轨迹后,PPO通过多轮迭代更新策略网络和价值网络,核心是限制新策略与旧策略的差异

步骤1:计算TD目标和优势函数
  • TD目标(td_target) :用于更新价值网络,定义为r + γ·V(s')(即时奖励+折扣后下一状态价值)。

  • 优势函数(advantage) :衡量动作的"额外价值"(即该动作比平均水平好多少),公式为A(s,a) = Q(s,a) - V(s),代码中用GAE(广义优势估计)计算以平衡偏差和方差。

    python 复制代码
    # 计算TD目标
    td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)
    # 计算TD误差(td_delta = td_target - V(s))
    td_delta = td_target - self.critic(states)
    # 用GAE计算优势函数(结合gamma和lmbda平滑)
    advantage = rl_utils.compute_advantage(self.gamma, self.lmbda, td_delta.cpu()).to(self.device)
步骤2:固定旧策略的概率(关键)

更新前先记录旧策略(未更新的策略)对当前轨迹中动作的概率,用于后续限制新策略的偏差。

python 复制代码
# 旧策略的对数概率(detach()固定,不参与梯度更新)
old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach()
步骤3:多轮迭代更新(epochs)

PPO的一大特点是用同一批数据训练多轮(提高数据利用率),但每轮都通过"截断"限制策略更新幅度。

python 复制代码
for _ in range(self.epochs):  # 同一批数据训练epochs轮
    # 新策略的对数概率
    log_probs = torch.log(self.actor(states).gather(1, actions))
    # 计算新旧策略的概率比(ratio = 新策略概率 / 旧策略概率)
    ratio = torch.exp(log_probs - old_log_probs)  # 等价于exp(log(new) - log(old)) = new/old
    
    # 计算两个替代损失(PPO的核心!)
    surr1 = ratio * advantage  # 未截断的损失
    surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage  # 截断的损失(限制ratio在[1-eps, 1+eps])
    
    # 策略损失:取两个损失的最小值(确保更新幅度不超过eps)
    actor_loss = torch.mean(-torch.min(surr1, surr2))
    # 价值损失:均方误差(让V(s)逼近td_target)
    critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))
    
    # 反向传播更新参数
    self.actor_optimizer.zero_grad()
    self.critic_optimizer.zero_grad()
    actor_loss.backward()
    critic_loss.backward()
    self.actor_optimizer.step()
    self.critic_optimizer.step()

为什么要截断?

如果新策略与旧策略差异过大(ratio超出[1-eps, 1+eps]),之前估计的优势函数可能不再准确(因为优势基于旧策略计算),此时强制用边界值限制更新,避免策略"跑偏",保证训练稳定。

三、关键超参数说明

代码中的超参数直接影响PPO的性能,需重点理解:

  • gamma:折扣因子(未来奖励的衰减系数,通常0.9~0.99)。
  • lmbda:GAE中的平滑系数(平衡优势估计的偏差和方差,通常0.9~0.95)。
  • epochs:同一批数据的训练轮数(通常5~20,太少利用率低,太多易过拟合)。
  • eps:截断参数(通常0.2,控制新旧策略的最大差异)。
  • actor_lr/critic_lr:策略/价值网络的学习率(策略通常略大于价值网络)。

四、总结

PPO通过以下设计实现高效稳定的强化学习:

  1. actor-critic框架:策略网络(选动作)和价值网络(评价值)协同工作。
  2. 截断损失 :限制新策略与旧策略的差异(ratio[1-eps, 1+eps]),避免更新不稳定。
  3. GAE优势估计:更准确地衡量动作价值,平衡偏差和方差。
  4. 多轮迭代:同一批数据训练多轮,提高数据利用率。

代码完整实现了PPO的核心逻辑,并在MountainCar环境中验证,通过自定义奖励函数(如根据动作与速度方向是否一致调整奖励)加速了训练过程。

rl_utils.py代码

csharp 复制代码
from tqdm import tqdm
import numpy as np
import torch
import collections
import random


class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity)

    def add(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        transitions = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*transitions)
        return np.array(state), action, reward, np.array(next_state), done

    def size(self):
        return len(self.buffer)


def moving_average(a, window_size):
    cumulative_sum = np.cumsum(np.insert(a, 0, 0))
    middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_size
    r = np.arange(1, window_size - 1, 2)
    begin = np.cumsum(a[:window_size - 1])[::2] / r
    end = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1]
    return np.concatenate((begin, middle, end))


def train_on_policy_agent(env, agent, num_episodes):
    return_list = []
    for i in range(10):
        with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:
            for i_episode in range(int(num_episodes / 10)):
                episode_return = 0
                transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}
                state = env.reset()
                env.render()
                done = False
                while not done:
                    action = agent.take_action(state)
                    next_state, reward, done, _ = env.step(action)
                    transition_dict['states'].append(state)
                    transition_dict['actions'].append(action)
                    transition_dict['next_states'].append(next_state)
                    transition_dict['rewards'].append(reward)
                    transition_dict['dones'].append(done)
                    state = next_state
                    episode_return += reward
                return_list.append(episode_return)
                agent.update(transition_dict)
                if (i_episode + 1) % 10 == 0:
                    pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + i_episode + 1),
                                      'return': '%.3f' % np.mean(return_list[-10:])})
                pbar.update(1)
    return return_list


def train_off_policy_agent(env, agent, num_episodes, replay_buffer, minimal_size, batch_size):
    return_list = []
    for i in range(10):
        with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:
            for i_episode in range(int(num_episodes / 10)):
                episode_return = 0
                state = env.reset()
                done = False
                while not done:
                    action = agent.take_action(state)
                    next_state, reward, done, _ = env.step(action)
                    replay_buffer.add(state, action, reward, next_state, done)
                    state = next_state
                    episode_return += reward
                    if replay_buffer.size() > minimal_size:
                        b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)
                        transition_dict = {'states': b_s, 'actions': b_a, 'next_states': b_ns, 'rewards': b_r,
                                           'dones': b_d}
                        agent.update(transition_dict)
                return_list.append(episode_return)
                if (i_episode + 1) % 10 == 0:
                    pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + i_episode + 1),
                                      'return': '%.3f' % np.mean(return_list[-10:])})
                pbar.update(1)
    return return_list


def compute_advantage(gamma, lmbda, td_delta):
    td_delta = td_delta.detach().numpy()
    advantage_list = []
    advantage = 0.0
    for delta in td_delta[::-1]:
        advantage = gamma * lmbda * advantage + delta
        advantage_list.append(advantage)
    advantage_list.reverse()
    return torch.tensor(advantage_list, dtype=torch.float)

这段代码实现了强化学习中常用的工具类和辅助函数,主要用于支持在线策略(On-Policy)离线策略(Off-Policy) 算法的训练,以及数据处理和性能评估。下面逐一解析各部分的功能和作用:

一、ReplayBuffer 经验回放池

经验回放池是离线策略(Off-Policy) 算法(如DQN、DDPG)的核心组件,用于存储智能体与环境交互产生的"经验"(状态、动作、奖励等),并通过随机采样打破样本间的相关性,提升训练稳定性。

python 复制代码
class ReplayBuffer:
    def __init__(self, capacity):
        # 用双向队列实现固定容量的缓冲区,超出容量时自动删除旧数据
        self.buffer = collections.deque(maxlen=capacity)

    def add(self, state, action, reward, next_state, done):
        # 向缓冲区添加一条经验 (s, a, r, s', done)
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        # 从缓冲区随机采样batch_size条经验,并按类型拆分返回
        transitions = random.sample(self.buffer, batch_size)
        # 拆分状态、动作、奖励、下一状态、终止标志
        state, action, reward, next_state, done = zip(*transitions)
        return np.array(state), action, reward, np.array(next_state), done

    def size(self):
        # 返回当前缓冲区的经验数量
        return len(self.buffer)

核心作用

  • 存储历史经验,解决离线策略中"当前策略"与"收集数据的策略"不一致的问题(例如DQN中用ε-贪心策略收集数据,用目标网络更新)。
  • 随机采样避免样本序列相关性(如连续状态高度相似)导致的训练波动,让梯度更新更稳定。

二、moving_average 移动平均函数

用于平滑奖励曲线,消除训练过程中的随机波动,更清晰地展示算法性能的趋势。

python 复制代码
def moving_average(a, window_size):
    # 计算累积和(在开头插入0,方便后续差分)
    cumulative_sum = np.cumsum(np.insert(a, 0, 0))
    # 中间部分:窗口内的平均值(核心平滑逻辑)
    middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_size
    # 处理开头部分(窗口小于window_size时的平均)
    r = np.arange(1, window_size - 1, 2)
    begin = np.cumsum(a[:window_size - 1])[::2] / r
    # 处理结尾部分(对称于开头)
    end = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1]
    # 拼接开头、中间、结尾,得到完整的平滑曲线
    return np.concatenate((begin, middle, end))

举例

如果原始奖励序列是 [1,3,5,7,9],窗口大小为3,移动平均后会得到更平滑的序列(如 [1,3,5,7,9] → 中间部分为 [(1+3+5)/3, (3+5+7)/3, (5+7+9)/3])。
作用:在可视化训练曲线时,减少噪声干扰,更直观地判断算法是否收敛。

三、train_on_policy_agent 在线策略训练流程

在线策略算法(如PPO、REINFORCE)要求收集数据的策略与当前训练的策略完全一致,因此每轮交互的数据仅用于一次更新,之后丢弃。该函数封装了这类算法的通用训练逻辑。

python 复制代码
def train_on_policy_agent(env, agent, num_episodes):
    return_list = []  # 记录每轮的总奖励
    # 分10个阶段显示进度
    for i in range(10):
        with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:
            for i_episode in range(int(num_episodes / 10)):
                episode_return = 0  # 本轮的总奖励
                # 存储本轮的轨迹数据 (s, a, s', r, done)
                transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}
                state = env.reset()  # 重置环境,获取初始状态
                env.render()  # 可视化环境(可选)
                done = False
                while not done:
                    action = agent.take_action(state)  # 用当前策略选动作
                    next_state, reward, done, _ = env.step(action)  # 与环境交互
                    # 记录轨迹数据
                    transition_dict['states'].append(state)
                    transition_dict['actions'].append(action)
                    transition_dict['next_states'].append(next_state)
                    transition_dict['rewards'].append(reward)
                    transition_dict['dones'].append(done)
                    state = next_state
                    episode_return += reward  # 累加奖励
                return_list.append(episode_return)  # 记录本轮总奖励
                agent.update(transition_dict)  # 用本轮轨迹更新智能体(核心:在线策略每轮更新一次)
                # 每10轮显示一次平均奖励
                if (i_episode + 1) % 10 == 0:
                    pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + i_episode + 1),
                                      'return': '%.3f' % np.mean(return_list[-10:])})
                pbar.update(1)
    return return_list

核心特点

  • 每轮交互结束后,立即用本轮收集的轨迹数据(transition_dict)更新智能体(agent.update)。
  • 数据仅用一次(因为下一轮的策略已更新,与收集数据的策略不同),符合在线策略"策略与数据同步"的要求。
  • 典型应用:PPO、REINFORCE等算法。

四、train_off_policy_agent 离线策略训练流程

离线策略算法(如DQN、DDPG)允许收集数据的策略(行为策略)与训练的策略(目标策略)不同,因此数据可以重复使用(通过经验回放池)。该函数封装了这类算法的通用训练逻辑。

python 复制代码
def train_off_policy_agent(env, agent, num_episodes, replay_buffer, minimal_size, batch_size):
    return_list = []  # 记录每轮的总奖励
    for i in range(10):
        with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:
            for i_episode in range(int(num_episodes / 10)):
                episode_return = 0
                state = env.reset()
                done = False
                while not done:
                    action = agent.take_action(state)  # 用行为策略(如ε-贪心)选动作
                    next_state, reward, done, _ = env.step(action)
                    # 将经验添加到回放池
                    replay_buffer.add(state, action, reward, next_state, done)
                    state = next_state
                    episode_return += reward
                    # 当回放池数据量超过最小阈值后,开始采样训练
                    if replay_buffer.size() > minimal_size:
                        # 从回放池随机采样一批数据
                        b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)
                        transition_dict = {'states': b_s, 'actions': b_a, 'next_states': b_ns, 'rewards': b_r,
                                           'dones': b_d}
                        agent.update(transition_dict)  # 用采样的数据更新智能体
                return_list.append(episode_return)
                # 每10轮显示一次平均奖励
                if (i_episode + 1) % 10 == 0:
                    pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + i_episode + 1),
                                      'return': '%.3f' % np.mean(return_list[-10:])})
                pbar.update(1)
    return return_list

核心特点

  • 经验先存入replay_buffer,只有当池内数据足够(> minimal_size)时才开始采样更新。
  • 同一批数据可被多次采样使用(与在线策略的"一次性使用"不同),提高数据利用率。
  • 行为策略(如DQN中的ε-贪心)与目标策略(更新的Q网络)分离,符合离线策略的特性。
  • 典型应用:DQN、DDPG、SAC等算法。

五、compute_advantage 优势函数计算(GAE)

优势函数 A(s,a) = Q(s,a) - V(s) 衡量"动作a相对于当前状态s的平均价值的优劣",是策略梯度方法(如PPO)的核心组件。该函数用GAE(Generalized Advantage Estimation,广义优势估计) 计算优势,平衡偏差和方差。

python 复制代码
def compute_advantage(gamma, lmbda, td_delta):
    td_delta = td_delta.detach().numpy()  # TD误差:δ_t = r_t + γV(s_{t+1}) - V(s_t)
    advantage_list = []
    advantage = 0.0  # 初始化优势
    # 从后往前计算(因为优势依赖未来的TD误差)
    for delta in td_delta[::-1]:
        # GAE公式:A_t = δ_t + γλA_{t+1}
        advantage = gamma * lmbda * advantage + delta
        advantage_list.append(advantage)
    advantage_list.reverse()  # 反转,恢复时间顺序
    return torch.tensor(advantage_list, dtype=torch.float)

为什么用GAE?

  • 直接用TD误差 δ_t 估计优势(A_t ≈ δ_t)方差小但偏差大;
  • 用累积奖励估计(A_t = sum(γ^k δ_{t+k}))偏差小但方差大;
  • GAE通过 λ 平衡两者(λ=0 等价于TD误差,λ=1 等价于累积奖励),让优势估计更稳定。

总结

这段代码是强化学习算法的"基础设施",总结如下:

  • ReplayBuffer:支撑离线策略算法的数据存储与采样。
  • moving_average:平滑奖励曲线,辅助训练效果可视化。
  • train_on_policy_agent/train_off_policy_agent:分别封装在线/离线策略的通用训练流程,简化算法实现。
  • compute_advantage:用GAE计算优势函数,为策略梯度算法提供核心的"动作好坏"度量。

这些工具可以直接复用在各类强化学习算法中,例如之前的PPO(在线策略)可搭配train_on_policy_agentcompute_advantage,DQN(离线策略)可搭配ReplayBuffertrain_off_policy_agent

相关推荐
OpenBayes2 小时前
教程上新丨Deepseek-OCR 以极少视觉 token 数在端到端模型中实现 SOTA
人工智能·深度学习·机器学习·ocr·大语言模型·文本处理·deepseek
蓝海星梦2 小时前
【论文笔记】R-HORIZON:重塑长周期推理评估与训练范式
论文阅读·人工智能·深度学习·自然语言处理·大型推理模型
da_vinci_x2 小时前
Substance 3D 材质流:AI 快速生成与程序化精修
人工智能·游戏·3d·材质·设计师·技术美术·游戏美术
aneasystone本尊2 小时前
重温 Java 21 之密钥封装机制 API
人工智能
欢聚赢销CRM3 小时前
从“各自为战“到“数据协同“:销采一体化CRM正在重构供应链竞争力
大数据·人工智能·重构·数据分析
IT_陈寒3 小时前
Python 3.12 新特性实战:10个让你代码更优雅的隐藏技巧
前端·人工智能·后端
说私域3 小时前
基于开源AI智能名片链动2+1模式与S2B2C商城小程序的商家活动策略研究
人工智能·小程序
亚马逊云开发者3 小时前
Agentic AI基础设施实践经验系列(八):Agent应用的隐私和安全
人工智能
大模型真好玩3 小时前
GPT-5.1 核心特性深度解析,它会是模型性能的新标杆吗?
人工智能