参考资料:
- [2411.14499] Understanding World or Predicting Future? A Comprehensive Survey of World Models
- [1803.10122] World Models
- Learning Latent Dynamics for Planning from Pixels
- Kaixhin/PlaNet: Deep Planning Network: Control from pixels by latent planning with learned dynamics
世界模型
简介
世界模型 :一种理解世界当前状态 或预测其未来动态的工具。
世界模型的两个主要功能:
- 构建内部表征以理解世界运作机制。
- 预测未来状态以模拟和指导决策。
分类
图片来自[2411.14499] Understanding World or Predicting Future? A Comprehensive Survey of World Models.
- 作者按照模型的侧重点不同,将世界模型分成两个大类,即:
- Internal Representations.
- Future Predictions.
- 经常能在网上刷到的LeCun力推世界模型,说的是JEPA.
- 左边分支的世界模型也可以做"future prediction",作为学习模型参数过程的一个副产物吧 (视觉模块的reconstruction)。
这里讨论的是两篇world model for decision-making的文章。
World Models (2018)
AI社区中,首篇系统性介绍世界模型的文章。
人类的心理模型
简单可以概括成以下几点:
- 对于外部世界的大量信息流,人脑能够学习到外部世界时空信息的抽象表示,作为我们对外部世界的"建模"。
- 我们所看到一切都基于脑中模型对未来的预测。
- 我们能够基于这个预测模型本能地行动,在面对危险时做出快速的反射性行为。
打棒球的例子: 击球手需要在毫秒级别的时间内决定如何挥棒 ------ 这比视觉信号到大脑的时间还要短。
在之后的世界模型结构和实验中,都可以看到这个心理模型的影子。
模型结构
世界模型主要由两个模块组成:视觉模块、记忆模块。
-
视觉模块:将外部世界的高维观测,压缩成低维的特征。
-
记忆模块:整合历史信息,预测未来。
控制器会利用世界模型给出的信息进行决策。
视觉模块
作者在文章中使用VAE的Encoder部分作为视觉模块。
记忆模块
作者在文章中使用MDN-RNN作为记忆模块。
- MDN指的是mixture density networks,就是一个建模混合模型的网络,文中使用的是高斯混合模型 (GMM),此时神经网络除了输出每个高斯分布的均值和标准差,还需要输出用于选择高斯分布的类别分布。
- MDN会接受一个temperature参数\(\tau\),用于调整不确定性。
- 在图中,MDN-RNN建模的是\(P(z_{t+1}\mid a_t, z_t, h_t)\).
- 除了隐状态之外,记忆模块可能还需要建模其他东西,比如奖励\(P(r_{t+1} \mid a_t, z_t, h_t)\),游戏结束的信号\(P(\text{done}_{t+1} \mid a_t, z_t, h_t)\).
NOTE:为什么要使用混合模型,即使VAE的隐变量空间只是一个对角高斯?作者的解释是:混合模型中的离散部分 (选择哪一个高斯组分),有利于建模环境中的离散随机事件。比如说NPC在平静状态和警觉状态下的表现不同。
控制器
作者将整个模型的复杂性都集中到了视觉和记忆模块,有意使得控制器的结构尽可能简单:
\[a_t = W_c[z_t~~h_t] + b_c \]
就是单层的神经网络。
模型训练和实验
文章官网World Models,有gif演示,而且可以试玩模型"梦中"的游戏。
训练
两个实验都是先单独训练世界模型 (无监督):
- 使用随机策略收集一系列的游戏图像。
- 使用这些图像训练好VAE。
- 在训练好的VAE基础上,训练好MDN-RNN。
之后部署世界模型并训练控制器。两个实验的主要区别在部署:
-
Car Racing实验:直接在实际环境部署,训练好了控制器之后,又给出了在模型"梦中"的模拟。
-
VizDoom实验:先在"梦中"部署,训练好了控制器之后,再将整个模型转移到实际环境查看效果。
NOTE:在两个实验中,世界模型都没有建模环境的奖励。第一个实验中,奖励只在训练控制器的时候由实际环境给出;第二个实验中,指标是存活时间,不需要奖励。
REMARK:训练成功之后,模型实际上成为了游戏的"模拟器",学习到了游戏逻辑 (角色中弹后会重新开始)、敌人行为 (按一定时间间隔发射子弹)、物理机制 (子弹飞行速度)等。
实验
Car Racing:
VizDoom:
消融实验1 -- 视觉模块+记忆模块的优越性
在Car Racing中,消融实验显示,单独的视觉模块效果不如一整个的世界模型 (但是也已经超过了DQN和A3C)
消融实验2 -- 用tau调整随机性
在VizDoom实验中,由于模型并非完全精确,控制器可能会利用模型的缺陷来在模拟器中达到高分,一旦部署到实际环境,控制器就不行了。
为了防止这一点,MDN-RNN预测的是具有随机性的环境,并通过调整不确定性参数\(\tau\)来控制随机性。在实验中,\(\tau=1.15\)时效果最好。
当\(\tau=0.1\)时,模型几乎是确定性的,这时候敌人甚至无法发射子弹,所以出现了在模拟器中非常高分,实际环境中却非常低分的情况。
跑分对比实验
- Car Racing实验:取得的分数超过了先前的基于深度强化学习的方法,如DQN、A3C.
- VizDoom实验:在梦中学会了如何躲避怪物的子弹,部署到实际环境后的存活时长也超过了先前。
迭代训练过程
本文的实验环境简单,所以是使用随机策略采样,分别训练三个模块。面对更复杂的任务,可能需要三个模块一起训练,但是本文只是提了一下记忆模块和控制器一起训练的流程:
三个模块一起训练的好处是:
- 视觉模块会倾向于学习到有利于当前任务的特征。
- 记忆模块可以对控制器进行学习,控制器又可以基于记忆模块继续改进,如此往复。
- 可以使用训练中的控制器进行轨迹采样而不是随机策略。
Learning Latent Dynamics for Planning from Pixels (2019)
相对于上一篇,这篇的改进:
- 假定了环境是部分可观测马尔可夫决策过程 (POMDP),世界模型就是在学习这个POMDP.
- 给出了一套结合模型预测控制 (MPC) 方法的训练过程 ------ Deep Planning Network (PlaNet).
- 提出基于确定性和随机性结合的状态空间模型 (RSSM),而不是仅有确定性状态的RNN和仅有随机性状态的SSM.
- 给出了适用于多步预测的变分推断方法 ------ latent overshooting.
Problem setup
假定实际的环境是一个POMDP:
目标是学习到一个策略,能够最大化期望累积回报\(\mathbb E[\sum r_t]\)。
Deep planning network
这里先讲世界模型+MPC的学习和规划算法。
while循环内部,总体上分成三个部分:模型学习,实时规划+数据收集,更新数据库。
模型学习
从数据库中随机抽取观测序列的小批量,然后使用梯度方法学习。
实时规划+数据收集
总体上就是一个有限时间域的MPC框架,在每个time step按三步走:
- Observe:获得当前时刻的状态。由于这里在隐状态空间进行规划,所以需要从历史的观测数据中推断当前状态 (通过隐变量的后验概率)。
- Predict and plan:利用当前学习到的模型,解一个有限时间域的最优控制问题,获得一串动作序列。本文中的planner使用的是cross entropy method (CEM).
- Act :对环境使用这串动作序列的第一个动作\(a_t\),移动到下一个time step. 这里用了一个trick,把取得的动作\(a_t\)重复了\(R\)次 (用相同的action,连续走了\(R\)步),取reward的总和作为当前时刻的reward,取最终的第\(R\)观测\(o_{t+1}^R\)作为下一个时刻的观测\(o_{t+1}\)。
更新数据库
将上一个部分收集到的观测序列加入到数据库中,以供世界模型的进一步更新。
NOTE:相对于model-free RL算法,model-based planning的一大优势就是数据利用率提高了。体现在planning取得的观测序列可以反复用于世界模型的学习。
RSSM
这种模型也叫:Non-linear Kalman filter, sequential VAE, deep variational bayes filter,看了一眼相关的文章,好像要从头到尾讲明白 (像VAE那样) 比较复杂。
这里浅浅讲一下世界模型的结构以及训练的Loss。
Latent state-space model
使用下面的encoder来近似后验概率:
\[q(s_{\le t} \mid o_{\le t},a_{<t}) = \prod_{t=1}^T q(s_t\mid s_{t-1},a_{t-1},o_t) \]
都使用神经网络参数化的高斯分布表示,其中observation model和encoder用的是卷积网络。
Training Objective
通过最大化log Evidence来训练:
\[\arg\max \ln p(o_{\le t} \mid a_{<t}) \]
接下来推导ELBO.
先拆成边际化的形式
\[\ln p(o_{\le T} \mid a_{<T}) = \ln \int p(o_{\le T}, s_{\le T} \mid a_{<T}) \text{d}s \]
把联合概率拆开
\[\ln p(o_{\le T} \mid a_{<T}) = \ln \int p(o_{\le T} \mid s_{\le T}, a_{<T}) p(s_{\le T} \mid a_{<T}) \text{d}s \]
写成期望的形式
\[\ln p(o_{\le t} \mid a_{<t}) = \ln \mathbb E_{p(s_{\le t}\mid a_{<t})}[ p(o_{\le t} \mid s_{\le t}, a_{<t}) ] \]
利用重要性采样方法,转变成从encoder采样
\[\ln p(o_{\le t} \mid a_{<t}) = \ln \mathbb E_{q(s_{\le t}\mid o_{\le t},a_{<t})}[ p(o_{\le t} \mid s_{\le t}, a_{<t}) p(s_{\le t}\mid a_{<t}) / q(s_{\le t}\mid o_{\le t},a_{<t})] \]
链式分解,并利用模型的条件独立性化简 (概率图参考下面的)
\[\ln p(o_{\le t} \mid a_{<t}) = \ln \mathbb E_{q(s_{\le t}\mid o_{\le t},a_{<t})}[\prod p(o_t \mid s_t) p(s_t \mid s_{t-1},a_{t-1}) / q(s_{\le t}\mid o_{\le t},a_{<t})] \]
根据Jensen不等式,\(\ln \mathbb E[x] \ge \mathbb E[\ln(x)]\)
\[\ln p(o_{\le t} \mid a_{<t}) \ge \mathbb E_{q(s_{\le t}\mid o_{\le t},a_{<t})}[\sum_t \ln p(o_t \mid s_t) + \ln p(s_t \mid s_{t-1},a_{t-1}) - \ln q(s_{\le t}\mid o_{\le t},a_{<t})] \]
右边可以写成reconstruction + KL的形式,最后就是
确定性和随机性结合 - RSSM
- 纯确定性的世界模型:模型难以预测多种可能的未来情况;容易被planner利用模型缺陷 (在World Models中,通过MDN添加随机性来缓解这一点,但本质还是确定性的)
- 纯随机性的世界模型:模型难以记住信息,导致产生前后不一致的预测结果。
所以作者考虑将确定性和随机性结合,称这种结构为RSSM.
相对于上一篇,把记忆模块换成了RSSM。
Latent Overshooting
之前讨论的都是\(s_t \to s_{t+1}\)的单步预测,如果每次单步预测都准确无误,那多步预测肯定也没问题。但是由于模型本身有局限,所以不一定能很好的推广到多部预测。
于是作者考虑了直接进行跨步的预测,先通过对中间几步隐变量边际化得到了跨步预测的转移
并且推导了针对跨步预测的变分bound
把考虑不同的步幅\(d\),求和,就得到latent overshooting的目标函数
实验结果
DeepMind control suite环境:图像作为观测,连续动作空间。
消融实验
-
验证PlaNet的数据收集过程有优势。Random Collection指的是用随机策略收集数据而不是通过MPC;Random shooting指的是使用了MPC框架,但是不使用CEM,而是直接从1000条随机采的动作序列里选最好的那条。最后PlaNet在大部分情况都明显好于另外两种。
-
RSSM和SSM、GRU的对比。观察到RSSM明显好于后两者,表明了确定性+随机性结合的优势。
-
是否加入latent overshooting作为变分目标。观察到Latent overshooting使RSSM的表现轻微变差,但是在一些任务上让DRNN的表现变好了。
跑分对比实验
- PlaNet的分数能打败A3C。
- PlaNet的分数总体不如D4PG,但是大部分任务相差不多。
- PlaNet在所有任务上,数据利用率都好于D4PG.
- PlaNet (CEM + 世界模型) 和 CEM + true simulator对比只差了一些,体现出世界模型较好地学习到了环境。
六个任务一起训练
每次循环中,agent面对的可能是不同的环境,所以数据库中抽取出来的轨迹也是打乱的。
最后跑分不如单独训练,但是体现出了agent能够自己判断出面对的是哪个任务了。
代码选讲
代码来自:Kaixhin/PlaNet: Deep Planning Network: Control from pixels by latent planning with learned dynamics
主要是看看transition model和模型训练过程。解释都在注释里,有部分注释是代码库原有的。
Transition model
python
class TransitionModel(jit.ScriptModule):
__constants__ = ['min_std_dev']
def __init__(self, belief_size, state_size, action_size, hidden_size, embedding_size, activation_function='relu', min_std_dev=0.1):
super().__init__()
self.act_fn = getattr(F, activation_function)
self.min_std_dev = min_std_dev
self.fc_embed_state_action = nn.Linear(state_size + action_size, belief_size) # combine s_t and a_t to comb(s_t, a_t)
self.rnn = nn.GRUCell(belief_size, belief_size) # from comb(s_t, a_t), h_t to h_t+1
self.fc_embed_belief_prior = nn.Linear(belief_size, hidden_size) # from h_t to z_t
self.fc_state_prior = nn.Linear(hidden_size, 2 * state_size) # parameterized prior of s_t, from z_t to mean and std
self.fc_embed_belief_posterior = nn.Linear(belief_size + embedding_size, hidden_size) # from h_t and e_t to z_t
self.fc_state_posterior = nn.Linear(hidden_size, 2 * state_size) # parameterized posterior of s_t, from z_t to mean and std
# Operates over (previous) state, (previous) actions, (previous) belief, (previous) nonterminals (mask), and (current) observations
# Diagram of expected inputs and outputs for T = 5 (-x- signifying beginning of output belief/state that gets sliced off):
# t : 0 1 2 3 4 5
# o : -X--X--X--X--X- 设置了初始的隐状态是None,所以不考虑0时刻的obs
# a : -X--X--X--X--X- 不考虑最后一个action,因为最后一个action没有后续的obs
# n : -X--X--X--X--X-
# pb: -X-
# ps: -X-
# b : -x--X--X--X--X--X-
# s : -x--X--X--X--X--X-
# 输入的shape都是(time_step, batch_size, *)
@jit.script_method
def forward(self, prev_state:torch.Tensor, actions:torch.Tensor, prev_belief:torch.Tensor, observations:Optional[torch.Tensor]=None, nonterminals:Optional[torch.Tensor]=None) -> List[torch.Tensor]:
# 后面都是动态更新,为了保留grad,不能使用单个tensor作为buffer,所以创建了几个list
T = actions.size(0) + 1 # 实际需要的list长度,参考上面的图
beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = \
[torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T
beliefs[0], prior_states[0], posterior_states[0] = prev_belief, prev_state, prev_state # 0时刻赋初值
# 每次循环开始,是已知t时刻的信息,进一步计算t+1时刻的信息
for t in range(T - 1):
# 根据情况合适的s,因为模型可以在脱离observations的情况下自己预测
# 如果observations为None,则使用先验状态 (模型一步步生成出来的),否则使用后验状态 (根据历史的obs和action推断出来的)
_state = prior_states[t] if observations is None else posterior_states[t]
# terminal则说明这段序列已经结束了,所以把状态mask掉 (就是0)
_state = _state if nonterminals is None else _state * nonterminals[t]
# 注意下面每一块的hidden是临时变量,表示的是不同的意思
# 计算确定性隐状态h = f(s_t, a_t, h_t)
hidden = self.act_fn(self.fc_embed_state_action(torch.cat([_state, actions[t]], dim=1))) # s和a先拼在一起
beliefs[t + 1] = self.rnn(hidden, beliefs[t]) # 对应概率图中从s,a,h到h的实线
# 计算隐状态s的先验 p(s_t|s_t-1,a_t-1)
hidden = self.act_fn(self.fc_embed_belief_prior(beliefs[t + 1])) # 对应概率图中从h到s的实线
prior_means[t + 1], _prior_std_dev = torch.chunk(self.fc_state_prior(hidden), 2, dim=1)
prior_std_devs[t + 1] = F.softplus(_prior_std_dev) + self.min_std_dev # Trick: 使用softplus来保证std_devs为正,并且使用min_std_dev来保证std_devs不会太小
prior_states[t + 1] = prior_means[t + 1] + prior_std_devs[t + 1] * torch.randn_like(prior_means[t + 1])
# 计算隐状态s的后验 q(s_t|o≤t,a<t)
if observations is not None: # 只有observations不为None时,才计算后验
t_ = t - 1 # 这是实现的问题,因为传进来的是obs[1:],所以应该用t_+1才能索引到对应的obs
hidden = self.act_fn(self.fc_embed_belief_posterior(torch.cat([beliefs[t + 1], observations[t_ + 1]], dim=1))) # 对应概率图中的两条虚线
posterior_means[t + 1], _posterior_std_dev = torch.chunk(self.fc_state_posterior(hidden), 2, dim=1)
posterior_std_devs[t + 1] = F.softplus(_posterior_std_dev) + self.min_std_dev
posterior_states[t + 1] = posterior_means[t + 1] + posterior_std_devs[t + 1] * torch.randn_like(posterior_means[t + 1])
# 返回h,s,以及先验和后验的均值和方差
hidden = [torch.stack(beliefs[1:], dim=0), torch.stack(prior_states[1:], dim=0), torch.stack(prior_means[1:], dim=0), torch.stack(prior_std_devs[1:], dim=0)]
if observations is not None:
hidden += [torch.stack(posterior_states[1:], dim=0), torch.stack(posterior_means[1:], dim=0), torch.stack(posterior_std_devs[1:], dim=0)]
return hidden
世界模型训练
只截取了一小部分,重点看loss func是如何计算的。
python
# Model fitting
losses = []
for s in tqdm(range(args.collect_interval)):
# Draw sequence chunks {(o_t, a_t, r_t+1, terminal_t+1)} ~ D uniformly at random from the dataset (including terminal flags)
observations, actions, rewards, nonterminals = D.sample(args.batch_size, args.chunk_size) # Transitions start at time t = 0
# Create initial belief and state for time t = 0
init_belief, init_state = torch.zeros(args.batch_size, args.belief_size, device=args.device), torch.zeros(args.batch_size, args.state_size, device=args.device)
# Update belief/state using posterior from previous belief/state, previous action and current observation (over entire sequence at once)
# 一次把整个隐状态序列全部计算出来
beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs =\
transition_model(init_state, actions[:-1], init_belief, bottle(encoder, (observations[1:], )), nonterminals[:-1])
# Calculate observation likelihood, reward likelihood and KL losses (for t = 0 only for latent overshooting); sum over final dims, average over batch and time (original implementation, though paper seems to miss 1/T scaling?)
# Reconstruction loss都使用MSE
# mean(dim=(0, 1))对batch和time进行平均
observation_loss =\
F.mse_loss(bottle(observation_model, (beliefs, posterior_states)), observations[1:], reduction='none').sum(dim=2 if args.symbolic_env else (2, 3, 4)).mean(dim=(0, 1))
reward_loss =\
F.mse_loss(bottle(reward_model, (beliefs, posterior_states)), rewards[:-1], reduction='none').mean(dim=(0, 1))
# KL loss, 计算了后验q(s_t|o≤t,a<t)和先验p(s_t|s_t-1,a_t-1)的KL散度
kl_loss =\
torch.max(kl_divergence(Normal(posterior_means, posterior_std_devs), Normal(prior_means, prior_std_devs)).sum(dim=2), free_nats).mean(dim=(0, 1)) # Note that normalisation by overshooting distance and weighting by overshooting distance cancel out
"""
后面的部分略
"""