Offline RL : Efficient Planning in a Compact Latent Action Space

ICLR 2023
paper

Intro

采用Transformer架构的Planning方法对马尔可夫序列重构,(et. TT)在面对高维状态动作空间,容易面对计算复杂度高的问题。本文提出TAP算法,基于Transformer的VQ-VAE,利用提取的状态动作在隐空间的低微特征进行Planning,然后使用latent codes经过decoder得到重构序列,在Offline下取较好的结果。

Method

VQ-VAE

训练VQ-VAE使用离线数据 τ = ( s 1 , a 1 , r 1 , R 1 , s 2 , a 2 , r 2 , R 2 , ... , s T , a T , r T , R T ) \tau=(\boldsymbol{s}_1,\boldsymbol{a}_1,r_1,R_1,\boldsymbol{s}_2,\boldsymbol{a}_2,r_2,R_2,\ldots,\boldsymbol{s}_T,\boldsymbol{a}_T,r_T,R_T) τ=(s1,a1,r1,R1,s2,a2,r2,R2,...,sT,aT,rT,RT)。以上图为例,经过encoder得到T个特征(图中T=9),然后步长为L的一维卷积以及最大池化得到向量 ( x ˉ 1 , x ˉ 2 , x ˉ 3 ) (\bar{x}_1,\bar{x}_2,\bar{x}_3) (xˉ1,xˉ2,xˉ3)。在由最近邻找到对应的codebook中的 e i e_i ei作为latent code。
z i = e k , w h e r e k = a r g m i n j ∣ ∣ x i − e j ∣ ∣ 2 \boldsymbol{z}_i=\boldsymbol{e}_k,\mathrm{where~}k=\mathrm{argmin}_j||\boldsymbol{x}_i-\boldsymbol{e}_j||_2 zi=ek,where k=argminj∣∣xi−ej∣∣2

解码阶段,首先将latentcode扩展,与输入等维度。concat初始状态,经过decoder得到重构的序列。损失函数则是由原序列与重构序列的均方误差。除此外还最小化特征向量、latent code分别与codebook的距离: ∣ ∣ x i − e k ∣ ∣ 2 a n d ∣ ∣ z i − e k ∣ ∣ 2 ||\boldsymbol{x}_i-\boldsymbol{e}_k||_2\mathrm{~and~}||\boldsymbol{z}_i-\boldsymbol{e}_k||_2 ∣∣xi−ek∣∣2 and ∣∣zi−ek∣∣2

得到latent code后,还需要训练其先验分布用于后续的Planning过程。TAP采用Transformer架构的自回归模型 p ( z t ∣ z < t , s 1 ) = p ( z t ∣ s 1 , z 1 , z 2 , . . . , z t − 1 ) p(\boldsymbol{z}{t}|\boldsymbol{z}{<t},\boldsymbol{s}{1})=p(\boldsymbol{z}{t}|\boldsymbol{s}{1},\boldsymbol{z}{1},\boldsymbol{z}{2},...,\boldsymbol{z}{t-1}) p(zt∣z<t,s1)=p(zt∣s1,z1,z2,...,zt−1)构建更加紧凑的latent code

Planning

使用先验模型,生成当前state在隐空间中的latent code序列,再用decoder进行解码就得到预测的轨迹。对每条生成轨迹有如下评价函数
g ( s 1 , z 1 , z 2 , . . . , z M ) = ∑ t γ t r ^ t + γ T R ^ T + α ln ⁡ ( min ⁡ ( p ( z 1 , z 2 , . . . , z M ∣ s 1 ) , β M ) ) g(\boldsymbol{s}_1,\boldsymbol{z}_1,\boldsymbol{z}_2,...,\boldsymbol{z}_M)=\sum_t\gamma^t\hat{r}_t+\gamma^T\hat{R}_T+\alpha\ln\left(\min(p(\boldsymbol{z}_1,\boldsymbol{z}_2,...,\boldsymbol{z}_M|\boldsymbol{s}_1),\beta^M)\right) g(s1,z1,z2,...,zM)=t∑γtr^t+γTR^T+αln(min(p(z1,z2,...,zM∣s1),βM))

前两项衡量轨迹累计折扣奖励,后一项则是惩罚项,如果轨迹有概率大于阈值则相信累计奖励。而若小于阈值,则后项由于权重 α \alpha α取值为大于折扣回报的最大值,使得此时后项对评分的主导远大于累计奖励,即选择高概率的轨迹。

Beam Serach

基于初始状态以及先验模型,采样生成latent code z采用Beam search

可以看出,首先利用先验模型采样得到n个 z 1 z_1 z1,然后对每个 z 1 z_1 z1由先验模型得到概率最高的排序为前E个的 z 2 z_2 z2拼接,然后由decoder解码并根据评价函数得到轨迹分数,选取Top-B的序列的 z 1 , z 2 z_1, z_2 z1,z2, 重复上述过程选取最大评分的轨迹。

总结

对于高维复杂环境,通过encoder到低维度隐空间进行推理学习好的特征表示,在decoder生成是一个好的框架。对于下游任务,就可以直接采用学习好的特征以及decoder实现zero-shot。这是一个不错的思路。后续ICLR2024有个工作使用在隐空间进行Diffusion:《Efficient Planning with Latent Diffusion》

相关推荐
Erik_LinX1 小时前
day1-->day7| 机器学习(吴恩达)学习笔记
笔记·学习·机器学习
时间很奇妙!2 小时前
decison tree 决策树
算法·决策树·机器学习
liruiqiang052 小时前
机器学习 - 初学者需要弄懂的一些线性代数的概念
人工智能·线性代数·机器学习·线性回归
Icomi_2 小时前
【外文原版书阅读】《机器学习前置知识》1.线性代数的重要性,初识向量以及向量加法
c语言·c++·人工智能·深度学习·神经网络·机器学习·计算机视觉
羊小猪~~3 小时前
深度学习项目--基于LSTM的糖尿病预测探究(pytorch实现)
人工智能·pytorch·rnn·深度学习·神经网络·机器学习·lstm
东来梁蕴秀3 小时前
大语言模型之prompt工程
人工智能·机器学习
yi0315 小时前
文献阅读记录8--Enhanced Machine Learning Sketches for Network Measurements
人工智能·机器学习
金融OG9 小时前
99.16 金融难点通俗解释:营业总收入
大数据·数据库·python·机器学习·金融
两千连弹14 小时前
机器学习 ---逻辑回归
人工智能·python·机器学习·逻辑回归·numpy
Swift社区17 小时前
【前沿聚焦】机器学习的未来版图:从自动化到隐私保护的技术突破
人工智能·机器学习