📢本篇文章是博主强化学习(RL)领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对相关等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在👉强化学习专栏:
【强化学习】- 【单智能体强化学习】(6)---《策略梯度---REINFORCE算法》
策略梯度---REINFORCE算法
目录
[1.REINFORCE 算法](#1.REINFORCE 算法)
[4.REINFORCE 算法流程](#4.REINFORCE 算法流程)
[[Notice] 注意事项](#[Notice] 注意事项)
[7.Policy Gradient算法和REINFORCE算法的对比](#7.Policy Gradient算法和REINFORCE算法的对比)
[8.REINFORCE 的优点和缺点](#8.REINFORCE 的优点和缺点)
1.REINFORCE 算法
REINFORCE 是一种策略梯度算法,用于强化学习中的策略优化问题。它的核心思想是直接优化策略,通过采样环境中的轨迹来估计梯度并更新策略。
2.基本概念
2.1 策略 (Policy)
策略 :表示在状态 下选择动作的概率,其中是策略的参数。
2.2 目标
最大化策略的期望累计奖励,即:
通过调整参数来提升。
2.3 策略梯度
REINFORCE 是一种基于梯度的方法,通过梯度上升优化 。
3.算法的关键思想
3.1 梯度公式
利用强化学习的公式推导出梯度:
是从状态出发后的累计奖励,作为对策略 好坏的衡量。
3.2 梯度估计
使用蒙特卡洛采样的方法,从环境中生成轨迹,估计梯度:
其中 是采样轨迹的数量。
4.REINFORCE 算法流程
初始化策略参数 (通常是神经网络的权重)。
重复以下步骤,直到收敛:
采样轨迹 :从环境中采样多条轨迹,基于当前策略 。
计算奖励 :计算每条轨迹的累积奖励 。
策略更新 :,其中 是学习率。
5.公式推导
策略梯度推导
强化学习的目标是最大化期望奖励:
使用采样分布的导数性质,得到:
蒙特卡洛估计
直接使用轨迹 ( \tau ) 中的采样点替代期望,得到无偏估计:
梯度更新
根据上述梯度进行更新,即完成策略的优化。
6.算法的改进点
基线函数 (Baseline)
为了减少累积奖励 ( R ) 的方差,引入基线函数 ( b(s) ),即:
常用的基线是状态值函数
。
优势函数
将基线函数扩展为优势函数,即:
减少方差
使用 Actor-Critic 算法,用一个 Critic 网络来估计 ,从而进一步降低方差。
[Python]REINFORCE算法实现
项目代码我已经放入GitCode里面,可以通过下面链接跳转:🔥
后续相关单智能体强化学习算法也会不断在**【强化学习】**项目里更新,如果该项目对你有所帮助,请帮我点一个星星 ✨✨✨✨✨,鼓励分享,十分感谢!!!
若是下面代码复现困难或者有问题,也欢迎评论区留言。
python
"""《REINFORCE算法项目》
时间:2024.012
作者:不去幼儿园
"""
import argparse # 用于处理命令行参数
import gym # 引入OpenAI Gym库,提供强化学习环境
import numpy as np # 用于数值计算
from itertools import count # 用于创建无限计数器
import torch # 引入PyTorch库,用于深度学习
import torch.nn as nn # 引入神经网络模块
import torch.nn.functional as F # 引入常用的神经网络函数,如激活函数
import torch.optim as optim # 引入优化器模块
from torch.distributions import Categorical # 引入分类分布,用于策略采样
定义策略网络
python
# 定义策略网络
class Policy(nn.Module):
def __init__(self):
super(Policy, self).__init__()
self.affine1 = nn.Linear(4, 128) # 输入层:状态维度为4,隐层维度为128
self.affine2 = nn.Linear(128, 2) # 输出层:动作维度为2(左右移动)
self.saved_log_probs = [] # 保存动作对应的log概率,用于后续梯度计算
self.rewards = [] # 保存回合奖励
def forward(self, x):
x = F.relu(self.affine1(x)) # 隐层使用ReLU激活函数
action_scores = self.affine2(x) # 输出动作得分
return F.softmax(action_scores, dim=1) # 使用Softmax将动作得分转换为概率分布
更新策略
python
# 完成一个回合并更新策略
def finish_episode():
R = 0 # 初始化累计折扣奖励
policy_loss = [] # 用于保存策略损失
rewards = [] # 保存折扣奖励
# 计算折扣奖励
for r in policy.rewards[::-1]: # 倒序遍历每一步的奖励
R = r + args.gamma * R # 计算累计奖励
rewards.insert(0, R) # 将累计奖励插入到列表开头
rewards = torch.tensor(rewards) # 将奖励转换为张量
rewards = (rewards - rewards.mean()) / (rewards.std() + eps) # 标准化奖励
# 计算策略损失
for log_prob, reward in zip(policy.saved_log_probs, rewards):
policy_loss.append(-log_prob * reward) # 损失是负的log概率乘以奖励
optimizer.zero_grad() # 清零梯度
policy_loss = torch.cat(policy_loss).sum() # 合并所有损失并求和
policy_loss.backward() # 反向传播计算梯度
optimizer.step() # 使用优化器更新网络参数
del policy.rewards[:] # 清空回合奖励
del policy.saved_log_probs[:] # 清空log概率
主函数
python
# 主函数
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch REINFORCE example') # 创建命令行参数解析器
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='discount factor (default: 0.99)') # 折扣因子
parser.add_argument('--seed', type=int, default=543, metavar='N',
help='random seed (default: 543)') # 随机种子
parser.add_argument('--render', action='store_true',
help='render the environment') # 是否渲染环境
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='interval between training status logs (default: 10)') # 日志间隔
args = parser.parse_args() # 解析命令行参数
policy = Policy() # 创建策略网络
optimizer = optim.Adam(policy.parameters(), lr=1e-2) # 定义优化器,使用Adam,学习率为0.01
eps = np.finfo(np.float32).eps.item() # 防止浮点数精度问题的小数值
env = gym.make('CartPole-v1') # 创建CartPole-v1环境
torch.manual_seed(args.seed) # 设置随机种子
running_reward = 10 # 初始化运行奖励
for i_episode in count(1): # 无限循环,直到达到停止条件
state, _ = env.reset() # 重置环境,获取初始状态
for t in range(10000): # 限制最大时间步,防止无限循环
action = select_action(state) # 选择动作
state, reward, done, _, _ = env.step(action) # 执行动作,返回下一状态和奖励
if args.render: # 如果启用了渲染
env.render() # 渲染环境
policy.rewards.append(reward) # 保存奖励
if done: # 如果当前回合结束
break # 跳出循环
running_reward = running_reward * 0.99 + t * 0.01 # 更新运行奖励
finish_episode() # 完成回合并更新策略
if i_episode % args.log_interval == 0: # 每隔一定回合打印日志
print('Episode {}\tLast length: {:5d}\tAverage length: {:.2f}'.format(
i_episode, t, running_reward))
if running_reward > env.spec.reward_threshold: # 如果运行奖励超过环境要求的阈值
print("Solved! Running reward is now {} and "
"the last episode runs to {} time steps!".format(running_reward, t))
break # 训练结束
定义策略网络
python
# 定义策略网络
class Policy(nn.Module):
def __init__(self):
super(Policy, self).__init__()
self.affine1 = nn.Linear(4, 128) # 输入层:状态维度为4,隐层维度为128
self.affine2 = nn.Linear(128, 2) # 输出层:动作维度为2(左右移动)
self.saved_log_probs = [] # 保存动作对应的log概率,用于后续梯度计算
self.rewards = [] # 保存回合奖励
def forward(self, x):
x = F.relu(self.affine1(x)) # 隐层使用ReLU激活函数
action_scores = self.affine2(x) # 输出动作得分
return F.softmax(action_scores, dim=1) # 使用Softmax将动作得分转换为概率分布
[Notice] 注意事项
策略网络 :Policy
类定义了一个简单的全连接神经网络,用于预测动作的概率分布。
REINFORCE算法:
使用策略梯度方法直接优化策略。
每一回合计算折扣奖励,并基于梯度上升更新策略。
环境交互:
使用OpenAI Gym的CartPole-v1
环境。
通过策略网络与环境交互并采集经验。
可以通过调整--gamma
、--seed
和--log-interval
等参数来测试不同的训练效果。
bash
# 环境配置
Python 3.11.5
torch 2.1.0
torchvision 0.16.0
gym 0.26.2
7.Policy Gradient算法和REINFORCE算法的对比
PG(Policy Gradient)算法是一个更大的算法框架,而 REINFORCE 是 PG 算法的一种具体实现。因此,比较两者的关键在于 PG 的普适性和 REINFORCE 的具体特性。
优缺点对比
特性 | PG 算法 | REINFORCE |
---|---|---|
普适性 | 是一个框架,包含多个算法实现。 | 是 PG 框架下的一种具体实现。 |
梯度估计方式 | 支持多种方法,如时间差分、优势函数、基线等。 | 仅使用蒙特卡洛方法,直接计算累积奖励。 |
方差 | 使用基线函数和优势函数,可以有效降低梯度估计的方差。 | 累积奖励估计方差较大,收敛速度较慢。 |
偏差 | 根据估计方法,可能引入小偏差。 | 梯度估计无偏,但训练噪声较大。 |
计算效率 | Actor-Critic 等变种可以逐步更新,计算高效。 | 需要完整的轨迹,不能实时更新,效率较低。 |
适用场景 | 高效适用于连续动作空间或复杂策略优化问题。 | 适用于小型问题或作为基线方法用于研究和对比。 |
- PG 是更灵活和高效的框架,适用于复杂任务,算法设计可扩展性强。
- REINFORCE 是最基础的策略梯度算法,适合作为入门学习和小型问题的参考实现,但在实际应用中效果可能受限于方差较高的问题。
8.REINFORCE 的优点和缺点
优点
算法简单易实现,适用于多种环境。
不需要建模环境的动态或奖励函数。
缺点
收敛速度较慢,尤其是在高维动作空间中。
奖励的方差可能较大,影响梯度估计。
9.总结
REINFORCE 是强化学习中的经典策略梯度方法,通过直接优化策略来解决问题。尽管存在一些缺陷(如方差较高),但它为后续的改进算法(如 Actor-Critic、PPO)奠定了理论基础。
更多强化学习 文章,请前往:【强化学习(RL)】专栏
博客都是给自己看的笔记,如有误导深表抱歉。文章若有不当和不正确之处,还望理解与指出。由于部分文字、图片等来源于互联网,无法核实真实出处,如涉及相关争议,请联系博主删除。如有错误、疑问和侵权,欢迎评论留言联系作者,或者添加VX:**Rainbook_2,**联系作者。✨