GraphPro: Graph Pre-training and Prompt Learning for Recommendation
北京B区 / 032机
北京B区 / 224机
数据集介绍
本文使用了三个公开的数据集来进行实验和评估GraphPro框架的性能。这些数据集分别代表了不同的商业场景和用户交互模式,具有丰富的动态交互数据,非常适合用来评估推荐系统在处理时间变化数据时的有效性。以下是对这些数据集的详细介绍:
1. Taobao数据集
- 来源:Taobao是中国最大的电子商务平台之一,数据集收集了用户在该平台上的隐式反馈数据。
- 时间跨度:数据集覆盖了10天的时间段。
- 数据特点:包含了大量的用户点击和购买行为,这些行为被用作推荐系统的训练和测试数据。
- 规模:具有大量的用户和物品,以及相应的交互记录。
2. Koubei数据集
- 来源:Koubei是支付宝平台上的一个本地生活服务平台,数据集记录了用户与附近商家的互动。
- 时间跨度:数据集涵盖了9周的时间段。
- 数据特点:包含了用户对商家的评分、评论和交易等信息,反映了用户对本地服务的偏好和选择。
- 规模:包含了一定数量的用户和商家,以及用户对商家的各种互动。
3. Amazon数据集
- 来源:Amazon是全球知名的电子商务平台,数据集包含了用户对产品的评价信息。
- 时间跨度:数据集收集了13周内的用户评价数据。
- 数据特点:包含了用户对产品的评分和评论,这些数据对于理解用户偏好和产品质量非常重要。
- 规模:具有大量的用户和产品,以及丰富的评价信息。
这些数据集的共同特点是都包含了随时间变化的用户-物品交互信息,这对于评估推荐系统的动态适应性至关重要。通过在这些数据集上进行实验,作者能够验证GraphPro框架在不同场景和时间尺度下的推荐性能,以及其在处理动态数据变化时的有效性和鲁棒性。此外,这些数据集的公开可用性也使得其他研究人员可以复现和比较GraphPro与其他推荐系统模型的性能。
对论文中的pretrain.txt数据做分析:
html
<!--
从0之后的8个数字(16198, 39895, ..., 81015)代表用户0与特定物品的交互。
后面的数字(1453161600, 1453161600, ..., 1453161600)是对应的时间戳,表示这些交互发生的时间。由于时间戳重复,这意味着在这些特定时间点,用户0与多个物品进行了交互。后续类似
-->
0 16198 39895 49560 67677 68881 75476 81015 1453161600 1453161600 1453161600 1453161600 1453161600 1453161600 1453161600
1 30158 52823 76295 1453766400 1453766400 1453766400
2 1391 10625 11537 1452729600 1452729600 1452729600
3 89514 1452211200
4 5573 92552 1452988800 1452988800
5 10444 22618 1452729600 1452729600
6 67102 75422 13393 73402 1452902400 1453507200 1453680000 1453680000
7 17988 1452902400
Pre-training 预训练
要从头开始预训练模型,预训练的主文件是 pretrain.py
。您可以通过以下命令在数据集上运行图预训练:
python
# Taobao
python pretrain.py --data_path dataset/taobao --exp_name pretrain --phase pretrain --log 1 --device cuda:0 --model GraphPro --lr 1e-3 --edge_dropout 0.5
# Koubei
python pretrain.py --data_path dataset/koubei --exp_name pretrain --phase pretrain --log 1 --device cuda:0 --model GraphPro --lr 1e-3 --edge_dropout 0.2 --hour_interval_pre 24
# Amazon
python pretrain.py --data_path dataset/amazon --exp_name pretrain --phase pretrain --log 1 --device cuda:0 --model GraphPro --lr 1e-3 --edge_dropout 0.2 --hour_interval_pre 24
数据集的预处理 DataLoader
python
from utils.parse_args import args
from os import path
from tqdm import tqdm
import numpy as np
import scipy.sparse as sp
import torch
import logging
import networkx as nx
from copy import deepcopy
from collections import defaultdict
import pandas as pd
# from torch_geometric.data import Data
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger('train_logger')
logger.setLevel(logging.INFO)
class EdgeListData:
def __init__(self, train_file, test_file, phase='pretrain', pre_dataset=None, user_hist_files=[], has_time=True):
# 记录加载数据集的阶段
logger.info(f"Loading dataset for {phase}...")
self.phase = phase # 训练或微调阶段
self.has_time = has_time # 数据集是否包含时间戳
self.pre_dataset = pre_dataset # 预训练数据集,如果有的话
self.hour_interval = args.hour_interval_pre if phase == 'pretrain' else args.hour_interval_f
self.edgelist = []
self.edge_time = []
self.num_users = 0
self.num_items = 0
self.num_edges = 0
self.train_user_dict = {}
self.test_user_dict = {}
self._load_data(train_file, test_file, has_time)
# 根据阶段不同,设置用户历史交互字典
if phase == 'pretrain':
self.user_hist_dict = self.train_user_dict # 预训练阶段使用训练数据
elif phase == 'finetune':
self.user_hist_dict = deepcopy(self.train_user_dict) # 微调阶段复制训练数据
self._load_user_hist_from_files(user_hist_files) # 加载额外的用户历史文件
users_has_hist = set(list(self.user_hist_dict.keys())) # 117338
all_users = set(list(range(self.num_users)))
users_no_hist = all_users - users_has_hist # 112
logger.info(f"Number of users from all users with no history: {len(users_no_hist)}")
for u in users_no_hist:
self.user_hist_dict[u] = []
def _read_file(self, train_file, test_file, has_time=True):
with open(train_file, 'r') as f:
for line in f:
line = line.strip().split('\t') # 源节点 目标节点 交互的时间戳,以\t做切割
if not has_time:
user, items = line[:2]
times = " ".join(["0"] * len(items.split(" ")))
else:
user, items, times = line
# 将用户-物品交互添加到边列表和时间列表
for i in items.split(" "):
self.edgelist.append((int(user), int(i)))
for i in times.split(" "):
self.edge_time.append(int(i))
# 将用户和物品交互存储到字典中
self.train_user_dict[int(user)] = [int(i) for i in items.split(" ")] # 字典的形式存储{源节点1:[目标节点集合1],源节点2:[目标节点集合2],...}
self.test_edge_num = 0
with open(test_file, 'r') as f:
for line in f:
line = line.strip().split('\t')
user, items = line[:2]
self.test_user_dict[int(user)] = [int(i) for i in items.split(" ")]
self.test_edge_num += len(self.test_user_dict[int(user)])
logger.info('Number of test users: {}'.format(len(self.test_user_dict)))
def _read_pd(self, train_pd, test_pd, has_time=True):
for i in range(len(train_pd)):
line = train_pd.iloc[i]
if not has_time:
user, items = line[0], line[1]
times = " ".join(["0"] * len(items.split(" ")))
else:
user, items, times = line[0], line[1], line[2]
for i in items.split(" "):
self.edgelist.append((int(user), int(i)))
for i in times.split(" "):
self.edge_time.append(int(i))
self.train_user_dict[int(user)] = [int(i) for i in items.split(" ")]
self.test_edge_num = 0
for i in range(len(test_pd)):
line = test_pd.iloc[i]
user, items = line[0], line[1]
self.test_user_dict[int(user)] = [int(i) for i in items.split(" ")]
self.test_edge_num += len(self.test_user_dict[int(user)])
logger.info('Number of test users: {}'.format(len(self.test_user_dict)))
# 定义加载数据的私有方法
def _load_data(self, train_file, test_file, has_time=True):
if isinstance(train_file, pd.DataFrame):
self._read_pd(train_file, test_file, has_time)
else:
self._read_file(train_file, test_file, has_time)
self.edgelist = np.array(self.edgelist, dtype=np.int32)
# refine timestamp to predefined time steps at intervals
# 0 as padding for self-loop
self.edge_time = 1 + self.timestamp_to_time_step(np.array(self.edge_time, dtype=np.int32)) # 范围:1-120
self.num_edges = len(self.edgelist)
if self.pre_dataset is not None:
self.num_users = self.pre_dataset.num_users
self.num_items = self.pre_dataset.num_items
else:
self.num_users = max([np.max(self.edgelist[:, 0]) + 1, np.max(list(self.test_user_dict.keys())) + 1])
self.num_items = max([np.max(self.edgelist[:, 1]) + 1, np.max([np.max(self.test_user_dict[u]) for u in self.test_user_dict.keys()]) + 1])
logger.info('Number of users: {}'.format(self.num_users))
logger.info('Number of items: {}'.format(self.num_items))
logger.info('Number of edges: {}'.format(self.num_edges))
# 创建了一个稀疏矩阵来表示用户和物品之间的交互图
self.graph = sp.coo_matrix((np.ones(self.num_edges), (self.edgelist[:, 0], self.edgelist[:, 1])), shape=(self.num_users, self.num_items))
if self.has_time:
self.edge_time_dict = defaultdict(dict)
for i in range(len(self.edgelist)):
self.edge_time_dict[self.edgelist[i][0]][self.edgelist[i][1]+self.num_users] = self.edge_time[i]
self.edge_time_dict[self.edgelist[i][1]+self.num_users][self.edgelist[i][0]] = self.edge_time[i]
# self.edge_time_dict[self.edgelist[i][0]][self.edgelist[i][0]] = 0
# self.edge_time_dict[self.edgelist[i][1]][self.edgelist[i][1]] = 0
# homogenous edges generation
# if args.ab in ["homo", "full"]:
# self.ii_adj = self.graph.T.dot(self.graph).tocoo()
# # sort the values in the sparse matrix and get top x% values and corresponding edges
# percentage_ii = 0.01
# tmp_data = sorted(self.ii_adj.data, reverse=True)
# tmp_data_xpercent, tmp_data_xpercent_len = tmp_data[int(len(tmp_data) * percentage_ii)], int(len(tmp_data) * percentage_ii)
# self.ii_adj.data = np.where(self.ii_adj.data > tmp_data_xpercent, self.ii_adj.data, 0)
# self.ii_adj.eliminate_zeros()
# logger.info(f"Sampled {len(self.ii_adj.data)} i-i edges from all {len(tmp_data)} edges.")
# self.uu_adj = self.graph.dot(self.graph.T).tocoo()
# # same filtering for uu_adj
# percentage_uu = 0.01
# tmp_data = sorted(self.uu_adj.data, reverse=True)
# tmp_data_xpercent, tmp_data_xpercent_len = tmp_data[int(len(tmp_data) * percentage_uu)], int(len(tmp_data) * percentage_uu)
# self.uu_adj.data = np.where(self.uu_adj.data > tmp_data_xpercent, self.uu_adj.data, 0)
# self.uu_adj.eliminate_zeros()
# logger.info(f"Sampled {len(self.uu_adj.data)} u-u edges from all {len(tmp_data)} edges.")
# self.graph = nx.Graph()
# for i in range(len(self.edgelist)):
# self.graph.add_edge(self.edgelist[i][0], self.edgelist[i][1], time=self.edge_time[i])
# print(self.graph.number_of_nodes(), self.graph.number_of_edges())
# self.ui_adj = nx.adjacency_matrix(self.graph, weight=None)[:self.num_users, self.num_users:].tocoo()
# self.ii_adj = self.ui_adj.T.dot(self.ui_adj).tocoo()
# self.uu_adj = self.ui_adj.dot(self.ui_adj.T).tocoo()
# self.graph = Data(torch.zeros(self.num_users + self.num_items, 1), torch.tensor(self.edgelist).t(), torch.tensor(self.edge_time))
def _load_user_hist_from_files(self, user_hist_files):
for file in user_hist_files:
with open(file, 'r') as f:
for line in f:
line = line.strip().split('\t')
user, items = int(line[0]), [int(i) for i in line[1].split(" ")]
try:
self.user_hist_dict[user].extend(items)
except KeyError:
self.user_hist_dict[user] = items
def sample_subgraph(self):
pass
def get_train_batch(self, start, end):
def negative_sampling(user_item, train_user_set, n=1):
neg_items = []
for user, _ in user_item:
user = int(user)
for i in range(n):
while True:
neg_item = np.random.randint(low=0, high=self.num_items, size=1)[0]
if neg_item not in train_user_set[user]:
break
neg_items.append(neg_item)
return neg_items
ui_pairs = self.edgelist[start:end]
users = torch.LongTensor(ui_pairs[:, 0]).to(args.device)
pos_items = torch.LongTensor(ui_pairs[:, 1]).to(args.device)
if args.model == "MixGCF":
neg_items = negative_sampling(ui_pairs, self.train_user_dict, args.n_negs)
else:
neg_items = negative_sampling(ui_pairs, self.train_user_dict, 1)
neg_items = torch.LongTensor(neg_items).to(args.device)
return users, pos_items, neg_items
def shuffle(self):
random_idx = np.random.permutation(self.num_edges)
self.edgelist = self.edgelist[random_idx]
self.edge_time = self.edge_time[random_idx]
def _generate_binorm_adj(self, edgelist):
adj = sp.coo_matrix((np.ones(len(edgelist)), (edgelist[:, 0], edgelist[:, 1])),
shape=(self.num_users, self.num_items), dtype=np.float32)
a = sp.csr_matrix((self.num_users, self.num_users))
b = sp.csr_matrix((self.num_items, self.num_items))
adj = sp.vstack([sp.hstack([a, adj]), sp.hstack([adj.transpose(), b])])
adj = (adj != 0) * 1.0
degree = np.array(adj.sum(axis=-1))
d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1])
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
adj = adj.dot(d_inv_sqrt_mat).transpose().dot(d_inv_sqrt_mat).tocoo()
ui_adj = adj.tocsr()[:self.num_users, self.num_users:].tocoo()
return adj
'''
将原始的时间戳转换为时间步,这是处理时间序列数据时常见的一种方法。
@param timestamp_arr: 原始的时间戳数组
方法步骤:
1、interval_hour: 从类的实例变量中获取时间间隔,这个时间间隔定义了每个时间步代表的小时数。
2、如果 least_time 未提供,方法将计算 timestamp_arr 中的最小值作为 least_time。这个最小值将用作后续转换的基准时间。
3、打印出时间戳数组中的最小值、排序后的第二小的值、第三小的值和最大值。这有助于了解数据集中时间戳的分布情况。
4、从每个时间戳中减去 least_time,这样可以将所有时间戳转换为相对于最小时间戳的偏移量。
5、将偏移量除以 interval_hour 和 3600(一小时的秒数),这样就将时间戳转换为了时间步。这个操作基于整数除法,结果将向下取整。
6、返回转换后的时间步数组。
原始的时间戳被标准化为一系列整数,每个整数代表一个时间步。这样,模型就可以使用这些时间步作为输入特征,而不需要直接处理原始的时间戳。
这种方法在处理具有时间动态性的数据时非常有用,例如在推荐系统中捕捉用户行为随时间的变化。
'''
def timestamp_to_time_step(self, timestamp_arr, least_time=None):
interval_hour = self.hour_interval
if least_time is None:
least_time = np.min(timestamp_arr)
print("Least time: ", least_time)
print("2nd least time: ", np.sort(timestamp_arr)[1]) # 倒数第2小
print("3rd least time: ", np.sort(timestamp_arr)[2]) # 倒数第3小
print("Max time: ", np.max(timestamp_arr))
timestamp_arr = timestamp_arr - least_time
timestamp_arr = timestamp_arr // (interval_hour * 3600)
return timestamp_arr
if __name__ == '__main__':
edgelist_dataset = EdgeListData("dataset/yelp_small/train.txt", "dataset/yelp_small/test.txt")
# edgelist_dataset = EdgeListData("D:\\技术学习资料\\综合技术学习笔记\\图神经网络_推荐\\图神经网络\\文献\\文献\\动态异质图推荐\\GraphPro(必看,很可能将来论文就是在这篇做改进)\\GraphPro-master\\dataset\\taobao\\pretrain.txt", "D:\\技术学习资料\\综合技术学习笔记\\图神经网络_推荐\\图神经网络\\文献\\文献\\动态异质图推荐\\GraphPro(必看,很可能将来论文就是在这篇做改进)\\GraphPro-master\\dataset\\taobao\\test_1.txt")
GraphPro模型构建代码
python
import torch
import torch.nn as nn
from modules.base_model import BaseModel
from utils.parse_args import args
import torch.nn.functional as F
from modules.utils import EdgelistDrop
import logging
from modules.utils import scatter_add, scatter_sum
import torch_scatter
# 初始化权重的函数
init = nn.init.xavier_uniform_
logger = logging.getLogger('train_logger') # 获取日志记录器实例
class GraphPro(BaseModel):
def __init__(self, dataset, pretrained_model=None, phase='pretrain'):
super().__init__(dataset) # 调用父类的构造函数
# 创建双归一化的邻接矩阵
self.adj = self._make_binorm_adj(dataset.graph)
# 获取邻接矩阵的索引和值
self.edges = self.adj._indices().t()
self.edge_norm = self.adj._values()
# 获取边的时间信息
self.edge_times = [dataset.edge_time_dict[e[0]][e[1]] for e in self.edges.cpu().tolist()]
self.edge_times = torch.LongTensor(self.edge_times).to(args.device)
self.phase = phase
# 随机生成门控权重
self.emb_gate = self.random_gate
# 预训练或为了微调的特定层的嵌入初始化
if self.phase == 'pretrain' or self.phase == 'for_tune':
self.user_embedding = nn.Parameter(init(torch.empty(self.num_users, self.emb_size)))
self.item_embedding = nn.Parameter(init(torch.empty(self.num_items, self.emb_size)))
if self.phase == 'for_tune':
self.emb_gate = self.random_gate
# load gating weights from pretrained model
# 微调阶段加载预训练模型的嵌入和门控权重
elif self.phase == 'finetune':
pre_user_emb, pre_item_emb = pretrained_model.generate()
self.user_embedding = nn.Parameter(pre_user_emb).requires_grad_(True)
self.item_embedding = nn.Parameter(pre_item_emb).requires_grad_(True)
self.gating_weight = nn.Parameter(init(torch.empty(args.emb_size, args.emb_size)))
self.gating_bias = nn.Parameter(init(torch.empty(1, args.emb_size)))
self.emb_dropout = nn.Dropout(args.emb_dropout)
# 定义门控机制
self.emb_gate = lambda x: self.emb_dropout(torch.mul(x, torch.sigmoid(torch.matmul(x, self.gating_weight) + self.gating_bias)))
# 定义边的dropout机制
self.edge_dropout = EdgelistDrop()
# 记录最大时间步长
logger.info(f"Max Time Step: {self.edge_times.max()}")
def random_gate(self, x):
# 随机生成并归一化门控权重和偏置
# 使用 F.normalize 函数对生成的权重参数进行归一化,确保它们的范数为 1。这有助于防止梯度消失或爆炸,并保持数值稳定性
gating_weight = F.normalize(torch.randn((args.emb_size, args.emb_size)).to(args.device))
gating_bias = F.normalize(torch.randn((1, args.emb_size)).to(args.device))
# 计算门控机制的输出。
# 首先,通过矩阵乘法将输入嵌入与权重相乘。
# 然后,加上偏置。
# 使用 sigmoid 函数将结果压缩到 0 和 1 之间,这样权重和偏置就控制了输入嵌入的激活程度。
gate = torch.sigmoid(torch.matmul(x, gating_weight) + gating_bias)
# 应用门控机制
# 将输入嵌入向量 x 与门控信号 gate 进行逐元素乘积。
# 这样,输入嵌入向量中的每个维度都将根据门控信号的相应值进行放大或缩小。
return torch.mul(x, gate)
def _agg(self, all_emb, edges, edge_norm):
# 聚合函数,用于GNN中的信息聚合
# all_emb: 包含所有节点(用户和物品)嵌入的张量
# edges: 一个二维数组,其中每一行代表一条边,列分别代表边的源节点和目标节点
# edge_norm: 边的归一化值,用于在聚合过程中进行加权
src_emb = all_emb[edges[:, 0]] # 源节点的嵌入
# 双归一化
# 提取所有嵌入中的源节点嵌入
# 对源节点嵌入进行双归一化处理
# 通过将源节点嵌入与对应的边的归一化值相乘来实现
# 这里使用unsqueeze(1)来给归一化值增加一个维度,使其可以与嵌入进行逐元素乘法
src_emb = src_emb * edge_norm.unsqueeze(1) # all_emb中索引为edges第一列值的行
# 使用scatter_sum函数进行聚合操作
# scatter_sum是一种在GNN中常用的操作,它将src_emb中的值根据edges中的索引分散到目标位置
# edges[:, 1]包含目标节点的索引
# dim=0表示沿着第一个维度(即行)进行scatter操作
# dim_size=self.num_users+self.num_items表示整个数据的维度大小,包括用户和物品
# dst_emb是聚合后的目标节点嵌入
# conv 通过scatter_sum函数进行聚合
dst_emb = scatter_sum(src_emb, edges[:, 1], dim=0, dim_size=self.num_users+self.num_items)
return dst_emb # 返回更新后的目标节点嵌入
def _edge_binorm(self, edges):
# 计算边的双归一化系数
# 使用 scatter_add 函数计算每个用户的度(即与该用户相连的边的数量)。
# torch.ones_like(edges[:, 0]) 创建一个与 edges[:, 0] 形状相同的全1张量。
# 将这个全1张量在第一个维度(dim=0)上累加,即对所有边的源节点索引进行累加。
# dim_size=self.num_users 指定了累加操作的维度大小,即用户的数量。
user_degs = scatter_add(torch.ones_like(edges[:, 0]), edges[:, 0], dim=0, dim_size=self.num_users)
# 从计算出的度张量中取出对应于边的源节点的度。
user_degs = user_degs[edges[:, 0]]
# 同样地,计算每个物品的度。
item_degs = scatter_add(torch.ones_like(edges[:, 1]), edges[:, 1], dim=0, dim_size=self.num_items)
# 从计算出的度张量中取出对应于边的目标节点的物品的度。
item_degs = item_degs[edges[:, 1]]
# 计算双归一化系数。
# 对每个用户的度取负0.5次方,对每个物品的度也取负0.5次方。
# 然后,将这两个结果相乘,得到每条边的双归一化系数。
norm = torch.pow(user_degs, -0.5) * torch.pow(item_degs, -0.5)
# 返回计算出的双归一化系数。
return norm
# 相对时间编码函数
# 该方法接受三个参数:edges(边的索引),edge_times(边的时间戳),max_step(可选,用于归一化的最大时间步长)
def _relative_edge_time_encoding(self, edges, edge_times, max_step=None):
# 相对时间编码函数,用于处理边的时间信息
# for each node, normalize edge_times according to its neighbors
# edge_times: [E]
# rescal to 0-1
edge_times = edge_times.float() # 转换为浮点数
# 如果没有提供max_step,则找出所有时间戳中的最大值作为max_step
if max_step is None:
max_step = edge_times.max() # 获取最大时间步长
# 将时间戳归一化到0和1之间,这样所有的时间戳都会被映射到一个统一的尺度上
# 这有助于模型更好地处理时间信息,避免时间戳的绝对值对模型造成影响
edge_times = (edge_times - edge_times.min()) / (max_step - edge_times.min()) # 归一化到0-1之间
# edge_times = torch.exp(edge_times)
# edge_times = torch.sigmoid(edge_times)
# 获取每条边的目标节点索引
dst_nodes = edges[:, 1] # 目标节点
# 使用torch_scatter库中的scatter_softmax函数计算每个目标节点的时间归一化值
# 这个函数会对每个目标节点对应的所有时间戳执行softmax操作,生成一个时间归一化向量
# dim_size参数指定了整个数据的维度大小,包括用户和物品
time_norm = torch_scatter.scatter_softmax(edge_times, dst_nodes, dim_size=self.num_users+self.num_items) # 计算时间的softmax归一化
# time_norm = time_norm[dst_nodes]
# 返回计算出的时间归一化向量
return time_norm
def forward(self, edges, edge_norm, edge_times, max_time_step=None):
time_norm = self._relative_edge_time_encoding(edges, edge_times, max_step=max_time_step)
edge_norm = edge_norm * 1/2 + time_norm * 1/2 # 合并时间归一化和边的归一化
all_emb = torch.cat([self.user_embedding, self.item_embedding], dim=0) # 拼接用户和物品的嵌入
all_emb = self.emb_gate(all_emb) # 应用门控机制
res_emb = [all_emb] # 存储每层的嵌入
for l in range(args.num_layers): # 遍历所有层
all_emb = self._agg(res_emb[-1], edges, edge_norm) # 聚合信息
res_emb.append(all_emb) # 存储当前层的嵌入
res_emb = sum(res_emb) # 合并所有层的嵌入
user_res_emb, item_res_emb = res_emb.split([self.num_users, self.num_items], dim=0) # 分离用户和物品的嵌入
return user_res_emb, item_res_emb
def cal_loss(self, batch_data):
# 应用边的dropout机制,以防止过拟合,并获取对应的掩码
# dropout_mask是一个布尔张量,表示哪些边应该被保留
edges, dropout_mask = self.edge_dropout(self.edges, 1-args.edge_dropout, return_mask=True)
# 从原始的边归一化值中选择被保留的边对应的值
edge_norm = self.edge_norm[dropout_mask]
# 从时间戳数组中选择被保留的边对应的时间戳
edge_times = self.edge_times[dropout_mask]
# forward # 获取批次数据中的用户、正样本和负样本
users, pos_items, neg_items = batch_data
# 执行模型的前向传播,获取用户和物品的嵌入表示
# 这里的edges是经过dropout处理后的边的索引
# edge_norm是经过dropout处理后的边的归一化系数
# edge_times是经过dropout处理后的边的时间戳
user_emb, item_emb = self.forward(edges, edge_norm, edge_times)
# 从嵌入表示中提取批次数据中用户和正样本的嵌入
batch_user_emb = user_emb[users]
pos_item_emb = item_emb[pos_items]
# 从嵌入表示中提取批次数据中负样本的嵌入
neg_item_emb = item_emb[neg_items]
# 计算推荐损失,这里使用_bpr_loss方法,它可能是一个基于BPR(Bayesian Personalized Ranking)的损失函数
rec_loss = self._bpr_loss(batch_user_emb, pos_item_emb, neg_item_emb)
# 计算正则化损失,以防止模型过拟合
# weight_decay是一个超参数,用于控制正则化的强度
reg_loss = args.weight_decay * self._reg_loss(users, pos_items, neg_items)
# 总损失是推荐损失和正则化损失的和
loss = rec_loss + reg_loss
# 创建一个字典,存储不同类型损失的值
loss_dict = {
"rec_loss": rec_loss.item(), # 推荐损失的值
"reg_loss": reg_loss.item(), # 正则化损失的值
}
# 返回总损失和损失字典
return loss, loss_dict
@torch.no_grad()
def generate(self, max_time_step=None):
# 生成用户和物品的嵌入表示,不计算梯度
# max_time_step参数用于指定时间步长的最大值,如果提供了该参数
return self.forward(self.edges, self.edge_norm, self.edge_times, max_time_step=max_time_step)
@torch.no_grad()
def rating(self, user_emb, item_emb):
# 计算用户和物品之间的评分预测
return torch.matmul(user_emb, item_emb.t())
def _reg_loss(self, users, pos_items, neg_items):
# 计算正则化损失
# 用于防止模型过拟合,通过惩罚大的嵌入表示
u_emb = self.user_embedding[users] # 获取用户的嵌入表示
pos_i_emb = self.item_embedding[pos_items] # 获取正样本物品的嵌入表示
neg_i_emb = self.item_embedding[neg_items] # 获取负样本物品的嵌入表示
# 计算正则化损失,这里是L2范数的平方
reg_loss = (1/2)*(u_emb.norm(2).pow(2) +
pos_i_emb.norm(2).pow(2) +
neg_i_emb.norm(2).pow(2))/float(len(users))
return reg_loss