强化学习第七课 —— 策略网络设计指南:赋予 Agent“大脑”的艺术

在强化学习(RL)中,我们总是把策略函数写成 πθ(a∣s)\pi_\theta(a|s)πθ(a∣s)。

这短短的一个 θ\thetaθ,掩盖了无数的工程细节。

策略网络(Policy Network)的本质,是一个 特征提取器(Feature Extractor)加上一个决策头(Decision Head)。它的核心任务非常明确:

  1. 读懂环境 :从原始的状态 sss 中提取有用的信息。
  2. 做出决策 :将这些信息映射为动作 aaa 的概率分布。

你的 Agent 是天才还是弱智,很大程度上取决于你给它装了一个什么样的"大脑"。是简单的线性感知机?还是复杂的视觉皮层?亦或是拥有长短期记忆的海马体?

今天,我们来聊聊卷积(CNN)、全连接(MLP)和循环(RNN)在 RL 中的生存法则。


一、MLP(多层感知机):简单粗暴的"直觉反射"

1. 适用场景

  • 状态(State):低维向量。例如:机器人的关节角度、速度、位置坐标、股票因子的数值。
  • 特点:状态特征之间没有明显的空间结构或时间依赖。

2. 设计哲学

当你的输入已经是由物理引擎计算好的特征向量(Feature Vector)时,你不需要花哨的结构。MLP 是最稳健的选择。

设计要点:

  • 不要太深 :不同于计算机视觉(CV)动辄上百层的 ResNet,RL 中的 MLP 通常很浅。2 到 3 层隐藏层(每层 64 到 256 个神经元)通常就足够解决 MuJoCo 或简单的控制任务。
  • 激活函数
    • ReLU:标准选择,但在 RL 中有时会导致"死神经元"问题。
    • Tanh:在 PPO/TRPO 等连续控制策略中,Tanh 往往比 ReLU 表现更好。因为它是有界的(-1 到 1),且梯度更平滑。
    • Layer Normalization:强烈建议加上。RL 的数据分布是非平稳的(Non-stationary),LayerNorm 能极大地稳定训练。

3. 代码直觉

python 复制代码
# 一个标准的 MLP 策略网络骨架
class MLPPolicy(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),  # RL中 Tanh 常常优于 ReLU
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, action_dim)
        )

二、CNN(卷积神经网络):Agent 的"视觉皮层"

1. 适用场景

  • 状态(State):高维图像、网格地图。例如:Atari 游戏画面、星际争霸的小地图、自动驾驶的摄像头输入。
  • 特点 :输入具有空间局部相关性(Spatial Locality)。像素点 (0,0) 和 (0,1) 的关系比 (0,0) 和 (100,100) 的关系更紧密。

2. 设计哲学

如果把图像展平直接喂给 MLP,参数量会爆炸,且丢掉了空间信息。你需要 CNN 来提取特征。

RL 中的 CNN 与 CV 中的 CNN 有何不同?

  • 慎用 Max Pooling(池化层)
    在图像分类(如猫狗识别)中,不管是左上角的猫还是右下角的猫,都是猫。这就是平移不变性 ,Pooling 对此很有帮助。
    但在 RL 中,位置至关重要 !Pong 游戏中,球在左边你要往左跑,球在右边你要往右跑。Pooling 会丢失精确的位置信息。
    • 替代方案 :使用 Strided Convolution(步长卷积) 来降维,而不是 Pooling。
  • Nature CNN 架构
    DQN 论文中提出的经典架构至今仍是标杆:
    • Conv1: 32 filters, size 8x8, stride 4
    • Conv2: 64 filters, size 4x4, stride 2
    • Conv3: 64 filters, size 3x3, stride 1
    • Flatten -> MLP Head

3. 帧堆叠(Frame Stacking)

一张静态图片无法告诉 Agent 速度加速度

如果只给一张 Pong 的截图,你不知道球是往左飞还是往右飞。
解决方案 :将最近的 4 帧图像叠加在一起,作为一个通道数为 4 的输入 (4,84,84)(4, 84, 84)(4,84,84)。这样 CNN 就能在通道维度上"看见"运动。


三、RNN/LSTM/GRU:解决"失忆症"

1. 适用场景

  • 状态(State):部分可观测(POMDP)。例如:第一人称射击游戏(你看不到背后的敌人)、扑克牌(你需要记住对手之前的下注习惯)、即时战略游戏的迷雾。
  • 特点 :当前的观测 oto_tot 不足以包含决策所需的所有信息,决策依赖于历史轨迹

2. 设计哲学

当环境不仅是关于"现在",而是关于"过去"时,MLP 和 CNN 就失效了。你需要记忆。

GRU vs LSTM:

在 RL 中,GRU (Gated Recurrent Unit) 通常比 LSTM 更受欢迎。

  • 参数更少:训练更快。
  • 收敛更好:RL 的样本本来就噪声大,复杂的 LSTM 容易过拟合或难以训练。

训练难点:

RNN 在 RL 中的训练极其痛苦。

  • 序列长度:你需要把一整段 Episode(或截断的一段)喂进去。
  • 隐藏状态管理:在采样时,你需要维护每个 Environment 的 hidden state,并在 Episode 结束时重置它。
  • R2D2 / DRQN:这些算法专门优化了带 RNN 的训练过程。

四、网络输出头(The Head):决策的出口

无论前面的骨干网络(Backbone)是 MLP、CNN 还是 RNN,最后都需要一个"头"来输出动作。

1. 离散动作空间 (Discrete)

  • 比如:上、下、左、右。
  • 输出:Linear 层输出维度 = 动作数量。
  • 激活:Softmax(输出概率分布)。
  • 采样:根据 Categorical 分布采样。

2. 连续动作空间 (Continuous)

  • 比如:方向盘转动角度 (-1.0 到 1.0)。
  • 输出 :通常输出两个向量:
    1. 均值 (Mu):经过 Tanh 激活,映射到动作范围。
    2. 标准差 (Log_Std):通常是独立的可学习参数,或者由网络输出(经过 Softplus 保证为正)。
  • 采样 :构建高斯分布(Gaussian Distribution),从中采样,然后用 tanh 压缩。

五、终极缝合怪:多模态架构

现代复杂的 RL 任务(如机器人管家)往往是多模态的。

  • 眼睛:RGB 摄像头 (CNN)
  • 身体:关节传感器 (MLP)
  • 任务:自然语言指令 "去把蓝色的杯子拿来" (Transformer/Embedding)

设计思路:

  1. 分治(Divide):用 CNN 处理图像,用 MLP 处理传感器数据,用 Embedding 处理文本。
  2. 融合(Conquer):将它们提取出的特征向量(Embedding)拼接(Concatenate)在一起。
  3. 决策:将拼接后的向量喂给一个最终的 MLP 决策头。

Features=Concat[CNN(Image),MLP(Sensor)] \text{Features} = \text{Concat}[ \text{CNN}(\text{Image}), \text{MLP}(\text{Sensor}) ] Features=Concat[CNN(Image),MLP(Sensor)]
π(a∣s)=MLPHead(Features) \pi(a|s) = \text{MLP}_{\text{Head}}(\text{Features}) π(a∣s)=MLPHead(Features)


结语:没有最好的架构,只有最适合的

设计策略网络时,请遵循奥卡姆剃刀原则如无必要,勿增实体。

  1. 能用 MLP 解决的,绝不上 CNN。
  2. 能用 Frame Stacking 解决的(通过堆叠几帧感知速度),绝不上 RNN
  3. 只有当环境完全无法通过单帧信息判断状态时(严重的 POMDP),才考虑 LSTM/GRU 或 Transformer。

RL 的训练本身已经极其不稳定了,不要让复杂的网络架构成为压死骆驼的最后一根稻草。简单,往往就是最强。


给读者的总结表

场景 推荐架构 关键技巧
MuJoCo / 机器人关节控制 MLP 2-3层,Tanh 激活,LayerNorm
Atari / 像素游戏 CNN 慎用 Pooling,使用 stride 卷积,Frame Stacking
FPS / 迷雾环境 / 扑克 CNN + GRU 维护 Hidden State,处理变长序列
多模态输入 Hybrid 各自提取特征后 Concat,再过 MLP
星际争霸 / Dota 2 Transformer Attention 机制处理海量单位实体
相关推荐
胡闹542 小时前
海康和大华厂商的RTSP取流地址格式进行拉流直播
java·网络
志凌海纳SmartX2 小时前
AI知识科普丨什么是 AI Agent?
人工智能
RockHopper20252 小时前
认知导向即面向服务——规避未来AI发展路径上的拟人化陷阱
人工智能·认知导向·xai 可解释人工智能
神算大模型APi--天枢6462 小时前
全栈自主可控:国产算力平台重塑大模型后端开发与部署生态
大数据·前端·人工智能·架构·硬件架构
@鱼香肉丝没有鱼2 小时前
Transformer底层原理—位置编码
人工智能·深度学习·transformer·位置编码
yiersansiwu123d2 小时前
AI大模型的进化与平衡:在技术突破与伦理治理中前行
人工智能
木卫二号Coding2 小时前
第六十一篇-ComfyUI+V100-32G+GGUF+运行Flux Schnell GGUF
人工智能
青啊青斯2 小时前
二、PaddlePaddle seal_recognition印章内容提取
人工智能·r语言·paddlepaddle
深度学习实战训练营2 小时前
HRNet:深度高分辨率表示学习用于人体姿态估计-k学长深度学习专栏
人工智能·深度学习