论文解读:MASS-EDITING MEMORY IN A TRANSFORMER(MEMIT)

论文发表于人工智能顶会ICLR(原文链接)。在模型编辑方法中,过去工作主要局限于更新单个事实。因此,基于ROME,本文开发了MEMIT,在大模型GPT-J(6B)和GPT-NeoX(20B)上实现了数千的批量编辑。

阅读本文请同时参考原始论文图表。

方法

模型定义为文中式(1),其中\[x_{\[1\]},...,x_{\[E\]}\]表示长度为E的输入句子,x_{\[t\]}表示模型输出单词。模型层之间状态的计算表示为式(2/3/4),将模型最后一层关于输入句子最后一个token的状态映射到词汇空间就是x_{\[t\]}。本文主要考虑GPT-J的架构来介绍方法,其中FFN和注意力模块并行,而不是使用注意力模块的输出输入FFN(当然后面介绍的MEMIT方法可以适用到其它LLM架构上)。

对于一个事实(s,r,o),模型输入包含头实体s和关系r的句子,输出头实体o。模型编辑就是让模型关于包含(s,r)的句子输出o变成另一个o' 。本文的目标是同时对多个事实进行编辑,对同时编辑的事实构成的集合\\mathcal{E}做了一个限制,如式(5)所示,即事实之间不能有冲突。

根据ROME论文的实验结果,对于某个prompt p_i,本文只考虑其中主体s的最后一个token的中间层状态h_i\^l、对应的FFN激活m_i\^l和注意力模块激活a_i\^l对模型输出的影响,此时i为prompt的编号。另外,如图3所示(ROME的实验),由于不止一个中间层对模型预测有影响,因此同时考虑多个中间层相应激活对预测的影响。比如对于GPT-J,l\\in \\mathcal{R}=\\{3,4,5,6,7,8\\}

模型推理机制

根据模型的状态计算式(2),可以得到式(6),即每一层的输出状态是初始状态加上其前面层的FFN和注意力模块激活。根据之前ROME实验(ROME论文图1e/f/g)的观察,作者认为模型的推理机制如图2所示:

(a)模型先使用注意力机制把主体s的信息汇集到s的最后一个token(Jordan)。

(b)通过模型各层FFN根据主体s的信息逐步读取相关的记忆并加入潜在表示。

(c)通过注意力模块使用读取的记忆来生成输出,也就是图2所示的信息通路。

批量参数更新

和ROME类似,对于第l层的FFN的第二层权重,在预训练后满足式(7),通过求导得到方程式(8)。其中K_0=\[k_1,...,k_n\],M_0=\[m_1,...,m_n\]。当要添加新知识K_1,M_1时,就是把它们拼接后进行优化,即式(10-13)。最终得到W_0的改变量\\Delta为式(14)。其中C_0=K_0K_0\^T定义为期望式(15),\\lambda=1.5\\times 10\^4。注意MEMIT的优化定义与ROME不同。

多层参数批量更新

1、根据之前的模型推理机制的分析,作者先通过式(16)优化得到主体s最后一个token在第L层关于待修改事实(s_i,r_i,o_i) 的表示z_i。其中L=\\max(\\mathcal{R})表示对预测有影响层的最大层数,h_i\^L表示模型关于(s_i,r_i)在该位置的原始表示。也就是优化一个残差值\\delta_i,使得z_i=h_i\^L+\\delta_ix_j表示prompt的前缀。

2、获得残差\\delta_i=z_i-h_i\^L后,就是修改\\mathcal{R}中每层FFN的权重W_{out}\^l,使得模型关于(s_i,r_i)的表示\\hat{h}_i\^L尽可能接近z_i,也就是优化式(17/18)。修改权重需要获取每个权重对应的新的键k_i\^l和值m_i\^l,并且由于前一层的权重修改会影响后层的输入,因此需要从\\mathcal{R}的第1层到第最后一层按顺序更新权重。每层的键可以直接通过前向传播得到,即式(19)。 值则是键k_i\^l经过权重W_{out}\^l映射后加上残差 r_i\^l,如式(20)所示。作者将第L层的残差\\delta_i=z_i-h_i\^L分配给\\mathcal{R}中的每一层,那为什么分母是L-l+1,而不是L呢?这是因为MLP的输出m_i\^{l}改了,会导致下一层的注意力输出a_i\^{l+1}也改了,所以总体改变量并不是直接对m_i\^l的改变量求和的结果。

总的编辑算法如算法1所示,看起来L,\\mathcal{R}对于一个批次中的每一个待更新事实都是固定的,具体细节还要看代码。

实验

表1:在GPT-J模型上修改zsRE数据集的10000个事实的对比结果,其中MEND基于元学习超网络可以并行编辑,ROME是安顺序编辑,这两个方法比正则化微调效果还差。

图5:各方法关于编辑事实的数量的指标变化图。ES为编辑准确率;PS为编辑后对同义句的准确率;NS为对不相关事实的准确率;RS是编辑后模型生成关于s的句子与参考句子的相似度;GE是生成关于s的句子的流畅度;CS是ES/PS/NS的调和平均。NS和GE应该和虚线也就是编辑前的模型相近。

表2:在GPT-J和GPT-NeoX上10000次编辑后的对比。

图6a:三个方法在反事实数据集中的不同关系对应事实的编辑得分。

图6b:通用性(同义句子的准确性)和特异性(无关事实的保持)的权衡。

图7:在并行编辑时混合包含不同关系的事实,对性能的影响。每个图都对两个关系事实的混合进行了编辑。可以看出混合编辑的结果和两个关系的事实分别单独进行编辑的结果的平均相近。

缺陷

文中指出,MEMIT只局限于有向关系,并且无法对空间时间推理、数学知识、语言知识、程序知识进行编辑,甚至无法泛化对称关系。例如,"库克是苹果首席执行官"必须与"苹果首席执行官是库克"分开处理。