0 论文信息
最近看看前人是怎么做强化学习中 visual transfer 的工作,在此基础之上看看有哪些可以将自己思路进行落地的地方。
1 相关工作
相关工作通过直接推广策略或学习广义状态表示来解决强化学习中的领域适应问题。
领域随机化是直接学习具有泛化能力的策略的最流行方法 (Tobin et al. 2017; Andrychowicz et al. 2020; Slaoui et al. 2020; Laskin et al. 2020)。通过在许多源领域上进行训练,强化学习主体学会忽略无关的变化因素,只关注共同的特征。然而,这种方法依赖于训练中多个源领域的可用性,并且该方法的复杂性随着变化数量的增加而增加。
与直接学习具有泛化能力的策略不同,其他工作关注状态表示的泛化。一些视觉领域适应工作使用图像到图像的转换将目标领域中基于像素的状态映射到源领域中的配对状态 (Pan et al. 2017; Tzeng et al. 2020; Gamrian and Goldberg 2019)。这通常通过对抗方法实现,例如生成对抗网络 (GANs) (Goodfellow et al. 2014),以及在图像对缺失的情况下的非对齐的 GANs (Liu, Breuel, and Kautz 2017; Zhu et al. 2017)。尽管这些方法提供了有希望的结果,但图像转换在推断时增加了额外的负担,这在实时应用中是不实际的。
其他工作进一步尝试通过将基于像素的状态映射到潜在空间来学习广义状态表示 (Higgins et al. 2017)。例如,变分自动编码器 (VAE) 的潜在嵌入可以作为强化学习中内部潜在状态表示的一部分。作者将这种方法称为 VAE-嵌入。DARLA 进一步扩展了 VAE 到 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β-VAE,以鼓励潜在嵌入的解耦,并使用预训练去噪自动编码器 (DAE)(Vincent et al. 2010) 的一个内部层作为重建目标。尽管潜在状态表示中的解耦使强化学习主体能够更容易地忽略与领域特定特征无关的因素,但策略的转移性能并不保证,因为领域特定特征仍然存在于潜在状态表示中,它们对策略输出的贡献不能推广到其他领域。CURL 使用对比学习从原始像素中提取高级特征,并极大地提高了样本效率 (Laskin, Srinivas, and Abbeel 2020)。
在这项工作中,作者选择 VAE-嵌入、DARLA、CURL 和基于 CycleGAN 的图像到图像转换作为基准。为了更清楚地说明 LUSR 与它们的区别,作者使用图 1 来展示它们的框架。
2 强化学习中的迁移学习
强化学习是研究主体在环境中应如何采取行动以最大化其累积奖励的领域。环境通常以马尔可夫决策过程 (MDP) 的形式给出,其由元组 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( S , A , T , R ) (S, A, T, R) </math>(S,A,T,R) 表示,其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> S S </math>S 是状态空间, <math xmlns="http://www.w3.org/1998/Math/MathML"> A A </math>A 是动作空间, <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T 是转移函数, <math xmlns="http://www.w3.org/1998/Math/MathML"> R R </math>R 是奖励函数。在 MDP 的每个时间步长 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t 中,主体根据当前状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> s t s_t </math>st 在环境中采取动作 <math xmlns="http://www.w3.org/1998/Math/MathML"> a t a_t </math>at,并获得奖励 <math xmlns="http://www.w3.org/1998/Math/MathML"> r t + 1 r_{t+1} </math>rt+1 和下一个状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> s t + 1 s_{t+1} </math>st+1。主体的目标是找到一个策略 <math xmlns="http://www.w3.org/1998/Math/MathML"> π ( s ) \pi(s) </math>π(s),以选择最大化预期奖励 <math xmlns="http://www.w3.org/1998/Math/MathML"> r t + γ r t + 1 + γ 2 r t + 2 + ... r_t+\gamma r_{t+1}+\gamma^2r_{t+2}+\ldots </math>rt+γrt+1+γ2rt+2+... 的行动,其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ 是介于 0 和 1 之间的折扣因子。
为了在强化学习的领域适应设置中形式化领域适应场景,作者将源域和目标域定义为 <math xmlns="http://www.w3.org/1998/Math/MathML"> D S D_S </math>DS 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> D T D_T </math>DT。每个域对应一个 MDP,定义为元组 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( S , A , T , R ) (S, A, T, R) </math>(S,A,T,R),因此源域 <math xmlns="http://www.w3.org/1998/Math/MathML"> D S D_S </math>DS 和目标域 <math xmlns="http://www.w3.org/1998/Math/MathML"> D T D_T </math>DT 中的 MDPs 分别定义为 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( S S , A S , T S , R S ) \left(S_S, A_S, T_S, R_S\right) </math>(SS,AS,TS,RS) 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( S T , A T , T T , R T ) \left(S_T, A_T, T_T, R_T\right) </math>(ST,AT,TT,RT)。源域和目标域的状态空间 <math xmlns="http://www.w3.org/1998/Math/MathML"> S S </math>S 可以不同,但它们的动作空间A应该相同,并且由于共享的内部动态,它们的转移函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T 和奖励函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> R R </math>R 应该具有相似性。换句话说,作者关注的是策略转移,其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> T S ≈ T T , R S ≈ R T , A S = A T T_S \approx T_T,R_S \approx R_T, A_S=A_T </math>TS≈TT,RS≈RT,AS=AT,但 <math xmlns="http://www.w3.org/1998/Math/MathML"> S S ≠ S T S_S \neq S_T </math>SS=ST。
文中的方法专注于在强化学习中学习不同领域状态的潜在统一状态表示 (LUSR)。在本节中,作者首先介绍 LUSR 的定义,然后介绍如何学习它。
图 1 . 文中的方法 (LUSR) 和本文中用于比较的其他基准 (DARLA、CURL 和基于 CycleGAN 的图像到图像转换)的架构。VAE-嵌入的架构可以被视为 DARLA 的一个特例,它将 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β-VAE 替换为 VAE,并避免使用 DAE。所有这些方法的学习可以分为两个阶段。第一阶段是学习适合支持强化学习中领域适应的状态表示,第二阶段是进行强化学习训练。
2.1 LUSR 的定义
文中首先介绍强化学习中状态空间的两个概念,即主体的原始观察状态空间 <math xmlns="http://www.w3.org/1998/Math/MathML"> S o S^o </math>So 和主体的内部潜在状态空间 <math xmlns="http://www.w3.org/1998/Math/MathML"> S z S^z </math>Sz。原始观察状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> S o S^o </math>So 由像素网格组成,而内部潜在状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> S z S^z </math>Sz 中的每个单元表示高级语义特征。映射函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> F : S o → S z \mathcal{F}: S^o \rightarrow S^z </math>F:So→Sz 将观察状态映射到相应的内部潜在状态。在作者的工作中, <math xmlns="http://www.w3.org/1998/Math/MathML"> S z S^z </math>Sz 中的高级语义特征进一步分为领域特定特征 (例如驾驶任务中的天气条件) 和领域通用特征 (例如车辆动力学)。这里作者表示 <math xmlns="http://www.w3.org/1998/Math/MathML"> S z = ( S z ^ , S z ‾ ) S^z=\left(\widehat{S^z}, \overline{S^z}\right) </math>Sz=(Sz ,Sz),其 <math xmlns="http://www.w3.org/1998/Math/MathML"> S z ^ \widehat{S^z} </math>Sz 表示领域特定特征, <math xmlns="http://www.w3.org/1998/Math/MathML"> S z ‾ \overline{S^z} </math>Sz 表示领域通用特征。关于源域和目标域的状态表示,可以参考图 1。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> S S o ≠ S T o S S z = ( S S z ^ , S S z ‾ ) ; S T z = ( S T z ^ , S T z ‾ ) S S z ‾ = S T z ‾ ; S S z ^ ≠ S T z ^ (1) \begin{gathered} S_S^o \neq S_T^o \\ S_S^z=\left(\widehat{S_S^z}, \overline{S_S^z}\right) ; \quad S_T^z=\left(\widehat{S_T^z}, \overline{S_T^z}\right) \\ \overline{S_S^z}=\overline{S_T^z} ; \quad \widehat{S_S^z} \neq \widehat{S_T^z}\tag{1} \end{gathered} </math>SSo=SToSSz=(SSz ,SSz);STz=(STz ,STz)SSz=STz;SSz =STz (1)
在文中的域迁移设置中,转移函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T 和奖励函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> R R </math>R 仅依赖于在不同领域中保持一致的 <math xmlns="http://www.w3.org/1998/Math/MathML"> S z ‾ \overline{S^z} </math>Sz。在这里,作者定义了以 <math xmlns="http://www.w3.org/1998/Math/MathML"> s o s^o </math>so 为输入的奖励函数和转移函数分别为 <math xmlns="http://www.w3.org/1998/Math/MathML"> R o R^o </math>Ro 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> T o T^o </math>To,以及以 <math xmlns="http://www.w3.org/1998/Math/MathML"> s z s^z </math>sz 为输入的奖励函数和转移函数分别为 <math xmlns="http://www.w3.org/1998/Math/MathML"> R z R^z </math>Rz 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> T z T^z </math>Tz。那么文中有以下关系 :
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> T S o ≠ T T o ; R S o ≠ R T o T S z = T ( S S z ‾ ) = T ( S T z ‾ ) = T T z R S z = R ( S S z ‾ ) = R ( S T z ‾ ) = R T z (2) \begin{gathered} T_S^o \neq T_T^o ; \quad R_S^o \neq R_T^o \\ T_S^z=T\left(\overline{S_S^z}\right)=T\left(\overline{S_T^z}\right)=T_T^z \\ R_S^z=R\left(\overline{S_S^z}\right)=R\left(\overline{S_T^z}\right)=R_T^z\tag{2} \end{gathered} </math>TSo=TTo;RSo=RToTSz=T(SSz)=T(STz)=TTzRSz=R(SSz)=R(STz)=RTz(2)
由于 <math xmlns="http://www.w3.org/1998/Math/MathML"> S z ‾ \overline{S^z} </math>Sz 在不同领域中保持一致,并且奖励结构 ( <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> R R </math>R) 仅依赖于这个表征 (而不依赖 <math xmlns="http://www.w3.org/1998/Math/MathML"> S z ^ \widehat{S^z} </math>Sz ),以 <math xmlns="http://www.w3.org/1998/Math/MathML"> S z ‾ \overline{S^z} </math>Sz 作为输入的强化学习主体将能够成功训练,并且经过训练的主体也具有从源域到目标域进行适应的能力。因此,文中方法的目标是学习将原始观察状态映射到作者称之为潜在统一状态表示 (LUSR) 的映射函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> F : S o → S z ‾ \mathcal{F}: S^o \rightarrow \overline{S^z} </math>F:So→Sz。
2.2 LUSR 的整体学习流程
在这项工作中,作者选择使用循环一致性变分自编码器 (Cycle-Consistent VAE) 来学习映射函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> F : S o → S z ‾ \mathcal{F}: S^o \rightarrow \overline{S^z} </math>F:So→Sz,该方法是一种非对抗性方法,用于解耦领域通用和领域特定的变化因素。类似于变分自编码器 (VAE),循环一致性变分自编码器也由编码器和解码器组成。然而,编码器的输出被分为领域通用和领域特定的嵌入。为了学习映射函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> F \mathcal{F} </math>F,首先收集来自预定义领域集的一些随机观察状态,然后将其用作循环一致性变分自编码器模型训练的输入。一旦模型训练完成,编码器就能够将来自领域集中任何领域的观察状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> s o s^o </math>so 映射到由 <math xmlns="http://www.w3.org/1998/Math/MathML"> s z ‾ \overline{s^z} </math>sz 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> s z ^ \widehat{s^z} </math>sz 组成的潜在状态表示。因此,作者使用经过训练的编码器作为映射函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> F \mathcal{F} </math>F,并仅保留领域通用表示作为 LUSR。
循环一致性变分自编码器 (Cycle-Consistent VAE) 基于循环一致性的思想,其直观理解是两个训练良好的正向和反向转换按任意顺序组合在一起应该近似于一个恒等函数。例如,在变分自编码器 (VAE) 中,编码器是一个正向转换,将输入图像转换为潜在向量,而解码器是一个反向转换,将潜在向量转换回重建图像。在这里,作者将正向循环定义为 : <math xmlns="http://www.w3.org/1998/Math/MathML"> Dec ( Enc ( s ∘ ) ) = s ∘ \operatorname{Dec}\left(\operatorname{Enc}\left(s^{\circ}\right)\right)=s^{\circ} </math>Dec(Enc(s∘))=s∘,反向循环定义为 <math xmlns="http://www.w3.org/1998/Math/MathML"> Enc ( Dec ( s z ^ , s z ‾ ) ) = ( s z ′ ^ , s z ′ ‾ ) \operatorname{Enc}\left(\operatorname{Dec}\left(\widehat{s^z}, \overline{s^z}\right)\right)=\left(\widehat{s^z \prime}, \overline{s^{z \prime}}\right) </math>Enc(Dec(sz ,sz))=(sz′ ,sz′)。根据循环一致性的指示, <math xmlns="http://www.w3.org/1998/Math/MathML"> s o ′ s^o\prime </math>so′ 应该接近 <math xmlns="http://www.w3.org/1998/Math/MathML"> s o s^o </math>so,而 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( s z ′ ^ , s z ′ ‾ ) \left(\widehat{s^z \prime}, \overline{s^{z \prime}}\right) </math>(sz′ ,sz′)也应该接近 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( s z ^ , s z ‾ ) \left(\widehat{s^z}, \overline{s^z}\right) </math>(sz ,sz)。
在 Cycle-Consistent VAE 的正向循环中,对于来自同一领域的两个观察状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> s 1 o s_1^o </math>s1o 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> s 2 o s_2^o </math>s2o, <math xmlns="http://www.w3.org/1998/Math/MathML"> Enc ( s 1 o ) = s 1 z ^ , s 1 z ‾ \operatorname{Enc}\left(s_1^o\right)=\widehat{s_1^z}, \overline{s_1^z} </math>Enc(s1o)=s1z ,s1z 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> Enc ( s 2 o ) = s 2 z ^ , s 2 z ‾ \operatorname{Enc}\left(s_2^o\right)=\widehat{s_2^z}, \overline{s_2^z} </math>Enc(s2o)=s2z ,s2z。由于两者都来自同一领域,并且 <math xmlns="http://www.w3.org/1998/Math/MathML"> s z ^ \widehat{s^z} </math>sz 仅包含领域特定信息,交换 <math xmlns="http://www.w3.org/1998/Math/MathML"> s 1 z ^ \widehat{s_1^z} </math>s1z 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> s 2 z ^ \widehat{s_2^z} </math>s2z 对重构损失没有影响,这意味着应该可以得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> Dec ( s 2 z ^ , s 1 z ‾ ) ≈ s 1 o \operatorname{Dec}\left(\widehat{s_2^z}, \overline{s_1^z}\right) \approx s_1^o </math>Dec(s2z ,s1z)≈s1o 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> Dec ( s 1 z ^ , s 2 z ‾ ) ≈ s 2 o \operatorname{Dec}\left(\widehat{s_1^z}, \overline{s_2^z}\right) \approx s_2^o </math>Dec(s1z ,s2z)≈s2o。这个操作确保了领域特定信息和领域通用信息分别压缩到 <math xmlns="http://www.w3.org/1998/Math/MathML"> s z ^ \widehat{s^z} </math>sz 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> s z ‾ \overline{s^z} </math>sz 中。
在反向循环中,通过将随机采样的 <math xmlns="http://www.w3.org/1998/Math/MathML"> s z ‾ \overline{s^z} </math>sz 与两个领域特定嵌入 <math xmlns="http://www.w3.org/1998/Math/MathML"> s 1 z ^ \widehat{s_1^z} </math>s1z 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> s 2 z ^ \widehat{s_2^z} </math>s2z 结合,经过解码器生成两个重建图像 <math xmlns="http://www.w3.org/1998/Math/MathML"> s 1 o ′ s_1^o \prime </math>s1o′ 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> s 2 o ′ s_2^o\prime </math>s2o′。由于 <math xmlns="http://www.w3.org/1998/Math/MathML"> s 1 o s_1^o </math>s1o 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> s 2 o s_2^o </math>s2o 都是基于相同的 <math xmlns="http://www.w3.org/1998/Math/MathML"> s z ‾ \overline{s^z} </math>sz 生成的,它们对应的领域通用潜在嵌入 <math xmlns="http://www.w3.org/1998/Math/MathML"> s 1 z ‾ \overline{s_1^z} </math>s1z 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> s 2 z ‾ \overline{s_2^z} </math>s2z 也应该相同。
因此,Cycle-Consistent VAE 的目标是最小化以下损失函数 :
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L cyclic = L forward + L reverse (3) \mathcal{L}{\text {cyclic }}=\mathcal{L}{\text {forward }}+\mathcal{L}_{\text {reverse }}\tag{3} </math>Lcyclic =Lforward +Lreverse (3)
其中 :
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L forward = − E q ϕ ( s z ‾ , s z ^ ∣ s o ) [ log p θ ( s o ∣ s z ‾ , s z ∗ ^ ) ] + K L ( q ϕ ( s z ‾ ∣ s o ) ∥ p ( s z ‾ ) ) L reverse = E s z ‾ ∼ p ( s z ‾ ) [ ∥ q ϕ ‾ ( p θ ( s z ‾ , s 1 z ^ ) ) − q ϕ ‾ ( p θ ( s z ‾ , s 2 z ^ ) ) ∥ 1 ] \begin{aligned} \mathcal{L}{\text {forward }}= & -\mathbb{E}{q_\phi\left(\overline{s^z}, \widehat{s^z} \mid s^o\right)}\left[\log p_\theta\left(s^o \mid \overline{s^z}, \widehat{s^z *}\right)\right] \\ & +K L\left(q_\phi\left(\overline{s^z} \mid s^o\right) \| p\left(\overline{s^z}\right)\right) \\ \mathcal{L}{\text {reverse }}= & \mathbb{E}{\overline{s^z} \sim p\left(\overline{s^z}\right)}\left[\left\|\overline{q_\phi}\left(p_\theta\left(\overline{s^z}, \widehat{s_1^z}\right)\right)-\overline{q_\phi}\left(p_\theta\left(\overline{s^z}, \widehat{s_2^z}\right)\right)\right\|_1\right] \end{aligned} </math>Lforward =Lreverse =−Eqϕ(sz,sz ∣so)[logpθ(so∣sz,sz∗ )]+KL(qϕ(sz∣so)∥p(sz))Esz∼p(sz)[∥ ∥qϕ(pθ(sz,s1z ))−qϕ(pθ(sz,s2z ))∥ ∥1]
这里的 <math xmlns="http://www.w3.org/1998/Math/MathML"> L forward \mathcal{L}{\text{forward}} </math>Lforward 是一个修改后的变分上界,而 <math xmlns="http://www.w3.org/1998/Math/MathML"> L reverse \mathcal{L}{\text{reverse}} </math>Lreverse 是循环一致性的损失函数。 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ϕ q_\phi </math>qϕ 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> p θ p_\theta </math>pθ 是编码器和解码器的参数化函数。文中将 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ϕ ‾ \overline{q_\phi} </math>qϕ 定义为仅输出领域通用嵌入的 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ϕ q_\phi </math>qϕ。潜在嵌入 <math xmlns="http://www.w3.org/1998/Math/MathML"> s z s^z </math>sz 由领域通用嵌入 <math xmlns="http://www.w3.org/1998/Math/MathML"> s z ‾ \overline{s^z} </math>sz 和领域特定嵌入 <math xmlns="http://www.w3.org/1998/Math/MathML"> s z ^ \widehat{s^z} </math>sz 组成,它们对应于观察状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> s o s^o </math>so。 <math xmlns="http://www.w3.org/1998/Math/MathML"> s z ∗ ^ \widehat{s^z *} </math>sz∗ 表示来自同一领域的任意随机领域特定嵌入,而 <math xmlns="http://www.w3.org/1998/Math/MathML"> s 1 z ^ \widehat{s_1^z} </math>s1z 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> s 2 z ^ \widehat{s_2^z} </math>s2z 是两个不同的领域特定嵌入。
3 实验
大致放一个实验中的 CarRacing 游戏的实验设置图,剩下的不多赘述,感觉强化学习这边的实验结果基本都是图,看上去比有监督的数据量少 (据说还是挺难训练的)。
图 2 . CarRacing 游戏的变体。A . CarRacing 游戏的原始版本设置为源域。B . 收集观察状态来学习 LUSR 的 CarRacing 游戏的可见目标域。C. CarRacing 游戏的看不见的目标域。这两个域永远不会被主体观测到,不仅在强化学习训练期间,而且在潜在状态表示学习期间。
4 感言
中规中矩的一个方法,感觉主要还是建立在 Cycle-Consistent VAE 这一工作之上,利用了不变表征的特性进行了目标的改写,之后调研一下 Cycle-Consistent VAE,感觉可以做出比这个有意思一些的工作。