【机器学习】机器学习的基本分类-强化学习-REINFORCE 算法

REINFORCE 算法

REINFORCE 是一种基于策略梯度的强化学习算法,直接通过采样环境中的轨迹来优化策略。它是策略梯度方法的基础实现,具有简单直观的优点。


核心思想

  1. 目标函数

    • 最大化策略的期望回报:

    • 通过优化策略参数 θ,使累积回报 J(θ) 最大化。

  2. 策略梯度定理

    • 策略梯度为:

    • 其中 是从时间步 t 开始的累积回报。

  3. 梯度估计

    • 使用采样方法估计梯度:

    • 其中 N 是采样的轨迹数量。


算法流程

  1. 初始化

    • 随机初始化策略参数 θ。
  2. 采样轨迹

    • 使用当前策略 与环境交互,生成 N 条轨迹。
  3. 计算回报

    • 对每条轨迹计算累积回报
  4. 计算梯度

    • 根据策略梯度定理计算梯度
  5. 更新策略参数

    • 使用梯度上升更新策略参数:

  6. 迭代

    • 重复上述步骤,直至策略收敛。

伪代码

python 复制代码
Initialize policy network with random weights θ
for episode in range(max_episodes):
    Generate a trajectory using πθ
    Compute returns G_t for each step in the trajectory
    for each step in the trajectory:
        Compute policy gradient:
            ∇θ J(θ) = ∇θ log πθ(a_t | s_t) * G_t
        Update policy network parameters:
            θ ← θ + α * ∇θ J(θ)

关键特性

  1. 无基线版本

    • 直接使用累积回报 作为更新目标。
    • 高方差:每条轨迹的回报差异可能很大,导致梯度估计的不稳定性。
  2. 基线改进

    • 减少方差的常用方法是在梯度中引入基线 b(s),更新规则变为:

    • 其中 b(st)b(s_t)b(st) 通常是状态值函数 的估计值。


优缺点

优点
  1. 实现简单

    • 通过采样轨迹即可直接优化策略。
  2. 适用于复杂策略

    • 可以学习高维连续动作或多样化策略。
  3. 灵活性

    • 可结合多种改进(如基线、Actor-Critic 方法)。
缺点
  1. 高方差

    • 回报 的方差较高,导致策略收敛较慢。
  2. 数据利用效率低

    • 每次更新仅使用一次采样的轨迹。
  3. 不稳定

    • 需要仔细调整学习率和其他超参数以确保收敛。

应用场景

  1. 游戏 AI

    • 用于优化游戏智能体的策略。
  2. 机器人控制

    • 优化机械臂或移动机器人在连续动作空间中的行为。
  3. 推荐系统

    • 动态优化用户推荐的长期回报。
  4. 金融交易

    • 在复杂的交易环境中设计交易策略。

改进方法

  1. 基线函数

    • 减少策略梯度的方差,提高更新的稳定性。
  2. Actor-Critic

    • 结合值函数的 Actor-Critic 方法,通过同时学习值函数和策略,进一步提高效率。
  3. Trust Region Policy Optimization (TRPO)

    • 限制策略更新幅度,确保每次更新的稳定性。
  4. Proximal Policy Optimization (PPO)

    • 通过裁剪策略更新的范围,兼顾效率和稳定性。

代码示例(简化版)

以下是一个 Python 示例,使用 NumPy 实现 REINFORCE:

python 复制代码
import numpy as np

# 环境接口
class Environment:
    def reset(self):
        # 返回初始状态
        pass

    def step(self, action):
        # 执行动作,返回 (下一状态, 奖励, 是否终止)
        pass

# 策略网络 (简单线性模型)
class PolicyNetwork:
    def __init__(self, state_dim, action_dim):
        self.weights = np.random.randn(state_dim, action_dim)
    
    def predict(self, state):
        logits = np.dot(state, self.weights)
        return np.exp(logits) / np.sum(np.exp(logits))  # Softmax
    
    def update(self, grads, learning_rate):
        self.weights += learning_rate * grads

# REINFORCE 算法
def reinforce(env, policy, episodes, learning_rate):
    for episode in range(episodes):
        state = env.reset()
        trajectory = []
        
        # 采样轨迹
        while True:
            probs = policy.predict(state)
            action = np.random.choice(len(probs), p=probs)
            next_state, reward, done = env.step(action)
            trajectory.append((state, action, reward))
            state = next_state
            if done:
                break
        
        # 计算回报
        G = 0
        grads = np.zeros_like(policy.weights)
        for t, (state, action, reward) in enumerate(reversed(trajectory)):
            G = reward + 0.99 * G
            grad = np.zeros_like(policy.weights)
            grad[:, action] = state
            grads += grad * (G - np.mean([x[2] for x in trajectory]))  # 使用基线
        
        # 更新策略
        policy.update(grads, learning_rate)
相关推荐
volcanical13 分钟前
MoCo 对比自监督学习
人工智能·学习·机器学习
yuanbenshidiaos23 分钟前
linux----文件访问(c语言)
linux·服务器·算法
四口鲸鱼爱吃盐26 分钟前
Pytorch | 从零构建Vgg对CIFAR10进行分类
人工智能·pytorch·分类
不想当程序猿_38 分钟前
【蓝桥杯每日一题】扫雷——暴力搜索
算法·蓝桥杯
Best_Me071 小时前
最短路径C++
java·c++·算法
程序猿(雷霆之王)2 小时前
优选算法——链表
数据结构·算法·链表
野風_199602012 小时前
代码随想录第52天
算法
图学习的小张2 小时前
论文笔记:是什么让多模态学习变得困难?
论文阅读·神经网络·机器学习
关关钧2 小时前
【Linux】结构化命令:控制循环
linux·运维·算法
赵钰老师3 小时前
遥感影像目标检测:从CNN(Faster-RCNN)到Transformer(DETR
pytorch·python·深度学习·目标检测·机器学习·cnn·transformer