用于决策的世界模型 -- 论文 World Models (2018) & PlaNet (2019) 讲解

参考资料:

世界模型

简介

世界模型 :一种理解世界当前状态预测其未来动态的工具。

世界模型的两个主要功能

  1. 构建内部表征以理解世界运作机制。
  2. 预测未来状态以模拟和指导决策。

分类

图片来自[2411.14499] Understanding World or Predicting Future? A Comprehensive Survey of World Models.

  • 作者按照模型的侧重点不同,将世界模型分成两个大类,即:
    • Internal Representations.
    • Future Predictions.
  • 经常能在网上刷到的LeCun力推世界模型,说的是JEPA.
  • 左边分支的世界模型也可以做"future prediction",作为学习模型参数过程的一个副产物吧 (视觉模块的reconstruction)。

这里讨论的是两篇world model for decision-making的文章。

World Models (2018)

AI社区中,首篇系统性介绍世界模型的文章。

人类的心理模型

简单可以概括成以下几点:

  • 对于外部世界的大量信息流,人脑能够学习到外部世界时空信息的抽象表示,作为我们对外部世界的"建模"。
  • 我们所看到一切都基于脑中模型对未来的预测。
  • 我们能够基于这个预测模型本能地行动,在面对危险时做出快速的反射性行为。

打棒球的例子: 击球手需要在毫秒级别的时间内决定如何挥棒 ------ 这比视觉信号到大脑的时间还要短。

在之后的世界模型结构和实验中,都可以看到这个心理模型的影子。

模型结构

世界模型主要由两个模块组成:视觉模块、记忆模块。

  1. 视觉模块:将外部世界的高维观测,压缩成低维的特征。

  2. 记忆模块:整合历史信息,预测未来。

控制器会利用世界模型给出的信息进行决策。

视觉模块

作者在文章中使用VAE的Encoder部分作为视觉模块。

记忆模块

作者在文章中使用MDN-RNN作为记忆模块。

  • MDN指的是mixture density networks,就是一个建模混合模型的网络,文中使用的是高斯混合模型 (GMM),此时神经网络除了输出每个高斯分布的均值和标准差,还需要输出用于选择高斯分布的类别分布。
  • MDN会接受一个temperature参数\(\tau\),用于调整不确定性。
  • 在图中,MDN-RNN建模的是\(P(z_{t+1}\mid a_t, z_t, h_t)\).
  • 除了隐状态之外,记忆模块可能还需要建模其他东西,比如奖励\(P(r_{t+1} \mid a_t, z_t, h_t)\),游戏结束的信号\(P(\text{done}_{t+1} \mid a_t, z_t, h_t)\).

NOTE:为什么要使用混合模型,即使VAE的隐变量空间只是一个对角高斯?作者的解释是:混合模型中的离散部分 (选择哪一个高斯组分),有利于建模环境中的离散随机事件。比如说NPC在平静状态和警觉状态下的表现不同。

控制器

作者将整个模型的复杂性都集中到了视觉和记忆模块,有意使得控制器的结构尽可能简单:

\[a_t = W_c[z_t~~h_t] + b_c \]

就是单层的神经网络。

模型训练和实验

文章官网World Models,有gif演示,而且可以试玩模型"梦中"的游戏。

训练

两个实验都是先单独训练世界模型 (无监督):

  1. 使用随机策略收集一系列的游戏图像。
  2. 使用这些图像训练好VAE。
  3. 在训练好的VAE基础上,训练好MDN-RNN。

之后部署世界模型并训练控制器。两个实验的主要区别在部署:

  • Car Racing实验:直接在实际环境部署,训练好了控制器之后,又给出了在模型"梦中"的模拟。

  • VizDoom实验:先在"梦中"部署,训练好了控制器之后,再将整个模型转移到实际环境查看效果。

NOTE:在两个实验中,世界模型都没有建模环境的奖励。第一个实验中,奖励只在训练控制器的时候由实际环境给出;第二个实验中,指标是存活时间,不需要奖励。

REMARK:训练成功之后,模型实际上成为了游戏的"模拟器",学习到了游戏逻辑 (角色中弹后会重新开始)、敌人行为 (按一定时间间隔发射子弹)、物理机制 (子弹飞行速度)等。

实验

Car Racing:

VizDoom:

消融实验1 -- 视觉模块+记忆模块的优越性

在Car Racing中,消融实验显示,单独的视觉模块效果不如一整个的世界模型 (但是也已经超过了DQN和A3C)

消融实验2 -- 用tau调整随机性

在VizDoom实验中,由于模型并非完全精确,控制器可能会利用模型的缺陷来在模拟器中达到高分,一旦部署到实际环境,控制器就不行了。

为了防止这一点,MDN-RNN预测的是具有随机性的环境,并通过调整不确定性参数\(\tau\)来控制随机性。在实验中,\(\tau=1.15\)时效果最好。

当\(\tau=0.1\)时,模型几乎是确定性的,这时候敌人甚至无法发射子弹,所以出现了在模拟器中非常高分,实际环境中却非常低分的情况。

跑分对比实验

  • Car Racing实验:取得的分数超过了先前的基于深度强化学习的方法,如DQN、A3C.
  • VizDoom实验:在梦中学会了如何躲避怪物的子弹,部署到实际环境后的存活时长也超过了先前。

迭代训练过程

本文的实验环境简单,所以是使用随机策略采样,分别训练三个模块。面对更复杂的任务,可能需要三个模块一起训练,但是本文只是提了一下记忆模块和控制器一起训练的流程:

三个模块一起训练的好处是:

  1. 视觉模块会倾向于学习到有利于当前任务的特征。
  2. 记忆模块可以对控制器进行学习,控制器又可以基于记忆模块继续改进,如此往复。
  3. 可以使用训练中的控制器进行轨迹采样而不是随机策略。

Learning Latent Dynamics for Planning from Pixels (2019)

相对于上一篇,这篇的改进:

  1. 假定了环境是部分可观测马尔可夫决策过程 (POMDP),世界模型就是在学习这个POMDP.
  2. 给出了一套结合模型预测控制 (MPC) 方法的训练过程 ------ Deep Planning Network (PlaNet).
  3. 提出基于确定性和随机性结合的状态空间模型 (RSSM),而不是仅有确定性状态的RNN和仅有随机性状态的SSM.
  4. 给出了适用于多步预测的变分推断方法 ------ latent overshooting.

Problem setup

假定实际的环境是一个POMDP:

目标是学习到一个策略,能够最大化期望累积回报\(\mathbb E[\sum r_t]\)。

Deep planning network

这里先讲世界模型+MPC的学习和规划算法。

while循环内部,总体上分成三个部分:模型学习,实时规划+数据收集,更新数据库。

模型学习

从数据库中随机抽取观测序列的小批量,然后使用梯度方法学习。

实时规划+数据收集

总体上就是一个有限时间域的MPC框架,在每个time step按三步走:

  1. Observe:获得当前时刻的状态。由于这里在隐状态空间进行规划,所以需要从历史的观测数据中推断当前状态 (通过隐变量的后验概率)。
  2. Predict and plan:利用当前学习到的模型,解一个有限时间域的最优控制问题,获得一串动作序列。本文中的planner使用的是cross entropy method (CEM).
  3. Act :对环境使用这串动作序列的第一个动作\(a_t\),移动到下一个time step. 这里用了一个trick,把取得的动作\(a_t\)重复了\(R\)次 (用相同的action,连续走了\(R\)步),取reward的总和作为当前时刻的reward,取最终的第\(R\)观测\(o_{t+1}^R\)作为下一个时刻的观测\(o_{t+1}\)。

更新数据库

将上一个部分收集到的观测序列加入到数据库中,以供世界模型的进一步更新。

NOTE:相对于model-free RL算法,model-based planning的一大优势就是数据利用率提高了。体现在planning取得的观测序列可以反复用于世界模型的学习。

RSSM

这种模型也叫:Non-linear Kalman filter, sequential VAE, deep variational bayes filter,看了一眼相关的文章,好像要从头到尾讲明白 (像VAE那样) 比较复杂。

这里浅浅讲一下世界模型的结构以及训练的Loss。

Latent state-space model

使用下面的encoder来近似后验概率:

\[q(s_{\le t} \mid o_{\le t},a_{<t}) = \prod_{t=1}^T q(s_t\mid s_{t-1},a_{t-1},o_t) \]

都使用神经网络参数化的高斯分布表示,其中observation model和encoder用的是卷积网络。

Training Objective

通过最大化log Evidence来训练:

\[\arg\max \ln p(o_{\le t} \mid a_{<t}) \]

接下来推导ELBO.

先拆成边际化的形式

\[\ln p(o_{\le T} \mid a_{<T}) = \ln \int p(o_{\le T}, s_{\le T} \mid a_{<T}) \text{d}s \]

把联合概率拆开

\[\ln p(o_{\le T} \mid a_{<T}) = \ln \int p(o_{\le T} \mid s_{\le T}, a_{<T}) p(s_{\le T} \mid a_{<T}) \text{d}s \]

写成期望的形式

\[\ln p(o_{\le t} \mid a_{<t}) = \ln \mathbb E_{p(s_{\le t}\mid a_{<t})}[ p(o_{\le t} \mid s_{\le t}, a_{<t}) ] \]

利用重要性采样方法,转变成从encoder采样

\[\ln p(o_{\le t} \mid a_{<t}) = \ln \mathbb E_{q(s_{\le t}\mid o_{\le t},a_{<t})}[ p(o_{\le t} \mid s_{\le t}, a_{<t}) p(s_{\le t}\mid a_{<t}) / q(s_{\le t}\mid o_{\le t},a_{<t})] \]

链式分解,并利用模型的条件独立性化简 (概率图参考下面的)

\[\ln p(o_{\le t} \mid a_{<t}) = \ln \mathbb E_{q(s_{\le t}\mid o_{\le t},a_{<t})}[\prod p(o_t \mid s_t) p(s_t \mid s_{t-1},a_{t-1}) / q(s_{\le t}\mid o_{\le t},a_{<t})] \]

根据Jensen不等式,\(\ln \mathbb E[x] \ge \mathbb E[\ln(x)]\)

\[\ln p(o_{\le t} \mid a_{<t}) \ge \mathbb E_{q(s_{\le t}\mid o_{\le t},a_{<t})}[\sum_t \ln p(o_t \mid s_t) + \ln p(s_t \mid s_{t-1},a_{t-1}) - \ln q(s_{\le t}\mid o_{\le t},a_{<t})] \]

右边可以写成reconstruction + KL的形式,最后就是

确定性和随机性结合 - RSSM

  • 纯确定性的世界模型:模型难以预测多种可能的未来情况;容易被planner利用模型缺陷 (在World Models中,通过MDN添加随机性来缓解这一点,但本质还是确定性的)
  • 纯随机性的世界模型:模型难以记住信息,导致产生前后不一致的预测结果。

所以作者考虑将确定性和随机性结合,称这种结构为RSSM.

相对于上一篇,把记忆模块换成了RSSM。

Latent Overshooting

之前讨论的都是\(s_t \to s_{t+1}\)的单步预测,如果每次单步预测都准确无误,那多步预测肯定也没问题。但是由于模型本身有局限,所以不一定能很好的推广到多部预测。

于是作者考虑了直接进行跨步的预测,先通过对中间几步隐变量边际化得到了跨步预测的转移

并且推导了针对跨步预测的变分bound

把考虑不同的步幅\(d\),求和,就得到latent overshooting的目标函数

实验结果

DeepMind control suite环境:图像作为观测,连续动作空间。

消融实验

  • 验证PlaNet的数据收集过程有优势。Random Collection指的是用随机策略收集数据而不是通过MPC;Random shooting指的是使用了MPC框架,但是不使用CEM,而是直接从1000条随机采的动作序列里选最好的那条。最后PlaNet在大部分情况都明显好于另外两种。

  • RSSM和SSM、GRU的对比。观察到RSSM明显好于后两者,表明了确定性+随机性结合的优势。

  • 是否加入latent overshooting作为变分目标。观察到Latent overshooting使RSSM的表现轻微变差,但是在一些任务上让DRNN的表现变好了。

跑分对比实验

  • PlaNet的分数能打败A3C。
  • PlaNet的分数总体不如D4PG,但是大部分任务相差不多。
  • PlaNet在所有任务上,数据利用率都好于D4PG.
  • PlaNet (CEM + 世界模型) 和 CEM + true simulator对比只差了一些,体现出世界模型较好地学习到了环境。

六个任务一起训练

每次循环中,agent面对的可能是不同的环境,所以数据库中抽取出来的轨迹也是打乱的。

最后跑分不如单独训练,但是体现出了agent能够自己判断出面对的是哪个任务了。

代码选讲

代码来自:Kaixhin/PlaNet: Deep Planning Network: Control from pixels by latent planning with learned dynamics

主要是看看transition model和模型训练过程。解释都在注释里,有部分注释是代码库原有的。

Transition model

python 复制代码
class TransitionModel(jit.ScriptModule):
  __constants__ = ['min_std_dev']

  def __init__(self, belief_size, state_size, action_size, hidden_size, embedding_size, activation_function='relu', min_std_dev=0.1):
    super().__init__()
    self.act_fn = getattr(F, activation_function)
    self.min_std_dev = min_std_dev
    self.fc_embed_state_action = nn.Linear(state_size + action_size, belief_size) # combine s_t and a_t to comb(s_t, a_t)
    self.rnn = nn.GRUCell(belief_size, belief_size) # from comb(s_t, a_t), h_t to h_t+1
    self.fc_embed_belief_prior = nn.Linear(belief_size, hidden_size) # from h_t to z_t
    self.fc_state_prior = nn.Linear(hidden_size, 2 * state_size) # parameterized prior of s_t, from z_t to mean and std
    self.fc_embed_belief_posterior = nn.Linear(belief_size + embedding_size, hidden_size) # from h_t and e_t to z_t
    self.fc_state_posterior = nn.Linear(hidden_size, 2 * state_size) # parameterized posterior of s_t, from z_t to mean and std

  # Operates over (previous) state, (previous) actions, (previous) belief, (previous) nonterminals (mask), and (current) observations
  # Diagram of expected inputs and outputs for T = 5 (-x- signifying beginning of output belief/state that gets sliced off):
  # t :  0  1  2  3  4  5
  # o :    -X--X--X--X--X-  设置了初始的隐状态是None,所以不考虑0时刻的obs
  # a : -X--X--X--X--X-     不考虑最后一个action,因为最后一个action没有后续的obs
  # n : -X--X--X--X--X-
  # pb: -X-
  # ps: -X-
  # b : -x--X--X--X--X--X-
  # s : -x--X--X--X--X--X-

  # 输入的shape都是(time_step, batch_size, *)
  @jit.script_method
  def forward(self, prev_state:torch.Tensor, actions:torch.Tensor, prev_belief:torch.Tensor, observations:Optional[torch.Tensor]=None, nonterminals:Optional[torch.Tensor]=None) -> List[torch.Tensor]:
    # 后面都是动态更新,为了保留grad,不能使用单个tensor作为buffer,所以创建了几个list
    T = actions.size(0) + 1 # 实际需要的list长度,参考上面的图
    beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = \
      [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T
    beliefs[0], prior_states[0], posterior_states[0] = prev_belief, prev_state, prev_state # 0时刻赋初值

    # 每次循环开始,是已知t时刻的信息,进一步计算t+1时刻的信息
    for t in range(T - 1):
      # 根据情况合适的s,因为模型可以在脱离observations的情况下自己预测
      # 如果observations为None,则使用先验状态 (模型一步步生成出来的),否则使用后验状态 (根据历史的obs和action推断出来的)
      _state = prior_states[t] if observations is None else posterior_states[t] 
      # terminal则说明这段序列已经结束了,所以把状态mask掉 (就是0)
      _state = _state if nonterminals is None else _state * nonterminals[t]  

      # 注意下面每一块的hidden是临时变量,表示的是不同的意思

      # 计算确定性隐状态h = f(s_t, a_t, h_t)
      hidden = self.act_fn(self.fc_embed_state_action(torch.cat([_state, actions[t]], dim=1))) # s和a先拼在一起
      beliefs[t + 1] = self.rnn(hidden, beliefs[t]) # 对应概率图中从s,a,h到h的实线

      # 计算隐状态s的先验 p(s_t|s_t-1,a_t-1)
      hidden = self.act_fn(self.fc_embed_belief_prior(beliefs[t + 1])) # 对应概率图中从h到s的实线
      prior_means[t + 1], _prior_std_dev = torch.chunk(self.fc_state_prior(hidden), 2, dim=1)
      prior_std_devs[t + 1] = F.softplus(_prior_std_dev) + self.min_std_dev # Trick: 使用softplus来保证std_devs为正,并且使用min_std_dev来保证std_devs不会太小
      prior_states[t + 1] = prior_means[t + 1] + prior_std_devs[t + 1] * torch.randn_like(prior_means[t + 1])     

      # 计算隐状态s的后验 q(s_t|o≤t,a<t)
      if observations is not None: # 只有observations不为None时,才计算后验
        t_ = t - 1  # 这是实现的问题,因为传进来的是obs[1:],所以应该用t_+1才能索引到对应的obs
        hidden = self.act_fn(self.fc_embed_belief_posterior(torch.cat([beliefs[t + 1], observations[t_ + 1]], dim=1))) # 对应概率图中的两条虚线
        posterior_means[t + 1], _posterior_std_dev = torch.chunk(self.fc_state_posterior(hidden), 2, dim=1)
        posterior_std_devs[t + 1] = F.softplus(_posterior_std_dev) + self.min_std_dev
        posterior_states[t + 1] = posterior_means[t + 1] + posterior_std_devs[t + 1] * torch.randn_like(posterior_means[t + 1])

    # 返回h,s,以及先验和后验的均值和方差
    hidden = [torch.stack(beliefs[1:], dim=0), torch.stack(prior_states[1:], dim=0), torch.stack(prior_means[1:], dim=0), torch.stack(prior_std_devs[1:], dim=0)]
    if observations is not None:
      hidden += [torch.stack(posterior_states[1:], dim=0), torch.stack(posterior_means[1:], dim=0), torch.stack(posterior_std_devs[1:], dim=0)]
    return hidden

世界模型训练

只截取了一小部分,重点看loss func是如何计算的。

python 复制代码
  # Model fitting
  losses = []
  for s in tqdm(range(args.collect_interval)):
    # Draw sequence chunks {(o_t, a_t, r_t+1, terminal_t+1)} ~ D uniformly at random from the dataset (including terminal flags)
    observations, actions, rewards, nonterminals = D.sample(args.batch_size, args.chunk_size)  # Transitions start at time t = 0

    # Create initial belief and state for time t = 0
    init_belief, init_state = torch.zeros(args.batch_size, args.belief_size, device=args.device), torch.zeros(args.batch_size, args.state_size, device=args.device)

    # Update belief/state using posterior from previous belief/state, previous action and current observation (over entire sequence at once)
    # 一次把整个隐状态序列全部计算出来
    beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs =\
      transition_model(init_state, actions[:-1], init_belief, bottle(encoder, (observations[1:], )), nonterminals[:-1])

    # Calculate observation likelihood, reward likelihood and KL losses (for t = 0 only for latent overshooting); sum over final dims, average over batch and time (original implementation, though paper seems to miss 1/T scaling?)
    # Reconstruction loss都使用MSE
    # mean(dim=(0, 1))对batch和time进行平均
    observation_loss =\
      F.mse_loss(bottle(observation_model, (beliefs, posterior_states)), observations[1:], reduction='none').sum(dim=2 if args.symbolic_env else (2, 3, 4)).mean(dim=(0, 1))
    reward_loss =\
      F.mse_loss(bottle(reward_model, (beliefs, posterior_states)), rewards[:-1], reduction='none').mean(dim=(0, 1))
    # KL loss, 计算了后验q(s_t|o≤t,a<t)和先验p(s_t|s_t-1,a_t-1)的KL散度
    kl_loss =\
      torch.max(kl_divergence(Normal(posterior_means, posterior_std_devs), Normal(prior_means, prior_std_devs)).sum(dim=2), free_nats).mean(dim=(0, 1))  # Note that normalisation by overshooting distance and weighting by overshooting distance cancel out

"""
后面的部分略
"""
相关推荐
KeyPan6 小时前
【机器学习:十九、反向传播】
人工智能·深度学习·机器学习
m0_743106468 小时前
【论文笔记】多个大规模数据集上的SOTA绝对位姿回归方法:Reloc3r
论文阅读·深度学习·计算机视觉·3d·几何学
HyperAI超神经10 小时前
微软与腾讯技术交锋,TRELLIS引领3D生成领域多格式支持新方向
人工智能·深度学习·机器学习·计算机视觉·3d·大模型·数据集
goomind12 小时前
DeepFM模型介绍
深度学习·dnn·推荐系统·deepfm
shichaog13 小时前
第四章 神经网络声码器
人工智能·深度学习·神经网络·语音合成·声码器
KeyPan13 小时前
【Ubuntu与Linux操作系统:一、Ubuntu安装与基本使用】
linux·运维·服务器·人工智能·深度学习·ubuntu·机器学习
m0_6786933314 小时前
深度学习笔记11-优化器对比实验(Tensorflow)
笔记·深度学习·tensorflow
亲持红叶14 小时前
Chapter5.1 Evaluating generative text models
人工智能·python·gpt·深度学习·自然语言处理
雾隐隐o15 小时前
基于深度学习的滑块验证破解方法及模型训练过程
人工智能·深度学习
羊小猪~~15 小时前
EDA数据分析结合深度学习---基于EDA数据分析和MLP模型的天气预测(tensorflow实现)
pytorch·python·深度学习·机器学习·数据挖掘·数据分析·tensorflow