KGAT: Knowledge Graph Attention Network for Recommendation 论文笔记

KGAT: Knowledge Graph Attention Network for Recommendation 论文笔记

为了提供更准确、多样和可解释的推荐,必须超越对用户项交互的建模,并考虑辅助信息。

一、论文信息

  1. 原文地址:doi.org/10.1145/329...
  2. 内容介绍:基于知识图谱+图注意力网络的推荐系统(KG+GAT)
  3. 发布信息:2019年8月4日至8日,在美国阿拉斯加州安克雷奇举行的KDD'19。("KDD 是由ACM 主办的国际数据挖掘领域最顶级会议,同时被CCF(中国计算机学会)列为A类会议)
  4. 作者信息:
  • Xiang Wang、National University of Singapore(新加坡国立大学)
  • Xiangnan He、University of Science and Technology of China(中国科学技术大学)
  • Yixin Cao、National University of Singapore(新加坡国立大学)
  • Meng Liu、Shandong University(山东大学)
  • Tat-Seng Chua、National University of Singapore(新加坡国立大学)
  1. 主要内容:传统推荐算法的缺陷、协作知识图谱CKG、KGAT模型(嵌入层、注意力嵌入传播层、预测层)
  2. 前置知识:了解推荐系统基础概念、协作过滤推荐算法、Trans系列(TransE、TransH、TransR)、softmax函数等。

二、论文内容

1. 传统推荐算法的缺陷

图1:协作知识图谱示例
在电影场景下,有以下几种关系类型:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> r 1 r_1 </math>r1:观看过;
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> r 2 r_2 </math>r2:导演过;
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> r 3 r_3 </math>r3:扮演过;
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> r 4 r_4 </math>r4:类型(电影题材)。

如上图:用户 <math xmlns="http://www.w3.org/1998/Math/MathML"> u 1 u_1 </math>u1看了电影 <math xmlns="http://www.w3.org/1998/Math/MathML"> i 1 i_1 </math>i1,而这个电影是 <math xmlns="http://www.w3.org/1998/Math/MathML"> e 1 e_1 </math>e1导演的。

传统的CF(协同过滤)方法会着重去找那些也看了电影 <math xmlns="http://www.w3.org/1998/Math/MathML"> i 1 i_1 </math>i1的用户,比如 <math xmlns="http://www.w3.org/1998/Math/MathML"> u 4 u_4 </math>u4、 <math xmlns="http://www.w3.org/1998/Math/MathML"> u 5 u_5 </math>u5 。
CF 侧重于描述与item有交互的用户之间的相似性:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> u 1 → r 1 i 1 → − r 1 u 2 → r 1 i 2 \mathrm u_1\stackrel{\mathrm r_1}\to\mathrm i_1\stackrel{-\mathrm r_1}\to\mathrm u_2\stackrel{\mathrm r_1}\to\mathrm i_2 </math>u1→r1i1→−r1u2→r1i2

SL(监督学习)方法会重点关注那些有相同属性 <math xmlns="http://www.w3.org/1998/Math/MathML"> e 1 e_1 </math>e1的电影,比如 <math xmlns="http://www.w3.org/1998/Math/MathML"> i 2 i_2 </math>i2。
SL 侧重于描述item的属性之间的相似性(因为是对每种关系单独建模,因此无法跨域):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> u 1 → r 1 i 1 → r 2 e 1 → − r 2 i 2 \mathrm u_1\stackrel{\mathrm r_1}{\rightarrow}\mathrm i_1\stackrel{\mathrm r_2}{\rightarrow}\mathrm e_1\stackrel{-\mathrm r_2}{\rightarrow}\mathrm i_2 </math>u1→r1i1→r2e1→−r2i2

仅靠上面其一都实现不了高阶连通
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> u 1 → r 1 i 1 → r 2 e 1 → − r 3 i 3 \mathrm u_1\stackrel{\mathrm r_1}{\rightarrow}\mathrm i_1\stackrel{\mathrm r^2}{\rightarrow}\mathrm e_1\stackrel{-\mathrm r_3}{\rightarrow}\mathrm i_3 </math>u1→r1i1→r2e1→−r3i3

CF 方法无法对侧信息建模,无法应对矩阵稀疏 的情况(解决:将侧信息建模为特征向量)
SL 把每种关系单独建模,没有考虑协同信号或者说SL方法没考虑到高阶连通性 (解决:知识图谱)

很显然这两类信息都可以作为推荐信息的补充,但是KGAT之前的模型不能做到上面两者信息的融合,而且这里的高阶关系也可以作为推荐信息的补充的。比如图中黄色框图里的用户 <math xmlns="http://www.w3.org/1998/Math/MathML"> u 2 u_2 </math>u2和 <math xmlns="http://www.w3.org/1998/Math/MathML"> u 3 u_3 </math>u3看了同样由 <math xmlns="http://www.w3.org/1998/Math/MathML"> e 1 e_1 </math>e1导演过的电影 <math xmlns="http://www.w3.org/1998/Math/MathML"> i 2 i_2 </math>i2,而灰色框图里用户 <math xmlns="http://www.w3.org/1998/Math/MathML"> u 2 u_2 </math>u2观看的电影 <math xmlns="http://www.w3.org/1998/Math/MathML"> i 3 i_3 </math>i3与电影 <math xmlns="http://www.w3.org/1998/Math/MathML"> i 4 i_4 </math>i4都是由 <math xmlns="http://www.w3.org/1998/Math/MathML"> e 1 e_1 </math>e1扮演过的,那么用户 <math xmlns="http://www.w3.org/1998/Math/MathML"> u 2 u_2 </math>u2是否会对电影 <math xmlns="http://www.w3.org/1998/Math/MathML"> i 4 i_4 </math>i4感兴趣?
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> u 1 → r 1 i 1 → − r 2 e 1 → r 2 i 2 → − r 1 { u 2 , u 3 } , u_1\xrightarrow{r_1}i_1\xrightarrow{-r_2}e_1\xrightarrow{r_2}i_2\xrightarrow{-r_1}\{u_2,u_3\}, </math>u1r1 i1−r2 e1r2 i2−r1 {u2,u3},
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> u 1 → r 1 i 1 → − r 2 e 1 → r 3 { i 3 , i 4 } , u_1\xrightarrow{r_1}i_1\xrightarrow{-r_2}e_1\xrightarrow{r_3}\{i_3,i_4\}, </math>u1r1 i1−r2 e1r3 {i3,i4},

2. 协作知识图谱CKG

为了解决上面提到的问题,作者提出 collaborative knowledge graph (CKG)方法,将图谱关系信息及用户user点击商品item的交互图融合到一个图空间里。这样就可以融合CF信息及KG信息,同时也可以通过CKG发现高阶的关系信息。

  • UI二部图(User-Item Bipartite Graph):用户物品交互图,有交互的用边相连。

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> G 1 = { ( u , y u i , i ) ∣ u ∈ U , i ∈ I ) } \mathrm G_1=\{(\mathrm u,\mathrm y_{\mathrm ui},\mathrm i)|\mathrm u\in\mathrm U,\mathrm i\in\mathrm I)\} </math>G1={(u,yui,i)∣u∈U,i∈I)}
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> y u i = { 1 u、i有交互 0 u、i无交互 \left.\mathrm{y}_{\mathrm{ui}}=\left\{\begin{array}{ll}1&\text{u、i有交互}\\0&\text{u、i无交互}\end{array}\right.\right. </math>yui={10u、i有交互u、i无交互

  • 知识图谱:用于给item添加属性和外部知识,为有向图。

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> G 2 = { ( h , r , t ) ∣ h , t ∈ E , r ∈ R } \mathrm G_2=\{(\mathrm h,\mathrm r,\mathrm t)|\mathrm h,\mathrm t\in\mathrm E,\mathrm r\in\mathrm R\} </math>G2={(h,r,t)∣h,t∈E,r∈R}

此外,还建立了一组UI对齐:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> A = { ( i , e ) ∣ i ∈ I , e ∈ E } \mathrm A=\{(\mathrm i,\mathrm e)|\mathrm i\in\mathrm I,\mathrm e\in\mathrm E\} </math>A={(i,e)∣i∈I,e∈E}

其中( <math xmlns="http://www.w3.org/1998/Math/MathML"> i , e i , e </math>i,e) 表示项目 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i可以与KG中的实体 <math xmlns="http://www.w3.org/1998/Math/MathML"> e e </math>e对齐。

  • 协作知识图谱:这里将二部图和物品实体知识图谱结合成一张图。简而言之就是把UI二部图和知识图谱无缝衔接:

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> G 1 ′ = { ( u , I n t e r a c t , i ) ∣ Interact ∈ y u i } \mathrm{G_1^{\prime}}=\{(\mathrm{u},\mathrm{Interact},\mathrm{i})|\text{Interact}\in\mathrm{y}_{\mathrm{ui}}\} </math>G1′={(u,Interact,i)∣Interact∈yui}
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> G 2 ′ = { ( h , r , t ) ∣ h , t ∈ E ′ , r ∈ R ′ } , \mathrm{G_2^{\prime}=\{(h,r,t)|h,t\in E^{\prime},r\in R^{\prime}\},} </math>G2′={(h,r,t)∣h,t∈E′,r∈R′},
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> E ′ = E ∪ U , \mathrm{E}^{\prime}=\mathrm{E}\cup\mathrm{U}, </math>E′=E∪U,
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> R ′ = R ∪ { I n t e r a c t } \mathrm{R'=R\cup\{Interact\}} </math>R′=R∪{Interact}

但新的问题又来了:

  1. 节点会随着阶数增加而暴增,计算过载;
  2. 高阶连通性会增加预测的不确定性,需要加权。

以往方法:

  1. 基于路径:提取携带高阶信息的路径并将其输入预测模型
    其中处理大量路径的方法:
    (1)选择路径------没有对推荐目标的优化;
    (2)约束路径------需要领域知识且费人力,难搞。
  2. 基于正则:设计额外损失隐式捕获KG结构,缺乏显式建模,既不能保证捕获远程连接,也不能解释高阶建模结果。

于是,问题转化为:既要考虑高阶连通,又要考虑邻居加权------KGAT诞生

3. 模型介绍

图2:KGAT模型的示意图。左边的子图显示了KGAT的模型框架,右边的子图显示了KGAT的注意嵌入传播层。

模型主要由三部分组成:

  1. 嵌入层 保持CKG结构,将节点embedding;
  2. 注意力嵌入传播层 递归的传播来自节点邻居的嵌入内容,更新表示,传播的过程中使用注意力机制;
  3. 预测层 传播层的用户表示和项表示聚合起来,并输出预测的匹配分数。

输入:由G1(用户物品交互图)和G2(知识图谱)衔接得来的CKG;

输出:一个预测U、I之间交互概率的预测函数。

3.1 Embedding Layer嵌入层

简要来说就是在CKG上用TransR模型(知识图谱嵌入方法)将实体和关系参数化为向量表示。
主要思路如下:

对于一个知识图谱的表示( <math xmlns="http://www.w3.org/1998/Math/MathML"> h , r , t h,r,t </math>h,r,t),使用下面的假设进行训练:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> e h r + e r ≈ e t r \mathbf{e}_h^r+\mathbf{e}_r\approx\mathbf{e}_t^r </math>ehr+er≈etr

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> e h , e t ∈ R d e_h,e_t\in\mathbb{R}^d </math>eh,et∈Rd、 <math xmlns="http://www.w3.org/1998/Math/MathML"> e r ∈ R k e_r\in\mathbb{R}^k </math>er∈Rk分别为 <math xmlns="http://www.w3.org/1998/Math/MathML"> h 、 t 、 r h、t、r </math>h、t、r的embedding,而 <math xmlns="http://www.w3.org/1998/Math/MathML"> e h r \mathbf{e}_h^r </math>ehr, <math xmlns="http://www.w3.org/1998/Math/MathML"> e t r \mathbf{e}_t^r </math>etr为 <math xmlns="http://www.w3.org/1998/Math/MathML"> e h e_h </math>eh、 <math xmlns="http://www.w3.org/1998/Math/MathML"> e t e_t </math>et在 <math xmlns="http://www.w3.org/1998/Math/MathML"> r r </math>r所处空间中的投影。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> g ( h , r , t ) = ∥ W r e h + e r − W r e t ∥ 2 2 , g(h,r,t)=\|\mathbf{W}_r\mathbf{e}_h+\mathbf{e}_r-\mathbf{W}_r\mathbf{e}_t\|_2^2, </math>g(h,r,t)=∥Wreh+er−Wret∥22,

对于每一个得到的三元组的嵌入,使用上面的公式来衡量可信度的得分,得分越低,说明三元组更正确。其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> R k × d \mathbb{R}^{k×d} </math>Rk×d是关系 <math xmlns="http://www.w3.org/1998/Math/MathML"> r r </math>r的变换矩阵,它将实体从 <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d维实体空间投影到 <math xmlns="http://www.w3.org/1998/Math/MathML"> k k </math>k维关系空间。
损失函数的设计: 损失函数基于的假设是: 有效的三元组和无效的三元组之间的置信得分差异最大化,无效三元组g(h,r,t')是通过随机替换有效三元组的一个实体得到。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L K G = ∑ ( h , r , t , t ′ ) ∈ T − l n σ ( g ( h , r , t ′ ) − g ( h , r , t ) \mathcal{L}\mathrm{KG}=\sum{(h,r,t,t^{\prime})\in\mathcal{T}}-ln\sigma(g(h,r,t^{\prime})-g(h,r,t) </math>LKG=(h,r,t,t′)∈T∑−lnσ(g(h,r,t′)−g(h,r,t)

该层在三元组的粒度上对实体和关系进行建模,作为正则化器并将直接连接注入到嵌入表示中,从而提高了模型表示能力。

输入:CKG;

输出:实体和关系参数化的向量表示( <math xmlns="http://www.w3.org/1998/Math/MathML"> e u i ( 0 ) e_{ui(0)} </math>eui(0) and <math xmlns="http://www.w3.org/1998/Math/MathML"> e t i ( 0 ) e_{ti(0)} </math>eti(0) and <math xmlns="http://www.w3.org/1998/Math/MathML"> r i r_i </math>ri)

3.2 Attentive Embedding Propagation Layers注意力嵌入传播层

在GCN的基础上,沿着高阶连通性,递归传播嵌入信息。同时使用图注意力网络的思想,生成注意力权值,来描述连接的重要性。 先看单层:

这一层的结构如上图所示,主要分为三个部分:

  • 信息传播(information propagation)
  • 知识意识注意(knowledge-aware attention)
  • 信息聚合(information aggregation)
  1. 信息传播
    同一个实体能够参与多个三元组,并且可以充当两个三元组的桥梁来传播信息,对于下面的一组示例:

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> e 1 → ⁡ r 2 i 2 → ⁡ − r 1 u 2 \mathrm{e}_1\overset{\mathrm{r}_2}{\operatorname*{\to}}\mathrm{i}_2\overset{\mathrm{-r}_1}{\operatorname*{\to}}\mathrm{u}_2 </math>e1→r2i2→−r1u2
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> e 2 → ⁡ r 3 i 2 → ⁡ − r 1 u 2 \mathrm{e}_2\overset{\mathrm{r}3}{\operatorname*{\to}}\mathrm{i}_2\overset{\mathrm{-r}1}{\operatorname*{\to}}\mathrm{u}_2 </math>e2→r3i2→−r1u2

<math xmlns="http://www.w3.org/1998/Math/MathML"> i 2 i_2 </math>i2可以靠 <math xmlns="http://www.w3.org/1998/Math/MathML"> e 1 e_1 </math>e1、 <math xmlns="http://www.w3.org/1998/Math/MathML"> e 2 e_2 </math>e2两个邻居的属性作为输入来细化自身嵌入,从而递归的细化 <math xmlns="http://www.w3.org/1998/Math/MathML"> u 2 u_2 </math>u2的嵌入:通过 <math xmlns="http://www.w3.org/1998/Math/MathML"> i 2 i_2 </math>i2将信息从 <math xmlns="http://www.w3.org/1998/Math/MathML"> e 1 e_1 </math>e1\ <math xmlns="http://www.w3.org/1998/Math/MathML"> e 2 e_2 </math>e2传播到 <math xmlns="http://www.w3.org/1998/Math/MathML"> u 2 u_2 </math>u2。这样就可以实现邻居间的信息传播。

模拟消息传播过程,现给定一个实体 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h, <math xmlns="http://www.w3.org/1998/Math/MathML"> N h = ( h , r , t ) ∣ ( h , r , t ) ∈ G N_h = {(h,r,t)|(h,r,t)∈G} </math>Nh=(h,r,t)∣(h,r,t)∈G是其作为头实体的全部三元组集合,为了刻画 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h的一阶连通性结构,计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> N h N_h </math>Nh的线性组合:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> e N h = ∑ ( h , r , t ) ∈ N π ( h , r , t ) e t e_{N_h}=\sum_{(h,r,t)\in N}\pi(h,r,t)e_t </math>eNh=(h,r,t)∈N∑π(h,r,t)et

这里的 <math xmlns="http://www.w3.org/1998/Math/MathML"> e N h e_{N_h} </math>eNh可以理解为 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h一阶邻居信息的聚合。其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> π ( h , r , t ) \pi(h,r,t) </math>π(h,r,t)控制着边 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( h , r , t ) (h,r,t) </math>(h,r,t)上每次传播衰减因子,表示有多少信息在 <math xmlns="http://www.w3.org/1998/Math/MathML"> r r </math>r关系下从 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t传播到 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h。(也就是权重)

输入:实体 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h的一阶邻居、 <math xmlns="http://www.w3.org/1998/Math/MathML"> π ( h , r , t ) \pi(h,r,t) </math>π(h,r,t)

输出: <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h一阶邻居信息的聚合(一个向量)

  1. 知识意识注意
    衰减因子 <math xmlns="http://www.w3.org/1998/Math/MathML"> π ( h , r , t ) \pi(h,r,t) </math>π(h,r,t)的实现机制是通过注意力机制来计算的:

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> π ( h , r , t ) = ( W r e t ) ⊤ tanh ⁡ ( ( W r e h + e r ) ) \pi(h,r,t)=(\mathbf{W}_r\mathbf{e}_t)^{\top}\tanh\biggl(({W}_r\mathbf{e}_h+\mathbf{e}_r)\biggr) </math>π(h,r,t)=(Wret)⊤tanh((Wreh+er))

选择 <math xmlns="http://www.w3.org/1998/Math/MathML"> t a n h ( ) tanh() </math>tanh()作为非线性激活函数。这使得注意力得分依赖于关系 <math xmlns="http://www.w3.org/1998/Math/MathML"> r r </math>r空间中 <math xmlns="http://www.w3.org/1998/Math/MathML"> e h e_h </math>eh和 <math xmlns="http://www.w3.org/1998/Math/MathML"> e t e_t </math>et之间的距离, <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t越靠近 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h, <math xmlns="http://www.w3.org/1998/Math/MathML"> π \pi </math>π越大(二者离得越近,传的消息越多 )。

作者只在这些表示上使用内积 ,并将注意力模块的进一步探索留作以后的工作。(可进一步改进)

之后,通过采用softmax函数(归一化指数函数)对与 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h相连的所有三元组的系数进行归一化:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> π ( h , r , t ) = exp ⁡ ( π ( h , r , t ) ) ∑ ( h , r ′ , t ′ ) ∈ N h exp ⁡ ( π ( h , r ′ , t ′ ) ) \pi(h,r,t)=\frac{\exp(\pi(h,r,t))}{\sum_{(h,r',t')\in\mathcal{N}_h}\exp(\pi(h,r',t'))} </math>π(h,r,t)=∑(h,r′,t′)∈Nhexp(π(h,r′,t′))exp(π(h,r,t))

最后的注意得分能够提示哪些邻居节点应该给予更多的权重,也能够作为建议的解释。

  1. 信息聚合
    三种,分别是GCN聚合器、GraphSage聚合器、双交互聚合器。
  • GCN Aggregator:
    GCN聚合器将两个表示相加并应用非线性变换。

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> f G C N = LeakyReLU ( W ( e h + e N h ) ) f_{\mathrm{GCN}}=\text{LeakyReLU}\Big(\mathrm{W}(\mathbf{e}{h}+\mathbf{e}{\mathcal{N}_{h}})\Big) </math>fGCN=LeakyReLU(W(eh+eNh))

  • GraphSage Aggregator:
    GraphSage聚合器拼接两个表示,然后是非线性转换:

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> f G r a p h S a g e = LeakyReLU ( W ( e h ∣ ∣ e N h ) ) f_{\mathrm{GraphSage}}=\text{LeakyReLU}\Big(\mathrm{W}(\mathrm{e}{h}||\mathrm{e}{\mathcal{N}_{h}})\Big) </math>fGraphSage=LeakyReLU(W(eh∣∣eNh))

  • Bi-Interaction Aggregator:
    双交互聚合器(作者提出),考虑 <math xmlns="http://www.w3.org/1998/Math/MathML"> e h e_h </math>eh和 <math xmlns="http://www.w3.org/1998/Math/MathML"> e N h e_{N_h} </math>eNh之间的两种特征交互:

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> f Bi-Interaction = L e a k y R e L U ( W 1 ( e h + e N h ) ) + LeakyReLU ⁡ ( W 2 ( e h ⊙ e N h ) ) f_\text{Bi-Interaction }=LeakyReLU\Big(W_1(e_h+e_{N_h})\Big)+\operatorname{LeakyReLU}\bigl(\mathsf{W}_2(\mathbf{e}h\odot\mathbf{e}{\mathcal{N}_h})\bigr) </math>fBi-Interaction =LeakyReLU(W1(eh+eNh))+LeakyReLU(W2(eh⊙eNh))

<math xmlns="http://www.w3.org/1998/Math/MathML"> W 1 , W 2 ∈ R d ′ × d \begin{array}{rcl}\mathbf{W}_1,\mathbf{W}_2&\in&\mathbb{R}^{d^{\prime}\times d}\end{array} </math>W1,W2∈Rd′×d为可训练权重矩阵, <math xmlns="http://www.w3.org/1998/Math/MathML"> ⊙ ⊙ </math>⊙表示元素乘积。

与GCN和GraphSage聚合器不同,双交互聚合器对 <math xmlns="http://www.w3.org/1998/Math/MathML"> e h e_h </math>eh和 <math xmlns="http://www.w3.org/1998/Math/MathML"> e N h e_{N_h} </math>eNh之间的特征交互进行了编码,这使得传播的信息对 <math xmlns="http://www.w3.org/1998/Math/MathML"> e h e_h </math>eh和 <math xmlns="http://www.w3.org/1998/Math/MathML"> e N h e_{N_h} </math>eNh之间的关联性敏感。(更好)

再看多层:

可以进一步堆叠更多的传播层来探索高阶连接性信息,收集从更高跳邻居传播的信息。更正式地说,我们递归地将实体的表示形式表示为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> e h ( l ) = f ( e h ( l − 1 ) , e N h ( l − 1 ) ) \mathbf{e}{h}^{(l)}=f\begin{pmatrix}\mathbf{e}{h}^{(l-1)},\mathbf{e}{N{h}}^{(l-1)}\end{pmatrix} </math>eh(l)=f(eh(l−1),eNh(l−1))

其中,在 <math xmlns="http://www.w3.org/1998/Math/MathML"> l − 1 l-1 </math>l−1层网络中,实体 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h传播的信息定义:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> e N h ( l − 1 ) = ∑ ( h , r , t ) ∈ N h π ( h , r , t ) e t ( l − 1 ) \mathbf{e}{N{h}}^{(l-1)}=\sum_{(h,r,t)\in{\mathcal N}{h}}\pi(h,r,t)\mathbf{e}{t}^{(l-1)} </math>eNh(l−1)=(h,r,t)∈Nh∑π(h,r,t)et(l−1)

<math xmlns="http://www.w3.org/1998/Math/MathML"> e t ( l − 1 ) e_t^{(l-1)} </math>et(l−1)是从先前的信息传播步骤生成的实体 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t的表示,它存储来自其 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( l − 1 ) (l-1) </math>(l−1)跳邻居的信息; <math xmlns="http://www.w3.org/1998/Math/MathML"> e h ( 0 ) e_h(0) </math>eh(0)集作为初始信息传播迭代的 <math xmlns="http://www.w3.org/1998/Math/MathML"> e h e_h </math>eh。它进一步有助于实体 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h在层 <math xmlns="http://www.w3.org/1998/Math/MathML"> l l </math>l上的表示。因此,在嵌入传播过程中,可以捕获高阶连接。高阶嵌入传播可以将基于属性的协作信号无缝地注入到表示学习的过程中.

3.3 Model Prediction预测层

在执行 <math xmlns="http://www.w3.org/1998/Math/MathML"> L L </math>L层之后,我们获得了用户节点 <math xmlns="http://www.w3.org/1998/Math/MathML"> u u </math>u的多层表示,即{ <math xmlns="http://www.w3.org/1998/Math/MathML"> e u 1 , ⋅ ⋅ , e u L e_{u_{1}},··,e_{u_{L}} </math>eu1,⋅⋅,euL};类似于项目节点 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i,{ <math xmlns="http://www.w3.org/1998/Math/MathML"> e i 1 , ⋅ ⋅ ⋅ , e i L e_{i_1},···,e_{i_L} </math>ei1,⋅⋅⋅,eiL}。由于第 <math xmlns="http://www.w3.org/1998/Math/MathML"> l l </math>l层的输出是图1所示根在 <math xmlns="http://www.w3.org/1998/Math/MathML"> u u </math>u(或 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i)处的 <math xmlns="http://www.w3.org/1998/Math/MathML"> l l </math>l层树结构深度的消息聚合,因此不同层的输出强调不同阶的连接信息。因此,我们采用层聚合机制将每一步的表示连接成一个向量:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> e u ∗ = e u ( 0 ) ∥ ⋯ ∥ e u ( L ) , e i ∗ = e i ( 0 ) ∥ ⋯ ∥ e i ( L ) \mathbf{e}{u}^{*}=\mathbf{e}{u}^{(0)}\|\cdots\|\mathbf{e}{u}^{(L)},\quad\mathbf{e}{i}^{*}=\mathbf{e}{i}^{(0)}\|\cdots\|\mathbf{e}{i}^{(L)} </math>eu∗=eu(0)∥⋯∥eu(L),ei∗=ei(0)∥⋯∥ei(L)

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∣ ∣ || </math>∣∣是拼接操作,通过这样做,我们不仅可以通过执行嵌入传播操作来丰富初始嵌入,还可以通过调整 <math xmlns="http://www.w3.org/1998/Math/MathML"> L L </math>L来控制传播强度。

最后,我们对用户和项目表示进行内积,以预测它们的匹配得分:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> y ^ ( u , i ) = e u ∗ ⊤ e i ∗ . \hat{y}(u,i)=\mathbf{e}_u^*{}^\top\mathbf{e}_i^*. </math>y^(u,i)=eu∗⊤ei∗.

3.4 Optimization

模型的优化使用BPR损失:它假设观察到的交互比未观察到的交互具有更高的预测值,这些交互表明了更多的用户偏好:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L C F = ∑ ( u , i , j ) ∈ O − ln ⁡ σ ( y ^ ( u , i ) − y ^ ( u , j ) ) \mathcal{L}{\mathrm{CF}}=\sum{(u,i,j)\in O}-\ln\sigma{\left(\hat{y}(u,i)-\hat{y}(u,j)\right)} </math>LCF=(u,i,j)∈O∑−lnσ(y^(u,i)−y^(u,j))

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> O = { ( u , i , j ) ∣ ( u , i ) ∈ R + , ( u , j ) ∈ R − } O=\{(u,i,j)|(u,i)\in\mathbb{R}^+,(u,j)\in\mathbb{R}^-\} </math>O={(u,i,j)∣(u,i)∈R+,(u,j)∈R−}表示训练集, <math xmlns="http://www.w3.org/1998/Math/MathML"> R + \mathbb{R}^+ </math>R+表示用户 <math xmlns="http://www.w3.org/1998/Math/MathML"> u u </math>u与项目 <math xmlns="http://www.w3.org/1998/Math/MathML"> j j </math>j之间观察到的(正的)交互作用, <math xmlns="http://www.w3.org/1998/Math/MathML"> R − \mathbb{R}^- </math>R−为抽样未观察到的(负的)交互作用集; <math xmlns="http://www.w3.org/1998/Math/MathML"> σ ( ⋅ ) σ(·) </math>σ(⋅)是 <math xmlns="http://www.w3.org/1998/Math/MathML"> s i g m o i d sigmoid </math>sigmoid函数。
总的损失函数为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L K G A T = L K G + L C F + λ ∥ Θ ∥ 2 2 , \mathcal{L}{\mathrm{KGAT}}=\mathcal{L}{\mathrm{KG}}+\mathcal{L}{\mathrm{CF}}+\lambda\left\|\Theta\right\|{2}^{2}, </math>LKGAT=LKG+LCF+λ∥Θ∥22,

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> Θ = { E , W r , ∀ l ∈ R , W 1 ( l ) , W 2 ( l ) , ∀ l ∈ { 1 , ⋯   , L } } \Theta=\{\mathrm{E},\mathbf{W}_r,\forall l\in\mathcal{R},\mathbf{W}1^{(l)},\mathbf{W}2^{(l)},\forall l\in\{1,\cdots,L\}\} </math>Θ={E,Wr,∀l∈R,W1(l),W2(l),∀l∈{1,⋯,L}}是模型参数集, <math xmlns="http://www.w3.org/1998/Math/MathML"> E E </math>E是所有实体和关系的嵌入表;进行由 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ θ </math>θ上的 <math xmlns="http://www.w3.org/1998/Math/MathML"> λ λ </math>λ参数化的L2正则化以防止过拟合。
训练方式: 交替优化 <math xmlns="http://www.w3.org/1998/Math/MathML"> L K G \mathcal{L}
\mathrm{KG} </math>LKG和 <math xmlns="http://www.w3.org/1998/Math/MathML"> L C F \mathcal{L}
{\mathrm{CF}} </math>LCF,优化器使用Adam。

4. 模型验证

从三部分对模型进行验证:
RQ1性能比较 ------与最先进的知识感知推荐方法相比,KGAT的表现如何?
RQ2消融研究 ------不同的组分(例如,知识图嵌入、注意机制和聚合器选择)如何影响KGAT的性能?
RQ3实例验证------KGAT能否合理解释用户对物品的偏好?

4.1 数据集

在三个数据集上进行评估:

  1. Amazon-book 书籍推荐
  2. Last-FM 音乐推荐
  3. Yelp2018 餐馆和酒吧等视为item

需要为每一个数据集构建知识: 对于Amazon-book和Last-FM,遵循[KB4Rec: A Data Set for Linking Knowledge Bases with Recommender Systems]中的方法 如果有可用的映射,则通过标题匹配将条目映射到Freebase实体。考虑与项目对其的实体直接相关的三元组。同时还考虑了两跳的邻居三元组。

对于Yelp2018,从本地商业信息网络中提取项目知识(如类别、位置和属性)作为KG数据。

对三个知识图谱,取出数据集中出现频率低的实体,对于每个数据集,随机选择80%作为训练集,其余是测试集。

对于交互数据,每一个交互过的实例,随机抽取一个未消费过的负物品配对。

4.2 问题1的性能比较

KGAT始终在所有数据集上产生最佳性能。特别是,KGAT比亚马逊图书、LastFM和Yelp2018的最强基线w.r.t. recall@20分别提高了8.95%、4.93%和7.18%。KGAT通过叠加多个注意力嵌入传播层,能够显式地探索高阶连接,从而有效地捕获协同信号。这验证了捕获协同信号传递知识的重要性。此外,与GC-MC相比,KGAT对注意机制的有效性进行了验证,指出注意权重是根据成分语义关系确定的,而不是GC-MC中使用的固定权重。

  • KGAT在大多数情况下优于其他模型,特别是在Amazon-Book和Yelp2018中两个最稀疏的用户群上。再次验证了高阶连接建模的重要性,高阶连接建模包含了基线中使用的低阶连接;高阶连接建模通过递归嵌入传播丰富了非活跃用户的表示。
  • 值得指出的是,KGAT在最密集的用户组(例如,Yelp2018的< 2057组)中表现略好于一些基线。一个可能的原因是,有太多交互的用户的偏好太普遍,难以捕获。高阶连接可能会给用户的偏好带来更多的干扰,从而导致负面影响。

4.3 KGAT的研究

分别从层数、聚合器、知识图谱嵌入和注意力机制的影响。 在不同的传播深度下,模型的效果验证:

聚合函数的影响

知识图谱嵌入效应与注意机制 消融实验:通过禁用KGAT模型的几个部分来进行研究

  1. w/o KGE 禁用TransR(去除知识图谱嵌入)
  2. w/o Att 把 <math xmlns="http://www.w3.org/1998/Math/MathML"> π ( h , r , t ) π(h, r, t) </math>π(h,r,t)设为 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 / ∣ N h ∣ 1/|N_h| </math>1/∣Nh∣
  3. w/o K&A 综合前两种
  • 去掉知识图嵌入和注意成分会降低模型的性能。KGAT-1w/o k&A 始终不如KGAT-1w/o kge 和KGAT-1w/o Att。这是有意义的,因为KGATw/o k&A未能显式地在三元组的粒度上建模表示关系。
  • 与KGAT-1w/o Att相比,KGAT-1w/o kgeat在大多数情况下性能更好。一个可能的原因是平等对待所有的邻居(即KGAT-1w/o Att)可能会引入噪声并误导嵌入传播过程。验证了图形注意机制的实质影响。

三、代码分析

1. 数据集文件

  • train.txt:userId,itemId(list)
  • test.txt:userId,itemId(list)
  • user_list.txt:org_id,remap_id(原数据集ID、本文使用ID)
  • item_list.txt:org_id,remap_id,freebase_id(原数据集ID、本文使用ID、freebaseID)
  • entity_list.txt:freebase_id,remap_id(freebaseID、本文使用ID)

item_list里面的数据包含在entity_list中 以Last-FM为例: Items有48123,那么item_list.txt有48123行(单纯的音乐item) Entitys有58266,那么entity_list.txt有48123+58266=106389行(音乐item+外部信息(如作词作曲?))

  • relation_list.txtfreebase_id,remap_id(freebaseID、本文使用ID)

freebase_id 是不重要的,只要管remap_id

  • kg_final.txtentityId、relationId、entityId

第一列和第三列都是entity的id(音乐、作词、作曲) 例如: 孤独之书 原唱 房东的猫 房东的猫 职业 歌手

2. 数据集的处理

  • 将kg添加逆关系,并对关系重新编号,做法是+2;将user-item交互图融入kg中,将user重新编号user id+实体总数,将user-item编码为0,将item-user编码为1
  • 采样一个batch_size的数据,包含bath_size的user,以及为每一个user采样user-item交互的正样例,负样例
  • 对产生的CKG图{h:(r,t)}进行采样生成负例正例。

loader_base.py

python 复制代码
 
    def load_cf(self, filename):
        """
        函数说明:对user-item交互矩阵进行处理
        Return:
            (user, item) - user和其作用的item
            user_dict - {user-id:[item1,item2,..],}
        """
        user = []
        item = []
        user_dict = dict()
 
        lines = open(filename, 'r').readlines()
        for l in lines:
            tmp = l.strip()
            inter = [int(i) for i in tmp.split()]
 
            if len(inter) > 1:
                user_id, item_ids = inter[0], inter[1:]
                item_ids = list(set(item_ids))
 
                for item_id in item_ids:
                    user.append(user_id)
                    item.append(item_id)
                user_dict[user_id] = item_ids
 
        user = np.array(user, dtype=np.int32)
        item = np.array(item, dtype=np.int32)
        return (user, item), user_dict
 
    def statistic_cf(self):
        """
        获取user、item、训练集、测试集总数
        """
        self.n_users = max(max(self.cf_train_data[0]), max(self.cf_test_data[0])) + 1
        self.n_items = max(max(self.cf_train_data[1]), max(self.cf_test_data[1])) + 1
        self.n_cf_train = len(self.cf_train_data[0])
        self.n_cf_test = len(self.cf_test_data[0])
 
 
    def load_kg(self, filename):
        """
        读取最后的CKG数据,返回dataframe形式
        """
        kg_data = pd.read_csv(filename, sep=' ', names=['h', 'r', 't'], engine='python')
        kg_data = kg_data.drop_duplicates()
        return kg_data
 
 
    def sample_pos_items_for_u(self, user_dict, user_id, n_sample_pos_items):
        """
        对user-item交互正样本进行采样
        """
        pos_items = user_dict[user_id]
        n_pos_items = len(pos_items)
 
        sample_pos_items = []
        while True:
            if len(sample_pos_items) == n_sample_pos_items:
                break
 
            pos_item_idx = np.random.randint(low=0, high=n_pos_items, size=1)[0]
            pos_item_id = pos_items[pos_item_idx]
            if pos_item_id not in sample_pos_items:
                sample_pos_items.append(pos_item_id)
        return sample_pos_items
 
 
    def sample_neg_items_for_u(self, user_dict, user_id, n_sample_neg_items):
        """
        为user-item交互采样负样例
        """
        pos_items = user_dict[user_id]
 
        sample_neg_items = []
        while True:
            if len(sample_neg_items) == n_sample_neg_items:
                break
 
            neg_item_id = np.random.randint(low=0, high=self.n_items, size=1)[0]
            if neg_item_id not in pos_items and neg_item_id not in sample_neg_items:
                sample_neg_items.append(neg_item_id)
        return sample_neg_items
 
 
    def generate_cf_batch(self, user_dict, batch_size):
        """
        采样batch_size的user,并对对这些user采样正样本,负样本
        """
        exist_users = user_dict.keys()
        if batch_size <= len(exist_users):
            batch_user = random.sample(exist_users, batch_size)
        else:
            batch_user = [random.choice(exist_users) for _ in range(batch_size)]
 
        batch_pos_item, batch_neg_item = [], []
        for u in batch_user:
            # 为每一个采样的user生成一个正样例和一个负样例
            batch_pos_item += self.sample_pos_items_for_u(user_dict, u, 1)
            batch_neg_item += self.sample_neg_items_for_u(user_dict, u, 1)
 
        batch_user = torch.LongTensor(batch_user)
        batch_pos_item = torch.LongTensor(batch_pos_item)
        batch_neg_item = torch.LongTensor(batch_neg_item)
        return batch_user, batch_pos_item, batch_neg_item
 
 
    def sample_pos_triples_for_h(self, kg_dict, head, n_sample_pos_triples):
        """
        为融合user-item交互的CKG图采样正例
        """
        pos_triples = kg_dict[head]
        n_pos_triples = len(pos_triples)
 
        sample_relations, sample_pos_tails = [], []
        while True:
            if len(sample_relations) == n_sample_pos_triples:
                break
 
            pos_triple_idx = np.random.randint(low=0, high=n_pos_triples, size=1)[0]
            tail = pos_triples[pos_triple_idx][0]
            relation = pos_triples[pos_triple_idx][1]
 
            if relation not in sample_relations and tail not in sample_pos_tails:
                sample_relations.append(relation)
                sample_pos_tails.append(tail)
        return sample_relations, sample_pos_tails
 
 
    def sample_neg_triples_for_h(self, kg_dict, head, relation, n_sample_neg_triples, highest_neg_idx):
        """
        为融合user-item交互的CKG图采样负例
        """
        pos_triples = kg_dict[head]
 
        sample_neg_tails = []
        while True:
            if len(sample_neg_tails) == n_sample_neg_triples:
                break
 
            tail = np.random.randint(low=0, high=highest_neg_idx, size=1)[0]
            if (tail, relation) not in pos_triples and tail not in sample_neg_tails:
                sample_neg_tails.append(tail)
        return sample_neg_tails
 
 
    def generate_kg_batch(self, kg_dict, batch_size, highest_neg_idx):
        """为训练集CKG中每一个头实体采样一个正例的(r,t),一个负例的t"""
        exist_heads = kg_dict.keys()
        if batch_size <= len(exist_heads):
            batch_head = random.sample(exist_heads, batch_size)
        else:
            batch_head = [random.choice(exist_heads) for _ in range(batch_size)]
 
        batch_relation, batch_pos_tail, batch_neg_tail = [], [], []
        for h in batch_head:
            relation, pos_tail = self.sample_pos_triples_for_h(kg_dict, h, 1)
            batch_relation += relation
            batch_pos_tail += pos_tail
 
            neg_tail = self.sample_neg_triples_for_h(kg_dict, h, relation[0], 1, highest_neg_idx)
            batch_neg_tail += neg_tail
 
        batch_head = torch.LongTensor(batch_head)
        batch_relation = torch.LongTensor(batch_relation)
        batch_pos_tail = torch.LongTensor(batch_pos_tail)
        batch_neg_tail = torch.LongTensor(batch_neg_tail)
        return batch_head, batch_relation, batch_pos_tail, batch_neg_tail

loader_kgat.py

python 复制代码
    def construct_data(self, kg_data):
        """
        函数说明:创建逆边,并把user-item交互图融入,创建CKG
        """
        # add inverse kg data
        n_relations = max(kg_data['r']) + 1
        inverse_kg_data = kg_data.copy()
        inverse_kg_data = inverse_kg_data.rename({'h': 't', 't': 'h'}, axis='columns')
        inverse_kg_data['r'] += n_relations
        kg_data = pd.concat([kg_data, inverse_kg_data], axis=0, ignore_index=True, sort=False)
 
        # re-map user id
        kg_data['r'] += 2
        self.n_relations = max(kg_data['r']) + 1
        self.n_entities = max(max(kg_data['h']), max(kg_data['t'])) + 1
        self.n_users_entities = self.n_users + self.n_entities
 
        # re-map user id = user-item中的id + num_entities
        self.cf_train_data = (np.array(list(map(lambda d: d + self.n_entities, self.cf_train_data[0]))).astype(np.int32), self.cf_train_data[1].astype(np.int32))
        self.cf_test_data = (np.array(list(map(lambda d: d + self.n_entities, self.cf_test_data[0]))).astype(np.int32), self.cf_test_data[1].astype(np.int32))
 
        self.train_user_dict = {k + self.n_entities: np.unique(v).astype(np.int32) for k, v in self.train_user_dict.items()}
        self.test_user_dict = {k + self.n_entities: np.unique(v).astype(np.int32) for k, v in self.test_user_dict.items()}
 
        # add interactions to kg data
        # 将user-item交互数据融入kg中user交互item的关系编码为0,item-user交互编码为1
        cf2kg_train_data = pd.DataFrame(np.zeros((self.n_cf_train, 3), dtype=np.int32), columns=['h', 'r', 't'])
        cf2kg_train_data['h'] = self.cf_train_data[0]
        cf2kg_train_data['t'] = self.cf_train_data[1]
 
        inverse_cf2kg_train_data = pd.DataFrame(np.ones((self.n_cf_train, 3), dtype=np.int32), columns=['h', 'r', 't'])
        inverse_cf2kg_train_data['h'] = self.cf_train_data[1]
        inverse_cf2kg_train_data['t'] = self.cf_train_data[0]
 
        self.kg_train_data = pd.concat([kg_data, cf2kg_train_data, inverse_cf2kg_train_data], ignore_index=True)
        self.n_kg_train = len(self.kg_train_data)
 
        # construct kg dict
        h_list = []
        t_list = []
        r_list = []
 
        self.train_kg_dict = collections.defaultdict(list)
        self.train_relation_dict = collections.defaultdict(list)
 
        for row in self.kg_train_data.iterrows():
            h, r, t = row[1]
            h_list.append(h)
            t_list.append(t)
            r_list.append(r)
 
            self.train_kg_dict[h].append((t, r))
            self.train_relation_dict[r].append((h, t))
 
        self.h_list = torch.LongTensor(h_list)
        self.t_list = torch.LongTensor(t_list)
        self.r_list = torch.LongTensor(r_list)

KGAT.py

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
 
 
def _L2_loss_mean(x):
    return torch.mean(torch.sum(torch.pow(x, 2), dim=1, keepdim=False) / 2.)
 
 
class Aggregator(nn.Module):
 
    def __init__(self, in_dim, out_dim, dropout, aggregator_type):
        super(Aggregator, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.dropout = dropout
        self.aggregator_type = aggregator_type
 
        self.message_dropout = nn.Dropout(dropout)
        self.activation = nn.LeakyReLU()
 
        if self.aggregator_type == 'gcn':
            self.linear = nn.Linear(self.in_dim, self.out_dim)       # W in Equation (6)
            nn.init.xavier_uniform_(self.linear.weight)
 
        elif self.aggregator_type == 'graphsage':
            self.linear = nn.Linear(self.in_dim * 2, self.out_dim)   # W in Equation (7)
            nn.init.xavier_uniform_(self.linear.weight)
 
        elif self.aggregator_type == 'bi-interaction':
            self.linear1 = nn.Linear(self.in_dim, self.out_dim)      # W1 in Equation (8)
            self.linear2 = nn.Linear(self.in_dim, self.out_dim)      # W2 in Equation (8)
            nn.init.xavier_uniform_(self.linear1.weight)
            nn.init.xavier_uniform_(self.linear2.weight)
 
        else:
            raise NotImplementedError
 
 
    def forward(self, ego_embeddings, A_in):
        """
        ego_embeddings:  (n_users + n_entities, in_dim)
        A_in:            (n_users + n_entities, n_users + n_entities), torch.sparse.FloatTensor
        """
        # Equation (3)
        side_embeddings = torch.matmul(A_in, ego_embeddings)
 
        if self.aggregator_type == 'gcn':
            # Equation (6) & (9)
            embeddings = ego_embeddings + side_embeddings
            embeddings = self.activation(self.linear(embeddings))
 
        elif self.aggregator_type == 'graphsage':
            # Equation (7) & (9)
            embeddings = torch.cat([ego_embeddings, side_embeddings], dim=1)
            embeddings = self.activation(self.linear(embeddings))
 
        elif self.aggregator_type == 'bi-interaction':
            # Equation (8) & (9)
            sum_embeddings = self.activation(self.linear1(ego_embeddings + side_embeddings))
            bi_embeddings = self.activation(self.linear2(ego_embeddings * side_embeddings))
            embeddings = bi_embeddings + sum_embeddings
 
        embeddings = self.message_dropout(embeddings)           # (n_users + n_entities, out_dim)
        return embeddings
 
 
class KGAT(nn.Module):
 
    def __init__(self, args,
                 n_users, n_entities, n_relations, A_in=None,
                 user_pre_embed=None, item_pre_embed=None):
 
        super(KGAT, self).__init__()
        self.use_pretrain = args.use_pretrain
 
        self.n_users = n_users
        self.n_entities = n_entities
        self.n_relations = n_relations
 
        self.embed_dim = args.embed_dim
        self.relation_dim = args.relation_dim
 
        self.aggregation_type = args.aggregation_type
        self.conv_dim_list = [args.embed_dim] + eval(args.conv_dim_list)
        self.mess_dropout = eval(args.mess_dropout)
        self.n_layers = len(eval(args.conv_dim_list))
 
        self.kg_l2loss_lambda = args.kg_l2loss_lambda
        self.cf_l2loss_lambda = args.cf_l2loss_lambda
 
        self.entity_user_embed = nn.Embedding(self.n_entities + self.n_users, self.embed_dim)
        self.relation_embed = nn.Embedding(self.n_relations, self.relation_dim)
        self.trans_M = nn.Parameter(torch.Tensor(self.n_relations, self.embed_dim, self.relation_dim))
 
        if (self.use_pretrain == 1) and (user_pre_embed is not None) and (item_pre_embed is not None):
            other_entity_embed = nn.Parameter(torch.Tensor(self.n_entities - item_pre_embed.shape[0], self.embed_dim))
            nn.init.xavier_uniform_(other_entity_embed)
            entity_user_embed = torch.cat([item_pre_embed, other_entity_embed, user_pre_embed], dim=0)
            self.entity_user_embed.weight = nn.Parameter(entity_user_embed)
        else:
            nn.init.xavier_uniform_(self.entity_user_embed.weight)
 
        nn.init.xavier_uniform_(self.relation_embed.weight)
        nn.init.xavier_uniform_(self.trans_M)
 
        self.aggregator_layers = nn.ModuleList()
        for k in range(self.n_layers):
            self.aggregator_layers.append(Aggregator(self.conv_dim_list[k], self.conv_dim_list[k + 1], self.mess_dropout[k], self.aggregation_type))
 
        # A是邻接矩阵
        self.A_in = nn.Parameter(torch.sparse.FloatTensor(self.n_users + self.n_entities, self.n_users + self.n_entities))
        if A_in is not None:
            self.A_in.data = A_in
        self.A_in.requires_grad = False
 
 
    def calc_cf_embeddings(self):
        """
        计算多层的消息传递和聚合
        """
        ego_embed = self.entity_user_embed.weight
        all_embed = [ego_embed]
 
        for idx, layer in enumerate(self.aggregator_layers):
            ego_embed = layer(ego_embed, self.A_in)
            norm_embed = F.normalize(ego_embed, p=2, dim=1)
            all_embed.append(norm_embed)
 
        # Equation (11)
        all_embed = torch.cat(all_embed, dim=1)         # (n_users + n_entities, concat_dim)
        return all_embed
 
 
    def calc_cf_loss(self, user_ids, item_pos_ids, item_neg_ids):
        """
        user_ids:       (cf_batch_size)
        item_pos_ids:   (cf_batch_size)
        item_neg_ids:   (cf_batch_size)
        """
        all_embed = self.calc_cf_embeddings()                       # (n_users + n_entities, concat_dim)
        user_embed = all_embed[user_ids]                            # (cf_batch_size, concat_dim)
        item_pos_embed = all_embed[item_pos_ids]                    # (cf_batch_size, concat_dim)
        item_neg_embed = all_embed[item_neg_ids]                    # (cf_batch_size, concat_dim)
 
        # Equation (12)
        pos_score = torch.sum(user_embed * item_pos_embed, dim=1)   # (cf_batch_size)
        neg_score = torch.sum(user_embed * item_neg_embed, dim=1)   # (cf_batch_size)
 
        # Equation (13)
        # cf_loss = F.softplus(neg_score - pos_score)
        cf_loss = (-1.0) * F.logsigmoid(pos_score - neg_score)
        cf_loss = torch.mean(cf_loss)
 
        l2_loss = _L2_loss_mean(user_embed) + _L2_loss_mean(item_pos_embed) + _L2_loss_mean(item_neg_embed)
        loss = cf_loss + self.cf_l2loss_lambda * l2_loss
        return loss
 
 
    def calc_kg_loss(self, h, r, pos_t, neg_t):
        """
        h:      (kg_batch_size)
        r:      (kg_batch_size)
        pos_t:  (kg_batch_size)
        neg_t:  (kg_batch_size)
        """
        r_embed = self.relation_embed(r)                                                # (kg_batch_size, relation_dim)
        W_r = self.trans_M[r]                                                           # (kg_batch_size, embed_dim, relation_dim)
 
        h_embed = self.entity_user_embed(h)                                             # (kg_batch_size, embed_dim)
        pos_t_embed = self.entity_user_embed(pos_t)                                     # (kg_batch_size, embed_dim)
        neg_t_embed = self.entity_user_embed(neg_t)                                     # (kg_batch_size, embed_dim)
 
        r_mul_h = torch.bmm(h_embed.unsqueeze(1), W_r).squeeze(1)                       # (kg_batch_size, relation_dim)
        r_mul_pos_t = torch.bmm(pos_t_embed.unsqueeze(1), W_r).squeeze(1)               # (kg_batch_size, relation_dim)
        r_mul_neg_t = torch.bmm(neg_t_embed.unsqueeze(1), W_r).squeeze(1)               # (kg_batch_size, relation_dim)
 
        # Equation (1)
        pos_score = torch.sum(torch.pow(r_mul_h + r_embed - r_mul_pos_t, 2), dim=1)     # (kg_batch_size)
        neg_score = torch.sum(torch.pow(r_mul_h + r_embed - r_mul_neg_t, 2), dim=1)     # (kg_batch_size)
 
        # Equation (2)
        # kg_loss = F.softplus(pos_score - neg_score)
        kg_loss = (-1.0) * F.logsigmoid(neg_score - pos_score)
        kg_loss = torch.mean(kg_loss)
 
        l2_loss = _L2_loss_mean(r_mul_h) + _L2_loss_mean(r_embed) + _L2_loss_mean(r_mul_pos_t) + _L2_loss_mean(r_mul_neg_t)
        loss = kg_loss + self.kg_l2loss_lambda * l2_loss
        return loss
 
 
    def update_attention_batch(self, h_list, t_list, r_idx):
        """
        更新注意力权重
        """
        r_embed = self.relation_embed.weight[r_idx]
        W_r = self.trans_M[r_idx]
 
        h_embed = self.entity_user_embed.weight[h_list]
        t_embed = self.entity_user_embed.weight[t_list]
 
        # Equation (4)
        r_mul_h = torch.matmul(h_embed, W_r)
        r_mul_t = torch.matmul(t_embed, W_r)
        v_list = torch.sum(r_mul_t * torch.tanh(r_mul_h + r_embed), dim=1)
        return v_list
 
 
    def update_attention(self, h_list, t_list, r_list, relations):
        device = self.A_in.device
 
        rows = []
        cols = []
        values = []
 
        for r_idx in relations:
            index_list = torch.where(r_list == r_idx)
            batch_h_list = h_list[index_list]
            batch_t_list = t_list[index_list]
 
            batch_v_list = self.update_attention_batch(batch_h_list, batch_t_list, r_idx)
            rows.append(batch_h_list)
            cols.append(batch_t_list)
            values.append(batch_v_list)
 
        rows = torch.cat(rows)
        cols = torch.cat(cols)
        values = torch.cat(values)
 
        indices = torch.stack([rows, cols])
        shape = self.A_in.shape
        A_in = torch.sparse.FloatTensor(indices, values, torch.Size(shape))
 
        # Equation (5)
        A_in = torch.sparse.softmax(A_in.cpu(), dim=1)
        self.A_in.data = A_in.to(device)
 
 
    def calc_score(self, user_ids, item_ids):
        """
        user_ids:  (n_users)
        item_ids:  (n_items)
        计算user点击item的得分
        """
        all_embed = self.calc_cf_embeddings()           # (n_users + n_entities, concat_dim)
        user_embed = all_embed[user_ids]                # (n_users, concat_dim)
        item_embed = all_embed[item_ids]                # (n_items, concat_dim)
 
        # Equation (12)
        cf_score = torch.matmul(user_embed, item_embed.transpose(0, 1))    # (n_users, n_items)
        return cf_score
 
 
    def forward(self, *input, mode):
        if mode == 'train_cf':
            return self.calc_cf_loss(*input)
        if mode == 'train_kg':
            return self.calc_kg_loss(*input)
        if mode == 'update_att':
            return self.update_attention(*input)
        if mode == 'predict':
            return self.calc_score(*input)

3. 模型训练

主要包括交替训练CF与KGC两个任务,并在每次交替训练后更新消息传递的权重。 main_kgat.py

ini 复制代码
def train(args):
    # seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
 
    log_save_id = create_log_id(args.save_dir)
    logging_config(folder=args.save_dir, name='log{:d}'.format(log_save_id), no_console=False)
    logging.info(args)
 
    # GPU / CPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
    # load data
    data = DataLoaderKGAT(args, logging)
    if args.use_pretrain == 1:
        user_pre_embed = torch.tensor(data.user_pre_embed)
        item_pre_embed = torch.tensor(data.item_pre_embed)
    else:
        user_pre_embed, item_pre_embed = None, None
 
    # construct model & optimizer
    model = KGAT(args, data.n_users, data.n_entities, data.n_relations, data.A_in, user_pre_embed, item_pre_embed)
    if args.use_pretrain == 2:
        model = load_model(model, args.pretrain_model_path)
 
    model.to(device)
    logging.info(model)
 
    cf_optimizer = optim.Adam(model.parameters(), lr=args.lr)
    kg_optimizer = optim.Adam(model.parameters(), lr=args.lr)
 
    # initialize metrics
    best_epoch = -1
    best_recall = 0
 
    Ks = eval(args.Ks)
    k_min = min(Ks)
    k_max = max(Ks)
 
    epoch_list = []
    metrics_list = {k: {'precision': [], 'recall': [], 'ndcg': []} for k in Ks}
 
    # train model
    for epoch in range(1, args.n_epoch + 1):
        time0 = time()
        model.train()
 
        # train cf
        time1 = time()
        cf_total_loss = 0
        n_cf_batch = data.n_cf_train // data.cf_batch_size + 1
        # 交替训练CF与KGC
        for iter in range(1, n_cf_batch + 1):
            time2 = time()
            # 采样一个cf_batch_size的user list,并为user list中的每一个user采样一个正样例和负样例。
            cf_batch_user, cf_batch_pos_item, cf_batch_neg_item = data.generate_cf_batch(data.train_user_dict, data.cf_batch_size)
            cf_batch_user = cf_batch_user.to(device)
            cf_batch_pos_item = cf_batch_pos_item.to(device)
            cf_batch_neg_item = cf_batch_neg_item.to(device)
 
 
            cf_batch_loss = model(cf_batch_user, cf_batch_pos_item, cf_batch_neg_item, mode='train_cf')
 
            if np.isnan(cf_batch_loss.cpu().detach().numpy()):
                logging.info('ERROR (CF Training): Epoch {:04d} Iter {:04d} / {:04d} Loss is nan.'.format(epoch, iter, n_cf_batch))
                sys.exit()
 
            cf_batch_loss.backward()
            cf_optimizer.step()
            cf_optimizer.zero_grad()
            cf_total_loss += cf_batch_loss.item()
 
            if (iter % args.cf_print_every) == 0:
                logging.info('CF Training: Epoch {:04d} Iter {:04d} / {:04d} | Time {:.1f}s | Iter Loss {:.4f} | Iter Mean Loss {:.4f}'.format(epoch, iter, n_cf_batch, time() - time2, cf_batch_loss.item(), cf_total_loss / iter))
        logging.info('CF Training: Epoch {:04d} Total Iter {:04d} | Total Time {:.1f}s | Iter Mean Loss {:.4f}'.format(epoch, n_cf_batch, time() - time1, cf_total_loss / n_cf_batch))
 
        # train kg
        time3 = time()
        kg_total_loss = 0
        n_kg_batch = data.n_kg_train // data.kg_batch_size + 1
 
        for iter in range(1, n_kg_batch + 1):
            time4 = time()
            kg_batch_head, kg_batch_relation, kg_batch_pos_tail, kg_batch_neg_tail = data.generate_kg_batch(data.train_kg_dict, data.kg_batch_size, data.n_users_entities)
            kg_batch_head = kg_batch_head.to(device)
            kg_batch_relation = kg_batch_relation.to(device)
            kg_batch_pos_tail = kg_batch_pos_tail.to(device)
            kg_batch_neg_tail = kg_batch_neg_tail.to(device)
 
            kg_batch_loss = model(kg_batch_head, kg_batch_relation, kg_batch_pos_tail, kg_batch_neg_tail, mode='train_kg')
 
            if np.isnan(kg_batch_loss.cpu().detach().numpy()):
                logging.info('ERROR (KG Training): Epoch {:04d} Iter {:04d} / {:04d} Loss is nan.'.format(epoch, iter, n_kg_batch))
                sys.exit()
 
            kg_batch_loss.backward()
            kg_optimizer.step()
            kg_optimizer.zero_grad()
            kg_total_loss += kg_batch_loss.item()
 
            if (iter % args.kg_print_every) == 0:
                logging.info('KG Training: Epoch {:04d} Iter {:04d} / {:04d} | Time {:.1f}s | Iter Loss {:.4f} | Iter Mean Loss {:.4f}'.format(epoch, iter, n_kg_batch, time() - time4, kg_batch_loss.item(), kg_total_loss / iter))
        logging.info('KG Training: Epoch {:04d} Total Iter {:04d} | Total Time {:.1f}s | Iter Mean Loss {:.4f}'.format(epoch, n_kg_batch, time() - time3, kg_total_loss / n_kg_batch))
        # 交替训练完一次更新注意力权重
        # update attention
        time5 = time()
        # h_list/t_list/r_list是CKG图中所有的头实体、关系、尾实体列表
        h_list = data.h_list.to(device)
        t_list = data.t_list.to(device)
        r_list = data.r_list.to(device)
        relations = list(data.laplacian_dict.keys())
        model(h_list, t_list, r_list, relations, mode='update_att')
        logging.info('Update Attention: Epoch {:04d} | Total Time {:.1f}s'.format(epoch, time() - time5))
 
        logging.info('CF + KG Training: Epoch {:04d} | Total Time {:.1f}s'.format(epoch, time() - time0))
 
        # evaluate cf
        if (epoch % args.evaluate_every) == 0 or epoch == args.n_epoch:
            time6 = time()
            _, metrics_dict = evaluate(model, data, Ks, device)
            logging.info('CF Evaluation: Epoch {:04d} | Total Time {:.1f}s | Precision [{:.4f}, {:.4f}], Recall [{:.4f}, {:.4f}], NDCG [{:.4f}, {:.4f}]'.format(
                epoch, time() - time6, metrics_dict[k_min]['precision'], metrics_dict[k_max]['precision'], metrics_dict[k_min]['recall'], metrics_dict[k_max]['recall'], metrics_dict[k_min]['ndcg'], metrics_dict[k_max]['ndcg']))
 
            epoch_list.append(epoch)
            for k in Ks:
                for m in ['precision', 'recall', 'ndcg']:
                    metrics_list[k][m].append(metrics_dict[k][m])
            best_recall, should_stop = early_stopping(metrics_list[k_min]['recall'], args.stopping_steps)
 
            if should_stop:
                break
 
            if metrics_list[k_min]['recall'].index(best_recall) == len(epoch_list) - 1:
                save_model(model, args.save_dir, epoch, best_epoch)
                logging.info('Save model on epoch {:04d}!'.format(epoch))
                best_epoch = epoch
 
    # save metrics
    metrics_df = [epoch_list]
    metrics_cols = ['epoch_idx']
    for k in Ks:
        for m in ['precision', 'recall', 'ndcg']:
            metrics_df.append(metrics_list[k][m])
            metrics_cols.append('{}@{}'.format(m, k))
    metrics_df = pd.DataFrame(metrics_df).transpose()
    metrics_df.columns = metrics_cols
    metrics_df.to_csv(args.save_dir + '/metrics.tsv', sep='\t', index=False)
 
    # print best metrics
    best_metrics = metrics_df.loc[metrics_df['epoch_idx'] == best_epoch].iloc[0].to_dict()
    logging.info('Best CF Evaluation: Epoch {:04d} | Precision [{:.4f}, {:.4f}], Recall [{:.4f}, {:.4f}], NDCG [{:.4f}, {:.4f}]'.format(
        int(best_metrics['epoch_idx']), best_metrics['precision@{}'.format(k_min)], best_metrics['precision@{}'.format(k_max)], best_metrics['recall@{}'.format(k_min)], best_metrics['recall@{}'.format(k_max)], best_metrics['ndcg@{}'.format(k_min)], best_metrics['ndcg@{}'.format(k_max)]))
 

Ref

  1. 推荐系统读书笔记(一):KGAT
相关推荐
云泽8084 分钟前
深入 AVL 树:原理剖析、旋转算法与性能评估
数据结构·c++·算法
CareyWYR6 分钟前
每周AI论文速递(260323-260327)
人工智能
guoji778828 分钟前
安全与对齐的深层博弈:Gemini 3.1 Pro 安全护栏与对抗测试深度拆解
人工智能·安全
实在智能RPA36 分钟前
实在 Agent 和通用大模型有什么不一样?深度拆解 AI Agent 的感知、决策与执行逻辑
人工智能·ai
独隅40 分钟前
PyTorch 模型部署的 Docker 配置与性能调优深入指南
人工智能·pytorch·docker
lihuayong1 小时前
OpenClaw 系统提示词
人工智能·prompt·提示词·openclaw
Wilber的技术分享1 小时前
【LeetCode高频手撕题 2】面试中常见的手撕算法题(小红书)
笔记·算法·leetcode·面试
邪神与厨二病1 小时前
Problem L. ZZUPC
c++·数学·算法·前缀和
黑客说1 小时前
AI驱动剧情,解锁无限可能——AI游戏发展解析
人工智能·游戏