用于决策的世界模型 -- 论文 World Models (2018) & PlaNet (2019) 讲解

参考资料:

世界模型

简介

世界模型 :一种理解世界当前状态预测其未来动态的工具。

世界模型的两个主要功能

  1. 构建内部表征以理解世界运作机制。
  2. 预测未来状态以模拟和指导决策。

分类

图片来自[2411.14499] Understanding World or Predicting Future? A Comprehensive Survey of World Models.

  • 作者按照模型的侧重点不同,将世界模型分成两个大类,即:
    • Internal Representations.
    • Future Predictions.
  • 经常能在网上刷到的LeCun力推世界模型,说的是JEPA.
  • 左边分支的世界模型也可以做"future prediction",作为学习模型参数过程的一个副产物吧 (视觉模块的reconstruction)。

这里讨论的是两篇world model for decision-making的文章。

World Models (2018)

AI社区中,首篇系统性介绍世界模型的文章。

人类的心理模型

简单可以概括成以下几点:

  • 对于外部世界的大量信息流,人脑能够学习到外部世界时空信息的抽象表示,作为我们对外部世界的"建模"。
  • 我们所看到一切都基于脑中模型对未来的预测。
  • 我们能够基于这个预测模型本能地行动,在面对危险时做出快速的反射性行为。

打棒球的例子: 击球手需要在毫秒级别的时间内决定如何挥棒 ------ 这比视觉信号到大脑的时间还要短。

在之后的世界模型结构和实验中,都可以看到这个心理模型的影子。

模型结构

世界模型主要由两个模块组成:视觉模块、记忆模块。

  1. 视觉模块:将外部世界的高维观测,压缩成低维的特征。

  2. 记忆模块:整合历史信息,预测未来。

控制器会利用世界模型给出的信息进行决策。

视觉模块

作者在文章中使用VAE的Encoder部分作为视觉模块。

记忆模块

作者在文章中使用MDN-RNN作为记忆模块。

  • MDN指的是mixture density networks,就是一个建模混合模型的网络,文中使用的是高斯混合模型 (GMM),此时神经网络除了输出每个高斯分布的均值和标准差,还需要输出用于选择高斯分布的类别分布。
  • MDN会接受一个temperature参数\(\tau\),用于调整不确定性。
  • 在图中,MDN-RNN建模的是\(P(z_{t+1}\mid a_t, z_t, h_t)\).
  • 除了隐状态之外,记忆模块可能还需要建模其他东西,比如奖励\(P(r_{t+1} \mid a_t, z_t, h_t)\),游戏结束的信号\(P(\text{done}_{t+1} \mid a_t, z_t, h_t)\).

NOTE:为什么要使用混合模型,即使VAE的隐变量空间只是一个对角高斯?作者的解释是:混合模型中的离散部分 (选择哪一个高斯组分),有利于建模环境中的离散随机事件。比如说NPC在平静状态和警觉状态下的表现不同。

控制器

作者将整个模型的复杂性都集中到了视觉和记忆模块,有意使得控制器的结构尽可能简单:

\[a_t = W_c[z_t~~h_t] + b_c \]

就是单层的神经网络。

模型训练和实验

文章官网World Models,有gif演示,而且可以试玩模型"梦中"的游戏。

训练

两个实验都是先单独训练世界模型 (无监督):

  1. 使用随机策略收集一系列的游戏图像。
  2. 使用这些图像训练好VAE。
  3. 在训练好的VAE基础上,训练好MDN-RNN。

之后部署世界模型并训练控制器。两个实验的主要区别在部署:

  • Car Racing实验:直接在实际环境部署,训练好了控制器之后,又给出了在模型"梦中"的模拟。

  • VizDoom实验:先在"梦中"部署,训练好了控制器之后,再将整个模型转移到实际环境查看效果。

NOTE:在两个实验中,世界模型都没有建模环境的奖励。第一个实验中,奖励只在训练控制器的时候由实际环境给出;第二个实验中,指标是存活时间,不需要奖励。

REMARK:训练成功之后,模型实际上成为了游戏的"模拟器",学习到了游戏逻辑 (角色中弹后会重新开始)、敌人行为 (按一定时间间隔发射子弹)、物理机制 (子弹飞行速度)等。

实验

Car Racing:

VizDoom:

消融实验1 -- 视觉模块+记忆模块的优越性

在Car Racing中,消融实验显示,单独的视觉模块效果不如一整个的世界模型 (但是也已经超过了DQN和A3C)

消融实验2 -- 用tau调整随机性

在VizDoom实验中,由于模型并非完全精确,控制器可能会利用模型的缺陷来在模拟器中达到高分,一旦部署到实际环境,控制器就不行了。

为了防止这一点,MDN-RNN预测的是具有随机性的环境,并通过调整不确定性参数\(\tau\)来控制随机性。在实验中,\(\tau=1.15\)时效果最好。

当\(\tau=0.1\)时,模型几乎是确定性的,这时候敌人甚至无法发射子弹,所以出现了在模拟器中非常高分,实际环境中却非常低分的情况。

跑分对比实验

  • Car Racing实验:取得的分数超过了先前的基于深度强化学习的方法,如DQN、A3C.
  • VizDoom实验:在梦中学会了如何躲避怪物的子弹,部署到实际环境后的存活时长也超过了先前。

迭代训练过程

本文的实验环境简单,所以是使用随机策略采样,分别训练三个模块。面对更复杂的任务,可能需要三个模块一起训练,但是本文只是提了一下记忆模块和控制器一起训练的流程:

三个模块一起训练的好处是:

  1. 视觉模块会倾向于学习到有利于当前任务的特征。
  2. 记忆模块可以对控制器进行学习,控制器又可以基于记忆模块继续改进,如此往复。
  3. 可以使用训练中的控制器进行轨迹采样而不是随机策略。

Learning Latent Dynamics for Planning from Pixels (2019)

相对于上一篇,这篇的改进:

  1. 假定了环境是部分可观测马尔可夫决策过程 (POMDP),世界模型就是在学习这个POMDP.
  2. 给出了一套结合模型预测控制 (MPC) 方法的训练过程 ------ Deep Planning Network (PlaNet).
  3. 提出基于确定性和随机性结合的状态空间模型 (RSSM),而不是仅有确定性状态的RNN和仅有随机性状态的SSM.
  4. 给出了适用于多步预测的变分推断方法 ------ latent overshooting.

Problem setup

假定实际的环境是一个POMDP:

目标是学习到一个策略,能够最大化期望累积回报\(\mathbb E[\sum r_t]\)。

Deep planning network

这里先讲世界模型+MPC的学习和规划算法。

while循环内部,总体上分成三个部分:模型学习,实时规划+数据收集,更新数据库。

模型学习

从数据库中随机抽取观测序列的小批量,然后使用梯度方法学习。

实时规划+数据收集

总体上就是一个有限时间域的MPC框架,在每个time step按三步走:

  1. Observe:获得当前时刻的状态。由于这里在隐状态空间进行规划,所以需要从历史的观测数据中推断当前状态 (通过隐变量的后验概率)。
  2. Predict and plan:利用当前学习到的模型,解一个有限时间域的最优控制问题,获得一串动作序列。本文中的planner使用的是cross entropy method (CEM).
  3. Act :对环境使用这串动作序列的第一个动作\(a_t\),移动到下一个time step. 这里用了一个trick,把取得的动作\(a_t\)重复了\(R\)次 (用相同的action,连续走了\(R\)步),取reward的总和作为当前时刻的reward,取最终的第\(R\)观测\(o_{t+1}^R\)作为下一个时刻的观测\(o_{t+1}\)。

更新数据库

将上一个部分收集到的观测序列加入到数据库中,以供世界模型的进一步更新。

NOTE:相对于model-free RL算法,model-based planning的一大优势就是数据利用率提高了。体现在planning取得的观测序列可以反复用于世界模型的学习。

RSSM

这种模型也叫:Non-linear Kalman filter, sequential VAE, deep variational bayes filter,看了一眼相关的文章,好像要从头到尾讲明白 (像VAE那样) 比较复杂。

这里浅浅讲一下世界模型的结构以及训练的Loss。

Latent state-space model

使用下面的encoder来近似后验概率:

\[q(s_{\le t} \mid o_{\le t},a_{<t}) = \prod_{t=1}^T q(s_t\mid s_{t-1},a_{t-1},o_t) \]

都使用神经网络参数化的高斯分布表示,其中observation model和encoder用的是卷积网络。

Training Objective

通过最大化log Evidence来训练:

\[\arg\max \ln p(o_{\le t} \mid a_{<t}) \]

接下来推导ELBO.

先拆成边际化的形式

\[\ln p(o_{\le T} \mid a_{<T}) = \ln \int p(o_{\le T}, s_{\le T} \mid a_{<T}) \text{d}s \]

把联合概率拆开

\[\ln p(o_{\le T} \mid a_{<T}) = \ln \int p(o_{\le T} \mid s_{\le T}, a_{<T}) p(s_{\le T} \mid a_{<T}) \text{d}s \]

写成期望的形式

\[\ln p(o_{\le t} \mid a_{<t}) = \ln \mathbb E_{p(s_{\le t}\mid a_{<t})}[ p(o_{\le t} \mid s_{\le t}, a_{<t}) ] \]

利用重要性采样方法,转变成从encoder采样

\[\ln p(o_{\le t} \mid a_{<t}) = \ln \mathbb E_{q(s_{\le t}\mid o_{\le t},a_{<t})}[ p(o_{\le t} \mid s_{\le t}, a_{<t}) p(s_{\le t}\mid a_{<t}) / q(s_{\le t}\mid o_{\le t},a_{<t})] \]

链式分解,并利用模型的条件独立性化简 (概率图参考下面的)

\[\ln p(o_{\le t} \mid a_{<t}) = \ln \mathbb E_{q(s_{\le t}\mid o_{\le t},a_{<t})}[\prod p(o_t \mid s_t) p(s_t \mid s_{t-1},a_{t-1}) / q(s_{\le t}\mid o_{\le t},a_{<t})] \]

根据Jensen不等式,\(\ln \mathbb E[x] \ge \mathbb E[\ln(x)]\)

\[\ln p(o_{\le t} \mid a_{<t}) \ge \mathbb E_{q(s_{\le t}\mid o_{\le t},a_{<t})}[\sum_t \ln p(o_t \mid s_t) + \ln p(s_t \mid s_{t-1},a_{t-1}) - \ln q(s_{\le t}\mid o_{\le t},a_{<t})] \]

右边可以写成reconstruction + KL的形式,最后就是

确定性和随机性结合 - RSSM

  • 纯确定性的世界模型:模型难以预测多种可能的未来情况;容易被planner利用模型缺陷 (在World Models中,通过MDN添加随机性来缓解这一点,但本质还是确定性的)
  • 纯随机性的世界模型:模型难以记住信息,导致产生前后不一致的预测结果。

所以作者考虑将确定性和随机性结合,称这种结构为RSSM.

相对于上一篇,把记忆模块换成了RSSM。

Latent Overshooting

之前讨论的都是\(s_t \to s_{t+1}\)的单步预测,如果每次单步预测都准确无误,那多步预测肯定也没问题。但是由于模型本身有局限,所以不一定能很好的推广到多部预测。

于是作者考虑了直接进行跨步的预测,先通过对中间几步隐变量边际化得到了跨步预测的转移

并且推导了针对跨步预测的变分bound

把考虑不同的步幅\(d\),求和,就得到latent overshooting的目标函数

实验结果

DeepMind control suite环境:图像作为观测,连续动作空间。

消融实验

  • 验证PlaNet的数据收集过程有优势。Random Collection指的是用随机策略收集数据而不是通过MPC;Random shooting指的是使用了MPC框架,但是不使用CEM,而是直接从1000条随机采的动作序列里选最好的那条。最后PlaNet在大部分情况都明显好于另外两种。

  • RSSM和SSM、GRU的对比。观察到RSSM明显好于后两者,表明了确定性+随机性结合的优势。

  • 是否加入latent overshooting作为变分目标。观察到Latent overshooting使RSSM的表现轻微变差,但是在一些任务上让DRNN的表现变好了。

跑分对比实验

  • PlaNet的分数能打败A3C。
  • PlaNet的分数总体不如D4PG,但是大部分任务相差不多。
  • PlaNet在所有任务上,数据利用率都好于D4PG.
  • PlaNet (CEM + 世界模型) 和 CEM + true simulator对比只差了一些,体现出世界模型较好地学习到了环境。

六个任务一起训练

每次循环中,agent面对的可能是不同的环境,所以数据库中抽取出来的轨迹也是打乱的。

最后跑分不如单独训练,但是体现出了agent能够自己判断出面对的是哪个任务了。

代码选讲

代码来自:Kaixhin/PlaNet: Deep Planning Network: Control from pixels by latent planning with learned dynamics

主要是看看transition model和模型训练过程。解释都在注释里,有部分注释是代码库原有的。

Transition model

python 复制代码
class TransitionModel(jit.ScriptModule):
  __constants__ = ['min_std_dev']

  def __init__(self, belief_size, state_size, action_size, hidden_size, embedding_size, activation_function='relu', min_std_dev=0.1):
    super().__init__()
    self.act_fn = getattr(F, activation_function)
    self.min_std_dev = min_std_dev
    self.fc_embed_state_action = nn.Linear(state_size + action_size, belief_size) # combine s_t and a_t to comb(s_t, a_t)
    self.rnn = nn.GRUCell(belief_size, belief_size) # from comb(s_t, a_t), h_t to h_t+1
    self.fc_embed_belief_prior = nn.Linear(belief_size, hidden_size) # from h_t to z_t
    self.fc_state_prior = nn.Linear(hidden_size, 2 * state_size) # parameterized prior of s_t, from z_t to mean and std
    self.fc_embed_belief_posterior = nn.Linear(belief_size + embedding_size, hidden_size) # from h_t and e_t to z_t
    self.fc_state_posterior = nn.Linear(hidden_size, 2 * state_size) # parameterized posterior of s_t, from z_t to mean and std

  # Operates over (previous) state, (previous) actions, (previous) belief, (previous) nonterminals (mask), and (current) observations
  # Diagram of expected inputs and outputs for T = 5 (-x- signifying beginning of output belief/state that gets sliced off):
  # t :  0  1  2  3  4  5
  # o :    -X--X--X--X--X-  设置了初始的隐状态是None,所以不考虑0时刻的obs
  # a : -X--X--X--X--X-     不考虑最后一个action,因为最后一个action没有后续的obs
  # n : -X--X--X--X--X-
  # pb: -X-
  # ps: -X-
  # b : -x--X--X--X--X--X-
  # s : -x--X--X--X--X--X-

  # 输入的shape都是(time_step, batch_size, *)
  @jit.script_method
  def forward(self, prev_state:torch.Tensor, actions:torch.Tensor, prev_belief:torch.Tensor, observations:Optional[torch.Tensor]=None, nonterminals:Optional[torch.Tensor]=None) -> List[torch.Tensor]:
    # 后面都是动态更新,为了保留grad,不能使用单个tensor作为buffer,所以创建了几个list
    T = actions.size(0) + 1 # 实际需要的list长度,参考上面的图
    beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = \
      [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T
    beliefs[0], prior_states[0], posterior_states[0] = prev_belief, prev_state, prev_state # 0时刻赋初值

    # 每次循环开始,是已知t时刻的信息,进一步计算t+1时刻的信息
    for t in range(T - 1):
      # 根据情况合适的s,因为模型可以在脱离observations的情况下自己预测
      # 如果observations为None,则使用先验状态 (模型一步步生成出来的),否则使用后验状态 (根据历史的obs和action推断出来的)
      _state = prior_states[t] if observations is None else posterior_states[t] 
      # terminal则说明这段序列已经结束了,所以把状态mask掉 (就是0)
      _state = _state if nonterminals is None else _state * nonterminals[t]  

      # 注意下面每一块的hidden是临时变量,表示的是不同的意思

      # 计算确定性隐状态h = f(s_t, a_t, h_t)
      hidden = self.act_fn(self.fc_embed_state_action(torch.cat([_state, actions[t]], dim=1))) # s和a先拼在一起
      beliefs[t + 1] = self.rnn(hidden, beliefs[t]) # 对应概率图中从s,a,h到h的实线

      # 计算隐状态s的先验 p(s_t|s_t-1,a_t-1)
      hidden = self.act_fn(self.fc_embed_belief_prior(beliefs[t + 1])) # 对应概率图中从h到s的实线
      prior_means[t + 1], _prior_std_dev = torch.chunk(self.fc_state_prior(hidden), 2, dim=1)
      prior_std_devs[t + 1] = F.softplus(_prior_std_dev) + self.min_std_dev # Trick: 使用softplus来保证std_devs为正,并且使用min_std_dev来保证std_devs不会太小
      prior_states[t + 1] = prior_means[t + 1] + prior_std_devs[t + 1] * torch.randn_like(prior_means[t + 1])     

      # 计算隐状态s的后验 q(s_t|o≤t,a<t)
      if observations is not None: # 只有observations不为None时,才计算后验
        t_ = t - 1  # 这是实现的问题,因为传进来的是obs[1:],所以应该用t_+1才能索引到对应的obs
        hidden = self.act_fn(self.fc_embed_belief_posterior(torch.cat([beliefs[t + 1], observations[t_ + 1]], dim=1))) # 对应概率图中的两条虚线
        posterior_means[t + 1], _posterior_std_dev = torch.chunk(self.fc_state_posterior(hidden), 2, dim=1)
        posterior_std_devs[t + 1] = F.softplus(_posterior_std_dev) + self.min_std_dev
        posterior_states[t + 1] = posterior_means[t + 1] + posterior_std_devs[t + 1] * torch.randn_like(posterior_means[t + 1])

    # 返回h,s,以及先验和后验的均值和方差
    hidden = [torch.stack(beliefs[1:], dim=0), torch.stack(prior_states[1:], dim=0), torch.stack(prior_means[1:], dim=0), torch.stack(prior_std_devs[1:], dim=0)]
    if observations is not None:
      hidden += [torch.stack(posterior_states[1:], dim=0), torch.stack(posterior_means[1:], dim=0), torch.stack(posterior_std_devs[1:], dim=0)]
    return hidden

世界模型训练

只截取了一小部分,重点看loss func是如何计算的。

python 复制代码
  # Model fitting
  losses = []
  for s in tqdm(range(args.collect_interval)):
    # Draw sequence chunks {(o_t, a_t, r_t+1, terminal_t+1)} ~ D uniformly at random from the dataset (including terminal flags)
    observations, actions, rewards, nonterminals = D.sample(args.batch_size, args.chunk_size)  # Transitions start at time t = 0

    # Create initial belief and state for time t = 0
    init_belief, init_state = torch.zeros(args.batch_size, args.belief_size, device=args.device), torch.zeros(args.batch_size, args.state_size, device=args.device)

    # Update belief/state using posterior from previous belief/state, previous action and current observation (over entire sequence at once)
    # 一次把整个隐状态序列全部计算出来
    beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs =\
      transition_model(init_state, actions[:-1], init_belief, bottle(encoder, (observations[1:], )), nonterminals[:-1])

    # Calculate observation likelihood, reward likelihood and KL losses (for t = 0 only for latent overshooting); sum over final dims, average over batch and time (original implementation, though paper seems to miss 1/T scaling?)
    # Reconstruction loss都使用MSE
    # mean(dim=(0, 1))对batch和time进行平均
    observation_loss =\
      F.mse_loss(bottle(observation_model, (beliefs, posterior_states)), observations[1:], reduction='none').sum(dim=2 if args.symbolic_env else (2, 3, 4)).mean(dim=(0, 1))
    reward_loss =\
      F.mse_loss(bottle(reward_model, (beliefs, posterior_states)), rewards[:-1], reduction='none').mean(dim=(0, 1))
    # KL loss, 计算了后验q(s_t|o≤t,a<t)和先验p(s_t|s_t-1,a_t-1)的KL散度
    kl_loss =\
      torch.max(kl_divergence(Normal(posterior_means, posterior_std_devs), Normal(prior_means, prior_std_devs)).sum(dim=2), free_nats).mean(dim=(0, 1))  # Note that normalisation by overshooting distance and weighting by overshooting distance cancel out

"""
后面的部分略
"""
相关推荐
盼小辉丶4 小时前
Transformer实战(4)——从零开始构建Transformer
pytorch·深度学习·transformer
paid槮9 小时前
机器学习总结
人工智能·深度学习·机器学习
失散1313 小时前
深度学习——02 PyTorch
人工智能·pytorch·深度学习
图灵学术计算机论文辅导13 小时前
傅里叶变换+attention机制,深耕深度学习领域
人工智能·python·深度学习·计算机网络·考研·机器学习·计算机视觉
楚韵天工17 小时前
基于多分类的工业异常声检测及应用
人工智能·深度学习·神经网络·目标检测·机器学习·分类·数据挖掘
老艾的AI世界19 小时前
AI去、穿、换装软件下载,无内容限制,偷偷收藏
图像处理·人工智能·深度学习·神经网络·目标检测·机器学习·ai·换装·虚拟试衣·ai换装·一键换装
软件测试-阿涛20 小时前
【AI绘画】Stable Diffusion webUI 常用功能使用技巧
人工智能·深度学习·计算机视觉·ai作画·stable diffusion
盼小辉丶21 小时前
Transformer实战(11)——从零开始构建GPT模型
gpt·深度学习·transformer
计算机sci论文精选1 天前
CVPR2025敲门砖丨机器人结合多模态+时空Transformer直冲高分,让你的论文不再灌水
人工智能·科技·深度学习·机器人·transformer·cvpr
华清远见成都中心1 天前
基于深度学习的异常检测算法在时间序列数据中的应用
人工智能·深度学习·算法